view libervia/backend/memory/sqla_mapping.py @ 4183:6784d07b99c8

plugin XEP-053, component AP gateway: use the new `trigger.add_with_check` method
author Goffi <goffi@goffi.org>
date Sat, 09 Dec 2023 19:20:13 +0100
parents 2074b2bbe616
children 5f2d496c633f
line wrap: on
line source

#!/usr/bin/env python3

# Libervia: an XMPP client
# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from datetime import datetime
import enum
import json
import pickle
import time
from typing import Any, Dict

from sqlalchemy import (
    Boolean,
    Column,
    DDL,
    DateTime,
    Enum,
    Float,
    ForeignKey,
    Index,
    Integer,
    JSON,
    MetaData,
    text,
    Text,
    UniqueConstraint,
    event,
)
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.sql.functions import now
from sqlalchemy.types import TypeDecorator
from twisted.words.protocols.jabber import jid
from wokkel import generic

from libervia.backend.core.constants import Const as C


Base = declarative_base(
    metadata=MetaData(
        naming_convention={
            "ix": "ix_%(column_0_label)s",
            "uq": "uq_%(table_name)s_%(column_0_name)s",
            "ck": "ck_%(table_name)s_%(constraint_name)s",
            "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
            "pk": "pk_%(table_name)s",
        }
    )
)
# keys which are in message data extra but not stored in extra field this is
# because those values are stored in separate fields
NOT_IN_EXTRA = ("origin_id", "stanza_id", "received_timestamp", "update_uid")


