diff libervia/backend/plugins/plugin_xep_0308.py @ 4163:3b3cd9453d9b

plugin XEP-0308: implement Last Message Correction
author Goffi <goffi@goffi.org>
date Tue, 28 Nov 2023 17:38:31 +0100
parents
children a1f7040b5a15
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/libervia/backend/plugins/plugin_xep_0308.py	Tue Nov 28 17:38:31 2023 +0100
@@ -0,0 +1,284 @@
+#!/usr/bin/env python3
+
+# Copyright (C) 2009-2023 Jérôme Poisson (goffi@goffi.org)
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+import time
+
+from sqlalchemy.orm.attributes import flag_modified
+from twisted.internet import defer
+from twisted.words.protocols.jabber import xmlstream
+from twisted.words.protocols.jabber import jid
+from twisted.words.xish import domish
+from wokkel import disco
+from zope.interface import implementer
+
+from libervia.backend.core import exceptions
+from libervia.backend.core.constants import Const as C
+from libervia.backend.core.core_types import SatXMPPEntity
+from libervia.backend.core.i18n import _
+from libervia.backend.core.log import getLogger
+from libervia.backend.memory.sqla import History, Message, Subject, joinedload, select
+from libervia.backend.models.core import MessageEditData, MessageEdition
+from libervia.backend.tools.common import data_format
+from libervia.backend.tools.utils import aio
+log = getLogger(__name__)
+
+
+PLUGIN_INFO = {
+    C.PI_NAME: "Last Message Correction",
+    C.PI_IMPORT_NAME: "XEP-0308",
+    C.PI_TYPE: "XEP",
+    C.PI_PROTOCOLS: ["XEP-0308"],
+    C.PI_DEPENDENCIES: ["XEP-0334"],
+    C.PI_MAIN: "XEP_0308",
+    C.PI_HANDLER: "yes",
+    C.PI_DESCRIPTION: _("""Implementation of XEP-0308 (Last Message Correction)"""),
+}
+
+NS_MESSAGE_CORRECT = "urn:xmpp:message-correct:0"
+
+
+class XEP_0308:
+    def __init__(self, host):
+        log.info(f"plugin {PLUGIN_INFO[C.PI_NAME]!r} initialization")
+        self.host = host
+        host.register_namespace("message_correct", NS_MESSAGE_CORRECT)
+        host.trigger.add("message_received", self._message_received_trigger)
+        host.bridge.add_method(
+            "message_edit",
+            ".plugin",
+            in_sign="sss",
+            out_sign="",
+            method=self._message_edit,
+            async_=True,
+        )
+        self._h = host.plugins["XEP-0334"]
+
+    def get_handler(self, client):
+        return XEP_0308_handler()
+
+    @aio
+    async def get_last_history(
+        self, client: SatXMPPEntity, message_elt: domish.Element
+    ) -> History | None:
+        profile_id = self.host.memory.storage.profiles[client.profile]
+        from_jid = jid.JID(message_elt["from"])
+        message_type = message_elt.getAttribute("type", C.MESS_TYPE_NORMAL)
+        async with self.host.memory.storage.session() as session:
+            stmt = (
+                select(History)
+                .where(
+                    History.profile_id == profile_id,
+                    History.source == from_jid.userhost(),
+                    History.type == message_type,
+                )
+                .options(joinedload(History.messages))
+                .options(joinedload(History.subjects))
+                .options(joinedload(History.thread))
+            )
+            if message_elt.type == C.MESS_TYPE_GROUPCHAT:
+                stmt = stmt.where(History.source_res == from_jid.resource)
+
+            # we want last message
+            stmt = stmt.order_by(History.timestamp.desc()).limit(1)
+            result = await session.execute(stmt)
+            history = result.unique().scalar_one_or_none()
+        return history
+
+    async def update_history(
+        self,
+        client: SatXMPPEntity,
+        edited_history: History,
+        edit_timestamp: float,
+        new_message: dict[str, str],
+        new_subject: dict[str, str],
+        new_extra: dict,
+        previous_message: dict[str, str],
+        previous_subject: dict[str, str],
+        previous_extra: dict | None,
+        store: bool = True,
+    ) -> None:
+        # FIXME: new_extra is not handled by now
+        edited_history.messages = [
+            Message(message=mess, language=lang) for lang, mess in new_message.items()
+        ]
+        edited_history.subjects = [
+            Subject(subject=mess, language=lang) for lang, mess in new_subject.items()
+        ]
+        previous_version = {
+            # this is the timestamp when this version was published
+            "timestamp": edited_history.extra.get("updated", edited_history.timestamp),
+            "message": previous_message,
+            "subject": previous_subject,
+        }
+        edited_history.extra["updated"] = edit_timestamp
+        if previous_extra:
+            previous_extra = previous_extra.copy()
+            # we must not have editions in each edition
+            try:
+                del previous_extra[C.MESS_EXTRA_EDITIONS]
+            except KeyError:
+                pass
+            # extra may be important for rich content
+            previous_version["extra"] = previous_extra
+
+        if store:
+            flag_modified(edited_history, "extra")
+            edited_history.extra.setdefault(C.MESS_EXTRA_EDITIONS, []).append(previous_version)
+            await self.host.memory.storage.add(edited_history)
+
+            edit_data = MessageEditData(edited_history.serialise())
+            self.host.bridge.message_update(
+                edited_history.uid,
+                C.MESS_UPDATE_EDIT,
+                data_format.serialise(edit_data),
+                client.profile,
+            )
+
+    async def _message_received_trigger(
+        self,
+        client: SatXMPPEntity,
+        message_elt: domish.Element,
+        post_treat: defer.Deferred,
+    ) -> bool:
+        replace_elt = next(message_elt.elements(NS_MESSAGE_CORRECT, "replace"), None)
+        if not replace_elt:
+            return True
+        try:
+            replace_id = replace_elt["id"].strip()
+            if not replace_id:
+                raise KeyError
+        except KeyError:
+            log.warning(f"Invalid message correction: {message_elt.toXml()}")
+        else:
+            edited_history = await self.get_last_history(client, message_elt)
+            if edited_history is None:
+                log.warning(
+                    f"No message found from {message_elt['from']}, can't correct "
+                    f"anything: {message_elt.toXml()}"
+                )
+                return False
+            if edited_history.extra.get("message_id") != replace_id:
+                log.warning(
+                    "Can't apply correction: it doesn't reference the last one: "
+                    f"{message_elt.toXml}"
+                )
+                return False
+            previous_message_data = edited_history.serialise()
+            message_data = client.messageProt.parse_message(message_elt)
+            if not message_data["message"] and not message_data["subject"]:
+                log.warning(
+                    "Message correction doesn't have body not subject, we can't edit "
+                    "anything"
+                )
+                return False
+
+            await self.update_history(
+                client,
+                edited_history,
+                message_data.get("received_timestamp") or message_data["timestamp"],
+                message_data["message"],
+                message_data["subject"],
+                message_data["extra"],
+                previous_message_data["message"],
+                previous_message_data["subject"],
+                previous_message_data.get("extra"),
+            )
+
+        return False
+
+    async def message_edit(
+        self,
+        client: SatXMPPEntity,
+        message_id: str,
+        edit_data: MessageEdition,
+    ) -> None:
+        """Edit a message
+
+        The message can only be edited if it's the last one of the discussion.
+        @param client: client instance
+        @param message_id: UID of the message to edit
+        @param edit_data: data to update in the message
+        """
+        timestamp = time.time()
+        edited_history = await self.host.memory.storage.get(
+            client,
+            History,
+            History.uid,
+            message_id,
+            joined_loads=[History.messages, History.subjects, History.thread],
+        )
+        if edited_history is None:
+            raise exceptions.NotFound(
+                f"message to edit not found in database ({message_id})"
+            )
+        if edited_history.type == C.MESS_TYPE_GROUPCHAT:
+            is_group_chat = True
+            peer_jid = edited_history.dest_jid
+        else:
+            is_group_chat = False
+            peer_jid = jid.JID(edited_history.dest)
+        history_data = await self.host.memory.history_get(
+            client.jid, peer_jid, limit=1, profile=client.profile
+        )
+        if not history_data:
+            raise exceptions.NotFound(
+                "No message found in conversation with {peer_jid.full()}"
+            )
+        last_mess = history_data[0]
+        if last_mess[0] != message_id:
+            raise ValueError(
+                f"{message_id} is not the last message of the discussion, we can't edit "
+                "it"
+            )
+
+        await self.update_history(
+            client,
+            edited_history,
+            timestamp,
+            edit_data.message,
+            edit_data.subject,
+            edit_data.extra,
+            last_mess[4],
+            last_mess[5],
+            last_mess[-1],
+            # message will be updated and signal sent on reception in group chat
+            store = not is_group_chat
+        )
+
+        serialised = edited_history.serialise()
+        serialised["from"] = jid.JID(serialised["from"])
+        serialised["to"] = jid.JID(serialised["to"])
+
+        message_elt = client.generate_message_xml(serialised)["xml"]
+        replace_elt = message_elt.addElement((NS_MESSAGE_CORRECT, "replace"))
+        replace_elt["id"] = message_id
+        self._h.add_hint_elements(message_elt, [self._h.HINT_STORE])
+        client.send(message_elt)
+
+    def _message_edit(self, message_id: str, edit_data_s: str, profile: str) -> None:
+        client = self.host.get_client(profile)
+        edit_data = MessageEdition.model_validate_json(edit_data_s)
+        defer.ensureDeferred(self.message_edit(client, message_id, edit_data))
+
+
+@implementer(disco.IDisco)
+class XEP_0308_handler(xmlstream.XMPPHandler):
+    def getDiscoInfo(self, __, target, nodeIdentifier=""):
+        return [disco.DiscoFeature(NS_MESSAGE_CORRECT)]
+
+    def getDiscoItems(self, requestor, target, nodeIdentifier=""):
+        return []