Mercurial > libervia-backend
diff sat/plugins/plugin_xep_0060.py @ 3584:edc79cefe968
plugin XEP-0060: `getItem(s)`, `publish` and `(un)subscribe` are now coroutines
author | Goffi <goffi@goffi.org> |
---|---|
date | Wed, 30 Jun 2021 16:19:14 +0200 |
parents | 02eec2a5b5f9 |
children | 5f65f4e9f8cb |
line wrap: on
line diff
--- a/sat/plugins/plugin_xep_0060.py Sun Jun 27 00:15:40 2021 +0200 +++ b/sat/plugins/plugin_xep_0060.py Wed Jun 30 16:19:14 2021 +0200 @@ -17,7 +17,7 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. -from typing import Optional +from typing import Optional, List, Tuple from collections import namedtuple import urllib.request, urllib.parse, urllib.error from functools import reduce @@ -35,7 +35,7 @@ from sat.core.i18n import _ from sat.core.constants import Const as C from sat.core.log import getLogger -from sat.core.xmpp import SatXMPPEntity +from sat.core.core_types import SatXMPPEntity from sat.core import exceptions from sat.tools import sat_defer from sat.tools import xml_tools @@ -471,9 +471,9 @@ service = None if not service else jid.JID(service) payload = xml_tools.parse(payload) extra = data_format.deserialise(extra_ser) - d = self.sendItem( + d = defer.ensureDeferred(self.sendItem( client, service, nodeIdentifier, payload, item_id or None, extra - ) + )) d.addCallback(lambda ret: ret or "") return d @@ -487,23 +487,13 @@ raise exceptions.DataError(_("Can't parse items: {msg}").format( msg=e)) extra = data_format.deserialise(extra_ser) - d = self.sendItems( + return defer.ensureDeferred(self.sendItems( client, service, nodeIdentifier, items, extra - ) - return d - - def _getPublishedItemId(self, published_ids, original_id): - """Return item of published id if found in answer + )) - if not found original_id is returned, which may be None - """ - try: - return published_ids[0] - except IndexError: - return original_id - - def sendItem(self, client, service, nodeIdentifier, payload, item_id=None, - extra=None): + async def sendItem( + self, client, service, nodeIdentifier, payload, item_id=None, extra=None + ): """High level method to send one item @param service(jid.JID, None): service to send the item to @@ -519,15 +509,17 @@ if item_id is not None: item_elt['id'] = item_id item_elt.addChild(payload) - d = defer.ensureDeferred(self.sendItems( + published_ids = await self.sendItems( client, service, nodeIdentifier, [item_elt], extra - )) - d.addCallback(self._getPublishedItemId, item_id) - return d + ) + try: + return published_ids[0] + except IndexError: + return item_id async def sendItems(self, client, service, nodeIdentifier, items, extra=None): """High level method to send several items at once @@ -593,8 +585,8 @@ except AttributeError: return [] - def publish(self, client, service, nodeIdentifier, items=None, options=None): - return client.pubsub_client.publish( + async def publish(self, client, service, nodeIdentifier, items=None, options=None): + return await client.pubsub_client.publish( service, nodeIdentifier, items, client.pubsub_client.parent.jid, options=options ) @@ -630,7 +622,7 @@ service = jid.JID(service) if service else None max_items = None if max_items == C.NO_LIMIT else max_items extra = self.parseExtra(extra_dict) - d = self.getItems( + d = defer.ensureDeferred(self.getItems( client, service, node or None, @@ -639,13 +631,22 @@ sub_id or None, extra.rsm_request, extra.extra, - ) + )) d.addCallback(self.transItemsData) d.addCallback(self.serialiseItems) return d - def getItems(self, client, service, node, max_items=None, item_ids=None, sub_id=None, - rsm_request=None, extra=None): + async def getItems( + self, + client: SatXMPPEntity, + service: Optional[jid.JID], + node: str, + max_items: Optional[int] = None, + item_ids: Optional[List[str]] = None, + sub_id: Optional[str] = None, + rsm_request: Optional[rsm.RSMRequest] = None, + extra: Optional[dict] = None + ) -> Tuple[List[dict], dict]: """Retrieve pubsub items from a node. @param service (JID, None): pubsub service. @@ -682,9 +683,10 @@ rsm_request = rsm_request ) # we have no MAM data here, so we add None - d.addCallback(lambda data: data + (None,)) d.addErrback(sat_defer.stanza2NotFound) d.addTimeout(TIMEOUT, reactor) + items, rsm_response = await d + mam_response = None else: # if mam is requested, we have to do a totally different query if self._mam is None: @@ -706,61 +708,49 @@ raise exceptions.DataError( "Conflict between RSM request and MAM's RSM request" ) - d = self._mam.getArchives(client, mam_query, service, self._unwrapMAMMessage) + items, rsm_response, mam_response = await self._mam.getArchives( + client, mam_query, service, self._unwrapMAMMessage + ) try: subscribe = C.bool(extra["subscribe"]) except KeyError: subscribe = False - def subscribeEb(failure, service, node): - failure.trap(error.StanzaError) - log.warning( - "Could not subscribe to node {} on service {}: {}".format( - node, str(service), str(failure.value) + if subscribe: + try: + await self.subscribe(client, service, node) + except error.StanzaError as e: + log.warning( + f"Could not subscribe to node {node} on service {service}: {e}" ) - ) - - def doSubscribe(data): - self.subscribe(client, service, node).addErrback( - subscribeEb, service, node - ) - return data - - if subscribe: - d.addCallback(doSubscribe) - def addMetadata(result): - # TODO: handle the third argument (mam_response) - items, rsm_response, mam_response = result - service_jid = service if service else client.jid.userhostJID() - metadata = { - "service": service_jid, - "node": node, - "uri": self.getNodeURI(service_jid, node), - } - if mam_response is not None: - # mam_response is a dict with "complete" and "stable" keys - # we can put them directly in metadata - metadata.update(mam_response) - if rsm_request is not None and rsm_response is not None: - metadata['rsm'] = rsm_response.toDict() - if mam_response is None: - index = rsm_response.index - count = rsm_response.count - if index is None or count is None: - # we don't have enough information to know if the data is complete - # or not - metadata["complete"] = None - else: - # normally we have a strict equality here but XEP-0059 states - # that index MAY be approximative, so just in case… - metadata["complete"] = index + len(items) >= count + # TODO: handle mam_response + service_jid = service if service else client.jid.userhostJID() + metadata = { + "service": service_jid, + "node": node, + "uri": self.getNodeURI(service_jid, node), + } + if mam_response is not None: + # mam_response is a dict with "complete" and "stable" keys + # we can put them directly in metadata + metadata.update(mam_response) + if rsm_request is not None and rsm_response is not None: + metadata['rsm'] = rsm_response.toDict() + if mam_response is None: + index = rsm_response.index + count = rsm_response.count + if index is None or count is None: + # we don't have enough information to know if the data is complete + # or not + metadata["complete"] = None + else: + # normally we have a strict equality here but XEP-0059 states + # that index MAY be approximative, so just in case… + metadata["complete"] = index + len(items) >= count - return (items, metadata) - - d.addCallback(addMetadata) - return d + return (items, metadata) # @defer.inlineCallbacks # def getItemsFromMany(self, service, data, max_items=None, sub_id=None, rsm=None, profile_key=C.PROF_KEY_NONE): @@ -1100,22 +1090,24 @@ def _subscribe(self, service, nodeIdentifier, options, profile_key=C.PROF_KEY_NONE): client = self.host.getClient(profile_key) service = None if not service else jid.JID(service) - d = self.subscribe(client, service, nodeIdentifier, options=options or None) + d = defer.ensureDeferred( + self.subscribe(client, service, nodeIdentifier, options=options or None) + ) d.addCallback(lambda subscription: subscription.subscriptionIdentifier or "") return d - def subscribe(self, client, service, nodeIdentifier, sub_jid=None, options=None): + async def subscribe(self, client, service, nodeIdentifier, sub_jid=None, options=None): # TODO: reimplement a subscribtion cache, checking that we have not subscription before trying to subscribe - return client.pubsub_client.subscribe( + return await client.pubsub_client.subscribe( service, nodeIdentifier, sub_jid or client.jid.userhostJID(), options=options ) def _unsubscribe(self, service, nodeIdentifier, profile_key=C.PROF_KEY_NONE): client = self.host.getClient(profile_key) service = None if not service else jid.JID(service) - return self.unsubscribe(client, service, nodeIdentifier) + return defer.ensureDeferred(self.unsubscribe(client, service, nodeIdentifier)) - def unsubscribe( + async def unsubscribe( self, client, service, @@ -1124,7 +1116,7 @@ subscriptionIdentifier=None, sender=None, ): - return client.pubsub_client.unsubscribe( + return await client.pubsub_client.unsubscribe( service, nodeIdentifier, sub_jid or client.jid.userhostJID(), @@ -1394,8 +1386,10 @@ client = self.host.getClient(profile_key) deferreds = {} for service, node in node_data: - deferreds[(service, node)] = client.pubsub_client.subscribe( - service, node, subscriber, options=options + deferreds[(service, node)] = defer.ensureDeferred( + client.pubsub_client.subscribe( + service, node, subscriber, options=options + ) ) return self.rt_sessions.newSession(deferreds, client.profile) # found_nodes = yield self.listNodes(service, profile=client.profile) @@ -1475,9 +1469,9 @@ client = self.host.getClient(profile_key) deferreds = {} for service, node in node_data: - deferreds[(service, node)] = self.getItems( + deferreds[(service, node)] = defer.ensureDeferred(self.getItems( client, service, node, max_item, rsm_request=rsm_request, extra=extra - ) + )) return self.rt_sessions.newSession(deferreds, client.profile)