Mercurial > libervia-backend
changeset 4358:c8d089b0e478
plugin JID Search: Use Pydantic models + better handling of partial search.
author | Goffi <goffi@goffi.org> |
---|---|
date | Fri, 11 Apr 2025 18:19:28 +0200 |
parents | f43cbceba2a0 |
children | a987a8ce34b9 |
files | libervia/backend/plugins/plugin_misc_jid_search.py |
diffstat | 1 files changed, 163 insertions(+), 64 deletions(-) [+] |
line wrap: on
line diff
--- a/libervia/backend/plugins/plugin_misc_jid_search.py Fri Apr 11 18:19:28 2025 +0200 +++ b/libervia/backend/plugins/plugin_misc_jid_search.py Fri Apr 11 18:19:28 2025 +0200 @@ -17,18 +17,18 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. from collections import OrderedDict -from dataclasses import dataclass, asdict import difflib -from typing import List, Optional +from typing import Annotated, Iterator, Literal, NamedTuple +from pydantic import BaseModel, Field, RootModel from twisted.internet import defer from twisted.words.protocols.jabber import jid from libervia.backend.core.constants import Const as C -from libervia.backend.core.core_types import SatXMPPEntity +from libervia.backend.core.core_types import SatXMPPClient, SatXMPPEntity from libervia.backend.core.i18n import _ from libervia.backend.core.log import getLogger -from libervia.backend.tools.common import data_format +from libervia.backend.models.types import JIDType log = getLogger(__name__) @@ -46,20 +46,77 @@ C.PI_DESCRIPTION: _("""Search for XMPP entities"""), } RATIO_CUTOFF = 0.6 +# Used when a search term matches a substring of a relevant data. +PARTIAL_MATCH_RATIO = 0.8 MAX_CACHE_SIZE = 10 -@dataclass -class JidSearchItem: - entity: jid.JID +class JidSearchItem(BaseModel): + entity: JIDType name: str = "" - in_roster: bool = False - groups: list[str] | None = None exact_match: bool = False relevance: float | None = None -JidSearchCache = OrderedDict[str, list[JidSearchItem]] +class EntitySearchItem(JidSearchItem): + type: Literal["entity"] = "entity" + in_roster: bool = False + groups: list[str] | None = None + + +class RoomSearchItem(JidSearchItem): + type: Literal["room"] = "room" + local: bool = Field( + description="True if the room comes from a server local component." + ) + service_type: str | None = None + description: str | None = None + language: str | None = None + nusers: int | None = Field(default=None, ge=0) + anonymity_mode: str | None = None + is_open: bool | None = None + + +SearchItem = Annotated[EntitySearchItem | RoomSearchItem, Field(discriminator="type")] + + +class SearchItems(RootModel): + root: list[SearchItem] + + def __iter__(self) -> Iterator[SearchItem]: # type: ignore + return iter(self.root) + + def __getitem__(self, item) -> str: + return self.root[item] + + def __len__(self) -> int: + return len(self.root) + + def append(self, item: SearchItem) -> None: + self.root.append(item) + + def sort(self, key=None, reverse=False) -> None: + self.root.sort(key=key, reverse=reverse) # type: ignore + + +class Options(BaseModel): + entities: bool = Field( + default=False, description="Search for entities for direct chat" + ) + groupchat: bool = Field( + default=False, description="Search for group chats." + ) + allow_external: bool = Field( + default=False, description="Authorise doing request to external services." + ) + + +class CachedSearch(NamedTuple): + search_items: SearchItems + options: Options + + +JidSearchCache = OrderedDict[str, CachedSearch] class JidSearch: @@ -80,17 +137,17 @@ def _search(self, search_term: str, options_s: str, profile: str) -> defer.Deferred: client = self.host.get_client(profile) - d = defer.ensureDeferred( - self.search(client, search_term, data_format.deserialise(options_s)) - ) - d.addCallback( - lambda search_items: data_format.serialise([asdict(i) for i in search_items]) - ) + options = Options.model_validate_json(options_s) + d = defer.ensureDeferred(self.search(client, search_term, options)) + d.addCallback(lambda search_items: search_items.model_dump_json()) return d async def search( - self, client: SatXMPPEntity, search_term: str, options: Optional[dict] = None - ) -> List[JidSearchItem]: + self, + client: SatXMPPEntity, + search_term: str, + options: Options | dict | None = None, + ) -> SearchItems: """Searches for entities in various locations. @param client: The SatXMPPEntity client where the search is to be performed. @@ -98,6 +155,10 @@ @param options: Additional search options. @return: A list of matches found. """ + if options is None: + options = Options() + elif isinstance(options, dict): + options = Options(**options) search_term = search_term.strip().lower() sequence_matcher = difflib.SequenceMatcher() sequence_matcher.set_seq1(search_term) @@ -110,27 +171,33 @@ # Look for a match in the cache for cache_key in cache: if search_term.startswith(cache_key): + cached_data = cache[cache_key] + if cached_data.options != options: + log.debug("Ignoring cached data due to incompatible options.") + continue log.debug( f"Match found in cache for {search_term!r} in [{client.profile}]." ) # If an exact match is found, return the results as is if search_term == cache_key: log.debug("Exact match found in cache, reusing results.") - matches = cache[cache_key] + matches = cached_data.search_items else: # If only the beginning matches, filter the cache results log.debug("Prefix match found in cache, filtering results.") - matches = [] - for jid_search_item in cache[cache_key]: - self._process_matching( + matches = SearchItems([]) + for jid_search_item in cached_data.search_items: + self.process_matching( search_term, sequence_matcher, matches, jid_search_item ) cache.move_to_end(cache_key) break else: # If no match is found in the cache, perform a new search - matches = await self._perform_search(client, search_term, sequence_matcher) - cache[search_term] = matches + matches = await self.perform_search( + client, search_term, options, sequence_matcher + ) + cache[search_term] = CachedSearch(matches, options) if len(cache) > MAX_CACHE_SIZE: cache.popitem(last=False) @@ -144,7 +211,7 @@ pass else: matches.append( - JidSearchItem( + EntitySearchItem( entity=search_jid, in_roster=False, exact_match=True, @@ -153,29 +220,34 @@ ) matches.sort( - key=lambda item: (item.exact_match, item.relevance or 0, item.in_roster), + key=lambda item: ( + item.exact_match, item.relevance or 0, getattr(item, "in_roster", False) + ), reverse=True, ) return matches - def _process_matching( + def process_matching( self, search_term: str, sequence_matcher: difflib.SequenceMatcher, - matches: List[JidSearchItem], - item: JidSearchItem, + matches: SearchItems, + item: SearchItem, ) -> None: - """Process matching of items + """Process the matching of an item against a search term. - @param sequence_matcher: The sequence matcher to be used for the matching process. - @param matches: A list where the match is to be appended. - @param item: The item that to be matched. - @return: True if it was an exact match + This method checks if the given item is an exact match or if it has any + significant similarity to the search term. If a match is found, the item's + relevance score is set and the item is added to the matches list. + + @param sequence_matcher: The sequence matcher used for comparing strings. + @param matches: A list where matched items will be appended. + @param item: The item to be compared against the search term. """ - item_name_lower = item.name.lower() - item_entity_lower = item.entity.full().lower() + item_name_lower = item.name.strip().lower() + item_entity_lower = item.entity.userhost().lower() if search_term in (item_name_lower, item_entity_lower): item.exact_match = True @@ -185,12 +257,25 @@ item.exact_match = False - sequence_matcher.set_seq2(item_name_lower) - name_ratio = sequence_matcher.ratio() - if name_ratio >= RATIO_CUTOFF: - item.relevance = name_ratio - matches.append(item) - return + # Check if search_term is a substring of item_name_lower or item_entity_lower + if len(search_term) >= 3: + if item_name_lower and search_term in item_name_lower: + item.relevance = PARTIAL_MATCH_RATIO + matches.append(item) + return + + if search_term in item_entity_lower: + item.relevance = PARTIAL_MATCH_RATIO + matches.append(item) + return + + if item_name_lower: + sequence_matcher.set_seq2(item_name_lower) + name_ratio = sequence_matcher.ratio() + if name_ratio >= RATIO_CUTOFF: + item.relevance = name_ratio + matches.append(item) + return sequence_matcher.set_seq2(item_entity_lower) jid_ratio = sequence_matcher.ratio() @@ -208,7 +293,7 @@ matches.append(item) return - if item.groups: + if isinstance(item, EntitySearchItem) and item.groups: group_ratios = [] for group in item.groups: sequence_matcher.set_seq2(group.lower()) @@ -227,38 +312,52 @@ matches.append(item) return - async def _perform_search( + async def perform_search( self, client: SatXMPPEntity, search_term: str, + options: Options, sequence_matcher: difflib.SequenceMatcher, - ) -> List[JidSearchItem]: - """Performs a new search when no match is found in the cache. + ) -> SearchItems: + """Performs a new search. + + Cache is not used here. @param search_term: The query to be searched. @param sequence_matcher: The SequenceMatcher object to be used for matching. @return: A list of matches found. """ - matches = [] + matches = SearchItems([]) - try: - roster = client.roster - except AttributeError: - # components have no roster - roster = [] - else: - roster = client.roster.get_items() + if options.entities: + assert isinstance(client, SatXMPPClient) + try: + client.roster + except AttributeError: + # components have no roster + roster = [] + else: + roster = client.roster.get_items() - for roster_item in roster: - jid_search_item = JidSearchItem( - entity=roster_item.entity, - name=roster_item.name, - in_roster=True, - groups=list(roster_item.groups), - ) + for roster_item in roster: + jid_search_item = EntitySearchItem( + entity=roster_item.entity, + name=roster_item.name, + in_roster=True, + groups=list(roster_item.groups), + ) - self._process_matching( - search_term, sequence_matcher, matches, jid_search_item - ) + self.process_matching( + search_term, sequence_matcher, matches, jid_search_item + ) + + await self.host.trigger.async_point( + "JID_SEARCH_perform_search", + client, + search_term, + options, + sequence_matcher, + matches, + ) return matches