view libervia/backend/plugins/plugin_misc_jid_search.py @ 4246:5eb13251fd75

tests (unit/XEP-0272): XEP-0272 tests: fix 429
author Goffi <goffi@goffi.org>
date Wed, 15 May 2024 17:35:16 +0200
parents 238e305f2306
children 0d7bb4df2343
line wrap: on
line source

#!/usr/bin/env python3

# Libervia plugin to handle XMPP entities search
# Copyright (C) 2009-2023 Jérôme Poisson (goffi@goffi.org)

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from collections import OrderedDict
from dataclasses import dataclass, asdict
import difflib
from typing import List, Optional

from twisted.internet import defer
from twisted.words.protocols.jabber import jid

from libervia.backend.core.constants import Const as C
from libervia.backend.core.core_types import SatXMPPEntity
from libervia.backend.core.i18n import _
from libervia.backend.core.log import getLogger
from libervia.backend.tools.common import data_format

log = getLogger(__name__)


PLUGIN_INFO = {
    C.PI_NAME: "JID Search",
    C.PI_IMPORT_NAME: "JID_SEARCH",
    C.PI_TYPE: C.PLUG_TYPE_MISC,
    C.PI_MODES: C.PLUG_MODE_BOTH,
    C.PI_PROTOCOLS: [],
    C.PI_DEPENDENCIES: [],
    C.PI_RECOMMENDATIONS: [],
    C.PI_MAIN: "JidSearch",
    C.PI_HANDLER: "no",
    C.PI_DESCRIPTION: _("""Search for XMPP entities"""),
}
RATIO_CUTOFF = 0.6
MAX_CACHE_SIZE = 10


@dataclass
class JidSearchItem:
    entity: jid.JID
    name: str = ""
    in_roster: bool = False
    groups: list[str] | None = None
    exact_match: bool = False
    relevance: float | None = None


JidSearchCache = OrderedDict[str, list[JidSearchItem]]


