diff libervia/backend/plugins/plugin_misc_jid_search.py @ 4358:c8d089b0e478

plugin JID Search: Use Pydantic models + better handling of partial search.
author Goffi <goffi@goffi.org>
date Fri, 11 Apr 2025 18:19:28 +0200
parents 0d7bb4df2343
children
line wrap: on
line diff
--- a/libervia/backend/plugins/plugin_misc_jid_search.py	Fri Apr 11 18:19:28 2025 +0200
+++ b/libervia/backend/plugins/plugin_misc_jid_search.py	Fri Apr 11 18:19:28 2025 +0200
@@ -17,18 +17,18 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 from collections import OrderedDict
-from dataclasses import dataclass, asdict
 import difflib
-from typing import List, Optional
+from typing import Annotated, Iterator, Literal, NamedTuple
 
+from pydantic import BaseModel, Field, RootModel
 from twisted.internet import defer
 from twisted.words.protocols.jabber import jid
 
 from libervia.backend.core.constants import Const as C
-from libervia.backend.core.core_types import SatXMPPEntity
+from libervia.backend.core.core_types import SatXMPPClient, SatXMPPEntity
 from libervia.backend.core.i18n import _
 from libervia.backend.core.log import getLogger
-from libervia.backend.tools.common import data_format
+from libervia.backend.models.types import JIDType
 
 log = getLogger(__name__)
 
@@ -46,20 +46,77 @@
     C.PI_DESCRIPTION: _("""Search for XMPP entities"""),
 }
 RATIO_CUTOFF = 0.6
+# Used when a search term matches a substring of a relevant data.
+PARTIAL_MATCH_RATIO = 0.8
 MAX_CACHE_SIZE = 10
 
 
-@dataclass
-class JidSearchItem:
-    entity: jid.JID
+class JidSearchItem(BaseModel):
+    entity: JIDType
     name: str = ""
-    in_roster: bool = False
-    groups: list[str] | None = None
     exact_match: bool = False
     relevance: float | None = None
 
 
-JidSearchCache = OrderedDict[str, list[JidSearchItem]]
+class EntitySearchItem(JidSearchItem):
+    type: Literal["entity"] = "entity"
+    in_roster: bool = False
+    groups: list[str] | None = None
+
+
+class RoomSearchItem(JidSearchItem):
+    type: Literal["room"] = "room"
+    local: bool = Field(
+        description="True if the room comes from a server local component."
+    )
+    service_type: str | None = None
+    description: str | None = None
+    language: str | None = None
+    nusers: int | None = Field(default=None, ge=0)
+    anonymity_mode: str | None = None
+    is_open: bool | None = None
+
+
+SearchItem = Annotated[EntitySearchItem | RoomSearchItem, Field(discriminator="type")]
+
+
+class SearchItems(RootModel):
+    root: list[SearchItem]
+
+    def __iter__(self) -> Iterator[SearchItem]:  # type: ignore
+        return iter(self.root)
+
+    def __getitem__(self, item) -> str:
+        return self.root[item]
+
+    def __len__(self) -> int:
+        return len(self.root)
+
+    def append(self, item: SearchItem) -> None:
+        self.root.append(item)
+
+    def sort(self, key=None, reverse=False) -> None:
+        self.root.sort(key=key, reverse=reverse)  # type: ignore
+
+
+class Options(BaseModel):
+    entities: bool = Field(
+        default=False, description="Search for entities for direct chat"
+    )
+    groupchat: bool = Field(
+        default=False, description="Search for group chats."
+    )
+    allow_external: bool = Field(
+        default=False, description="Authorise doing request to external services."
+    )
+
+
+class CachedSearch(NamedTuple):
+    search_items: SearchItems
+    options: Options
+
+
+JidSearchCache = OrderedDict[str, CachedSearch]
 
 
 class JidSearch:
