view libervia/backend/memory/sqla_mapping.py @ 4180:b86912d3fd33

plugin IP: fix use of legacy URL + coroutine use: An https:/salut-a-toi.org URL was used to retrieve external IP, but it's not valid anymore, resulting in an exception. This feature is currently disabled. Also moved several methods from legacy inline callbacks to coroutines.
author Goffi <goffi@goffi.org>
date Sat, 09 Dec 2023 14:30:54 +0100
parents 2074b2bbe616
children 5f2d496c633f
line wrap: on
line source

#!/usr/bin/env python3

# Libervia: an XMPP client
# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from datetime import datetime
import enum
import json
import pickle
import time
from typing import Any, Dict

from sqlalchemy import (
    Boolean,
    Column,
    DDL,
    DateTime,
    Enum,
    Float,
    ForeignKey,
    Index,
    Integer,
    JSON,
    MetaData,
    text,
    Text,
    UniqueConstraint,
    event,
)
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.sql.functions import now
from sqlalchemy.types import TypeDecorator
from twisted.words.protocols.jabber import jid
from wokkel import generic

from libervia.backend.core.constants import Const as C


Base = declarative_base(
    metadata=MetaData(
        naming_convention={
            "ix": "ix_%(column_0_label)s",
            "uq": "uq_%(table_name)s_%(column_0_name)s",
            "ck": "ck_%(table_name)s_%(constraint_name)s",
            "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
            "pk": "pk_%(table_name)s",
        }
    )
)
# keys which are in message data extra but not stored in extra field this is
# because those values are stored in separate fields
NOT_IN_EXTRA = ("origin_id", "stanza_id", "received_timestamp", "update_uid")


