Mercurial > libervia-backend
diff libervia/backend/plugins/plugin_xep_0308.py @ 4163:3b3cd9453d9b
plugin XEP-0308: implement Last Message Correction
author | Goffi <goffi@goffi.org> |
---|---|
date | Tue, 28 Nov 2023 17:38:31 +0100 |
parents | |
children | a1f7040b5a15 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libervia/backend/plugins/plugin_xep_0308.py Tue Nov 28 17:38:31 2023 +0100 @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 + +# Copyright (C) 2009-2023 Jérôme Poisson (goffi@goffi.org) + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. + +import time + +from sqlalchemy.orm.attributes import flag_modified +from twisted.internet import defer +from twisted.words.protocols.jabber import xmlstream +from twisted.words.protocols.jabber import jid +from twisted.words.xish import domish +from wokkel import disco +from zope.interface import implementer + +from libervia.backend.core import exceptions +from libervia.backend.core.constants import Const as C +from libervia.backend.core.core_types import SatXMPPEntity +from libervia.backend.core.i18n import _ +from libervia.backend.core.log import getLogger +from libervia.backend.memory.sqla import History, Message, Subject, joinedload, select +from libervia.backend.models.core import MessageEditData, MessageEdition +from libervia.backend.tools.common import data_format +from libervia.backend.tools.utils import aio +log = getLogger(__name__) + + +PLUGIN_INFO = { + C.PI_NAME: "Last Message Correction", + C.PI_IMPORT_NAME: "XEP-0308", + C.PI_TYPE: "XEP", + C.PI_PROTOCOLS: ["XEP-0308"], + C.PI_DEPENDENCIES: ["XEP-0334"], + C.PI_MAIN: "XEP_0308", + C.PI_HANDLER: "yes", + C.PI_DESCRIPTION: _("""Implementation of XEP-0308 (Last Message Correction)"""), +} + +NS_MESSAGE_CORRECT = "urn:xmpp:message-correct:0" + + +class XEP_0308: + def __init__(self, host): + log.info(f"plugin {PLUGIN_INFO[C.PI_NAME]!r} initialization") + self.host = host + host.register_namespace("message_correct", NS_MESSAGE_CORRECT) + host.trigger.add("message_received", self._message_received_trigger) + host.bridge.add_method( + "message_edit", + ".plugin", + in_sign="sss", + out_sign="", + method=self._message_edit, + async_=True, + ) + self._h = host.plugins["XEP-0334"] + + def get_handler(self, client): + return XEP_0308_handler() + + @aio + async def get_last_history( + self, client: SatXMPPEntity, message_elt: domish.Element + ) -> History | None: + profile_id = self.host.memory.storage.profiles[client.profile] + from_jid = jid.JID(message_elt["from"]) + message_type = message_elt.getAttribute("type", C.MESS_TYPE_NORMAL) + async with self.host.memory.storage.session() as session: + stmt = ( + select(History) + .where( + History.profile_id == profile_id, + History.source == from_jid.userhost(), + History.type == message_type, + ) + .options(joinedload(History.messages)) + .options(joinedload(History.subjects)) + .options(joinedload(History.thread)) + ) + if message_elt.type == C.MESS_TYPE_GROUPCHAT: + stmt = stmt.where(History.source_res == from_jid.resource) + + # we want last message + stmt = stmt.order_by(History.timestamp.desc()).limit(1) + result = await session.execute(stmt) + history = result.unique().scalar_one_or_none() + return history + + async def update_history( + self, + client: SatXMPPEntity, + edited_history: History, + edit_timestamp: float, + new_message: dict[str, str], + new_subject: dict[str, str], + new_extra: dict, + previous_message: dict[str, str], + previous_subject: dict[str, str], + previous_extra: dict | None, + store: bool = True, + ) -> None: + # FIXME: new_extra is not handled by now + edited_history.messages = [ + Message(message=mess, language=lang) for lang, mess in new_message.items() + ] + edited_history.subjects = [ + Subject(subject=mess, language=lang) for lang, mess in new_subject.items() + ] + previous_version = { + # this is the timestamp when this version was published + "timestamp": edited_history.extra.get("updated", edited_history.timestamp), + "message": previous_message, + "subject": previous_subject, + } + edited_history.extra["updated"] = edit_timestamp + if previous_extra: + previous_extra = previous_extra.copy() + # we must not have editions in each edition + try: + del previous_extra[C.MESS_EXTRA_EDITIONS] + except KeyError: + pass + # extra may be important for rich content + previous_version["extra"] = previous_extra + + if store: + flag_modified(edited_history, "extra") + edited_history.extra.setdefault(C.MESS_EXTRA_EDITIONS, []).append(previous_version) + await self.host.memory.storage.add(edited_history) + + edit_data = MessageEditData(edited_history.serialise()) + self.host.bridge.message_update( + edited_history.uid, + C.MESS_UPDATE_EDIT, + data_format.serialise(edit_data), + client.profile, + ) + + async def _message_received_trigger( + self, + client: SatXMPPEntity, + message_elt: domish.Element, + post_treat: defer.Deferred, + ) -> bool: + replace_elt = next(message_elt.elements(NS_MESSAGE_CORRECT, "replace"), None) + if not replace_elt: + return True + try: + replace_id = replace_elt["id"].strip() + if not replace_id: + raise KeyError + except KeyError: + log.warning(f"Invalid message correction: {message_elt.toXml()}") + else: + edited_history = await self.get_last_history(client, message_elt) + if edited_history is None: + log.warning( + f"No message found from {message_elt['from']}, can't correct " + f"anything: {message_elt.toXml()}" + ) + return False + if edited_history.extra.get("message_id") != replace_id: + log.warning( + "Can't apply correction: it doesn't reference the last one: " + f"{message_elt.toXml}" + ) + return False + previous_message_data = edited_history.serialise() + message_data = client.messageProt.parse_message(message_elt) + if not message_data["message"] and not message_data["subject"]: + log.warning( + "Message correction doesn't have body not subject, we can't edit " + "anything" + ) + return False + + await self.update_history( + client, + edited_history, + message_data.get("received_timestamp") or message_data["timestamp"], + message_data["message"], + message_data["subject"], + message_data["extra"], + previous_message_data["message"], + previous_message_data["subject"], + previous_message_data.get("extra"), + ) + + return False + + async def message_edit( + self, + client: SatXMPPEntity, + message_id: str, + edit_data: MessageEdition, + ) -> None: + """Edit a message + + The message can only be edited if it's the last one of the discussion. + @param client: client instance + @param message_id: UID of the message to edit + @param edit_data: data to update in the message + """ + timestamp = time.time() + edited_history = await self.host.memory.storage.get( + client, + History, + History.uid, + message_id, + joined_loads=[History.messages, History.subjects, History.thread], + ) + if edited_history is None: + raise exceptions.NotFound( + f"message to edit not found in database ({message_id})" + ) + if edited_history.type == C.MESS_TYPE_GROUPCHAT: + is_group_chat = True + peer_jid = edited_history.dest_jid + else: + is_group_chat = False + peer_jid = jid.JID(edited_history.dest) + history_data = await self.host.memory.history_get( + client.jid, peer_jid, limit=1, profile=client.profile + ) + if not history_data: + raise exceptions.NotFound( + "No message found in conversation with {peer_jid.full()}" + ) + last_mess = history_data[0] + if last_mess[0] != message_id: + raise ValueError( + f"{message_id} is not the last message of the discussion, we can't edit " + "it" + ) + + await self.update_history( + client, + edited_history, + timestamp, + edit_data.message, + edit_data.subject, + edit_data.extra, + last_mess[4], + last_mess[5], + last_mess[-1], + # message will be updated and signal sent on reception in group chat + store = not is_group_chat + ) + + serialised = edited_history.serialise() + serialised["from"] = jid.JID(serialised["from"]) + serialised["to"] = jid.JID(serialised["to"]) + + message_elt = client.generate_message_xml(serialised)["xml"] + replace_elt = message_elt.addElement((NS_MESSAGE_CORRECT, "replace")) + replace_elt["id"] = message_id + self._h.add_hint_elements(message_elt, [self._h.HINT_STORE]) + client.send(message_elt) + + def _message_edit(self, message_id: str, edit_data_s: str, profile: str) -> None: + client = self.host.get_client(profile) + edit_data = MessageEdition.model_validate_json(edit_data_s) + defer.ensureDeferred(self.message_edit(client, message_id, edit_data)) + + +@implementer(disco.IDisco) +class XEP_0308_handler(xmlstream.XMPPHandler): + def getDiscoInfo(self, __, target, nodeIdentifier=""): + return [disco.DiscoFeature(NS_MESSAGE_CORRECT)] + + def getDiscoItems(self, requestor, target, nodeIdentifier=""): + return []