view libervia/backend/memory/sqla_mapping.py @ 4306:94e0968987cd

plugin XEP-0033: code modernisation, improve delivery, data validation: - Code has been rewritten using Pydantic models and `async` coroutines for data validation and cleaner element parsing/generation. - Delivery has been completely rewritten. It now works even if server doesn't support multicast, and send to local multicast service first. Delivering to local multicast service first is due to bad support of XEP-0033 in server (notably Prosody which has an incomplete implementation), and the current impossibility to detect if a sub-domain service handles fully multicast or only for local domains. This is a workaround to have a good balance between backward compatilibity and use of bandwith, and to make it work with the incoming email gateway implementation (the gateway will only deliver to entities of its own domain). - disco feature checking now uses `async` corountines. `host` implementation still use Deferred return values for compatibility with legacy code. rel 450
author Goffi <goffi@goffi.org>
date Thu, 26 Sep 2024 16:12:01 +0200
parents 0d7bb4df2343
children
line wrap: on
line source

#!/usr/bin/env python3

# Libervia: an XMPP client
# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from datetime import datetime
import enum
import json
import time
from typing import Any, Dict

from sqlalchemy import (
    Boolean,
    Column,
    DDL,
    DateTime,
    Enum,
    Float,
    ForeignKey,
    Index,
    Integer,
    JSON,
    MetaData,
    text,
    Text,
    UniqueConstraint,
    event,
)
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.sql.functions import now
from sqlalchemy.types import TypeDecorator
from twisted.words.protocols.jabber import jid
from wokkel import generic

from libervia.backend.core.constants import Const as C


Base = declarative_base(
    metadata=MetaData(
        naming_convention={
            "ix": "ix_%(column_0_label)s",
            "uq": "uq_%(table_name)s_%(column_0_name)s",
            "ck": "ck_%(table_name)s_%(constraint_name)s",
            "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
            "pk": "pk_%(table_name)s",
        }
    )
)
# keys which are in message data extra but not stored in extra field this is
# because those values are stored in separate fields
NOT_IN_EXTRA = ("origin_id", "stanza_id", "received_timestamp", "update_uid")