class Profiles(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.id_to_profile = {v: k for k, v in self.items()}

    def __setitem__(self, key, value):
        super().__setitem__(key, value)
        self.id_to_profile[value] = key

    def __delitem__(self, key):
        del self.id_to_profile[self[key]]
        super().__delitem__(key)

    def update(self, *args, **kwargs):
        super().update(*args, **kwargs)
        self.id_to_profile = {v: k for k, v in self.items()}

    def clear(self):
        super().clear()
        self.id_to_profile.clear()


profiles = Profiles()


def get_profile_by_id( profile_id):
    return profiles.id_to_profile.get(profile_id)


class SyncState(enum.Enum):
    #: synchronisation is currently in progress
    IN_PROGRESS = 1
    #: synchronisation is done
    COMPLETED = 2
    #: something wrong happened during synchronisation, won't sync
    ERROR = 3
    #: synchronisation won't be done even if a syncing analyser matches
    NO_SYNC = 4


class SubscriptionState(enum.Enum):
    SUBSCRIBED = 1
    PENDING = 2


class NotificationType(enum.Enum):
    chat = "chat"
    blog = "blog"
    calendar = "calendar"
    file = "file"
    call = "call"
    service = "service"
    other = "other"


class NotificationStatus(enum.Enum):
    new = "new"
    read = "read"


class NotificationPriority(enum.IntEnum):
    LOW = 10
    MEDIUM = 20
    HIGH = 30
    URGENT = 40


class LegacyPickle(TypeDecorator):
    """Handle troubles with data pickled by former version of SàT

    This type is temporary until we do migration to a proper data type
    """

    # Blob is used on SQLite but gives errors when used here, while Text works fine
    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return pickle.dumps(value, 0)

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        # value types are inconsistent (probably a consequence of Python 2/3 port
        # and/or SQLite dynamic typing)
        try:
            value = value.encode()
        except AttributeError:
            pass
        # "utf-8" encoding is needed to handle Python 2 pickled data
        try:
            return pickle.loads(value, encoding="utf-8")
        except ModuleNotFoundError:
            # FIXME: workaround due to package renaming, need to move all pickle code to
            #   JSON
            return pickle.loads(
                value.replace(b"sat.plugins", b"libervia.backend.plugins"),
                encoding="utf-8",
            )


class Json(TypeDecorator):
    """Handle JSON field in DB independant way"""

    # Blob is used on SQLite but gives errors when used here, while Text works fine
    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return json.dumps(value)

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return json.loads(value)


class JsonDefaultDict(Json):
    """Json type which convert NULL to empty dict instead of None"""

    def process_result_value(self, value, dialect):
        if value is None:
            return {}
        return json.loads(value)


class Xml(TypeDecorator):
    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return value.toXml()

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return generic.parseXml(value.encode())


class JID(TypeDecorator):
    """Store twisted JID in text fields"""

    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return value.full()

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return jid.JID(value)


class Profile(Base):
    __tablename__ = "profiles"

    id = Column(
        Integer,
        primary_key=True,
        nullable=True,
    )
    name = Column(Text, unique=True)

    params = relationship("ParamInd", back_populates="profile", passive_deletes=True)
    private_data = relationship(
        "PrivateInd", back_populates="profile", passive_deletes=True
    )
    private_bin_data = relationship(
        "PrivateIndBin", back_populates="profile", passive_deletes=True
    )


class Component(Base):
    __tablename__ = "components"

    profile_id = Column(
        ForeignKey("profiles.id", ondelete="CASCADE"), nullable=True, primary_key=True
    )
    entry_point = Column(Text, nullable=False)
    profile = relationship("Profile")


class History(Base):
    __tablename__ = "history"
    __table_args__ = (
        UniqueConstraint("profile_id", "stanza_id", "source", "dest"),
        UniqueConstraint("profile_id", "origin_id", "source", name="uq_origin_id"),
        Index("history__profile_id_timestamp", "profile_id", "timestamp"),
        Index(
            "history__profile_id_received_timestamp", "profile_id", "received_timestamp"
        ),
    )

    uid = Column(Text, primary_key=True)
    # FIXME: version_id is only needed for changes in `extra` column. It would maybe be
    # better to use separate table for `extra` data instead.
    version_id = Column(Integer, nullable=False, server_default=text("1"))
    origin_id = Column(Text)
    stanza_id = Column(Text)
    update_uid = Column(Text)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"))
    source = Column(Text)
    dest = Column(Text)
    source_res = Column(Text)
    dest_res = Column(Text)
    timestamp = Column(Float, nullable=False)
    received_timestamp = Column(Float)
    type = Column(
        Enum(
            "chat",
            "error",
            "groupchat",
            "headline",
            "normal",
            # info is not XMPP standard, but used to keep track of info like join/leave
            # in a MUC
            "info",
            name="message_type",
            create_constraint=True,
        ),
        nullable=False,
    )
    extra = Column(LegacyPickle)

    profile = relationship("Profile")
    messages = relationship(
        "Message",
        backref="history",
        cascade="all, delete-orphan",
        passive_deletes=True
    )
    subjects = relationship(
        "Subject",
        backref="history",
        cascade="all, delete-orphan",
        passive_deletes=True
    )
    thread = relationship(
        "Thread",
        uselist=False,
        back_populates="history",
        cascade="all, delete-orphan",
        passive_deletes=True
    )
    __mapper_args__ = {"version_id_col": version_id}

    def __init__(self, *args, **kwargs):
        source_jid = kwargs.pop("source_jid", None)
        if source_jid is not None:
            kwargs["source"] = source_jid.userhost()
            kwargs["source_res"] = source_jid.resource
        dest_jid = kwargs.pop("dest_jid", None)
        if dest_jid is not None:
            kwargs["dest"] = dest_jid.userhost()
            kwargs["dest_res"] = dest_jid.resource
        super().__init__(*args, **kwargs)

    @property
    def source_jid(self) -> jid.JID:
        return jid.JID(f"{self.source}/{self.source_res or ''}")

    @source_jid.setter
    def source_jid(self, source_jid: jid.JID) -> None:
        self.source = source_jid.userhost
        self.source_res = source_jid.resource

    @property
    def dest_jid(self):
        return jid.JID(f"{self.dest}/{self.dest_res or ''}")

    @dest_jid.setter
    def dest_jid(self, dest_jid: jid.JID) -> None:
        self.dest = dest_jid.userhost
        self.dest_res = dest_jid.resource

    def __repr__(self):
        dt = datetime.fromtimestamp(self.timestamp)
        return f"History<{self.source_jid.full()}->{self.dest_jid.full()} [{dt}]>"

    def serialise(self):
        extra = self.extra or {}
        if self.origin_id is not None:
            extra["origin_id"] = self.origin_id
        if self.stanza_id is not None:
            extra["stanza_id"] = self.stanza_id
        if self.update_uid is not None:
            extra["update_uid"] = self.update_uid
        if self.received_timestamp is not None:
            extra["received_timestamp"] = self.received_timestamp
        if self.thread is not None:
            extra["thread"] = self.thread.thread_id
            if self.thread.parent_id is not None:
                extra["thread_parent"] = self.thread.parent_id

        return {
            "from": f"{self.source}/{self.source_res}"
            if self.source_res
            else self.source,
            "to": f"{self.dest}/{self.dest_res}" if self.dest_res else self.dest,
            "uid": self.uid,
            "message": {m.language or "": m.message for m in self.messages},
            "subject": {m.language or "": m.subject for m in self.subjects},
            "type": self.type,
            "extra": extra,
            "timestamp": self.timestamp,
        }

    def as_tuple(self):
        d = self.serialise()
        return (
            d["uid"],
            d["timestamp"],
            d["from"],
            d["to"],
            d["message"],
            d["subject"],
            d["type"],
            d["extra"],
        )

    @staticmethod
    def debug_collection(history_collection):
        for idx, history in enumerate(history_collection):
            history.debug_msg(idx)

    def debug_msg(self, idx=None):
        """Print messages"""
        dt = datetime.fromtimestamp(self.timestamp)
        if idx is not None:
            dt = f"({idx}) {dt}"
        parts = []
        parts.append(f"[{dt}]<{self.source_jid.full()}->{self.dest_jid.full()}> ")
        for message in self.messages:
            if message.language:
                parts.append(f"[{message.language}] ")
            parts.append(f"{message.message}\n")
        print("".join(parts))


class Message(Base):
    __tablename__ = "message"
    __table_args__ = (Index("message__history_uid", "history_uid"),)

    id = Column(
        Integer,
        primary_key=True,
    )
    history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE"), nullable=False)
    message = Column(Text, nullable=False)
    language = Column(Text)

    def serialise(self) -> Dict[str, Any]:
        s = {}
        if self.message:
            s["message"] = str(self.message)
        if self.language:
            s["language"] = str(self.language)
        return s

    def __repr__(self):
        lang_str = f"[{self.language}]" if self.language else ""
        msg = f"{self.message[:20]}…" if len(self.message) > 20 else self.message
        content = f"{lang_str}{msg}"
        return f"Message<{content}>"


class Subject(Base):
    __tablename__ = "subject"
    __table_args__ = (Index("subject__history_uid", "history_uid"),)

    id = Column(
        Integer,
        primary_key=True,
    )
    history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE"), nullable=False)
    subject = Column(Text, nullable=False)
    language = Column(Text)

    def serialise(self) -> Dict[str, Any]:
        s = {}
        if self.subject:
            s["subject"] = str(self.subject)
        if self.language:
            s["language"] = str(self.language)
        return s

    def __repr__(self):
        lang_str = f"[{self.language}]" if self.language else ""
        msg = f"{self.subject[:20]}…" if len(self.subject) > 20 else self.subject
        content = f"{lang_str}{msg}"
        return f"Subject<{content}>"


class Thread(Base):
    __tablename__ = "thread"
    __table_args__ = (Index("thread__history_uid", "history_uid"),)

    id = Column(
        Integer,
        primary_key=True,
    )
    history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE"))
    thread_id = Column(Text)
    parent_id = Column(Text)

    history = relationship("History", uselist=False, back_populates="thread")

    def __repr__(self):
        return f"Thread<{self.thread_id} [parent: {self.parent_id}]>"


class Notification(Base):
    __tablename__ = "notifications"
    __table_args__ = (Index("notifications_profile_id_status", "profile_id", "status"),)

    id = Column(Integer, primary_key=True, autoincrement=True)
    timestamp = Column(Float, nullable=False, default=time.time)
    expire_at = Column(Float, nullable=True)

    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), index=True, nullable=True)
    profile = relationship("Profile")

    type = Column(Enum(NotificationType), nullable=False)

    title = Column(Text, nullable=True)
    body_plain = Column(Text, nullable=False)
    body_rich = Column(Text, nullable=True)

    requires_action = Column(Boolean, default=False)
    priority = Column(Integer, default=NotificationPriority.MEDIUM.value)

    extra_data = Column(JSON)
    status = Column(Enum(NotificationStatus), default=NotificationStatus.new)

    def serialise(self) -> dict[str, str | float | bool | int | dict]:
        """
        Serialises the Notification instance to a dictionary.
        """
        result = {}
        for column in self.__table__.columns:
            value = getattr(self, column.name)
            if value is not None:
                if column.name in ("type", "status"):
                    result[column.name] = value.name
                elif column.name == "id":
                    result[column.name] = str(value)
                elif column.name == "profile_id":
                    if value is None:
                        result["profile"] = C.PROF_KEY_ALL
                    else:
                        result["profile"] = get_profile_by_id(value)
                else:
                    result[column.name] = value
        return result