class Profiles(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.id_to_profile = {v: k for k, v in self.items()}

    def __setitem__(self, key, value):
        super().__setitem__(key, value)
        self.id_to_profile[value] = key

    def __delitem__(self, key):
        del self.id_to_profile[self[key]]
        super().__delitem__(key)

    def update(self, *args, **kwargs):
        super().update(*args, **kwargs)
        self.id_to_profile = {v: k for k, v in self.items()}

    def clear(self):
        super().clear()
        self.id_to_profile.clear()


profiles = Profiles()


def get_profile_by_id( profile_id):
    return profiles.id_to_profile.get(profile_id)


class SyncState(enum.Enum):
    #: synchronisation is currently in progress
    IN_PROGRESS = 1
    #: synchronisation is done
    COMPLETED = 2
    #: something wrong happened during synchronisation, won't sync
    ERROR = 3
    #: synchronisation won't be done even if a syncing analyser matches
    NO_SYNC = 4


class SubscriptionState(enum.Enum):
    SUBSCRIBED = 1
    PENDING = 2


class NotificationType(enum.Enum):
    chat = "chat"
    blog = "blog"
    calendar = "calendar"
    file = "file"
    call = "call"
    service = "service"
    other = "other"


class NotificationStatus(enum.Enum):
    new = "new"
    read = "read"


class NotificationPriority(enum.IntEnum):
    LOW = 10
    MEDIUM = 20
    HIGH = 30
    URGENT = 40


class LegacyPickle(TypeDecorator):
    """Handle troubles with data pickled by former version of SàT

    This type is temporary until we do migration to a proper data type
    """

    # Blob is used on SQLite but gives errors when used here, while Text works fine
    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return pickle.dumps(value, 0)

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        # value types are inconsistent (probably a consequence of Python 2/3 port
        # and/or SQLite dynamic typing)
        try:
            value = value.encode()
        except AttributeError:
            pass
        # "utf-8" encoding is needed to handle Python 2 pickled data
        try:
            return pickle.loads(value, encoding="utf-8")
        except ModuleNotFoundError:
            # FIXME: workaround due to package renaming, need to move all pickle code to
            #   JSON
            return pickle.loads(
                value.replace(b"sat.plugins", b"libervia.backend.plugins"),
                encoding="utf-8",
            )


class Json(TypeDecorator):
    """Handle JSON field in DB independant way"""

    # Blob is used on SQLite but gives errors when used here, while Text works fine
    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return json.dumps(value)

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return json.loads(value)


class JsonDefaultDict(Json):
    """Json type which convert NULL to empty dict instead of None"""

    def process_result_value(self, value, dialect):
        if value is None:
            return {}
        return json.loads(value)


class Xml(TypeDecorator):
    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return value.toXml()

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return generic.parseXml(value.encode())


class JID(TypeDecorator):
    """Store twisted JID in text fields"""

    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return value.full()

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return jid.JID(value)


class Profile(Base):
    __tablename__ = "profiles"

    id = Column(
        Integer,
        primary_key=True,
        nullable=True,
    )
    name = Column(Text, unique=True)

    params = relationship("ParamInd", back_populates="profile", passive_deletes=True)
    private_data = relationship(
        "PrivateInd", back_populates="profile", passive_deletes=True
    )
    private_bin_data = relationship(
        "PrivateIndBin", back_populates="profile", passive_deletes=True
    )


class Component(Base):
    __tablename__ = "components"

    profile_id = Column(
        ForeignKey("profiles.id", ondelete="CASCADE"), nullable=True, primary_key=True
    )
    entry_point = Column(Text, nullable=False)
    profile = relationship("Profile")


class History(Base):
    __tablename__ = "history"
    __table_args__ = (
        UniqueConstraint("profile_id", "stanza_id", "source", "dest"),
        UniqueConstraint("profile_id", "origin_id", "source", name="uq_origin_id"),
        Index("history__profile_id_timestamp", "profile_id", "timestamp"),
        Index(
            "history__profile_id_received_timestamp", "profile_id", "received_timestamp"
        ),
    )

    uid = Column(Text, primary_key=True)
    # FIXME: version_id is only needed for changes in `extra` column. It would maybe be
    # better to use separate table for `extra` data instead.
    version_id = Column(Integer, nullable=False, server_default=text("1"))
    origin_id = Column(Text)
    stanza_id = Column(Text)
    update_uid = Column(Text)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"))
    source = Column(Text)
    dest = Column(Text)
    source_res = Column(Text)
    dest_res = Column(Text)
    timestamp = Column(Float, nullable=False)
    received_timestamp = Column(Float)
    type = Column(
        Enum(
            "chat",
            "error",
            "groupchat",
            "headline",
            "normal",
            # info is not XMPP standard, but used to keep track of info like join/leave
            # in a MUC
            "info",
            name="message_type",
            create_constraint=True,
        ),
        nullable=False,
    )
    extra = Column(LegacyPickle)

    profile = relationship("Profile")
    messages = relationship(
        "Message",
        backref="history",
        cascade="all, delete-orphan",
        passive_deletes=True
    )
    subjects = relationship(
        "Subject",
        backref="history",
        cascade="all, delete-orphan",
        passive_deletes=True
    )
    thread = relationship(
        "Thread",
        uselist=False,
        back_populates="history",
        cascade="all, delete-orphan",
        passive_deletes=True
    )
    __mapper_args__ = {"version_id_col": version_id}

    def __init__(self, *args, **kwargs):
        source_jid = kwargs.pop("source_jid", None)
        if source_jid is not None:
            kwargs["source"] = source_jid.userhost()
            kwargs["source_res"] = source_jid.resource
        dest_jid = kwargs.pop("dest_jid", None)
        if dest_jid is not None:
            kwargs["dest"] = dest_jid.userhost()
            kwargs["dest_res"] = dest_jid.resource
        super().__init__(*args, **kwargs)

    @property
    def source_jid(self) -> jid.JID:
        return jid.JID(f"{self.source}/{self.source_res or ''}")

    @source_jid.setter
    def source_jid(self, source_jid: jid.JID) -> None:
        self.source = source_jid.userhost
        self.source_res = source_jid.resource

    @property
    def dest_jid(self):
        return jid.JID(f"{self.dest}/{self.dest_res or ''}")

    @dest_jid.setter
    def dest_jid(self, dest_jid: jid.JID) -> None:
        self.dest = dest_jid.userhost
        self.dest_res = dest_jid.resource

    def __repr__(self):
        dt = datetime.fromtimestamp(self.timestamp)
        return f"History<{self.source_jid.full()}->{self.dest_jid.full()} [{dt}]>"

    def serialise(self):
        extra = self.extra or {}
        if self.origin_id is not None:
            extra["origin_id"] = self.origin_id
        if self.stanza_id is not None:
            extra["stanza_id"] = self.stanza_id
        if self.update_uid is not None:
            extra["update_uid"] = self.update_uid
        if self.received_timestamp is not None:
            extra["received_timestamp"] = self.received_timestamp
        if self.thread is not None:
            extra["thread"] = self.thread.thread_id
            if self.thread.parent_id is not None:
                extra["thread_parent"] = self.thread.parent_id

        return {
            "from": f"{self.source}/{self.source_res}"
            if self.source_res
            else self.source,
            "to": f"{self.dest}/{self.dest_res}" if self.dest_res else self.dest,
            "uid": self.uid,
            "message": {m.language or "": m.message for m in self.messages},
            "subject": {m.language or "": m.subject for m in self.subjects},
            "type": self.type,
            "extra": extra,
            "timestamp": self.timestamp,
        }

    def as_tuple(self):
        d = self.serialise()
        return (
            d["uid"],
            d["timestamp"],
            d["from"],
            d["to"],
            d["message"],
            d["subject"],
            d["type"],
            d["extra"],
        )

    @staticmethod
    def debug_collection(history_collection):
        for idx, history in enumerate(history_collection):
            history.debug_msg(idx)

    def debug_msg(self, idx=None):
        """Print messages"""
        dt = datetime.fromtimestamp(self.timestamp)
        if idx is not None:
            dt = f"({idx}) {dt}"
        parts = []
        parts.append(f"[{dt}]<{self.source_jid.full()}->{self.dest_jid.full()}> ")
        for message in self.messages:
            if message.language:
                parts.append(f"[{message.language}] ")
            parts.append(f"{message.message}\n")
        print("".join(parts))


class Message(Base):
    __tablename__ = "message"
    __table_args__ = (Index("message__history_uid", "history_uid"),)

    id = Column(
        Integer,
        primary_key=True,
    )
    history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE"), nullable=False)
    message = Column(Text, nullable=False)
    language = Column(Text)

    def serialise(self) -> Dict[str, Any]:
        s = {}
        if self.message:
            s["message"] = str(self.message)
        if self.language:
            s["language"] = str(self.language)
        return s

    def __repr__(self):
        lang_str = f"[{self.language}]" if self.language else ""
        msg = f"{self.message[:20]}…" if len(self.message) > 20 else self.message
        content = f"{lang_str}{msg}"
        return f"Message<{content}>"


class Subject(Base):
    __tablename__ = "subject"
    __table_args__ = (Index("subject__history_uid", "history_uid"),)

    id = Column(
        Integer,
        primary_key=True,
    )
    history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE"), nullable=False)
    subject = Column(Text, nullable=False)
    language = Column(Text)

    def serialise(self) -> Dict[str, Any]:
        s = {}
        if self.subject:
            s["subject"] = str(self.subject)
        if self.language:
            s["language"] = str(self.language)
        return s

    def __repr__(self):
        lang_str = f"[{self.language}]" if self.language else ""
        msg = f"{self.subject[:20]}…" if len(self.subject) > 20 else self.subject
        content = f"{lang_str}{msg}"
        return f"Subject<{content}>"


class Thread(Base):
    __tablename__ = "thread"
    __table_args__ = (Index("thread__history_uid", "history_uid"),)

    id = Column(
        Integer,
        primary_key=True,
    )
    history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE"))
    thread_id = Column(Text)
    parent_id = Column(Text)

    history = relationship("History", uselist=False, back_populates="thread")

    def __repr__(self):
        return f"Thread<{self.thread_id} [parent: {self.parent_id}]>"


class Notification(Base):
    __tablename__ = "notifications"
    __table_args__ = (Index("notifications_profile_id_status", "profile_id", "status"),)

    id = Column(Integer, primary_key=True, autoincrement=True)
    timestamp = Column(Float, nullable=False, default=time.time)
    expire_at = Column(Float, nullable=True)

    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), index=True, nullable=True)
    profile = relationship("Profile")

    type = Column(Enum(NotificationType), nullable=False)

    title = Column(Text, nullable=True)
    body_plain = Column(Text, nullable=False)
    body_rich = Column(Text, nullable=True)

    requires_action = Column(Boolean, default=False)
    priority = Column(Integer, default=NotificationPriority.MEDIUM.value)

    extra_data = Column(JSON)
    status = Column(Enum(NotificationStatus), default=NotificationStatus.new)

    def serialise(self) -> dict[str, str | float | bool | int | dict]:
        """
        Serialises the Notification instance to a dictionary.
        """
        result = {}
        for column in self.__table__.columns:
            value = getattr(self, column.name)
            if value is not None:
                if column.name in ("type", "status"):
                    result[column.name] = value.name
                elif column.name == "id":
                    result[column.name] = str(value)
                elif column.name == "profile_id":
                    if value is None:
                        result["profile"] = C.PROF_KEY_ALL
                    else:
                        result["profile"] = get_profile_by_id(value)
                else:
                    result[column.name] = value
        return result


