view libervia/backend/plugins/plugin_misc_jid_search.py @ 4370:0eaa50f21efb

plugin XEP-0461: Message Replies implementation: Implement message replies. Thread ID are always added when a reply is initiated from Libervia, so a thread can continue the reply. rel 457
author Goffi <goffi@goffi.org>
date Tue, 06 May 2025 00:34:01 +0200
parents c8d089b0e478
children
line wrap: on
line source

#!/usr/bin/env python3

# Libervia plugin to handle XMPP entities search
# Copyright (C) 2009-2023 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from collections import OrderedDict
import difflib
from typing import Annotated, Iterator, Literal, NamedTuple

from pydantic import BaseModel, Field, RootModel
from twisted.internet import defer
from twisted.words.protocols.jabber import jid

from libervia.backend.core.constants import Const as C
from libervia.backend.core.core_types import SatXMPPClient, SatXMPPEntity
from libervia.backend.core.i18n import _
from libervia.backend.core.log import getLogger
from libervia.backend.models.types import JIDType

log = getLogger(__name__)


PLUGIN_INFO = {
    C.PI_NAME: "JID Search",
    C.PI_IMPORT_NAME: "JID_SEARCH",
    C.PI_TYPE: C.PLUG_TYPE_MISC,
    C.PI_MODES: C.PLUG_MODE_BOTH,
    C.PI_PROTOCOLS: [],
    C.PI_DEPENDENCIES: [],
    C.PI_RECOMMENDATIONS: [],
    C.PI_MAIN: "JidSearch",
    C.PI_HANDLER: "no",
    C.PI_DESCRIPTION: _("""Search for XMPP entities"""),
}
RATIO_CUTOFF = 0.6
# Used when a search term matches a substring of a relevant data.
PARTIAL_MATCH_RATIO = 0.8
MAX_CACHE_SIZE = 10


class JidSearchItem(BaseModel):
    entity: JIDType
    name: str = ""
    exact_match: bool = False
    relevance: float | None = None


class EntitySearchItem(JidSearchItem):
    type: Literal["entity"] = "entity"
    in_roster: bool = False
    groups: list[str] | None = None


class RoomSearchItem(JidSearchItem):
    type: Literal["room"] = "room"
    local: bool = Field(
        description="True if the room comes from a server local component."
    )
    service_type: str | None = None
    description: str | None = None
    language: str | None = None
    nusers: int | None = Field(default=None, ge=0)
    anonymity_mode: str | None = None
    is_open: bool | None = None


SearchItem = Annotated[EntitySearchItem | RoomSearchItem, Field(discriminator="type")]


class SearchItems(RootModel):
    root: list[SearchItem]

    def __iter__(self) -> Iterator[SearchItem]:  # type: ignore
        return iter(self.root)

    def __getitem__(self, item) -> str:
        return self.root[item]

    def __len__(self) -> int:
        return len(self.root)

    def append(self, item: SearchItem) -> None:
        self.root.append(item)

    def sort(self, key=None, reverse=False) -> None:
        self.root.sort(key=key, reverse=reverse)  # type: ignore


class Options(BaseModel):
    entities: bool = Field(
        default=False, description="Search for entities for direct chat"
    )
    groupchat: bool = Field(
        default=False, description="Search for group chats."
    )
    allow_external: bool = Field(
        default=False, description="Authorise doing request to external services."
    )


class CachedSearch(NamedTuple):
    search_items: SearchItems
    options: Options


JidSearchCache = OrderedDict[str, CachedSearch]


