diff sat/plugins/plugin_misc_identity.py @ 3277:cf07641b764d

plugin identity: fixed infinite loop on nicknames update
author Goffi <goffi@goffi.org>
date Mon, 18 May 2020 23:52:34 +0200
parents aa71f1d40300
children 27d4b71e264a
line wrap: on
line diff
--- a/sat/plugins/plugin_misc_identity.py	Mon May 18 23:48:40 2020 +0200
+++ b/sat/plugins/plugin_misc_identity.py	Mon May 18 23:52:34 2020 +0200
@@ -15,10 +15,12 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+from typing import Dict, Union, Coroutine, Any, Optional
 from collections import namedtuple
 from pathlib import Path
 from twisted.internet import defer
 from twisted.words.protocols.jabber import jid
+from sat.core.xmpp import SatXMPPEntity
 from sat.core.i18n import _
 from sat.core.constants import Const as C
 from sat.core import exceptions
@@ -44,7 +46,7 @@
     C.PI_DESCRIPTION: _("""Identity manager"""),
 }
 
-Callback = namedtuple("Callback", ("get", "set", "priority"))
+Callback = namedtuple("Callback", ("origin", "get", "set", "priority"))
 
 
 class Identity:
@@ -114,6 +116,7 @@
         )
 
     async def profileConnecting(self, client):
+        client._identity_update_lock = []
         # we restore known identities from database
         client._identity_storage = persistent.LazyPersistentBinaryDict(
             "identity", client.profile)
@@ -176,22 +179,29 @@
             )
         return True
 
-    def register(self, metadata_name, cb_get, cb_set, priority=0):
+    def register(
+            self,
+            origin: str,
+            metadata_name: str,
+            cb_get: Union[Coroutine, defer.Deferred],
+            cb_set: Union[Coroutine, defer.Deferred],
+            priority: int=0):
         """Register callbacks to handle identity metadata
 
-        @param metadata_name(str): name of metadata can be:
+        @param origin: namespace of the plugin managing this metadata
+        @param metadata_name: name of metadata can be:
             - avatar
             - nicknames
-        @param cb_get(coroutine, Deferred): method to retrieve a metadata
+        @param cb_get: method to retrieve a metadata
             the method will get client and metadata names to retrieve as arguments.
-        @param cb_set(coroutine, Deferred): method to set a metadata
+        @param cb_set: method to set a metadata
             the method will get client, metadata name to set, and value as argument.
-        @param priority(int): priority of this method for the given metadata.
+        @param priority: priority of this method for the given metadata.
             methods with bigger priorities will be called first
         """
         if not metadata_name in self.metadata.keys():
             raise ValueError(f"Invalid metadata_name: {metadata_name!r}")
-        callback = Callback(get=cb_get, set=cb_set, priority=priority)
+        callback = Callback(origin=origin, get=cb_get, set=cb_set, priority=priority)
         cb_list = self.metadata[metadata_name].setdefault('callbacks', [])
         cb_list.append(callback)
         cb_list.sort(key=lambda c: c.priority, reverse=True)
@@ -219,17 +229,25 @@
                 f"{value} has wrong type: it is {type(value)} while {value_type} was "
                 f"expected")
 
-    async def get(self, client, metadata_name, entity, use_cache=True):
+    async def get(
+            self,
+            client: SatXMPPEntity,
+            metadata_name: str,
+            entity: Optional[jid.JID],
+            use_cache: bool=True,
+            prefilled_values: Optional[Dict[str, Any]]=None
+        ):
         """Retrieve identity metadata of an entity
 
         if metadata is already in cache, it is returned. Otherwise, registered callbacks
         will be tried in priority order (bigger to lower)
-        @param metadata_name(str): name of the metadata
+        @param metadata_name: name of the metadata
             must be one of self.metadata key
             the name will also be used as entity data name in host.memory
-        @param entity(jid.JID, None): entity for which avatar is requested
+        @param entity: entity for which avatar is requested
             None to use profile's jid
-        @param use_cache(bool): if False, cache won't be checked
+        @param use_cache: if False, cache won't be checked
+        @param prefilled_values: map of origin => value to use when `get_all` is set
         """
         entity = self.getIdentityJid(client, entity)
         try:
@@ -255,10 +273,19 @@
 
         if get_all:
             all_data = []
+        elif prefilled_values is not None:
+            raise exceptions.InternalError(
+                "prefilled_values can only be used when `get_all` is set")
 
         for callback in callbacks:
             try:
-                data = await defer.ensureDeferred(callback.get(client, entity))
+                if prefilled_values is not None and callback.origin in prefilled_values:
+                    data = prefilled_values[callback.origin]
+                    log.debug(
+                        f"using prefilled values {data!r} for {metadata_name} with "
+                        f"{callback.origin}")
+                else:
+                    data = await defer.ensureDeferred(callback.get(client, entity))
             except exceptions.CancelError:
                 continue
             except Exception as e:
@@ -334,12 +361,23 @@
         if post_treatment is not None:
             await utils.asDeferred(post_treatment, client, entity, data)
 
-    async def update(self, client, metadata_name, data, entity):
+    async def update(
+            self,
+            client: SatXMPPEntity,
+            origin: str,
+            metadata_name: str,
+            data: Any,
+            entity: Optional[jid.JID]
+        ):
         """Update a metadata in cache
 
         This method may be called by plugins when an identity metadata is available.
+        @param origin: namespace of the plugin which is source of the metadata
         """
         entity = self.getIdentityJid(client, entity)
+        if (entity, metadata_name) in client._identity_update_lock:
+            log.debug(f"update is locked for {entity}'s {metadata_name}")
+            return
         metadata = self.metadata[metadata_name]
 
         try:
@@ -378,10 +416,14 @@
             # so we first delete current cache
             try:
                 self.host.memory.delEntityDatum(client, entity, metadata_name)
-            except KeyError:
+            except (KeyError, exceptions.UnknownEntityError):
                 pass
             # then fill it again by calling get, which will retrieve all values
-            await self.get(client, metadata_name, entity)
+            # we lock update to avoid infinite recursions (update can be called during
+            # get callbacks)
+            client._identity_update_lock.append((entity, metadata_name))
+            await self.get(client, metadata_name, entity, prefilled_values={origin: data})
+            client._identity_update_lock.remove((entity, metadata_name))
             return
 
         if data is not None: