view sat/memory/sqla_mapping.py @ 4002:5245b675f7ad

plugin XEP-0313: don't wait for MAM to be retrieved in connection workflow: MAM retrieval can be long, and can be done after connection, message just need to be sorted when being inserted (i.e. frontends must do insort). To avoid blocking connection for too long and result in bad UX and timeout risk, one2one MAM message are not retrieved in background.
author Goffi <goffi@goffi.org>
date Fri, 10 Mar 2023 17:22:45 +0100
parents 24c1c06c865b
children fe4725bf42fb
line wrap: on
line source

#!/usr/bin/env python3

# Libervia: an XMPP client
# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from typing import Dict, Any
from datetime import datetime
import enum
import json
import pickle
import time

from sqlalchemy import (
    Boolean,
    Column,
    DDL,
    DateTime,
    Enum,
    Float,
    ForeignKey,
    Index,
    Integer,
    JSON,
    MetaData,
    Text,
    UniqueConstraint,
    event,
)
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy.sql.functions import now
from sqlalchemy.types import TypeDecorator
from twisted.words.protocols.jabber import jid
from wokkel import generic


Base = declarative_base(
    metadata=MetaData(
        naming_convention={
            "ix": 'ix_%(column_0_label)s',
            "uq": "uq_%(table_name)s_%(column_0_name)s",
            "ck": "ck_%(table_name)s_%(constraint_name)s",
            "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
            "pk": "pk_%(table_name)s"
        }
    )
)
# keys which are in message data extra but not stored in extra field this is
# because those values are stored in separate fields
NOT_IN_EXTRA = ('origin_id', 'stanza_id', 'received_timestamp', 'update_uid')


class SyncState(enum.Enum):
    #: synchronisation is currently in progress
    IN_PROGRESS = 1
    #: synchronisation is done
    COMPLETED = 2
    #: something wrong happened during synchronisation, won't sync
    ERROR = 3
    #: synchronisation won't be done even if a syncing analyser matches
    NO_SYNC = 4


class SubscriptionState(enum.Enum):
    SUBSCRIBED = 1
    PENDING = 2


class LegacyPickle(TypeDecorator):
    """Handle troubles with data pickled by former version of SàT

    This type is temporary until we do migration to a proper data type
    """
    # Blob is used on SQLite but gives errors when used here, while Text works fine
    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return pickle.dumps(value, 0)

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        # value types are inconsistent (probably a consequence of Python 2/3 port
        # and/or SQLite dynamic typing)
        try:
            value = value.encode()
        except AttributeError:
            pass
        # "utf-8" encoding is needed to handle Python 2 pickled data
        return pickle.loads(value, encoding="utf-8")


class Json(TypeDecorator):
    """Handle JSON field in DB independant way"""
    # Blob is used on SQLite but gives errors when used here, while Text works fine
    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return json.dumps(value)

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return json.loads(value)


class JsonDefaultDict(Json):
    """Json type which convert NULL to empty dict instead of None"""

    def process_result_value(self, value, dialect):
        if value is None:
            return {}
        return json.loads(value)


class Xml(TypeDecorator):
    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return value.toXml()

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return generic.parseXml(value.encode())


class JID(TypeDecorator):
    """Store twisted JID in text fields"""
    impl = Text
    cache_ok = True

    def process_bind_param(self, value, dialect):
        if value is None:
            return None
        return value.full()

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        return jid.JID(value)


class Profile(Base):
    __tablename__ = "profiles"

    id = Column(
        Integer,
        primary_key=True,
        nullable=True,
    )
    name = Column(Text, unique=True)

    params = relationship("ParamInd", back_populates="profile", passive_deletes=True)
    private_data = relationship(
        "PrivateInd", back_populates="profile", passive_deletes=True
    )
    private_bin_data = relationship(
        "PrivateIndBin", back_populates="profile", passive_deletes=True
    )


class Component(Base):
    __tablename__ = "components"

    profile_id = Column(
        ForeignKey("profiles.id", ondelete="CASCADE"),
        nullable=True,
        primary_key=True
    )
    entry_point = Column(Text, nullable=False)
    profile = relationship("Profile")


