diff sat/plugins/plugin_misc_identity.py @ 3338:203a491fcd86

plugin identity: new methods `identitiesGet` and `identitiesBaseGet` `identitiesGet` retrieve several identities at once, in parallel. `identitiesBaseGet` retrieve essential identities (roster + own)
author Goffi <goffi@goffi.org>
date Thu, 13 Aug 2020 23:46:18 +0200
parents 9e1ba1e1179f
children be6d91572633
line wrap: on
line diff
--- a/sat/plugins/plugin_misc_identity.py	Thu Aug 13 23:46:18 2020 +0200
+++ b/sat/plugins/plugin_misc_identity.py	Thu Aug 13 23:46:18 2020 +0200
@@ -15,7 +15,7 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
-from typing import Dict, Union, Coroutine, Any, Optional
+from typing import Dict, List, Union, Coroutine, Any, Optional
 from collections import namedtuple
 from pathlib import Path
 from twisted.internet import defer
@@ -94,6 +94,22 @@
             async_=True,
         )
         host.bridge.addMethod(
+            "identitiesGet",
+            ".plugin",
+            in_sign="asass",
+            out_sign="s",
+            method=self._getIdentities,
+            async_=True,
+        )
+        host.bridge.addMethod(
+            "identitiesBaseGet",
+            ".plugin",
+            in_sign="s",
+            out_sign="s",
+            method=self._getBaseIdentities,
+            async_=True,
+        )
+        host.bridge.addMethod(
             "identitySet",
             ".plugin",
             in_sign="ss",
@@ -178,6 +194,7 @@
             defer.ensureDeferred(
                 self.update(
                     client,
+                    IMPORT_NAME,
                     "nicknames",
                     [roster_item.name],
                     roster_item.jid
@@ -561,7 +578,7 @@
         client = self.host.getClient(profile)
         d = defer.ensureDeferred(
             self.getIdentity(client, entity, metadata_filter, use_cache))
-        d.addCallback(lambda data: data_format.serialise(data))
+        d.addCallback(data_format.serialise)
         return d
 
     async def getIdentity(
@@ -588,6 +605,71 @@
 
         return id_data
 
+    def _getIdentities(self, entities_s, metadata_filter, profile):
+        entities = [jid.JID(e) for e in entities_s]
+        client = self.host.getClient(profile)
+        d = defer.ensureDeferred(self.getIdentities(client, entities, metadata_filter))
+        d.addCallback(lambda d: data_format.serialise({str(j):i for j, i in d.items()}))
+        return d
+
+    async def getIdentities(
+        self,
+        client: SatXMPPEntity,
+        entities: List[jid.JID],
+        metadata_filter: Optional[List[str]] = None,
+    ) -> dict:
+        """Retrieve several identities at once
+
+        @param entities: entities from which identities must be retrieved
+        @param metadata_filter: same as for [getIdentity]
+        @return: identities metadata where key is jid
+            if an error happens while retrieve a jid entity, it won't be present in the
+            result (and a warning will be logged)
+        """
+        identities = {}
+        get_identity_list = []
+        for entity_jid in entities:
+            get_identity_list.append(
+                defer.ensureDeferred(
+                    self.getIdentity(
+                        client,
+                        entity=entity_jid,
+                        metadata_filter=metadata_filter,
+                    )
+                )
+            )
+        identities_result = await defer.DeferredList(get_identity_list)
+        for idx, (success, identity) in enumerate(identities_result):
+            entity_jid = entities[idx]
+            if not success:
+                log.warning(f"Can't get identity for {entity_jid}")
+            else:
+                identities[entity_jid] = identity
+        return identities
+
+    def _getBaseIdentities(self, profile_key):
+        client = self.host.getClient(profile_key)
+        d = defer.ensureDeferred(self.getBaseIdentities(client))
+        d.addCallback(lambda d: data_format.serialise({str(j):i for j, i in d.items()}))
+        return d
+
+    async def getBaseIdentities(
+        self,
+        client: SatXMPPEntity,
+    ) -> dict:
+        """Retrieve identities for entities in roster + own identity + invitations
+
+        @param with_guests: if True, get affiliations of people invited by email
+
+        """
+        entities = client.roster.getJids() + [client.jid.userhostJID()]
+
+        return await self.getIdentities(
+            client,
+            entities,
+            ['avatar', 'nicknames']
+        )
+
     def _setIdentity(self, id_data_s, profile):
         client = self.host.getClient(profile)
         id_data = data_format.deserialise(id_data_s)