@@ -80,17 +137,17 @@
 
     def _search(self, search_term: str, options_s: str, profile: str) -> defer.Deferred:
         client = self.host.get_client(profile)
-        d = defer.ensureDeferred(
-            self.search(client, search_term, data_format.deserialise(options_s))
-        )
-        d.addCallback(
-            lambda search_items: data_format.serialise([asdict(i) for i in search_items])
-        )
+        options = Options.model_validate_json(options_s)
+        d = defer.ensureDeferred(self.search(client, search_term, options))
+        d.addCallback(lambda search_items: search_items.model_dump_json())
         return d
 
     async def search(
-        self, client: SatXMPPEntity, search_term: str, options: Optional[dict] = None
-    ) -> List[JidSearchItem]:
+        self,
+        client: SatXMPPEntity,
+        search_term: str,
+        options: Options | dict | None = None,
+    ) -> SearchItems:
         """Searches for entities in various locations.
 
         @param client: The SatXMPPEntity client where the search is to be performed.
@@ -98,6 +155,10 @@
         @param options: Additional search options.
         @return: A list of matches found.
         """
+        if options is None:
+            options = Options()
+        elif isinstance(options, dict):
+            options = Options(**options)
         search_term = search_term.strip().lower()
         sequence_matcher = difflib.SequenceMatcher()
         sequence_matcher.set_seq1(search_term)
@@ -110,27 +171,33 @@
         # Look for a match in the cache
         for cache_key in cache:
             if search_term.startswith(cache_key):
+                cached_data = cache[cache_key]
+                if cached_data.options != options:
+                    log.debug("Ignoring cached data due to incompatible options.")
+                    continue
                 log.debug(
                     f"Match found in cache for {search_term!r} in [{client.profile}]."
                 )
                 # If an exact match is found, return the results as is
                 if search_term == cache_key:
                     log.debug("Exact match found in cache, reusing results.")
-                    matches = cache[cache_key]
+                    matches = cached_data.search_items
                 else:
                     # If only the beginning matches, filter the cache results
                     log.debug("Prefix match found in cache, filtering results.")