class History(Base):
    __tablename__ = "history"
    __table_args__ = (
        UniqueConstraint("profile_id", "stanza_id", "source", "dest"),
        UniqueConstraint("profile_id", "origin_id", "source", name="uq_origin_id"),
        Index("history__profile_id_timestamp", "profile_id", "timestamp"),
        Index(
            "history__profile_id_received_timestamp", "profile_id", "received_timestamp"
        )
    )

    uid = Column(Text, primary_key=True)
    origin_id = Column(Text)
    stanza_id = Column(Text)
    update_uid = Column(Text)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"))
    source = Column(Text)
    dest = Column(Text)
    source_res = Column(Text)
    dest_res = Column(Text)
    timestamp = Column(Float, nullable=False)
    received_timestamp = Column(Float)
    type = Column(
        Enum(
            "chat",
            "error",
            "groupchat",
            "headline",
            "normal",
            # info is not XMPP standard, but used to keep track of info like join/leave
            # in a MUC
            "info",
            name="message_type",
            create_constraint=True,
        ),
        nullable=False,
    )
    extra = Column(LegacyPickle)

    profile = relationship("Profile")
    messages = relationship("Message", backref="history", passive_deletes=True)
    subjects = relationship("Subject", backref="history", passive_deletes=True)
    thread = relationship(
        "Thread", uselist=False, back_populates="history", passive_deletes=True
    )

    def __init__(self, *args, **kwargs):
        source_jid = kwargs.pop("source_jid", None)
        if source_jid is not None:
            kwargs["source"] = source_jid.userhost()
            kwargs["source_res"] = source_jid.resource
        dest_jid = kwargs.pop("dest_jid", None)
        if dest_jid is not None:
            kwargs["dest"] = dest_jid.userhost()
            kwargs["dest_res"] = dest_jid.resource
        super().__init__(*args, **kwargs)

    @property
    def source_jid(self) -> jid.JID:
        return jid.JID(f"{self.source}/{self.source_res or ''}")

    @source_jid.setter
    def source_jid(self, source_jid: jid.JID) -> None:
        self.source = source_jid.userhost
        self.source_res = source_jid.resource

    @property
    def dest_jid(self):
        return jid.JID(f"{self.dest}/{self.dest_res or ''}")

    @dest_jid.setter
    def dest_jid(self, dest_jid: jid.JID) -> None:
        self.dest = dest_jid.userhost
        self.dest_res = dest_jid.resource

    def __repr__(self):
        dt = datetime.fromtimestamp(self.timestamp)
        return f"History<{self.source_jid.full()}->{self.dest_jid.full()} [{dt}]>"

    def serialise(self):
        extra = self.extra
        if self.origin_id is not None:
            extra["origin_id"] = self.origin_id
        if self.stanza_id is not None:
            extra["stanza_id"] = self.stanza_id
        if self.update_uid is not None:
            extra["update_uid"] = self.update_uid
        if self.received_timestamp is not None:
            extra["received_timestamp"] = self.received_timestamp
        if self.thread is not None:
            extra["thread"] = self.thread.thread_id
            if self.thread.parent_id is not None:
                extra["thread_parent"] = self.thread.parent_id


        return {
            "from": f"{self.source}/{self.source_res}" if self.source_res
                else self.source,
            "to": f"{self.dest}/{self.dest_res}" if self.dest_res else self.dest,
            "uid": self.uid,
            "message": {m.language or '': m.message for m in self.messages},
            "subject": {m.language or '': m.subject for m in self.subjects},
            "type": self.type,
            "extra": extra,
            "timestamp": self.timestamp,
        }

    def as_tuple(self):
        d = self.serialise()
        return (
            d['uid'], d['timestamp'], d['from'], d['to'], d['message'], d['subject'],
            d['type'], d['extra']
        )

    @staticmethod
    def debug_collection(history_collection):
        for idx, history in enumerate(history_collection):
            history.debug_msg(idx)

    def debug_msg(self, idx=None):
        """Print messages"""
        dt = datetime.fromtimestamp(self.timestamp)
        if idx is not None:
            dt = f"({idx}) {dt}"
        parts = []
        parts.append(f"[{dt}]<{self.source_jid.full()}->{self.dest_jid.full()}> ")
        for message in self.messages:
            if message.language:
                parts.append(f"[{message.language}] ")
            parts.append(f"{message.message}\n")
        print("".join(parts))


