view libervia/backend/plugins/plugin_xep_0444.py @ 4334:111dce64dcb5

plugins XEP-0300, XEP-0446, XEP-0447, XEP0448 and others: Refactoring to use Pydantic: Pydantic models are used more and more in Libervia, for the bridge API, and also to convert `domish.Element` to internal representation. Type hints have also been added in many places. rel 453
author Goffi <goffi@goffi.org>
date Tue, 03 Dec 2024 00:12:38 +0100
parents 5d2de6c1156d
children
line wrap: on
line source

#!/usr/bin/env python3

# Libervia plugin for XEP-0444
# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from typing import Iterable, List

from twisted.internet import defer
from twisted.words.protocols.jabber import jid, xmlstream
from twisted.words.xish import domish
from wokkel import disco, iwokkel
from zope.interface import implementer

import emoji
from libervia.backend.core import exceptions
from libervia.backend.core.constants import Const as C
from libervia.backend.core.core_types import SatXMPPEntity
from libervia.backend.core.i18n import _
from libervia.backend.core.log import getLogger
from libervia.backend.memory.sqla_mapping import History
from libervia.backend.models.core import MessageReactionData
from libervia.backend.tools.utils import aio
from libervia.backend.memory.sqla import select
from sqlalchemy.orm.attributes import flag_modified

log = getLogger(__name__)

PLUGIN_INFO = {
    C.PI_NAME: "Message Reactions",
    C.PI_IMPORT_NAME: "XEP-0444",
    C.PI_TYPE: C.PLUG_TYPE_XEP,
    C.PI_MODES: C.PLUG_MODE_BOTH,
    C.PI_PROTOCOLS: ["XEP-0444"],
    C.PI_DEPENDENCIES: ["XEP-0045", "XEP-0334"],
    C.PI_MAIN: "XEP_0444",
    C.PI_HANDLER: "yes",
    C.PI_DESCRIPTION: _("""Message Reactions implementation"""),
}

NS_REACTIONS = "urn:xmpp:reactions:0"