class ParamGen(Base):
    __tablename__ = "param_gen"

    category = Column(Text, primary_key=True)
    name = Column(Text, primary_key=True)
    value = Column(Text)


class ParamInd(Base):
    __tablename__ = "param_ind"

    category = Column(Text, primary_key=True)
    name = Column(Text, primary_key=True)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True)
    value = Column(Text)

    profile = relationship("Profile", back_populates="params")


class PrivateGen(Base):
    __tablename__ = "private_gen"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    value = Column(Text)


class PrivateInd(Base):
    __tablename__ = "private_ind"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True)
    value = Column(Text)

    profile = relationship("Profile", back_populates="private_data")


class PrivateGenBin(Base):
    __tablename__ = "private_gen_bin"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    value = Column(LegacyPickle)


class PrivateIndBin(Base):
    __tablename__ = "private_ind_bin"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True)
    value = Column(LegacyPickle)

    profile = relationship("Profile", back_populates="private_bin_data")


class File(Base):
    __tablename__ = "files"
    __table_args__ = (
        Index("files__profile_id_owner_parent", "profile_id", "owner", "parent"),
        Index(
            "files__profile_id_owner_media_type_media_subtype",
            "profile_id",
            "owner",
            "media_type",
            "media_subtype",
        ),
    )

    id = Column(Text, primary_key=True)
    public_id = Column(Text, unique=True)
    version = Column(Text, primary_key=True)
    parent = Column(Text, nullable=False)
    type = Column(
        Enum("file", "directory", name="file_type", create_constraint=True),
        nullable=False,
        server_default="file",
    )
    file_hash = Column(Text)
    hash_algo = Column(Text)
    name = Column(Text, nullable=False)
    size = Column(Integer)
    namespace = Column(Text)
    media_type = Column(Text)
    media_subtype = Column(Text)
    created = Column(Float, nullable=False)
    modified = Column(Float)
    owner = Column(JID)
    access = Column(JsonDefaultDict)
    extra = Column(JsonDefaultDict)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"))

    profile = relationship("Profile")