class Message(Base):
    __tablename__ = "message"
    __table_args__ = (
        Index("message__history_uid", "history_uid"),
    )

    id = Column(
        Integer,
        primary_key=True,
    )
    history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE"), nullable=False)
    message = Column(Text, nullable=False)
    language = Column(Text)

    def serialise(self) -> Dict[str, Any]:
        s = {}
        if self.message:
            s["message"] = str(self.message)
        if self.language:
            s["language"] = str(self.language)
        return s

    def __repr__(self):
        lang_str = f"[{self.language}]" if self.language else ""
        msg = f"{self.message[:20]}…" if len(self.message)>20 else self.message
        content = f"{lang_str}{msg}"
        return f"Message<{content}>"


class Subject(Base):
    __tablename__ = "subject"
    __table_args__ = (
        Index("subject__history_uid", "history_uid"),
    )

    id = Column(
        Integer,
        primary_key=True,
    )
    history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE"), nullable=False)
    subject = Column(Text, nullable=False)
    language = Column(Text)

    def serialise(self) -> Dict[str, Any]:
        s = {}
        if self.subject:
            s["subject"] = str(self.subject)
        if self.language:
            s["language"] = str(self.language)
        return s

    def __repr__(self):
        lang_str = f"[{self.language}]" if self.language else ""
        msg = f"{self.subject[:20]}…" if len(self.subject)>20 else self.subject
        content = f"{lang_str}{msg}"
        return f"Subject<{content}>"


class Thread(Base):
    __tablename__ = "thread"
    __table_args__ = (
        Index("thread__history_uid", "history_uid"),
    )

    id = Column(
        Integer,
        primary_key=True,
    )
    history_uid = Column(ForeignKey("history.uid", ondelete="CASCADE"))
    thread_id = Column(Text)
    parent_id = Column(Text)

    history = relationship("History", uselist=False, back_populates="thread")

    def __repr__(self):
        return f"Thread<{self.thread_id} [parent: {self.parent_id}]>"


class ParamGen(Base):
    __tablename__ = "param_gen"

    category = Column(Text, primary_key=True)
    name = Column(Text, primary_key=True)
    value = Column(Text)


class ParamInd(Base):
    __tablename__ = "param_ind"

    category = Column(Text, primary_key=True)
    name = Column(Text, primary_key=True)
    profile_id = Column(
        ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True
    )
    value = Column(Text)

    profile = relationship("Profile", back_populates="params")


class PrivateGen(Base):
    __tablename__ = "private_gen"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    value = Column(Text)


class PrivateInd(Base):
    __tablename__ = "private_ind"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    profile_id = Column(
        ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True
    )
    value = Column(Text)

    profile = relationship("Profile", back_populates="private_data")


class PrivateGenBin(Base):
    __tablename__ = "private_gen_bin"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    value = Column(LegacyPickle)


class PrivateIndBin(Base):
    __tablename__ = "private_ind_bin"

    namespace = Column(Text, primary_key=True)
    key = Column(Text, primary_key=True)
    profile_id = Column(
        ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True
    )
    value = Column(LegacyPickle)

    profile = relationship("Profile", back_populates="private_bin_data")


class File(Base):
    __tablename__ = "files"
    __table_args__ = (
        Index("files__profile_id_owner_parent", "profile_id", "owner", "parent"),
        Index(
            "files__profile_id_owner_media_type_media_subtype",
            "profile_id",
            "owner",
            "media_type",
            "media_subtype"
        )
    )

    id = Column(Text, primary_key=True)
    public_id = Column(Text, unique=True)
    version = Column(Text, primary_key=True)
    parent = Column(Text, nullable=False)
    type = Column(
        Enum(
            "file", "directory",
            name="file_type",
            create_constraint=True
        ),
        nullable=False,
        server_default="file",
    )
    file_hash = Column(Text)
    hash_algo = Column(Text)
    name = Column(Text, nullable=False)
    size = Column(Integer)
    namespace = Column(Text)
    media_type = Column(Text)
    media_subtype = Column(Text)
    created = Column(Float, nullable=False)
    modified = Column(Float)
    owner = Column(JID)
    access = Column(JsonDefaultDict)
    extra = Column(JsonDefaultDict)
    profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"))

    profile = relationship("Profile")