class XEP_0444:
    # TODO: implement and use occupant-ID (XEP-0421), and check sender (see
    #   https://xmpp.org/extensions/xep-0444.html#acceptable-reactions).
    def __init__(self, host):
        log.info(f"plugin {PLUGIN_INFO[C.PI_NAME]!r} initialization")
        host.register_namespace("reactions", NS_REACTIONS)
        self.host = host
        self._m = host.plugins["XEP-0045"]
        self._h = host.plugins["XEP-0334"]
        host.bridge.add_method(
            "message_reactions_set",
            ".plugin",
            in_sign="sasss",
            out_sign="",
            method=self._reactions_set,
            async_=True,
        )
        host.trigger.add("message_received", self._message_received_trigger)

    def get_handler(self, client):
        return XEP_0444_Handler()

    @aio
    async def get_history_from_reaction_id(
        self, client: SatXMPPEntity, message_elt: domish.Element, reaction_id: str
    ) -> History | None:
        """Retrieves history rows that match a specific reaction_id.

        The retrieval criteria vary based on the message type, according to XEP-0444.

        @param message_elt: The message element, used to determine the type of message.
        @param reaction_id: The reaction ID to match in the history.
        """
        profile_id = self.host.memory.storage.profiles[client.profile]
        async with self.host.memory.storage.session() as session:
            if message_elt.getAttribute("type") == C.MESS_TYPE_GROUPCHAT:
                query = select(History).where(
                    History.profile_id == profile_id, History.stanza_id == reaction_id
                )
            else:
                query = select(History).where(
                    History.profile_id == profile_id, History.origin_id == reaction_id
                )

            result = await session.execute(query)
            history = result.scalars().first()

            return history

    async def _message_received_trigger(
        self,
        client: SatXMPPEntity,
        message_elt: domish.Element,
        post_treat: defer.Deferred,
    ) -> bool:
        reactions_elt = next(message_elt.elements(NS_REACTIONS, "reactions"), None)
        if reactions_elt is not None:
            reaction_id = reactions_elt.getAttribute("id")
            history = await self.get_history_from_reaction_id(
                client, message_elt, reaction_id
            )
            if history is None:
                log.warning(
                    f"Can't find matching message for reaction: {reactions_elt.toXml()}"
                )
            else:
                if not reaction_id:
                    log.warning(f"Invalid reaction: {reactions_elt.toXml()}")
                    return False
                from_jid = jid.JID(message_elt["from"])
                reactions = set()
                for reaction_elt in reactions_elt.elements("reaction"):
                    reaction = str(reaction_elt)
                    if (
                        not emoji.purely_emoji(reaction)
                        or emoji.emoji_count(reaction) != 1
                    ):
                        log.warning(f"Ignoring invalid reaction: {reaction_elt.toXml()}.")
                        continue
                    reactions.add(reaction)
                await self.update_history_reactions(client, history, from_jid, reactions)

            return False

        return True

    def _reactions_set(
        self,
        message_id: str,
        reactions: List[str],
        update_type: str = "replace",
        profile: str = C.PROF_KEY_NONE,
    ) -> defer.Deferred[None]:
        client = self.host.get_client(profile)
        return defer.ensureDeferred(
            self.set_reactions(client, message_id, reactions, update_type)
        )

    async def get_history_from_uid(
        self, client: SatXMPPEntity, message_id: str
    ) -> History:
        """Retrieve the chat history associated with a given message ID.

        @param message_id: The Libervia specific identifier of the message

        @return: An instance of History containing the chat history, messages, and
            subjects related to the specified message ID.

        @raises exceptions.NotFound: The history corresponding to the given message ID is
            not found in the database.
        """
        history = await self.host.memory.storage.get(
            client,
            History,
            History.uid,
            message_id,
        )
        if history is None:
            raise exceptions.NotFound(
                f"message to retract not found in database ({message_id})"
            )
        return history

    def send_reactions(
        self,
        client: SatXMPPEntity,
        dest_jid: jid.JID,
        message_type: str,
        message_id: str,
        reactions: Iterable[str],
    ) -> None:
        """Send the <message> stanza containing the reactions

        @param dest_jid: recipient of the reaction
        @param message_id: either origin ID or stanza ID
            see https://xmpp.org/extensions/xep-0444.html#business-id
        """
        message_elt = domish.Element((None, "message"))
        message_elt["from"] = client.jid.full()
        message_elt["to"] = dest_jid.full()
        message_elt["type"] = message_type
        reactions_elt = message_elt.addElement((NS_REACTIONS, "reactions"))
        reactions_elt["id"] = message_id
        for r in set(reactions):
            reactions_elt.addElement("reaction", content=r)
        self._h.add_hint_elements(message_elt, [self._h.HINT_STORE])
        client.send(message_elt)

    def convert_to_replace_update(
        self,
        current_reactions: set[str],
        reactions_new: Iterable[str],
        update_type: str,
    ) -> set[str]:
        """Convert the given update to a replace update.

        @param current_reactions: reaction of reacting JID before update
        @param reactions_new: New reactions to be updated.
        @param update_type: Original type of update (add, remove, toggle, replace).
        @return: reactions for a replace operation.

        @raise ValueError: invalid ``update_type``.
        """
        new_reactions_set = set(reactions_new)

        if update_type == "replace":
            return new_reactions_set

        if update_type == "add":
            replace_reactions = current_reactions | new_reactions_set
        elif update_type == "remove":
            replace_reactions = current_reactions - new_reactions_set
        elif update_type == "toggle":
            replace_reactions = current_reactions ^ new_reactions_set
        else:
            raise ValueError(f"Invalid update type: {update_type!r}")

        return replace_reactions

    async def update_history_reactions(
        self,
        client: SatXMPPEntity,
        history: History,
        reacting_jid: jid.JID,
        reactions_new: Iterable[str],
        update_type: str = "replace",
        store: bool = True,
    ) -> set[str]:
        """Update reactions in History instance and optionally store and signal it.

        @param history: storage History instance to be updated
        @param reacting_jid: author of the reactions
        @param reactions_new: new reactions to update
        @param update_type: Original type of update (add, remove, toggle, replace).
        @param store: if True, update history in storage, and send a `message_update`
            signal with the new reactions.
        @return: set of reactions for this JID for a "replace" update
        """
        # FIXME: handle race conditions
        if history.type == C.MESS_TYPE_GROUPCHAT:
            entity_jid_s = reacting_jid.full()
        else:
            entity_jid_s = reacting_jid.userhost()
        if history.extra is None:
            history.extra = {}
        extra = history.extra
        reactions = extra.get("reactions", {})

        current_reactions = {
            reaction for reaction, jids in reactions.items() if entity_jid_s in jids
        }
        reactions_replace_set = self.convert_to_replace_update(
            current_reactions, reactions_new, update_type
        )

        for reaction in current_reactions - reactions_replace_set:
            reaction_jids = reactions[reaction]
            reaction_jids.remove(entity_jid_s)
            if not reaction_jids:
                del reactions[reaction]

        for reaction in reactions_replace_set - current_reactions:
            reactions.setdefault(reaction, []).append(entity_jid_s)

        # we want to have a constant order in reactions
        extra["reactions"] = dict(sorted(reactions.items()))

        if store:
            # FIXME: this is not a clean way to flag "extra" as modified, but a deepcopy
            # is for some reason not working here.
            flag_modified(history, "extra")
            await self.host.memory.storage.add(history)
            # we send the signal for frontends update
            data = MessageReactionData(reactions=extra["reactions"])
            self.host.bridge.message_update(
                history.uid,
                C.MESS_UPDATE_REACTION,
                data.model_dump_json(),
                client.profile,
            )

        return reactions_replace_set

    async def set_reactions(
        self,
        client: SatXMPPEntity,
        message_id: str,
        reactions: Iterable[str],
        update_type: str = "replace",
    ) -> None:
        """Set and replace reactions to a message

        @param message_id: internal ID of the message
        @param rections: lsit of emojis to used to react to the message
            use empty list to remove all reactions
        @param update_type: how to use the reaction to make the update, can be "replace",
            "add", "remove" or "toggle".
        """
        if not message_id:
            raise ValueError("message_id can't be empty")

        history = await self.get_history_from_uid(client, message_id)
        if history.type == C.MESS_TYPE_GROUPCHAT:
            mess_id = history.stanza_id
        else:
            mess_id = history.origin_id

        if mess_id is None:
            raise exceptions.DataError(
                "target message has neither origin-id nor message-id, we can't send a "
                "reaction"
            )

        if history.source == client.jid.userhost():
            dest_jid = history.dest_jid
        else:
            dest_jid = history.source_jid

        if history.type == C.MESS_TYPE_GROUPCHAT:
            # the reaction is for the chat room itself
            dest_jid = dest_jid.userhostJID()
            reacting_jid = self._m.get_room_user_jid(client, dest_jid)
            # we don't store reactions for MUC, at this will be done when we receive back
            # the reaction
            store = False
        else:
            # we use bare JIDs for one2one message, except if the recipient is from a MUC
            if self._m.is_room(client, dest_jid):
                # this is a private message in a room, we need to use the MUC JID
                reacting_jid = self._m.get_room_user_jid(client, dest_jid.userhostJID())
            else:
                # this is a classic one2one message, we need the bare JID
                reacting_jid = client.jid.userhostJID()
                dest_jid = dest_jid.userhostJID()
            store = True

        reaction_replace_set = await self.update_history_reactions(
            client, history, reacting_jid, reactions, update_type, store
        )

        self.send_reactions(client, dest_jid, history.type, mess_id, reaction_replace_set)


@implementer(iwokkel.IDisco)
class XEP_0444_Handler(xmlstream.XMPPHandler):
    def getDiscoInfo(self, requestor, service, nodeIdentifier=""):
        return [disco.DiscoFeature(NS_REACTIONS)]

    def getDiscoItems(self, requestor, service, nodeIdentifier=""):
        return []