class JidSearch:
    def __init__(self, host) -> None:
        log.info(f"plugin {PLUGIN_INFO[C.PI_NAME]!r} initialization")
        self.host = host
        host.bridge.add_method(
            "jid_search",
            ".plugin",
            in_sign="sss",
            out_sign="s",
            method=self._search,
            async_=True,
        )

    def profile_connecting(self, client: SatXMPPEntity) -> None:
        client._jid_search_cache = JidSearchCache()

    def _search(self, search_term: str, options_s: str, profile: str) -> defer.Deferred:
        client = self.host.get_client(profile)
        options = Options.model_validate_json(options_s)
        d = defer.ensureDeferred(self.search(client, search_term, options))
        d.addCallback(lambda search_items: search_items.model_dump_json())
        return d

    async def search(
        self,
        client: SatXMPPEntity,
        search_term: str,
        options: Options | dict | None = None,
    ) -> SearchItems:
        """Searches for entities in various locations.

        @param client: The SatXMPPEntity client where the search is to be performed.
        @param search_term: The query to be searched.
        @param options: Additional search options.
        @return: A list of matches found.
        """
        if options is None:
            options = Options()
        elif isinstance(options, dict):
            options = Options(**options)
        search_term = search_term.strip().lower()
        sequence_matcher = difflib.SequenceMatcher()
        sequence_matcher.set_seq1(search_term)
        # FIXME: cache can give different results due to the filtering mechanism (if a
        #   cached search term match the beginning of current search term, its results a
        #   re-used and filtered, and sometimes items can be missing in compraison to the
        #   results without caching). This may need to be fixed.
        cache: JidSearchCache = client._jid_search_cache

        # Look for a match in the cache
        for cache_key in cache:
            if search_term.startswith(cache_key):
                cached_data = cache[cache_key]
                if cached_data.options != options:
                    log.debug("Ignoring cached data due to incompatible options.")
                    continue
                log.debug(
                    f"Match found in cache for {search_term!r} in [{client.profile}]."
                )
                # If an exact match is found, return the results as is
                if search_term == cache_key:
                    log.debug("Exact match found in cache, reusing results.")
                    matches = cached_data.search_items
                else:
                    # If only the beginning matches, filter the cache results
                    log.debug("Prefix match found in cache, filtering results.")
                    matches = SearchItems([])
                    for jid_search_item in cached_data.search_items:
                        self.process_matching(
                            search_term, sequence_matcher, matches, jid_search_item
                        )
                cache.move_to_end(cache_key)
                break
        else:
            # If no match is found in the cache, perform a new search
            matches = await self.perform_search(
                client, search_term, options, sequence_matcher
            )
            cache[search_term] = CachedSearch(matches, options)
            if len(cache) > MAX_CACHE_SIZE:
                cache.popitem(last=False)

        # If no exact match is found, but the search term is a valid JID, we add the JID
        # as a result
        exact_match = any(m.exact_match for m in matches)
        if not exact_match and "@" in search_term:
            try:
                search_jid = jid.JID(search_term)
            except jid.InvalidFormat:
                pass
            else:
                matches.append(
                    EntitySearchItem(
                        entity=search_jid,
                        in_roster=False,
                        exact_match=True,
                        relevance=1,
                    )
                )

        matches.sort(
            key=lambda item: (
                item.exact_match, item.relevance or 0, getattr(item, "in_roster", False)
            ),
            reverse=True,
        )

        return matches

    def process_matching(
        self,
        search_term: str,
        sequence_matcher: difflib.SequenceMatcher,
        matches: SearchItems,
        item: SearchItem,
    ) -> None:
        """Process the matching of an item against a search term.

        This method checks if the given item is an exact match or if it has any
        significant similarity to the search term. If a match is found, the item's
        relevance score is set and the item is added to the matches list.

        @param sequence_matcher: The sequence matcher used for comparing strings.
        @param matches: A list where matched items will be appended.
        @param item: The item to be compared against the search term.
        """

        item_name_lower = item.name.strip().lower()
        item_entity_lower = item.entity.userhost().lower()

        if search_term in (item_name_lower, item_entity_lower):
            item.exact_match = True
            item.relevance = 1
            matches.append(item)
            return

        item.exact_match = False

        # Check if search_term is a substring of item_name_lower or item_entity_lower
        if len(search_term) >= 3:
            if item_name_lower and search_term in item_name_lower:
                item.relevance = PARTIAL_MATCH_RATIO
                matches.append(item)
                return

            if search_term in item_entity_lower:
                item.relevance = PARTIAL_MATCH_RATIO
                matches.append(item)
                return

        if item_name_lower:
            sequence_matcher.set_seq2(item_name_lower)
            name_ratio = sequence_matcher.ratio()
            if name_ratio >= RATIO_CUTOFF:
                item.relevance = name_ratio
                matches.append(item)
                return

        sequence_matcher.set_seq2(item_entity_lower)
        jid_ratio = sequence_matcher.ratio()
        if jid_ratio >= RATIO_CUTOFF:
            item.relevance = jid_ratio
            matches.append(item)
            return

        localpart = item.entity.user.lower() if item.entity.user else ""
        if localpart:
            sequence_matcher.set_seq2(localpart)
            domain_ratio = sequence_matcher.ratio()
            if domain_ratio >= RATIO_CUTOFF:
                item.relevance = domain_ratio
                matches.append(item)
                return

        if isinstance(item, EntitySearchItem) and item.groups:
            group_ratios = []
            for group in item.groups:
                sequence_matcher.set_seq2(group.lower())
                group_ratios.append(sequence_matcher.ratio())
            group_ratio = max(group_ratios)
            if group_ratio >= RATIO_CUTOFF:
                item.relevance = group_ratio
                matches.append(item)
                return

        domain = item.entity.host.lower()
        sequence_matcher.set_seq2(domain)
        domain_ratio = sequence_matcher.ratio()
        if domain_ratio >= RATIO_CUTOFF:
            item.relevance = domain_ratio
            matches.append(item)
            return

    async def perform_search(
        self,
        client: SatXMPPEntity,
        search_term: str,
        options: Options,
        sequence_matcher: difflib.SequenceMatcher,
    ) -> SearchItems:
        """Performs a new search.

        Cache is not used here.

        @param search_term: The query to be searched.
        @param sequence_matcher: The SequenceMatcher object to be used for matching.
        @return: A list of matches found.
        """
        matches = SearchItems([])

        if options.entities:
            assert isinstance(client, SatXMPPClient)
            try:
                client.roster
            except AttributeError:
                # components have no roster
                roster = []
            else:
                roster = client.roster.get_items()

            for roster_item in roster:
                jid_search_item = EntitySearchItem(
                    entity=roster_item.entity,
                    name=roster_item.name,
                    in_roster=True,
                    groups=list(roster_item.groups),
                )

                self.process_matching(
                    search_term, sequence_matcher, matches, jid_search_item
                )

        await self.host.trigger.async_point(
            "JID_SEARCH_perform_search",
            client,
            search_term,
            options,
            sequence_matcher,
            matches,
        )

        return matches