class PubsubNode(Base):
    __tablename__ = "pubsub_nodes"
    __table_args__ = (
        UniqueConstraint("profile_id", "service", "name"),
    )

    id = Column(Integer, primary_key=True)
    profile_id = Column(
        ForeignKey("profiles.id", ondelete="CASCADE")
    )
    service = Column(JID)
    name = Column(Text, nullable=False)
    subscribed = Column(
        Boolean(create_constraint=True, name="subscribed_bool"), nullable=False
    )
    analyser = Column(Text)
    sync_state = Column(
        Enum(
            SyncState,
            name="sync_state",
            create_constraint=True,
        ),
        nullable=True
    )
    sync_state_updated = Column(
        Float,
        nullable=False,
        default=time.time()
    )
    type_ = Column(
        Text, name="type", nullable=True
    )
    subtype = Column(
        Text, nullable=True
    )
    extra = Column(JSON)

    items = relationship("PubsubItem", back_populates="node", passive_deletes=True)
    subscriptions = relationship("PubsubSub", back_populates="node", passive_deletes=True)

    def __str__(self):
        return f"Pubsub node {self.name!r} at {self.service}"


class PubsubSub(Base):
    """Subscriptions to pubsub nodes

    Used by components managing a pubsub service
    """
    __tablename__ = "pubsub_subs"
    __table_args__ = (
        UniqueConstraint("node_id", "subscriber"),
    )

    id = Column(Integer, primary_key=True)
    node_id = Column(ForeignKey("pubsub_nodes.id", ondelete="CASCADE"), nullable=False)
    subscriber = Column(JID)
    state = Column(
        Enum(
            SubscriptionState,
            name="state",
            create_constraint=True,
        ),
        nullable=True
    )

    node = relationship("PubsubNode", back_populates="subscriptions")


class PubsubItem(Base):
    __tablename__ = "pubsub_items"
    __table_args__ = (
        UniqueConstraint("node_id", "name"),
    )
    id = Column(Integer, primary_key=True)
    node_id = Column(ForeignKey("pubsub_nodes.id", ondelete="CASCADE"), nullable=False)
    name = Column(Text, nullable=False)
    data = Column(Xml, nullable=False)
    created = Column(DateTime, nullable=False, server_default=now())
    updated = Column(DateTime, nullable=False, server_default=now(), onupdate=now())
    parsed = Column(JSON)

    node = relationship("PubsubNode", back_populates="items")


## Full-Text Search

# create

@event.listens_for(PubsubItem.__table__, "after_create")
def fts_create(target, connection, **kw):
    """Full-Text Search table creation"""
    if connection.engine.name == "sqlite":
        # Using SQLite FTS5
        queries = [
            "CREATE VIRTUAL TABLE pubsub_items_fts "
            "USING fts5(data, content=pubsub_items, content_rowid=id)",
            "CREATE TRIGGER pubsub_items_fts_sync_ins AFTER INSERT ON pubsub_items BEGIN"
            "  INSERT INTO pubsub_items_fts(rowid, data) VALUES (new.id, new.data);"
            "END",
            "CREATE TRIGGER pubsub_items_fts_sync_del AFTER DELETE ON pubsub_items BEGIN"
            "  INSERT INTO pubsub_items_fts(pubsub_items_fts, rowid, data) "
            "VALUES('delete', old.id, old.data);"
            "END",
            "CREATE TRIGGER pubsub_items_fts_sync_upd AFTER UPDATE ON pubsub_items BEGIN"
            "  INSERT INTO pubsub_items_fts(pubsub_items_fts, rowid, data) VALUES"
            "('delete', old.id, old.data);"
            "  INSERT INTO pubsub_items_fts(rowid, data) VALUES(new.id, new.data);"
            "END"
        ]
        for q in queries:
            connection.execute(DDL(q))

# drop

@event.listens_for(PubsubItem.__table__, "before_drop")
def fts_drop(target, connection, **kw):
    "Full-Text Search table drop" ""
    if connection.engine.name == "sqlite":
        # Using SQLite FTS5
        queries = [
            "DROP TRIGGER IF EXISTS pubsub_items_fts_sync_ins",
            "DROP TRIGGER IF EXISTS pubsub_items_fts_sync_del",
            "DROP TRIGGER IF EXISTS pubsub_items_fts_sync_upd",
            "DROP TABLE IF EXISTS pubsub_items_fts",
        ]
        for q in queries:
            connection.execute(DDL(q))