changeset 3744:658ddbabaf36

core (memory/sqla): new table/mapping to handle Pubsub node subscriptions: node subscriptions can now be cached, this can be useful for components which must keep track of subscibers. rel 364
author Goffi <goffi@goffi.org>
date Tue, 22 Mar 2022 17:00:42 +0100 (2022-03-22)
parents 54c249ec35ce
children a8c7e5cef0cb
files sat/memory/migration/versions/79e5f3313fa4_create_table_for_pubsub_subscriptions.py sat/memory/sqla.py sat/memory/sqla_mapping.py
diffstat 3 files changed, 76 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sat/memory/migration/versions/79e5f3313fa4_create_table_for_pubsub_subscriptions.py	Tue Mar 22 17:00:42 2022 +0100
@@ -0,0 +1,33 @@
+"""create table for pubsub subscriptions
+
+Revision ID: 79e5f3313fa4
+Revises: 129ac51807e4
+Create Date: 2022-03-14 17:15:00.689871
+
+"""
+from alembic import op
+import sqlalchemy as sa
+from sat.memory.sqla_mapping import JID
+
+
+# revision identifiers, used by Alembic.
+revision = '79e5f3313fa4'
+down_revision = '129ac51807e4'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    op.create_table('pubsub_subs',
+    sa.Column('id', sa.Integer(), nullable=False),
+    sa.Column('node_id', sa.Integer(), nullable=False),
+    sa.Column('subscriber', JID(), nullable=True),
+    sa.Column('state', sa.Enum('SUBSCRIBED', 'PENDING', name='state'), nullable=True),
+    sa.ForeignKeyConstraint(['node_id'], ['pubsub_nodes.id'], name=op.f('fk_pubsub_subs_node_id_pubsub_nodes'), ondelete='CASCADE'),
+    sa.PrimaryKeyConstraint('id', name=op.f('pk_pubsub_subs')),
+    sa.UniqueConstraint('node_id', 'subscriber', name=op.f('uq_pubsub_subs_node_id'))
+    )
+
+
+def downgrade():
+    op.drop_table('pubsub_subs')
--- a/sat/memory/sqla.py	Tue Mar 22 17:00:42 2022 +0100
+++ b/sat/memory/sqla.py	Tue Mar 22 17:00:42 2022 +0100
@@ -111,7 +111,7 @@
         self.initialized = defer.Deferred()
         # we keep cache for the profiles (key: profile name, value: profile id)
         # profile id to name
-        self.profiles: Dict[int, str] = {}
+        self.profiles: Dict[str, int] = {}
         # profile id to component entry point
         self.components: Dict[int, str] = {}
 
@@ -1015,6 +1015,7 @@
         service: jid.JID,
         name: str,
         with_items: bool = False,
+        with_subscriptions: bool = False,
     ) -> Optional[PubsubNode]:
         """
         """
@@ -1031,6 +1032,10 @@
                 stmt = stmt.options(
                     joinedload(PubsubNode.items)
                 )
+            if with_subscriptions:
+                stmt = stmt.options(
+                    joinedload(PubsubNode.subscriptions)
+                )
             result = await session.execute(stmt)
         return result.unique().scalar_one_or_none()
 
@@ -1043,15 +1048,17 @@
         analyser: Optional[str] = None,
         type_: Optional[str] = None,
         subtype: Optional[str] = None,
+        subscribed: bool = False,
     ) -> PubsubNode:
         node = PubsubNode(
             profile_id=self.profiles[client.profile],
             service=service,
             name=name,
-            subscribed=False,
+            subscribed=subscribed,
             analyser=analyser,
             type_=type_,
             subtype=subtype,
+            subscriptions=[],
         )
         async with self.session() as session:
             async with session.begin():
@@ -1187,6 +1194,9 @@
             "type_": types,
             "subtype": subtypes,
         }
+        if profiles is not None:
+            node_fields["profile_id"] = [self.profiles[p] for p in profiles]
+
         if any(x is not None for x in node_fields.values()):
             sub_q = select(PubsubNode.id)
             for col, values in node_fields.items():
@@ -1199,11 +1209,6 @@
                 .execution_options(synchronize_session=False)
             )
 
-        if profiles is not None:
-            stmt = stmt.where(
-                PubsubItem.profile_id.in_([self.profiles[p] for p in profiles])
-            )
-
         if created_before is not None:
             stmt = stmt.where(PubsubItem.created < created_before)
 
--- a/sat/memory/sqla_mapping.py	Tue Mar 22 17:00:42 2022 +0100
+++ b/sat/memory/sqla_mapping.py	Tue Mar 22 17:00:42 2022 +0100
@@ -60,6 +60,11 @@
     NO_SYNC = 4
 
 
+class SubscriptionState(enum.Enum):
+    SUBSCRIBED = 1
+    PENDING = 2
+
+
 class LegacyPickle(TypeDecorator):
     """Handle troubles with data pickled by former version of SàT
 
@@ -510,11 +515,37 @@
     extra = Column(JSON)
 
     items = relationship("PubsubItem", back_populates="node", passive_deletes=True)
+    subscriptions = relationship("PubsubSub", back_populates="node", passive_deletes=True)
 
     def __str__(self):
         return f"Pubsub node {self.name!r} at {self.service}"
 
 
+class PubsubSub(Base):
+    """Subscriptions to pubsub nodes
+
+    Used by components managing a pubsub service
+    """
+    __tablename__ = "pubsub_subs"
+    __table_args__ = (
+        UniqueConstraint("node_id", "subscriber"),
+    )
+
+    id = Column(Integer, primary_key=True)
+    node_id = Column(ForeignKey("pubsub_nodes.id", ondelete="CASCADE"), nullable=False)
+    subscriber = Column(JID)
+    state = Column(
+        Enum(
+            SubscriptionState,
+            name="state",
+            create_constraint=True,
+        ),
+        nullable=True
+    )
+
+    node = relationship("PubsubNode", back_populates="subscriptions")
+
+
 class PubsubItem(Base):
     __tablename__ = "pubsub_items"
     __table_args__ = (