class PubsubNode(Base):
    __tablename__ = "pubsub_nodes"
    __table_args__ = (UniqueConstraint("profile_id", "service", "name"),)

    id = Column(Integer, primary_key=True)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"))
    service = Column(JID)
    name = Column(Text, nullable=False)
    subscribed = Column(
        Boolean(create_constraint=True, name="subscribed_bool"), nullable=False
    )
    analyser = Column(Text)
    sync_state = Column(
        Enum(
            SyncState,
            name="sync_state",
            create_constraint=True,
        ),
        nullable=True,
    )
    sync_state_updated = Column(Float, nullable=False, default=time.time())
    type_ = Column(Text, name="type", nullable=True)
    subtype = Column(Text, nullable=True)
    extra = Column(JSON)

    items = relationship("PubsubItem", back_populates="node", passive_deletes=True)
    subscriptions = relationship("PubsubSub", back_populates="node", passive_deletes=True)

    def __str__(self):
        return f"Pubsub node {self.name!r} at {self.service}"


class PubsubSub(Base):
    """Subscriptions to pubsub nodes

    Used by components managing a pubsub service
    """

    __tablename__ = "pubsub_subs"
    __table_args__ = (UniqueConstraint("node_id", "subscriber"),)

    id = Column(Integer, primary_key=True)
    node_id = Column(ForeignKey("pubsub_nodes.id", ondelete="CASCADE"), nullable=False)
    subscriber = Column(JID)
    state = Column(
        Enum(
            SubscriptionState,
            name="state",
            create_constraint=True,
        ),
        nullable=True,
    )

    node = relationship("PubsubNode", back_populates="subscriptions")