class ParamGen(Base):
    __tablename__ = "param_gen"

    category = Column(Text, primary_key=True)
    name = Column(Text, primary_key=True)
    value = Column(Text)


class ParamInd(Base):
    __tablename__ = "param_ind"

    category = Column(Text, primary_key=True)
    name = Column(Text, primary_key=True)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True)
    value = Column(Text)

    profile = relationship("Profile", back_populates="params")


class PrivateGen(Base):
    __tablename__ = "private_gen"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    value = Column(Text)


class PrivateInd(Base):
    __tablename__ = "private_ind"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True)
    value = Column(Text)

    profile = relationship("Profile", back_populates="private_data")


class PrivateGenBin(Base):
    __tablename__ = "private_gen_bin"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    value = Column(LegacyPickle)


class PrivateIndBin(Base):
    __tablename__ = "private_ind_bin"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True)
    value = Column(LegacyPickle)

    profile = relationship("Profile", back_populates="private_bin_data")


class File(Base):
    __tablename__ = "files"
    __table_args__ = (
        Index("files__profile_id_owner_parent", "profile_id", "owner", "parent"),
        Index(
            "files__profile_id_owner_media_type_media_subtype",
            "profile_id",
            "owner",
            "media_type",
            "media_subtype",
        ),
    )

    id = Column(Text, primary_key=True)
    public_id = Column(Text, unique=True)
    version = Column(Text, primary_key=True)
    parent = Column(Text, nullable=False)
    type = Column(
        Enum("file", "directory", name="file_type", create_constraint=True),
        nullable=False,
        server_default="file",
    )
    file_hash = Column(Text)
    hash_algo = Column(Text)
    name = Column(Text, nullable=False)
    size = Column(Integer)
    namespace = Column(Text)
    media_type = Column(Text)
    media_subtype = Column(Text)
    created = Column(Float, nullable=False)
    modified = Column(Float)
    owner = Column(JID)
    access = Column(JsonDefaultDict)
    extra = Column(JsonDefaultDict)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"))

    profile = relationship("Profile")


