Mercurial > libervia-backend
view libervia/backend/plugins/plugin_misc_jid_search.py @ 4369:b74a76a8e168
plugin XEP-0045: Fix `_message_parse_trigger` which was incorrectly breaking the trigger workflow:
`None` was returned in some case instead of `True`, breaking the trigger workflow.
author | Goffi <goffi@goffi.org> |
---|---|
date | Tue, 06 May 2025 00:34:01 +0200 |
parents | c8d089b0e478 |
children |
line wrap: on
line source
#!/usr/bin/env python3 # Libervia plugin to handle XMPP entities search # Copyright (C) 2009-2023 Jérôme Poisson (goffi@goffi.org) # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. from collections import OrderedDict import difflib from typing import Annotated, Iterator, Literal, NamedTuple from pydantic import BaseModel, Field, RootModel from twisted.internet import defer from twisted.words.protocols.jabber import jid from libervia.backend.core.constants import Const as C from libervia.backend.core.core_types import SatXMPPClient, SatXMPPEntity from libervia.backend.core.i18n import _ from libervia.backend.core.log import getLogger from libervia.backend.models.types import JIDType log = getLogger(__name__) PLUGIN_INFO = { C.PI_NAME: "JID Search", C.PI_IMPORT_NAME: "JID_SEARCH", C.PI_TYPE: C.PLUG_TYPE_MISC, C.PI_MODES: C.PLUG_MODE_BOTH, C.PI_PROTOCOLS: [], C.PI_DEPENDENCIES: [], C.PI_RECOMMENDATIONS: [], C.PI_MAIN: "JidSearch", C.PI_HANDLER: "no", C.PI_DESCRIPTION: _("""Search for XMPP entities"""), } RATIO_CUTOFF = 0.6 # Used when a search term matches a substring of a relevant data. PARTIAL_MATCH_RATIO = 0.8 MAX_CACHE_SIZE = 10 class JidSearchItem(BaseModel): entity: JIDType name: str = "" exact_match: bool = False relevance: float | None = None class EntitySearchItem(JidSearchItem): type: Literal["entity"] = "entity" in_roster: bool = False groups: list[str] | None = None class RoomSearchItem(JidSearchItem): type: Literal["room"] = "room" local: bool = Field( description="True if the room comes from a server local component." ) service_type: str | None = None description: str | None = None language: str | None = None nusers: int | None = Field(default=None, ge=0) anonymity_mode: str | None = None is_open: bool | None = None SearchItem = Annotated[EntitySearchItem | RoomSearchItem, Field(discriminator="type")] class SearchItems(RootModel): root: list[SearchItem] def __iter__(self) -> Iterator[SearchItem]: # type: ignore return iter(self.root) def __getitem__(self, item) -> str: return self.root[item] def __len__(self) -> int: return len(self.root) def append(self, item: SearchItem) -> None: self.root.append(item) def sort(self, key=None, reverse=False) -> None: self.root.sort(key=key, reverse=reverse) # type: ignore class Options(BaseModel): entities: bool = Field( default=False, description="Search for entities for direct chat" ) groupchat: bool = Field( default=False, description="Search for group chats." ) allow_external: bool = Field( default=False, description="Authorise doing request to external services." ) class CachedSearch(NamedTuple): search_items: SearchItems options: Options JidSearchCache = OrderedDict[str, CachedSearch] class JidSearch: def __init__(self, host) -> None: log.info(f"plugin {PLUGIN_INFO[C.PI_NAME]!r} initialization") self.host = host host.bridge.add_method( "jid_search", ".plugin", in_sign="sss", out_sign="s", method=self._search, async_=True, ) def profile_connecting(self, client: SatXMPPEntity) -> None: client._jid_search_cache = JidSearchCache() def _search(self, search_term: str, options_s: str, profile: str) -> defer.Deferred: client = self.host.get_client(profile) options = Options.model_validate_json(options_s) d = defer.ensureDeferred(self.search(client, search_term, options)) d.addCallback(lambda search_items: search_items.model_dump_json()) return d async def search( self, client: SatXMPPEntity, search_term: str, options: Options | dict | None = None, ) -> SearchItems: """Searches for entities in various locations. @param client: The SatXMPPEntity client where the search is to be performed. @param search_term: The query to be searched. @param options: Additional search options. @return: A list of matches found. """ if options is None: options = Options() elif isinstance(options, dict): options = Options(**options) search_term = search_term.strip().lower() sequence_matcher = difflib.SequenceMatcher() sequence_matcher.set_seq1(search_term) # FIXME: cache can give different results due to the filtering mechanism (if a # cached search term match the beginning of current search term, its results a # re-used and filtered, and sometimes items can be missing in compraison to the # results without caching). This may need to be fixed. cache: JidSearchCache = client._jid_search_cache # Look for a match in the cache for cache_key in cache: if search_term.startswith(cache_key): cached_data = cache[cache_key] if cached_data.options != options: log.debug("Ignoring cached data due to incompatible options.") continue log.debug( f"Match found in cache for {search_term!r} in [{client.profile}]." ) # If an exact match is found, return the results as is if search_term == cache_key: log.debug("Exact match found in cache, reusing results.") matches = cached_data.search_items else: # If only the beginning matches, filter the cache results log.debug("Prefix match found in cache, filtering results.") matches = SearchItems([]) for jid_search_item in cached_data.search_items: self.process_matching( search_term, sequence_matcher, matches, jid_search_item ) cache.move_to_end(cache_key) break else: # If no match is found in the cache, perform a new search matches = await self.perform_search( client, search_term, options, sequence_matcher ) cache[search_term] = CachedSearch(matches, options) if len(cache) > MAX_CACHE_SIZE: cache.popitem(last=False) # If no exact match is found, but the search term is a valid JID, we add the JID # as a result exact_match = any(m.exact_match for m in matches) if not exact_match and "@" in search_term: try: search_jid = jid.JID(search_term) except jid.InvalidFormat: pass else: matches.append( EntitySearchItem( entity=search_jid, in_roster=False, exact_match=True, relevance=1, ) ) matches.sort( key=lambda item: ( item.exact_match, item.relevance or 0, getattr(item, "in_roster", False) ), reverse=True, ) return matches def process_matching( self, search_term: str, sequence_matcher: difflib.SequenceMatcher, matches: SearchItems, item: SearchItem, ) -> None: """Process the matching of an item against a search term. This method checks if the given item is an exact match or if it has any significant similarity to the search term. If a match is found, the item's relevance score is set and the item is added to the matches list. @param sequence_matcher: The sequence matcher used for comparing strings. @param matches: A list where matched items will be appended. @param item: The item to be compared against the search term. """ item_name_lower = item.name.strip().lower() item_entity_lower = item.entity.userhost().lower() if search_term in (item_name_lower, item_entity_lower): item.exact_match = True item.relevance = 1 matches.append(item) return item.exact_match = False # Check if search_term is a substring of item_name_lower or item_entity_lower if len(search_term) >= 3: if item_name_lower and search_term in item_name_lower: item.relevance = PARTIAL_MATCH_RATIO matches.append(item) return if search_term in item_entity_lower: item.relevance = PARTIAL_MATCH_RATIO matches.append(item) return if item_name_lower: sequence_matcher.set_seq2(item_name_lower) name_ratio = sequence_matcher.ratio() if name_ratio >= RATIO_CUTOFF: item.relevance = name_ratio matches.append(item) return sequence_matcher.set_seq2(item_entity_lower) jid_ratio = sequence_matcher.ratio() if jid_ratio >= RATIO_CUTOFF: item.relevance = jid_ratio matches.append(item) return localpart = item.entity.user.lower() if item.entity.user else "" if localpart: sequence_matcher.set_seq2(localpart) domain_ratio = sequence_matcher.ratio() if domain_ratio >= RATIO_CUTOFF: item.relevance = domain_ratio matches.append(item) return if isinstance(item, EntitySearchItem) and item.groups: group_ratios = [] for group in item.groups: sequence_matcher.set_seq2(group.lower()) group_ratios.append(sequence_matcher.ratio()) group_ratio = max(group_ratios) if group_ratio >= RATIO_CUTOFF: item.relevance = group_ratio matches.append(item) return domain = item.entity.host.lower() sequence_matcher.set_seq2(domain) domain_ratio = sequence_matcher.ratio() if domain_ratio >= RATIO_CUTOFF: item.relevance = domain_ratio matches.append(item) return async def perform_search( self, client: SatXMPPEntity, search_term: str, options: Options, sequence_matcher: difflib.SequenceMatcher, ) -> SearchItems: """Performs a new search. Cache is not used here. @param search_term: The query to be searched. @param sequence_matcher: The SequenceMatcher object to be used for matching. @return: A list of matches found. """ matches = SearchItems([]) if options.entities: assert isinstance(client, SatXMPPClient) try: client.roster except AttributeError: # components have no roster roster = [] else: roster = client.roster.get_items() for roster_item in roster: jid_search_item = EntitySearchItem( entity=roster_item.entity, name=roster_item.name, in_roster=True, groups=list(roster_item.groups), ) self.process_matching( search_term, sequence_matcher, matches, jid_search_item ) await self.host.trigger.async_point( "JID_SEARCH_perform_search", client, search_term, options, sequence_matcher, matches, ) return matches