class Profiles(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.id_to_profile = {v: k for k, v in self.items()}

    def __setitem__(self, key, value):
        super().__setitem__(key, value)
        self.id_to_profile[value] = key

    def __delitem__(self, key):
        del self.id_to_profile[self[key]]
        super().__delitem__(key)

    def update(self, *args, **kwargs):
        super().update(*args, **kwargs)
        self.id_to_profile = {v: k for k, v in self.items()}

    def clear(self):
        super().clear()
        self.id_to_profile.clear()


profiles = Profiles()


def get_profile_by_id(profile_id):
    return profiles.id_to_profile.get(profile_id)


class SyncState(enum.Enum):
    #: synchronisation is currently in progress
    IN_PROGRESS = 1
    #: synchronisation is done
    COMPLETED = 2
    #: something wrong happened during synchronisation, won't sync
    ERROR = 3
    #: synchronisation won't be done even if a syncing analyser matches
    NO_SYNC = 4


class SubscriptionState(enum.Enum):
    SUBSCRIBED = 1
    PENDING = 2


class NotificationType(enum.Enum):
    chat = "chat"
    blog = "blog"
    calendar = "calendar"
    file = "file"
    call = "call"
    service = "service"
    other = "other"


class NotificationStatus(enum.Enum):
    new = "new"
    read = "read"


class NotificationPriority(enum.IntEnum):
    LOW = 10
    MEDIUM = 20
    HIGH = 30
    URGENT = 40


class Json(TypeDecorator):
    """Handle JSON field in DB independant way"""

    # Blob is used on SQLite but gives errors when used here, while Text works fine
    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return json.dumps(value, ensure_ascii=False)

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return json.loads(value)


class JsonDefaultDict(Json):
    """Json type which convert NULL to empty dict instead of None"""

    def process_result_value(self, value, dialect):
        if value is None:
            return {}
        return json.loads(value)


class Xml(TypeDecorator):
    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return value.toXml()

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return generic.parseXml(value.encode())


class JID(TypeDecorator):
    """Store twisted JID in text fields"""

    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return value.full()

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return jid.JID(value)


class Profile(Base):
    __tablename__ = "profiles"

    id = Column(
        Integer,
        primary_key=True,
        nullable=True,
    )
    name = Column(Text, unique=True)

    params = relationship("ParamInd", back_populates="profile", passive_deletes=True)
    private_data = relationship(
        "PrivateInd", back_populates="profile", passive_deletes=True
    )
    private_bin_data = relationship(
        "PrivateIndBin", back_populates="profile", passive_deletes=True
    )


class Component(Base):
    __tablename__ = "components"

    profile_id = Column(
        ForeignKey("profiles.id", ondelete="CASCADE"), nullable=True, primary_key=True
    )
    entry_point = Column(Text, nullable=False)
    profile = relationship("Profile")


class History(Base):
    __tablename__ = "history"
    __table_args__ = (
        UniqueConstraint("profile_id", "stanza_id", "source", "dest"),
        UniqueConstraint("profile_id", "origin_id", "source", name="uq_origin_id"),
        Index("history__profile_id_timestamp", "profile_id", "timestamp"),
        Index(
            "history__profile_id_received_timestamp", "profile_id", "received_timestamp"
        ),
    )

    uid = Column(Text, primary_key=True)
    # FIXME: version_id is only needed for changes in `extra` column. It would maybe be
    # better to use separate table for `extra` data instead.
    version_id = Column(Integer, nullable=False, server_default=text("1"))
    origin_id = Column(Text)
    stanza_id = Column(Text)
    update_uid = Column(Text)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"))
    source = Column(Text)
    dest = Column(Text)
    source_res = Column(Text)
    dest_res = Column(Text)
    timestamp = Column(Float, nullable=False)
    received_timestamp = Column(Float)
    type = Column(
        Enum(
            "chat",
            "error",
            "groupchat",
            "headline",
            "normal",
            # info is not XMPP standard, but used to keep track of info like join/leave
            # in a MUC
            "info",
            name="message_type",
            create_constraint=True,
        ),
        nullable=False,
    )
    extra = Column(JSON)

    profile = relationship("Profile")
    messages = relationship(
        "Message", backref="history", cascade="all, delete-orphan", passive_deletes=True
    )
    subjects = relationship(
        "Subject", backref="history", cascade="all, delete-orphan", passive_deletes=True
    )
    thread = relationship(
        "Thread",
        uselist=False,
        back_populates="history",
        cascade="all, delete-orphan",
        passive_deletes=True,
    )
    __mapper_args__ = {"version_id_col": version_id}

    def __init__(self, *args, **kwargs):
        source_jid = kwargs.pop("source_jid", None)
        if source_jid is not None:
            kwargs["source"] = source_jid.userhost()
            kwargs["source_res"] = source_jid.resource
        dest_jid = kwargs.pop("dest_jid", None)
        if dest_jid is not None:
            kwargs["dest"] = dest_jid.userhost()
            kwargs["dest_res"] = dest_jid.resource
        super().__init__(*args, **kwargs)

    @property
    def source_jid(self) -> jid.JID:
        return jid.JID(f"{self.source}/{self.source_res or ''}")

    @source_jid.setter
    def source_jid(self, source_jid: jid.JID) -> None:
        self.source = source_jid.userhost
        self.source_res = source_jid.resource

    @property
    def dest_jid(self):
        return jid.JID(f"{self.dest}/{self.dest_res or ''}")

    @dest_jid.setter
    def dest_jid(self, dest_jid: jid.JID) -> None:
        self.dest = dest_jid.userhost
        self.dest_res = dest_jid.resource

    def __repr__(self):
        dt = datetime.fromtimestamp(self.timestamp)
        return f"History<{self.source_jid.full()}->{self.dest_jid.full()} [{dt}]>"

    def serialise(self):
        extra = self.extra or {}
        if self.origin_id is not None:
            extra["origin_id"] = self.origin_id
        if self.stanza_id is not None:
            extra["stanza_id"] = self.stanza_id
        if self.update_uid is not None:
            extra["update_uid"] = self.update_uid
        if self.received_timestamp is not None:
            extra["received_timestamp"] = self.received_timestamp
        if self.thread is not None:
            extra["thread"] = self.thread.thread_id
            if self.thread.parent_id is not None:
                extra["thread_parent"] = self.thread.parent_id

        return {
            "from": (
                f"{self.source}/{self.source_res}" if self.source_res else self.source
            ),
            "to": f"{self.dest}/{self.dest_res}" if self.dest_res else self.dest,
            "uid": self.uid,
            "message": {m.language or "": m.message for m in self.messages},
            "subject": {m.language or "": m.subject for m in self.subjects},
            "type": self.type,
            "extra": extra,
            "timestamp": self.timestamp,
        }

    def as_tuple(self):
        d = self.serialise()
        return (
            d["uid"],
            d["timestamp"],
            d["from"],
            d["to"],
            d["message"],
            d["subject"],
            d["type"],
            d["extra"],
        )

    @staticmethod
    def debug_collection(history_collection):
        for idx, history in enumerate(history_collection):
            history.debug_msg(idx)

    def debug_msg(self, idx=None):
        """Print messages"""
        dt = datetime.fromtimestamp(self.timestamp)
        if idx is not None:
            dt = f"({idx}) {dt}"
        parts = []
        parts.append(f"[{dt}]<{self.source_jid.full()}->{self.dest_jid.full()}> ")
        for message in self.messages:
            if message.language:
                parts.append(f"[{message.language}] ")
            parts.append(f"{message.message}\n")
        print("".join(parts))


class Message(Base):
    __tablename__ = "message"
    __table_args__ = (Index("message__history_uid", "history_uid"),)

    id = Column(
        Integer,
        primary_key=True,
    )
    history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE"), nullable=False)
    message = Column(Text, nullable=False)
    language = Column(Text)

    def serialise(self) -> Dict[str, Any]:
        s = {}
        if self.message:
            s["message"] = str(self.message)
        if self.language:
            s["language"] = str(self.language)
        return s

    def __repr__(self):
        lang_str = f"[{self.language}]" if self.language else ""
        msg = f"{self.message[:20]}…" if len(self.message) > 20 else self.message
        content = f"{lang_str}{msg}"
        return f"Message<{content}>"


class Subject(Base):
    __tablename__ = "subject"
    __table_args__ = (Index("subject__history_uid", "history_uid"),)

    id = Column(
        Integer,
        primary_key=True,
    )
    history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE"), nullable=False)
    subject = Column(Text, nullable=False)
    language = Column(Text)

    def serialise(self) -> Dict[str, Any]:
        s = {}
        if self.subject:
            s["subject"] = str(self.subject)
        if self.language:
            s["language"] = str(self.language)
        return s

    def __repr__(self):
        lang_str = f"[{self.language}]" if self.language else ""
        msg = f"{self.subject[:20]}…" if len(self.subject) > 20 else self.subject
        content = f"{lang_str}{msg}"
        return f"Subject<{content}>"


class Thread(Base):
    __tablename__ = "thread"
    __table_args__ = (Index("thread__history_uid", "history_uid"),)

    id = Column(
        Integer,
        primary_key=True,
    )
    history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE"))
    thread_id = Column(Text)
    parent_id = Column(Text)

    history = relationship("History", uselist=False, back_populates="thread")

    def __repr__(self):
        return f"Thread<{self.thread_id} [parent: {self.parent_id}]>"


class Notification(Base):
    __tablename__ = "notifications"
    __table_args__ = (Index("notifications_profile_id_status", "profile_id", "status"),)

    id = Column(Integer, primary_key=True, autoincrement=True)
    timestamp = Column(Float, nullable=False, default=time.time)
    expire_at = Column(Float, nullable=True)

    profile_id = Column(
        ForeignKey("profiles.id", ondelete="CASCADE"), index=True, nullable=True
    )
    profile = relationship("Profile")

    type = Column(Enum(NotificationType), nullable=False)

    title = Column(Text, nullable=True)
    body_plain = Column(Text, nullable=False)
    body_rich = Column(Text, nullable=True)

    requires_action = Column(Boolean, default=False)
    priority = Column(Integer, default=NotificationPriority.MEDIUM.value)

    extra_data = Column(JSON)
    status = Column(Enum(NotificationStatus), default=NotificationStatus.new)

    def serialise(self) -> dict[str, str | float | bool | int | dict]:
        """
        Serialises the Notification instance to a dictionary.
        """
        result = {}
        for column in self.__table__.columns:
            value = getattr(self, column.name)
            if value is not None:
                if column.name in ("type", "status"):
                    result[column.name] = value.name
                elif column.name == "id":
                    result[column.name] = str(value)
                elif column.name == "profile_id":
                    if value is None:
                        result["profile"] = C.PROF_KEY_ALL
                    else:
                        result["profile"] = get_profile_by_id(value)
                else:
                    result[column.name] = value
        return result


class ParamGen(Base):
    __tablename__ = "param_gen"

    category = Column(Text, primary_key=True)
    name = Column(Text, primary_key=True)
    value = Column(Text)


class ParamInd(Base):
    __tablename__ = "param_ind"

    category = Column(Text, primary_key=True)
    name = Column(Text, primary_key=True)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True)
    value = Column(Text)

    profile = relationship("Profile", back_populates="params")


