diff sat/plugins/plugin_xep_0313.py @ 2701:2ea2369ae7de

plugin XEP-0313: implementation of MAM for messages: - (core/xmpp): new messageGetBridgeArgs to easily retrieve arguments used in bridge from message data - : parseMessage is not static anymore - : new "message_parse" trigger point - (xep-0313) : new "MAMGet" bridge method to retrieve history from MAM instead of local one - : on profileConnected, if previous MAM message is found (i.e. message with a stanza_id), message received while offline are retrieved and injected in message workflow. In other words, one2one history is synchronised on connection. - : new "parseExtra" method which parse MAM (and optionally RSM) option from extra dictionary used in bridge.
author Goffi <goffi@goffi.org>
date Sat, 01 Dec 2018 10:33:43 +0100
parents 56f94936df1e
children 19000c506d0c
line wrap: on
line diff
--- a/sat/plugins/plugin_xep_0313.py	Sat Dec 01 10:10:25 2018 +0100
+++ b/sat/plugins/plugin_xep_0313.py	Sat Dec 01 10:33:43 2018 +0100
@@ -21,15 +21,15 @@
 from sat.core.constants import Const as C
 from sat.core.i18n import _
 from sat.core.log import getLogger
-
-log = getLogger(__name__)
 from sat.core import exceptions
-
+from sat.tools.common import data_format
 from twisted.words.protocols.jabber import jid
-
+from twisted.internet import defer
 from zope.interface import implements
-
+from datetime import datetime
+from dateutil import tz
 from wokkel import disco
+from wokkel import data_form
 import uuid
 
 # XXX: mam and rsm come from sat_tmp.wokkel
@@ -37,28 +37,178 @@
 from wokkel import mam
 
 
+log = getLogger(__name__)
+
+
 MESSAGE_RESULT = "/message/result[@xmlns='{mam_ns}' and @queryid='{query_id}']"
 
 PLUGIN_INFO = {
-    C.PI_NAME: "Message Archive Management",
-    C.PI_IMPORT_NAME: "XEP-0313",
-    C.PI_TYPE: "XEP",
-    C.PI_PROTOCOLS: ["XEP-0313"],
-    C.PI_MAIN: "XEP_0313",
-    C.PI_HANDLER: "yes",
-    C.PI_DESCRIPTION: _("""Implementation of Message Archive Management"""),
+    C.PI_NAME: u"Message Archive Management",
+    C.PI_IMPORT_NAME: u"XEP-0313",
+    C.PI_TYPE: u"XEP",
+    C.PI_PROTOCOLS: [u"XEP-0313"],
+    C.PI_DEPENDENCIES: [u"XEP-0059", u"XEP-0359"],
+    C.PI_MAIN: u"XEP_0313",
+    C.PI_HANDLER: u"yes",
+    C.PI_DESCRIPTION: _(u"""Implementation of Message Archive Management"""),
 }
 
+MAM_PREFIX = u"mam_"
+FILTER_PREFIX = MAM_PREFIX + "filter_"
+
 
 class XEP_0313(object):
     def __init__(self, host):
         log.info(_("Message Archive Management plugin initialization"))
         self.host = host