class PubsubNode(Base):
    __tablename__ = "pubsub_nodes"
    __table_args__ = (UniqueConstraint("profile_id", "service", "name"),)

    id = Column(Integer, primary_key=True)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"))
    service = Column(JID)
    name = Column(Text, nullable=False)
    subscribed = Column(
        Boolean(create_constraint=True, name="subscribed_bool"), nullable=False
    )
    analyser = Column(Text)
    sync_state = Column(
        Enum(
            SyncState,
            name="sync_state",
            create_constraint=True,
        ),
        nullable=True,
    )
    sync_state_updated = Column(Float, nullable=False, default=time.time())
    type_ = Column(Text, name="type", nullable=True)
    subtype = Column(Text, nullable=True)
    extra = Column(JSON)

    items = relationship("PubsubItem", back_populates="node", passive_deletes=True)
    subscriptions = relationship("PubsubSub", back_populates="node", passive_deletes=True)

    def __str__(self):
        return f"Pubsub node {self.name!r} at {self.service}"


class PubsubSub(Base):
    """Subscriptions to pubsub nodes

    Used by components managing a pubsub service
    """

    __tablename__ = "pubsub_subs"
    __table_args__ = (UniqueConstraint("node_id", "subscriber"),)

    id = Column(Integer, primary_key=True)
    node_id = Column(ForeignKey("pubsub_nodes.id", ondelete="CASCADE"), nullable=False)
    subscriber = Column(JID)
    state = Column(
        Enum(
            SubscriptionState,
            name="state",
            create_constraint=True,
        ),
        nullable=True,
    )

    node = relationship("PubsubNode", back_populates="subscriptions")