class JidSearch:
    def __init__(self, host) -> None:
        log.info(f"plugin {PLUGIN_INFO[C.PI_NAME]!r} initialization")
        self.host = host
        host.bridge.add_method(
            "jid_search",
            ".plugin",
            in_sign="sss",
            out_sign="s",
            method=self._search,
            async_=True,
        )

    def profile_connecting(self, client: SatXMPPEntity) -> None:
        client._jid_search_cache = JidSearchCache()

    def _search(self, search_term: str, options_s: str, profile: str) -> defer.Deferred:
        client = self.host.get_client(profile)
        d = defer.ensureDeferred(
            self.search(client, search_term, data_format.deserialise(options_s))
        )
        d.addCallback(
            lambda search_items: data_format.serialise([asdict(i) for i in search_items])
        )
        return d

    async def search(
        self, client: SatXMPPEntity, search_term: str, options: Optional[dict] = None
    ) -> List[JidSearchItem]:
        """Searches for entities in various locations.

        @param client: The SatXMPPEntity client where the search is to be performed.
        @param search_term: The query to be searched.
        @param options: Additional search options.
        @return: A list of matches found.
        """
        search_term = search_term.strip().lower()
        sequence_matcher = difflib.SequenceMatcher()
        sequence_matcher.set_seq1(search_term)
        # FIXME: cache can give different results due to the filtering mechanism (if a
        #   cached search term match the beginning of current search term, its results a
        #   re-used and filtered, and sometimes items can be missing in compraison to the
        #   results without caching). This may need to be fixed.
        cache: JidSearchCache = client._jid_search_cache

        # Look for a match in the cache
        for cache_key in cache:
            if search_term.startswith(cache_key):
                log.debug(
                    f"Match found in cache for {search_term!r} in [{client.profile}]."
                )
                # If an exact match is found, return the results as is
                if search_term == cache_key:
                    log.debug("Exact match found in cache, reusing results.")
                    matches = cache[cache_key]
                else:
                    # If only the beginning matches, filter the cache results
                    log.debug("Prefix match found in cache, filtering results.")
                    matches = []
                    for jid_search_item in cache[cache_key]:
                        self._process_matching(
                            search_term, sequence_matcher, matches, jid_search_item
                        )
                cache.move_to_end(cache_key)
                break
        else:
            # If no match is found in the cache, perform a new search
            matches = await self._perform_search(client, search_term, sequence_matcher)
            cache[search_term] = matches
            if len(cache) > MAX_CACHE_SIZE:
                cache.popitem(last=False)

        # If no exact match is found, but the search term is a valid JID, we add the JID
        # as a result
        exact_match = any(m.exact_match for m in matches)
        if not exact_match and "@" in search_term:
            try:
                search_jid = jid.JID(search_term)
            except jid.InvalidFormat:
                pass
            else:
                matches.append(
                    JidSearchItem(
                        entity=search_jid,
                        in_roster=False,
                        exact_match=True,
                        relevance=1,
                    )
                )


        matches.sort(
            key=lambda item: (item.exact_match, item.relevance or 0, item.in_roster),
            reverse=True,
        )

        return matches

    def _process_matching(
        self,
        search_term: str,
        sequence_matcher: difflib.SequenceMatcher,
        matches: List[JidSearchItem],
        item: JidSearchItem,
    ) -> None:
        """Process matching of items

        @param sequence_matcher: The sequence matcher to be used for the matching process.
        @param matches: A list where the match is to be appended.
        @param item: The item that to be matched.
        @return: True if it was an exact match
        """

        item_name_lower = item.name.lower()
        item_entity_lower = item.entity.full().lower()

        if search_term in (item_name_lower, item_entity_lower):
            item.exact_match = True
            item.relevance = 1
            matches.append(item)
            return

        item.exact_match = False

        sequence_matcher.set_seq2(item_name_lower)
        name_ratio = sequence_matcher.ratio()
        if name_ratio >= RATIO_CUTOFF:
            item.relevance = name_ratio
            matches.append(item)
            return

        sequence_matcher.set_seq2(item_entity_lower)
        jid_ratio = sequence_matcher.ratio()
        if jid_ratio >= RATIO_CUTOFF:
            item.relevance = jid_ratio
            matches.append(item)
            return

        localpart = item.entity.user.lower() if item.entity.user else ""
        if localpart:
            sequence_matcher.set_seq2(localpart)
            domain_ratio = sequence_matcher.ratio()
            if domain_ratio >= RATIO_CUTOFF:
                item.relevance = domain_ratio
                matches.append(item)
                return

        if item.groups:
            group_ratios = []
            for group in item.groups:
                sequence_matcher.set_seq2(group.lower())
                group_ratios.append(sequence_matcher.ratio())
            group_ratio = max(group_ratios)
            if group_ratio >= RATIO_CUTOFF:
                item.relevance = group_ratio
                matches.append(item)
                return

        domain = item.entity.host.lower()
        sequence_matcher.set_seq2(domain)
        domain_ratio = sequence_matcher.ratio()
        if domain_ratio >= RATIO_CUTOFF:
            item.relevance = domain_ratio
            matches.append(item)
            return

    async def _perform_search(
        self,
        client: SatXMPPEntity,
        search_term: str,
        sequence_matcher: difflib.SequenceMatcher,
    ) -> List[JidSearchItem]:
        """Performs a new search when no match is found in the cache.

        @param search_term: The query to be searched.
        @param sequence_matcher: The SequenceMatcher object to be used for matching.
        @return: A list of matches found.
        """
        matches = []

        try:
            roster = client.roster
        except AttributeError:
            # components have no roster
            roster = []
        else:
            roster = client.roster.get_items()

        for roster_item in roster:
            jid_search_item = JidSearchItem(
                entity=roster_item.entity,
                name=roster_item.name,
                in_roster=True,
                groups=list(roster_item.groups),
            )

            self._process_matching(
                search_term, sequence_matcher, matches, jid_search_item
            )

        return matches