-                    matches = []
-                    for jid_search_item in cache[cache_key]:
-                        self._process_matching(
+                    matches = SearchItems([])
+                    for jid_search_item in cached_data.search_items:
+                        self.process_matching(
                             search_term, sequence_matcher, matches, jid_search_item
                         )
                 cache.move_to_end(cache_key)
                 break
         else:
             # If no match is found in the cache, perform a new search
-            matches = await self._perform_search(client, search_term, sequence_matcher)
-            cache[search_term] = matches
+            matches = await self.perform_search(
+                client, search_term, options, sequence_matcher
+            )
+            cache[search_term] = CachedSearch(matches, options)
             if len(cache) > MAX_CACHE_SIZE:
                 cache.popitem(last=False)
 
@@ -144,7 +211,7 @@
                 pass
             else:
                 matches.append(
-                    JidSearchItem(
+                    EntitySearchItem(
                         entity=search_jid,
                         in_roster=False,
                         exact_match=True,
@@ -153,29 +220,34 @@
                 )
 
         matches.sort(
-            key=lambda item: (item.exact_match, item.relevance or 0, item.in_roster),
+            key=lambda item: (
+                item.exact_match, item.relevance or 0, getattr(item, "in_roster", False)
+            ),
             reverse=True,
         )
 
         return matches
 
-    def _process_matching(
+    def process_matching(
         self,
         search_term: str,
         sequence_matcher: difflib.SequenceMatcher,
-        matches: List[JidSearchItem],
-        item: JidSearchItem,
+        matches: SearchItems,
+        item: SearchItem,
     ) -> None:
-        """Process matching of items
+        """Process the matching of an item against a search term.
 
-        @param sequence_matcher: The sequence matcher to be used for the matching process.
-        @param matches: A list where the match is to be appended.
-        @param item: The item that to be matched.
-        @return: True if it was an exact match
+        This method checks if the given item is an exact match or if it has any
+        significant similarity to the search term. If a match is found, the item's
+        relevance score is set and the item is added to the matches list.
+
+        @param sequence_matcher: The sequence matcher used for comparing strings.
+        @param matches: A list where matched items will be appended.
+        @param item: The item to be compared against the search term.
         """
 
-        item_name_lower = item.name.lower()
-        item_entity_lower = item.entity.full().lower()
+        item_name_lower = item.name.strip().lower()
+        item_entity_lower = item.entity.userhost().lower()
 
         if search_term in (item_name_lower, item_entity_lower):
             item.exact_match = True
@@ -185,12 +257,25 @@
 
         item.exact_match = False
 
-        sequence_matcher.set_seq2(item_name_lower)
-        name_ratio = sequence_matcher.ratio()
-        if name_ratio >= RATIO_CUTOFF:
-            item.relevance = name_ratio
-            matches.append(item)
-            return
+        # Check if search_term is a substring of item_name_lower or item_entity_lower
+        if len(search_term) >= 3:
+            if item_name_lower and search_term in item_name_lower:
+                item.relevance = PARTIAL_MATCH_RATIO
+                matches.append(item)
+                return
+
+            if search_term in item_entity_lower:
+                item.relevance = PARTIAL_MATCH_RATIO
+                matches.append(item)
+                return
+
+        if item_name_lower:
+            sequence_matcher.set_seq2(item_name_lower)
+            name_ratio = sequence_matcher.ratio()
+            if name_ratio >= RATIO_CUTOFF:
+                item.relevance = name_ratio
+                matches.append(item)
+                return
 
         sequence_matcher.set_seq2(item_entity_lower)
         jid_ratio = sequence_matcher.ratio()
@@ -208,7 +293,7 @@
                 matches.append(item)
                 return
 
-        if item.groups:
+        if isinstance(item, EntitySearchItem) and item.groups:
             group_ratios = []
             for group in item.groups:
                 sequence_matcher.set_seq2(group.lower())
@@ -227,38 +312,52 @@
             matches.append(item)
             return
 
-    async def _perform_search(
+    async def perform_search(
         self,
         client: SatXMPPEntity,
         search_term: str,
+        options: Options,
         sequence_matcher: difflib.SequenceMatcher,
-    ) -> List[JidSearchItem]:
-        """Performs a new search when no match is found in the cache.
+    ) -> SearchItems:
+        """Performs a new search.
+
+        Cache is not used here.
 
         @param search_term: The query to be searched.
         @param sequence_matcher: The SequenceMatcher object to be used for matching.
         @return: A list of matches found.
         """
-        matches = []
+        matches = SearchItems([])
 
-        try:
-            roster = client.roster
-        except AttributeError:
-            # components have no roster
-            roster = []
-        else:
-            roster = client.roster.get_items()
+        if options.entities:
+            assert isinstance(client, SatXMPPClient)
+            try:
+                client.roster
+            except AttributeError:
+                # components have no roster
+                roster = []
+            else:
+                roster = client.roster.get_items()
 
-        for roster_item in roster:
-            jid_search_item = JidSearchItem(
-                entity=roster_item.entity,
-                name=roster_item.name,
-                in_roster=True,
-                groups=list(roster_item.groups),
-            )
+            for roster_item in roster:
+                jid_search_item = EntitySearchItem(
+                    entity=roster_item.entity,
+                    name=roster_item.name,
+                    in_roster=True,
+                    groups=list(roster_item.groups),
+                )
 
-            self._process_matching(
-                search_term, sequence_matcher, matches, jid_search_item
-            )
+                self.process_matching(
+                    search_term, sequence_matcher, matches, jid_search_item
+                )
+
+        await self.host.trigger.async_point(
+            "JID_SEARCH_perform_search",
+            client,
+            search_term,
+            options,
+            sequence_matcher,
+            matches,
+        )
 
         return matches