class PubsubItem(Base):
    __tablename__ = "pubsub_items"
    __table_args__ = (UniqueConstraint("node_id", "name"),)
    id = Column(Integer, primary_key=True)
    node_id = Column(ForeignKey("pubsub_nodes.id", ondelete="CASCADE"), nullable=False)
    name = Column(Text, nullable=False)
    data = Column(Xml, nullable=False)
    created = Column(DateTime, nullable=False, server_default=now())
    updated = Column(DateTime, nullable=False, server_default=now(), onupdate=now())
    parsed = Column(JSON)

    node = relationship("PubsubNode", back_populates="items")


## Full-Text Search

# create


@event.listens_for(PubsubItem.__table__, "after_create")
def fts_create(target, connection, **kw):
    """Full-Text Search table creation"""
    if connection.engine.name == "sqlite":
        # Using SQLite FTS5
        queries = [
            "CREATE VIRTUAL TABLE pubsub_items_fts "
            "USING fts5(data, content=pubsub_items, content_rowid=id)",
            "CREATE TRIGGER pubsub_items_fts_sync_ins AFTER INSERT ON pubsub_items BEGIN"
            "  INSERT INTO pubsub_items_fts(rowid, data) VALUES (new.id, new.data);"
            "END",
            "CREATE TRIGGER pubsub_items_fts_sync_del AFTER DELETE ON pubsub_items BEGIN"
            "  INSERT INTO pubsub_items_fts(pubsub_items_fts, rowid, data) "
            "VALUES('delete', old.id, old.data);"
            "END",
            "CREATE TRIGGER pubsub_items_fts_sync_upd AFTER UPDATE ON pubsub_items BEGIN"
            "  INSERT INTO pubsub_items_fts(pubsub_items_fts, rowid, data) VALUES"
            "('delete', old.id, old.data);"
            "  INSERT INTO pubsub_items_fts(rowid, data) VALUES(new.id, new.data);"
            "END",
        ]
        for q in queries:
            connection.execute(DDL(q))


# drop


@event.listens_for(PubsubItem.__table__, "before_drop")
def fts_drop(target, connection, **kw):
    "Full-Text Search table drop" ""
    if connection.engine.name == "sqlite":
        # Using SQLite FTS5
        queries = [
            "DROP TRIGGER IF EXISTS pubsub_items_fts_sync_ins",
            "DROP TRIGGER IF EXISTS pubsub_items_fts_sync_del",
            "DROP TRIGGER IF EXISTS pubsub_items_fts_sync_upd",
            "DROP TABLE IF EXISTS pubsub_items_fts",
        ]
        for q in queries:
            connection.execute(DDL(q))