class PrivateGen(Base):
    __tablename__ = "private_gen"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    value = Column(Text)


class PrivateInd(Base):
    __tablename__ = "private_ind"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True)
    value = Column(Text)

    profile = relationship("Profile", back_populates="private_data")


class PrivateGenBin(Base):
    __tablename__ = "private_gen_bin"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    value = Column(JSON)


class PrivateIndBin(Base):
    __tablename__ = "private_ind_bin"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True)
    value = Column(JSON)

    profile = relationship("Profile", back_populates="private_bin_data")


class File(Base):
    __tablename__ = "files"
    __table_args__ = (
        Index("files__profile_id_owner_parent", "profile_id", "owner", "parent"),
        Index(
            "files__profile_id_owner_media_type_media_subtype",
            "profile_id",
            "owner",
            "media_type",
            "media_subtype",
        ),
    )

    id = Column(Text, primary_key=True)
    public_id = Column(Text, unique=True)
    version = Column(Text, primary_key=True)
    parent = Column(Text, nullable=False)
    type = Column(
        Enum("file", "directory", name="file_type", create_constraint=True),
        nullable=False,
        server_default="file",
    )
    file_hash = Column(Text)
    hash_algo = Column(Text)
    name = Column(Text, nullable=False)
    size = Column(Integer)
    namespace = Column(Text)
    media_type = Column(Text)
    media_subtype = Column(Text)
    created = Column(Float, nullable=False)
    modified = Column(Float)
    owner = Column(JID)
    access = Column(JsonDefaultDict)
    extra = Column(JsonDefaultDict)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"))

    profile = relationship("Profile")