class PubsubItem(Base):
    __tablename__ = "pubsub_items"
    __table_args__ = (UniqueConstraint("node_id", "name"),)
    id = Column(Integer, primary_key=True)
    node_id = Column(ForeignKey("pubsub_nodes.id", ondelete="CASCADE"), nullable=False)
    name = Column(Text, nullable=False)
    data = Column(Xml, nullable=False)
    created = Column(DateTime, nullable=False, server_default=now())
    updated = Column(DateTime, nullable=False, server_default=now(), onupdate=now())
    parsed = Column(JSON)

    node = relationship("PubsubNode", back_populates="items")


## Full-Text Search

# create


@event.listens_for(PubsubItem.__table__, "after_create")
def fts_create(target, connection, **kw):
    """Full-Text Search table creation"""
    if connection.engine.name == "sqlite":
        # Using SQLite FTS5
        queries = [
            "CREATE VIRTUAL TABLE pubsub_items_fts "
            "USING fts5(data, content=pubsub_items, content_rowid=id)",
            "CREATE TRIGGER pubsub_items_fts_sync_ins AFTER INSERT ON pubsub_items BEGIN"
            "  INSERT INTO pubsub_items_fts(rowid, data) VALUES (new.id, new.data);"
            "END",
            "CREATE TRIGGER pubsub_items_fts_sync_del AFTER DELETE ON pubsub_items BEGIN"
            "  INSERT INTO pubsub_items_fts(pubsub_items_fts, rowid, data) "
            "VALUES('delete', old.id, old.data);"
            "END",
            "CREATE TRIGGER pubsub_items_fts_sync_upd AFTER UPDATE ON pubsub_items BEGIN"
            "  INSERT INTO pubsub_items_fts(pubsub_items_fts, rowid, data) VALUES"
            "('delete', old.id, old.data);"
            "  INSERT INTO pubsub_items_fts(rowid, data) VALUES(new.id, new.data);"
            "END",
        ]
        for q in queries:
            connection.execute(DDL(q))


# drop


@event.listens_for(PubsubItem.__table__, "before_drop")
def fts_drop(target, connection, **kw):
    "Full-Text Search table drop" ""
    if connection.engine.name == "sqlite":
        # Using SQLite FTS5
        queries = [
            "DROP TRIGGER IF EXISTS pubsub_items_fts_sync_ins",
            "DROP TRIGGER IF EXISTS pubsub_items_fts_sync_del",
            "DROP TRIGGER IF EXISTS pubsub_items_fts_sync_upd",
            "DROP TABLE IF EXISTS pubsub_items_fts",
        ]
        for q in queries:
            connection.execute(DDL(q))