+        self.host.registerNamespace(u"mam", mam.NS_MAM)
+        self._rsm = host.plugins[u"XEP-0059"]
+        self._sid = host.plugins[u"XEP-0359"]
+        host.bridge.addMethod(
+            "MAMGet", ".plugin", in_sign='sss', out_sign='(a(sdssa{ss}a{ss}sa{ss})s)', method=self._getArchives,
+            async=True)
+
+    @defer.inlineCallbacks
+    def profileConnected(self, client):
+        last_mess = yield self.host.memory.historyGet(
+            None, None, limit=1, filters={u'last_stanza_id': True},
+            profile=client.profile)
+        if not last_mess:
+            log.info(_(u"It seems that we have no MAM history yet"))
+            return
+        stanza_id = last_mess[0][-1][u'stanza_id']
+        # XXX: test
+        # stanza_id = "IIheJOfiIhkPYkw6"
+        rsm_req = rsm.RSMRequest(after=stanza_id)
+        mam_req = mam.MAMRequest(rsm_=rsm_req)
+        mam_data = yield self.getArchives(client, mam_req,
+                                         service=client.jid.userhostJID())
+        elt_list, rsm_response = mam_data
+        if not elt_list:
+            log.info(_(u"We have received no message while offline"))
+            return
+        else:
+            log.info(_(u"We have received {num_mess} message(s) while offline.").format(
+                num_mess=len(elt_list)))
+
+        for mess_elt in elt_list:
+            try:
+                fwd_message_elt = self.getMessageFromResult(client, mess_elt, mam_req)
+            except exceptions.DataError:
+                continue
+
+            client.messageProt.onMessage(fwd_message_elt)
 
     def getHandler(self, client):
         mam_client = client._mam = SatMAMClient()
         return mam_client
 
+    def parseExtra(self, extra, with_rsm=True):
+        """Parse extra dictionnary to retrieve MAM arguments
+
+        @param extra(dict): data for parse
+        @param with_rsm(bool): if True, RSM data will be parsed too
+        @return (data_form, None): request with parsed arguments
+            or None if no MAM arguments have been found
+        """
+        mam_args = {}
+        form_args = {}
+        for arg in (u"start", u"end"):
+            try:
+                value = extra.pop(MAM_PREFIX + arg)
+                form_args[arg] = datetime.fromtimestamp(float(value), tz.tzutc())
+            except (TypeError, ValueError):
+                log.warning(u"Bad value for {arg} filter ({value}), ignoring".format(
+                    arg=arg, value=value))
+            except KeyError:
+                continue
+
+        try:
+            form_args[u"with_jid"] = jid.JID(extra.pop(
+                MAM_PREFIX + u"with"))
+        except (jid.InvalidFormat):
+            log.warning(u"Bad value for jid filter")
+        except KeyError:
+            pass
+
+        for name, value in extra.iteritems():
+            if name.startswith(FILTER_PREFIX):
+                var = name[len(FILTER_PREFIX) :]
+                extra_fields = form_args.setdefault(u"extra_fields", [])
+                extra_fields.append(data_form.Field(var=var, value=value))
+
+        for arg in (u"node", u"query_id"):
+            try:
+                value = extra.pop(MAM_PREFIX + arg)
+                mam_args[arg] = value
+            except KeyError:
+                continue
+
+        if with_rsm:
+            rsm_request = self._rsm.parseExtra(extra)
+            if rsm_request is not None:
+                mam_args["rsm_"] = rsm_request
+
+        if form_args:
+            mam_args["form"] = mam.buildForm(**form_args)
+
+        return mam.MAMRequest(**mam_args) if mam_args else None
+
+    def getMessageFromResult(self, client, mess_elt, mam_req):
+        """Extract usable <message/> from MAM query result
+
+        The message will be validated, and stanza-id/delay will be added if necessary.
+        @param mess_elt(domish.Element): result <message/> element wrapping the message
+            to retrieve
+        @param mam_req(mam.MAMRequest): request used
+        @return (domish.Element): <message/> that can be used directly with onMessage
+        """
+        if mess_elt.name != u"message":
+            log.warning(u"unexpected stanza in archive: {xml}".format(
+                xml=mess_elt.toXml()))
+            raise exceptions.DataError(u"Invalid element")
+        mess_from = mess_elt[u"from"]
+        if mess_from != client.jid.host and mess_from != client.jid.userhost():
+            log.error(u"Message is not from our server, something went wrong: "
+                      u"{xml}".format(xml=mess_elt.toXml()))
+            raise exceptions.DataError(u"Invalid element")
+        try:
+            result_elt = next(mess_elt.elements(mam.NS_MAM, u"result"))
+            forwarded_elt = next(result_elt.elements(C.NS_FORWARD, u"forwarded"))
+            try:
+                delay_elt = next(forwarded_elt.elements(C.NS_DELAY, u"delay"))
+            except StopIteration:
+                # delay_elt is not mandatory
+                delay_elt = None
+            fwd_message_elt = next(forwarded_elt.elements(C.NS_CLIENT, u"message"))
+        except StopIteration:
+            log.warning(u"Invalid message received from MAM: {xml}".format(
+                xml=mess_elt.toXml()))
+            raise exceptions.DataError(u"Invalid element")
+        else:
+            if not result_elt[u"queryid"] == mam_req.query_id:
+                log.error(u"Unexpected query id (was expecting {query_id}): {xml}"
+                    .format(query_id=mam.query_id, xml=mess_elt.toXml()))
+                raise exceptions.DataError(u"Invalid element")
+            stanza_id = self._sid.getStanzaId(fwd_message_elt,
+                                              client.jid.userhostJID())
+            if stanza_id is None:
+                # not stanza-id element is present, we add one so message
+                # will be archived with it, and we won't request several times
+                # the same MAM achive
+                try:
+                    stanza_id = result_elt[u"id"]
+                except AttributeError:
+                    log.warning(u'Invalid MAM result: missing "id" attribute: {xml}'
+                                .format(xml=result_elt.toXml()))
+                    raise exceptions.DataError(u"Invalid element")
+                self._sid.addStanzaId(client, fwd_message_elt, stanza_id)
+
+            if delay_elt is not None:
+                fwd_message_elt.addChild(delay_elt)
+
+            return fwd_message_elt
+
     def queryFields(self, client, service=None):
         """Ask the server about supported fields.
 
@@ -67,15 +217,15 @@
         """
         return client._mam.queryFields(service)
 
