Mercurial > libervia-backend
changeset 3637:51983c55c5b6
merge branche "@"
author | Goffi <goffi@goffi.org> |
---|---|
date | Wed, 01 Sep 2021 13:47:17 +0200 |
parents | 7bc443253b7c (diff) 43542cf32e5a (current diff) |
children | 257135d5c5c2 |
files | sat_frontends/jp/base.py |
diffstat | 67 files changed, 5326 insertions(+), 3138 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/dev-requirements.txt Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,4 @@ +-r requirements.txt + +pytest +pytest_twisted
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/doc/_ext/docstring.py Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 + +"""Adapt Libervia docstring style to autodoc""" + + +def process_docstring(app, what, name, obj, options, lines): + lines[:] = [ + l.replace("@param", ":param").replace("@raise", ":raises") + for l in lines + ] + + +def setup(app): + app.connect("autodoc-process-docstring", process_docstring) + return { + 'version': '0.1', + 'parallel_read_safe': True, + 'parallel_write_safe': True, + }
--- a/doc/conf.py Wed Sep 01 13:46:43 2021 +0200 +++ b/doc/conf.py Wed Sep 01 13:47:17 2021 +0200 @@ -12,13 +12,15 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # -# import os -# import sys -# sys.path.insert(0, os.path.abspath('.')) +import os +import sys import os.path import re +sys.path.insert(0, os.path.abspath("./_ext")) + + # -- Project information ----------------------------------------------------- project = 'Libervia' @@ -47,6 +49,8 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ + "sphinx.ext.autodoc", + "docstring" ] # Add any paths that contain templates here, relative to this directory.
--- a/doc/configuration.rst Wed Sep 01 13:46:43 2021 +0200 +++ b/doc/configuration.rst Wed Sep 01 13:47:17 2021 +0200 @@ -108,6 +108,8 @@ background = auto # end-user facing URL, used by Web frontend public_url = example.com + # uncomment next line if you don't want to use local cache for pubsub items + ; pubsub_cache_strategy = "no_cache" [plugin account] ; where a new account must be created
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/doc/developer.rst Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,82 @@ +.. _developer: + +======================= +Developer Documentation +======================= + +This documentation is intended for people who wants to contribute or work with the +internals of the project, it is not for end-users. + +Storage +======= + +Since version 0.9, Libervia uses SQLAlchemy_ with its Object–Relational Mapping as a +backend to store persistent data, and Alembic_ is used to handle schema and data +migrations. + +SQLite_ is currently the only supported database, but it is planned to add support for +other ones (notably PostgreSQL), probably during the development of 0.9 version. + +The mapping is done in ``sat.memory.sqla_mapping`` and working with database is done +through high level methods found in ``sat.memory.sqla``. + +Before the move to SQLAlchemy, there was a strict separation between database +implementation and the rest of the code. With 0.9, objects mapping to database can be used +and manipulated directly outside of ``sat.memory.sqla`` to take profit of SQLAlchemy +possibilities. + +Database state is detected when the backend starts, and the database will be created or +migrated automatically if necessary. + +To create a new migration script, ``Alembic`` may be used directly. To do so, be sure to +have an up-to-date database (and a backup in case of troubles), then activate the virtual +environment where Libervia is installed (Alembic needs to access ORM mapping), go to +``sat/memory/migration`` directory, and enter the following command:: + + alembic revision --autogenerate -m "some revision message" + +This will create a base migration file in ``versions`` directory. Adapt it to your needs, +try to create both ``upgrade`` and ``downgrade`` method whenever possible, and be sure to +test it in both directions (``alembic upgrade head`` and ``alembic downgrade +<previous_revision>``). Please check Alembic documentation for more details. + +.. _SQLALchemy: https://www.sqlalchemy.org/ +.. _Alembic: https://alembic.sqlalchemy.org/ +.. _SQLite: https://sqlite.org + +Pubsub Cache +============ + +There is an internal cache for pubsub nodes and items, which is done in +``plugin_pubsub_cache``. The ``PubsubNode`` and ``PubsubItem`` class are the one mapping +the database. + +The cache is operated transparently to end-user, when a pubsub request is done, it uses a +trigger to check if the requested node is or must be cached, and if possible returns +result directly from database, otherwise it lets the normal workflow continue and query the +pubsub service. + +To save resources, not all nodes are fully cached. When a node is checked, a series of +analysers are checked, and the first one matching is used to determine if the node must be +synchronised or not. + +Analysers can be registered by any plugins using ``registerAnalyser`` method: + +.. automethod:: sat.plugins.plugin_pubsub_cache.PubsubCache.registerAnalyser + +If no analyser is found, ``to_sync`` is false, or an error happens during the caching, +the node won't be synchronised and the pubsub service will always be requested. + +Specifying an optional **parser** will store parsed data in addition to the raw XML of the +items. This is more space consuming, but may be desired for the following reasons: + +* the parsing is resource consuming (network call or some CPU intensive operations are + done) +* it is desirable to to queries on parsed data. Indeed the parsed data is stored in a + JSON_ field and its keys may be queried individually. + +The Raw XML is kept as the cache operates transparently, and a plugin may need raw data, or +an user may be doing a low-level pubsub request. + +.. _JSON: https://docs.sqlalchemy.org/en/14/core/type_basics.html#sqlalchemy.types.JSON +
--- a/doc/index.rst Wed Sep 01 13:46:43 2021 +0200 +++ b/doc/index.rst Wed Sep 01 13:47:17 2021 +0200 @@ -26,6 +26,7 @@ installation.rst overview.rst configuration.rst + developer.rst /libervia-cli/index.rst /libervia-tui/index.rst /contribuing/index.rst
--- a/doc/libervia-cli/common_arguments.rst Wed Sep 01 13:46:43 2021 +0200 +++ b/doc/libervia-cli/common_arguments.rst Wed Sep 01 13:47:17 2021 +0200 @@ -88,6 +88,11 @@ this case the ``-m, --max`` argument should be prefered. See below for RSM common arguments. +``-C, --no-cache`` + skip pubsub cache. By default, internal pubsub cache is used automatically if requested + items are available there. With this option set, a request to the pubsub service will + always be done, regardless of presence of node items in internal cache. + result set management ===================== @@ -191,7 +196,7 @@ manipulated by a script, or if you want only a specific element of the result. ``-O {…}, --output {…}`` - specifiy the output to use. Available options depends of the command you are using, + specify the output to use. Available options depends of the command you are using, check ``li [your command] --help`` to know them. e.g.:: @@ -209,6 +214,64 @@ Some options expect parameters, in this case they can be specified using ``=``. - e.g. specifiying a template to use:: + e.g. specifying a template to use:: $ li blog get -O template --oo browser --oo template=/tmp/my_template.html + +Time Pattern +============ + +When a command expect a date or date with time argument, you can use a "time pattern" (you +usually see the ``TIME_PATTERN`` as name of the argument in ``--help`` message when it can +be used). + +This is a flexible way to enter a date, you can enter a date in one of the following way: + +- the string ``now`` to indicate current date and time; +- an absolute date using an international format. The parser know many formats (please + check dateutil_ package documentation to have a detail of supported formats). Please + note that days are specified first and that if no time zone is specified, the local + time zone of your computer is assumed; +- a relative date (or "relative delta"), see below for details on how to construct it; +- a reference time (``now`` or absolute date as above) followed by a relative delta. If + the reference time is not specified, ``now`` is used; + +Time pattern is not case sensitive. + +.. _dateutil: https://dateutil.readthedocs.io + +Relative Delta +-------------- + +A relative delta is specified with: + +- an optional direction ``+`` for times after reference time or ``-`` for time before reference time (defaulting to ``+``); +- a number for the quantity of units +- a unit (e.g. seconds or minutes), see the bellow for details +- the word ``ago`` which is same as using ``-`` for direction (direction and ``ago`` can't + be used at the same time) + +Time Units +---------- + +The singular or plural form of following units can be used: + +- ``s``, ``sec``, ``second`` +- ``m``, ``min``, ``minute`` +- ``h``, ``hr``, ``hour`` +- ``d``, ``day`` +- ``w``, ``week`` +- ``mo``, ``month`` +- ``y``, ``yr``, ``year`` + +examples +-------- + +- ``2022-01-01``: first of January of 2022 at midnight +- ``2017-02-10T13:05:00Z``: 10th of February 2017 at 13:05 UTC +- ``2019-07-14 12:00:00 CET``: 14th of July 2019 at 12:00 CET +- ``10 min ago``: current time minus 10 minutes +- ``now - 10 m``: same as above (current time minus 10 minutes) +- ``07/08/2021 +5 hours``: 7 August 2021 at midnight (local time of the computer) + 5 + hours, i.e. 5 in the morning at local time. +
--- a/doc/libervia-cli/pubsub.rst Wed Sep 01 13:46:43 2021 +0200 +++ b/doc/libervia-cli/pubsub.rst Wed Sep 01 13:47:17 2021 +0200 @@ -33,6 +33,8 @@ $ echo '<note xmlns="http://example.net/mynotes">this is a note</note>' | li pubsub set -n "notes" +.. _li_pubsub_get: + get === @@ -43,15 +45,16 @@ Retrieve the last 5 notes from our custom notes node:: - $ li pubsub get -n notes -m 5 + $ li pubsub get -n notes -M 5 .. _li_pubsub_delete: delete ====== -Delete an item from a node. If ``-N, --notify`` is specified, subscribers will be notified -of the item retraction. +Delete an item from a node. If ``--no-notification`` is specified, subscribers wont be notified +of the item retraction (this is NOT recommended, as it will cause trouble to keep items in +sync, take caution when using this flag). By default a confirmation is requested before deletion is requested to the PubSub service, but you can override this behaviour by using ``-f, --force`` option. @@ -330,7 +333,7 @@ example ------- -Imagine that you want to replace all occurrences of "sàt" by "Libervia" in your personal blog. You first create a Python script like this: +Imagine that you want to replace all occurrences of "SàT" by "Libervia" in your personal blog. You first create a Python script like this: .. sourcecode:: python @@ -338,10 +341,10 @@ import sys item_raw = sys.stdin.read() - if not "sàt" in item_raw: + if not "SàT" in item_raw: print("SKIP") else: - print(item_raw.replace("sàt", "Libervia")) + print(item_raw.replace("SàT", "Libervia")) And save it a some location, e.g. ``~/expand_sat.py`` (don't forget to make it executable with ``chmod +x ~/expand_sat.py``). @@ -380,3 +383,8 @@ ==== Subcommands for hooks management. Please check :ref:`libervia-cli_pubsub_hook`. + +cache +===== + +Subcommands for cache management. Please check :ref:`libervia-cli_pubsub_cache`.
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/doc/libervia-cli/pubsub_cache.rst Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,98 @@ +.. _libervia-cli_pubsub_cache: + +===================================== +pubsub/cache: PubSub Cache Management +===================================== + +Libervia runs transparently a cache for pubsub. That means that according to internal +criteria, some pubsub items are stored locally. + +The ``cache`` subcommands let user inspect and manipulate the internal cache. + +get +=== + +Retrieve items from internal cache only. Most end-users won't need to use this command, as +the usual ``pubsub get`` command will use cache transparently. However, it may be useful +to inspect local cache, notably for debugging. + +The parameters are basically the same as for :ref:`li_pubsub_get`. + +example +------- + +Retrieve the last 2 cached items for personal blog:: + + $ li pubsub cache get -n urn:xmpp:microblog:0 -M 2 + +.. _li_pubsub_cache_sync: + +sync +==== + +Synchronise or resynchronise a pubsub node. If the node is already in cache, it will be +deleted then re-cached. Node will be put in cache even if internal policy doesn't request +a synchronisation for this kind of nodes. Node will be (re-)subscribed to keep cache +synchronised. + +All items of the node (up to the internal limit which is high), will be retrieved and put +in cache, even if a previous version of those items have been deleted by the +:ref:`li_pubsub_cache_purge` command. + + +example +------- + +Resynchronise personal blog:: + + $ li pubusb cache sync -n urn:xmpp:microblog:0 + +.. _li_pubsub_cache_purge: + +purge +===== + +Remove items from cache. This may be desirable to save resource, notably disk space. + +Note that once a pubsub node is cached, the cache is the source of trust. That means that +if cache is not explicitly bypassed when retrieving items of a pubsub node (notably with +the ``-C, --no-cache`` option of :ref:`li_pubsub_get`), only items found in cache will be +returned, thus purged items won't be used or returned anymore even if they still exists on +the original pubsub service. + +If you have purged items by mistake, it is possible to retrieve them either node by node +using :ref:`li_pubsub_cache_sync`, or by resetting the whole pubsub cache with +:ref:`li_pubsub_cache_reset`. + +If you have a node or a profile (e.g. a component) caching a lot of items frequently, you +may use this command using a scheduler like cron_. + +.. _cron: https://en.wikipedia.org/wiki/Cron + +examples +-------- + +Remove all blog and event items from cache if they haven't been updated since 6 months:: + + $ li pubsub cache purge -t blog -t event -b "6 months ago" + +Remove items from profile ``ap_gateway`` if they have been created more that 2 months +ago:: + + $ li pubsub cache purge -p ap_gateway --created-before "2 months ago" + +.. _li_pubsub_cache_reset: + +reset +===== + +Reset the whole pubsub cache. This means that all nodes and all them items will be removed +from cache. After this command, cache will be re-filled progressively as if it where a new +one. + +example +------- + +Reset the whole pubsub cache:: + + $ li pubsub cache reset
--- a/doc/libervia-cli/pubsub_hook.rst Wed Sep 01 13:46:43 2021 +0200 +++ b/doc/libervia-cli/pubsub_hook.rst Wed Sep 01 13:47:17 2021 +0200 @@ -4,7 +4,7 @@ pubsub/hook: PubSub hooks management ==================================== -``hook`` is a subcommands grouping all PubSub commands related to hooks management. Hooks +``hook`` is a subcommand grouping all PubSub commands related to hooks management. Hooks are user actions launched on specific events. 3 types of hooks can be used:
--- a/sat/VERSION Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/VERSION Wed Sep 01 13:47:17 2021 +0200 @@ -1,1 +1,1 @@ -0.8.0b1.post1 +0.9.0D
--- a/sat/bridge/bridge_constructor/base_constructor.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/bridge/bridge_constructor/base_constructor.py Wed Sep 01 13:47:17 2021 +0200 @@ -49,7 +49,7 @@ FRONTEND_TEMPLATE = None FRONTEND_DEST = None - # set to False if your bridge need only core + # set to False if your bridge needs only core FRONTEND_ACTIVATE = True def __init__(self, bridge_template, options): @@ -284,6 +284,7 @@ "args": self.getArguments( function["sig_in"], name=arg_doc, default=default ), + "args_no_default": self.getArguments(function["sig_in"], name=arg_doc), } extend_method = getattr(
--- a/sat/bridge/bridge_constructor/constructors/dbus/constructor.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/bridge/bridge_constructor/constructors/dbus/constructor.py Wed Sep 01 13:47:17 2021 +0200 @@ -25,20 +25,19 @@ CORE_TEMPLATE = "dbus_core_template.py" CORE_DEST = "dbus_bridge.py" CORE_FORMATS = { - "signals": """\ - @dbus.service.signal(const_INT_PREFIX+const_{category}_SUFFIX, - signature='{sig_in}') - def {name}(self, {args}): - {body}\n""", + "methods_declarations": """\ + Method('{name}', arguments='{sig_in}', returns='{sig_out}'),""", + "methods": """\ - @dbus.service.method(const_INT_PREFIX+const_{category}_SUFFIX, - in_signature='{sig_in}', out_signature='{sig_out}', - async_callbacks={async_callbacks}) - def {name}(self, {args}{async_comma}{async_args_def}): - {debug}return self._callback("{name}", {args_result}{async_comma}{async_args_call})\n""", - "signal_direct_calls": """\ + def dbus_{name}(self, {args}): + {debug}return self._callback("{name}", {args_no_default})\n""", + + "signals_declarations": """\ + Signal('{name}', '{sig_in}'),""", + + "signals": """\ def {name}(self, {args}): - self.dbus_bridge.{name}({args})\n""", + self._obj.emitSignal("{name}", {args})\n""", } FRONTEND_TEMPLATE = "dbus_frontend_template.py" @@ -68,17 +67,10 @@ def core_completion_method(self, completion, function, default, arg_doc, async_): completion.update( { - "debug": "" - if not self.args.debug - else 'log.debug ("%s")\n%s' % (completion["name"], 8 * " "), - "args_result": self.getArguments( - function["sig_in"], name=arg_doc, unicode_protect=self.args.unicode - ), - "async_comma": ", " if async_ and function["sig_in"] else "", - "async_args_def": "callback=None, errback=None" if async_ else "", - "async_args_call": "callback=callback, errback=errback" if async_ else "", - "async_callbacks": "('callback', 'errback')" if async_ else "None", - "category": completion["category"].upper(), + "debug": ( + "" if not self.args.debug + else f'log.debug ("{completion["name"]}")\n{8 * " "}' + ) } )
--- a/sat/bridge/bridge_constructor/constructors/dbus/dbus_core_template.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/bridge/bridge_constructor/constructors/dbus/dbus_core_template.py Wed Sep 01 13:47:17 2021 +0200 @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# SàT communication bridge +# Libervia communication bridge # Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) # This program is free software: you can redistribute it and/or modify @@ -16,15 +16,15 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. +from types import MethodType +from functools import partialmethod +from twisted.internet import defer, reactor from sat.core.i18n import _ -import dbus -import dbus.service -import dbus.mainloop.glib -import inspect from sat.core.log import getLogger +from sat.core.exceptions import BridgeInitError from sat.tools import config -from twisted.internet.defer import Deferred -from sat.core.exceptions import BridgeInitError +from txdbus import client, objects, error +from txdbus.interface import DBusInterface, Method, Signal log = getLogger(__name__) @@ -45,251 +45,127 @@ pass -class MethodNotRegistered(dbus.DBusException): - _dbus_error_name = const_ERROR_PREFIX + ".MethodNotRegistered" - - -class InternalError(dbus.DBusException): - _dbus_error_name = const_ERROR_PREFIX + ".InternalError" +class DBusException(Exception): + pass -class AsyncNotDeferred(dbus.DBusException): - _dbus_error_name = const_ERROR_PREFIX + ".AsyncNotDeferred" +class MethodNotRegistered(DBusException): + dbusErrorName = const_ERROR_PREFIX + ".MethodNotRegistered" -class DeferredNotAsync(dbus.DBusException): - _dbus_error_name = const_ERROR_PREFIX + ".DeferredNotAsync" - - -class GenericException(dbus.DBusException): +class GenericException(DBusException): def __init__(self, twisted_error): """ @param twisted_error (Failure): instance of twisted Failure - @return: DBusException + error message is used to store a repr of message and condition in a tuple, + so it can be evaluated by the frontend bridge. """ - super(GenericException, self).__init__() try: # twisted_error.value is a class class_ = twisted_error.value().__class__ except TypeError: # twisted_error.value is an instance class_ = twisted_error.value.__class__ - message = twisted_error.getErrorMessage() + data = twisted_error.getErrorMessage() try: - self.args = (message, twisted_error.value.condition) + data = (data, twisted_error.value.condition) except AttributeError: - self.args = (message,) - self._dbus_error_name = ".".join( - [const_ERROR_PREFIX, class_.__module__, class_.__name__] + data = (data,) + else: + data = (str(twisted_error),) + self.dbusErrorName = ".".join( + (const_ERROR_PREFIX, class_.__module__, class_.__name__) ) + super(GenericException, self).__init__(repr(data)) + + @classmethod + def create_and_raise(cls, exc): + raise cls(exc) -class DbusObject(dbus.service.Object): - def __init__(self, bus, path): - dbus.service.Object.__init__(self, bus, path) - log.debug("Init DbusObject...") +class DBusObject(objects.DBusObject): + + core_iface = DBusInterface( + const_INT_PREFIX + const_CORE_SUFFIX, +##METHODS_DECLARATIONS_PART## +##SIGNALS_DECLARATIONS_PART## + ) + plugin_iface = DBusInterface( + const_INT_PREFIX + const_PLUGIN_SUFFIX + ) + + dbusInterfaces = [core_iface, plugin_iface] + + def __init__(self, path): + super().__init__(path) + log.debug("Init DBusObject...") self.cb = {} def register_method(self, name, cb): self.cb[name] = cb def _callback(self, name, *args, **kwargs): - """call the callback if it exists, raise an exception else - if the callback return a deferred, use async methods""" - if not name in self.cb: + """Call the callback if it exists, raise an exception else""" + try: + cb = self.cb[name] + except KeyError: raise MethodNotRegistered - - if "callback" in kwargs: - # we must have errback too - if not "errback" in kwargs: - log.error("errback is missing in method call [%s]" % name) - raise InternalError - callback = kwargs.pop("callback") - errback = kwargs.pop("errback") - async_ = True else: - async_ = False - result = self.cb[name](*args, **kwargs) - if async_: - if not isinstance(result, Deferred): - log.error("Asynchronous method [%s] does not return a Deferred." % name) - raise AsyncNotDeferred - result.addCallback( - lambda result: callback() if result is None else callback(result) - ) - result.addErrback(lambda err: errback(GenericException(err))) - else: - if isinstance(result, Deferred): - log.error("Synchronous method [%s] return a Deferred." % name) - raise DeferredNotAsync - return result - - ### signals ### - - @dbus.service.signal(const_INT_PREFIX + const_PLUGIN_SUFFIX, signature="") - def dummySignal(self): - # FIXME: workaround for addSignal (doesn't work if one method doensn't - # already exist for plugins), probably missing some initialisation, need - # further investigations - pass - -##SIGNALS_PART## - ### methods ### + d = defer.maybeDeferred(cb, *args, **kwargs) + d.addErrback(GenericException.create_and_raise) + return d ##METHODS_PART## - def __attributes(self, in_sign): - """Return arguments to user given a in_sign - @param in_sign: in_sign in the short form (using s,a,i,b etc) - @return: list of arguments that correspond to a in_sign (e.g.: "sss" return "arg1, arg2, arg3")""" - i = 0 - idx = 0 - attr = [] - while i < len(in_sign): - if in_sign[i] not in ["b", "y", "n", "i", "x", "q", "u", "t", "d", "s", "a"]: - raise ParseError("Unmanaged attribute type [%c]" % in_sign[i]) - attr.append("arg_%i" % idx) - idx += 1 - - if in_sign[i] == "a": - i += 1 - if ( - in_sign[i] != "{" and in_sign[i] != "(" - ): # FIXME: must manage tuples out of arrays - i += 1 - continue # we have a simple type for the array - opening_car = in_sign[i] - assert opening_car in ["{", "("] - closing_car = "}" if opening_car == "{" else ")" - opening_count = 1 - while True: # we have a dict or a list of tuples - i += 1 - if i >= len(in_sign): - raise ParseError("missing }") - if in_sign[i] == opening_car: - opening_count += 1 - if in_sign[i] == closing_car: - opening_count -= 1 - if opening_count == 0: - break - i += 1 - return attr - - def addMethod(self, name, int_suffix, in_sign, out_sign, method, async_=False): - """Dynamically add a method to Dbus Bridge""" - inspect_args = inspect.getfullargspec(method) - - _arguments = inspect_args.args - _defaults = list(inspect_args.defaults or []) - - if inspect.ismethod(method): - # if we have a method, we don't want the first argument (usually 'self') - del (_arguments[0]) - - # first arguments are for the _callback method - arguments_callback = ", ".join( - [repr(name)] - + ( - (_arguments + ["callback=callback", "errback=errback"]) - if async_ - else _arguments - ) - ) - - if async_: - _arguments.extend(["callback", "errback"]) - _defaults.extend([None, None]) +class Bridge: - # now we create a second list with default values - for i in range(1, len(_defaults) + 1): - _arguments[-i] = "%s = %s" % (_arguments[-i], repr(_defaults[-i])) - - arguments_defaults = ", ".join(_arguments) + def __init__(self): + log.info("Init DBus...") + self._obj = DBusObject(const_OBJ_PATH) - code = compile( - "def %(name)s (self,%(arguments_defaults)s): return self._callback(%(arguments_callback)s)" - % { - "name": name, - "arguments_defaults": arguments_defaults, - "arguments_callback": arguments_callback, - }, - "<DBus bridge>", - "exec", - ) - exec(code) # FIXME: to the same thing in a cleaner way, without compile/exec - method = locals()[name] - async_callbacks = ("callback", "errback") if async_ else None - setattr( - DbusObject, - name, - dbus.service.method( - const_INT_PREFIX + int_suffix, - in_signature=in_sign, - out_signature=out_sign, - async_callbacks=async_callbacks, - )(method), - ) - function = getattr(self, name) - func_table = self._dbus_class_table[ - self.__class__.__module__ + "." + self.__class__.__name__ - ][function._dbus_interface] - func_table[function.__name__] = function # Needed for introspection - - def addSignal(self, name, int_suffix, signature, doc={}): - """Dynamically add a signal to Dbus Bridge""" - attributes = ", ".join(self.__attributes(signature)) - # TODO: use doc parameter to name attributes - - # code = compile ('def '+name+' (self,'+attributes+'): log.debug ("'+name+' signal")', '<DBus bridge>','exec') #XXX: the log.debug is too annoying with xmllog - code = compile( - "def " + name + " (self," + attributes + "): pass", "<DBus bridge>", "exec" - ) - exec(code) - signal = locals()[name] - setattr( - DbusObject, - name, - dbus.service.signal(const_INT_PREFIX + int_suffix, signature=signature)( - signal - ), - ) - function = getattr(self, name) - func_table = self._dbus_class_table[ - self.__class__.__module__ + "." + self.__class__.__name__ - ][function._dbus_interface] - func_table[function.__name__] = function # Needed for introspection - - -class Bridge(object): - def __init__(self): - dbus.mainloop.glib.DBusGMainLoop(set_as_default=True) - log.info("Init DBus...") + async def postInit(self): try: - self.session_bus = dbus.SessionBus() - except dbus.DBusException as e: - if e._dbus_error_name == "org.freedesktop.DBus.Error.NotSupported": + conn = await client.connect(reactor) + except error.DBusException as e: + if e.errName == "org.freedesktop.DBus.Error.NotSupported": log.error( _( - "D-Bus is not launched, please see README to see instructions on how to launch it" + "D-Bus is not launched, please see README to see instructions on " + "how to launch it" ) ) - raise BridgeInitError - self.dbus_name = dbus.service.BusName(const_INT_PREFIX, self.session_bus) - self.dbus_bridge = DbusObject(self.session_bus, const_OBJ_PATH) + raise BridgeInitError(str(e)) -##SIGNAL_DIRECT_CALLS_PART## + conn.exportObject(self._obj) + await conn.requestBusName(const_INT_PREFIX) + +##SIGNALS_PART## def register_method(self, name, callback): - log.debug("registering DBus bridge method [%s]" % name) - self.dbus_bridge.register_method(name, callback) + log.debug(f"registering DBus bridge method [{name}]") + self._obj.register_method(name, callback) + + def emitSignal(self, name, *args): + self._obj.emitSignal(name, *args) - def addMethod(self, name, int_suffix, in_sign, out_sign, method, async_=False, doc={}): - """Dynamically add a method to Dbus Bridge""" + def addMethod( + self, name, int_suffix, in_sign, out_sign, method, async_=False, doc={} + ): + """Dynamically add a method to D-Bus Bridge""" # FIXME: doc parameter is kept only temporary, the time to remove it from calls - log.debug("Adding method [%s] to DBus bridge" % name) - self.dbus_bridge.addMethod(name, int_suffix, in_sign, out_sign, method, async_) + log.debug(f"Adding method {name!r} to D-Bus bridge") + self._obj.plugin_iface.addMethod( + Method(name, arguments=in_sign, returns=out_sign) + ) + # we have to create a method here instead of using partialmethod, because txdbus + # uses __func__ which doesn't work with partialmethod + def caller(self_, *args, **kwargs): + return self_._callback(name, *args, **kwargs) + setattr(self._obj, f"dbus_{name}", MethodType(caller, self._obj)) self.register_method(name, method) def addSignal(self, name, int_suffix, signature, doc={}): - self.dbus_bridge.addSignal(name, int_suffix, signature, doc) - setattr(Bridge, name, getattr(self.dbus_bridge, name)) + """Dynamically add a signal to D-Bus Bridge""" + log.debug(f"Adding signal {name!r} to D-Bus bridge") + self._obj.plugin_iface.addSignal(Signal(name, signature)) + setattr(Bridge, name, partialmethod(Bridge.emitSignal, name))
--- a/sat/bridge/dbus_bridge.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/bridge/dbus_bridge.py Wed Sep 01 13:47:17 2021 +0200 @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# SàT communication bridge +# Libervia communication bridge # Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) # This program is free software: you can redistribute it and/or modify @@ -16,15 +16,15 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. +from types import MethodType +from functools import partialmethod +from twisted.internet import defer, reactor from sat.core.i18n import _ -import dbus -import dbus.service -import dbus.mainloop.glib -import inspect from sat.core.log import getLogger +from sat.core.exceptions import BridgeInitError from sat.tools import config -from twisted.internet.defer import Deferred -from sat.core.exceptions import BridgeInitError +from txdbus import client, objects, error +from txdbus.interface import DBusInterface, Method, Signal log = getLogger(__name__) @@ -45,772 +45,451 @@ pass -class MethodNotRegistered(dbus.DBusException): - _dbus_error_name = const_ERROR_PREFIX + ".MethodNotRegistered" - - -class InternalError(dbus.DBusException): - _dbus_error_name = const_ERROR_PREFIX + ".InternalError" +class DBusException(Exception): + pass -class AsyncNotDeferred(dbus.DBusException): - _dbus_error_name = const_ERROR_PREFIX + ".AsyncNotDeferred" +class MethodNotRegistered(DBusException): + dbusErrorName = const_ERROR_PREFIX + ".MethodNotRegistered" -class DeferredNotAsync(dbus.DBusException): - _dbus_error_name = const_ERROR_PREFIX + ".DeferredNotAsync" - - -class GenericException(dbus.DBusException): +class GenericException(DBusException): def __init__(self, twisted_error): """ @param twisted_error (Failure): instance of twisted Failure - @return: DBusException + error message is used to store a repr of message and condition in a tuple, + so it can be evaluated by the frontend bridge. """ - super(GenericException, self).__init__() try: # twisted_error.value is a class class_ = twisted_error.value().__class__ except TypeError: # twisted_error.value is an instance class_ = twisted_error.value.__class__ - message = twisted_error.getErrorMessage() + data = twisted_error.getErrorMessage() try: - self.args = (message, twisted_error.value.condition) + data = (data, twisted_error.value.condition) except AttributeError: - self.args = (message,) - self._dbus_error_name = ".".join( - [const_ERROR_PREFIX, class_.__module__, class_.__name__] + data = (data,) + else: + data = (str(twisted_error),) + self.dbusErrorName = ".".join( + (const_ERROR_PREFIX, class_.__module__, class_.__name__) ) + super(GenericException, self).__init__(repr(data)) + + @classmethod + def create_and_raise(cls, exc): + raise cls(exc) -class DbusObject(dbus.service.Object): - def __init__(self, bus, path): - dbus.service.Object.__init__(self, bus, path) - log.debug("Init DbusObject...") +class DBusObject(objects.DBusObject): + + core_iface = DBusInterface( + const_INT_PREFIX + const_CORE_SUFFIX, + Method('actionsGet', arguments='s', returns='a(a{ss}si)'), + Method('addContact', arguments='ss', returns=''), + Method('asyncDeleteProfile', arguments='s', returns=''), + Method('asyncGetParamA', arguments='sssis', returns='s'), + Method('asyncGetParamsValuesFromCategory', arguments='sisss', returns='a{ss}'), + Method('connect', arguments='ssa{ss}', returns='b'), + Method('contactGet', arguments='ss', returns='(a{ss}as)'), + Method('delContact', arguments='ss', returns=''), + Method('devicesInfosGet', arguments='ss', returns='s'), + Method('discoFindByFeatures', arguments='asa(ss)bbbbbs', returns='(a{sa(sss)}a{sa(sss)}a{sa(sss)})'), + Method('discoInfos', arguments='ssbs', returns='(asa(sss)a{sa(a{ss}as)})'), + Method('discoItems', arguments='ssbs', returns='a(sss)'), + Method('disconnect', arguments='s', returns=''), + Method('encryptionNamespaceGet', arguments='s', returns='s'), + Method('encryptionPluginsGet', arguments='', returns='s'), + Method('encryptionTrustUIGet', arguments='sss', returns='s'), + Method('getConfig', arguments='ss', returns='s'), + Method('getContacts', arguments='s', returns='a(sa{ss}as)'), + Method('getContactsFromGroup', arguments='ss', returns='as'), + Method('getEntitiesData', arguments='asass', returns='a{sa{ss}}'), + Method('getEntityData', arguments='sass', returns='a{ss}'), + Method('getFeatures', arguments='s', returns='a{sa{ss}}'), + Method('getMainResource', arguments='ss', returns='s'), + Method('getParamA', arguments='ssss', returns='s'), + Method('getParamsCategories', arguments='', returns='as'), + Method('getParamsUI', arguments='isss', returns='s'), + Method('getPresenceStatuses', arguments='s', returns='a{sa{s(sia{ss})}}'), + Method('getReady', arguments='', returns=''), + Method('getVersion', arguments='', returns='s'), + Method('getWaitingSub', arguments='s', returns='a{ss}'), + Method('historyGet', arguments='ssiba{ss}s', returns='a(sdssa{ss}a{ss}ss)'), + Method('imageCheck', arguments='s', returns='s'), + Method('imageConvert', arguments='ssss', returns='s'), + Method('imageGeneratePreview', arguments='ss', returns='s'), + Method('imageResize', arguments='sii', returns='s'), + Method('isConnected', arguments='s', returns='b'), + Method('launchAction', arguments='sa{ss}s', returns='a{ss}'), + Method('loadParamsTemplate', arguments='s', returns='b'), + Method('menuHelpGet', arguments='ss', returns='s'), + Method('menuLaunch', arguments='sasa{ss}is', returns='a{ss}'), + Method('menusGet', arguments='si', returns='a(ssasasa{ss})'), + Method('messageEncryptionGet', arguments='ss', returns='s'), + Method('messageEncryptionStart', arguments='ssbs', returns=''), + Method('messageEncryptionStop', arguments='ss', returns=''), + Method('messageSend', arguments='sa{ss}a{ss}sss', returns=''), + Method('namespacesGet', arguments='', returns='a{ss}'), + Method('paramsRegisterApp', arguments='sis', returns=''), + Method('privateDataDelete', arguments='sss', returns=''), + Method('privateDataGet', arguments='sss', returns='s'), + Method('privateDataSet', arguments='ssss', returns=''), + Method('profileCreate', arguments='sss', returns=''), + Method('profileIsSessionStarted', arguments='s', returns='b'), + Method('profileNameGet', arguments='s', returns='s'), + Method('profileSetDefault', arguments='s', returns=''), + Method('profileStartSession', arguments='ss', returns='b'), + Method('profilesListGet', arguments='bb', returns='as'), + Method('progressGet', arguments='ss', returns='a{ss}'), + Method('progressGetAll', arguments='s', returns='a{sa{sa{ss}}}'), + Method('progressGetAllMetadata', arguments='s', returns='a{sa{sa{ss}}}'), + Method('rosterResync', arguments='s', returns=''), + Method('saveParamsTemplate', arguments='s', returns='b'), + Method('sessionInfosGet', arguments='s', returns='a{ss}'), + Method('setParam', arguments='sssis', returns=''), + Method('setPresence', arguments='ssa{ss}s', returns=''), + Method('subscription', arguments='sss', returns=''), + Method('updateContact', arguments='ssass', returns=''), + Signal('_debug', 'sa{ss}s'), + Signal('actionNew', 'a{ss}sis'), + Signal('connected', 'ss'), + Signal('contactDeleted', 'ss'), + Signal('disconnected', 's'), + Signal('entityDataUpdated', 'ssss'), + Signal('messageEncryptionStarted', 'sss'), + Signal('messageEncryptionStopped', 'sa{ss}s'), + Signal('messageNew', 'sdssa{ss}a{ss}sss'), + Signal('newContact', 'sa{ss}ass'), + Signal('paramUpdate', 'ssss'), + Signal('presenceUpdate', 'ssia{ss}s'), + Signal('progressError', 'sss'), + Signal('progressFinished', 'sa{ss}s'), + Signal('progressStarted', 'sa{ss}s'), + Signal('subscribe', 'sss'), + ) + plugin_iface = DBusInterface( + const_INT_PREFIX + const_PLUGIN_SUFFIX + ) + + dbusInterfaces = [core_iface, plugin_iface] + + def __init__(self, path): + super().__init__(path) + log.debug("Init DBusObject...") self.cb = {} def register_method(self, name, cb): self.cb[name] = cb def _callback(self, name, *args, **kwargs): - """call the callback if it exists, raise an exception else - if the callback return a deferred, use async methods""" - if not name in self.cb: + """Call the callback if it exists, raise an exception else""" + try: + cb = self.cb[name] + except KeyError: raise MethodNotRegistered - - if "callback" in kwargs: - # we must have errback too - if not "errback" in kwargs: - log.error("errback is missing in method call [%s]" % name) - raise InternalError - callback = kwargs.pop("callback") - errback = kwargs.pop("errback") - async_ = True else: - async_ = False - result = self.cb[name](*args, **kwargs) - if async_: - if not isinstance(result, Deferred): - log.error("Asynchronous method [%s] does not return a Deferred." % name) - raise AsyncNotDeferred - result.addCallback( - lambda result: callback() if result is None else callback(result) - ) - result.addErrback(lambda err: errback(GenericException(err))) - else: - if isinstance(result, Deferred): - log.error("Synchronous method [%s] return a Deferred." % name) - raise DeferredNotAsync - return result - - ### signals ### - - @dbus.service.signal(const_INT_PREFIX + const_PLUGIN_SUFFIX, signature="") - def dummySignal(self): - # FIXME: workaround for addSignal (doesn't work if one method doensn't - # already exist for plugins), probably missing some initialisation, need - # further investigations - pass - - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='sa{ss}s') - def _debug(self, action, params, profile): - pass - - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='a{ss}sis') - def actionNew(self, action_data, id, security_limit, profile): - pass + d = defer.maybeDeferred(cb, *args, **kwargs) + d.addErrback(GenericException.create_and_raise) + return d - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='ss') - def connected(self, jid_s, profile): - pass - - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='ss') - def contactDeleted(self, entity_jid, profile): - pass + def dbus_actionsGet(self, profile_key="@DEFAULT@"): + return self._callback("actionsGet", profile_key) - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='s') - def disconnected(self, profile): - pass - - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='ssss') - def entityDataUpdated(self, jid, name, value, profile): - pass - - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='sss') - def messageEncryptionStarted(self, to_jid, encryption_data, profile_key): - pass + def dbus_addContact(self, entity_jid, profile_key="@DEFAULT@"): + return self._callback("addContact", entity_jid, profile_key) - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='sa{ss}s') - def messageEncryptionStopped(self, to_jid, encryption_data, profile_key): - pass - - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='sdssa{ss}a{ss}sss') - def messageNew(self, uid, timestamp, from_jid, to_jid, message, subject, mess_type, extra, profile): - pass - - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='sa{ss}ass') - def newContact(self, contact_jid, attributes, groups, profile): - pass + def dbus_asyncDeleteProfile(self, profile): + return self._callback("asyncDeleteProfile", profile) - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='ssss') - def paramUpdate(self, name, value, category, profile): - pass + def dbus_asyncGetParamA(self, name, category, attribute="value", security_limit=-1, profile_key="@DEFAULT@"): + return self._callback("asyncGetParamA", name, category, attribute, security_limit, profile_key) - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='ssia{ss}s') - def presenceUpdate(self, entity_jid, show, priority, statuses, profile): - pass - - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='sss') - def progressError(self, id, error, profile): - pass + def dbus_asyncGetParamsValuesFromCategory(self, category, security_limit=-1, app="", extra="", profile_key="@DEFAULT@"): + return self._callback("asyncGetParamsValuesFromCategory", category, security_limit, app, extra, profile_key) - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='sa{ss}s') - def progressFinished(self, id, metadata, profile): - pass - - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='sa{ss}s') - def progressStarted(self, id, metadata, profile): - pass + def dbus_connect(self, profile_key="@DEFAULT@", password='', options={}): + return self._callback("connect", profile_key, password, options) - @dbus.service.signal(const_INT_PREFIX+const_CORE_SUFFIX, - signature='sss') - def subscribe(self, sub_type, entity_jid, profile): - pass - - ### methods ### - - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='a(a{ss}si)', - async_callbacks=None) - def actionsGet(self, profile_key="@DEFAULT@"): - return self._callback("actionsGet", str(profile_key)) + def dbus_contactGet(self, arg_0, profile_key="@DEFAULT@"): + return self._callback("contactGet", arg_0, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ss', out_signature='', - async_callbacks=None) - def addContact(self, entity_jid, profile_key="@DEFAULT@"): - return self._callback("addContact", str(entity_jid), str(profile_key)) - - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='', - async_callbacks=('callback', 'errback')) - def asyncDeleteProfile(self, profile, callback=None, errback=None): - return self._callback("asyncDeleteProfile", str(profile), callback=callback, errback=errback) + def dbus_delContact(self, entity_jid, profile_key="@DEFAULT@"): + return self._callback("delContact", entity_jid, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='sssis', out_signature='s', - async_callbacks=('callback', 'errback')) - def asyncGetParamA(self, name, category, attribute="value", security_limit=-1, profile_key="@DEFAULT@", callback=None, errback=None): - return self._callback("asyncGetParamA", str(name), str(category), str(attribute), security_limit, str(profile_key), callback=callback, errback=errback) - - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='sisss', out_signature='a{ss}', - async_callbacks=('callback', 'errback')) - def asyncGetParamsValuesFromCategory(self, category, security_limit=-1, app="", extra="", profile_key="@DEFAULT@", callback=None, errback=None): - return self._callback("asyncGetParamsValuesFromCategory", str(category), security_limit, str(app), str(extra), str(profile_key), callback=callback, errback=errback) - - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ssa{ss}', out_signature='b', - async_callbacks=('callback', 'errback')) - def connect(self, profile_key="@DEFAULT@", password='', options={}, callback=None, errback=None): - return self._callback("connect", str(profile_key), str(password), options, callback=callback, errback=errback) + def dbus_devicesInfosGet(self, bare_jid, profile_key): + return self._callback("devicesInfosGet", bare_jid, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ss', out_signature='(a{ss}as)', - async_callbacks=('callback', 'errback')) - def contactGet(self, arg_0, profile_key="@DEFAULT@", callback=None, errback=None): - return self._callback("contactGet", str(arg_0), str(profile_key), callback=callback, errback=errback) - - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ss', out_signature='', - async_callbacks=('callback', 'errback')) - def delContact(self, entity_jid, profile_key="@DEFAULT@", callback=None, errback=None): - return self._callback("delContact", str(entity_jid), str(profile_key), callback=callback, errback=errback) + def dbus_discoFindByFeatures(self, namespaces, identities, bare_jid=False, service=True, roster=True, own_jid=True, local_device=False, profile_key="@DEFAULT@"): + return self._callback("discoFindByFeatures", namespaces, identities, bare_jid, service, roster, own_jid, local_device, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ss', out_signature='s', - async_callbacks=('callback', 'errback')) - def devicesInfosGet(self, bare_jid, profile_key, callback=None, errback=None): - return self._callback("devicesInfosGet", str(bare_jid), str(profile_key), callback=callback, errback=errback) - - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='asa(ss)bbbbbs', out_signature='(a{sa(sss)}a{sa(sss)}a{sa(sss)})', - async_callbacks=('callback', 'errback')) - def discoFindByFeatures(self, namespaces, identities, bare_jid=False, service=True, roster=True, own_jid=True, local_device=False, profile_key="@DEFAULT@", callback=None, errback=None): - return self._callback("discoFindByFeatures", namespaces, identities, bare_jid, service, roster, own_jid, local_device, str(profile_key), callback=callback, errback=errback) + def dbus_discoInfos(self, entity_jid, node=u'', use_cache=True, profile_key="@DEFAULT@"): + return self._callback("discoInfos", entity_jid, node, use_cache, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ssbs', out_signature='(asa(sss)a{sa(a{ss}as)})', - async_callbacks=('callback', 'errback')) - def discoInfos(self, entity_jid, node=u'', use_cache=True, profile_key="@DEFAULT@", callback=None, errback=None): - return self._callback("discoInfos", str(entity_jid), str(node), use_cache, str(profile_key), callback=callback, errback=errback) - - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ssbs', out_signature='a(sss)', - async_callbacks=('callback', 'errback')) - def discoItems(self, entity_jid, node=u'', use_cache=True, profile_key="@DEFAULT@", callback=None, errback=None): - return self._callback("discoItems", str(entity_jid), str(node), use_cache, str(profile_key), callback=callback, errback=errback) + def dbus_discoItems(self, entity_jid, node=u'', use_cache=True, profile_key="@DEFAULT@"): + return self._callback("discoItems", entity_jid, node, use_cache, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='', - async_callbacks=('callback', 'errback')) - def disconnect(self, profile_key="@DEFAULT@", callback=None, errback=None): - return self._callback("disconnect", str(profile_key), callback=callback, errback=errback) + def dbus_disconnect(self, profile_key="@DEFAULT@"): + return self._callback("disconnect", profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='s', - async_callbacks=None) - def encryptionNamespaceGet(self, arg_0): - return self._callback("encryptionNamespaceGet", str(arg_0)) + def dbus_encryptionNamespaceGet(self, arg_0): + return self._callback("encryptionNamespaceGet", arg_0) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='', out_signature='s', - async_callbacks=None) - def encryptionPluginsGet(self, ): + def dbus_encryptionPluginsGet(self, ): return self._callback("encryptionPluginsGet", ) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='sss', out_signature='s', - async_callbacks=('callback', 'errback')) - def encryptionTrustUIGet(self, to_jid, namespace, profile_key, callback=None, errback=None): - return self._callback("encryptionTrustUIGet", str(to_jid), str(namespace), str(profile_key), callback=callback, errback=errback) + def dbus_encryptionTrustUIGet(self, to_jid, namespace, profile_key): + return self._callback("encryptionTrustUIGet", to_jid, namespace, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ss', out_signature='s', - async_callbacks=None) - def getConfig(self, section, name): - return self._callback("getConfig", str(section), str(name)) + def dbus_getConfig(self, section, name): + return self._callback("getConfig", section, name) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='a(sa{ss}as)', - async_callbacks=('callback', 'errback')) - def getContacts(self, profile_key="@DEFAULT@", callback=None, errback=None): - return self._callback("getContacts", str(profile_key), callback=callback, errback=errback) + def dbus_getContacts(self, profile_key="@DEFAULT@"): + return self._callback("getContacts", profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ss', out_signature='as', - async_callbacks=None) - def getContactsFromGroup(self, group, profile_key="@DEFAULT@"): - return self._callback("getContactsFromGroup", str(group), str(profile_key)) + def dbus_getContactsFromGroup(self, group, profile_key="@DEFAULT@"): + return self._callback("getContactsFromGroup", group, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='asass', out_signature='a{sa{ss}}', - async_callbacks=None) - def getEntitiesData(self, jids, keys, profile): - return self._callback("getEntitiesData", jids, keys, str(profile)) + def dbus_getEntitiesData(self, jids, keys, profile): + return self._callback("getEntitiesData", jids, keys, profile) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='sass', out_signature='a{ss}', - async_callbacks=None) - def getEntityData(self, jid, keys, profile): - return self._callback("getEntityData", str(jid), keys, str(profile)) + def dbus_getEntityData(self, jid, keys, profile): + return self._callback("getEntityData", jid, keys, profile) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='a{sa{ss}}', - async_callbacks=('callback', 'errback')) - def getFeatures(self, profile_key, callback=None, errback=None): - return self._callback("getFeatures", str(profile_key), callback=callback, errback=errback) + def dbus_getFeatures(self, profile_key): + return self._callback("getFeatures", profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ss', out_signature='s', - async_callbacks=None) - def getMainResource(self, contact_jid, profile_key="@DEFAULT@"): - return self._callback("getMainResource", str(contact_jid), str(profile_key)) + def dbus_getMainResource(self, contact_jid, profile_key="@DEFAULT@"): + return self._callback("getMainResource", contact_jid, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ssss', out_signature='s', - async_callbacks=None) - def getParamA(self, name, category, attribute="value", profile_key="@DEFAULT@"): - return self._callback("getParamA", str(name), str(category), str(attribute), str(profile_key)) + def dbus_getParamA(self, name, category, attribute="value", profile_key="@DEFAULT@"): + return self._callback("getParamA", name, category, attribute, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='', out_signature='as', - async_callbacks=None) - def getParamsCategories(self, ): + def dbus_getParamsCategories(self, ): return self._callback("getParamsCategories", ) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='isss', out_signature='s', - async_callbacks=('callback', 'errback')) - def getParamsUI(self, security_limit=-1, app='', extra='', profile_key="@DEFAULT@", callback=None, errback=None): - return self._callback("getParamsUI", security_limit, str(app), str(extra), str(profile_key), callback=callback, errback=errback) + def dbus_getParamsUI(self, security_limit=-1, app='', extra='', profile_key="@DEFAULT@"): + return self._callback("getParamsUI", security_limit, app, extra, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='a{sa{s(sia{ss})}}', - async_callbacks=None) - def getPresenceStatuses(self, profile_key="@DEFAULT@"): - return self._callback("getPresenceStatuses", str(profile_key)) + def dbus_getPresenceStatuses(self, profile_key="@DEFAULT@"): + return self._callback("getPresenceStatuses", profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='', out_signature='', - async_callbacks=('callback', 'errback')) - def getReady(self, callback=None, errback=None): - return self._callback("getReady", callback=callback, errback=errback) + def dbus_getReady(self, ): + return self._callback("getReady", ) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='', out_signature='s', - async_callbacks=None) - def getVersion(self, ): + def dbus_getVersion(self, ): return self._callback("getVersion", ) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='a{ss}', - async_callbacks=None) - def getWaitingSub(self, profile_key="@DEFAULT@"): - return self._callback("getWaitingSub", str(profile_key)) + def dbus_getWaitingSub(self, profile_key="@DEFAULT@"): + return self._callback("getWaitingSub", profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ssiba{ss}s', out_signature='a(sdssa{ss}a{ss}ss)', - async_callbacks=('callback', 'errback')) - def historyGet(self, from_jid, to_jid, limit, between=True, filters='', profile="@NONE@", callback=None, errback=None): - return self._callback("historyGet", str(from_jid), str(to_jid), limit, between, filters, str(profile), callback=callback, errback=errback) + def dbus_historyGet(self, from_jid, to_jid, limit, between=True, filters='', profile="@NONE@"): + return self._callback("historyGet", from_jid, to_jid, limit, between, filters, profile) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='s', - async_callbacks=None) - def imageCheck(self, arg_0): - return self._callback("imageCheck", str(arg_0)) + def dbus_imageCheck(self, arg_0): + return self._callback("imageCheck", arg_0) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ssss', out_signature='s', - async_callbacks=('callback', 'errback')) - def imageConvert(self, source, dest, arg_2, extra, callback=None, errback=None): - return self._callback("imageConvert", str(source), str(dest), str(arg_2), str(extra), callback=callback, errback=errback) + def dbus_imageConvert(self, source, dest, arg_2, extra): + return self._callback("imageConvert", source, dest, arg_2, extra) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ss', out_signature='s', - async_callbacks=('callback', 'errback')) - def imageGeneratePreview(self, image_path, profile_key, callback=None, errback=None): - return self._callback("imageGeneratePreview", str(image_path), str(profile_key), callback=callback, errback=errback) + def dbus_imageGeneratePreview(self, image_path, profile_key): + return self._callback("imageGeneratePreview", image_path, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='sii', out_signature='s', - async_callbacks=('callback', 'errback')) - def imageResize(self, image_path, width, height, callback=None, errback=None): - return self._callback("imageResize", str(image_path), width, height, callback=callback, errback=errback) + def dbus_imageResize(self, image_path, width, height): + return self._callback("imageResize", image_path, width, height) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='b', - async_callbacks=None) - def isConnected(self, profile_key="@DEFAULT@"): - return self._callback("isConnected", str(profile_key)) + def dbus_isConnected(self, profile_key="@DEFAULT@"): + return self._callback("isConnected", profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='sa{ss}s', out_signature='a{ss}', - async_callbacks=('callback', 'errback')) - def launchAction(self, callback_id, data, profile_key="@DEFAULT@", callback=None, errback=None): - return self._callback("launchAction", str(callback_id), data, str(profile_key), callback=callback, errback=errback) + def dbus_launchAction(self, callback_id, data, profile_key="@DEFAULT@"): + return self._callback("launchAction", callback_id, data, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='b', - async_callbacks=None) - def loadParamsTemplate(self, filename): - return self._callback("loadParamsTemplate", str(filename)) + def dbus_loadParamsTemplate(self, filename): + return self._callback("loadParamsTemplate", filename) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ss', out_signature='s', - async_callbacks=None) - def menuHelpGet(self, menu_id, language): - return self._callback("menuHelpGet", str(menu_id), str(language)) + def dbus_menuHelpGet(self, menu_id, language): + return self._callback("menuHelpGet", menu_id, language) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='sasa{ss}is', out_signature='a{ss}', - async_callbacks=('callback', 'errback')) - def menuLaunch(self, menu_type, path, data, security_limit, profile_key, callback=None, errback=None): - return self._callback("menuLaunch", str(menu_type), path, data, security_limit, str(profile_key), callback=callback, errback=errback) + def dbus_menuLaunch(self, menu_type, path, data, security_limit, profile_key): + return self._callback("menuLaunch", menu_type, path, data, security_limit, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='si', out_signature='a(ssasasa{ss})', - async_callbacks=None) - def menusGet(self, language, security_limit): - return self._callback("menusGet", str(language), security_limit) + def dbus_menusGet(self, language, security_limit): + return self._callback("menusGet", language, security_limit) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ss', out_signature='s', - async_callbacks=None) - def messageEncryptionGet(self, to_jid, profile_key): - return self._callback("messageEncryptionGet", str(to_jid), str(profile_key)) + def dbus_messageEncryptionGet(self, to_jid, profile_key): + return self._callback("messageEncryptionGet", to_jid, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ssbs', out_signature='', - async_callbacks=('callback', 'errback')) - def messageEncryptionStart(self, to_jid, namespace='', replace=False, profile_key="@NONE@", callback=None, errback=None): - return self._callback("messageEncryptionStart", str(to_jid), str(namespace), replace, str(profile_key), callback=callback, errback=errback) + def dbus_messageEncryptionStart(self, to_jid, namespace='', replace=False, profile_key="@NONE@"): + return self._callback("messageEncryptionStart", to_jid, namespace, replace, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ss', out_signature='', - async_callbacks=('callback', 'errback')) - def messageEncryptionStop(self, to_jid, profile_key, callback=None, errback=None): - return self._callback("messageEncryptionStop", str(to_jid), str(profile_key), callback=callback, errback=errback) + def dbus_messageEncryptionStop(self, to_jid, profile_key): + return self._callback("messageEncryptionStop", to_jid, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='sa{ss}a{ss}sss', out_signature='', - async_callbacks=('callback', 'errback')) - def messageSend(self, to_jid, message, subject={}, mess_type="auto", extra={}, profile_key="@NONE@", callback=None, errback=None): - return self._callback("messageSend", str(to_jid), message, subject, str(mess_type), str(extra), str(profile_key), callback=callback, errback=errback) + def dbus_messageSend(self, to_jid, message, subject={}, mess_type="auto", extra={}, profile_key="@NONE@"): + return self._callback("messageSend", to_jid, message, subject, mess_type, extra, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='', out_signature='a{ss}', - async_callbacks=None) - def namespacesGet(self, ): + def dbus_namespacesGet(self, ): return self._callback("namespacesGet", ) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='sis', out_signature='', - async_callbacks=None) - def paramsRegisterApp(self, xml, security_limit=-1, app=''): - return self._callback("paramsRegisterApp", str(xml), security_limit, str(app)) + def dbus_paramsRegisterApp(self, xml, security_limit=-1, app=''): + return self._callback("paramsRegisterApp", xml, security_limit, app) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='sss', out_signature='', - async_callbacks=('callback', 'errback')) - def privateDataDelete(self, namespace, key, arg_2, callback=None, errback=None): - return self._callback("privateDataDelete", str(namespace), str(key), str(arg_2), callback=callback, errback=errback) + def dbus_privateDataDelete(self, namespace, key, arg_2): + return self._callback("privateDataDelete", namespace, key, arg_2) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='sss', out_signature='s', - async_callbacks=('callback', 'errback')) - def privateDataGet(self, namespace, key, profile_key, callback=None, errback=None): - return self._callback("privateDataGet", str(namespace), str(key), str(profile_key), callback=callback, errback=errback) + def dbus_privateDataGet(self, namespace, key, profile_key): + return self._callback("privateDataGet", namespace, key, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ssss', out_signature='', - async_callbacks=('callback', 'errback')) - def privateDataSet(self, namespace, key, data, profile_key, callback=None, errback=None): - return self._callback("privateDataSet", str(namespace), str(key), str(data), str(profile_key), callback=callback, errback=errback) + def dbus_privateDataSet(self, namespace, key, data, profile_key): + return self._callback("privateDataSet", namespace, key, data, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='sss', out_signature='', - async_callbacks=('callback', 'errback')) - def profileCreate(self, profile, password='', component='', callback=None, errback=None): - return self._callback("profileCreate", str(profile), str(password), str(component), callback=callback, errback=errback) + def dbus_profileCreate(self, profile, password='', component=''): + return self._callback("profileCreate", profile, password, component) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='b', - async_callbacks=None) - def profileIsSessionStarted(self, profile_key="@DEFAULT@"): - return self._callback("profileIsSessionStarted", str(profile_key)) + def dbus_profileIsSessionStarted(self, profile_key="@DEFAULT@"): + return self._callback("profileIsSessionStarted", profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='s', - async_callbacks=None) - def profileNameGet(self, profile_key="@DEFAULT@"): - return self._callback("profileNameGet", str(profile_key)) + def dbus_profileNameGet(self, profile_key="@DEFAULT@"): + return self._callback("profileNameGet", profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='', - async_callbacks=None) - def profileSetDefault(self, profile): - return self._callback("profileSetDefault", str(profile)) + def dbus_profileSetDefault(self, profile): + return self._callback("profileSetDefault", profile) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ss', out_signature='b', - async_callbacks=('callback', 'errback')) - def profileStartSession(self, password='', profile_key="@DEFAULT@", callback=None, errback=None): - return self._callback("profileStartSession", str(password), str(profile_key), callback=callback, errback=errback) + def dbus_profileStartSession(self, password='', profile_key="@DEFAULT@"): + return self._callback("profileStartSession", password, profile_key) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='bb', out_signature='as', - async_callbacks=None) - def profilesListGet(self, clients=True, components=False): + def dbus_profilesListGet(self, clients=True, components=False): return self._callback("profilesListGet", clients, components) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ss', out_signature='a{ss}', - async_callbacks=None) - def progressGet(self, id, profile): - return self._callback("progressGet", str(id), str(profile)) - - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='a{sa{sa{ss}}}', - async_callbacks=None) - def progressGetAll(self, profile): - return self._callback("progressGetAll", str(profile)) - - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='a{sa{sa{ss}}}', - async_callbacks=None) - def progressGetAllMetadata(self, profile): - return self._callback("progressGetAllMetadata", str(profile)) + def dbus_progressGet(self, id, profile): + return self._callback("progressGet", id, profile) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='', - async_callbacks=('callback', 'errback')) - def rosterResync(self, profile_key="@DEFAULT@", callback=None, errback=None): - return self._callback("rosterResync", str(profile_key), callback=callback, errback=errback) - - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='b', - async_callbacks=None) - def saveParamsTemplate(self, filename): - return self._callback("saveParamsTemplate", str(filename)) - - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='s', out_signature='a{ss}', - async_callbacks=('callback', 'errback')) - def sessionInfosGet(self, profile_key, callback=None, errback=None): - return self._callback("sessionInfosGet", str(profile_key), callback=callback, errback=errback) + def dbus_progressGetAll(self, profile): + return self._callback("progressGetAll", profile) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='sssis', out_signature='', - async_callbacks=None) - def setParam(self, name, value, category, security_limit=-1, profile_key="@DEFAULT@"): - return self._callback("setParam", str(name), str(value), str(category), security_limit, str(profile_key)) - - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ssa{ss}s', out_signature='', - async_callbacks=None) - def setPresence(self, to_jid='', show='', statuses={}, profile_key="@DEFAULT@"): - return self._callback("setPresence", str(to_jid), str(show), statuses, str(profile_key)) - - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='sss', out_signature='', - async_callbacks=None) - def subscription(self, sub_type, entity, profile_key="@DEFAULT@"): - return self._callback("subscription", str(sub_type), str(entity), str(profile_key)) + def dbus_progressGetAllMetadata(self, profile): + return self._callback("progressGetAllMetadata", profile) - @dbus.service.method(const_INT_PREFIX+const_CORE_SUFFIX, - in_signature='ssass', out_signature='', - async_callbacks=None) - def updateContact(self, entity_jid, name, groups, profile_key="@DEFAULT@"): - return self._callback("updateContact", str(entity_jid), str(name), groups, str(profile_key)) + def dbus_rosterResync(self, profile_key="@DEFAULT@"): + return self._callback("rosterResync", profile_key) - def __attributes(self, in_sign): - """Return arguments to user given a in_sign - @param in_sign: in_sign in the short form (using s,a,i,b etc) - @return: list of arguments that correspond to a in_sign (e.g.: "sss" return "arg1, arg2, arg3")""" - i = 0 - idx = 0 - attr = [] - while i < len(in_sign): - if in_sign[i] not in ["b", "y", "n", "i", "x", "q", "u", "t", "d", "s", "a"]: - raise ParseError("Unmanaged attribute type [%c]" % in_sign[i]) - - attr.append("arg_%i" % idx) - idx += 1 + def dbus_saveParamsTemplate(self, filename): + return self._callback("saveParamsTemplate", filename) - if in_sign[i] == "a": - i += 1 - if ( - in_sign[i] != "{" and in_sign[i] != "(" - ): # FIXME: must manage tuples out of arrays - i += 1 - continue # we have a simple type for the array - opening_car = in_sign[i] - assert opening_car in ["{", "("] - closing_car = "}" if opening_car == "{" else ")" - opening_count = 1 - while True: # we have a dict or a list of tuples - i += 1 - if i >= len(in_sign): - raise ParseError("missing }") - if in_sign[i] == opening_car: - opening_count += 1 - if in_sign[i] == closing_car: - opening_count -= 1 - if opening_count == 0: - break - i += 1 - return attr + def dbus_sessionInfosGet(self, profile_key): + return self._callback("sessionInfosGet", profile_key) - def addMethod(self, name, int_suffix, in_sign, out_sign, method, async_=False): - """Dynamically add a method to Dbus Bridge""" - inspect_args = inspect.getfullargspec(method) - - _arguments = inspect_args.args - _defaults = list(inspect_args.defaults or []) - - if inspect.ismethod(method): - # if we have a method, we don't want the first argument (usually 'self') - del (_arguments[0]) - - # first arguments are for the _callback method - arguments_callback = ", ".join( - [repr(name)] - + ( - (_arguments + ["callback=callback", "errback=errback"]) - if async_ - else _arguments - ) - ) - - if async_: - _arguments.extend(["callback", "errback"]) - _defaults.extend([None, None]) - - # now we create a second list with default values - for i in range(1, len(_defaults) + 1): - _arguments[-i] = "%s = %s" % (_arguments[-i], repr(_defaults[-i])) + def dbus_setParam(self, name, value, category, security_limit=-1, profile_key="@DEFAULT@"): + return self._callback("setParam", name, value, category, security_limit, profile_key) - arguments_defaults = ", ".join(_arguments) + def dbus_setPresence(self, to_jid='', show='', statuses={}, profile_key="@DEFAULT@"): + return self._callback("setPresence", to_jid, show, statuses, profile_key) - code = compile( - "def %(name)s (self,%(arguments_defaults)s): return self._callback(%(arguments_callback)s)" - % { - "name": name, - "arguments_defaults": arguments_defaults, - "arguments_callback": arguments_callback, - }, - "<DBus bridge>", - "exec", - ) - exec(code) # FIXME: to the same thing in a cleaner way, without compile/exec - method = locals()[name] - async_callbacks = ("callback", "errback") if async_ else None - setattr( - DbusObject, - name, - dbus.service.method( - const_INT_PREFIX + int_suffix, - in_signature=in_sign, - out_signature=out_sign, - async_callbacks=async_callbacks, - )(method), - ) - function = getattr(self, name) - func_table = self._dbus_class_table[ - self.__class__.__module__ + "." + self.__class__.__name__ - ][function._dbus_interface] - func_table[function.__name__] = function # Needed for introspection + def dbus_subscription(self, sub_type, entity, profile_key="@DEFAULT@"): + return self._callback("subscription", sub_type, entity, profile_key) - def addSignal(self, name, int_suffix, signature, doc={}): - """Dynamically add a signal to Dbus Bridge""" - attributes = ", ".join(self.__attributes(signature)) - # TODO: use doc parameter to name attributes - - # code = compile ('def '+name+' (self,'+attributes+'): log.debug ("'+name+' signal")', '<DBus bridge>','exec') #XXX: the log.debug is too annoying with xmllog - code = compile( - "def " + name + " (self," + attributes + "): pass", "<DBus bridge>", "exec" - ) - exec(code) - signal = locals()[name] - setattr( - DbusObject, - name, - dbus.service.signal(const_INT_PREFIX + int_suffix, signature=signature)( - signal - ), - ) - function = getattr(self, name) - func_table = self._dbus_class_table[ - self.__class__.__module__ + "." + self.__class__.__name__ - ][function._dbus_interface] - func_table[function.__name__] = function # Needed for introspection + def dbus_updateContact(self, entity_jid, name, groups, profile_key="@DEFAULT@"): + return self._callback("updateContact", entity_jid, name, groups, profile_key) -class Bridge(object): +class Bridge: + def __init__(self): - dbus.mainloop.glib.DBusGMainLoop(set_as_default=True) log.info("Init DBus...") + self._obj = DBusObject(const_OBJ_PATH) + + async def postInit(self): try: - self.session_bus = dbus.SessionBus() - except dbus.DBusException as e: - if e._dbus_error_name == "org.freedesktop.DBus.Error.NotSupported": + conn = await client.connect(reactor) + except error.DBusException as e: + if e.errName == "org.freedesktop.DBus.Error.NotSupported": log.error( _( - "D-Bus is not launched, please see README to see instructions on how to launch it" + "D-Bus is not launched, please see README to see instructions on " + "how to launch it" ) ) - raise BridgeInitError - self.dbus_name = dbus.service.BusName(const_INT_PREFIX, self.session_bus) - self.dbus_bridge = DbusObject(self.session_bus, const_OBJ_PATH) + raise BridgeInitError(str(e)) + + conn.exportObject(self._obj) + await conn.requestBusName(const_INT_PREFIX) def _debug(self, action, params, profile): - self.dbus_bridge._debug(action, params, profile) + self._obj.emitSignal("_debug", action, params, profile) def actionNew(self, action_data, id, security_limit, profile): - self.dbus_bridge.actionNew(action_data, id, security_limit, profile) + self._obj.emitSignal("actionNew", action_data, id, security_limit, profile) def connected(self, jid_s, profile): - self.dbus_bridge.connected(jid_s, profile) + self._obj.emitSignal("connected", jid_s, profile) def contactDeleted(self, entity_jid, profile): - self.dbus_bridge.contactDeleted(entity_jid, profile) + self._obj.emitSignal("contactDeleted", entity_jid, profile) def disconnected(self, profile): - self.dbus_bridge.disconnected(profile) + self._obj.emitSignal("disconnected", profile) def entityDataUpdated(self, jid, name, value, profile): - self.dbus_bridge.entityDataUpdated(jid, name, value, profile) + self._obj.emitSignal("entityDataUpdated", jid, name, value, profile) def messageEncryptionStarted(self, to_jid, encryption_data, profile_key): - self.dbus_bridge.messageEncryptionStarted(to_jid, encryption_data, profile_key) + self._obj.emitSignal("messageEncryptionStarted", to_jid, encryption_data, profile_key) def messageEncryptionStopped(self, to_jid, encryption_data, profile_key): - self.dbus_bridge.messageEncryptionStopped(to_jid, encryption_data, profile_key) + self._obj.emitSignal("messageEncryptionStopped", to_jid, encryption_data, profile_key) def messageNew(self, uid, timestamp, from_jid, to_jid, message, subject, mess_type, extra, profile): - self.dbus_bridge.messageNew(uid, timestamp, from_jid, to_jid, message, subject, mess_type, extra, profile) + self._obj.emitSignal("messageNew", uid, timestamp, from_jid, to_jid, message, subject, mess_type, extra, profile) def newContact(self, contact_jid, attributes, groups, profile): - self.dbus_bridge.newContact(contact_jid, attributes, groups, profile) + self._obj.emitSignal("newContact", contact_jid, attributes, groups, profile) def paramUpdate(self, name, value, category, profile): - self.dbus_bridge.paramUpdate(name, value, category, profile) + self._obj.emitSignal("paramUpdate", name, value, category, profile) def presenceUpdate(self, entity_jid, show, priority, statuses, profile): - self.dbus_bridge.presenceUpdate(entity_jid, show, priority, statuses, profile) + self._obj.emitSignal("presenceUpdate", entity_jid, show, priority, statuses, profile) def progressError(self, id, error, profile): - self.dbus_bridge.progressError(id, error, profile) + self._obj.emitSignal("progressError", id, error, profile) def progressFinished(self, id, metadata, profile): - self.dbus_bridge.progressFinished(id, metadata, profile) + self._obj.emitSignal("progressFinished", id, metadata, profile) def progressStarted(self, id, metadata, profile): - self.dbus_bridge.progressStarted(id, metadata, profile) + self._obj.emitSignal("progressStarted", id, metadata, profile) def subscribe(self, sub_type, entity_jid, profile): - self.dbus_bridge.subscribe(sub_type, entity_jid, profile) + self._obj.emitSignal("subscribe", sub_type, entity_jid, profile) def register_method(self, name, callback): - log.debug("registering DBus bridge method [%s]" % name) - self.dbus_bridge.register_method(name, callback) + log.debug(f"registering DBus bridge method [{name}]") + self._obj.register_method(name, callback) + + def emitSignal(self, name, *args): + self._obj.emitSignal(name, *args) - def addMethod(self, name, int_suffix, in_sign, out_sign, method, async_=False, doc={}): - """Dynamically add a method to Dbus Bridge""" + def addMethod( + self, name, int_suffix, in_sign, out_sign, method, async_=False, doc={} + ): + """Dynamically add a method to D-Bus Bridge""" # FIXME: doc parameter is kept only temporary, the time to remove it from calls - log.debug("Adding method [%s] to DBus bridge" % name) - self.dbus_bridge.addMethod(name, int_suffix, in_sign, out_sign, method, async_) + log.debug(f"Adding method {name!r} to D-Bus bridge") + self._obj.plugin_iface.addMethod( + Method(name, arguments=in_sign, returns=out_sign) + ) + # we have to create a method here instead of using partialmethod, because txdbus + # uses __func__ which doesn't work with partialmethod + def caller(self_, *args, **kwargs): + return self_._callback(name, *args, **kwargs) + setattr(self._obj, f"dbus_{name}", MethodType(caller, self._obj)) self.register_method(name, method) def addSignal(self, name, int_suffix, signature, doc={}): - self.dbus_bridge.addSignal(name, int_suffix, signature, doc) - setattr(Bridge, name, getattr(self.dbus_bridge, name)) \ No newline at end of file + """Dynamically add a signal to D-Bus Bridge""" + log.debug(f"Adding signal {name!r} to D-Bus bridge") + self._obj.plugin_iface.addSignal(Signal(name, signature)) + setattr(Bridge, name, partialmethod(Bridge.emitSignal, name)) \ No newline at end of file
--- a/sat/core/constants.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/core/constants.py Wed Sep 01 13:47:17 2021 +0200 @@ -216,6 +216,7 @@ PLUG_TYPE_EXP = "EXP" PLUG_TYPE_SEC = "SEC" PLUG_TYPE_SYNTAXE = "SYNTAXE" + PLUG_TYPE_PUBSUB = "PUBSUB" PLUG_TYPE_BLOG = "BLOG" PLUG_TYPE_IMPORT = "IMPORT" PLUG_TYPE_ENTRY_POINT = "ENTRY_POINT" @@ -237,9 +238,10 @@ PS_PUBLISH = "publish" PS_RETRACT = "retract" # used for items PS_DELETE = "delete" # used for nodes + PS_PURGE = "purge" # used for nodes PS_ITEM = "item" PS_ITEMS = "items" # Can contain publish and retract items - PS_EVENTS = (PS_ITEMS, PS_DELETE) + PS_EVENTS = (PS_ITEMS, PS_DELETE, PS_PURGE) ## MESSAGE/NOTIFICATION LEVELS ## @@ -366,6 +368,7 @@ ## Common extra keys/values ## KEY_ORDER_BY = "order_by" + KEY_USE_CACHE = "use_cache" ORDER_BY_CREATION = 'creation' ORDER_BY_MODIFICATION = 'modification'
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sat/core/core_types.py Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +# Libervia types +# Copyright (C) 2011 Jérôme Poisson (goffi@goffi.org) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + + +class SatXMPPEntity: + pass
--- a/sat/core/sat_main.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/core/sat_main.py Wed Sep 01 13:47:17 2021 +0200 @@ -58,6 +58,7 @@ log = getLogger(__name__) class SAT(service.Service): + def _init(self): # we don't use __init__ to avoid doule initialisation with twistd # this _init is called in startService @@ -80,7 +81,7 @@ self.memory = memory.Memory(self) - # trigger are used to change SàT behaviour + # trigger are used to change Libervia behaviour self.trigger = ( trigger.TriggerManager() ) @@ -92,14 +93,53 @@ bridge_module = dynamic_import.bridge(bridge_name) if bridge_module is None: - log.error("Can't find bridge module of name {}".format(bridge_name)) + log.error(f"Can't find bridge module of name {bridge_name}") sys.exit(1) - log.info("using {} bridge".format(bridge_name)) + log.info(f"using {bridge_name} bridge") try: self.bridge = bridge_module.Bridge() except exceptions.BridgeInitError: - log.error("Bridge can't be initialised, can't start SàT core") + log.error("Bridge can't be initialised, can't start Libervia Backend") sys.exit(1) + + defer.ensureDeferred(self._postInit()) + + @property + def version(self): + """Return the short version of Libervia""" + return C.APP_VERSION + + @property + def full_version(self): + """Return the full version of Libervia + + In developement mode, release name and extra data are returned too + """ + version = self.version + if version[-1] == "D": + # we are in debug version, we add extra data + try: + return self._version_cache + except AttributeError: + self._version_cache = "{} « {} » ({})".format( + version, C.APP_RELEASE_NAME, utils.getRepositoryData(sat) + ) + return self._version_cache + else: + return version + + @property + def bridge_name(self): + return os.path.splitext(os.path.basename(self.bridge.__file__))[0] + + async def _postInit(self): + try: + bridge_pi = self.bridge.postInit + except AttributeError: + pass + else: + await bridge_pi() + self.bridge.register_method("getReady", lambda: self.initialised) self.bridge.register_method("getVersion", lambda: self.full_version) self.bridge.register_method("getFeatures", self.getFeatures) @@ -179,38 +219,8 @@ self.bridge.register_method("imageGeneratePreview", self._imageGeneratePreview) self.bridge.register_method("imageConvert", self._imageConvert) - self.memory.initialized.addCallback(lambda __: defer.ensureDeferred(self._postMemoryInit())) - @property - def version(self): - """Return the short version of SàT""" - return C.APP_VERSION - - @property - def full_version(self): - """Return the full version of SàT - - In developement mode, release name and extra data are returned too - """ - version = self.version - if version[-1] == "D": - # we are in debug version, we add extra data - try: - return self._version_cache - except AttributeError: - self._version_cache = "{} « {} » ({})".format( - version, C.APP_RELEASE_NAME, utils.getRepositoryData(sat) - ) - return self._version_cache - else: - return version - - @property - def bridge_name(self): - return os.path.splitext(os.path.basename(self.bridge.__file__))[0] - - async def _postMemoryInit(self): - """Method called after memory initialization is done""" + await self.memory.initialise() self.common_cache = cache.Cache(self, None) log.info(_("Memory initialised")) try: @@ -449,7 +459,7 @@ except AttributeError: continue else: - defers_list.append(defer.maybeDeferred(unload)) + defers_list.append(utils.asDeferred(unload)) return defers_list def _connect(self, profile_key, password="", options=None): @@ -463,7 +473,7 @@ Retrieve the individual parameters, authenticate the profile and initiate the connection to the associated XMPP server. @param profile: %(doc_profile)s - @param password (string): the SàT profile password + @param password (string): the Libervia profile password @param options (dict): connection options. Key can be: - @param max_retries (int): max number of connection retries @@ -526,7 +536,7 @@ features = [] for import_name, plugin in self.plugins.items(): try: - features_d = defer.maybeDeferred(plugin.getFeatures, profile_key) + features_d = utils.asDeferred(plugin.getFeatures, profile_key) except AttributeError: features_d = defer.succeed({}) features.append(features_d) @@ -574,7 +584,7 @@ attr = client.roster.getAttributes(item) # we use full() and not userhost() because jid with resources are allowed # in roster, even if it's not common. - ret.append([item.entity.full(), attr, item.groups]) + ret.append([item.entity.full(), attr, list(item.groups)]) return ret return client.roster.got_roster.addCallback(got_roster) @@ -1101,6 +1111,7 @@ def _findByFeatures(self, namespaces, identities, bare_jids, service, roster, own_jid, local_device, profile_key): client = self.getClient(profile_key) + identities = [tuple(i) for i in identities] if identities else None return defer.ensureDeferred(self.findByFeatures( client, namespaces, identities, bare_jids, service, roster, own_jid, local_device))
--- a/sat/core/xmpp.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/core/xmpp.py Wed Sep 01 13:47:17 2021 +0200 @@ -42,6 +42,7 @@ from wokkel import delay from sat.core.log import getLogger from sat.core import exceptions +from sat.core import core_types from sat.memory import encryption from sat.memory import persistent from sat.tools import xml_tools @@ -83,7 +84,7 @@ return partial(getattr(self.plugin, attr), self.client) -class SatXMPPEntity: +class SatXMPPEntity(core_types.SatXMPPEntity): """Common code for Client and Component""" # profile is added there when startConnection begins and removed when it is finished profiles_connecting = set() @@ -714,7 +715,7 @@ or mess_data["type"] == C.MESS_TYPE_INFO ) - def messageAddToHistory(self, data): + async def messageAddToHistory(self, data): """Store message into database (for local history) @param data: message data dictionnary @@ -726,7 +727,7 @@ # we need a message to store if self.isMessagePrintable(data): - self.host_app.memory.addToHistory(self, data) + await self.host_app.memory.addToHistory(self, data) else: log.warning( "No message found" @@ -876,7 +877,9 @@ def addPostXmlCallbacks(self, post_xml_treatments): post_xml_treatments.addCallback(self.messageProt.completeAttachments) - post_xml_treatments.addCallback(self.messageAddToHistory) + post_xml_treatments.addCallback( + lambda ret: defer.ensureDeferred(self.messageAddToHistory(ret)) + ) post_xml_treatments.addCallback(self.messageSendToBridge) def send(self, obj): @@ -1061,7 +1064,9 @@ def addPostXmlCallbacks(self, post_xml_treatments): if self.sendHistory: - post_xml_treatments.addCallback(self.messageAddToHistory) + post_xml_treatments.addCallback( + lambda ret: defer.ensureDeferred(self.messageAddToHistory(ret)) + ) def getOwnerFromJid(self, to_jid: jid.JID) -> jid.JID: """Retrieve "owner" of a component resource from the destination jid of the request @@ -1212,7 +1217,9 @@ data = self.parseMessage(message_elt) post_treat.addCallback(self.completeAttachments) post_treat.addCallback(self.skipEmptyMessage) - post_treat.addCallback(self.addToHistory) + post_treat.addCallback( + lambda ret: defer.ensureDeferred(self.addToHistory(ret)) + ) post_treat.addCallback(self.bridgeSignal, data) post_treat.addErrback(self.cancelErrorTrap) post_treat.callback(data) @@ -1253,14 +1260,14 @@ raise failure.Failure(exceptions.CancelError("Cancelled empty message")) return data - def addToHistory(self, data): + async def addToHistory(self, data): if data.pop("history", None) == C.HISTORY_SKIP: log.debug("history is skipped as requested") data["extra"]["history"] = C.HISTORY_SKIP else: # we need a message to store if self.parent.isMessagePrintable(data): - return self.host.memory.addToHistory(self.parent, data) + return await self.host.memory.addToHistory(self.parent, data) else: log.debug("not storing empty message to history: {data}" .format(data=data)) @@ -1478,7 +1485,8 @@ self._jids[entity] = item self._registerItem(item) self.host.bridge.newContact( - entity.full(), self.getAttributes(item), item.groups, self.parent.profile + entity.full(), self.getAttributes(item), list(item.groups), + self.parent.profile ) def removeReceived(self, request): @@ -1544,7 +1552,7 @@ f"a JID is expected, not {type(entity_jid)}: {entity_jid!r}") return entity_jid in self._jids - def isPresenceAuthorised(self, entity_jid): + def isSubscribedFrom(self, entity_jid: jid.JID) -> bool: """Return True if entity is authorised to see our presence""" try: item = self._jids[entity_jid.userhostJID()] @@ -1552,6 +1560,14 @@ return False return item.subscriptionFrom + def isSubscribedTo(self, entity_jid: jid.JID) -> bool: + """Return True if we are subscribed to entity""" + try: + item = self._jids[entity_jid.userhostJID()] + except KeyError: + return False + return item.subscriptionTo + def getItems(self): """Return all items of the roster""" return list(self._jids.values())
--- a/sat/memory/memory.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/memory/memory.py Wed Sep 01 13:47:17 2021 +0200 @@ -33,7 +33,7 @@ from sat.core.log import getLogger from sat.core import exceptions from sat.core.constants import Const as C -from sat.memory.sqlite import SqliteStorage +from sat.memory.sqla import Storage from sat.memory.persistent import PersistentDict from sat.memory.params import Params from sat.memory.disco import Discovery @@ -223,12 +223,11 @@ ) -class Memory(object): +class Memory: """This class manage all the persistent information""" def __init__(self, host): log.info(_("Memory manager init")) - self.initialized = defer.Deferred() self.host = host self._entities_cache = {} # XXX: keep presence/last resource/other data in cache # /!\ an entity is not necessarily in roster @@ -240,19 +239,19 @@ self.disco = Discovery(host) self.config = tools_config.parseMainConf(log_filenames=True) self._cache_path = Path(self.getConfig("", "local_dir"), C.CACHE_DIR) - database_file = os.path.expanduser( - os.path.join(self.getConfig("", "local_dir"), C.SAVEFILE_DATABASE) - ) - self.storage = SqliteStorage(database_file, host.version) + + async def initialise(self): + self.storage = Storage() + await self.storage.initialise() PersistentDict.storage = self.storage - self.params = Params(host, self.storage) + self.params = Params(self.host, self.storage) log.info(_("Loading default params template")) self.params.load_default_params() - d = self.storage.initialized.addCallback(lambda ignore: self.load()) + await self.load() self.memory_data = PersistentDict("memory") - d.addCallback(lambda ignore: self.memory_data.load()) - d.addCallback(lambda ignore: self.disco.load()) - d.chainDeferred(self.initialized) + await self.memory_data.load() + await self.disco.load() + ## Configuration ## @@ -1129,12 +1128,12 @@ ) def asyncGetStringParamA( - self, name, category, attr="value", security_limit=C.NO_SECURITY_LIMIT, + self, name, category, attribute="value", security_limit=C.NO_SECURITY_LIMIT, profile_key=C.PROF_KEY_NONE): profile = self.getProfileName(profile_key) return defer.ensureDeferred(self.params.asyncGetStringParamA( - name, category, attr, security_limit, profile + name, category, attribute, security_limit, profile )) def _getParamsUI(self, security_limit, app, extra_s, profile_key): @@ -1168,20 +1167,22 @@ client = self.host.getClient(profile_key) # we accept any type data = data_format.deserialise(data_s, type_check=None) - return self.storage.setPrivateValue( - namespace, key, data, binary=True, profile=client.profile) + return defer.ensureDeferred(self.storage.setPrivateValue( + namespace, key, data, binary=True, profile=client.profile)) def _privateDataGet(self, namespace, key, profile_key): client = self.host.getClient(profile_key) - d = self.storage.getPrivates( - namespace, [key], binary=True, profile=client.profile) + d = defer.ensureDeferred( + self.storage.getPrivates( + namespace, [key], binary=True, profile=client.profile) + ) d.addCallback(lambda data_dict: data_format.serialise(data_dict.get(key))) return d def _privateDataDelete(self, namespace, key, profile_key): client = self.host.getClient(profile_key) - return self.storage.delPrivateValue( - namespace, key, binary=True, profile=client.profile) + return defer.ensureDeferred(self.storage.delPrivateValue( + namespace, key, binary=True, profile=client.profile)) ## Files ## @@ -1249,8 +1250,7 @@ _("unknown access type: {type}").format(type=perm_type) ) - @defer.inlineCallbacks - def checkPermissionToRoot(self, client, file_data, peer_jid, perms_to_check): + async def checkPermissionToRoot(self, client, file_data, peer_jid, perms_to_check): """do checkFilePermission on file_data and all its parents until root""" current = file_data while True: @@ -1258,7 +1258,7 @@ parent = current["parent"] if not parent: break - files_data = yield self.getFiles( + files_data = await self.getFiles( client, peer_jid=None, file_id=parent, perms_to_check=None ) try: @@ -1266,8 +1266,7 @@ except IndexError: raise exceptions.DataError("Missing parent") - @defer.inlineCallbacks - def _getParentDir( + async def _getParentDir( self, client, path, parent, namespace, owner, peer_jid, perms_to_check ): """Retrieve parent node from a path, or last existing directory @@ -1291,7 +1290,7 @@ # non existing directories will be created parent = "" for idx, path_elt in enumerate(path_elts): - directories = yield self.storage.getFiles( + directories = await self.storage.getFiles( client, parent=parent, type_=C.FILE_TYPE_DIRECTORY, @@ -1300,7 +1299,7 @@ owner=owner, ) if not directories: - defer.returnValue((parent, path_elts[idx:])) + return (parent, path_elts[idx:]) # from this point, directories don't exist anymore, we have to create them elif len(directories) > 1: raise exceptions.InternalError( @@ -1310,7 +1309,7 @@ directory = directories[0] self.checkFilePermission(directory, peer_jid, perms_to_check) parent = directory["id"] - defer.returnValue((parent, [])) + return (parent, []) def getFileAffiliations(self, file_data: dict) -> Dict[jid.JID, str]: """Convert file access to pubsub like affiliations""" @@ -1482,8 +1481,7 @@ ) return peer_jid.userhostJID() - @defer.inlineCallbacks - def getFiles( + async def getFiles( self, client, peer_jid, file_id=None, version=None, parent=None, path=None, type_=None, file_hash=None, hash_algo=None, name=None, namespace=None, mime_type=None, public_id=None, owner=None, access=None, projection=None, @@ -1534,7 +1532,7 @@ if path is not None: path = str(path) # permission are checked by _getParentDir - parent, remaining_path_elts = yield self._getParentDir( + parent, remaining_path_elts = await self._getParentDir( client, path, parent, namespace, owner, peer_jid, perms_to_check ) if remaining_path_elts: @@ -1544,16 +1542,16 @@ if parent and peer_jid: # if parent is given directly and permission check is requested, # we need to check all the parents - parent_data = yield self.storage.getFiles(client, file_id=parent) + parent_data = await self.storage.getFiles(client, file_id=parent) try: parent_data = parent_data[0] except IndexError: raise exceptions.DataError("mising parent") - yield self.checkPermissionToRoot( + await self.checkPermissionToRoot( client, parent_data, peer_jid, perms_to_check ) - files = yield self.storage.getFiles( + files = await self.storage.getFiles( client, file_id=file_id, version=version, @@ -1576,15 +1574,16 @@ to_remove = [] for file_data in files: try: - self.checkFilePermission(file_data, peer_jid, perms_to_check, set_affiliation=True) + self.checkFilePermission( + file_data, peer_jid, perms_to_check, set_affiliation=True + ) except exceptions.PermissionError: to_remove.append(file_data) for file_data in to_remove: files.remove(file_data) - defer.returnValue(files) + return files - @defer.inlineCallbacks - def setFile( + async def setFile( self, client, name, file_id=None, version="", parent=None, path=None, type_=C.FILE_TYPE_FILE, file_hash=None, hash_algo=None, size=None, namespace=None, mime_type=None, public_id=None, created=None, modified=None, @@ -1666,13 +1665,13 @@ if path is not None: path = str(path) # _getParentDir will check permissions if peer_jid is set, so we use owner - parent, remaining_path_elts = yield self._getParentDir( + parent, remaining_path_elts = await self._getParentDir( client, path, parent, namespace, owner, owner, perms_to_check ) # if remaining directories don't exist, we have to create them for new_dir in remaining_path_elts: new_dir_id = shortuuid.uuid() - yield self.storage.setFile( + await self.storage.setFile( client, name=new_dir, file_id=new_dir_id, @@ -1689,7 +1688,7 @@ elif parent is None: parent = "" - yield self.storage.setFile( + await self.storage.setFile( client, file_id=file_id, version=version,
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sat/memory/migration/README Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,3 @@ +This directory and subdirectories contains Alembic migration scripts. + +Please check Libervia documentation for details.
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sat/memory/migration/alembic.ini Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,89 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = %(here)s + +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +# prepend_sys_path = . + +# timezone to use when rendering the date +# within the migration file as well as the filename. +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; this defaults +# to migration/versions. When using multiple version +# directories, initial revisions must be specified with --version-path +# version_locations = %(here)s/bar %(here)s/bat migration/versions + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sat/memory/migration/env.py Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,85 @@ +import asyncio +from logging.config import fileConfig +from sqlalchemy import pool +from sqlalchemy.ext.asyncio import create_async_engine +from alembic import context +from sat.memory import sqla_config +from sat.memory.sqla_mapping import Base + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + db_config = sqla_config.getDbConfig() + context.configure( + url=db_config["url"], + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection): + context.configure( + connection=connection, + target_metadata=target_metadata, + render_as_batch=True + ) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + db_config = sqla_config.getDbConfig() + engine = create_async_engine( + db_config["url"], + poolclass=pool.NullPool, + future=True, + ) + + async with engine.connect() as connection: + await connection.run_sync(do_run_migrations) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + asyncio.run(run_migrations_online())
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sat/memory/migration/script.py.mako Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"}
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sat/memory/migration/versions/602caf848068_drop_message_types_table_fix_nullable.py Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,410 @@ +"""drop message_types table + fix nullable + rename constraints + +Revision ID: 602caf848068 +Revises: +Create Date: 2021-06-26 12:42:54.148313 + +""" +from alembic import op +from sqlalchemy import ( + Table, + Column, + MetaData, + TEXT, + INTEGER, + Text, + Integer, + Float, + Enum, + ForeignKey, + Index, + PrimaryKeyConstraint, +) +from sqlalchemy.sql import table, column + + +# revision identifiers, used by Alembic. +revision = "602caf848068" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + # we have to recreate former tables for batch_alter_table's reflexion, otherwise the + # database will be used, and this will keep unammed UNIQUE constraints in addition + # to the named ones that we create + metadata = MetaData( + naming_convention={ + "ix": "ix_%(column_0_label)s", + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s", + }, + ) + + old_profiles_table = Table( + "profiles", + metadata, + Column("id", Integer, primary_key=True, nullable=True, autoincrement=False), + Column("name", Text, unique=True), + ) + + old_components_table = Table( + "components", + metadata, + Column( + "profile_id", + ForeignKey("profiles.id", ondelete="CASCADE"), + nullable=True, + primary_key=True, + ), + Column("entry_point", Text, nullable=False), + ) + + old_message_table = Table( + "message", + metadata, + Column("id", Integer, primary_key=True, nullable=True, autoincrement=False), + Column("history_uid", ForeignKey("history.uid", ondelete="CASCADE")), + Column("message", Text), + Column("language", Text), + Index("message__history_uid", "history_uid"), + ) + + old_subject_table = Table( + "subject", + metadata, + Column("id", Integer, primary_key=True, nullable=True, autoincrement=False), + Column("history_uid", ForeignKey("history.uid", ondelete="CASCADE")), + Column("subject", Text), + Column("language", Text), + Index("subject__history_uid", "history_uid"), + ) + + old_thread_table = Table( + "thread", + metadata, + Column("id", Integer, primary_key=True, nullable=True, autoincrement=False), + Column("history_uid", ForeignKey("history.uid", ondelete="CASCADE")), + Column("thread_id", Text), + Column("parent_id", Text), + Index("thread__history_uid", "history_uid"), + ) + + old_history_table = Table( + "history", + metadata, + Column("uid", Text, primary_key=True, nullable=True), + Column("stanza_id", Text), + Column("update_uid", Text), + Column("profile_id", Integer, ForeignKey("profiles.id", ondelete="CASCADE")), + Column("source", Text), + Column("dest", Text), + Column("source_res", Text), + Column("dest_res", Text), + Column("timestamp", Float, nullable=False), + Column("received_timestamp", Float), + Column("type", Text, ForeignKey("message_types.type")), + Column("extra", Text), + Index("history__profile_id_timestamp", "profile_id", "timestamp"), + Index( + "history__profile_id_received_timestamp", "profile_id", "received_timestamp" + ), + ) + + old_param_gen_table = Table( + "param_gen", + metadata, + Column("category", Text, primary_key=True), + Column("name", Text, primary_key=True), + Column("value", Text), + ) + + old_param_ind_table = Table( + "param_ind", + metadata, + Column("category", Text, primary_key=True), + Column("name", Text, primary_key=True), + Column( + "profile_id", ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True + ), + Column("value", Text), + ) + + old_private_gen_table = Table( + "private_gen", + metadata, + Column("namespace", Text, primary_key=True), + Column("key", Text, primary_key=True), + Column("value", Text), + ) + + old_private_ind_table = Table( + "private_ind", + metadata, + Column("namespace", Text, primary_key=True), + Column("key", Text, primary_key=True), + Column( + "profile_id", ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True + ), + Column("value", Text), + ) + + old_private_gen_bin_table = Table( + "private_gen_bin", + metadata, + Column("namespace", Text, primary_key=True), + Column("key", Text, primary_key=True), + Column("value", Text), + ) + + old_private_ind_bin_table = Table( + "private_ind_bin", + metadata, + Column("namespace", Text, primary_key=True), + Column("key", Text, primary_key=True), + Column( + "profile_id", ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True + ), + Column("value", Text), + ) + + old_files_table = Table( + "files", + metadata, + Column("id", Text, primary_key=True), + Column("public_id", Text, unique=True), + Column("version", Text, primary_key=True), + Column("parent", Text, nullable=False), + Column( + "type", + Enum("file", "directory", name="file_type", create_constraint=True), + nullable=False, + server_default="file", + ), + Column("file_hash", Text), + Column("hash_algo", Text), + Column("name", Text, nullable=False), + Column("size", Integer), + Column("namespace", Text), + Column("media_type", Text), + Column("media_subtype", Text), + Column("created", Float, nullable=False), + Column("modified", Float), + Column("owner", Text), + Column("access", Text), + Column("extra", Text), + Column("profile_id", ForeignKey("profiles.id", ondelete="CASCADE")), + Index("files__profile_id_owner_parent", "profile_id", "owner", "parent"), + Index( + "files__profile_id_owner_media_type_media_subtype", + "profile_id", + "owner", + "media_type", + "media_subtype", + ), + ) + + op.drop_table("message_types") + + with op.batch_alter_table( + "profiles", copy_from=old_profiles_table, schema=None + ) as batch_op: + batch_op.create_unique_constraint(batch_op.f("uq_profiles_name"), ["name"]) + + with op.batch_alter_table( + "components", + copy_from=old_components_table, + naming_convention={ + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + }, + schema=None, + ) as batch_op: + batch_op.create_unique_constraint(batch_op.f("uq_profiles_name"), ["name"]) + + with op.batch_alter_table( + "history", + copy_from=old_history_table, + naming_convention={ + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + }, + schema=None, + ) as batch_op: + batch_op.alter_column("uid", existing_type=TEXT(), nullable=False) + batch_op.alter_column( + "type", + type_=Enum( + "chat", + "error", + "groupchat", + "headline", + "normal", + "info", + name="message_type", + create_constraint=True, + ), + existing_type=TEXT(), + nullable=False, + ) + batch_op.create_unique_constraint( + batch_op.f("uq_history_profile_id"), + ["profile_id", "stanza_id", "source", "dest"], + ) + batch_op.drop_constraint("fk_history_type_message_types", type_="foreignkey") + + with op.batch_alter_table( + "message", copy_from=old_message_table, schema=None + ) as batch_op: + batch_op.alter_column( + "id", existing_type=INTEGER(), nullable=False, autoincrement=False + ) + + with op.batch_alter_table( + "subject", copy_from=old_subject_table, schema=None + ) as batch_op: + batch_op.alter_column( + "id", existing_type=INTEGER(), nullable=False, autoincrement=False + ) + + with op.batch_alter_table( + "thread", copy_from=old_thread_table, schema=None + ) as batch_op: + batch_op.alter_column( + "id", existing_type=INTEGER(), nullable=False, autoincrement=False + ) + + with op.batch_alter_table( + "param_gen", copy_from=old_param_gen_table, schema=None + ) as batch_op: + batch_op.alter_column("category", existing_type=TEXT(), nullable=False) + batch_op.alter_column("name", existing_type=TEXT(), nullable=False) + + with op.batch_alter_table( + "param_ind", copy_from=old_param_ind_table, schema=None + ) as batch_op: + batch_op.alter_column("category", existing_type=TEXT(), nullable=False) + batch_op.alter_column("name", existing_type=TEXT(), nullable=False) + batch_op.alter_column("profile_id", existing_type=INTEGER(), nullable=False) + + with op.batch_alter_table( + "private_gen", copy_from=old_private_gen_table, schema=None + ) as batch_op: + batch_op.alter_column("namespace", existing_type=TEXT(), nullable=False) + batch_op.alter_column("key", existing_type=TEXT(), nullable=False) + + with op.batch_alter_table( + "private_ind", copy_from=old_private_ind_table, schema=None + ) as batch_op: + batch_op.alter_column("namespace", existing_type=TEXT(), nullable=False) + batch_op.alter_column("key", existing_type=TEXT(), nullable=False) + batch_op.alter_column("profile_id", existing_type=INTEGER(), nullable=False) + + with op.batch_alter_table( + "private_gen_bin", copy_from=old_private_gen_bin_table, schema=None + ) as batch_op: + batch_op.alter_column("namespace", existing_type=TEXT(), nullable=False) + batch_op.alter_column("key", existing_type=TEXT(), nullable=False) + + # found some invalid rows in local database, maybe old values made during development, + # but in doubt we have to delete them + op.execute("DELETE FROM private_ind_bin WHERE namespace IS NULL") + + with op.batch_alter_table( + "private_ind_bin", copy_from=old_private_ind_bin_table, schema=None + ) as batch_op: + batch_op.alter_column("namespace", existing_type=TEXT(), nullable=False) + batch_op.alter_column("key", existing_type=TEXT(), nullable=False) + batch_op.alter_column("profile_id", existing_type=INTEGER(), nullable=False) + + with op.batch_alter_table( + "files", copy_from=old_files_table, schema=None + ) as batch_op: + batch_op.create_unique_constraint(batch_op.f("uq_files_public_id"), ["public_id"]) + batch_op.alter_column( + "type", + type_=Enum("file", "directory", name="file_type", create_constraint=True), + existing_type=Text(), + nullable=False, + ) + + +def downgrade(): + # downgrade doesn't restore the exact same state as before upgrade, as it + # would be useless and waste of resource to restore broken things such as + # anonymous constraints + with op.batch_alter_table("thread", schema=None) as batch_op: + batch_op.alter_column( + "id", existing_type=INTEGER(), nullable=True, autoincrement=False + ) + + with op.batch_alter_table("subject", schema=None) as batch_op: + batch_op.alter_column( + "id", existing_type=INTEGER(), nullable=True, autoincrement=False + ) + + with op.batch_alter_table("private_ind_bin", schema=None) as batch_op: + batch_op.alter_column("profile_id", existing_type=INTEGER(), nullable=True) + batch_op.alter_column("key", existing_type=TEXT(), nullable=True) + batch_op.alter_column("namespace", existing_type=TEXT(), nullable=True) + + with op.batch_alter_table("private_ind", schema=None) as batch_op: + batch_op.alter_column("profile_id", existing_type=INTEGER(), nullable=True) + batch_op.alter_column("key", existing_type=TEXT(), nullable=True) + batch_op.alter_column("namespace", existing_type=TEXT(), nullable=True) + + with op.batch_alter_table("private_gen_bin", schema=None) as batch_op: + batch_op.alter_column("key", existing_type=TEXT(), nullable=True) + batch_op.alter_column("namespace", existing_type=TEXT(), nullable=True) + + with op.batch_alter_table("private_gen", schema=None) as batch_op: + batch_op.alter_column("key", existing_type=TEXT(), nullable=True) + batch_op.alter_column("namespace", existing_type=TEXT(), nullable=True) + + with op.batch_alter_table("param_ind", schema=None) as batch_op: + batch_op.alter_column("profile_id", existing_type=INTEGER(), nullable=True) + batch_op.alter_column("name", existing_type=TEXT(), nullable=True) + batch_op.alter_column("category", existing_type=TEXT(), nullable=True) + + with op.batch_alter_table("param_gen", schema=None) as batch_op: + batch_op.alter_column("name", existing_type=TEXT(), nullable=True) + batch_op.alter_column("category", existing_type=TEXT(), nullable=True) + + with op.batch_alter_table("message", schema=None) as batch_op: + batch_op.alter_column( + "id", existing_type=INTEGER(), nullable=True, autoincrement=False + ) + + op.create_table( + "message_types", + Column("type", TEXT(), nullable=True), + PrimaryKeyConstraint("type"), + ) + message_types_table = table("message_types", column("type", TEXT())) + op.bulk_insert( + message_types_table, + [ + {"type": "chat"}, + {"type": "error"}, + {"type": "groupchat"}, + {"type": "headline"}, + {"type": "normal"}, + {"type": "info"}, + ], + ) + + with op.batch_alter_table("history", schema=None) as batch_op: + batch_op.alter_column( + "type", + type_=TEXT(), + existing_type=TEXT(), + nullable=True, + ) + batch_op.create_foreign_key( + batch_op.f("fk_history_type_message_types"), + "message_types", + ["type"], + ["type"], + ) + batch_op.alter_column("uid", existing_type=TEXT(), nullable=True)
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sat/memory/migration/versions/8974efc51d22_create_tables_for_pubsub_caching.py Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,57 @@ +"""create tables for Pubsub caching + +Revision ID: 8974efc51d22 +Revises: 602caf848068 +Create Date: 2021-07-27 16:38:54.658212 + +""" +from alembic import op +import sqlalchemy as sa +from sat.memory.sqla_mapping import JID, Xml + + +# revision identifiers, used by Alembic. +revision = '8974efc51d22' +down_revision = '602caf848068' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('pubsub_nodes', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('profile_id', sa.Integer(), nullable=True), + sa.Column('service', JID(), nullable=True), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('subscribed', sa.Boolean(create_constraint=True, name='subscribed_bool'), nullable=False), + sa.Column('analyser', sa.Text(), nullable=True), + sa.Column('sync_state', sa.Enum('IN_PROGRESS', 'COMPLETED', 'ERROR', 'NO_SYNC', name='sync_state', create_constraint=True), nullable=True), + sa.Column('sync_state_updated', sa.Float(), nullable=False), + sa.Column('type', sa.Text(), nullable=True), + sa.Column('subtype', sa.Text(), nullable=True), + sa.Column('extra', sa.JSON(), nullable=True), + sa.ForeignKeyConstraint(['profile_id'], ['profiles.id'], name=op.f('fk_pubsub_nodes_profile_id_profiles'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_pubsub_nodes')), + sa.UniqueConstraint('profile_id', 'service', 'name', name=op.f('uq_pubsub_nodes_profile_id')) + ) + op.create_table('pubsub_items', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('node_id', sa.Integer(), nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('data', Xml(), nullable=False), + sa.Column('created', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False), + sa.Column('updated', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False), + sa.Column('parsed', sa.JSON(), nullable=True), + sa.ForeignKeyConstraint(['node_id'], ['pubsub_nodes.id'], name=op.f('fk_pubsub_items_node_id_pubsub_nodes'), ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id', name=op.f('pk_pubsub_items')), + sa.UniqueConstraint('node_id', 'name', name=op.f('uq_pubsub_items_node_id')) + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('pubsub_items') + op.drop_table('pubsub_nodes') + # ### end Alembic commands ###
--- a/sat/memory/persistent.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/memory/persistent.py Wed Sep 01 13:47:17 2021 +0200 @@ -17,10 +17,13 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. +from twisted.internet import defer +from twisted.python import failure from sat.core.i18n import _ from sat.core.log import getLogger + + log = getLogger(__name__) -from twisted.python import failure class MemoryNotInitializedError(Exception): @@ -57,7 +60,9 @@ need to be called before any other operation @return: defers the PersistentDict instance itself """ - d = self.storage.getPrivates(self.namespace, binary=self.binary, profile=self.profile) + d = defer.ensureDeferred(self.storage.getPrivates( + self.namespace, binary=self.binary, profile=self.profile + )) d.addCallback(self._setCache) d.addCallback(lambda __: self) return d @@ -111,8 +116,11 @@ return self._cache.__getitem__(key) def __setitem__(self, key, value): - self.storage.setPrivateValue(self.namespace, key, value, self.binary, - self.profile) + defer.ensureDeferred( + self.storage.setPrivateValue( + self.namespace, key, value, self.binary, self.profile + ) + ) return self._cache.__setitem__(key, value) def __delitem__(self, key): @@ -130,8 +138,11 @@ def aset(self, key, value): """Async set, return a Deferred fired when value is actually stored""" self._cache.__setitem__(key, value) - return self.storage.setPrivateValue(self.namespace, key, value, - self.binary, self.profile) + return defer.ensureDeferred( + self.storage.setPrivateValue( + self.namespace, key, value, self.binary, self.profile + ) + ) def adel(self, key): """Async del, return a Deferred fired when value is actually deleted""" @@ -151,8 +162,11 @@ @return: deferred fired when data is actually saved """ - return self.storage.setPrivateValue(self.namespace, name, self._cache[name], - self.binary, self.profile) + return defer.ensureDeferred( + self.storage.setPrivateValue( + self.namespace, name, self._cache[name], self.binary, self.profile + ) + ) class PersistentBinaryDict(PersistentDict): @@ -178,12 +192,16 @@ raise NotImplementedError def items(self): - d = self.storage.getPrivates(self.namespace, binary=self.binary, profile=self.profile) + d = defer.ensureDeferred(self.storage.getPrivates( + self.namespace, binary=self.binary, profile=self.profile + )) d.addCallback(lambda data_dict: data_dict.items()) return d def all(self): - return self.storage.getPrivates(self.namespace, binary=self.binary, profile=self.profile) + return defer.ensureDeferred(self.storage.getPrivates( + self.namespace, binary=self.binary, profile=self.profile + )) def __repr__(self): raise NotImplementedError @@ -234,14 +252,18 @@ def __getitem__(self, key): """get the value as a Deferred""" - d = self.storage.getPrivates(self.namespace, keys=[key], binary=self.binary, - profile=self.profile) + d = defer.ensureDeferred(self.storage.getPrivates( + self.namespace, keys=[key], binary=self.binary, profile=self.profile + )) d.addCallback(self._data2value, key) return d def __setitem__(self, key, value): - self.storage.setPrivateValue(self.namespace, key, value, self.binary, - self.profile) + defer.ensureDeferred( + self.storage.setPrivateValue( + self.namespace, key, value, self.binary, self.profile + ) + ) def __delitem__(self, key): self.storage.delPrivateValue(self.namespace, key, self.binary, self.profile) @@ -259,8 +281,11 @@ """Async set, return a Deferred fired when value is actually stored""" # FIXME: redundant with force, force must be removed # XXX: similar as PersistentDict.aset, but doesn't use cache - return self.storage.setPrivateValue(self.namespace, key, value, - self.binary, self.profile) + return defer.ensureDeferred( + self.storage.setPrivateValue( + self.namespace, key, value, self.binary, self.profile + ) + ) def adel(self, key): """Async del, return a Deferred fired when value is actually deleted""" @@ -277,7 +302,11 @@ @param value(object): value is needed for LazyPersistentBinaryDict @return: deferred fired when data is actually saved """ - return self.storage.setPrivateValue(self.namespace, name, value, self.binary, self.profile) + return defer.ensureDeferred( + self.storage.setPrivateValue( + self.namespace, name, value, self.binary, self.profile + ) + ) def remove(self, key): """Delete a key from sotrage, and return a deferred called when it's done
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sat/memory/sqla.py Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,1357 @@ +#!/usr/bin/env python3 + +# Libervia: an XMPP client +# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import sys +import time +import asyncio +from datetime import datetime +from asyncio.subprocess import PIPE +from pathlib import Path +from typing import Dict, List, Tuple, Iterable, Any, Callable, Optional +from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine +from sqlalchemy.exc import IntegrityError, NoResultFound +from sqlalchemy.orm import ( + sessionmaker, subqueryload, joinedload, contains_eager # , aliased +) +from sqlalchemy.orm.decl_api import DeclarativeMeta +from sqlalchemy.future import select +from sqlalchemy.engine import Engine, Connection +from sqlalchemy import update, delete, and_, or_, event, func +from sqlalchemy.sql.functions import coalesce, sum as sum_, now, count +from sqlalchemy.dialects.sqlite import insert +from alembic import script as al_script, config as al_config +from alembic.runtime import migration as al_migration +from twisted.internet import defer +from twisted.words.protocols.jabber import jid +from twisted.words.xish import domish +from sat.core.i18n import _ +from sat.core import exceptions +from sat.core.log import getLogger +from sat.core.constants import Const as C +from sat.core.core_types import SatXMPPEntity +from sat.tools.utils import aio +from sat.tools.common import uri +from sat.memory import migration +from sat.memory import sqla_config +from sat.memory.sqla_mapping import ( + NOT_IN_EXTRA, + SyncState, + Base, + Profile, + Component, + History, + Message, + Subject, + Thread, + ParamGen, + ParamInd, + PrivateGen, + PrivateInd, + PrivateGenBin, + PrivateIndBin, + File, + PubsubNode, + PubsubItem, +) + + +log = getLogger(__name__) +migration_path = Path(migration.__file__).parent + + +@event.listens_for(Engine, "connect") +def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + +class Storage: + + def __init__(self): + self.initialized = defer.Deferred() + # we keep cache for the profiles (key: profile name, value: profile id) + # profile id to name + self.profiles: Dict[int, str] = {} + # profile id to component entry point + self.components: Dict[int, str] = {} + + async def migrateApply(self, *args: str, log_output: bool = False) -> None: + """Do a migration command + + Commands are applied by running Alembic in a subprocess. + Arguments are alembic executables commands + + @param log_output: manage stdout and stderr: + - if False, stdout and stderr are buffered, and logged only in case of error + - if True, stdout and stderr will be logged during the command execution + @raise exceptions.DatabaseError: something went wrong while running the + process + """ + stdout, stderr = 2 * (None,) if log_output else 2 * (PIPE,) + proc = await asyncio.create_subprocess_exec( + sys.executable, "-m", "alembic", *args, + stdout=stdout, stderr=stderr, cwd=migration_path + ) + log_out, log_err = await proc.communicate() + if proc.returncode != 0: + msg = _( + "Can't {operation} database (exit code {exit_code})" + ).format( + operation=args[0], + exit_code=proc.returncode + ) + if log_out or log_err: + msg += f":\nstdout: {log_out.decode()}\nstderr: {log_err.decode()}" + log.error(msg) + + raise exceptions.DatabaseError(msg) + + async def createDB(self, engine: AsyncEngine, db_config: dict) -> None: + """Create a new database + + The database is generated from SQLAlchemy model, then stamped by Alembic + """ + # the dir may not exist if it's not the XDG recommended one + db_config["path"].parent.mkdir(0o700, True, True) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + log.debug("stamping the database") + await self.migrateApply("stamp", "head") + log.debug("stamping done") + + def _checkDBIsUpToDate(self, conn: Connection) -> bool: + al_ini_path = migration_path / "alembic.ini" + al_cfg = al_config.Config(al_ini_path) + directory = al_script.ScriptDirectory.from_config(al_cfg) + context = al_migration.MigrationContext.configure(conn) + return set(context.get_current_heads()) == set(directory.get_heads()) + + async def checkAndUpdateDB(self, engine: AsyncEngine, db_config: dict) -> None: + """Check that database is up-to-date, and update if necessary""" + async with engine.connect() as conn: + up_to_date = await conn.run_sync(self._checkDBIsUpToDate) + if up_to_date: + log.debug("Database is up-to-date") + else: + log.info("Database needs to be updated") + log.info("updating…") + await self.migrateApply("upgrade", "head", log_output=True) + log.info("Database is now up-to-date") + + @aio + async def initialise(self) -> None: + log.info(_("Connecting database")) + + db_config = sqla_config.getDbConfig() + engine = create_async_engine( + db_config["url"], + future=True, + ) + + new_base = not db_config["path"].exists() + if new_base: + log.info(_("The database is new, creating the tables")) + await self.createDB(engine, db_config) + else: + await self.checkAndUpdateDB(engine, db_config) + + self.session = sessionmaker( + engine, expire_on_commit=False, class_=AsyncSession + ) + + async with self.session() as session: + result = await session.execute(select(Profile)) + for p in result.scalars(): + self.profiles[p.name] = p.id + result = await session.execute(select(Component)) + for c in result.scalars(): + self.components[c.profile_id] = c.entry_point + + self.initialized.callback(None) + + ## Generic + + @aio + async def add(self, db_obj: DeclarativeMeta) -> None: + """Add an object to database""" + async with self.session() as session: + async with session.begin(): + session.add(db_obj) + + @aio + async def delete(self, db_obj: DeclarativeMeta) -> None: + """Delete an object from database""" + async with self.session() as session: + async with session.begin(): + await session.delete(db_obj) + await session.commit() + + ## Profiles + + def getProfilesList(self) -> List[str]: + """"Return list of all registered profiles""" + return list(self.profiles.keys()) + + def hasProfile(self, profile_name: str) -> bool: + """return True if profile_name exists + + @param profile_name: name of the profile to check + """ + return profile_name in self.profiles + + def profileIsComponent(self, profile_name: str) -> bool: + try: + return self.profiles[profile_name] in self.components + except KeyError: + raise exceptions.NotFound("the requested profile doesn't exists") + + def getEntryPoint(self, profile_name: str) -> str: + try: + return self.components[self.profiles[profile_name]] + except KeyError: + raise exceptions.NotFound("the requested profile doesn't exists or is not a component") + + @aio + async def createProfile(self, name: str, component_ep: Optional[str] = None) -> None: + """Create a new profile + + @param name: name of the profile + @param component: if not None, must point to a component entry point + """ + async with self.session() as session: + profile = Profile(name=name) + async with session.begin(): + session.add(profile) + self.profiles[profile.id] = profile.name + if component_ep is not None: + async with session.begin(): + component = Component(profile=profile, entry_point=component_ep) + session.add(component) + self.components[profile.id] = component_ep + return profile + + @aio + async def deleteProfile(self, name: str) -> None: + """Delete profile + + @param name: name of the profile + """ + async with self.session() as session: + result = await session.execute(select(Profile).where(Profile.name == name)) + profile = result.scalar() + await session.delete(profile) + await session.commit() + del self.profiles[profile.id] + if profile.id in self.components: + del self.components[profile.id] + log.info(_("Profile {name!r} deleted").format(name = name)) + + ## Params + + @aio + async def loadGenParams(self, params_gen: dict) -> None: + """Load general parameters + + @param params_gen: dictionary to fill + """ + log.debug(_("loading general parameters from database")) + async with self.session() as session: + result = await session.execute(select(ParamGen)) + for p in result.scalars(): + params_gen[(p.category, p.name)] = p.value + + @aio + async def loadIndParams(self, params_ind: dict, profile: str) -> None: + """Load individual parameters + + @param params_ind: dictionary to fill + @param profile: a profile which *must* exist + """ + log.debug(_("loading individual parameters from database")) + async with self.session() as session: + result = await session.execute( + select(ParamInd).where(ParamInd.profile_id == self.profiles[profile]) + ) + for p in result.scalars(): + params_ind[(p.category, p.name)] = p.value + + @aio + async def getIndParam(self, category: str, name: str, profile: str) -> Optional[str]: + """Ask database for the value of one specific individual parameter + + @param category: category of the parameter + @param name: name of the parameter + @param profile: %(doc_profile)s + """ + async with self.session() as session: + result = await session.execute( + select(ParamInd.value) + .filter_by( + category=category, + name=name, + profile_id=self.profiles[profile] + ) + ) + return result.scalar_one_or_none() + + @aio + async def getIndParamValues(self, category: str, name: str) -> Dict[str, str]: + """Ask database for the individual values of a parameter for all profiles + + @param category: category of the parameter + @param name: name of the parameter + @return dict: profile => value map + """ + async with self.session() as session: + result = await session.execute( + select(ParamInd) + .filter_by( + category=category, + name=name + ) + .options(subqueryload(ParamInd.profile)) + ) + return {param.profile.name: param.value for param in result.scalars()} + + @aio + async def setGenParam(self, category: str, name: str, value: Optional[str]) -> None: + """Save the general parameters in database + + @param category: category of the parameter + @param name: name of the parameter + @param value: value to set + """ + async with self.session() as session: + stmt = insert(ParamGen).values( + category=category, + name=name, + value=value + ).on_conflict_do_update( + index_elements=(ParamGen.category, ParamGen.name), + set_={ + ParamGen.value: value + } + ) + await session.execute(stmt) + await session.commit() + + @aio + async def setIndParam( + self, + category:str, + name: str, + value: Optional[str], + profile: str + ) -> None: + """Save the individual parameters in database + + @param category: category of the parameter + @param name: name of the parameter + @param value: value to set + @param profile: a profile which *must* exist + """ + async with self.session() as session: + stmt = insert(ParamInd).values( + category=category, + name=name, + profile_id=self.profiles[profile], + value=value + ).on_conflict_do_update( + index_elements=(ParamInd.category, ParamInd.name, ParamInd.profile_id), + set_={ + ParamInd.value: value + } + ) + await session.execute(stmt) + await session.commit() + + def _jid_filter(self, jid_: jid.JID, dest: bool = False): + """Generate condition to filter on a JID, using relevant columns + + @param dest: True if it's the destinee JID, otherwise it's the source one + @param jid_: JID to filter by + """ + if jid_.resource: + if dest: + return and_( + History.dest == jid_.userhost(), + History.dest_res == jid_.resource + ) + else: + return and_( + History.source == jid_.userhost(), + History.source_res == jid_.resource + ) + else: + if dest: + return History.dest == jid_.userhost() + else: + return History.source == jid_.userhost() + + @aio + async def historyGet( + self, + from_jid: Optional[jid.JID], + to_jid: Optional[jid.JID], + limit: Optional[int] = None, + between: bool = True, + filters: Optional[Dict[str, str]] = None, + profile: Optional[str] = None, + ) -> List[Tuple[ + str, int, str, str, Dict[str, str], Dict[str, str], str, str, str] + ]: + """Retrieve messages in history + + @param from_jid: source JID (full, or bare for catchall) + @param to_jid: dest JID (full, or bare for catchall) + @param limit: maximum number of messages to get: + - 0 for no message (returns the empty list) + - None for unlimited + @param between: confound source and dest (ignore the direction) + @param filters: pattern to filter the history results + @return: list of messages as in [messageNew], minus the profile which is already + known. + """ + # we have to set a default value to profile because it's last argument + # and thus follow other keyword arguments with default values + # but None should not be used for it + assert profile is not None + if limit == 0: + return [] + if filters is None: + filters = {} + + stmt = ( + select(History) + .filter_by( + profile_id=self.profiles[profile] + ) + .outerjoin(History.messages) + .outerjoin(History.subjects) + .outerjoin(History.thread) + .options( + contains_eager(History.messages), + contains_eager(History.subjects), + contains_eager(History.thread), + ) + .order_by( + # timestamp may be identical for 2 close messages (specially when delay is + # used) that's why we order ties by received_timestamp. We'll reverse the + # order when returning the result. We use DESC here so LIMIT keep the last + # messages + History.timestamp.desc(), + History.received_timestamp.desc() + ) + ) + + + if not from_jid and not to_jid: + # no jid specified, we want all one2one communications + pass + elif between: + if not from_jid or not to_jid: + # we only have one jid specified, we check all messages + # from or to this jid + jid_ = from_jid or to_jid + stmt = stmt.where( + or_( + self._jid_filter(jid_), + self._jid_filter(jid_, dest=True) + ) + ) + else: + # we have 2 jids specified, we check all communications between + # those 2 jids + stmt = stmt.where( + or_( + and_( + self._jid_filter(from_jid), + self._jid_filter(to_jid, dest=True), + ), + and_( + self._jid_filter(to_jid), + self._jid_filter(from_jid, dest=True), + ) + ) + ) + else: + # we want one communication in specific direction (from somebody or + # to somebody). + if from_jid is not None: + stmt = stmt.where(self._jid_filter(from_jid)) + if to_jid is not None: + stmt = stmt.where(self._jid_filter(to_jid, dest=True)) + + if filters: + if 'timestamp_start' in filters: + stmt = stmt.where(History.timestamp >= float(filters['timestamp_start'])) + if 'before_uid' in filters: + # orignially this query was using SQLITE's rowid. This has been changed + # to use coalesce(received_timestamp, timestamp) to be SQL engine independant + stmt = stmt.where( + coalesce( + History.received_timestamp, + History.timestamp + ) < ( + select(coalesce(History.received_timestamp, History.timestamp)) + .filter_by(uid=filters["before_uid"]) + ).scalar_subquery() + ) + if 'body' in filters: + # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html + stmt = stmt.where(Message.message.like(f"%{filters['body']}%")) + if 'search' in filters: + search_term = f"%{filters['search']}%" + stmt = stmt.where(or_( + Message.message.like(search_term), + History.source_res.like(search_term) + )) + if 'types' in filters: + types = filters['types'].split() + stmt = stmt.where(History.type.in_(types)) + if 'not_types' in filters: + types = filters['not_types'].split() + stmt = stmt.where(History.type.not_in(types)) + if 'last_stanza_id' in filters: + # this request get the last message with a "stanza_id" that we + # have in history. This is mainly used to retrieve messages sent + # while we were offline, using MAM (XEP-0313). + if (filters['last_stanza_id'] is not True + or limit != 1): + raise ValueError("Unexpected values for last_stanza_id filter") + stmt = stmt.where(History.stanza_id.is_not(None)) + + if limit is not None: + stmt = stmt.limit(limit) + + async with self.session() as session: + result = await session.execute(stmt) + + result = result.scalars().unique().all() + result.reverse() + return [h.as_tuple() for h in result] + + @aio + async def addToHistory(self, data: dict, profile: str) -> None: + """Store a new message in history + + @param data: message data as build by SatMessageProtocol.onMessage + """ + extra = {k: v for k, v in data["extra"].items() if k not in NOT_IN_EXTRA} + messages = [Message(message=mess, language=lang) + for lang, mess in data["message"].items()] + subjects = [Subject(subject=mess, language=lang) + for lang, mess in data["subject"].items()] + if "thread" in data["extra"]: + thread = Thread(thread_id=data["extra"]["thread"], + parent_id=data["extra"].get["thread_parent"]) + else: + thread = None + try: + async with self.session() as session: + async with session.begin(): + session.add(History( + uid=data["uid"], + stanza_id=data["extra"].get("stanza_id"), + update_uid=data["extra"].get("update_uid"), + profile_id=self.profiles[profile], + source_jid=data["from"], + dest_jid=data["to"], + timestamp=data["timestamp"], + received_timestamp=data.get("received_timestamp"), + type=data["type"], + extra=extra, + messages=messages, + subjects=subjects, + thread=thread, + )) + except IntegrityError as e: + if "unique" in str(e.orig).lower(): + log.debug( + f"message {data['uid']!r} is already in history, not storing it again" + ) + else: + log.error(f"Can't store message {data['uid']!r} in history: {e}") + except Exception as e: + log.critical( + f"Can't store message, unexpected exception (uid: {data['uid']}): {e}" + ) + + ## Private values + + def _getPrivateClass(self, binary, profile): + """Get ORM class to use for private values""" + if profile is None: + return PrivateGenBin if binary else PrivateGen + else: + return PrivateIndBin if binary else PrivateInd + + + @aio + async def getPrivates( + self, + namespace:str, + keys: Optional[Iterable[str]] = None, + binary: bool = False, + profile: Optional[str] = None + ) -> Dict[str, Any]: + """Get private value(s) from databases + + @param namespace: namespace of the values + @param keys: keys of the values to get None to get all keys/values + @param binary: True to deserialise binary values + @param profile: profile to use for individual values + None to use general values + @return: gotten keys/values + """ + if keys is not None: + keys = list(keys) + log.debug( + f"getting {'general' if profile is None else 'individual'}" + f"{' binary' if binary else ''} private values from database for namespace " + f"{namespace}{f' with keys {keys!r}' if keys is not None else ''}" + ) + cls = self._getPrivateClass(binary, profile) + stmt = select(cls).filter_by(namespace=namespace) + if keys: + stmt = stmt.where(cls.key.in_(list(keys))) + if profile is not None: + stmt = stmt.filter_by(profile_id=self.profiles[profile]) + async with self.session() as session: + result = await session.execute(stmt) + return {p.key: p.value for p in result.scalars()} + + @aio + async def setPrivateValue( + self, + namespace: str, + key:str, + value: Any, + binary: bool = False, + profile: Optional[str] = None + ) -> None: + """Set a private value in database + + @param namespace: namespace of the values + @param key: key of the value to set + @param value: value to set + @param binary: True if it's a binary values + binary values need to be serialised, used for everything but strings + @param profile: profile to use for individual value + if None, it's a general value + """ + cls = self._getPrivateClass(binary, profile) + + values = { + "namespace": namespace, + "key": key, + "value": value + } + index_elements = [cls.namespace, cls.key] + + if profile is not None: + values["profile_id"] = self.profiles[profile] + index_elements.append(cls.profile_id) + + async with self.session() as session: + await session.execute( + insert(cls).values(**values).on_conflict_do_update( + index_elements=index_elements, + set_={ + cls.value: value + } + ) + ) + await session.commit() + + @aio + async def delPrivateValue( + self, + namespace: str, + key: str, + binary: bool = False, + profile: Optional[str] = None + ) -> None: + """Delete private value from database + + @param category: category of the privateeter + @param key: key of the private value + @param binary: True if it's a binary values + @param profile: profile to use for individual value + if None, it's a general value + """ + cls = self._getPrivateClass(binary, profile) + + stmt = delete(cls).filter_by(namespace=namespace, key=key) + + if profile is not None: + stmt = stmt.filter_by(profile_id=self.profiles[profile]) + + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + @aio + async def delPrivateNamespace( + self, + namespace: str, + binary: bool = False, + profile: Optional[str] = None + ) -> None: + """Delete all data from a private namespace + + Be really cautious when you use this method, as all data with given namespace are + removed. + Params are the same as for delPrivateValue + """ + cls = self._getPrivateClass(binary, profile) + + stmt = delete(cls).filter_by(namespace=namespace) + + if profile is not None: + stmt = stmt.filter_by(profile_id=self.profiles[profile]) + + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + ## Files + + @aio + async def getFiles( + self, + client: Optional[SatXMPPEntity], + file_id: Optional[str] = None, + version: Optional[str] = '', + parent: Optional[str] = None, + type_: Optional[str] = None, + file_hash: Optional[str] = None, + hash_algo: Optional[str] = None, + name: Optional[str] = None, + namespace: Optional[str] = None, + mime_type: Optional[str] = None, + public_id: Optional[str] = None, + owner: Optional[jid.JID] = None, + access: Optional[dict] = None, + projection: Optional[List[str]] = None, + unique: bool = False + ) -> List[dict]: + """Retrieve files with with given filters + + @param file_id: id of the file + None to ignore + @param version: version of the file + None to ignore + empty string to look for current version + @param parent: id of the directory containing the files + None to ignore + empty string to look for root files/directories + @param projection: name of columns to retrieve + None to retrieve all + @param unique: if True will remove duplicates + other params are the same as for [setFile] + @return: files corresponding to filters + """ + if projection is None: + projection = [ + 'id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name', + 'size', 'namespace', 'media_type', 'media_subtype', 'public_id', + 'created', 'modified', 'owner', 'access', 'extra' + ] + + stmt = select(*[getattr(File, f) for f in projection]) + + if unique: + stmt = stmt.distinct() + + if client is not None: + stmt = stmt.filter_by(profile_id=self.profiles[client.profile]) + else: + if public_id is None: + raise exceptions.InternalError( + "client can only be omitted when public_id is set" + ) + if file_id is not None: + stmt = stmt.filter_by(id=file_id) + if version is not None: + stmt = stmt.filter_by(version=version) + if parent is not None: + stmt = stmt.filter_by(parent=parent) + if type_ is not None: + stmt = stmt.filter_by(type=type_) + if file_hash is not None: + stmt = stmt.filter_by(file_hash=file_hash) + if hash_algo is not None: + stmt = stmt.filter_by(hash_algo=hash_algo) + if name is not None: + stmt = stmt.filter_by(name=name) + if namespace is not None: + stmt = stmt.filter_by(namespace=namespace) + if mime_type is not None: + if '/' in mime_type: + media_type, media_subtype = mime_type.split("/", 1) + stmt = stmt.filter_by(media_type=media_type, media_subtype=media_subtype) + else: + stmt = stmt.filter_by(media_type=mime_type) + if public_id is not None: + stmt = stmt.filter_by(public_id=public_id) + if owner is not None: + stmt = stmt.filter_by(owner=owner) + if access is not None: + raise NotImplementedError('Access check is not implemented yet') + # a JSON comparison is needed here + + async with self.session() as session: + result = await session.execute(stmt) + + return [dict(r) for r in result] + + @aio + async def setFile( + self, + client: SatXMPPEntity, + name: str, + file_id: str, + version: str = "", + parent: str = "", + type_: str = C.FILE_TYPE_FILE, + file_hash: Optional[str] = None, + hash_algo: Optional[str] = None, + size: int = None, + namespace: Optional[str] = None, + mime_type: Optional[str] = None, + public_id: Optional[str] = None, + created: Optional[float] = None, + modified: Optional[float] = None, + owner: Optional[jid.JID] = None, + access: Optional[dict] = None, + extra: Optional[dict] = None + ) -> None: + """Set a file metadata + + @param client: client owning the file + @param name: name of the file (must not contain "/") + @param file_id: unique id of the file + @param version: version of this file + @param parent: id of the directory containing this file + Empty string if it is a root file/directory + @param type_: one of: + - file + - directory + @param file_hash: unique hash of the payload + @param hash_algo: algorithm used for hashing the file (usually sha-256) + @param size: size in bytes + @param namespace: identifier (human readable is better) to group files + for instance, namespace could be used to group files in a specific photo album + @param mime_type: media type of the file, or None if not known/guessed + @param public_id: ID used to server the file publicly via HTTP + @param created: UNIX time of creation + @param modified: UNIX time of last modification, or None to use created date + @param owner: jid of the owner of the file (mainly useful for component) + @param access: serialisable dictionary with access rules. See [memory.memory] for + details + @param extra: serialisable dictionary of any extra data + will be encoded to json in database + """ + if mime_type is None: + media_type = media_subtype = None + elif '/' in mime_type: + media_type, media_subtype = mime_type.split('/', 1) + else: + media_type, media_subtype = mime_type, None + + async with self.session() as session: + async with session.begin(): + session.add(File( + id=file_id, + version=version.strip(), + parent=parent, + type=type_, + file_hash=file_hash, + hash_algo=hash_algo, + name=name, + size=size, + namespace=namespace, + media_type=media_type, + media_subtype=media_subtype, + public_id=public_id, + created=time.time() if created is None else created, + modified=modified, + owner=owner, + access=access, + extra=extra, + profile_id=self.profiles[client.profile] + )) + + @aio + async def fileGetUsedSpace(self, client: SatXMPPEntity, owner: jid.JID) -> int: + async with self.session() as session: + result = await session.execute( + select(sum_(File.size)).filter_by( + owner=owner, + type=C.FILE_TYPE_FILE, + profile_id=self.profiles[client.profile] + )) + return result.scalar_one_or_none() or 0 + + @aio + async def fileDelete(self, file_id: str) -> None: + """Delete file metadata from the database + + @param file_id: id of the file to delete + NOTE: file itself must still be removed, this method only handle metadata in + database + """ + async with self.session() as session: + await session.execute(delete(File).filter_by(id=file_id)) + await session.commit() + + @aio + async def fileUpdate( + self, + file_id: str, + column: str, + update_cb: Callable[[dict], None] + ) -> None: + """Update a column value using a method to avoid race conditions + + the older value will be retrieved from database, then update_cb will be applied to + update it, and file will be updated checking that older value has not been changed + meanwhile by an other user. If it has changed, it tries again a couple of times + before failing + @param column: column name (only "access" or "extra" are allowed) + @param update_cb: method to update the value of the colum + the method will take older value as argument, and must update it in place + update_cb must not care about serialization, + it get the deserialized data (i.e. a Python object) directly + @raise exceptions.NotFound: there is not file with this id + """ + if column not in ('access', 'extra'): + raise exceptions.InternalError('bad column name') + orm_col = getattr(File, column) + + for i in range(5): + async with self.session() as session: + try: + value = (await session.execute( + select(orm_col).filter_by(id=file_id) + )).scalar_one() + except NoResultFound: + raise exceptions.NotFound + update_cb(value) + stmt = update(orm_col).filter_by(id=file_id) + if not value: + # because JsonDefaultDict convert NULL to an empty dict, we have to + # test both for empty dict and None when we have and empty dict + stmt = stmt.where((orm_col == None) | (orm_col == value)) + else: + stmt = stmt.where(orm_col == value) + result = await session.execute(stmt) + await session.commit() + + if result.rowcount == 1: + break + + log.warning( + _("table not updated, probably due to race condition, trying again " + "({tries})").format(tries=i+1) + ) + + else: + raise exceptions.DatabaseError( + _("Can't update file {file_id} due to race condition") + .format(file_id=file_id) + ) + + @aio + async def getPubsubNode( + self, + client: SatXMPPEntity, + service: jid.JID, + name: str, + with_items: bool = False, + ) -> Optional[PubsubNode]: + """ + """ + async with self.session() as session: + stmt = ( + select(PubsubNode) + .filter_by( + service=service, + name=name, + profile_id=self.profiles[client.profile], + ) + ) + if with_items: + stmt = stmt.options( + joinedload(PubsubNode.items) + ) + result = await session.execute(stmt) + return result.unique().scalar_one_or_none() + + @aio + async def setPubsubNode( + self, + client: SatXMPPEntity, + service: jid.JID, + name: str, + analyser: Optional[str] = None, + type_: Optional[str] = None, + subtype: Optional[str] = None, + ) -> PubsubNode: + node = PubsubNode( + profile_id=self.profiles[client.profile], + service=service, + name=name, + subscribed=False, + analyser=analyser, + type_=type_, + subtype=subtype, + ) + async with self.session() as session: + async with session.begin(): + session.add(node) + return node + + @aio + async def updatePubsubNodeSyncState( + self, + node: PubsubNode, + state: SyncState + ) -> None: + async with self.session() as session: + async with session.begin(): + await session.execute( + update(PubsubNode) + .filter_by(id=node.id) + .values( + sync_state=state, + sync_state_updated=time.time(), + ) + ) + + @aio + async def deletePubsubNode( + self, + profiles: Optional[List[str]], + services: Optional[List[jid.JID]], + names: Optional[List[str]] + ) -> None: + """Delete items cached for a node + + @param profiles: profile names from which nodes must be deleted. + None to remove nodes from ALL profiles + @param services: JIDs of pubsub services from which nodes must be deleted. + None to remove nodes from ALL services + @param names: names of nodes which must be deleted. + None to remove ALL nodes whatever is their names + """ + stmt = delete(PubsubNode) + if profiles is not None: + stmt = stmt.where( + PubsubNode.profile.in_( + [self.profiles[p] for p in profiles] + ) + ) + if services is not None: + stmt = stmt.where(PubsubNode.service.in_(services)) + if names is not None: + stmt = stmt.where(PubsubNode.name.in_(names)) + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + @aio + async def cachePubsubItems( + self, + client: SatXMPPEntity, + node: PubsubNode, + items: List[domish.Element], + parsed_items: Optional[List[dict]] = None, + ) -> None: + """Add items to database, using an upsert taking care of "updated" field""" + if parsed_items is not None and len(items) != len(parsed_items): + raise exceptions.InternalError( + "parsed_items must have the same lenght as items" + ) + async with self.session() as session: + async with session.begin(): + for idx, item in enumerate(items): + parsed = parsed_items[idx] if parsed_items else None + stmt = insert(PubsubItem).values( + node_id = node.id, + name = item["id"], + data = item, + parsed = parsed, + ).on_conflict_do_update( + index_elements=(PubsubItem.node_id, PubsubItem.name), + set_={ + PubsubItem.data: item, + PubsubItem.parsed: parsed, + PubsubItem.updated: now() + } + ) + await session.execute(stmt) + await session.commit() + + @aio + async def deletePubsubItems( + self, + node: PubsubNode, + items_names: Optional[List[str]] = None + ) -> None: + """Delete items cached for a node + + @param node: node from which items must be deleted + @param items_names: names of items to delete + if None, ALL items will be deleted + """ + stmt = delete(PubsubItem) + if node is not None: + if isinstance(node, list): + stmt = stmt.where(PubsubItem.node_id.in_([n.id for n in node])) + else: + stmt = stmt.filter_by(node_id=node.id) + if items_names is not None: + stmt = stmt.where(PubsubItem.name.in_(items_names)) + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + @aio + async def purgePubsubItems( + self, + services: Optional[List[jid.JID]] = None, + names: Optional[List[str]] = None, + types: Optional[List[str]] = None, + subtypes: Optional[List[str]] = None, + profiles: Optional[List[str]] = None, + created_before: Optional[datetime] = None, + updated_before: Optional[datetime] = None, + ) -> None: + """Delete items cached for a node + + @param node: node from which items must be deleted + @param items_names: names of items to delete + if None, ALL items will be deleted + """ + stmt = delete(PubsubItem) + node_fields = { + "service": services, + "name": names, + "type_": types, + "subtype": subtypes, + } + if any(x is not None for x in node_fields.values()): + sub_q = select(PubsubNode.id) + for col, values in node_fields.items(): + if values is None: + continue + sub_q = sub_q.where(getattr(PubsubNode, col).in_(values)) + stmt = ( + stmt + .where(PubsubItem.node_id.in_(sub_q)) + .execution_options(synchronize_session=False) + ) + + if profiles is not None: + stmt = stmt.where( + PubsubItem.profile_id.in_([self.profiles[p] for p in profiles]) + ) + + if created_before is not None: + stmt = stmt.where(PubsubItem.created < created_before) + + if updated_before is not None: + stmt = stmt.where(PubsubItem.updated < updated_before) + + async with self.session() as session: + await session.execute(stmt) + await session.commit() + + @aio + async def getItems( + self, + node: PubsubNode, + max_items: Optional[int] = None, + before: Optional[str] = None, + after: Optional[str] = None, + from_index: Optional[int] = None, + order_by: Optional[List[str]] = None, + desc: bool = True, + force_rsm: bool = False, + ) -> Tuple[List[PubsubItem], dict]: + """Get Pubsub Items from cache + + @param node: retrieve items from this node (must be synchronised) + @param max_items: maximum number of items to retrieve + @param before: get items which are before the item with this name in given order + empty string is not managed here, use desc order to reproduce RSM + behaviour. + @param after: get items which are after the item with this name in given order + @param from_index: get items with item index (as defined in RSM spec) + starting from this number + @param order_by: sorting order of items (one of C.ORDER_BY_*) + @param desc: direction or ordering + @param force_rsm: if True, force the use of RSM worklow. + RSM workflow is automatically used if any of before, after or + from_index is used, but if only RSM max_items is used, it won't be + used by default. This parameter let's use RSM workflow in this + case. Note that in addition to RSM metadata, the result will not be + the same (max_items without RSM will returns most recent items, + i.e. last items in modification order, while max_items with RSM + will return the oldest ones (i.e. first items in modification + order). + to be used when max_items is used from RSM + """ + + metadata = { + "service": node.service, + "node": node.name, + "uri": uri.buildXMPPUri( + "pubsub", + path=node.service.full(), + node=node.name, + ), + } + if max_items is None: + max_items = 20 + + use_rsm = any((before, after, from_index is not None)) + if force_rsm and not use_rsm: + # + use_rsm = True + from_index = 0 + + stmt = ( + select(PubsubItem) + .filter_by(node_id=node.id) + .limit(max_items) + ) + + if not order_by: + order_by = [C.ORDER_BY_MODIFICATION] + + order = [] + for order_type in order_by: + if order_type == C.ORDER_BY_MODIFICATION: + if desc: + order.extend((PubsubItem.updated.desc(), PubsubItem.id.desc())) + else: + order.extend((PubsubItem.updated.asc(), PubsubItem.id.asc())) + elif order_type == C.ORDER_BY_CREATION: + if desc: + order.append(PubsubItem.id.desc()) + else: + order.append(PubsubItem.id.asc()) + else: + raise exceptions.InternalError(f"Unknown order type {order_type!r}") + + stmt = stmt.order_by(*order) + + if use_rsm: + # CTE to have result row numbers + row_num_q = select( + PubsubItem.id, + PubsubItem.name, + # row_number starts from 1, but RSM index must start from 0 + (func.row_number().over(order_by=order)-1).label("item_index") + ).filter_by(node_id=node.id) + + row_num_cte = row_num_q.cte() + + if max_items > 0: + # as we can't simply use PubsubItem.id when we order by modification, + # we need to use row number + item_name = before or after + row_num_limit_q = ( + select(row_num_cte.c.item_index) + .where(row_num_cte.c.name==item_name) + ).scalar_subquery() + + stmt = ( + select(row_num_cte.c.item_index, PubsubItem) + .join(row_num_cte, PubsubItem.id == row_num_cte.c.id) + .limit(max_items) + ) + if before: + stmt = ( + stmt + .where(row_num_cte.c.item_index<row_num_limit_q) + .order_by(row_num_cte.c.item_index.desc()) + ) + elif after: + stmt = ( + stmt + .where(row_num_cte.c.item_index>row_num_limit_q) + .order_by(row_num_cte.c.item_index.asc()) + ) + else: + stmt = ( + stmt + .where(row_num_cte.c.item_index>=from_index) + .order_by(row_num_cte.c.item_index.asc()) + ) + # from_index is used + + async with self.session() as session: + if max_items == 0: + items = result = [] + else: + result = await session.execute(stmt) + result = result.all() + if before: + result.reverse() + items = [row[-1] for row in result] + rows_count = ( + await session.execute(row_num_q.with_only_columns(count())) + ).scalar_one() + + try: + index = result[0][0] + except IndexError: + index = None + + try: + first = result[0][1].name + except IndexError: + first = None + last = None + else: + last = result[-1][1].name + + + metadata["rsm"] = { + "index": index, + "count": rows_count, + "first": first, + "last": last, + } + metadata["complete"] = index + len(result) == rows_count + + return items, metadata + + async with self.session() as session: + result = await session.execute(stmt) + + result = result.scalars().all() + if desc: + result.reverse() + return result, metadata
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sat/memory/sqla_config.py Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +# Libervia: an XMPP client +# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +from pathlib import Path +from urllib.parse import quote +from sat.core.constants import Const as C +from sat.tools import config + + +def getDbConfig() -> dict: + """Get configuration for database + + @return: dict with following keys: + - type: only "sqlite" for now + - path: path to the sqlite DB + """ + main_conf = config.parseMainConf() + local_dir = Path(config.getConfig(main_conf, "", "local_dir")) + database_path = (local_dir / C.SAVEFILE_DATABASE).expanduser() + url = f"sqlite+aiosqlite:///{quote(str(database_path))}" + return { + "type": "sqlite", + "path": database_path, + "url": url, + }
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sat/memory/sqla_mapping.py Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,540 @@ +#!/usr/bin/env python3 + +# Libervia: an XMPP client +# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import pickle +import json +from datetime import datetime +import time +import enum +from sqlalchemy import ( + MetaData, Column, Integer, Text, Float, Boolean, DateTime, Enum, JSON, ForeignKey, + UniqueConstraint, Index +) + +from sqlalchemy.orm import declarative_base, relationship +from sqlalchemy.types import TypeDecorator +from sqlalchemy.sql.functions import now +from twisted.words.protocols.jabber import jid +from wokkel import generic + + +Base = declarative_base( + metadata=MetaData( + naming_convention={ + "ix": 'ix_%(column_0_label)s', + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s" + } + ) +) +# keys which are in message data extra but not stored in extra field this is +# because those values are stored in separate fields +NOT_IN_EXTRA = ('stanza_id', 'received_timestamp', 'update_uid') + + +class SyncState(enum.Enum): + #: synchronisation is currently in progress + IN_PROGRESS = 1 + #: synchronisation is done + COMPLETED = 2 + #: something wrong happened during synchronisation, won't sync + ERROR = 3 + #: synchronisation won't be done even if a syncing analyser match + NO_SYNC = 4 + + +class LegacyPickle(TypeDecorator): + """Handle troubles with data pickled by former version of SàT + + This type is temporary until we do migration to a proper data type + """ + # Blob is used on SQLite but gives errors when used here, while Text works fine + impl = Text + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return None + return pickle.dumps(value, 0) + + def process_result_value(self, value, dialect): + if value is None: + return None + # value types are inconsistent (probably a consequence of Python 2/3 port + # and/or SQLite dynamic typing) + try: + value = value.encode() + except AttributeError: + pass + # "utf-8" encoding is needed to handle Python 2 pickled data + return pickle.loads(value, encoding="utf-8") + + +class Json(TypeDecorator): + """Handle JSON field in DB independant way""" + # Blob is used on SQLite but gives errors when used here, while Text works fine + impl = Text + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return None + return json.dumps(value) + + def process_result_value(self, value, dialect): + if value is None: + return None + return json.loads(value) + + +class JsonDefaultDict(Json): + """Json type which convert NULL to empty dict instead of None""" + + def process_result_value(self, value, dialect): + if value is None: + return {} + return json.loads(value) + + +class Xml(TypeDecorator): + impl = Text + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return None + return value.toXml() + + def process_result_value(self, value, dialect): + if value is None: + return None + return generic.parseXml(value.encode()) + + +class JID(TypeDecorator): + """Store twisted JID in text fields""" + impl = Text + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is None: + return None + return value.full() + + def process_result_value(self, value, dialect): + if value is None: + return None + return jid.JID(value) + + +class Profile(Base): + __tablename__ = "profiles" + + id = Column( + Integer, + primary_key=True, + nullable=True, + # settings autoincrement would have negative performance impact + # cf. https://sqlite.org/autoinc.html + autoincrement=False + ) + name = Column(Text, unique=True) + + params = relationship("ParamInd", back_populates="profile", passive_deletes=True) + private_data = relationship( + "PrivateInd", back_populates="profile", passive_deletes=True + ) + private_bin_data = relationship( + "PrivateIndBin", back_populates="profile", passive_deletes=True + ) + + +class Component(Base): + __tablename__ = "components" + + profile_id = Column( + ForeignKey("profiles.id", ondelete="CASCADE"), + nullable=True, + primary_key=True + ) + entry_point = Column(Text, nullable=False) + profile = relationship("Profile") + + +class History(Base): + __tablename__ = "history" + __table_args__ = ( + UniqueConstraint("profile_id", "stanza_id", "source", "dest"), + Index("history__profile_id_timestamp", "profile_id", "timestamp"), + Index( + "history__profile_id_received_timestamp", "profile_id", "received_timestamp" + ) + ) + + uid = Column(Text, primary_key=True) + stanza_id = Column(Text) + update_uid = Column(Text) + profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE")) + source = Column(Text) + dest = Column(Text) + source_res = Column(Text) + dest_res = Column(Text) + timestamp = Column(Float, nullable=False) + received_timestamp = Column(Float) + type = Column( + Enum( + "chat", + "error", + "groupchat", + "headline", + "normal", + # info is not XMPP standard, but used to keep track of info like join/leave + # in a MUC + "info", + name="message_type", + create_constraint=True, + ), + nullable=False, + ) + extra = Column(LegacyPickle) + + profile = relationship("Profile") + messages = relationship("Message", backref="history", passive_deletes=True) + subjects = relationship("Subject", backref="history", passive_deletes=True) + thread = relationship( + "Thread", uselist=False, back_populates="history", passive_deletes=True + ) + + def __init__(self, *args, **kwargs): + source_jid = kwargs.pop("source_jid", None) + if source_jid is not None: + kwargs["source"] = source_jid.userhost() + kwargs["source_res"] = source_jid.resource + dest_jid = kwargs.pop("dest_jid", None) + if dest_jid is not None: + kwargs["dest"] = dest_jid.userhost() + kwargs["dest_res"] = dest_jid.resource + super().__init__(*args, **kwargs) + + @property + def source_jid(self) -> jid.JID: + return jid.JID(f"{self.source}/{self.source_res or ''}") + + @source_jid.setter + def source_jid(self, source_jid: jid.JID) -> None: + self.source = source_jid.userhost + self.source_res = source_jid.resource + + @property + def dest_jid(self): + return jid.JID(f"{self.dest}/{self.dest_res or ''}") + + @dest_jid.setter + def dest_jid(self, dest_jid: jid.JID) -> None: + self.dest = dest_jid.userhost + self.dest_res = dest_jid.resource + + def __repr__(self): + dt = datetime.fromtimestamp(self.timestamp) + return f"History<{self.source_jid.full()}->{self.dest_jid.full()} [{dt}]>" + + def serialise(self): + extra = self.extra + if self.stanza_id is not None: + extra["stanza_id"] = self.stanza_id + if self.update_uid is not None: + extra["update_uid"] = self.update_uid + if self.received_timestamp is not None: + extra["received_timestamp"] = self.received_timestamp + if self.thread is not None: + extra["thread"] = self.thread.thread_id + if self.thread.parent_id is not None: + extra["thread_parent"] = self.thread.parent_id + + + return { + "from": f"{self.source}/{self.source_res}" if self.source_res + else self.source, + "to": f"{self.dest}/{self.dest_res}" if self.dest_res else self.dest, + "uid": self.uid, + "message": {m.language or '': m.message for m in self.messages}, + "subject": {m.language or '': m.subject for m in self.subjects}, + "type": self.type, + "extra": extra, + "timestamp": self.timestamp, + } + + def as_tuple(self): + d = self.serialise() + return ( + d['uid'], d['timestamp'], d['from'], d['to'], d['message'], d['subject'], + d['type'], d['extra'] + ) + + @staticmethod + def debug_collection(history_collection): + for idx, history in enumerate(history_collection): + history.debug_msg(idx) + + def debug_msg(self, idx=None): + """Print messages""" + dt = datetime.fromtimestamp(self.timestamp) + if idx is not None: + dt = f"({idx}) {dt}" + parts = [] + parts.append(f"[{dt}]<{self.source_jid.full()}->{self.dest_jid.full()}> ") + for message in self.messages: + if message.language: + parts.append(f"[{message.language}] ") + parts.append(f"{message.message}\n") + print("".join(parts)) + + +class Message(Base): + __tablename__ = "message" + __table_args__ = ( + Index("message__history_uid", "history_uid"), + ) + + id = Column( + Integer, + primary_key=True, + # cf. note for Profile.id + autoincrement=False + ) + history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE")) + message = Column(Text) + language = Column(Text) + + def __repr__(self): + lang_str = f"[{self.language}]" if self.language else "" + msg = f"{self.message[:20]}…" if len(self.message)>20 else self.message + content = f"{lang_str}{msg}" + return f"Message<{content}>" + + +class Subject(Base): + __tablename__ = "subject" + __table_args__ = ( + Index("subject__history_uid", "history_uid"), + ) + + id = Column( + Integer, + primary_key=True, + # cf. note for Profile.id + autoincrement=False, + ) + history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE")) + subject = Column(Text) + language = Column(Text) + + def __repr__(self): + lang_str = f"[{self.language}]" if self.language else "" + msg = f"{self.subject[:20]}…" if len(self.subject)>20 else self.subject + content = f"{lang_str}{msg}" + return f"Subject<{content}>" + + +class Thread(Base): + __tablename__ = "thread" + __table_args__ = ( + Index("thread__history_uid", "history_uid"), + ) + + id = Column( + Integer, + primary_key=True, + # cf. note for Profile.id + autoincrement=False, + ) + history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE")) + thread_id = Column(Text) + parent_id = Column(Text) + + history = relationship("History", uselist=False, back_populates="thread") + + def __repr__(self): + return f"Thread<{self.thread_id} [parent: {self.parent_id}]>" + + +class ParamGen(Base): + __tablename__ = "param_gen" + + category = Column(Text, primary_key=True) + name = Column(Text, primary_key=True) + value = Column(Text) + + +class ParamInd(Base): + __tablename__ = "param_ind" + + category = Column(Text, primary_key=True) + name = Column(Text, primary_key=True) + profile_id = Column( + ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True + ) + value = Column(Text) + + profile = relationship("Profile", back_populates="params") + + +class PrivateGen(Base): + __tablename__ = "private_gen" + + namespace = Column(Text, primary_key=True) + key = Column(Text, primary_key=True) + value = Column(Text) + + +class PrivateInd(Base): + __tablename__ = "private_ind" + + namespace = Column(Text, primary_key=True) + key = Column(Text, primary_key=True) + profile_id = Column( + ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True + ) + value = Column(Text) + + profile = relationship("Profile", back_populates="private_data") + + +class PrivateGenBin(Base): + __tablename__ = "private_gen_bin" + + namespace = Column(Text, primary_key=True) + key = Column(Text, primary_key=True) + value = Column(LegacyPickle) + + +class PrivateIndBin(Base): + __tablename__ = "private_ind_bin" + + namespace = Column(Text, primary_key=True) + key = Column(Text, primary_key=True) + profile_id = Column( + ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True + ) + value = Column(LegacyPickle) + + profile = relationship("Profile", back_populates="private_bin_data") + + +class File(Base): + __tablename__ = "files" + __table_args__ = ( + Index("files__profile_id_owner_parent", "profile_id", "owner", "parent"), + Index( + "files__profile_id_owner_media_type_media_subtype", + "profile_id", + "owner", + "media_type", + "media_subtype" + ) + ) + + id = Column(Text, primary_key=True) + public_id = Column(Text, unique=True) + version = Column(Text, primary_key=True) + parent = Column(Text, nullable=False) + type = Column( + Enum( + "file", "directory", + name="file_type", + create_constraint=True + ), + nullable=False, + server_default="file", + ) + file_hash = Column(Text) + hash_algo = Column(Text) + name = Column(Text, nullable=False) + size = Column(Integer) + namespace = Column(Text) + media_type = Column(Text) + media_subtype = Column(Text) + created = Column(Float, nullable=False) + modified = Column(Float) + owner = Column(JID) + access = Column(JsonDefaultDict) + extra = Column(JsonDefaultDict) + profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE")) + + profile = relationship("Profile") + + +class PubsubNode(Base): + __tablename__ = "pubsub_nodes" + __table_args__ = ( + UniqueConstraint("profile_id", "service", "name"), + ) + + id = Column(Integer, primary_key=True) + profile_id = Column( + ForeignKey("profiles.id", ondelete="CASCADE") + ) + service = Column(JID) + name = Column(Text, nullable=False) + subscribed = Column( + Boolean(create_constraint=True, name="subscribed_bool"), nullable=False + ) + analyser = Column(Text) + sync_state = Column( + Enum( + SyncState, + name="sync_state", + create_constraint=True, + ), + nullable=True + ) + sync_state_updated = Column( + Float, + nullable=False, + default=time.time() + ) + type_ = Column( + Text, name="type", nullable=True + ) + subtype = Column( + Text, nullable=True + ) + extra = Column(JSON) + + items = relationship("PubsubItem", back_populates="node", passive_deletes=True) + + def __str__(self): + return f"Pubsub node {self.name!r} at {self.service}" + + +class PubsubItem(Base): + __tablename__ = "pubsub_items" + __table_args__ = ( + UniqueConstraint("node_id", "name"), + ) + id = Column(Integer, primary_key=True) + node_id = Column(ForeignKey("pubsub_nodes.id", ondelete="CASCADE"), nullable=False) + name = Column(Text, nullable=False) + data = Column(Xml, nullable=False) + created = Column(DateTime, nullable=False, server_default=now()) + updated = Column(DateTime, nullable=False, server_default=now(), onupdate=now()) + parsed = Column(JSON) + + node = relationship("PubsubNode", back_populates="items")
--- a/sat/memory/sqlite.py Wed Sep 01 13:46:43 2021 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,1766 +0,0 @@ -#!/usr/bin/env python3 - - -# SAT: a jabber client -# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) - -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. - -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. - -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see <http://www.gnu.org/licenses/>. - -from sat.core.i18n import _ -from sat.core.constants import Const as C -from sat.core import exceptions -from sat.core.log import getLogger -from sat.memory.crypto import BlockCipher, PasswordHasher -from sat.tools.config import fixConfigOption -from twisted.enterprise import adbapi -from twisted.internet import defer -from twisted.words.protocols.jabber import jid -from twisted.python import failure -from collections import OrderedDict -import sys -import re -import os.path -import pickle as pickle -import hashlib -import sqlite3 -import json - -log = getLogger(__name__) - -CURRENT_DB_VERSION = 9 - -# XXX: DATABASE schemas are used in the following way: -# - 'current' key is for the actual database schema, for a new base -# - x(int) is for update needed between x-1 and x. All number are needed between y and z to do an update -# e.g.: if CURRENT_DB_VERSION is 6, 'current' is the actuel DB, and to update from version 3, numbers 4, 5 and 6 are needed -# a 'current' data dict can contains the keys: -# - 'CREATE': it contains an Ordered dict with table to create as keys, and a len 2 tuple as value, where value[0] are the columns definitions and value[1] are the table constraints -# - 'INSERT': it contains an Ordered dict with table where values have to be inserted, and many tuples containing values to insert in the order of the rows (#TODO: manage named columns) -# - 'INDEX': -# an update data dict (the ones with a number) can contains the keys 'create', 'delete', 'cols create', 'cols delete', 'cols modify', 'insert' or 'specific'. See Updater.generateUpdateData for more infos. This method can be used to autogenerate update_data, to ease the work of the developers. -# TODO: indexes need to be improved - -DATABASE_SCHEMAS = { - "current": {'CREATE': OrderedDict(( - ('profiles', (("id INTEGER PRIMARY KEY ASC", "name TEXT"), - ("UNIQUE (name)",))), - ('components', (("profile_id INTEGER PRIMARY KEY", "entry_point TEXT NOT NULL"), - ("FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE",))), - ('message_types', (("type TEXT PRIMARY KEY",), - ())), - ('history', (("uid TEXT PRIMARY KEY", "stanza_id TEXT", "update_uid TEXT", "profile_id INTEGER", "source TEXT", "dest TEXT", "source_res TEXT", "dest_res TEXT", - "timestamp DATETIME NOT NULL", "received_timestamp DATETIME", # XXX: timestamp is the time when the message was emitted. If received time stamp is not NULL, the message was delayed and timestamp is the declared value (and received_timestamp the time of reception) - "type TEXT", "extra BLOB"), - ("FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE", "FOREIGN KEY(type) REFERENCES message_types(type)", - "UNIQUE (profile_id, stanza_id, source, dest)" # avoid storing 2 times the same message - ))), - ('message', (("id INTEGER PRIMARY KEY ASC", "history_uid INTEGER", "message TEXT", "language TEXT"), - ("FOREIGN KEY(history_uid) REFERENCES history(uid) ON DELETE CASCADE",))), - ('subject', (("id INTEGER PRIMARY KEY ASC", "history_uid INTEGER", "subject TEXT", "language TEXT"), - ("FOREIGN KEY(history_uid) REFERENCES history(uid) ON DELETE CASCADE",))), - ('thread', (("id INTEGER PRIMARY KEY ASC", "history_uid INTEGER", "thread_id TEXT", "parent_id TEXT"),("FOREIGN KEY(history_uid) REFERENCES history(uid) ON DELETE CASCADE",))), - ('param_gen', (("category TEXT", "name TEXT", "value TEXT"), - ("PRIMARY KEY (category, name)",))), - ('param_ind', (("category TEXT", "name TEXT", "profile_id INTEGER", "value TEXT"), - ("PRIMARY KEY (profile_id, category, name)", "FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE"))), - ('private_gen', (("namespace TEXT", "key TEXT", "value TEXT"), - ("PRIMARY KEY (namespace, key)",))), - ('private_ind', (("namespace TEXT", "key TEXT", "profile_id INTEGER", "value TEXT"), - ("PRIMARY KEY (profile_id, namespace, key)", "FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE"))), - ('private_gen_bin', (("namespace TEXT", "key TEXT", "value BLOB"), - ("PRIMARY KEY (namespace, key)",))), - ('private_ind_bin', (("namespace TEXT", "key TEXT", "profile_id INTEGER", "value BLOB"), - ("PRIMARY KEY (profile_id, namespace, key)", "FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE"))), - ('files', (("id TEXT NOT NULL", "public_id TEXT", "version TEXT NOT NULL", - "parent TEXT NOT NULL", - "type TEXT CHECK(type in ('{file}', '{directory}')) NOT NULL DEFAULT '{file}'".format( - file=C.FILE_TYPE_FILE, directory=C.FILE_TYPE_DIRECTORY), - "file_hash TEXT", "hash_algo TEXT", "name TEXT NOT NULL", "size INTEGER", - "namespace TEXT", "media_type TEXT", "media_subtype TEXT", - "created DATETIME NOT NULL", "modified DATETIME", - "owner TEXT", "access TEXT", "extra TEXT", "profile_id INTEGER"), - ("PRIMARY KEY (id, version)", "FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE", - "UNIQUE (public_id)"))), - )), - 'INSERT': OrderedDict(( - ('message_types', (("'chat'",), - ("'error'",), - ("'groupchat'",), - ("'headline'",), - ("'normal'",), - ("'info'",) # info is not standard, but used to keep track of info like join/leave in a MUC - )), - )), - 'INDEX': (('history', (('profile_id', 'timestamp'), - ('profile_id', 'received_timestamp'))), - ('message', ('history_uid',)), - ('subject', ('history_uid',)), - ('thread', ('history_uid',)), - ('files', (('profile_id', 'owner', 'media_type', 'media_subtype'), - ('profile_id', 'owner', 'parent'))), - ) - }, - 9: {'specific': 'update_v9' - }, - 8: {'specific': 'update_v8' - }, - 7: {'specific': 'update_v7' - }, - 6: {'cols create': {'history': ('stanza_id TEXT',)}, - }, - 5: {'create': {'files': (("id TEXT NOT NULL", "version TEXT NOT NULL", "parent TEXT NOT NULL", - "type TEXT CHECK(type in ('{file}', '{directory}')) NOT NULL DEFAULT '{file}'".format( - file=C.FILE_TYPE_FILE, directory=C.FILE_TYPE_DIRECTORY), - "file_hash TEXT", "hash_algo TEXT", "name TEXT NOT NULL", "size INTEGER", - "namespace TEXT", "mime_type TEXT", - "created DATETIME NOT NULL", "modified DATETIME", - "owner TEXT", "access TEXT", "extra TEXT", "profile_id INTEGER"), - ("PRIMARY KEY (id, version)", "FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE"))}, - }, - 4: {'create': {'components': (('profile_id INTEGER PRIMARY KEY', 'entry_point TEXT NOT NULL'), ('FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE',))} - }, - 3: {'specific': 'update_v3' - }, - 2: {'specific': 'update2raw_v2' - }, - 1: {'cols create': {'history': ('extra BLOB',)}, - }, - } - -NOT_IN_EXTRA = ('stanza_id', 'received_timestamp', 'update_uid') # keys which are in message data extra but not stored in sqlite's extra field - # this is specific to this sqlite storage and for now only used for received_timestamp - # because this value is stored in a separate field - - -class ConnectionPool(adbapi.ConnectionPool): - def _runQuery(self, trans, *args, **kw): - retry = kw.pop('query_retry', 6) - try: - trans.execute(*args, **kw) - except sqlite3.IntegrityError as e: - # Workaround to avoid IntegrityError causing (i)pdb to be - # launched in debug mode - raise failure.Failure(e) - except Exception as e: - # FIXME: in case of error, we retry a couple of times - # this is a workaround, we need to move to better - # Sqlite integration, probably with high level library - retry -= 1 - if retry == 0: - log.error(_('too many db tries, we abandon! Error message: {msg}\n' - 'query was {query}' - .format(msg=e, query=' '.join([str(a) for a in args])))) - raise e - log.warning( - _('exception while running query, retrying ({try_}): {msg}').format( - try_ = 6 - retry, - msg = e)) - kw['query_retry'] = retry - return self._runQuery(trans, *args, **kw) - return trans.fetchall() - - def _runInteraction(self, interaction, *args, **kw): - # sometimes interaction may fail while committing in _runInteraction - # and it may be due to a db lock. So we work around it in a similar way - # as for _runQuery but with only 3 tries - retry = kw.pop('interaction_retry', 4) - try: - return adbapi.ConnectionPool._runInteraction(self, interaction, *args, **kw) - except Exception as e: - retry -= 1 - if retry == 0: - log.error( - _('too many interaction tries, we abandon! Error message: {msg}\n' - 'interaction method was: {interaction}\n' - 'interaction arguments were: {args}' - .format(msg=e, interaction=interaction, - args=', '.join([str(a) for a in args])))) - raise e - log.warning( - _('exception while running interaction, retrying ({try_}): {msg}') - .format(try_ = 4 - retry, msg = e)) - kw['interaction_retry'] = retry - return self._runInteraction(interaction, *args, **kw) - - -class SqliteStorage(object): - """This class manage storage with Sqlite database""" - - def __init__(self, db_filename, sat_version): - """Connect to the given database - - @param db_filename: full path to the Sqlite database - """ - # triggered when memory is fully initialised and ready - self.initialized = defer.Deferred() - # we keep cache for the profiles (key: profile name, value: profile id) - self.profiles = {} - - log.info(_("Connecting database")) - new_base = not os.path.exists(db_filename) # do we have to create the database ? - if new_base: # the dir may not exist if it's not the XDG recommended one - dir_ = os.path.dirname(db_filename) - if not os.path.exists(dir_): - os.makedirs(dir_, 0o700) - - def foreignKeysOn(sqlite): - sqlite.execute('PRAGMA foreign_keys = ON') - - self.dbpool = ConnectionPool("sqlite3", db_filename, cp_openfun=foreignKeysOn, check_same_thread=False, timeout=15) - - def getNewBaseSql(): - log.info(_("The database is new, creating the tables")) - database_creation = ["PRAGMA user_version=%d" % CURRENT_DB_VERSION] - database_creation.extend(Updater.createData2Raw(DATABASE_SCHEMAS['current']['CREATE'])) - database_creation.extend(Updater.insertData2Raw(DATABASE_SCHEMAS['current']['INSERT'])) - database_creation.extend(Updater.indexData2Raw(DATABASE_SCHEMAS['current']['INDEX'])) - return database_creation - - def getUpdateSql(): - updater = Updater(self, sat_version) - return updater.checkUpdates() - - # init_defer is the initialisation deferred, initialisation is ok when all its callbacks have been done - - init_defer = defer.succeed(None) - - init_defer.addCallback(lambda ignore: getNewBaseSql() if new_base else getUpdateSql()) - init_defer.addCallback(self.commitStatements) - - def fillProfileCache(ignore): - return self.dbpool.runQuery("SELECT profile_id, entry_point FROM components").addCallback(self._cacheComponentsAndProfiles) - - init_defer.addCallback(fillProfileCache) - init_defer.chainDeferred(self.initialized) - - def commitStatements(self, statements): - - if statements is None: - return defer.succeed(None) - log.debug("\n===== COMMITTING STATEMENTS =====\n%s\n============\n\n" % '\n'.join(statements)) - d = self.dbpool.runInteraction(self._updateDb, tuple(statements)) - return d - - def _updateDb(self, interaction, statements): - for statement in statements: - interaction.execute(statement) - - ## Profiles - - def _cacheComponentsAndProfiles(self, components_result): - """Get components results and send requests profiles - - they will be both put in cache in _profilesCache - """ - return self.dbpool.runQuery("SELECT name,id FROM profiles").addCallback( - self._cacheComponentsAndProfiles2, components_result) - - def _cacheComponentsAndProfiles2(self, profiles_result, components): - """Fill the profiles cache - - @param profiles_result: result of the sql profiles query - """ - self.components = dict(components) - for profile in profiles_result: - name, id_ = profile - self.profiles[name] = id_ - - def getProfilesList(self): - """"Return list of all registered profiles""" - return list(self.profiles.keys()) - - def hasProfile(self, profile_name): - """return True if profile_name exists - - @param profile_name: name of the profile to check - """ - return profile_name in self.profiles - - def profileIsComponent(self, profile_name): - try: - return self.profiles[profile_name] in self.components - except KeyError: - raise exceptions.NotFound("the requested profile doesn't exists") - - def getEntryPoint(self, profile_name): - try: - return self.components[self.profiles[profile_name]] - except KeyError: - raise exceptions.NotFound("the requested profile doesn't exists or is not a component") - - def createProfile(self, name, component=None): - """Create a new profile - - @param name(unicode): name of the profile - @param component(None, unicode): if not None, must point to a component entry point - @return: deferred triggered once profile is actually created - """ - - def getProfileId(ignore): - return self.dbpool.runQuery("SELECT (id) FROM profiles WHERE name = ?", (name, )) - - def setComponent(profile_id): - id_ = profile_id[0][0] - d_comp = self.dbpool.runQuery("INSERT INTO components(profile_id, entry_point) VALUES (?, ?)", (id_, component)) - d_comp.addCallback(lambda __: profile_id) - return d_comp - - def profile_created(profile_id): - id_= profile_id[0][0] - self.profiles[name] = id_ # we synchronise the cache - - d = self.dbpool.runQuery("INSERT INTO profiles(name) VALUES (?)", (name, )) - d.addCallback(getProfileId) - if component is not None: - d.addCallback(setComponent) - d.addCallback(profile_created) - return d - - def deleteProfile(self, name): - """Delete profile - - @param name: name of the profile - @return: deferred triggered once profile is actually deleted - """ - def deletionError(failure_): - log.error(_("Can't delete profile [%s]") % name) - return failure_ - - def delete(txn): - profile_id = self.profiles.pop(name) - txn.execute("DELETE FROM profiles WHERE name = ?", (name,)) - # FIXME: the following queries should be done by the ON DELETE CASCADE - # but it seems they are not, so we explicitly do them by security - # this need more investigation - txn.execute("DELETE FROM history WHERE profile_id = ?", (profile_id,)) - txn.execute("DELETE FROM param_ind WHERE profile_id = ?", (profile_id,)) - txn.execute("DELETE FROM private_ind WHERE profile_id = ?", (profile_id,)) - txn.execute("DELETE FROM private_ind_bin WHERE profile_id = ?", (profile_id,)) - txn.execute("DELETE FROM components WHERE profile_id = ?", (profile_id,)) - return None - - d = self.dbpool.runInteraction(delete) - d.addCallback(lambda ignore: log.info(_("Profile [%s] deleted") % name)) - d.addErrback(deletionError) - return d - - ## Params - def loadGenParams(self, params_gen): - """Load general parameters - - @param params_gen: dictionary to fill - @return: deferred - """ - - def fillParams(result): - for param in result: - category, name, value = param - params_gen[(category, name)] = value - log.debug(_("loading general parameters from database")) - return self.dbpool.runQuery("SELECT category,name,value FROM param_gen").addCallback(fillParams) - - def loadIndParams(self, params_ind, profile): - """Load individual parameters - - @param params_ind: dictionary to fill - @param profile: a profile which *must* exist - @return: deferred - """ - - def fillParams(result): - for param in result: - category, name, value = param - params_ind[(category, name)] = value - log.debug(_("loading individual parameters from database")) - d = self.dbpool.runQuery("SELECT category,name,value FROM param_ind WHERE profile_id=?", (self.profiles[profile], )) - d.addCallback(fillParams) - return d - - def getIndParam(self, category, name, profile): - """Ask database for the value of one specific individual parameter - - @param category: category of the parameter - @param name: name of the parameter - @param profile: %(doc_profile)s - @return: deferred - """ - d = self.dbpool.runQuery( - "SELECT value FROM param_ind WHERE category=? AND name=? AND profile_id=?", - (category, name, self.profiles[profile])) - d.addCallback(self.__getFirstResult) - return d - - async def getIndParamValues(self, category, name): - """Ask database for the individual values of a parameter for all profiles - - @param category: category of the parameter - @param name: name of the parameter - @return dict: profile => value map - """ - result = await self.dbpool.runQuery( - "SELECT profiles.name, param_ind.value FROM param_ind JOIN profiles ON " - "param_ind.profile_id = profiles.id WHERE param_ind.category=? " - "and param_ind.name=?", - (category, name)) - return dict(result) - - def setGenParam(self, category, name, value): - """Save the general parameters in database - - @param category: category of the parameter - @param name: name of the parameter - @param value: value to set - @return: deferred""" - d = self.dbpool.runQuery("REPLACE INTO param_gen(category,name,value) VALUES (?,?,?)", (category, name, value)) - d.addErrback(lambda ignore: log.error(_("Can't set general parameter (%(category)s/%(name)s) in database" % {"category": category, "name": name}))) - return d - - def setIndParam(self, category, name, value, profile): - """Save the individual parameters in database - - @param category: category of the parameter - @param name: name of the parameter - @param value: value to set - @param profile: a profile which *must* exist - @return: deferred - """ - d = self.dbpool.runQuery("REPLACE INTO param_ind(category,name,profile_id,value) VALUES (?,?,?,?)", (category, name, self.profiles[profile], value)) - d.addErrback(lambda ignore: log.error(_("Can't set individual parameter (%(category)s/%(name)s) for [%(profile)s] in database" % {"category": category, "name": name, "profile": profile}))) - return d - - ## History - - def _addToHistoryCb(self, __, data): - # Message metadata were successfuly added to history - # now we can add message and subject - uid = data['uid'] - d_list = [] - for key in ('message', 'subject'): - for lang, value in data[key].items(): - if not value.strip(): - # no need to store empty messages - continue - d = self.dbpool.runQuery( - "INSERT INTO {key}(history_uid, {key}, language) VALUES (?,?,?)" - .format(key=key), - (uid, value, lang or None)) - d.addErrback(lambda __: log.error( - _("Can't save following {key} in history (uid: {uid}, lang:{lang}):" - " {value}").format( - key=key, uid=uid, lang=lang, value=value))) - d_list.append(d) - try: - thread = data['extra']['thread'] - except KeyError: - pass - else: - thread_parent = data['extra'].get('thread_parent') - d = self.dbpool.runQuery( - "INSERT INTO thread(history_uid, thread_id, parent_id) VALUES (?,?,?)", - (uid, thread, thread_parent)) - d.addErrback(lambda __: log.error( - _("Can't save following thread in history (uid: {uid}): thread: " - "{thread}), parent:{parent}").format( - uid=uid, thread=thread, parent=thread_parent))) - d_list.append(d) - return defer.DeferredList(d_list) - - def _addToHistoryEb(self, failure_, data): - failure_.trap(sqlite3.IntegrityError) - sqlite_msg = failure_.value.args[0] - if "UNIQUE constraint failed" in sqlite_msg: - log.debug("message {} is already in history, not storing it again" - .format(data['uid'])) - if 'received_timestamp' not in data: - log.warning( - "duplicate message is not delayed, this is maybe a bug: data={}" - .format(data)) - # we cancel message to avoid sending duplicate message to frontends - raise failure.Failure(exceptions.CancelError("Cancelled duplicated message")) - else: - log.error("Can't store message in history: {}".format(failure_)) - - def _logHistoryError(self, failure_, from_jid, to_jid, data): - if failure_.check(exceptions.CancelError): - # we propagate CancelError to avoid sending message to frontends - raise failure_ - log.error(_( - "Can't save following message in history: from [{from_jid}] to [{to_jid}] " - "(uid: {uid})") - .format(from_jid=from_jid.full(), to_jid=to_jid.full(), uid=data['uid'])) - - def addToHistory(self, data, profile): - """Store a new message in history - - @param data(dict): message data as build by SatMessageProtocol.onMessage - """ - extra = pickle.dumps({k: v for k, v in data['extra'].items() - if k not in NOT_IN_EXTRA}, 0) - from_jid = data['from'] - to_jid = data['to'] - d = self.dbpool.runQuery( - "INSERT INTO history(uid, stanza_id, update_uid, profile_id, source, dest, " - "source_res, dest_res, timestamp, received_timestamp, type, extra) VALUES " - "(?,?,?,?,?,?,?,?,?,?,?,?)", - (data['uid'], data['extra'].get('stanza_id'), data['extra'].get('update_uid'), - self.profiles[profile], data['from'].userhost(), to_jid.userhost(), - from_jid.resource, to_jid.resource, data['timestamp'], - data.get('received_timestamp'), data['type'], sqlite3.Binary(extra))) - d.addCallbacks(self._addToHistoryCb, - self._addToHistoryEb, - callbackArgs=[data], - errbackArgs=[data]) - d.addErrback(self._logHistoryError, from_jid, to_jid, data) - return d - - def sqliteHistoryToList(self, query_result): - """Get SQL query result and return a list of message data dicts""" - result = [] - current = {'uid': None} - for row in reversed(query_result): - (uid, stanza_id, update_uid, source, dest, source_res, dest_res, timestamp, - received_timestamp, type_, extra, message, message_lang, subject, - subject_lang, thread, thread_parent) = row - if uid != current['uid']: - # new message - try: - extra = self._load_pickle(extra or b"") - except EOFError: - extra = {} - current = { - 'from': "%s/%s" % (source, source_res) if source_res else source, - 'to': "%s/%s" % (dest, dest_res) if dest_res else dest, - 'uid': uid, - 'message': {}, - 'subject': {}, - 'type': type_, - 'extra': extra, - 'timestamp': timestamp, - } - if stanza_id is not None: - current['extra']['stanza_id'] = stanza_id - if update_uid is not None: - current['extra']['update_uid'] = update_uid - if received_timestamp is not None: - current['extra']['received_timestamp'] = str(received_timestamp) - result.append(current) - - if message is not None: - current['message'][message_lang or ''] = message - - if subject is not None: - current['subject'][subject_lang or ''] = subject - - if thread is not None: - current_extra = current['extra'] - current_extra['thread'] = thread - if thread_parent is not None: - current_extra['thread_parent'] = thread_parent - else: - if thread_parent is not None: - log.error( - "Database inconsistency: thread parent without thread (uid: " - "{uid}, thread_parent: {parent})" - .format(uid=uid, parent=thread_parent)) - - return result - - def listDict2listTuple(self, messages_data): - """Return a list of tuple as used in bridge from a list of messages data""" - ret = [] - for m in messages_data: - ret.append((m['uid'], m['timestamp'], m['from'], m['to'], m['message'], m['subject'], m['type'], m['extra'])) - return ret - - def historyGet(self, from_jid, to_jid, limit=None, between=True, filters=None, profile=None): - """Retrieve messages in history - - @param from_jid (JID): source JID (full, or bare for catchall) - @param to_jid (JID): dest JID (full, or bare for catchall) - @param limit (int): maximum number of messages to get: - - 0 for no message (returns the empty list) - - None for unlimited - @param between (bool): confound source and dest (ignore the direction) - @param filters (dict[unicode, unicode]): pattern to filter the history results - @param profile (unicode): %(doc_profile)s - @return: list of tuple as in [messageNew] - """ - assert profile - if filters is None: - filters = {} - if limit == 0: - return defer.succeed([]) - - query_parts = ["SELECT uid, stanza_id, update_uid, source, dest, source_res, dest_res, timestamp, received_timestamp,\ - type, extra, message, message.language, subject, subject.language, thread_id, thread.parent_id\ - FROM history LEFT JOIN message ON history.uid = message.history_uid\ - LEFT JOIN subject ON history.uid=subject.history_uid\ - LEFT JOIN thread ON history.uid=thread.history_uid\ - WHERE profile_id=?"] # FIXME: not sure if it's the best request, messages and subjects can appear several times here - values = [self.profiles[profile]] - - def test_jid(type_, jid_): - values.append(jid_.userhost()) - if jid_.resource: - values.append(jid_.resource) - return '({type_}=? AND {type_}_res=?)'.format(type_=type_) - return '{type_}=?'.format(type_=type_) - - if not from_jid and not to_jid: - # not jid specified, we want all one2one communications - pass - elif between: - if not from_jid or not to_jid: - # we only have one jid specified, we check all messages - # from or to this jid - jid_ = from_jid or to_jid - query_parts.append("AND ({source} OR {dest})".format( - source=test_jid('source', jid_), - dest=test_jid('dest' , jid_))) - else: - # we have 2 jids specified, we check all communications between - # those 2 jids - query_parts.append( - "AND (({source_from} AND {dest_to}) " - "OR ({source_to} AND {dest_from}))".format( - source_from=test_jid('source', from_jid), - dest_to=test_jid('dest', to_jid), - source_to=test_jid('source', to_jid), - dest_from=test_jid('dest', from_jid))) - else: - # we want one communication in specific direction (from somebody or - # to somebody). - q = [] - if from_jid is not None: - q.append(test_jid('source', from_jid)) - if to_jid is not None: - q.append(test_jid('dest', to_jid)) - query_parts.append("AND " + " AND ".join(q)) - - if filters: - if 'timestamp_start' in filters: - query_parts.append("AND timestamp>= ?") - values.append(float(filters['timestamp_start'])) - if 'before_uid' in filters: - query_parts.append("AND history.rowid<(select rowid from history where uid=?)") - values.append(filters['before_uid']) - if 'body' in filters: - # TODO: use REGEXP (function to be defined) instead of GLOB: https://www.sqlite.org/lang_expr.html - query_parts.append("AND message LIKE ?") - values.append("%{}%".format(filters['body'])) - if 'search' in filters: - query_parts.append("AND (message LIKE ? OR source_res LIKE ?)") - values.extend(["%{}%".format(filters['search'])] * 2) - if 'types' in filters: - types = filters['types'].split() - query_parts.append("AND type IN ({})".format(','.join("?"*len(types)))) - values.extend(types) - if 'not_types' in filters: - types = filters['not_types'].split() - query_parts.append("AND type NOT IN ({})".format(','.join("?"*len(types)))) - values.extend(types) - if 'last_stanza_id' in filters: - # this request get the last message with a "stanza_id" that we - # have in history. This is mainly used to retrieve messages sent - # while we were offline, using MAM (XEP-0313). - if (filters['last_stanza_id'] is not True - or limit != 1): - raise ValueError("Unexpected values for last_stanza_id filter") - query_parts.append("AND stanza_id IS NOT NULL") - - - # timestamp may be identical for 2 close messages (specially when delay is - # used) that's why we order ties by received_timestamp - # We'll reverse the order in sqliteHistoryToList - # we use DESC here so LIMIT keep the last messages - query_parts.append("ORDER BY timestamp DESC, history.received_timestamp DESC") - if limit is not None: - query_parts.append("LIMIT ?") - values.append(limit) - - d = self.dbpool.runQuery(" ".join(query_parts), values) - d.addCallback(self.sqliteHistoryToList) - d.addCallback(self.listDict2listTuple) - return d - - ## Private values - - def _privateDataEb(self, failure_, operation, namespace, key=None, profile=None): - """generic errback for data queries""" - log.error(_("Can't {operation} data in database for namespace {namespace}{and_key}{for_profile}: {msg}").format( - operation = operation, - namespace = namespace, - and_key = (" and key " + key) if key is not None else "", - for_profile = (' [' + profile + ']') if profile is not None else '', - msg = failure_)) - - def _load_pickle(self, v): - # FIXME: workaround for Python 3 port, some pickled data are bytes while other are strings - try: - return pickle.loads(v, encoding="utf-8") - except TypeError: - data = pickle.loads(v.encode('utf-8'), encoding="utf-8") - log.debug(f"encoding issue in pickled data: {data}") - return data - - def _generateDataDict(self, query_result, binary): - if binary: - return {k: self._load_pickle(v) for k,v in query_result} - else: - return dict(query_result) - - def _getPrivateTable(self, binary, profile): - """Get table to use for private values""" - table = ['private'] - - if profile is None: - table.append('gen') - else: - table.append('ind') - - if binary: - table.append('bin') - - return '_'.join(table) - - def getPrivates(self, namespace, keys=None, binary=False, profile=None): - """Get private value(s) from databases - - @param namespace(unicode): namespace of the values - @param keys(iterable, None): keys of the values to get - None to get all keys/values - @param binary(bool): True to deserialise binary values - @param profile(unicode, None): profile to use for individual values - None to use general values - @return (dict[unicode, object]): gotten keys/values - """ - log.debug(_("getting {type}{binary} private values from database for namespace {namespace}{keys}".format( - type = "general" if profile is None else "individual", - binary = " binary" if binary else "", - namespace = namespace, - keys = " with keys {}".format(", ".join(keys)) if keys is not None else ""))) - table = self._getPrivateTable(binary, profile) - query_parts = ["SELECT key,value FROM", table, "WHERE namespace=?"] - args = [namespace] - - if keys is not None: - placeholders = ','.join(len(keys) * '?') - query_parts.append('AND key IN (' + placeholders + ')') - args.extend(keys) - - if profile is not None: - query_parts.append('AND profile_id=?') - args.append(self.profiles[profile]) - - d = self.dbpool.runQuery(" ".join(query_parts), args) - d.addCallback(self._generateDataDict, binary) - d.addErrback(self._privateDataEb, "get", namespace, profile=profile) - return d - - def setPrivateValue(self, namespace, key, value, binary=False, profile=None): - """Set a private value in database - - @param namespace(unicode): namespace of the values - @param key(unicode): key of the value to set - @param value(object): value to set - @param binary(bool): True if it's a binary values - binary values need to be serialised, used for everything but strings - @param profile(unicode, None): profile to use for individual value - if None, it's a general value - """ - table = self._getPrivateTable(binary, profile) - query_values_names = ['namespace', 'key', 'value'] - query_values = [namespace, key] - - if binary: - value = sqlite3.Binary(pickle.dumps(value, 0)) - - query_values.append(value) - - if profile is not None: - query_values_names.append('profile_id') - query_values.append(self.profiles[profile]) - - query_parts = ["REPLACE INTO", table, '(', ','.join(query_values_names), ')', - "VALUES (", ",".join('?'*len(query_values_names)), ')'] - - d = self.dbpool.runQuery(" ".join(query_parts), query_values) - d.addErrback(self._privateDataEb, "set", namespace, key, profile=profile) - return d - - def delPrivateValue(self, namespace, key, binary=False, profile=None): - """Delete private value from database - - @param category: category of the privateeter - @param key: key of the private value - @param binary(bool): True if it's a binary values - @param profile(unicode, None): profile to use for individual value - if None, it's a general value - """ - table = self._getPrivateTable(binary, profile) - query_parts = ["DELETE FROM", table, "WHERE namespace=? AND key=?"] - args = [namespace, key] - if profile is not None: - query_parts.append("AND profile_id=?") - args.append(self.profiles[profile]) - d = self.dbpool.runQuery(" ".join(query_parts), args) - d.addErrback(self._privateDataEb, "delete", namespace, key, profile=profile) - return d - - def delPrivateNamespace(self, namespace, binary=False, profile=None): - """Delete all data from a private namespace - - Be really cautious when you use this method, as all data with given namespace are - removed. - Params are the same as for delPrivateValue - """ - table = self._getPrivateTable(binary, profile) - query_parts = ["DELETE FROM", table, "WHERE namespace=?"] - args = [namespace] - if profile is not None: - query_parts.append("AND profile_id=?") - args.append(self.profiles[profile]) - d = self.dbpool.runQuery(" ".join(query_parts), args) - d.addErrback(self._privateDataEb, "delete namespace", namespace, profile=profile) - return d - - ## Files - - @defer.inlineCallbacks - def getFiles(self, client, file_id=None, version='', parent=None, type_=None, - file_hash=None, hash_algo=None, name=None, namespace=None, mime_type=None, - public_id=None, owner=None, access=None, projection=None, unique=False): - """retrieve files with with given filters - - @param file_id(unicode, None): id of the file - None to ignore - @param version(unicode, None): version of the file - None to ignore - empty string to look for current version - @param parent(unicode, None): id of the directory containing the files - None to ignore - empty string to look for root files/directories - @param projection(list[unicode], None): name of columns to retrieve - None to retrieve all - @param unique(bool): if True will remove duplicates - other params are the same as for [setFile] - @return (list[dict]): files corresponding to filters - """ - query_parts = ["SELECT"] - if unique: - query_parts.append('DISTINCT') - if projection is None: - projection = ['id', 'version', 'parent', 'type', 'file_hash', 'hash_algo', 'name', - 'size', 'namespace', 'media_type', 'media_subtype', 'public_id', 'created', 'modified', 'owner', - 'access', 'extra'] - query_parts.append(','.join(projection)) - query_parts.append("FROM files WHERE") - if client is not None: - filters = ['profile_id=?'] - args = [self.profiles[client.profile]] - else: - if public_id is None: - raise exceptions.InternalError( - "client can only be omitted when public_id is set") - filters = [] - args = [] - - if file_id is not None: - filters.append('id=?') - args.append(file_id) - if version is not None: - filters.append('version=?') - args.append(version) - if parent is not None: - filters.append('parent=?') - args.append(parent) - if type_ is not None: - filters.append('type=?') - args.append(type_) - if file_hash is not None: - filters.append('file_hash=?') - args.append(file_hash) - if hash_algo is not None: - filters.append('hash_algo=?') - args.append(hash_algo) - if name is not None: - filters.append('name=?') - args.append(name) - if namespace is not None: - filters.append('namespace=?') - args.append(namespace) - if mime_type is not None: - if '/' in mime_type: - filters.extend('media_type=?', 'media_subtype=?') - args.extend(mime_type.split('/', 1)) - else: - filters.append('media_type=?') - args.append(mime_type) - if public_id is not None: - filters.append('public_id=?') - args.append(public_id) - if owner is not None: - filters.append('owner=?') - args.append(owner.full()) - if access is not None: - raise NotImplementedError('Access check is not implemented yet') - # a JSON comparison is needed here - - filters = ' AND '.join(filters) - query_parts.append(filters) - query = ' '.join(query_parts) - - result = yield self.dbpool.runQuery(query, args) - files_data = [dict(list(zip(projection, row))) for row in result] - to_parse = {'access', 'extra'}.intersection(projection) - to_filter = {'owner'}.intersection(projection) - if to_parse or to_filter: - for file_data in files_data: - for key in to_parse: - value = file_data[key] - file_data[key] = {} if value is None else json.loads(value) - owner = file_data.get('owner') - if owner is not None: - file_data['owner'] = jid.JID(owner) - defer.returnValue(files_data) - - def setFile(self, client, name, file_id, version='', parent=None, type_=C.FILE_TYPE_FILE, - file_hash=None, hash_algo=None, size=None, namespace=None, mime_type=None, - public_id=None, created=None, modified=None, owner=None, access=None, extra=None): - """set a file metadata - - @param client(SatXMPPClient): client owning the file - @param name(str): name of the file (must not contain "/") - @param file_id(str): unique id of the file - @param version(str): version of this file - @param parent(str): id of the directory containing this file - None if it is a root file/directory - @param type_(str): one of: - - file - - directory - @param file_hash(str): unique hash of the payload - @param hash_algo(str): algorithm used for hashing the file (usually sha-256) - @param size(int): size in bytes - @param namespace(str, None): identifier (human readable is better) to group files - for instance, namespace could be used to group files in a specific photo album - @param mime_type(str): media type of the file, or None if not known/guessed - @param public_id(str): ID used to server the file publicly via HTTP - @param created(int): UNIX time of creation - @param modified(int,None): UNIX time of last modification, or None to use created date - @param owner(jid.JID, None): jid of the owner of the file (mainly useful for component) - @param access(dict, None): serialisable dictionary with access rules. See [memory.memory] for details - @param extra(dict, None): serialisable dictionary of any extra data - will be encoded to json in database - """ - if extra is not None: - assert isinstance(extra, dict) - - if mime_type is None: - media_type = media_subtype = None - elif '/' in mime_type: - media_type, media_subtype = mime_type.split('/', 1) - else: - media_type, media_subtype = mime_type, None - - query = ('INSERT INTO files(id, version, parent, type, file_hash, hash_algo, name, size, namespace, ' - 'media_type, media_subtype, public_id, created, modified, owner, access, extra, profile_id) VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)') - d = self.dbpool.runQuery(query, (file_id, version.strip(), parent, type_, - file_hash, hash_algo, - name, size, namespace, - media_type, media_subtype, public_id, created, modified, - owner.full() if owner is not None else None, - json.dumps(access) if access else None, - json.dumps(extra) if extra else None, - self.profiles[client.profile])) - d.addErrback(lambda failure: log.error(_("Can't save file metadata for [{profile}]: {reason}".format(profile=client.profile, reason=failure)))) - return d - - async def fileGetUsedSpace(self, client, owner): - """Get space used by owner of file""" - query = "SELECT SUM(size) FROM files WHERE owner=? AND type='file' AND profile_id=?" - ret = await self.dbpool.runQuery( - query, - (owner.userhost(), self.profiles[client.profile]) - ) - return ret[0][0] or 0 - - def _fileUpdate(self, cursor, file_id, column, update_cb): - query = 'SELECT {column} FROM files where id=?'.format(column=column) - for i in range(5): - cursor.execute(query, [file_id]) - try: - older_value_raw = cursor.fetchone()[0] - except TypeError: - raise exceptions.NotFound - if older_value_raw is None: - value = {} - else: - value = json.loads(older_value_raw) - update_cb(value) - value_raw = json.dumps(value) - if older_value_raw is None: - update_query = 'UPDATE files SET {column}=? WHERE id=? AND {column} is NULL'.format(column=column) - update_args = (value_raw, file_id) - else: - update_query = 'UPDATE files SET {column}=? WHERE id=? AND {column}=?'.format(column=column) - update_args = (value_raw, file_id, older_value_raw) - try: - cursor.execute(update_query, update_args) - except sqlite3.Error: - pass - else: - if cursor.rowcount == 1: - break; - log.warning(_("table not updated, probably due to race condition, trying again ({tries})").format(tries=i+1)) - else: - log.error(_("Can't update file table")) - - def fileUpdate(self, file_id, column, update_cb): - """Update a column value using a method to avoid race conditions - - the older value will be retrieved from database, then update_cb will be applied - to update it, and file will be updated checking that older value has not been changed meanwhile - by an other user. If it has changed, it tries again a couple of times before failing - @param column(str): column name (only "access" or "extra" are allowed) - @param update_cb(callable): method to update the value of the colum - the method will take older value as argument, and must update it in place - update_cb must not care about serialization, - it get the deserialized data (i.e. a Python object) directly - Note that the callable must be thread-safe - @raise exceptions.NotFound: there is not file with this id - """ - if column not in ('access', 'extra'): - raise exceptions.InternalError('bad column name') - return self.dbpool.runInteraction(self._fileUpdate, file_id, column, update_cb) - - def fileDelete(self, file_id): - """Delete file metadata from the database - - @param file_id(unicode): id of the file to delete - NOTE: file itself must still be removed, this method only handle metadata in - database - """ - return self.dbpool.runQuery("DELETE FROM files WHERE id = ?", (file_id,)) - - ##Helper methods## - - def __getFirstResult(self, result): - """Return the first result of a database query - Useful when we are looking for one specific value""" - return None if not result else result[0][0] - - -class Updater(object): - stmnt_regex = re.compile(r"[\w/' ]+(?:\(.*?\))?[^,]*") - clean_regex = re.compile(r"^ +|(?<= ) +|(?<=,) +| +$") - CREATE_SQL = "CREATE TABLE %s (%s)" - INSERT_SQL = "INSERT INTO %s VALUES (%s)" - INDEX_SQL = "CREATE INDEX %s ON %s(%s)" - DROP_SQL = "DROP TABLE %s" - ALTER_SQL = "ALTER TABLE %s ADD COLUMN %s" - RENAME_TABLE_SQL = "ALTER TABLE %s RENAME TO %s" - - CONSTRAINTS = ('PRIMARY', 'UNIQUE', 'CHECK', 'FOREIGN') - TMP_TABLE = "tmp_sat_update" - - def __init__(self, sqlite_storage, sat_version): - self._sat_version = sat_version - self.sqlite_storage = sqlite_storage - - @property - def dbpool(self): - return self.sqlite_storage.dbpool - - def getLocalVersion(self): - """ Get local database version - - @return: version (int) - """ - return self.dbpool.runQuery("PRAGMA user_version").addCallback(lambda ret: int(ret[0][0])) - - def _setLocalVersion(self, version): - """ Set local database version - - @param version: version (int) - @return: deferred - """ - return self.dbpool.runOperation("PRAGMA user_version=%d" % version) - - def getLocalSchema(self): - """ return raw local schema - - @return: list of strings with CREATE sql statements for local database - """ - d = self.dbpool.runQuery("select sql from sqlite_master where type = 'table'") - d.addCallback(lambda result: [row[0] for row in result]) - return d - - @defer.inlineCallbacks - def checkUpdates(self): - """ Check if a database schema/content update is needed, according to DATABASE_SCHEMAS - - @return: deferred which fire a list of SQL update statements, or None if no update is needed - """ - # TODO: only "table" type (i.e. "CREATE" statements) is checked, - # "index" should be checked too. - # This may be not relevant is we move to a higher level library (alchimia?) - local_version = yield self.getLocalVersion() - raw_local_sch = yield self.getLocalSchema() - - local_sch = self.rawStatements2data(raw_local_sch) - current_sch = DATABASE_SCHEMAS['current']['CREATE'] - local_hash = self.statementHash(local_sch) - current_hash = self.statementHash(current_sch) - - # Force the update if the schemas are unchanged but a specific update is needed - force_update = local_hash == current_hash and local_version < CURRENT_DB_VERSION \ - and {'index', 'specific'}.intersection(DATABASE_SCHEMAS[CURRENT_DB_VERSION]) - - if local_hash == current_hash and not force_update: - if local_version != CURRENT_DB_VERSION: - log.warning(_("Your local schema is up-to-date, but database versions mismatch, fixing it...")) - yield self._setLocalVersion(CURRENT_DB_VERSION) - else: - # an update is needed - - if local_version == CURRENT_DB_VERSION: - # Database mismatch and we have the latest version - if self._sat_version.endswith('D'): - # we are in a development version - update_data = self.generateUpdateData(local_sch, current_sch, False) - log.warning(_("There is a schema mismatch, but as we are on a dev version, database will be updated")) - update_raw = yield self.update2raw(update_data, True) - defer.returnValue(update_raw) - else: - log.error(_("schema version is up-to-date, but local schema differ from expected current schema")) - update_data = self.generateUpdateData(local_sch, current_sch, True) - update_raw = yield self.update2raw(update_data) - log.warning(_("Here are the commands that should fix the situation, use at your own risk (do a backup before modifying database), you can go to SàT's MUC room at sat@chat.jabberfr.org for help\n### SQL###\n%s\n### END SQL ###\n") % '\n'.join("%s;" % statement for statement in update_raw)) - raise exceptions.DatabaseError("Database mismatch") - else: - if local_version > CURRENT_DB_VERSION: - log.error(_( - "You database version is higher than the one used in this SàT " - "version, are you using several version at the same time? We " - "can't run SàT with this database.")) - sys.exit(1) - - # Database is not up-to-date, we'll do the update - if force_update: - log.info(_("Database content needs a specific processing, local database will be updated")) - else: - log.info(_("Database schema has changed, local database will be updated")) - update_raw = [] - for version in range(local_version + 1, CURRENT_DB_VERSION + 1): - try: - update_data = DATABASE_SCHEMAS[version] - except KeyError: - raise exceptions.InternalError("Missing update definition (version %d)" % version) - if "specific" in update_data and update_raw: - # if we have a specific, we must commit current statements - # because a specific may modify database itself, and the database - # must be in the expected state of the previous version. - yield self.sqlite_storage.commitStatements(update_raw) - del update_raw[:] - update_raw_step = yield self.update2raw(update_data) - if update_raw_step is not None: - # can be None with specifics - update_raw.extend(update_raw_step) - update_raw.append("PRAGMA user_version=%d" % CURRENT_DB_VERSION) - defer.returnValue(update_raw) - - @staticmethod - def createData2Raw(data): - """ Generate SQL statements from statements data - - @param data: dictionary with table as key, and statements data in tuples as value - @return: list of strings with raw statements - """ - ret = [] - for table in data: - defs, constraints = data[table] - assert isinstance(defs, tuple) - assert isinstance(constraints, tuple) - ret.append(Updater.CREATE_SQL % (table, ', '.join(defs + constraints))) - return ret - - @staticmethod - def insertData2Raw(data): - """ Generate SQL statements from statements data - - @param data: dictionary with table as key, and statements data in tuples as value - @return: list of strings with raw statements - """ - ret = [] - for table in data: - values_tuple = data[table] - assert isinstance(values_tuple, tuple) - for values in values_tuple: - assert isinstance(values, tuple) - ret.append(Updater.INSERT_SQL % (table, ', '.join(values))) - return ret - - @staticmethod - def indexData2Raw(data): - """ Generate SQL statements from statements data - - @param data: dictionary with table as key, and statements data in tuples as value - @return: list of strings with raw statements - """ - ret = [] - assert isinstance(data, tuple) - for table, col_data in data: - assert isinstance(table, str) - assert isinstance(col_data, tuple) - for cols in col_data: - if isinstance(cols, tuple): - assert all([isinstance(c, str) for c in cols]) - indexed_cols = ','.join(cols) - elif isinstance(cols, str): - indexed_cols = cols - else: - raise exceptions.InternalError("unexpected index columns value") - index_name = table + '__' + indexed_cols.replace(',', '_') - ret.append(Updater.INDEX_SQL % (index_name, table, indexed_cols)) - return ret - - def statementHash(self, data): - """ Generate hash of template data - - useful to compare schemas - @param data: dictionary of "CREATE" statement, with tables names as key, - and tuples of (col_defs, constraints) as values - @return: hash as string - """ - hash_ = hashlib.sha1() - tables = list(data.keys()) - tables.sort() - - def stmnts2str(stmts): - return ','.join([self.clean_regex.sub('',stmt) for stmt in sorted(stmts)]) - - for table in tables: - col_defs, col_constr = data[table] - hash_.update( - ("%s:%s:%s" % (table, stmnts2str(col_defs), stmnts2str(col_constr))) - .encode('utf-8')) - return hash_.digest() - - def rawStatements2data(self, raw_statements): - """ separate "CREATE" statements into dictionary/tuples data - - @param raw_statements: list of CREATE statements as strings - @return: dictionary with table names as key, and a (col_defs, constraints) tuple - """ - schema_dict = {} - for create_statement in raw_statements: - if not create_statement.startswith("CREATE TABLE "): - log.warning("Unexpected statement, ignoring it") - continue - _create_statement = create_statement[13:] - table, raw_col_stats = _create_statement.split(' ',1) - if raw_col_stats[0] != '(' or raw_col_stats[-1] != ')': - log.warning("Unexpected statement structure, ignoring it") - continue - col_stats = [stmt.strip() for stmt in self.stmnt_regex.findall(raw_col_stats[1:-1])] - col_defs = [] - constraints = [] - for col_stat in col_stats: - name = col_stat.split(' ',1)[0] - if name in self.CONSTRAINTS: - constraints.append(col_stat) - else: - col_defs.append(col_stat) - schema_dict[table] = (tuple(col_defs), tuple(constraints)) - return schema_dict - - def generateUpdateData(self, old_data, new_data, modify=False): - """ Generate data for automatic update between two schema data - - @param old_data: data of the former schema (which must be updated) - @param new_data: data of the current schema - @param modify: if True, always use "cols modify" table, else try to ALTER tables - @return: update data, a dictionary with: - - 'create': dictionary of tables to create - - 'delete': tuple of tables to delete - - 'cols create': dictionary of columns to create (table as key, tuple of columns to create as value) - - 'cols delete': dictionary of columns to delete (table as key, tuple of columns to delete as value) - - 'cols modify': dictionary of columns to modify (table as key, tuple of old columns to transfert as value). With this table, a new table will be created, and content from the old table will be transfered to it, only cols specified in the tuple will be transfered. - """ - - create_tables_data = {} - create_cols_data = {} - modify_cols_data = {} - delete_cols_data = {} - old_tables = set(old_data.keys()) - new_tables = set(new_data.keys()) - - def getChanges(set_olds, set_news): - to_create = set_news.difference(set_olds) - to_delete = set_olds.difference(set_news) - to_check = set_news.intersection(set_olds) - return tuple(to_create), tuple(to_delete), tuple(to_check) - - tables_to_create, tables_to_delete, tables_to_check = getChanges(old_tables, new_tables) - - for table in tables_to_create: - create_tables_data[table] = new_data[table] - - for table in tables_to_check: - old_col_defs, old_constraints = old_data[table] - new_col_defs, new_constraints = new_data[table] - for obj in old_col_defs, old_constraints, new_col_defs, new_constraints: - if not isinstance(obj, tuple): - raise exceptions.InternalError("Columns definitions must be tuples") - defs_create, defs_delete, ignore = getChanges(set(old_col_defs), set(new_col_defs)) - constraints_create, constraints_delete, ignore = getChanges(set(old_constraints), set(new_constraints)) - created_col_names = set([name.split(' ',1)[0] for name in defs_create]) - deleted_col_names = set([name.split(' ',1)[0] for name in defs_delete]) - if (created_col_names.intersection(deleted_col_names or constraints_create or constraints_delete) or - (modify and (defs_create or constraints_create or defs_delete or constraints_delete))): - # we have modified columns, we need to transfer table - # we determinate which columns are in both schema so we can transfer them - old_names = set([name.split(' ',1)[0] for name in old_col_defs]) - new_names = set([name.split(' ',1)[0] for name in new_col_defs]) - modify_cols_data[table] = tuple(old_names.intersection(new_names)); - else: - if defs_create: - create_cols_data[table] = (defs_create) - if defs_delete or constraints_delete: - delete_cols_data[table] = (defs_delete) - - return {'create': create_tables_data, - 'delete': tables_to_delete, - 'cols create': create_cols_data, - 'cols delete': delete_cols_data, - 'cols modify': modify_cols_data - } - - @defer.inlineCallbacks - def update2raw(self, update, dev_version=False): - """ Transform update data to raw SQLite statements - - @param update: update data as returned by generateUpdateData - @param dev_version: if True, update will be done in dev mode: no deletion will be done, instead a message will be shown. This prevent accidental lost of data while working on the code/database. - @return: list of string with SQL statements needed to update the base - """ - ret = self.createData2Raw(update.get('create', {})) - drop = [] - for table in update.get('delete', tuple()): - drop.append(self.DROP_SQL % table) - if dev_version: - if drop: - log.info("Dev version, SQL NOT EXECUTED:\n--\n%s\n--\n" % "\n".join(drop)) - else: - ret.extend(drop) - - cols_create = update.get('cols create', {}) - for table in cols_create: - for col_def in cols_create[table]: - ret.append(self.ALTER_SQL % (table, col_def)) - - cols_delete = update.get('cols delete', {}) - for table in cols_delete: - log.info("Following columns in table [%s] are not needed anymore, but are kept for dev version: %s" % (table, ", ".join(cols_delete[table]))) - - cols_modify = update.get('cols modify', {}) - for table in cols_modify: - ret.append(self.RENAME_TABLE_SQL % (table, self.TMP_TABLE)) - main, extra = DATABASE_SCHEMAS['current']['CREATE'][table] - ret.append(self.CREATE_SQL % (table, ', '.join(main + extra))) - common_cols = ', '.join(cols_modify[table]) - ret.append("INSERT INTO %s (%s) SELECT %s FROM %s" % (table, common_cols, common_cols, self.TMP_TABLE)) - ret.append(self.DROP_SQL % self.TMP_TABLE) - - insert = update.get('insert', {}) - ret.extend(self.insertData2Raw(insert)) - - index = update.get('index', tuple()) - ret.extend(self.indexData2Raw(index)) - - specific = update.get('specific', None) - if specific: - cmds = yield getattr(self, specific)() - ret.extend(cmds or []) - defer.returnValue(ret) - - def update_v9(self): - """Update database from v8 to v9 - - (public_id on file with UNIQUE constraint, files indexes fix, media_type split) - """ - # we have to do a specific update because we can't set UNIQUE constraint when adding a column - # (see https://sqlite.org/lang_altertable.html#alter_table_add_column) - log.info("Database update to v9") - - create = { - 'files': (("id TEXT NOT NULL", "public_id TEXT", "version TEXT NOT NULL", - "parent TEXT NOT NULL", - "type TEXT CHECK(type in ('{file}', '{directory}')) NOT NULL DEFAULT '{file}'".format( - file=C.FILE_TYPE_FILE, directory=C.FILE_TYPE_DIRECTORY), - "file_hash TEXT", "hash_algo TEXT", "name TEXT NOT NULL", "size INTEGER", - "namespace TEXT", "media_type TEXT", "media_subtype TEXT", - "created DATETIME NOT NULL", "modified DATETIME", - "owner TEXT", "access TEXT", "extra TEXT", "profile_id INTEGER"), - ("PRIMARY KEY (id, version)", "FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE", - "UNIQUE (public_id)")), - - } - index = tuple({'files': (('profile_id', 'owner', 'media_type', 'media_subtype'), - ('profile_id', 'owner', 'parent'))}.items()) - # XXX: Sqlite doc recommends to do the other way around (i.e. create new table, - # copy, drop old table then rename), but the RENAME would then add - # "IF NOT EXISTS" which breaks the (admittely fragile) schema comparison. - # TODO: rework sqlite update management, don't try to automatically detect - # update, the database version is now enough. - statements = ["ALTER TABLE files RENAME TO files_old"] - statements.extend(Updater.createData2Raw(create)) - cols = ','.join([col_stmt.split()[0] for col_stmt in create['files'][0] if "public_id" not in col_stmt]) - old_cols = cols[:] - # we need to split mime_type to the new media_type and media_subtype - old_cols = old_cols.replace( - 'media_type,media_subtype', - "substr(mime_type, 0, instr(mime_type,'/')),substr(mime_type, instr(mime_type,'/')+1)" - ) - statements.extend([ - f"INSERT INTO files({cols}) SELECT {old_cols} FROM files_old", - "DROP TABLE files_old", - ]) - statements.extend(Updater.indexData2Raw(index)) - return statements - - def update_v8(self): - """Update database from v7 to v8 (primary keys order changes + indexes)""" - log.info("Database update to v8") - statements = ["PRAGMA foreign_keys = OFF"] - - # here is a copy of create and index data, we can't use "current" table - # because it may change in a future version, which would break the update - # when doing v8 - create = { - 'param_gen': ( - ("category TEXT", "name TEXT", "value TEXT"), - ("PRIMARY KEY (category, name)",)), - 'param_ind': ( - ("category TEXT", "name TEXT", "profile_id INTEGER", "value TEXT"), - ("PRIMARY KEY (profile_id, category, name)", "FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE")), - 'private_ind': ( - ("namespace TEXT", "key TEXT", "profile_id INTEGER", "value TEXT"), - ("PRIMARY KEY (profile_id, namespace, key)", "FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE")), - 'private_ind_bin': ( - ("namespace TEXT", "key TEXT", "profile_id INTEGER", "value BLOB"), - ("PRIMARY KEY (profile_id, namespace, key)", "FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE")), - } - index = ( - ('history', (('profile_id', 'timestamp'), - ('profile_id', 'received_timestamp'))), - ('message', ('history_uid',)), - ('subject', ('history_uid',)), - ('thread', ('history_uid',)), - ('files', ('profile_id', 'mime_type', 'owner', 'parent'))) - - for table in ('param_gen', 'param_ind', 'private_ind', 'private_ind_bin'): - statements.append("ALTER TABLE {0} RENAME TO {0}_old".format(table)) - schema = {table: create[table]} - cols = [d.split()[0] for d in schema[table][0]] - statements.extend(Updater.createData2Raw(schema)) - statements.append("INSERT INTO {table}({cols}) " - "SELECT {cols} FROM {table}_old".format( - table=table, - cols=','.join(cols))) - statements.append("DROP TABLE {}_old".format(table)) - - statements.extend(Updater.indexData2Raw(index)) - statements.append("PRAGMA foreign_keys = ON") - return statements - - @defer.inlineCallbacks - def update_v7(self): - """Update database from v6 to v7 (history unique constraint change)""" - log.info("Database update to v7, this may be long depending on your history " - "size, please be patient.") - - log.info("Some cleaning first") - # we need to fix duplicate stanza_id, as it can result in conflicts with the new schema - # normally database should not contain any, but better safe than sorry. - rows = yield self.dbpool.runQuery( - "SELECT stanza_id, COUNT(*) as c FROM history WHERE stanza_id is not NULL " - "GROUP BY stanza_id HAVING c>1") - if rows: - count = sum([r[1] for r in rows]) - len(rows) - log.info("{count} duplicate stanzas found, cleaning".format(count=count)) - for stanza_id, count in rows: - log.info("cleaning duplicate stanza {stanza_id}".format(stanza_id=stanza_id)) - row_uids = yield self.dbpool.runQuery( - "SELECT uid FROM history WHERE stanza_id = ? LIMIT ?", - (stanza_id, count-1)) - uids = [r[0] for r in row_uids] - yield self.dbpool.runQuery( - "DELETE FROM history WHERE uid IN ({})".format(",".join("?"*len(uids))), - uids) - - def deleteInfo(txn): - # with foreign_keys on, the delete takes ages, so we deactivate it here - # the time to delete info messages from history. - txn.execute("PRAGMA foreign_keys = OFF") - txn.execute("DELETE FROM message WHERE history_uid IN (SELECT uid FROM history WHERE " - "type='info')") - txn.execute("DELETE FROM subject WHERE history_uid IN (SELECT uid FROM history WHERE " - "type='info')") - txn.execute("DELETE FROM thread WHERE history_uid IN (SELECT uid FROM history WHERE " - "type='info')") - txn.execute("DELETE FROM message WHERE history_uid IN (SELECT uid FROM history WHERE " - "type='info')") - txn.execute("DELETE FROM history WHERE type='info'") - # not sure that is is necessary to reactivate here, but in doubt… - txn.execute("PRAGMA foreign_keys = ON") - - log.info('Deleting "info" messages (this can take a while)') - yield self.dbpool.runInteraction(deleteInfo) - - log.info("Cleaning done") - - # we have to rename table we will replace - # tables referencing history need to be replaced to, else reference would - # be to the old table (which will be dropped at the end). This buggy behaviour - # seems to be fixed in new version of Sqlite - yield self.dbpool.runQuery("ALTER TABLE history RENAME TO history_old") - yield self.dbpool.runQuery("ALTER TABLE message RENAME TO message_old") - yield self.dbpool.runQuery("ALTER TABLE subject RENAME TO subject_old") - yield self.dbpool.runQuery("ALTER TABLE thread RENAME TO thread_old") - - # history - query = ("CREATE TABLE history (uid TEXT PRIMARY KEY, stanza_id TEXT, " - "update_uid TEXT, profile_id INTEGER, source TEXT, dest TEXT, " - "source_res TEXT, dest_res TEXT, timestamp DATETIME NOT NULL, " - "received_timestamp DATETIME, type TEXT, extra BLOB, " - "FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE, " - "FOREIGN KEY(type) REFERENCES message_types(type), " - "UNIQUE (profile_id, stanza_id, source, dest))") - yield self.dbpool.runQuery(query) - - # message - query = ("CREATE TABLE message (id INTEGER PRIMARY KEY ASC, history_uid INTEGER" - ", message TEXT, language TEXT, FOREIGN KEY(history_uid) REFERENCES " - "history(uid) ON DELETE CASCADE)") - yield self.dbpool.runQuery(query) - - # subject - query = ("CREATE TABLE subject (id INTEGER PRIMARY KEY ASC, history_uid INTEGER" - ", subject TEXT, language TEXT, FOREIGN KEY(history_uid) REFERENCES " - "history(uid) ON DELETE CASCADE)") - yield self.dbpool.runQuery(query) - - # thread - query = ("CREATE TABLE thread (id INTEGER PRIMARY KEY ASC, history_uid INTEGER" - ", thread_id TEXT, parent_id TEXT, FOREIGN KEY(history_uid) REFERENCES " - "history(uid) ON DELETE CASCADE)") - yield self.dbpool.runQuery(query) - - log.info("Now transfering old data to new tables, please be patient.") - - log.info("\nTransfering table history") - query = ("INSERT INTO history (uid, stanza_id, update_uid, profile_id, source, " - "dest, source_res, dest_res, timestamp, received_timestamp, type, extra" - ") SELECT uid, stanza_id, update_uid, profile_id, source, dest, " - "source_res, dest_res, timestamp, received_timestamp, type, extra " - "FROM history_old") - yield self.dbpool.runQuery(query) - - log.info("\nTransfering table message") - query = ("INSERT INTO message (id, history_uid, message, language) SELECT id, " - "history_uid, message, language FROM message_old") - yield self.dbpool.runQuery(query) - - log.info("\nTransfering table subject") - query = ("INSERT INTO subject (id, history_uid, subject, language) SELECT id, " - "history_uid, subject, language FROM subject_old") - yield self.dbpool.runQuery(query) - - log.info("\nTransfering table thread") - query = ("INSERT INTO thread (id, history_uid, thread_id, parent_id) SELECT id" - ", history_uid, thread_id, parent_id FROM thread_old") - yield self.dbpool.runQuery(query) - - log.info("\nRemoving old tables") - # because of foreign keys, tables referencing history_old - # must be deleted first - yield self.dbpool.runQuery("DROP TABLE thread_old") - yield self.dbpool.runQuery("DROP TABLE subject_old") - yield self.dbpool.runQuery("DROP TABLE message_old") - yield self.dbpool.runQuery("DROP TABLE history_old") - log.info("\nReducing database size (this can take a while)") - yield self.dbpool.runQuery("VACUUM") - log.info("Database update done :)") - - @defer.inlineCallbacks - def update_v3(self): - """Update database from v2 to v3 (message refactoring)""" - # XXX: this update do all the messages in one huge transaction - # this is really memory consuming, but was OK on a reasonably - # big database for tests. If issues are happening, we can cut it - # in smaller transactions using LIMIT and by deleting already updated - # messages - log.info("Database update to v3, this may take a while") - - # we need to fix duplicate timestamp, as it can result in conflicts with the new schema - rows = yield self.dbpool.runQuery("SELECT timestamp, COUNT(*) as c FROM history GROUP BY timestamp HAVING c>1") - if rows: - log.info("fixing duplicate timestamp") - fixed = [] - for timestamp, __ in rows: - ids_rows = yield self.dbpool.runQuery("SELECT id from history where timestamp=?", (timestamp,)) - for idx, (id_,) in enumerate(ids_rows): - fixed.append(id_) - yield self.dbpool.runQuery("UPDATE history SET timestamp=? WHERE id=?", (float(timestamp) + idx * 0.001, id_)) - log.info("fixed messages with ids {}".format(', '.join([str(id_) for id_ in fixed]))) - - def historySchema(txn): - log.info("History schema update") - txn.execute("ALTER TABLE history RENAME TO tmp_sat_update") - txn.execute("CREATE TABLE history (uid TEXT PRIMARY KEY, update_uid TEXT, profile_id INTEGER, source TEXT, dest TEXT, source_res TEXT, dest_res TEXT, timestamp DATETIME NOT NULL, received_timestamp DATETIME, type TEXT, extra BLOB, FOREIGN KEY(profile_id) REFERENCES profiles(id) ON DELETE CASCADE, FOREIGN KEY(type) REFERENCES message_types(type), UNIQUE (profile_id, timestamp, source, dest, source_res, dest_res))") - txn.execute("INSERT INTO history (uid, profile_id, source, dest, source_res, dest_res, timestamp, type, extra) SELECT id, profile_id, source, dest, source_res, dest_res, timestamp, type, extra FROM tmp_sat_update") - - yield self.dbpool.runInteraction(historySchema) - - def newTables(txn): - log.info("Creating new tables") - txn.execute("CREATE TABLE message (id INTEGER PRIMARY KEY ASC, history_uid INTEGER, message TEXT, language TEXT, FOREIGN KEY(history_uid) REFERENCES history(uid) ON DELETE CASCADE)") - txn.execute("CREATE TABLE thread (id INTEGER PRIMARY KEY ASC, history_uid INTEGER, thread_id TEXT, parent_id TEXT, FOREIGN KEY(history_uid) REFERENCES history(uid) ON DELETE CASCADE)") - txn.execute("CREATE TABLE subject (id INTEGER PRIMARY KEY ASC, history_uid INTEGER, subject TEXT, language TEXT, FOREIGN KEY(history_uid) REFERENCES history(uid) ON DELETE CASCADE)") - - yield self.dbpool.runInteraction(newTables) - - log.info("inserting new message type") - yield self.dbpool.runQuery("INSERT INTO message_types VALUES (?)", ('info',)) - - log.info("messages update") - rows = yield self.dbpool.runQuery("SELECT id, timestamp, message, extra FROM tmp_sat_update") - total = len(rows) - - def updateHistory(txn, queries): - for query, args in iter(queries): - txn.execute(query, args) - - queries = [] - for idx, row in enumerate(rows, 1): - if idx % 1000 == 0 or total - idx == 0: - log.info("preparing message {}/{}".format(idx, total)) - id_, timestamp, message, extra = row - try: - extra = self._load_pickle(extra or b"") - except EOFError: - extra = {} - except Exception: - log.warning("Can't handle extra data for message id {}, ignoring it".format(id_)) - extra = {} - - queries.append(("INSERT INTO message(history_uid, message) VALUES (?,?)", (id_, message))) - - try: - subject = extra.pop('subject') - except KeyError: - pass - else: - try: - subject = subject - except UnicodeEncodeError: - log.warning("Error while decoding subject, ignoring it") - del extra['subject'] - else: - queries.append(("INSERT INTO subject(history_uid, subject) VALUES (?,?)", (id_, subject))) - - received_timestamp = extra.pop('timestamp', None) - try: - del extra['archive'] - except KeyError: - # archive was not used - pass - - queries.append(("UPDATE history SET received_timestamp=?,extra=? WHERE uid=?",(id_, received_timestamp, sqlite3.Binary(pickle.dumps(extra, 0))))) - - yield self.dbpool.runInteraction(updateHistory, queries) - - log.info("Dropping temporary table") - yield self.dbpool.runQuery("DROP TABLE tmp_sat_update") - log.info("Database update finished :)") - - def update2raw_v2(self): - """Update the database from v1 to v2 (add passwords encryptions): - - - the XMPP password value is re-used for the profile password (new parameter) - - the profile password is stored hashed - - the XMPP password is stored encrypted, with the profile password as key - - as there are no other stored passwords yet, it is enough, otherwise we - would need to encrypt the other passwords as it's done for XMPP password - """ - xmpp_pass_path = ('Connection', 'Password') - - def encrypt_values(values): - ret = [] - list_ = [] - - def prepare_queries(result, xmpp_password): - try: - id_ = result[0][0] - except IndexError: - log.error("Profile of id %d is referenced in 'param_ind' but it doesn't exist!" % profile_id) - return defer.succeed(None) - - sat_password = xmpp_password - sat_cipher = PasswordHasher.hash(sat_password) - personal_key = BlockCipher.getRandomKey(base64=True) - personal_cipher = BlockCipher.encrypt(sat_password, personal_key) - xmpp_cipher = BlockCipher.encrypt(personal_key, xmpp_password) - - ret.append("INSERT INTO param_ind(category,name,profile_id,value) VALUES ('%s','%s',%s,'%s')" % - (C.PROFILE_PASS_PATH[0], C.PROFILE_PASS_PATH[1], id_, sat_cipher)) - - ret.append("INSERT INTO private_ind(namespace,key,profile_id,value) VALUES ('%s','%s',%s,'%s')" % - (C.MEMORY_CRYPTO_NAMESPACE, C.MEMORY_CRYPTO_KEY, id_, personal_cipher)) - - ret.append("REPLACE INTO param_ind(category,name,profile_id,value) VALUES ('%s','%s',%s,'%s')" % - (xmpp_pass_path[0], xmpp_pass_path[1], id_, xmpp_cipher)) - - - for profile_id, xmpp_password in values: - d = self.dbpool.runQuery("SELECT id FROM profiles WHERE id=?", (profile_id,)) - d.addCallback(prepare_queries, xmpp_password) - list_.append(d) - - d_list = defer.DeferredList(list_) - d_list.addCallback(lambda __: ret) - return d_list - - def updateLiberviaConf(values): - try: - profile_id = values[0][0] - except IndexError: - return # no profile called "libervia" - - def cb(selected): - try: - password = selected[0][0] - except IndexError: - log.error("Libervia profile exists but no password is set! Update Libervia configuration will be skipped.") - return - fixConfigOption('libervia', 'passphrase', password, False) - d = self.dbpool.runQuery("SELECT value FROM param_ind WHERE category=? AND name=? AND profile_id=?", xmpp_pass_path + (profile_id,)) - return d.addCallback(cb) - - d = self.dbpool.runQuery("SELECT id FROM profiles WHERE name='libervia'") - d.addCallback(updateLiberviaConf) - d.addCallback(lambda __: self.dbpool.runQuery("SELECT profile_id,value FROM param_ind WHERE category=? AND name=?", xmpp_pass_path)) - d.addCallback(encrypt_values) - return d
--- a/sat/plugins/plugin_comp_file_sharing.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_comp_file_sharing.py Wed Sep 01 13:47:17 2021 +0200 @@ -31,6 +31,7 @@ from sat.core.log import getLogger from sat.tools import stream from sat.tools import video +from sat.tools.utils import ensure_deferred from sat.tools.common import regex from sat.tools.common import uri from sat.tools.common import files_utils @@ -487,7 +488,7 @@ else: await self.generate_thumbnails(extra, thumb_path) - self.host.memory.setFile( + await self.host.memory.setFile( client, name=name, version="", @@ -546,8 +547,7 @@ ) return False, defer.succeed(True) - @defer.inlineCallbacks - def _retrieveFiles( + async def _retrieveFiles( self, client, session, content_data, content_name, file_data, file_elt ): """This method retrieve a file on request, and send if after checking permissions""" @@ -557,7 +557,7 @@ else: owner = peer_jid try: - found_files = yield self.host.memory.getFiles( + found_files = await self.host.memory.getFiles( client, peer_jid=peer_jid, name=file_data.get("name"), @@ -575,13 +575,13 @@ peer_jid=peer_jid, name=file_data.get("name") ) ) - defer.returnValue(False) + return False if not found_files: log.warning( _("no matching file found ({file_data})").format(file_data=file_data) ) - defer.returnValue(False) + return False # we only use the first found file found_file = found_files[0] @@ -607,7 +607,7 @@ size=size, data_cb=lambda data: hasher.update(data), ) - defer.returnValue(True) + return True def _fileSendingRequestTrigger( self, client, session, content_data, content_name, file_data, file_elt @@ -617,9 +617,9 @@ else: return ( False, - self._retrieveFiles( + defer.ensureDeferred(self._retrieveFiles( client, session, content_data, content_name, file_data, file_elt - ), + )), ) ## HTTP Upload ## @@ -757,11 +757,10 @@ raise error.StanzaError("item-not-found") return file_id - @defer.inlineCallbacks - def getFileData(self, requestor, nodeIdentifier): + async def getFileData(self, requestor, nodeIdentifier): file_id = self._getFileId(nodeIdentifier) try: - files = yield self.host.memory.getFiles(self.parent, requestor, file_id) + files = await self.host.memory.getFiles(self.parent, requestor, file_id) except (exceptions.NotFound, exceptions.PermissionError): # we don't differenciate between NotFound and PermissionError # to avoid leaking information on existing files @@ -770,7 +769,7 @@ raise error.StanzaError("item-not-found") if len(files) > 1: raise error.InternalError("there should be only one file") - defer.returnValue(files[0]) + return files[0] def commentsUpdate(self, extra, new_comments, peer_jid): """update comments (replace or insert new_comments) @@ -825,10 +824,10 @@ iq_elt = iq_elt.parent return iq_elt["from"] - @defer.inlineCallbacks - def publish(self, requestor, service, nodeIdentifier, items): + @ensure_deferred + async def publish(self, requestor, service, nodeIdentifier, items): # we retrieve file a first time to check authorisations - file_data = yield self.getFileData(requestor, nodeIdentifier) + file_data = await self.getFileData(requestor, nodeIdentifier) file_id = file_data["id"] comments = [(item["id"], self._getFrom(item), item.toXml()) for item in items] if requestor.userhostJID() == file_data["owner"]: @@ -837,24 +836,24 @@ peer_jid = requestor.userhost() update_cb = partial(self.commentsUpdate, new_comments=comments, peer_jid=peer_jid) try: - yield self.host.memory.fileUpdate(file_id, "extra", update_cb) + await self.host.memory.fileUpdate(file_id, "extra", update_cb) except exceptions.PermissionError: raise error.StanzaError("not-authorized") - @defer.inlineCallbacks - def items(self, requestor, service, nodeIdentifier, maxItems, itemIdentifiers): - file_data = yield self.getFileData(requestor, nodeIdentifier) + @ensure_deferred + async def items(self, requestor, service, nodeIdentifier, maxItems, itemIdentifiers): + file_data = await self.getFileData(requestor, nodeIdentifier) comments = file_data["extra"].get("comments", []) if itemIdentifiers: defer.returnValue( [generic.parseXml(c[2]) for c in comments if c[0] in itemIdentifiers] ) else: - defer.returnValue([generic.parseXml(c[2]) for c in comments]) + return [generic.parseXml(c[2]) for c in comments] - @defer.inlineCallbacks - def retract(self, requestor, service, nodeIdentifier, itemIdentifiers): - file_data = yield self.getFileData(requestor, nodeIdentifier) + @ensure_deferred + async def retract(self, requestor, service, nodeIdentifier, itemIdentifiers): + file_data = await self.getFileData(requestor, nodeIdentifier) file_id = file_data["id"] try: comments = file_data["extra"]["comments"] @@ -879,4 +878,4 @@ raise error.StanzaError("not-authorized") remove_cb = partial(self.commentsDelete, comments=to_remove) - yield self.host.memory.fileUpdate(file_id, "extra", remove_cb) + await self.host.memory.fileUpdate(file_id, "extra", remove_cb)
--- a/sat/plugins/plugin_comp_file_sharing_management.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_comp_file_sharing_management.py Wed Sep 01 13:47:17 2021 +0200 @@ -149,8 +149,7 @@ payload = form.toElement() return payload, status, None, None - @defer.inlineCallbacks - def _getFileData(self, client, session_data, command_form): + async def _getFileData(self, client, session_data, command_form): """Retrieve field requested in root form "found_file" will also be set in session_data @@ -177,7 +176,7 @@ # this must be managed try: - found_files = yield self.host.memory.getFiles( + found_files = await self.host.memory.getFiles( client, requestor_bare, path=parent_path, name=basename, namespace=namespace) found_file = found_files[0] @@ -193,7 +192,7 @@ session_data['found_file'] = found_file session_data['namespace'] = namespace - defer.returnValue(found_file) + return found_file def _updateReadPermission(self, access, allowed_jids): if not allowed_jids: @@ -209,29 +208,27 @@ "jids": [j.full() for j in allowed_jids] } - @defer.inlineCallbacks - def _updateDir(self, client, requestor, namespace, file_data, allowed_jids): + async def _updateDir(self, client, requestor, namespace, file_data, allowed_jids): """Recursively update permission of a directory and all subdirectories @param file_data(dict): metadata of the file @param allowed_jids(list[jid.JID]): list of entities allowed to read the file """ assert file_data['type'] == C.FILE_TYPE_DIRECTORY - files_data = yield self.host.memory.getFiles( + files_data = await self.host.memory.getFiles( client, requestor, parent=file_data['id'], namespace=namespace) for file_data in files_data: if not file_data['access'].get(C.ACCESS_PERM_READ, {}): log.debug("setting {perm} read permission for {name}".format( perm=allowed_jids, name=file_data['name'])) - yield self.host.memory.fileUpdate( + await self.host.memory.fileUpdate( file_data['id'], 'access', partial(self._updateReadPermission, allowed_jids=allowed_jids)) if file_data['type'] == C.FILE_TYPE_DIRECTORY: - yield self._updateDir(client, requestor, namespace, file_data, 'PUBLIC') + await self._updateDir(client, requestor, namespace, file_data, 'PUBLIC') - @defer.inlineCallbacks - def _onChangeFile(self, client, command_elt, session_data, action, node): + async def _onChangeFile(self, client, command_elt, session_data, action, node): try: x_elt = next(command_elt.elements(data_form.NS_X_DATA, "x")) command_form = data_form.Form.fromElement(x_elt) @@ -244,14 +241,14 @@ if command_form is None or len(command_form.fields) == 0: # root request - defer.returnValue(self._getRootArgs()) + return self._getRootArgs() elif found_file is None: # file selected, we retrieve it and ask for permissions try: - found_file = yield self._getFileData(client, session_data, command_form) + found_file = await self._getFileData(client, session_data, command_form) except WorkflowError as e: - defer.returnValue(e.err_args) + return e.err_args # management request if found_file['type'] == C.FILE_TYPE_DIRECTORY: @@ -284,7 +281,7 @@ status = self._c.STATUS.EXECUTING payload = form.toElement() - defer.returnValue((payload, status, None, None)) + return (payload, status, None, None) else: # final phase, we'll do permission change here @@ -307,7 +304,7 @@ self._c.adHocError(self._c.ERROR.BAD_PAYLOAD) if found_file['type'] == C.FILE_TYPE_FILE: - yield self.host.memory.fileUpdate( + await self.host.memory.fileUpdate( found_file['id'], 'access', partial(self._updateReadPermission, allowed_jids=allowed_jids)) else: @@ -315,7 +312,7 @@ recursive = command_form.fields['recursive'] except KeyError: self._c.adHocError(self._c.ERROR.BAD_PAYLOAD) - yield self.host.memory.fileUpdate( + await self.host.memory.fileUpdate( found_file['id'], 'access', partial(self._updateReadPermission, allowed_jids=allowed_jids)) if recursive: @@ -323,17 +320,16 @@ # already a permission set), so allowed entities of root directory # can read them. namespace = session_data['namespace'] - yield self._updateDir( + await self._updateDir( client, requestor_bare, namespace, found_file, 'PUBLIC') # job done, we can end the session status = self._c.STATUS.COMPLETED payload = None note = (self._c.NOTE.INFO, _("management session done")) - defer.returnValue((payload, status, None, note)) + return (payload, status, None, note) - @defer.inlineCallbacks - def _onDeleteFile(self, client, command_elt, session_data, action, node): + async def _onDeleteFile(self, client, command_elt, session_data, action, node): try: x_elt = next(command_elt.elements(data_form.NS_X_DATA, "x")) command_form = data_form.Form.fromElement(x_elt) @@ -346,14 +342,14 @@ if command_form is None or len(command_form.fields) == 0: # root request - defer.returnValue(self._getRootArgs()) + return self._getRootArgs() elif found_file is None: # file selected, we need confirmation before actually deleting try: - found_file = yield self._getFileData(client, session_data, command_form) + found_file = await self._getFileData(client, session_data, command_form) except WorkflowError as e: - defer.returnValue(e.err_args) + return e.err_args if found_file['type'] == C.FILE_TYPE_DIRECTORY: msg = D_("Are you sure to delete directory {name} and all files and " "directories under it?").format(name=found_file['name']) @@ -370,7 +366,7 @@ form.addField(field) status = self._c.STATUS.EXECUTING payload = form.toElement() - defer.returnValue((payload, status, None, None)) + return (payload, status, None, None) else: # final phase, we'll do deletion here @@ -382,27 +378,26 @@ note = None else: recursive = found_file['type'] == C.FILE_TYPE_DIRECTORY - yield self.host.memory.fileDelete( + await self.host.memory.fileDelete( client, requestor_bare, found_file['id'], recursive) note = (self._c.NOTE.INFO, _("file deleted")) status = self._c.STATUS.COMPLETED payload = None - defer.returnValue((payload, status, None, note)) + return (payload, status, None, note) def _updateThumbs(self, extra, thumbnails): extra[C.KEY_THUMBNAILS] = thumbnails - @defer.inlineCallbacks - def _genThumbs(self, client, requestor, namespace, file_data): + async def _genThumbs(self, client, requestor, namespace, file_data): """Recursively generate thumbnails @param file_data(dict): metadata of the file """ if file_data['type'] == C.FILE_TYPE_DIRECTORY: - sub_files_data = yield self.host.memory.getFiles( + sub_files_data = await self.host.memory.getFiles( client, requestor, parent=file_data['id'], namespace=namespace) for sub_file_data in sub_files_data: - yield self._genThumbs(client, requestor, namespace, sub_file_data) + await self._genThumbs(client, requestor, namespace, sub_file_data) elif file_data['type'] == C.FILE_TYPE_FILE: media_type = file_data['media_type'] @@ -412,7 +407,7 @@ for max_thumb_size in self._t.SIZES: try: - thumb_size, thumb_id = yield self._t.generateThumbnail( + thumb_size, thumb_id = await self._t.generateThumbnail( file_path, max_thumb_size, # we keep thumbnails for 6 months @@ -424,7 +419,7 @@ break thumbnails.append({"id": thumb_id, "size": thumb_size}) - yield self.host.memory.fileUpdate( + await self.host.memory.fileUpdate( file_data['id'], 'extra', partial(self._updateThumbs, thumbnails=thumbnails)) @@ -434,8 +429,7 @@ else: log.warning("unmanaged file type: {type_}".format(type_=file_data['type'])) - @defer.inlineCallbacks - def _onGenThumbnails(self, client, command_elt, session_data, action, node): + async def _onGenThumbnails(self, client, command_elt, session_data, action, node): try: x_elt = next(command_elt.elements(data_form.NS_X_DATA, "x")) command_form = data_form.Form.fromElement(x_elt) @@ -447,23 +441,23 @@ if command_form is None or len(command_form.fields) == 0: # root request - defer.returnValue(self._getRootArgs()) + return self._getRootArgs() elif found_file is None: # file selected, we retrieve it and ask for permissions try: - found_file = yield self._getFileData(client, session_data, command_form) + found_file = await self._getFileData(client, session_data, command_form) except WorkflowError as e: - defer.returnValue(e.err_args) + return e.err_args log.info("Generating thumbnails as requested") - yield self._genThumbs(client, requestor, found_file['namespace'], found_file) + await self._genThumbs(client, requestor, found_file['namespace'], found_file) # job done, we can end the session status = self._c.STATUS.COMPLETED payload = None note = (self._c.NOTE.INFO, _("thumbnails generated")) - defer.returnValue((payload, status, None, note)) + return (payload, status, None, note) async def _onQuota(self, client, command_elt, session_data, action, node): requestor = session_data['requestor']
--- a/sat/plugins/plugin_exp_events.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_exp_events.py Wed Sep 01 13:47:17 2021 +0200 @@ -212,8 +212,7 @@ data["creator"] = True return timestamp, data - @defer.inlineCallbacks - def getEventElement(self, client, service, node, id_): + async def getEventElement(self, client, service, node, id_): """Retrieve event element @param service(jid.JID): pubsub service @@ -224,23 +223,24 @@ """ if not id_: id_ = NS_EVENT - items, metadata = yield self._p.getItems(client, service, node, item_ids=[id_]) + items, metadata = await self._p.getItems(client, service, node, item_ids=[id_]) try: event_elt = next(items[0].elements(NS_EVENT, "event")) except StopIteration: raise exceptions.NotFound(_("No event element has been found")) except IndexError: raise exceptions.NotFound(_("No event with this id has been found")) - defer.returnValue(event_elt) + return event_elt def _eventGet(self, service, node, id_="", profile_key=C.PROF_KEY_NONE): service = jid.JID(service) if service else None node = node if node else NS_EVENT client = self.host.getClient(profile_key) - return self.eventGet(client, service, node, id_) + return defer.ensureDeferred( + self.eventGet(client, service, node, id_) + ) - @defer.inlineCallbacks - def eventGet(self, client, service, node, id_=NS_EVENT): + async def eventGet(self, client, service, node, id_=NS_EVENT): """Retrieve event data @param service(unicode, None): PubSub service @@ -253,9 +253,9 @@ image: URL of a picture to use to represent event background-image: URL of a picture to use in background """ - event_elt = yield self.getEventElement(client, service, node, id_) + event_elt = await self.getEventElement(client, service, node, id_) - defer.returnValue(self._parseEventElt(event_elt)) + return self._parseEventElt(event_elt) def _eventCreate( self, timestamp, data, service, node, id_="", profile_key=C.PROF_KEY_NONE @@ -436,10 +436,11 @@ node = node if node else NS_EVENT client = self.host.getClient(profile_key) invitee_jid = jid.JID(invitee_jid_s) if invitee_jid_s else None - return self.eventInviteeGet(client, service, node, invitee_jid) + return defer.ensureDeferred( + self.eventInviteeGet(client, service, node, invitee_jid) + ) - @defer.inlineCallbacks - def eventInviteeGet(self, client, service, node, invitee_jid=None): + async def eventInviteeGet(self, client, service, node, invitee_jid=None): """Retrieve attendance from event node @param service(unicode, None): PubSub service @@ -452,28 +453,30 @@ if invitee_jid is None: invitee_jid = client.jid try: - items, metadata = yield self._p.getItems( + items, metadata = await self._p.getItems( client, service, node, item_ids=[invitee_jid.userhost()] ) event_elt = next(items[0].elements(NS_EVENT, "invitee")) except (exceptions.NotFound, IndexError): # no item found, event data are not set yet - defer.returnValue({}) + return {} data = {} for key in ("attend", "guests"): try: data[key] = event_elt[key] except KeyError: continue - defer.returnValue(data) + return data def _eventInviteeSet(self, service, node, event_data, profile_key): service = jid.JID(service) if service else None node = node if node else NS_EVENT client = self.host.getClient(profile_key) - return self.eventInviteeSet(client, service, node, event_data) + return defer.ensureDeferred( + self.eventInviteeSet(client, service, node, event_data) + ) - def eventInviteeSet(self, client, service, node, data): + async def eventInviteeSet(self, client, service, node, data): """Set or update attendance data in event node @param service(unicode, None): PubSub service @@ -490,16 +493,17 @@ except KeyError: pass item_elt = pubsub.Item(id=client.jid.userhost(), payload=event_elt) - return self._p.publish(client, service, node, items=[item_elt]) + return await self._p.publish(client, service, node, items=[item_elt]) def _eventInviteesList(self, service, node, profile_key): service = jid.JID(service) if service else None node = node if node else NS_EVENT client = self.host.getClient(profile_key) - return self.eventInviteesList(client, service, node) + return defer.ensureDeferred( + self.eventInviteesList(client, service, node) + ) - @defer.inlineCallbacks - def eventInviteesList(self, client, service, node): + async def eventInviteesList(self, client, service, node): """Retrieve attendance from event node @param service(unicode, None): PubSub service @@ -507,7 +511,7 @@ @return (dict): a dict with current attendance status, an empty dict is returned if nothing has been answered yed """ - items, metadata = yield self._p.getItems(client, service, node) + items, metadata = await self._p.getItems(client, service, node) invitees = {} for item in items: try: @@ -525,7 +529,7 @@ except KeyError: continue invitees[item["id"]] = data - defer.returnValue(invitees) + return invitees async def invitePreflight( self,
--- a/sat/plugins/plugin_exp_invitation_pubsub.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_exp_invitation_pubsub.py Wed Sep 01 13:47:17 2021 +0200 @@ -164,6 +164,6 @@ if not name: name = extra.pop("name", "") - return self.host.plugins['LIST_INTEREST'].registerPubsub( + return await self.host.plugins['LIST_INTEREST'].registerPubsub( client, namespace, service, node, item_id, creator, name, element, extra)
--- a/sat/plugins/plugin_exp_list_of_interest.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_exp_list_of_interest.py Wed Sep 01 13:47:17 2021 +0200 @@ -99,8 +99,7 @@ if e.condition == "conflict": log.debug(_("requested node already exists")) - @defer.inlineCallbacks - def registerPubsub(self, client, namespace, service, node, item_id=None, + async def registerPubsub(self, client, namespace, service, node, item_id=None, creator=False, name=None, element=None, extra=None): """Register an interesting element in personal list @@ -120,7 +119,7 @@ """ if extra is None: extra = {} - yield self.createNode(client) + await self.createNode(client) interest_elt = domish.Element((NS_LIST_INTEREST, "interest")) interest_elt["namespace"] = namespace if name is not None: @@ -146,7 +145,7 @@ interest_uri = uri.buildXMPPUri("pubsub", **uri_kwargs) # we use URI of the interest as item id to avoid duplicates item_elt = pubsub.Item(interest_uri, payload=interest_elt) - yield self._p.publish( + await self._p.publish( client, client.jid.userhostJID(), NS_LIST_INTEREST, items=[item_elt] ) @@ -258,12 +257,11 @@ node = node or None namespace = namespace or None client = self.host.getClient(profile) - d = self.listInterests(client, service, node, namespace) + d = defer.ensureDeferred(self.listInterests(client, service, node, namespace)) d.addCallback(self._listInterestsSerialise) return d - @defer.inlineCallbacks - def listInterests(self, client, service=None, node=None, namespace=None): + async def listInterests(self, client, service=None, node=None, namespace=None): """Retrieve list of interests @param service(jid.JID, None): service to use @@ -277,7 +275,7 @@ # TODO: if a MAM filter were available, it would improve performances if not node: node = NS_LIST_INTEREST - items, metadata = yield self._p.getItems(client, service, node) + items, metadata = await self._p.getItems(client, service, node) if namespace is not None: filtered_items = [] for item in items: @@ -291,7 +289,7 @@ filtered_items.append(item) items = filtered_items - defer.returnValue((items, metadata)) + return (items, metadata) def _interestRetract(self, service_s, item_id, profile_key): d = self._p._retractItem(
--- a/sat/plugins/plugin_misc_forums.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_misc_forums.py Wed Sep 01 13:47:17 2021 +0200 @@ -189,23 +189,22 @@ service = None if not node.strip(): node = None - d=self.get(client, service, node, forums_key or None) + d = defer.ensureDeferred(self.get(client, service, node, forums_key or None)) d.addCallback(lambda data: json.dumps(data)) return d - @defer.inlineCallbacks - def get(self, client, service=None, node=None, forums_key=None): + async def get(self, client, service=None, node=None, forums_key=None): if service is None: service = client.pubsub_service if node is None: node = NS_FORUMS if forums_key is None: forums_key = 'default' - items_data = yield self._p.getItems(client, service, node, item_ids=[forums_key]) + items_data = await self._p.getItems(client, service, node, item_ids=[forums_key]) item = items_data[0][0] # we have the item and need to convert it to json forums = self._parseForums(item) - defer.returnValue(forums) + return forums def _set(self, forums, service=None, node=None, forums_key=None, profile_key=C.PROF_KEY_NONE): client = self.host.getClient(profile_key) @@ -216,10 +215,11 @@ service = None if not node.strip(): node = None - return self.set(client, forums, service, node, forums_key or None) + return defer.ensureDeferred( + self.set(client, forums, service, node, forums_key or None) + ) - @defer.inlineCallbacks - def set(self, client, forums, service=None, node=None, forums_key=None): + async def set(self, client, forums, service=None, node=None, forums_key=None): """Create or replace forums structure @param forums(list): list of dictionary as follow: @@ -242,25 +242,33 @@ node = NS_FORUMS if forums_key is None: forums_key = 'default' - forums_elt = yield self._createForums(client, forums, service, node) - yield self._p.sendItem(client, service, node, forums_elt, item_id=forums_key) + forums_elt = await self._createForums(client, forums, service, node) + return await self._p.sendItem( + client, service, node, forums_elt, item_id=forums_key + ) def _getTopics(self, service, node, extra=None, profile_key=C.PROF_KEY_NONE): client = self.host.getClient(profile_key) extra = self._p.parseExtra(extra) - d = self.getTopics(client, jid.JID(service), node, rsm_request=extra.rsm_request, extra=extra.extra) + d = defer.ensureDeferred( + self.getTopics( + client, jid.JID(service), node, rsm_request=extra.rsm_request, + extra=extra.extra + ) + ) d.addCallback( lambda topics_data: (topics_data[0], data_format.serialise(topics_data[1])) ) return d - @defer.inlineCallbacks - def getTopics(self, client, service, node, rsm_request=None, extra=None): + async def getTopics(self, client, service, node, rsm_request=None, extra=None): """Retrieve topics data Topics are simple microblog URIs with some metadata duplicated from first post """ - topics_data = yield self._p.getItems(client, service, node, rsm_request=rsm_request, extra=extra) + topics_data = await self._p.getItems( + client, service, node, rsm_request=rsm_request, extra=extra + ) topics = [] item_elts, metadata = topics_data for item_elt in item_elts: @@ -270,7 +278,7 @@ 'author': topic_elt['author'], 'title': str(title_elt)} topics.append(topic) - defer.returnValue((topics, metadata)) + return (topics, metadata) def _createTopic(self, service, node, mb_data, profile_key): client = self.host.getClient(profile_key)
--- a/sat/plugins/plugin_misc_identity.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_misc_identity.py Wed Sep 01 13:47:17 2021 +0200 @@ -142,9 +142,6 @@ stored_data = await client._identity_storage.all() - self.host.memory.storage.getPrivates( - namespace="identity", binary=True, profile=client.profile) - to_delete = [] for key, value in stored_data.items():
--- a/sat/plugins/plugin_misc_lists.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_misc_lists.py Wed Sep 01 13:47:17 2021 +0200 @@ -219,7 +219,7 @@ host.bridge.addMethod( "listGet", ".plugin", - in_sign="ssiassa{ss}s", + in_sign="ssiassss", out_sign="s", method=lambda service, node, max_items, items_ids, sub_id, extra, profile_key: self._s._get(
--- a/sat/plugins/plugin_misc_merge_requests.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_misc_merge_requests.py Wed Sep 01 13:47:17 2021 +0200 @@ -77,7 +77,7 @@ self._handlers_list = [] # handlers sorted by priority self._type_handlers = {} # data type => handler map host.bridge.addMethod("mergeRequestsGet", ".plugin", - in_sign='ssiassa{ss}s', out_sign='s', + in_sign='ssiassss', out_sign='s', method=self._get, async_=True ) @@ -149,11 +149,10 @@ }) def _get(self, service='', node='', max_items=10, item_ids=None, sub_id=None, - extra_dict=None, profile_key=C.PROF_KEY_NONE): - if extra_dict and 'parse' in extra_dict: - extra_dict['parse'] = C.bool(extra_dict['parse']) + extra="", profile_key=C.PROF_KEY_NONE): + extra = data_format.deserialise(extra) client, service, node, max_items, extra, sub_id = self._s.prepareBridgeGet( - service, node, max_items, sub_id, extra_dict, profile_key) + service, node, max_items, sub_id, extra, profile_key) d = self.get(client, service, node or None, max_items, item_ids, sub_id or None, extra.rsm_request, extra.extra) d.addCallback(self.serialise)
--- a/sat/plugins/plugin_misc_text_syntaxes.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_misc_text_syntaxes.py Wed Sep 01 13:47:17 2021 +0200 @@ -156,33 +156,35 @@ SYNTAX_XHTML = _SYNTAX_XHTML SYNTAX_MARKDOWN = "markdown" SYNTAX_TEXT = "text" - syntaxes = {} # default_syntax must be lower case default_syntax = SYNTAX_XHTML - params = """ - <params> - <individual> - <category name="%(category_name)s" label="%(category_label)s"> - <param name="%(name)s" label="%(label)s" type="list" security="0"> - %(options)s - </param> - </category> - </individual> - </params> - """ - - params_data = { - "category_name": CATEGORY, - "category_label": _(CATEGORY), - "name": NAME, - "label": _(NAME), - "syntaxes": syntaxes, - } def __init__(self, host): log.info(_("Text syntaxes plugin initialization")) self.host = host + self.syntaxes = {} + + self.params = """ + <params> + <individual> + <category name="%(category_name)s" label="%(category_label)s"> + <param name="%(name)s" label="%(label)s" type="list" security="0"> + %(options)s + </param> + </category> + </individual> + </params> + """ + + self.params_data = { + "category_name": CATEGORY, + "category_label": _(CATEGORY), + "name": NAME, + "label": _(NAME), + "syntaxes": self.syntaxes, + } + self.addSyntax( self.SYNTAX_XHTML, lambda xhtml: defer.succeed(xhtml), @@ -253,7 +255,7 @@ xml_tools.cleanXHTML = self.cleanXHTML def _updateParamOptions(self): - data_synt = TextSyntaxes.syntaxes + data_synt = self.syntaxes default_synt = TextSyntaxes.default_syntax syntaxes = [] @@ -269,8 +271,8 @@ selected = 'selected="true"' if syntax == default_synt else "" options.append('<option value="%s" %s/>' % (syntax, selected)) - TextSyntaxes.params_data["options"] = "\n".join(options) - self.host.memory.updateParams(TextSyntaxes.params % TextSyntaxes.params_data) + self.params_data["options"] = "\n".join(options) + self.host.memory.updateParams(self.params % self.params_data) def getCurrentSyntax(self, profile): """ Return the selected syntax for the given profile @@ -372,7 +374,7 @@ syntax_to = self.getCurrentSyntax(profile) else: syntax_to = syntax_to.lower().strip() - syntaxes = TextSyntaxes.syntaxes + syntaxes = self.syntaxes if syntax_from not in syntaxes: raise exceptions.NotFound(syntax_from) if syntax_to not in syntaxes: @@ -417,7 +419,7 @@ ) ) - syntaxes = TextSyntaxes.syntaxes + syntaxes = self.syntaxes key = name.lower().strip() if key in syntaxes: raise exceptions.ConflictError(
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sat/plugins/plugin_pubsub_cache.py Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,778 @@ +#!/usr/bin/env python3 + +# Libervia plugin for PubSub Caching +# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import time +from datetime import datetime +from typing import Optional, List, Tuple +from twisted.words.protocols.jabber import jid, error +from twisted.words.xish import domish +from twisted.internet import defer +from wokkel import pubsub, rsm +from sat.core.i18n import _ +from sat.core.constants import Const as C +from sat.core import exceptions +from sat.core.log import getLogger +from sat.core.core_types import SatXMPPEntity +from sat.tools import xml_tools, utils +from sat.tools.common import data_format +from sat.memory.sqla import PubsubNode, PubsubItem, SyncState + + +log = getLogger(__name__) + +PLUGIN_INFO = { + C.PI_NAME: "PubSub Cache", + C.PI_IMPORT_NAME: "PUBSUB_CACHE", + C.PI_TYPE: C.PLUG_TYPE_PUBSUB, + C.PI_PROTOCOLS: [], + C.PI_DEPENDENCIES: ["XEP-0059", "XEP-0060"], + C.PI_RECOMMENDATIONS: [], + C.PI_MAIN: "PubsubCache", + C.PI_HANDLER: "no", + C.PI_DESCRIPTION: _("""Local Cache for PubSub"""), +} + +ANALYSER_KEYS_TO_COPY = ("name", "type", "to_sync", "parser") +# maximum of items to cache +CACHE_LIMIT = 5000 +# number of second before a progress caching is considered failed and tried again +PROGRESS_DEADLINE = 60 * 60 * 6 + + + +class PubsubCache: + # TODO: there is currently no notification for (un)subscribe events with XEP-0060, + # but it would be necessary to have this data if some devices unsubscribe a cached + # node, as we can then get out of sync. A protoXEP could be proposed to fix this + # situation. + # TODO: handle configuration events + + def __init__(self, host): + log.info(_("PubSub Cache initialization")) + strategy = host.memory.getConfig(None, "pubsub_cache_strategy") + if strategy == "no_cache": + log.info( + _( + "Pubsub cache won't be used due to pubsub_cache_strategy={value} " + "setting." + ).format(value=repr(strategy)) + ) + self.use_cache = False + else: + self.use_cache = True + self.host = host + self._p = host.plugins["XEP-0060"] + self.analysers = {} + # map for caching in progress (node, service) => Deferred + self.in_progress = {} + self.host.trigger.add("XEP-0060_getItems", self._getItemsTrigger) + self._p.addManagedNode( + "", + items_cb=self.onItemsEvent, + delete_cb=self.onDeleteEvent, + purge_db=self.onPurgeEvent, + ) + host.bridge.addMethod( + "psCacheGet", + ".plugin", + in_sign="ssiassss", + out_sign="s", + method=self._getItemsFromCache, + async_=True, + ) + host.bridge.addMethod( + "psCacheSync", + ".plugin", + "sss", + out_sign="", + method=self._synchronise, + async_=True, + ) + host.bridge.addMethod( + "psCachePurge", + ".plugin", + "s", + out_sign="", + method=self._purge, + async_=True, + ) + host.bridge.addMethod( + "psCacheReset", + ".plugin", + "", + out_sign="", + method=self._reset, + async_=True, + ) + + def registerAnalyser(self, analyser: dict) -> None: + """Register a new pubsub node analyser + + @param analyser: An analyser is a dictionary which may have the following keys + (keys with a ``*`` are mandatory, at least one of ``node`` or ``namespace`` keys + must be used): + + :name (str)*: + a unique name for this analyser. This name will be stored in database + to retrieve the analyser when necessary (notably to get the parsing method), + thus it is recommended to use a stable name such as the source plugin name + instead of a name which may change with standard evolution, such as the + feature namespace. + + :type (str)*: + indicates what kind of items we are dealing with. Type must be a human + readable word, as it may be used in searches. Good types examples are + **blog** or **event**. + + :node (str): + prefix of a node name which may be used to identify its type. Example: + *urn:xmpp:microblog:0* (a node starting with this name will be identified as + *blog* node). + + :namespace (str): + root namespace of items. When analysing a node, the first item will be + retrieved. The analyser will be chosen its given namespace match the + namespace of the first child element of ``<item>`` element. + + :to_sync (bool): + if True, the node must be synchronised in cache. The default False value + means that the pubsub service will always be requested. + + :parser (callable): + method (which may be sync, a coroutine or a method returning a "Deferred") + to call to parse the ``domish.Element`` of the item. The result must be + dictionary which can be serialised to JSON. + + The method must have the following signature: + + .. function:: parser(client: SatXMPPEntity, item_elt: domish.Element, \ + service: Optional[jid.JID], node: Optional[str]) \ + -> dict + :noindex: + + :match_cb (callable): + method (which may be sync, a coroutine or a method returning a "Deferred") + called when the analyser matches. The method is called with the curreny + analyse which is can modify **in-place**. + + The method must have the following signature: + + .. function:: match_cb(client: SatXMPPEntity, analyse: dict) -> None + :noindex: + + @raise exceptions.Conflict: a analyser with this name already exists + """ + + name = analyser.get("name", "").strip().lower() + # we want the normalised name + analyser["name"] = name + if not name: + raise ValueError('"name" is mandatory in analyser') + if "type" not in analyser: + raise ValueError('"type" is mandatory in analyser') + type_test_keys = {"node", "namespace"} + if not type_test_keys.intersection(analyser): + raise ValueError(f'at least one of {type_test_keys} must be used') + if name in self.analysers: + raise exceptions.Conflict( + f"An analyser with the name {name!r} is already registered" + ) + self.analysers[name] = analyser + + async def cacheItems( + self, + client: SatXMPPEntity, + pubsub_node: PubsubNode, + items: List[domish.Element] + ) -> None: + try: + parser = self.analysers[pubsub_node.analyser].get("parser") + except KeyError: + parser = None + + if parser is not None: + parsed_items = [ + await utils.asDeferred( + parser, + client, + item, + pubsub_node.service, + pubsub_node.name + ) + for item in items + ] + else: + parsed_items = None + + await self.host.memory.storage.cachePubsubItems( + client, pubsub_node, items, parsed_items + ) + + async def _cacheNode( + self, + client: SatXMPPEntity, + pubsub_node: PubsubNode + ) -> None: + await self.host.memory.storage.updatePubsubNodeSyncState( + pubsub_node, SyncState.IN_PROGRESS + ) + service, node = pubsub_node.service, pubsub_node.name + try: + log.debug( + f"Caching node {node!r} at {service} for {client.profile}" + ) + if not pubsub_node.subscribed: + try: + sub = await self._p.subscribe(client, service, node) + except Exception as e: + log.warning( + _( + "Can't subscribe node {pubsub_node}, that means that " + "synchronisation can't be maintained: {reason}" + ).format(pubsub_node=pubsub_node, reason=e) + ) + else: + if sub.state == "subscribed": + sub_id = sub.subscriptionIdentifier + log.debug( + f"{pubsub_node} subscribed (subscription id: {sub_id!r})" + ) + pubsub_node.subscribed = True + await self.host.memory.storage.add(pubsub_node) + else: + log.warning( + _( + "{pubsub_node} is not subscribed, that means that " + "synchronisation can't be maintained, and you may have " + "to enforce subscription manually. Subscription state: " + "{state}" + ).format(pubsub_node=pubsub_node, state=sub.state) + ) + + try: + await self.host.checkFeatures( + client, [rsm.NS_RSM, self._p.DISCO_RSM], pubsub_node.service + ) + except exceptions.FeatureNotFound: + log.warning( + f"service {service} doesn't handle Result Set Management " + "(XEP-0059), we'll only cache latest 20 items" + ) + items, __ = await client.pubsub_client.items( + pubsub_node.service, pubsub_node.name, maxItems=20 + ) + await self.cacheItems( + client, pubsub_node, items + ) + else: + rsm_p = self.host.plugins["XEP-0059"] + rsm_request = rsm.RSMRequest() + cached_ids = set() + while True: + items, rsm_response = await client.pubsub_client.items( + service, node, rsm_request=rsm_request + ) + await self.cacheItems( + client, pubsub_node, items + ) + for item in items: + item_id = item["id"] + if item_id in cached_ids: + log.warning( + f"Pubsub node {node!r} at {service} is returning several " + f"times the same item ({item_id!r}). This is illegal " + "behaviour, and it means that Pubsub service " + f"{service} is buggy and can't be cached properly. " + f"Please report this to {service.host} administrators" + ) + rsm_request = None + break + cached_ids.add(item["id"]) + if len(cached_ids) >= CACHE_LIMIT: + log.warning( + f"Pubsub node {node!r} at {service} contains more items " + f"than the cache limit ({CACHE_LIMIT}). We stop " + "caching here, at item {item['id']!r}." + ) + rsm_request = None + break + rsm_request = rsm_p.getNextRequest(rsm_request, rsm_response) + if rsm_request is None: + break + + await self.host.memory.storage.updatePubsubNodeSyncState( + pubsub_node, SyncState.COMPLETED + ) + except Exception as e: + import traceback + tb = traceback.format_tb(e.__traceback__) + log.error( + f"Can't cache node {node!r} at {service} for {client.profile}: {e}\n{tb}" + ) + await self.host.memory.storage.updatePubsubNodeSyncState( + pubsub_node, SyncState.ERROR + ) + await self.host.memory.storage.deletePubsubItems(pubsub_node) + raise e + + def _cacheNodeClean(self, __, pubsub_node): + del self.in_progress[(pubsub_node.service, pubsub_node.name)] + + def cacheNode( + self, + client: SatXMPPEntity, + pubsub_node: PubsubNode + ) -> None: + """Launch node caching as a background task""" + d = defer.ensureDeferred(self._cacheNode(client, pubsub_node)) + d.addBoth(self._cacheNodeClean, pubsub_node=pubsub_node) + self.in_progress[(pubsub_node.service, pubsub_node.name)] = d + return d + + async def analyseNode( + self, + client: SatXMPPEntity, + service: jid.JID, + node: str, + pubsub_node : PubsubNode = None, + ) -> dict: + """Use registered analysers on a node to determine what it is used for""" + analyse = {"service": service, "node": node} + if pubsub_node is None: + try: + first_item = (await client.pubsub_client.items( + service, node, 1 + ))[0][0] + except IndexError: + pass + except error.StanzaError as e: + if e.condition == "item-not-found": + pass + else: + log.warning( + f"Can't retrieve last item on node {node!r} at service " + f"{service} for {client.profile}: {e}" + ) + else: + try: + uri = first_item.firstChildElement().uri + except Exception as e: + log.warning( + f"Can't retrieve item namespace on node {node!r} at service " + f"{service} for {client.profile}: {e}" + ) + else: + analyse["namespace"] = uri + try: + conf = await self._p.getConfiguration(client, service, node) + except Exception as e: + log.warning( + f"Can't retrieve configuration for node {node!r} at service {service} " + f"for {client.profile}: {e}" + ) + else: + analyse["conf"] = conf + + for analyser in self.analysers.values(): + try: + an_node = analyser["node"] + except KeyError: + pass + else: + if node.startswith(an_node): + for key in ANALYSER_KEYS_TO_COPY: + try: + analyse[key] = analyser[key] + except KeyError: + pass + found = True + break + try: + namespace = analyse["namespace"] + an_namespace = analyser["namespace"] + except KeyError: + pass + else: + if namespace == an_namespace: + for key in ANALYSER_KEYS_TO_COPY: + try: + analyse[key] = analyser[key] + except KeyError: + pass + found = True + break + + else: + found = False + log.debug( + f"node {node!r} at service {service} doesn't match any known type" + ) + if found: + try: + match_cb = analyser["match_cb"] + except KeyError: + pass + else: + await utils.asDeferred(match_cb, client, analyse) + return analyse + + def _getItemsFromCache( + self, service="", node="", max_items=10, item_ids=None, sub_id=None, + extra="", profile_key=C.PROF_KEY_NONE + ): + d = defer.ensureDeferred(self._aGetItemsFromCache( + service, node, max_items, item_ids, sub_id, extra, profile_key + )) + d.addCallback(self._p.transItemsData) + d.addCallback(self._p.serialiseItems) + return d + + async def _aGetItemsFromCache( + self, service, node, max_items, item_ids, sub_id, extra, profile_key + ): + client = self.host.getClient(profile_key) + service = jid.JID(service) if service else client.jid.userhostJID() + pubsub_node = await self.host.memory.storage.getPubsubNode( + client, service, node + ) + if pubsub_node is None: + raise exceptions.NotFound( + f"{node!r} at {service} doesn't exist in cache for {client.profile!r}" + ) + max_items = None if max_items == C.NO_LIMIT else max_items + extra = self._p.parseExtra(data_format.deserialise(extra)) + items, metadata = await self.getItemsFromCache( + client, + pubsub_node, + max_items, + item_ids, + sub_id or None, + extra.rsm_request, + extra.extra, + ) + return [i.data for i in items], metadata + + async def getItemsFromCache( + self, + client: SatXMPPEntity, + node: PubsubNode, + max_items: Optional[int] = None, + item_ids: Optional[List[str]] = None, + sub_id: Optional[str] = None, + rsm_request: Optional[rsm.RSMRequest] = None, + extra: Optional[dict] = None + ) -> Tuple[List[PubsubItem], dict]: + """Get items from cache, using same arguments as for external Pubsub request""" + if "mam" in extra: + raise NotImplementedError("MAM queries are not supported yet") + if max_items is None and rsm_request is None: + max_items = 20 + if max_items is not None: + if rsm_request is not None: + raise exceptions.InternalError( + "Pubsub max items and RSM must not be used at the same time" + ) + elif item_ids is None: + raise exceptions.InternalError( + "Pubsub max items and item IDs must not be used at the same time" + ) + pubsub_items, metadata = await self.host.memory.storage.getItems( + node, max_items=max_items, order_by=extra.get(C.KEY_ORDER_BY) + ) + else: + desc = False + if rsm_request.before == "": + before = None + desc = True + else: + before = rsm_request.before + pubsub_items, metadata = await self.host.memory.storage.getItems( + node, max_items=rsm_request.max, before=before, after=rsm_request.after, + from_index=rsm_request.index, order_by=extra.get(C.KEY_ORDER_BY), + desc=desc, force_rsm=True, + ) + + return pubsub_items, metadata + + async def onItemsEvent(self, client, event): + node = await self.host.memory.storage.getPubsubNode( + client, event.sender, event.nodeIdentifier + ) + if node is None: + return + if node.sync_state in (SyncState.COMPLETED, SyncState.IN_PROGRESS): + items = [] + retract_ids = [] + for elt in event.items: + if elt.name == "item": + items.append(elt) + elif elt.name == "retract": + item_id = elt.getAttribute("id") + if not item_id: + log.warning( + "Ignoring invalid retract item element: " + f"{xml_tools.pFmtElt(elt)}" + ) + continue + + retract_ids.append(elt["id"]) + else: + log.warning( + f"Unexpected Pubsub event element: {xml_tools.pFmtElt(elt)}" + ) + if items: + log.debug("caching new items received from {node}") + await self.cacheItems( + client, node, items + ) + if retract_ids: + log.debug(f"deleting retracted items from {node}") + await self.host.memory.storage.deletePubsubItems( + node, items_names=retract_ids + ) + + async def onDeleteEvent(self, client, event): + log.debug( + f"deleting node {event.nodeIdentifier} from {event.sender} for " + f"{client.profile}" + ) + await self.host.memory.storage.deletePubsubNode( + [client.profile], [event.sender], [event.nodeIdentifier] + ) + + async def onPurgeEvent(self, client, event): + node = await self.host.memory.storage.getPubsubNode( + client, event.sender, event.nodeIdentifier + ) + if node is None: + return + log.debug(f"purging node {node} for {client.profile}") + await self.host.memory.storage.deletePubsubItems(node) + + async def _getItemsTrigger( + self, + client: SatXMPPEntity, + service: Optional[jid.JID], + node: str, + max_items: Optional[int], + item_ids: Optional[List[str]], + sub_id: Optional[str], + rsm_request: Optional[rsm.RSMRequest], + extra: dict + ) -> Tuple[bool, Optional[Tuple[List[dict], dict]]]: + if not self.use_cache: + log.debug("cache disabled in settings") + return True, None + if extra.get(C.KEY_USE_CACHE) == False: + log.debug("skipping pubsub cache as requested") + return True, None + if service is None: + service = client.jid.userhostJID() + pubsub_node = await self.host.memory.storage.getPubsubNode( + client, service, node + ) + if pubsub_node is not None and pubsub_node.sync_state == SyncState.COMPLETED: + analyse = {"to_sync": True} + else: + analyse = await self.analyseNode(client, service, node) + + if pubsub_node is None: + pubsub_node = await self.host.memory.storage.setPubsubNode( + client, + service, + node, + analyser=analyse.get("name"), + type_=analyse.get("type"), + subtype=analyse.get("subtype"), + ) + + if analyse.get("to_sync"): + if pubsub_node.sync_state == SyncState.COMPLETED: + if "mam" in extra: + log.debug("MAM caching is not supported yet, skipping cache") + return True, None + pubsub_items, metadata = await self.getItemsFromCache( + client, pubsub_node, max_items, item_ids, sub_id, rsm_request, extra + ) + return False, ([i.data for i in pubsub_items], metadata) + + if pubsub_node.sync_state == SyncState.IN_PROGRESS: + if (service, node) not in self.in_progress: + log.warning( + f"{pubsub_node} is reported as being cached, but not caching is " + "in progress, this is most probably due to the backend being " + "restarted. Resetting the status, caching will be done again." + ) + pubsub_node.sync_state = None + await self.host.memory.storage.deletePubsubItems(pubsub_node) + elif time.time() - pubsub_node.sync_state_updated > PROGRESS_DEADLINE: + log.warning( + f"{pubsub_node} is in progress for too long " + f"({pubsub_node.sync_state_updated//60} minutes), " + "cancelling it and retrying." + ) + self.in_progress.pop[(service, node)].cancel() + pubsub_node.sync_state = None + await self.host.memory.storage.deletePubsubItems(pubsub_node) + else: + log.debug( + f"{pubsub_node} synchronisation is already in progress, skipping" + ) + if pubsub_node.sync_state is None: + key = (service, node) + if key in self.in_progress: + raise exceptions.InternalError( + f"There is already a caching in progress for {pubsub_node}, this " + "should not happen" + ) + self.cacheNode(client, pubsub_node) + elif pubsub_node.sync_state == SyncState.ERROR: + log.debug( + f"{pubsub_node} synchronisation has previously failed, skipping" + ) + + return True, None + + async def _subscribeTrigger( + self, + client: SatXMPPEntity, + service: jid.JID, + nodeIdentifier: str, + sub_jid: Optional[jid.JID], + options: Optional[dict], + subscription: pubsub.Subscription + ) -> None: + pass + + async def _unsubscribeTrigger( + self, + client: SatXMPPEntity, + service: jid.JID, + nodeIdentifier: str, + sub_jid, + subscriptionIdentifier, + sender, + ) -> None: + pass + + def _synchronise(self, service, node, profile_key): + client = self.host.getClient(profile_key) + service = client.jid.userhostJID() if not service else jid.JID(service) + return defer.ensureDeferred(self.synchronise(client, service, node)) + + async def synchronise( + self, + client: SatXMPPEntity, + service: jid.JID, + node: str + ) -> None: + """Synchronise a node with a pubsub service + + If the node is already synchronised, it will be resynchronised (all items will be + deleted and re-downloaded). + + The node will be synchronised even if there is no matching analyser. + + Note that when a node is synchronised, it is automatically subscribed. + """ + pubsub_node = await self.host.memory.storage.getPubsubNode( + client, service, node + ) + if pubsub_node is None: + log.info( + _( + "Synchronising the new node {node} at {service}" + ).format(node=node, service=service.full) + ) + analyse = await self.analyseNode(client, service, node) + pubsub_node = await self.host.memory.storage.setPubsubNode( + client, + service, + node, + analyser=analyse.get("name"), + type_=analyse.get("type"), + ) + + if ((pubsub_node.sync_state == SyncState.IN_PROGRESS + or (service, node) in self.in_progress)): + log.warning( + _( + "{node} at {service} is already being synchronised, can't do a new " + "synchronisation." + ).format(node=node, service=service) + ) + else: + log.info( + _( + "(Re)Synchronising the node {node} at {service} on user request" + ).format(node=node, service=service.full) + ) + # we first delete and recreate the node (will also delete its items) + await self.host.memory.storage.delete(pubsub_node) + analyse = await self.analyseNode(client, service, node) + pubsub_node = await self.host.memory.storage.setPubsubNode( + client, + service, + node, + analyser=analyse.get("name"), + type_=analyse.get("type"), + ) + # then we can put node in cache + await self.cacheNode(client, pubsub_node) + + async def purge(self, purge_filters: dict) -> None: + """Remove items according to filters + + filters can have on of the following keys, all are optional: + + :services: + list of JIDs of services from which items must be deleted + :nodes: + list of node names to delete + :types: + list of node types to delete + :subtypes: + list of node subtypes to delete + :profiles: + list of profiles from which items must be deleted + :created_before: + datetime before which items must have been created to be deleted + :created_update: + datetime before which items must have been updated last to be deleted + """ + purge_filters["names"] = purge_filters.pop("nodes", None) + await self.host.memory.storage.purgePubsubItems(**purge_filters) + + def _purge(self, purge_filters: str) -> None: + purge_filters = data_format.deserialise(purge_filters) + for key in "created_before", "updated_before": + try: + purge_filters[key] = datetime.fromtimestamp(purge_filters[key]) + except (KeyError, TypeError): + pass + return defer.ensureDeferred(self.purge(purge_filters)) + + async def reset(self) -> None: + """Remove ALL nodes and items from cache + + After calling this method, cache will be refilled progressively as if it where new + """ + await self.host.memory.storage.deletePubsubNode(None, None, None) + + def _reset(self) -> None: + return defer.ensureDeferred(self.reset())
--- a/sat/plugins/plugin_xep_0033.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_xep_0033.py Wed Sep 01 13:47:17 2021 +0200 @@ -156,7 +156,9 @@ d = defer.Deferred() if not skip_send: d.addCallback(client.sendMessageData) - d.addCallback(client.messageAddToHistory) + d.addCallback( + lambda ret: defer.ensureDeferred(client.messageAddToHistory(ret)) + ) d.addCallback(client.messageSendToBridge) d.addErrback(lambda failure: failure.trap(exceptions.CancelError)) return d.callback(mess_data)
--- a/sat/plugins/plugin_xep_0045.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_xep_0045.py Wed Sep 01 13:47:17 2021 +0200 @@ -1321,7 +1321,9 @@ except AttributeError: mess_data = self.client.messageProt.parseMessage(message.element) if mess_data['message'] or mess_data['subject']: - return self.host.memory.addToHistory(self.client, mess_data) + return defer.ensureDeferred( + self.host.memory.addToHistory(self.client, mess_data) + ) else: return defer.succeed(None)
--- a/sat/plugins/plugin_xep_0059.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_xep_0059.py Wed Sep 01 13:47:17 2021 +0200 @@ -1,7 +1,6 @@ #!/usr/bin/env python3 - -# SAT plugin for Result Set Management (XEP-0059) +# Result Set Management (XEP-0059) # Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) # Copyright (C) 2013-2016 Adrien Cossa (souliane@mailoo.org) @@ -18,19 +17,19 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. +from typing import Optional +from zope.interface import implementer +from twisted.words.protocols.jabber import xmlstream +from wokkel import disco +from wokkel import iwokkel +from wokkel import rsm from sat.core.i18n import _ from sat.core.constants import Const as C from sat.core.log import getLogger + log = getLogger(__name__) -from wokkel import disco -from wokkel import iwokkel -from wokkel import rsm - -from twisted.words.protocols.jabber import xmlstream -from zope.interface import implementer - PLUGIN_INFO = { C.PI_NAME: "Result Set Management", @@ -102,6 +101,52 @@ data["index"] = rsm_response.index return data + def getNextRequest( + self, + rsm_request: rsm.RSMRequest, + rsm_response: rsm.RSMResponse, + log_progress: bool = True, + ) -> Optional[rsm.RSMRequest]: + """Generate next request to paginate through all items + + Page will be retrieved forward + @param rsm_request: last request used + @param rsm_response: response from the last request + @return: request to retrive next page, or None if we are at the end + or if pagination is not possible + """ + if rsm_request.max == 0: + log.warning("Can't do pagination if max is 0") + return None + if rsm_response is None: + # may happen if result set it empty, or we are at the end + return None + if ( + rsm_response.count is not None + and rsm_response.index is not None + ): + next_index = rsm_response.index + rsm_request.max + if next_index >= rsm_response.count: + # we have reached the last page + return None + + if log_progress: + log.debug( + f"retrieving items {next_index} to " + f"{min(next_index+rsm_request.max, rsm_response.count)} on " + f"{rsm_response.count} ({next_index/rsm_response.count*100:.2f}%)" + ) + + if rsm_response.last is None: + if rsm_response.count: + log.warning("Can't do pagination, no \"last\" received") + return None + + return rsm.RSMRequest( + max_=rsm_request.max, + after=rsm_response.last + ) + @implementer(iwokkel.IDisco) class XEP_0059_handler(xmlstream.XMPPHandler):
--- a/sat/plugins/plugin_xep_0060.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_xep_0060.py Wed Sep 01 13:47:17 2021 +0200 @@ -17,7 +17,7 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. -from typing import Optional +from typing import Optional, List, Tuple from collections import namedtuple import urllib.request, urllib.parse, urllib.error from functools import reduce @@ -35,8 +35,9 @@ from sat.core.i18n import _ from sat.core.constants import Const as C from sat.core.log import getLogger -from sat.core.xmpp import SatXMPPEntity +from sat.core.core_types import SatXMPPEntity from sat.core import exceptions +from sat.tools import utils from sat.tools import sat_defer from sat.tools import xml_tools from sat.tools.common import data_format @@ -90,6 +91,8 @@ ID_SINGLETON = "current" EXTRA_PUBLISH_OPTIONS = "publish_options" EXTRA_ON_PRECOND_NOT_MET = "on_precondition_not_met" + # extra disco needed for RSM, cf. XEP-0060 § 6.5.4 + DISCO_RSM = "http://jabber.org/protocol/pubsub#rsm" def __init__(self, host): log.info(_("PubSub plugin initialization")) @@ -197,7 +200,7 @@ host.bridge.addMethod( "psItemsGet", ".plugin", - in_sign="ssiassa{ss}s", + in_sign="ssiassss", out_sign="s", method=self._getItems, async_=True, @@ -284,7 +287,7 @@ host.bridge.addMethod( "psGetFromMany", ".plugin", - in_sign="a(ss)ia{ss}s", + in_sign="a(ss)iss", out_sign="s", method=self._getFromMany, ) @@ -391,6 +394,7 @@ the method must be named after PubSub constants in lower case and suffixed with "_cb" e.g.: "items_cb" for C.PS_ITEMS, "delete_cb" for C.PS_DELETE + note: only C.PS_ITEMS and C.PS_DELETE are implemented so far """ assert node is not None assert kwargs @@ -471,9 +475,9 @@ service = None if not service else jid.JID(service) payload = xml_tools.parse(payload) extra = data_format.deserialise(extra_ser) - d = self.sendItem( + d = defer.ensureDeferred(self.sendItem( client, service, nodeIdentifier, payload, item_id or None, extra - ) + )) d.addCallback(lambda ret: ret or "") return d @@ -487,23 +491,13 @@ raise exceptions.DataError(_("Can't parse items: {msg}").format( msg=e)) extra = data_format.deserialise(extra_ser) - d = self.sendItems( + return defer.ensureDeferred(self.sendItems( client, service, nodeIdentifier, items, extra - ) - return d - - def _getPublishedItemId(self, published_ids, original_id): - """Return item of published id if found in answer + )) - if not found original_id is returned, which may be None - """ - try: - return published_ids[0] - except IndexError: - return original_id - - def sendItem(self, client, service, nodeIdentifier, payload, item_id=None, - extra=None): + async def sendItem( + self, client, service, nodeIdentifier, payload, item_id=None, extra=None + ): """High level method to send one item @param service(jid.JID, None): service to send the item to @@ -519,15 +513,17 @@ if item_id is not None: item_elt['id'] = item_id item_elt.addChild(payload) - d = defer.ensureDeferred(self.sendItems( + published_ids = await self.sendItems( client, service, nodeIdentifier, [item_elt], extra - )) - d.addCallback(self._getPublishedItemId, item_id) - return d + ) + try: + return published_ids[0] + except IndexError: + return item_id async def sendItems(self, client, service, nodeIdentifier, items, extra=None): """High level method to send several items at once @@ -593,12 +589,25 @@ except AttributeError: return [] - def publish(self, client, service, nodeIdentifier, items=None, options=None): - return client.pubsub_client.publish( + async def publish( + self, + client: SatXMPPEntity, + service: jid.JID, + nodeIdentifier: str, + items: Optional[List[domish.Element]] = None, + options: Optional[dict] = None + ) -> List[str]: + published_ids = await client.pubsub_client.publish( service, nodeIdentifier, items, client.pubsub_client.parent.jid, options=options ) + await self.host.trigger.asyncPoint( + "XEP-0060_publish", client, service, nodeIdentifier, items, options, + published_ids + ) + return published_ids + def _unwrapMAMMessage(self, message_elt): try: item_elt = reduce( @@ -621,7 +630,7 @@ return data_format.serialise(metadata) def _getItems(self, service="", node="", max_items=10, item_ids=None, sub_id=None, - extra_dict=None, profile_key=C.PROF_KEY_NONE): + extra="", profile_key=C.PROF_KEY_NONE): """Get items from pubsub node @param max_items(int): maximum number of item to get, C.NO_LIMIT for no limit @@ -629,23 +638,32 @@ client = self.host.getClient(profile_key) service = jid.JID(service) if service else None max_items = None if max_items == C.NO_LIMIT else max_items - extra = self.parseExtra(extra_dict) - d = self.getItems( + extra = self.parseExtra(data_format.deserialise(extra)) + d = defer.ensureDeferred(self.getItems( client, service, node or None, - max_items or None, + max_items, item_ids, sub_id or None, extra.rsm_request, extra.extra, - ) + )) d.addCallback(self.transItemsData) d.addCallback(self.serialiseItems) return d - def getItems(self, client, service, node, max_items=None, item_ids=None, sub_id=None, - rsm_request=None, extra=None): + async def getItems( + self, + client: SatXMPPEntity, + service: Optional[jid.JID], + node: str, + max_items: Optional[int] = None, + item_ids: Optional[List[str]] = None, + sub_id: Optional[str] = None, + rsm_request: Optional[rsm.RSMRequest] = None, + extra: Optional[dict] = None + ) -> Tuple[List[dict], dict]: """Retrieve pubsub items from a node. @param service (JID, None): pubsub service. @@ -668,6 +686,12 @@ raise ValueError("items_id can't be used with rsm") if extra is None: extra = {} + cont, ret = await self.host.trigger.asyncReturnPoint( + "XEP-0060_getItems", client, service, node, max_items, item_ids, sub_id, + rsm_request, extra + ) + if not cont: + return ret try: mam_query = extra["mam"] except KeyError: @@ -682,9 +706,10 @@ rsm_request = rsm_request ) # we have no MAM data here, so we add None - d.addCallback(lambda data: data + (None,)) d.addErrback(sat_defer.stanza2NotFound) d.addTimeout(TIMEOUT, reactor) + items, rsm_response = await d + mam_response = None else: # if mam is requested, we have to do a totally different query if self._mam is None: @@ -706,61 +731,49 @@ raise exceptions.DataError( "Conflict between RSM request and MAM's RSM request" ) - d = self._mam.getArchives(client, mam_query, service, self._unwrapMAMMessage) + items, rsm_response, mam_response = await self._mam.getArchives( + client, mam_query, service, self._unwrapMAMMessage + ) try: subscribe = C.bool(extra["subscribe"]) except KeyError: subscribe = False - def subscribeEb(failure, service, node): - failure.trap(error.StanzaError) - log.warning( - "Could not subscribe to node {} on service {}: {}".format( - node, str(service), str(failure.value) + if subscribe: + try: + await self.subscribe(client, service, node) + except error.StanzaError as e: + log.warning( + f"Could not subscribe to node {node} on service {service}: {e}" ) - ) - - def doSubscribe(data): - self.subscribe(client, service, node).addErrback( - subscribeEb, service, node - ) - return data - - if subscribe: - d.addCallback(doSubscribe) - def addMetadata(result): - # TODO: handle the third argument (mam_response) - items, rsm_response, mam_response = result - service_jid = service if service else client.jid.userhostJID() - metadata = { - "service": service_jid, - "node": node, - "uri": self.getNodeURI(service_jid, node), - } - if mam_response is not None: - # mam_response is a dict with "complete" and "stable" keys - # we can put them directly in metadata - metadata.update(mam_response) - if rsm_request is not None and rsm_response is not None: - metadata['rsm'] = rsm_response.toDict() - if mam_response is None: - index = rsm_response.index - count = rsm_response.count - if index is None or count is None: - # we don't have enough information to know if the data is complete - # or not - metadata["complete"] = None - else: - # normally we have a strict equality here but XEP-0059 states - # that index MAY be approximative, so just in case… - metadata["complete"] = index + len(items) >= count + # TODO: handle mam_response + service_jid = service if service else client.jid.userhostJID() + metadata = { + "service": service_jid, + "node": node, + "uri": self.getNodeURI(service_jid, node), + } + if mam_response is not None: + # mam_response is a dict with "complete" and "stable" keys + # we can put them directly in metadata + metadata.update(mam_response) + if rsm_request is not None and rsm_response is not None: + metadata['rsm'] = rsm_response.toDict() + if mam_response is None: + index = rsm_response.index + count = rsm_response.count + if index is None or count is None: + # we don't have enough information to know if the data is complete + # or not + metadata["complete"] = None + else: + # normally we have a strict equality here but XEP-0059 states + # that index MAY be approximative, so just in case… + metadata["complete"] = index + len(items) >= count - return (items, metadata) - - d.addCallback(addMetadata) - return d + return (items, metadata) # @defer.inlineCallbacks # def getItemsFromMany(self, service, data, max_items=None, sub_id=None, rsm=None, profile_key=C.PROF_KEY_NONE): @@ -1059,7 +1072,7 @@ notify=True, ): return client.pubsub_client.retractItems( - service, nodeIdentifier, itemIdentifiers, notify=True + service, nodeIdentifier, itemIdentifiers, notify=notify ) def _renameItem( @@ -1100,37 +1113,55 @@ def _subscribe(self, service, nodeIdentifier, options, profile_key=C.PROF_KEY_NONE): client = self.host.getClient(profile_key) service = None if not service else jid.JID(service) - d = self.subscribe(client, service, nodeIdentifier, options=options or None) + d = defer.ensureDeferred( + self.subscribe(client, service, nodeIdentifier, options=options or None) + ) d.addCallback(lambda subscription: subscription.subscriptionIdentifier or "") return d - def subscribe(self, client, service, nodeIdentifier, sub_jid=None, options=None): + async def subscribe( + self, + client: SatXMPPEntity, + service: jid.JID, + nodeIdentifier: str, + sub_jid: Optional[jid.JID] = None, + options: Optional[dict] = None + ) -> pubsub.Subscription: # TODO: reimplement a subscribtion cache, checking that we have not subscription before trying to subscribe - return client.pubsub_client.subscribe( + subscription = await client.pubsub_client.subscribe( service, nodeIdentifier, sub_jid or client.jid.userhostJID(), options=options ) + await self.host.trigger.asyncPoint( + "XEP-0060_subscribe", client, service, nodeIdentifier, sub_jid, options, + subscription + ) + return subscription def _unsubscribe(self, service, nodeIdentifier, profile_key=C.PROF_KEY_NONE): client = self.host.getClient(profile_key) service = None if not service else jid.JID(service) - return self.unsubscribe(client, service, nodeIdentifier) + return defer.ensureDeferred(self.unsubscribe(client, service, nodeIdentifier)) - def unsubscribe( + async def unsubscribe( self, - client, - service, - nodeIdentifier, + client: SatXMPPEntity, + service: jid.JID, + nodeIdentifier: str, sub_jid=None, subscriptionIdentifier=None, sender=None, ): - return client.pubsub_client.unsubscribe( + await client.pubsub_client.unsubscribe( service, nodeIdentifier, sub_jid or client.jid.userhostJID(), subscriptionIdentifier, sender, ) + await self.host.trigger.asyncPoint( + "XEP-0060_unsubscribe", client, service, nodeIdentifier, sub_jid, + subscriptionIdentifier, sender + ) def _subscriptions(self, service, nodeIdentifier="", profile_key=C.PROF_KEY_NONE): client = self.host.getClient(profile_key) @@ -1394,8 +1425,10 @@ client = self.host.getClient(profile_key) deferreds = {} for service, node in node_data: - deferreds[(service, node)] = client.pubsub_client.subscribe( - service, node, subscriber, options=options + deferreds[(service, node)] = defer.ensureDeferred( + client.pubsub_client.subscribe( + service, node, subscriber, options=options + ) ) return self.rt_sessions.newSession(deferreds, client.profile) # found_nodes = yield self.listNodes(service, profile=client.profile) @@ -1445,13 +1478,13 @@ return d def _getFromMany( - self, node_data, max_item=10, extra_dict=None, profile_key=C.PROF_KEY_NONE + self, node_data, max_item=10, extra="", profile_key=C.PROF_KEY_NONE ): """ @param max_item(int): maximum number of item to get, C.NO_LIMIT for no limit """ max_item = None if max_item == C.NO_LIMIT else max_item - extra = self.parseExtra(extra_dict) + extra = self.parseExtra(data_format.deserialise(extra)) return self.getFromMany( [(jid.JID(service), str(node)) for service, node in node_data], max_item, @@ -1475,9 +1508,9 @@ client = self.host.getClient(profile_key) deferreds = {} for service, node in node_data: - deferreds[(service, node)] = self.getItems( + deferreds[(service, node)] = defer.ensureDeferred(self.getItems( client, service, node, max_item, rsm_request=rsm_request, extra=extra - ) + )) return self.rt_sessions.newSession(deferreds, client.profile) @@ -1513,7 +1546,10 @@ def itemsReceived(self, event): log.debug("Pubsub items received") for callback in self._getNodeCallbacks(event.nodeIdentifier, C.PS_ITEMS): - callback(self.parent, event) + d = utils.asDeferred(callback, self.parent, event) + d.addErrback(lambda f: log.error( + f"Error while running items event callback {callback}: {f}" + )) client = self.parent if (event.sender, event.nodeIdentifier) in client.pubsub_watching: raw_items = [i.toXml() for i in event.items] @@ -1528,13 +1564,29 @@ def deleteReceived(self, event): log.debug(("Publish node deleted")) for callback in self._getNodeCallbacks(event.nodeIdentifier, C.PS_DELETE): - callback(self.parent, event) + d = utils.asDeferred(callback, self.parent, event) + d.addErrback(lambda f: log.error( + f"Error while running delete event callback {callback}: {f}" + )) client = self.parent if (event.sender, event.nodeIdentifier) in client.pubsub_watching: self.host.bridge.psEventRaw( event.sender.full(), event.nodeIdentifier, C.PS_DELETE, [], client.profile ) + def purgeReceived(self, event): + log.debug(("Publish node purged")) + for callback in self._getNodeCallbacks(event.nodeIdentifier, C.PS_PURGE): + d = utils.asDeferred(callback, self.parent, event) + d.addErrback(lambda f: log.error( + f"Error while running purge event callback {callback}: {f}" + )) + client = self.parent + if (event.sender, event.nodeIdentifier) in client.pubsub_watching: + self.host.bridge.psEventRaw( + event.sender.full(), event.nodeIdentifier, C.PS_PURGE, [], client.profile + ) + def subscriptions(self, service, nodeIdentifier, sender=None): """Return the list of subscriptions to the given service and node.
--- a/sat/plugins/plugin_xep_0184.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_xep_0184.py Wed Sep 01 13:47:17 2021 +0200 @@ -150,7 +150,7 @@ @param client: %(doc_client)s""" from_jid = jid.JID(msg_elt["from"]) - if self._isActif(client.profile) and client.roster.isPresenceAuthorised(from_jid): + if self._isActif(client.profile) and client.roster.isSubscribedFrom(from_jid): received_elt_ret = domish.Element((NS_MESSAGE_DELIVERY_RECEIPTS, "received")) try: received_elt_ret["id"] = msg_elt["id"]
--- a/sat/plugins/plugin_xep_0198.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_xep_0198.py Wed Sep 01 13:47:17 2021 +0200 @@ -448,7 +448,7 @@ d.addCallback(lambda __: client.roster.got_roster) if plg_0313 is not None: # we retrieve one2one MAM archives - d.addCallback(lambda __: plg_0313.resume(client)) + d.addCallback(lambda __: defer.ensureDeferred(plg_0313.resume(client))) # initial presence must be sent manually d.addCallback(lambda __: client.presence.available()) if plg_0045 is not None:
--- a/sat/plugins/plugin_xep_0277.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_xep_0277.py Wed Sep 01 13:47:17 2021 +0200 @@ -55,8 +55,8 @@ NS_MICROBLOG = "urn:xmpp:microblog:0" NS_ATOM = "http://www.w3.org/2005/Atom" -NS_PUBSUB_EVENT = "{}{}".format(pubsub.NS_PUBSUB, "#event") -NS_COMMENT_PREFIX = "{}:comments/".format(NS_MICROBLOG) +NS_PUBSUB_EVENT = f"{pubsub.NS_PUBSUB}#event" +NS_COMMENT_PREFIX = f"{NS_MICROBLOG}:comments/" PLUGIN_INFO = { @@ -65,7 +65,7 @@ C.PI_TYPE: "XEP", C.PI_PROTOCOLS: ["XEP-0277"], C.PI_DEPENDENCIES: ["XEP-0163", "XEP-0060", "TEXT_SYNTAXES"], - C.PI_RECOMMENDATIONS: ["XEP-0059", "EXTRA-PEP"], + C.PI_RECOMMENDATIONS: ["XEP-0059", "EXTRA-PEP", "PUBSUB_CACHE"], C.PI_MAIN: "XEP_0277", C.PI_HANDLER: "yes", C.PI_DESCRIPTION: _("""Implementation of microblogging Protocol"""), @@ -86,6 +86,19 @@ self._p = self.host.plugins[ "XEP-0060" ] # this facilitate the access to pubsub plugin + ps_cache = self.host.plugins.get("PUBSUB_CACHE") + if ps_cache is not None: + ps_cache.registerAnalyser( + { + "name": "XEP-0277", + "node": NS_MICROBLOG, + "namespace": NS_ATOM, + "type": "blog", + "to_sync": True, + "parser": self.item2mbdata, + "match_cb": self._cacheNodeMatchCb, + } + ) self.rt_sessions = sat_defer.RTDeferredSessions() self.host.plugins["XEP-0060"].addManagedNode( NS_MICROBLOG, items_cb=self._itemsReceived @@ -118,7 +131,7 @@ host.bridge.addMethod( "mbGet", ".plugin", - in_sign="ssiasa{ss}s", + in_sign="ssiasss", out_sign="s", method=self._mbGet, async_=True, @@ -180,6 +193,15 @@ def getHandler(self, client): return XEP_0277_handler() + def _cacheNodeMatchCb( + self, + client: SatXMPPEntity, + analyse: dict, + ) -> None: + """Check is analysed node is a comment and fill analyse accordingly""" + if analyse["node"].startswith(NS_COMMENT_PREFIX): + analyse["subtype"] = "comment" + def _checkFeaturesCb(self, available): return {"available": C.BOOL_TRUE} @@ -682,7 +704,7 @@ @param item_id(unicode): id of the parent item @return (unicode): comment node to use """ - return "{}{}".format(NS_COMMENT_PREFIX, item_id) + return f"{NS_COMMENT_PREFIX}{item_id}" def getCommentsService(self, client, parent_service=None): """Get prefered PubSub service to create comment node @@ -931,7 +953,7 @@ metadata['items'] = items return data_format.serialise(metadata) - def _mbGet(self, service="", node="", max_items=10, item_ids=None, extra_dict=None, + def _mbGet(self, service="", node="", max_items=10, item_ids=None, extra="", profile_key=C.PROF_KEY_NONE): """ @param max_items(int): maximum number of item to get, C.NO_LIMIT for no limit @@ -940,14 +962,15 @@ client = self.host.getClient(profile_key) service = jid.JID(service) if service else None max_items = None if max_items == C.NO_LIMIT else max_items - extra = self._p.parseExtra(extra_dict) - d = self.mbGet(client, service, node or None, max_items, item_ids, + extra = self._p.parseExtra(data_format.deserialise(extra)) + d = defer.ensureDeferred( + self.mbGet(client, service, node or None, max_items, item_ids, extra.rsm_request, extra.extra) + ) d.addCallback(self._mbGetSerialise) return d - @defer.inlineCallbacks - def mbGet(self, client, service=None, node=None, max_items=10, item_ids=None, + async def mbGet(self, client, service=None, node=None, max_items=10, item_ids=None, rsm_request=None, extra=None): """Get some microblogs @@ -955,6 +978,7 @@ None to get profile's PEP @param node(unicode, None): node to get (or microblog node if None) @param max_items(int): maximum number of item to get, None for no limit + ignored if rsm_request is set @param item_ids (list[unicode]): list of item IDs @param rsm_request (rsm.RSMRequest): RSM request data @param extra (dict): extra data @@ -963,7 +987,9 @@ """ if node is None: node = NS_MICROBLOG - items_data = yield self._p.getItems( + if rsm_request: + max_items = None + items_data = await self._p.getItems( client, service, node, @@ -972,9 +998,9 @@ rsm_request=rsm_request, extra=extra, ) - mb_data = yield self._p.transItemsDataD( + mb_data = await self._p.transItemsDataD( items_data, partial(self.item2mbdata, client, service=service, node=node)) - defer.returnValue(mb_data) + return mb_data def _mbRename(self, service, node, item_id, new_id, profile_key): return defer.ensureDeferred(self.mbRename( @@ -1372,13 +1398,15 @@ service = jid.JID(service_s) node = item["{}{}".format(prefix, "_node")] # time to get the comments - d = self._p.getItems( - client, - service, - node, - max_comments, - rsm_request=rsm_comments, - extra=extra_comments, + d = defer.ensureDeferred( + self._p.getItems( + client, + service, + node, + max_comments, + rsm_request=rsm_comments, + extra=extra_comments, + ) ) # then serialise d.addCallback( @@ -1420,9 +1448,9 @@ deferreds = {} for service, node in node_data: - d = deferreds[(service, node)] = self._p.getItems( + d = deferreds[(service, node)] = defer.ensureDeferred(self._p.getItems( client, service, node, max_items, rsm_request=rsm_request, extra=extra - ) + )) d.addCallback( lambda items_data: self._p.transItemsDataD( items_data,
--- a/sat/plugins/plugin_xep_0313.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_xep_0313.py Wed Sep 01 13:47:17 2021 +0200 @@ -75,16 +75,15 @@ out_sign='(a(sdssa{ss}a{ss}ss)ss)', method=self._getArchives, async_=True) - @defer.inlineCallbacks - def resume(self, client): + async def resume(self, client): """Retrieve one2one messages received since the last we have in local storage""" - stanza_id_data = yield self.host.memory.storage.getPrivates( + stanza_id_data = await self.host.memory.storage.getPrivates( mam.NS_MAM, [KEY_LAST_STANZA_ID], profile=client.profile) stanza_id = stanza_id_data.get(KEY_LAST_STANZA_ID) rsm_req = None if stanza_id is None: log.info("can't retrieve last stanza ID, checking history") - last_mess = yield self.host.memory.historyGet( + last_mess = await self.host.memory.historyGet( None, None, limit=1, filters={'not_types': C.MESS_TYPE_GROUPCHAT, 'last_stanza_id': True}, profile=client.profile) @@ -100,7 +99,7 @@ complete = False count = 0 while not complete: - mam_data = yield self.getArchives(client, mam_req, + mam_data = await self.getArchives(client, mam_req, service=client.jid.userhostJID()) elt_list, rsm_response, mam_response = mam_data complete = mam_response["complete"] @@ -145,7 +144,7 @@ # adding message to history mess_data = client.messageProt.parseMessage(fwd_message_elt) try: - yield client.messageProt.addToHistory(mess_data) + await client.messageProt.addToHistory(mess_data) except exceptions.CancelError as e: log.warning( "message has not been added to history: {e}".format(e=e)) @@ -160,8 +159,8 @@ log.info(_("We have received {num_mess} message(s) while offline.") .format(num_mess=count)) - def profileConnected(self, client): - return self.resume(client) + async def profileConnected(self, client): + await self.resume(client) def getHandler(self, client): mam_client = client._mam = SatMAMClient(self)
--- a/sat/plugins/plugin_xep_0329.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_xep_0329.py Wed Sep 01 13:47:17 2021 +0200 @@ -19,6 +19,7 @@ import mimetypes import json import os +import traceback from pathlib import Path from typing import Optional, Dict from zope.interface import implementer @@ -33,6 +34,7 @@ from sat.core.constants import Const as C from sat.core.log import getLogger from sat.tools import stream +from sat.tools import utils from sat.tools.common import regex @@ -453,9 +455,9 @@ iq_elt.handled = True node = iq_elt.query.getAttribute("node") if not node: - d = defer.maybeDeferred(root_nodes_cb, client, iq_elt) + d = utils.asDeferred(root_nodes_cb, client, iq_elt) else: - d = defer.maybeDeferred(files_from_node_cb, client, iq_elt, node) + d = utils.asDeferred(files_from_node_cb, client, iq_elt, node) d.addErrback( lambda failure_: log.error( _("error while retrieving files: {msg}").format(msg=failure_) @@ -589,10 +591,9 @@ @return (tuple[jid.JID, jid.JID]): peer_jid and owner """ - @defer.inlineCallbacks - def _compGetRootNodesCb(self, client, iq_elt): + async def _compGetRootNodesCb(self, client, iq_elt): peer_jid, owner = client.getOwnerAndPeer(iq_elt) - files_data = yield self.host.memory.getFiles( + files_data = await self.host.memory.getFiles( client, peer_jid=peer_jid, parent="", @@ -607,8 +608,7 @@ directory_elt["name"] = name client.send(iq_result_elt) - @defer.inlineCallbacks - def _compGetFilesFromNodeCb(self, client, iq_elt, node_path): + async def _compGetFilesFromNodeCb(self, client, iq_elt, node_path): """Retrieve files from local files repository according to permissions result stanza is then built and sent to requestor @@ -618,7 +618,7 @@ """ peer_jid, owner = client.getOwnerAndPeer(iq_elt) try: - files_data = yield self.host.memory.getFiles( + files_data = await self.host.memory.getFiles( client, peer_jid=peer_jid, path=node_path, owner=owner ) except exceptions.NotFound: @@ -628,7 +628,8 @@ self._iqError(client, iq_elt, condition='not-allowed') return except Exception as e: - log.error("internal server error: {e}".format(e=e)) + tb = traceback.format_tb(e.__traceback__) + log.error(f"internal server error: {e}\n{''.join(tb)}") self._iqError(client, iq_elt, condition='internal-server-error') return iq_result_elt = xmlstream.toResponse(iq_elt, "result")
--- a/sat/plugins/plugin_xep_0346.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_xep_0346.py Wed Sep 01 13:47:17 2021 +0200 @@ -118,7 +118,7 @@ host.bridge.addMethod( "psItemsFormGet", ".plugin", - in_sign="ssssiassa{ss}s", + in_sign="ssssiassss", out_sign="(asa{ss})", method=self._getDataFormItems, async_=True, @@ -302,7 +302,7 @@ return xml_tools.dataForm2dataDict(schema_form) def _getDataFormItems(self, form_ns="", service="", node="", schema="", max_items=10, - item_ids=None, sub_id=None, extra_dict=None, + item_ids=None, sub_id=None, extra="", profile_key=C.PROF_KEY_NONE): client = self.host.getClient(profile_key) service = jid.JID(service) if service else None @@ -313,7 +313,7 @@ else: schema = None max_items = None if max_items == C.NO_LIMIT else max_items - extra = self._p.parseExtra(extra_dict) + extra = self._p.parseExtra(data_format.deserialise(extra)) d = defer.ensureDeferred( self.getDataFormItems( client, @@ -553,7 +553,7 @@ ## Helper methods ## - def prepareBridgeGet(self, service, node, max_items, sub_id, extra_dict, profile_key): + def prepareBridgeGet(self, service, node, max_items, sub_id, extra, profile_key): """Parse arguments received from bridge *Get methods and return higher level data @return (tuple): (client, service, node, max_items, extra, sub_id) usable for @@ -566,12 +566,12 @@ max_items = None if max_items == C.NO_LIMIT else max_items if not sub_id: sub_id = None - extra = self._p.parseExtra(extra_dict) + extra = self._p.parseExtra(extra) return client, service, node, max_items, extra, sub_id def _get(self, service="", node="", max_items=10, item_ids=None, sub_id=None, - extra=None, default_node=None, form_ns=None, filters=None, + extra="", default_node=None, form_ns=None, filters=None, profile_key=C.PROF_KEY_NONE): """Bridge method to retrieve data from node with schema @@ -583,8 +583,7 @@ """ if filters is None: filters = {} - if extra is None: - extra = {} + extra = data_format.deserialise(extra) # XXX: Q&D way to get list for labels when displaying them, but text when we # have to modify them if C.bool(extra.get("labels_as_list", C.BOOL_FALSE)): @@ -627,8 +626,7 @@ extra = data_format.deserialise(extra) return client, service, node or None, schema, item_id or None, extra - @defer.inlineCallbacks - def copyMissingValues(self, client, service, node, item_id, form_ns, values): + async def copyMissingValues(self, client, service, node, item_id, form_ns, values): """Retrieve values existing in original item and missing in update Existing item will be retrieve, and values not already specified in values will @@ -643,7 +641,7 @@ """ try: # we get previous item - items_data = yield self._p.getItems( + items_data = await self._p.getItems( client, service, node, item_ids=[item_id] ) item_elt = items_data[0][0]
--- a/sat/plugins/plugin_xep_0384.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/plugins/plugin_xep_0384.py Wed Sep 01 13:47:17 2021 +0200 @@ -114,7 +114,10 @@ @return (defer.Deferred): deferred instance linked to the promise """ d = defer.Deferred() - promise_.then(d.callback, d.errback) + promise_.then( + lambda result: reactor.callLater(0, d.callback, result), + lambda exc: reactor.callLater(0, d.errback, exc) + ) return d @@ -141,6 +144,26 @@ deferred.addCallback(partial(callback, True)) deferred.addErrback(partial(callback, False)) + def _callMainThread(self, callback, method, *args, check_jid=None): + d = method(*args) + if check_jid is not None: + check_jid_d = self._checkJid(check_jid) + check_jid_d.addCallback(lambda __: d) + d = check_jid_d + if callback is not None: + d.addCallback(partial(callback, True)) + d.addErrback(partial(callback, False)) + + def _call(self, callback, method, *args, check_jid=None): + """Create Deferred and add Promise callback to it + + This method use reactor.callLater to launch Deferred in main thread + @param check_jid: run self._checkJid before method + """ + reactor.callLater( + 0, self._callMainThread, callback, method, *args, check_jid=check_jid + ) + def _checkJid(self, bare_jid): """Check if jid is known, and store it if not @@ -164,71 +187,50 @@ callback(True, None) def loadState(self, callback): - d = self.data.get(KEY_STATE) - self.setCb(d, callback) + self._call(callback, self.data.get, KEY_STATE) def storeState(self, callback, state): - d = self.data.force(KEY_STATE, state) - self.setCb(d, callback) + self._call(callback, self.data.force, KEY_STATE, state) def loadSession(self, callback, bare_jid, device_id): key = '\n'.join([KEY_SESSION, bare_jid, str(device_id)]) - d = self.data.get(key) - self.setCb(d, callback) + self._call(callback, self.data.get, key) def storeSession(self, callback, bare_jid, device_id, session): key = '\n'.join([KEY_SESSION, bare_jid, str(device_id)]) - d = self.data.force(key, session) - self.setCb(d, callback) + self._call(callback, self._data.force, key, session) def deleteSession(self, callback, bare_jid, device_id): key = '\n'.join([KEY_SESSION, bare_jid, str(device_id)]) - d = self.data.remove(key) - self.setCb(d, callback) + self._call(callback, self.data.remove, key) def loadActiveDevices(self, callback, bare_jid): key = '\n'.join([KEY_ACTIVE_DEVICES, bare_jid]) - d = self.data.get(key, {}) - if callback is not None: - self.setCb(d, callback) - return d + self._call(callback, self.data.get, key, {}) def loadInactiveDevices(self, callback, bare_jid): key = '\n'.join([KEY_INACTIVE_DEVICES, bare_jid]) - d = self.data.get(key, {}) - if callback is not None: - self.setCb(d, callback) - return d + self._call(callback, self.data.get, key, {}) def storeActiveDevices(self, callback, bare_jid, devices): key = '\n'.join([KEY_ACTIVE_DEVICES, bare_jid]) - d = self._checkJid(bare_jid) - d.addCallback(lambda _: self.data.force(key, devices)) - self.setCb(d, callback) + self._call(callback, self.data.force, key, devices, check_jid=bare_jid) def storeInactiveDevices(self, callback, bare_jid, devices): key = '\n'.join([KEY_INACTIVE_DEVICES, bare_jid]) - d = self._checkJid(bare_jid) - d.addCallback(lambda _: self.data.force(key, devices)) - self.setCb(d, callback) + self._call(callback, self.data.force, key, devices, check_jid=bare_jid) def storeTrust(self, callback, bare_jid, device_id, trust): key = '\n'.join([KEY_TRUST, bare_jid, str(device_id)]) - d = self.data.force(key, trust) - self.setCb(d, callback) + self._call(callback, self.data.force, key, trust) def loadTrust(self, callback, bare_jid, device_id): key = '\n'.join([KEY_TRUST, bare_jid, str(device_id)]) - d = self.data.get(key) - if callback is not None: - self.setCb(d, callback) - return d + self._call(callback, self.data.get, key) def listJIDs(self, callback): - d = defer.succeed(self.all_jids) if callback is not None: - self.setCb(d, callback) - return d + callback(True, self.all_jids) def _deleteJID_logResults(self, results): failed = [success for success, __ in results if not success] @@ -266,8 +268,7 @@ d.addCallback(self._deleteJID_logResults) return d - def deleteJID(self, callback, bare_jid): - """Retrieve all (in)actives devices of bare_jid, and delete all related keys""" + def _deleteJID(self, callback, bare_jid): d_list = [] key = '\n'.join([KEY_ACTIVE_DEVICES, bare_jid]) @@ -284,7 +285,10 @@ d.addCallback(self._deleteJID_gotDevices, bare_jid) if callback is not None: self.setCb(d, callback) - return d + + def deleteJID(self, callback, bare_jid): + """Retrieve all (in)actives devices of bare_jid, and delete all related keys""" + reactor.callLater(0, self._deleteJID, callback, bare_jid) class SatOTPKPolicy(omemo.DefaultOTPKPolicy): @@ -728,7 +732,7 @@ while device_id in devices: device_id = random.randint(1, 2**31-1) # and we save it - persistent_dict[KEY_DEVICE_ID] = device_id + await persistent_dict.aset(KEY_DEVICE_ID, device_id) log.debug(f"our OMEMO device id is {device_id}") @@ -788,8 +792,7 @@ devices.add(device_id) return devices - @defer.inlineCallbacks - def getDevices(self, client, entity_jid=None): + async def getDevices(self, client, entity_jid=None): """Retrieve list of registered OMEMO devices @param entity_jid(jid.JID, None): get devices from this entity @@ -799,13 +802,13 @@ if entity_jid is not None: assert not entity_jid.resource try: - items, metadata = yield self._p.getItems(client, entity_jid, NS_OMEMO_DEVICES) + items, metadata = await self._p.getItems(client, entity_jid, NS_OMEMO_DEVICES) except exceptions.NotFound: log.info(_("there is no node to handle OMEMO devices")) - defer.returnValue(set()) + return set() devices = self.parseDevices(items) - defer.returnValue(devices) + return devices async def setDevices(self, client, devices): log.debug(f"setting devices with {', '.join(str(d) for d in devices)}") @@ -827,8 +830,7 @@ # bundles - @defer.inlineCallbacks - def getBundles(self, client, entity_jid, devices_ids): + async def getBundles(self, client, entity_jid, devices_ids): """Retrieve public bundles of an entity devices @param entity_jid(jid.JID): bare jid of entity @@ -845,7 +847,7 @@ for device_id in devices_ids: node = NS_OMEMO_BUNDLE.format(device_id=device_id) try: - items, metadata = yield self._p.getItems(client, entity_jid, node) + items, metadata = await self._p.getItems(client, entity_jid, node) except exceptions.NotFound: log.warning(_("Bundle missing for device {device_id}") .format(device_id=device_id)) @@ -906,7 +908,7 @@ bundles[device_id] = ExtendedPublicBundle.parse(omemo_backend, ik, spk, spkSignature, otpks) - defer.returnValue((bundles, missing)) + return (bundles, missing) async def setBundle(self, client, bundle, device_id): """Set public bundle for this device.
--- a/sat/tools/common/async_process.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/tools/common/async_process.py Wed Sep 01 13:47:17 2021 +0200 @@ -114,8 +114,7 @@ if stdin is not None: stdin = stdin.encode('utf-8') verbose = kwargs.pop('verbose', False) - args = [a.encode('utf-8') for a in args] - kwargs = {k:v.encode('utf-8') for k,v in list(kwargs.items())} + args = list(args) d = defer.Deferred() prot = cls(d, stdin=stdin) if verbose: @@ -131,7 +130,7 @@ prot.name = name else: command = cls.command - cmd_args = [os.path.basename(command)] + args + cmd_args = [command] + args reactor.spawnProcess(prot, command, cmd_args,
--- a/sat/tools/common/date_utils.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/tools/common/date_utils.py Wed Sep 01 13:47:17 2021 +0200 @@ -32,8 +32,22 @@ import re RELATIVE_RE = re.compile(r"(?P<date>.*?)(?P<direction>[-+]?) *(?P<quantity>\d+) *" - r"(?P<unit>(second|minute|hour|day|week|month|year))s?" + r"(?P<unit>(second|sec|s|minute|min|month|mo|m|hour|hr|h|day|d" + r"|week|w|year|yr|y))s?" r"(?P<ago> +ago)?", re.I) +TIME_SYMBOL_MAP = { + "s": "second", + "sec": "second", + "m": "minute", + "min": "minute", + "h": "hour", + "hr": "hour", + "d": "day", + "w": "week", + "mo": "month", + "y": "year", + "yr": "year", +} YEAR_FIRST_RE = re.compile(r"\d{4}[^\d]+") TZ_UTC = tz.tzutc() TZ_LOCAL = tz.gettz() @@ -91,11 +105,15 @@ if not date or date == "now": dt = datetime.datetime.now(tz.tzutc()) else: - dt = default_tzinfo(parser.parse(date, dayfirst=True)) + dt = default_tzinfo(parser.parse(date, dayfirst=True), default_tz) quantity = int(m.group("quantity")) - key = m.group("unit").lower() + "s" - delta_kw = {key: direction * quantity} + unit = m.group("unit").lower() + try: + unit = TIME_SYMBOL_MAP[unit] + except KeyError: + pass + delta_kw = {f"{unit}s": direction * quantity} dt = dt + relativedelta(**delta_kw) return calendar.timegm(dt.utctimetuple())
--- a/sat/tools/utils.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat/tools/utils.py Wed Sep 01 13:47:17 2021 +0200 @@ -28,7 +28,7 @@ import inspect import textwrap import functools -from asyncio import iscoroutine +import asyncio from twisted.python import procutils, failure from twisted.internet import defer from sat.core.constants import Const as C @@ -102,7 +102,7 @@ except Exception as e: return defer.fail(failure.Failure(e)) else: - if iscoroutine(ret): + if asyncio.iscoroutine(ret): return defer.ensureDeferred(ret) elif isinstance(ret, defer.Deferred): return ret @@ -112,6 +112,31 @@ return defer.succeed(ret) +def aio(func): + """Decorator to return a Deferred from asyncio coroutine + + Functions with this decorator are run in asyncio context + """ + def wrapper(*args, **kwargs): + return defer.Deferred.fromFuture(asyncio.ensure_future(func(*args, **kwargs))) + return wrapper + + +def as_future(d): + return d.asFuture(asyncio.get_event_loop()) + + +def ensure_deferred(func): + """Decorator to apply ensureDeferred to a function + + to be used when the function is called by third party library (e.g. wokkel) + Otherwise, it's better to use ensureDeferred as early as possible. + """ + def wrapper(*args, **kwargs): + return defer.ensureDeferred(func(*args, **kwargs)) + return wrapper + + def xmpp_date(timestamp=None, with_time=True): """Return date according to XEP-0082 specification
--- a/sat_frontends/bridge/bridge_frontend.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat_frontends/bridge/bridge_frontend.py Wed Sep 01 13:47:17 2021 +0200 @@ -28,15 +28,14 @@ @param message (str): error message @param condition (str) : error condition """ - Exception.__init__(self) + super().__init__() self.fullname = str(name) self.message = str(message) self.condition = str(condition) if condition else "" self.module, __, self.classname = str(self.fullname).rpartition(".") def __str__(self): - message = (": %s" % self.message) if self.message else "" - return self.classname + message + return self.classname + (f": {self.message}" if self.message else "") def __eq__(self, other): return self.classname == other
--- a/sat_frontends/jp/base.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat_frontends/jp/base.py Wed Sep 01 13:47:17 2021 +0200 @@ -34,6 +34,7 @@ import termios from pathlib import Path from glob import iglob +from typing import Optional from importlib import import_module from sat_frontends.tools.jid import JID from sat.tools import config @@ -41,6 +42,7 @@ from sat.tools.common import uri from sat.tools.common import date_utils from sat.tools.common import utils +from sat.tools.common import data_format from sat.tools.common.ansi import ANSI as A from sat.core import exceptions import sat_frontends.jp @@ -331,7 +333,7 @@ def make_pubsub_group(self, flags, defaults): - """generate pubsub options according to flags + """Generate pubsub options according to flags @param flags(iterable[unicode]): see [CommandBase.__init__] @param defaults(dict[unicode, unicode]): help text for default value @@ -401,7 +403,7 @@ help=_("find page before this item"), metavar='ITEM_ID') rsm_page_group.add_argument( "--index", dest="rsm_index", type=int, - help=_("index of the page to retrieve")) + help=_("index of the first item to retrieve")) # MAM @@ -423,6 +425,12 @@ C.ORDER_BY_MODIFICATION], help=_("how items should be ordered")) + if flags[C.CACHE]: + pubsub_group.add_argument( + "-C", "--no-cache", dest="use_cache", action='store_false', + help=_("don't use Pubsub cache") + ) + if not flags.all_used: raise exceptions.InternalError('unknown flags: {flags}'.format( flags=', '.join(flags.unused))) @@ -1190,11 +1198,11 @@ _('trying to use output when use_output has not been set')) return self.host.output(output_type, self.args.output, self.extra_outputs, data) - def getPubsubExtra(self, extra=None): + def getPubsubExtra(self, extra: Optional[dict] = None) -> str: """Helper method to compute extra data from pubsub arguments - @param extra(None, dict): base extra dict, or None to generate a new one - @return (dict): dict which can be used directly in the bridge for pubsub + @param extra: base extra dict, or None to generate a new one + @return: dict which can be used directly in the bridge for pubsub """ if extra is None: extra = {} @@ -1235,7 +1243,17 @@ else: if order_by is not None: extra[C.KEY_ORDER_BY] = self.args.order_by - return extra + + # Cache + try: + use_cache = self.args.use_cache + except AttributeError: + pass + else: + if not use_cache: + extra[C.KEY_USE_CACHE] = use_cache + + return data_format.serialise(extra) def add_parser_options(self): try:
--- a/sat_frontends/jp/cmd_blog.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat_frontends/jp/cmd_blog.py Wed Sep 01 13:47:17 2021 +0200 @@ -251,7 +251,7 @@ "get", use_verbose=True, use_pubsub=True, - pubsub_flags={C.MULTI_ITEMS}, + pubsub_flags={C.MULTI_ITEMS, C.CACHE}, use_output=C.OUTPUT_COMPLEX, extra_outputs=extra_outputs, help=_("get blog item(s)"), @@ -600,7 +600,9 @@ items = [item] if item else [] mb_data = data_format.deserialise( - await self.host.bridge.mbGet(service, node, 1, items, {}, self.profile) + await self.host.bridge.mbGet( + service, node, 1, items, data_format.serialise({}), self.profile + ) ) item = mb_data["items"][0]
--- a/sat_frontends/jp/cmd_merge_request.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat_frontends/jp/cmd_merge_request.py Wed Sep 01 13:47:17 2021 +0200 @@ -139,7 +139,7 @@ self.args.max, self.args.items, "", - extra, + data_format.serialise(extra), self.profile, ) )
--- a/sat_frontends/jp/cmd_pubsub.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat_frontends/jp/cmd_pubsub.py Wed Sep 01 13:47:17 2021 +0200 @@ -762,6 +762,220 @@ ) +class CacheGet(base.CommandBase): + def __init__(self, host): + super().__init__( + host, + "get", + use_output=C.OUTPUT_LIST_XML, + use_pubsub=True, + pubsub_flags={C.NODE, C.MULTI_ITEMS, C.CACHE}, + help=_("get pubsub item(s) from cache"), + ) + + def add_parser_options(self): + self.parser.add_argument( + "-S", + "--sub-id", + default="", + help=_("subscription id"), + ) + + async def start(self): + try: + ps_result = data_format.deserialise( + await self.host.bridge.psCacheGet( + self.args.service, + self.args.node, + self.args.max, + self.args.items, + self.args.sub_id, + self.getPubsubExtra(), + self.profile, + ) + ) + except BridgeException as e: + if e.classname == "NotFound": + self.disp( + f"The node {self.args.node} from {self.args.service} is not in cache " + f"for {self.profile}", + error=True, + ) + self.host.quit(C.EXIT_NOT_FOUND) + else: + self.disp(f"can't get pubsub items from cache: {e}", error=True) + self.host.quit(C.EXIT_BRIDGE_ERRBACK) + except Exception as e: + self.disp(f"Internal error: {e}", error=True) + self.host.quit(C.EXIT_INTERNAL_ERROR) + else: + await self.output(ps_result["items"]) + self.host.quit(C.EXIT_OK) + + +class CacheSync(base.CommandBase): + + def __init__(self, host): + super().__init__( + host, + "sync", + use_pubsub=True, + pubsub_flags={C.NODE}, + help=_("(re)synchronise a pubsub node"), + ) + + def add_parser_options(self): + pass + + async def start(self): + try: + await self.host.bridge.psCacheSync( + self.args.service, + self.args.node, + self.profile, + ) + except BridgeException as e: + if e.condition == "item-not-found" or e.classname == "NotFound": + self.disp( + f"The node {self.args.node} doesn't exist on {self.args.service}", + error=True, + ) + self.host.quit(C.EXIT_NOT_FOUND) + else: + self.disp(f"can't synchronise pubsub node: {e}", error=True) + self.host.quit(C.EXIT_BRIDGE_ERRBACK) + except Exception as e: + self.disp(f"Internal error: {e}", error=True) + self.host.quit(C.EXIT_INTERNAL_ERROR) + else: + self.host.quit(C.EXIT_OK) + + +class CachePurge(base.CommandBase): + + def __init__(self, host): + super().__init__( + host, + "purge", + use_profile=False, + help=_("purge (delete) items from cache"), + ) + + def add_parser_options(self): + self.parser.add_argument( + "-s", "--service", action="append", metavar="JID", dest="services", + help="purge items only for these services. If not specified, items from ALL " + "services will be purged. May be used several times." + ) + self.parser.add_argument( + "-n", "--node", action="append", dest="nodes", + help="purge items only for these nodes. If not specified, items from ALL " + "nodes will be purged. May be used several times." + ) + self.parser.add_argument( + "-p", "--profile", action="append", dest="profiles", + help="purge items only for these profiles. If not specified, items from ALL " + "profiles will be purged. May be used several times." + ) + self.parser.add_argument( + "-b", "--updated-before", type=base.date_decoder, metavar="TIME_PATTERN", + help="purge items which have been last updated before given time." + ) + self.parser.add_argument( + "-C", "--created-before", type=base.date_decoder, metavar="TIME_PATTERN", + help="purge items which have been last created before given time." + ) + self.parser.add_argument( + "-t", "--type", action="append", dest="types", + help="purge items flagged with TYPE. May be used several times." + ) + self.parser.add_argument( + "-S", "--subtype", action="append", dest="subtypes", + help="purge items flagged with SUBTYPE. May be used several times." + ) + self.parser.add_argument( + "-f", "--force", action="store_true", + help=_("purge items without confirmation") + ) + + async def start(self): + if not self.args.force: + await self.host.confirmOrQuit( + _( + "Are you sure to purge items from cache? You'll have to bypass cache " + "or resynchronise nodes to access deleted items again." + ), + _("Items purgins has been cancelled.") + ) + purge_data = {} + for key in ( + "services", "nodes", "profiles", "updated_before", "created_before", + "types", "subtypes" + ): + value = getattr(self.args, key) + if value is not None: + purge_data[key] = value + try: + await self.host.bridge.psCachePurge( + data_format.serialise( + purge_data + ) + ) + except Exception as e: + self.disp(f"Internal error: {e}", error=True) + self.host.quit(C.EXIT_INTERNAL_ERROR) + else: + self.host.quit(C.EXIT_OK) + + +class CacheReset(base.CommandBase): + + def __init__(self, host): + super().__init__( + host, + "reset", + use_profile=False, + help=_("remove everything from cache"), + ) + + def add_parser_options(self): + self.parser.add_argument( + "-f", "--force", action="store_true", + help=_("reset cache without confirmation") + ) + + async def start(self): + if not self.args.force: + await self.host.confirmOrQuit( + _( + "Are you sure to reset cache? All nodes and items will be removed " + "from it, then it will be progressively refilled as if it were new." + ), + _("Pubsub cache reset has been cancelled.") + ) + try: + await self.host.bridge.psCacheReset() + except Exception as e: + self.disp(f"Internal error: {e}", error=True) + self.host.quit(C.EXIT_INTERNAL_ERROR) + else: + self.host.quit(C.EXIT_OK) + + +class Cache(base.CommandBase): + subcommands = ( + CacheGet, + CacheSync, + CachePurge, + CacheReset, + ) + + def __init__(self, host): + super(Cache, self).__init__( + host, "cache", use_profile=False, help=_("pubsub cache handling") + ) + + class Set(base.CommandBase): def __init__(self, host): base.CommandBase.__init__( @@ -823,7 +1037,7 @@ "get", use_output=C.OUTPUT_LIST_XML, use_pubsub=True, - pubsub_flags={C.NODE, C.MULTI_ITEMS}, + pubsub_flags={C.NODE, C.MULTI_ITEMS, C.CACHE}, help=_("get pubsub item(s)"), ) @@ -835,7 +1049,6 @@ help=_("subscription id"), ) # TODO: a key(s) argument to select keys to display - # TODO: add MAM filters async def start(self): try: @@ -884,7 +1097,8 @@ "-f", "--force", action="store_true", help=_("delete without confirmation") ) self.parser.add_argument( - "-N", "--notify", action="store_true", help=_("notify deletion") + "--no-notification", dest="notify", action="store_false", + help=_("do not send notification (not recommended)") ) async def start(self): @@ -955,7 +1169,7 @@ items = [item] if item else [] ps_result = data_format.deserialise( await self.host.bridge.psItemsGet( - service, node, 1, items, "", {}, self.profile + service, node, 1, items, "", data_format.serialise({}), self.profile ) ) item_raw = ps_result["items"][0] @@ -1687,7 +1901,7 @@ self.args.rsm_max, self.args.items, "", - extra, + data_format.serialise(extra), self.profile, ) ) @@ -2054,12 +2268,13 @@ Subscribe, Unsubscribe, Subscriptions, - Node, Affiliations, Search, Transform, Hook, Uri, + Node, + Cache, ) def __init__(self, host):
--- a/sat_frontends/jp/constants.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat_frontends/jp/constants.py Wed Sep 01 13:47:17 2021 +0200 @@ -62,6 +62,7 @@ SINGLE_ITEM = "single_item" # only one item is allowed MULTI_ITEMS = "multi_items" # multiple items are allowed NO_MAX = "no_max" # don't add --max option for multi items + CACHE = "cache" # add cache control flag # ANSI A_HEADER = A.BOLD + A.FG_YELLOW
--- a/sat_frontends/tools/misc.py Wed Sep 01 13:46:43 2021 +0200 +++ b/sat_frontends/tools/misc.py Wed Sep 01 13:47:17 2021 +0200 @@ -76,6 +76,9 @@ self._used_flags.add(flag) return flag in self.flags + def __getitem__(self, flag): + return getattr(self, flag) + def __len__(self): return len(self.flags)
--- a/setup.py Wed Sep 01 13:46:43 2021 +0200 +++ b/setup.py Wed Sep 01 13:47:17 2021 +0200 @@ -44,7 +44,7 @@ 'python-dateutil < 3', 'python-potr < 1.1', 'pyxdg < 0.30', - 'sat_tmp >= 0.8.0b1, < 0.9', + 'sat_tmp == 0.9.*', 'shortuuid < 1.1', 'twisted[tls] >= 20.3.0, < 21.3.0', 'treq < 22.0.0', @@ -54,6 +54,10 @@ 'omemo >= 0.11.0, < 0.13.0', 'omemo-backend-signal < 0.3', 'pyyaml < 5.5.0', + 'sqlalchemy >= 1.4', + 'alembic', + 'aiosqlite', + 'txdbus', ] extras_require = { @@ -137,6 +141,6 @@ use_scm_version=sat_dev_version if is_dev_version else False, install_requires=install_requires, extras_require=extras_require, - package_data={"sat": ["VERSION"]}, + package_data={"sat": ["VERSION", "memory/migration/alembic.ini"]}, python_requires=">=3.7", )
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tests/unit/conftest.py Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 + +# Libervia: an XMPP client +# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +from contextlib import contextmanager +from unittest.mock import MagicMock, AsyncMock +from pytest import fixture +from twisted.internet import defer +from twisted.words.protocols.jabber import jid +from sat.core.sat_main import SAT +from sat.tools import async_trigger as trigger +from sat.core import xmpp + + +@fixture(scope="session") +def bridge(): + bridge = AsyncMock() + bridge.addSignal = MagicMock() + bridge.addMethod = MagicMock() + return bridge + + +@fixture(scope="session") +def storage(): + return AsyncMock() + + +class MockSAT(SAT): + + def __init__(self, bridge, storage): + self._cb_map = {} + self._menus = {} + self._menus_paths = {} + self._test_config = {} + self.profiles = {} + self.plugins = {} + # map for short name to whole namespace, + # extended by plugins with registerNamespace + self.ns_map = { + "x-data": xmpp.NS_X_DATA, + "disco#info": xmpp.NS_DISCO_INFO, + } + self.memory = MagicMock() + self.memory.storage = storage + self.memory.getConfig.side_effect = self.get_test_config + + self.trigger = trigger.TriggerManager() + self.bridge = bridge + defer.ensureDeferred(self._postInit()) + self.common_cache = AsyncMock() + self._import_plugins() + self._addBaseMenus() + self.initialised = defer.Deferred() + self.initialised.callback(None) + + def get_test_config(self, section, name, default=None): + return self._test_config.get((section or None, name), default) + + def set_test_config(self, section, name, value): + self._test_config[(section or None, name)] = value + + def clear_test_config(self): + self._test_config.clear() + + @contextmanager + def use_option_and_reload(self, section, name, value): + self.set_test_config(section, name, value) + self.reload_plugins() + try: + yield self + finally: + self.clear_test_config() + self.reload_plugins() + + def reload_plugins(self): + self.plugins.clear() + self.trigger._TriggerManager__triggers.clear() + self.ns_map = { + "x-data": xmpp.NS_X_DATA, + "disco#info": xmpp.NS_DISCO_INFO, + } + self._import_plugins() + + def _init(self): + pass + + async def _postInit(self): + pass + + +@fixture(scope="session") +def host(bridge, storage): + host = MockSAT(bridge=bridge, storage=storage) + return host + + +@fixture +def client(): + client = MagicMock() + client.jid = jid.JID("test_user@test.example/123") + client.pubsub_service = jid.JID("pubsub.test.example") + client.pubsub_client = AsyncMock() + return client
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tests/unit/test_pubsub-cache.py Wed Sep 01 13:47:17 2021 +0200 @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 + +# Libervia: an XMPP client +# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +from twisted.internet import defer +from pytest_twisted import ensureDeferred as ed +from unittest.mock import MagicMock, patch +from sat.memory.sqla import PubsubNode, SyncState +from sat.core.constants import Const as C + + +class TestPubsubCache: + + @ed + async def test_cache_is_used_transparently(self, host, client): + """Cache is used when a pubsub getItems operation is done""" + items_ret = defer.Deferred() + items_ret.callback(([], {})) + client.pubsub_client.items = MagicMock(return_value=items_ret) + host.memory.storage.getPubsubNode.return_value = None + pubsub_node = host.memory.storage.setPubsubNode.return_value = PubsubNode( + sync_state = None + ) + with patch.object(host.plugins["PUBSUB_CACHE"], "cacheNode") as cacheNode: + await host.plugins["XEP-0060"].getItems( + client, + None, + "urn:xmpp:microblog:0", + ) + assert cacheNode.call_count == 1 + assert cacheNode.call_args.args[-1] == pubsub_node + + @ed + async def test_cache_is_skipped_with_use_cache_false(self, host, client): + """Cache is skipped when 'use_cache' extra field is False""" + items_ret = defer.Deferred() + items_ret.callback(([], {})) + client.pubsub_client.items = MagicMock(return_value=items_ret) + host.memory.storage.getPubsubNode.return_value = None + host.memory.storage.setPubsubNode.return_value = PubsubNode( + sync_state = None + ) + with patch.object(host.plugins["PUBSUB_CACHE"], "cacheNode") as cacheNode: + await host.plugins["XEP-0060"].getItems( + client, + None, + "urn:xmpp:microblog:0", + extra = {C.KEY_USE_CACHE: False} + ) + assert not cacheNode.called + + @ed + async def test_cache_is_not_used_when_no_cache(self, host, client): + """Cache is skipped when 'pubsub_cache_strategy' is set to 'no_cache'""" + with host.use_option_and_reload(None, "pubsub_cache_strategy", "no_cache"): + items_ret = defer.Deferred() + items_ret.callback(([], {})) + client.pubsub_client.items = MagicMock(return_value=items_ret) + host.memory.storage.getPubsubNode.return_value = None + host.memory.storage.setPubsubNode.return_value = PubsubNode( + sync_state = None + ) + with patch.object(host.plugins["PUBSUB_CACHE"], "cacheNode") as cacheNode: + await host.plugins["XEP-0060"].getItems( + client, + None, + "urn:xmpp:microblog:0", + ) + assert not cacheNode.called + + + @ed + async def test_no_pubsub_get_when_cache_completed(self, host, client): + """No pubsub get is emitted when items are fully cached""" + items_ret = defer.Deferred() + items_ret.callback(([], {})) + client.pubsub_client.items = MagicMock(return_value=items_ret) + host.memory.storage.getPubsubNode.return_value = PubsubNode( + sync_state = SyncState.COMPLETED + ) + with patch.object( + host.plugins["PUBSUB_CACHE"], + "getItemsFromCache" + ) as getItemsFromCache: + getItemsFromCache.return_value = ([], {}) + await host.plugins["XEP-0060"].getItems( + client, + None, + "urn:xmpp:microblog:0", + ) + assert getItemsFromCache.call_count == 1 + assert not client.pubsub_client.items.called + + @ed + async def test_pubsub_get_when_cache_in_progress(self, host, client): + """Pubsub get is emitted when items are currently being cached""" + items_ret = defer.Deferred() + items_ret.callback(([], {})) + client.pubsub_client.items = MagicMock(return_value=items_ret) + host.memory.storage.getPubsubNode.return_value = PubsubNode( + sync_state = SyncState.IN_PROGRESS + ) + with patch.object(host.plugins["PUBSUB_CACHE"], "analyseNode") as analyseNode: + analyseNode.return_value = {"to_sync": True} + with patch.object( + host.plugins["PUBSUB_CACHE"], + "getItemsFromCache" + ) as getItemsFromCache: + getItemsFromCache.return_value = ([], {}) + assert client.pubsub_client.items.call_count == 0 + await host.plugins["XEP-0060"].getItems( + client, + None, + "urn:xmpp:microblog:0", + ) + assert not getItemsFromCache.called + assert client.pubsub_client.items.call_count == 1
--- a/twisted/plugins/sat_plugin.py Wed Sep 01 13:46:43 2021 +0200 +++ b/twisted/plugins/sat_plugin.py Wed Sep 01 13:47:17 2021 +0200 @@ -63,10 +63,11 @@ pass def makeService(self, options): - from twisted.internet import gireactor - gireactor.install() + from twisted.internet import asyncioreactor + asyncioreactor.install() self.setDebugger() - # XXX: SAT must be imported after log configuration, because it write stuff to logs + # XXX: Libervia must be imported after log configuration, + # because it write stuff to logs initialise(options.parent) from sat.core.sat_main import SAT return SAT()