diff sat/core/xmpp.py @ 3715:b9718216a1c0 0.9

merge bookmark 0.9
author Goffi <goffi@goffi.org>
date Wed, 01 Dec 2021 16:13:31 +0100
parents e3dddf65fa88 09f5ac48ffe3
children e52de21873d3
line wrap: on
line diff
--- a/sat/core/xmpp.py	Tue Nov 30 23:31:09 2021 +0100
+++ b/sat/core/xmpp.py	Wed Dec 01 16:13:31 2021 +0100
@@ -42,6 +42,7 @@
 from wokkel import delay
 from sat.core.log import getLogger
 from sat.core import exceptions
+from sat.core import core_types
 from sat.memory import encryption
 from sat.memory import persistent
 from sat.tools import xml_tools
@@ -83,7 +84,7 @@
         return partial(getattr(self.plugin, attr), self.client)
 
 
-class SatXMPPEntity:
+class SatXMPPEntity(core_types.SatXMPPEntity):
     """Common code for Client and Component"""
     # profile is added there when startConnection begins and removed when it is finished
     profiles_connecting = set()
@@ -716,7 +717,7 @@
             or mess_data["type"] == C.MESS_TYPE_INFO
         )
 
-    def messageAddToHistory(self, data):
+    async def messageAddToHistory(self, data):
         """Store message into database (for local history)
 
         @param data: message data dictionnary
@@ -728,7 +729,7 @@
 
             # we need a message to store
             if self.isMessagePrintable(data):
-                self.host_app.memory.addToHistory(self, data)
+                await self.host_app.memory.addToHistory(self, data)
             else:
                 log.warning(
                     "No message found"
@@ -878,7 +879,9 @@
 
     def addPostXmlCallbacks(self, post_xml_treatments):
         post_xml_treatments.addCallback(self.messageProt.completeAttachments)
-        post_xml_treatments.addCallback(self.messageAddToHistory)
+        post_xml_treatments.addCallback(
+            lambda ret: defer.ensureDeferred(self.messageAddToHistory(ret))
+        )
         post_xml_treatments.addCallback(self.messageSendToBridge)
 
     def send(self, obj):
@@ -1063,7 +1066,9 @@
 
     def addPostXmlCallbacks(self, post_xml_treatments):
         if self.sendHistory:
-            post_xml_treatments.addCallback(self.messageAddToHistory)
+            post_xml_treatments.addCallback(
+                lambda ret: defer.ensureDeferred(self.messageAddToHistory(ret))
+            )
 
     def getOwnerFromJid(self, to_jid: jid.JID) -> jid.JID:
         """Retrieve "owner" of a component resource from the destination jid of the request
@@ -1214,7 +1219,9 @@
         data = self.parseMessage(message_elt)
         post_treat.addCallback(self.completeAttachments)
         post_treat.addCallback(self.skipEmptyMessage)
-        post_treat.addCallback(self.addToHistory)
+        post_treat.addCallback(
+            lambda ret: defer.ensureDeferred(self.addToHistory(ret))
+        )
         post_treat.addCallback(self.bridgeSignal, data)
         post_treat.addErrback(self.cancelErrorTrap)
         post_treat.callback(data)
@@ -1255,14 +1262,14 @@
             raise failure.Failure(exceptions.CancelError("Cancelled empty message"))
         return data
 
-    def addToHistory(self, data):
+    async def addToHistory(self, data):
         if data.pop("history", None) == C.HISTORY_SKIP:
             log.debug("history is skipped as requested")
             data["extra"]["history"] = C.HISTORY_SKIP
         else:
             # we need a message to store
             if self.parent.isMessagePrintable(data):
-                return self.host.memory.addToHistory(self.parent, data)
+                return await self.host.memory.addToHistory(self.parent, data)
             else:
                 log.debug("not storing empty message to history: {data}"
                     .format(data=data))
@@ -1480,7 +1487,8 @@
         self._jids[entity] = item
         self._registerItem(item)
         self.host.bridge.newContact(
-            entity.full(), self.getAttributes(item), item.groups, self.parent.profile
+            entity.full(), self.getAttributes(item), list(item.groups),
+            self.parent.profile
         )
 
     def removeReceived(self, request):
@@ -1546,7 +1554,7 @@
                 f"a JID is expected, not {type(entity_jid)}: {entity_jid!r}")
         return entity_jid in self._jids
 
-    def isPresenceAuthorised(self, entity_jid):
+    def isSubscribedFrom(self, entity_jid: jid.JID) -> bool:
         """Return True if entity is authorised to see our presence"""
         try:
             item = self._jids[entity_jid.userhostJID()]
@@ -1554,6 +1562,14 @@
             return False
         return item.subscriptionFrom
 
+    def isSubscribedTo(self, entity_jid: jid.JID) -> bool:
+        """Return True if we are subscribed to entity"""
+        try:
+            item = self._jids[entity_jid.userhostJID()]
+        except KeyError:
+            return False
+        return item.subscriptionTo
+
     def getItems(self):
         """Return all items of the roster"""
         return list(self._jids.values())