class PubsubNode(Base):
    __tablename__ = "pubsub_nodes"
    __table_args__ = (UniqueConstraint("profile_id", "service", "name"),)

    id = Column(Integer, primary_key=True)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"))
    service = Column(JID)
    name = Column(Text, nullable=False)
    subscribed = Column(
        Boolean(create_constraint=True, name="subscribed_bool"), nullable=False
    )
    analyser = Column(Text)
    sync_state = Column(
        Enum(
            SyncState,
            name="sync_state",
            create_constraint=True,
        ),
        nullable=True,
    )
    sync_state_updated = Column(Float, nullable=False, default=time.time())
    type_ = Column(Text, name="type", nullable=True)
    subtype = Column(Text, nullable=True)
    extra = Column(JSON)

    items = relationship("PubsubItem", back_populates="node", passive_deletes=True)
    subscriptions = relationship("PubsubSub", back_populates="node", passive_deletes=True)

    def __str__(self):
        return f"Pubsub node {self.name!r} at {self.service}"


class PubsubSub(Base):
    """Subscriptions to pubsub nodes

    Used by components managing a pubsub service
    """

    __tablename__ = "pubsub_subs"
    __table_args__ = (UniqueConstraint("node_id", "subscriber"),)

    id = Column(Integer, primary_key=True)
    node_id = Column(ForeignKey("pubsub_nodes.id", ondelete="CASCADE"), nullable=False)
    subscriber = Column(JID)
    state = Column(
        Enum(
            SubscriptionState,
            name="state",
            create_constraint=True,
        ),
        nullable=True,
    )

    node = relationship("PubsubNode", back_populates="subscriptions")