-    def queryArchive(self, client, mam_query, service=None):
+    def queryArchive(self, client, mam_req, service=None):
         """Query a user, MUC or pubsub archive.
 
-        @param mam_query(mam.MAMRequest): MAM query instance
+        @param mam_req(mam.MAMRequest): MAM query instance
         @param service(jid.JID, None): entity offering the MAM service
             None for user server
         @return (D(domish.Element)): <IQ/> result
         """
-        return client._mam.queryArchive(mam_query, service)
+        return client._mam.queryArchive(mam_req, service)
 
     def _appendMessage(self, elt_list, message_cb, message_elt):
         if message_cb is not None:
@@ -97,6 +247,25 @@
 
         return (elt_list, rsm_response)
 
+    def serializeArchiveResult(self, data, client, mam_req):
+        elt_list, rsm_response = data
+        mess_list = []
+        for elt in elt_list:
+            fwd_message_elt = self.getMessageFromResult(client, elt, mam_req)
+            mess_data = client.messageProt.parseMessage(fwd_message_elt)
+            mess_list.append(client.messageGetBridgeArgs(mess_data))
+        return mess_list, client.profile
+
+    def _getArchives(self, service, extra_ser, profile_key):
+        client = self.host.getClient(profile_key)
+        service = jid.JID(service) if service else None
+        extra = data_format.deserialise(extra_ser, {})
+        mam_req = self.parseExtra(extra)
+
+        d = self.getArchives(client, mam_req, service=service)
+        d.addCallback(self.serializeArchiveResult, client, mam_req)
+        return d
+
     def getArchives(self, client, query, service=None, message_cb=None):
         """Query archive then grab and return them all in the result
 
@@ -119,14 +288,8 @@
         # http://xmpp.org/extensions/xep-0313.html#prefs
         return client._mam.queryPrefs(service)
 
-    def _setPrefs(
-        self,
-        service_s=None,
-        default="roster",
-        always=None,
-        never=None,
-        profile_key=C.PROF_KEY_NONE,
-    ):
+    def _setPrefs(self, service_s=None, default="roster", always=None, never=None,
+                  profile_key=C.PROF_KEY_NONE):
         service = jid.JID(service_s) if service_s else None
         always_jid = [jid.JID(entity) for entity in always]
         never_jid = [jid.JID(entity) for entity in never]