class PubsubItem(Base):
    __tablename__ = "pubsub_items"
    __table_args__ = (UniqueConstraint("node_id", "name"),)
    id = Column(Integer, primary_key=True)
    node_id = Column(ForeignKey("pubsub_nodes.id", ondelete="CASCADE"), nullable=False)
    name = Column(Text, nullable=False)
    data = Column(Xml, nullable=False)
    created = Column(DateTime, nullable=False, server_default=now())
    updated = Column(DateTime, nullable=False, server_default=now(), onupdate=now())
    parsed = Column(JSON)

    node = relationship("PubsubNode", back_populates="items")


## Full-Text Search

# create


@event.listens_for(PubsubItem.__table__, "after_create")
def fts_create(target, connection, **kw):
    """Full-Text Search table creation"""
    if connection.engine.name == "sqlite":
        # Using SQLite FTS5
        queries = [
            "CREATE VIRTUAL TABLE pubsub_items_fts "
            "USING fts5(data, content=pubsub_items, content_rowid=id)",
            "CREATE TRIGGER pubsub_items_fts_sync_ins AFTER INSERT ON pubsub_items BEGIN"
            "  INSERT INTO pubsub_items_fts(rowid, data) VALUES (new.id, new.data);"
            "END",
            "CREATE TRIGGER pubsub_items_fts_sync_del AFTER DELETE ON pubsub_items BEGIN"
            "  INSERT INTO pubsub_items_fts(pubsub_items_fts, rowid, data) "
            "VALUES('delete', old.id, old.data);"
            "END",
            "CREATE TRIGGER pubsub_items_fts_sync_upd AFTER UPDATE ON pubsub_items BEGIN"
            "  INSERT INTO pubsub_items_fts(pubsub_items_fts, rowid, data) VALUES"
            "('delete', old.id, old.data);"
            "  INSERT INTO pubsub_items_fts(rowid, data) VALUES(new.id, new.data);"
            "END",
        ]
        for q in queries:
            connection.execute(DDL(q))


# drop


@event.listens_for(PubsubItem.__table__, "before_drop")
def fts_drop(target, connection, **kw):
    "Full-Text Search table drop" ""
    if connection.engine.name == "sqlite":
        # Using SQLite FTS5
        queries = [
            "DROP TRIGGER IF EXISTS pubsub_items_fts_sync_ins",
            "DROP TRIGGER IF EXISTS pubsub_items_fts_sync_del",
            "DROP TRIGGER IF EXISTS pubsub_items_fts_sync_upd",
            "DROP TABLE IF EXISTS pubsub_items_fts",
        ]
        for q in queries:
            connection.execute(DDL(q))