Mercurial > libervia-backend
comparison libervia/backend/plugins/plugin_misc_jid_search.py @ 4358:c8d089b0e478
plugin JID Search: Use Pydantic models + better handling of partial search.
author | Goffi <goffi@goffi.org> |
---|---|
date | Fri, 11 Apr 2025 18:19:28 +0200 |
parents | 0d7bb4df2343 |
children |
comparison
equal
deleted
inserted
replaced
4357:f43cbceba2a0 | 4358:c8d089b0e478 |
---|---|
15 | 15 |
16 # You should have received a copy of the GNU Affero General Public License | 16 # You should have received a copy of the GNU Affero General Public License |
17 # along with this program. If not, see <http://www.gnu.org/licenses/>. | 17 # along with this program. If not, see <http://www.gnu.org/licenses/>. |
18 | 18 |
19 from collections import OrderedDict | 19 from collections import OrderedDict |
20 from dataclasses import dataclass, asdict | |
21 import difflib | 20 import difflib |
22 from typing import List, Optional | 21 from typing import Annotated, Iterator, Literal, NamedTuple |
23 | 22 |
23 from pydantic import BaseModel, Field, RootModel | |
24 from twisted.internet import defer | 24 from twisted.internet import defer |
25 from twisted.words.protocols.jabber import jid | 25 from twisted.words.protocols.jabber import jid |
26 | 26 |
27 from libervia.backend.core.constants import Const as C | 27 from libervia.backend.core.constants import Const as C |
28 from libervia.backend.core.core_types import SatXMPPEntity | 28 from libervia.backend.core.core_types import SatXMPPClient, SatXMPPEntity |
29 from libervia.backend.core.i18n import _ | 29 from libervia.backend.core.i18n import _ |
30 from libervia.backend.core.log import getLogger | 30 from libervia.backend.core.log import getLogger |
31 from libervia.backend.tools.common import data_format | 31 from libervia.backend.models.types import JIDType |
32 | 32 |
33 log = getLogger(__name__) | 33 log = getLogger(__name__) |
34 | 34 |
35 | 35 |
36 PLUGIN_INFO = { | 36 PLUGIN_INFO = { |
44 C.PI_MAIN: "JidSearch", | 44 C.PI_MAIN: "JidSearch", |
45 C.PI_HANDLER: "no", | 45 C.PI_HANDLER: "no", |
46 C.PI_DESCRIPTION: _("""Search for XMPP entities"""), | 46 C.PI_DESCRIPTION: _("""Search for XMPP entities"""), |
47 } | 47 } |
48 RATIO_CUTOFF = 0.6 | 48 RATIO_CUTOFF = 0.6 |
49 # Used when a search term matches a substring of a relevant data. | |
50 PARTIAL_MATCH_RATIO = 0.8 | |
49 MAX_CACHE_SIZE = 10 | 51 MAX_CACHE_SIZE = 10 |
50 | 52 |
51 | 53 |
52 @dataclass | 54 class JidSearchItem(BaseModel): |
53 class JidSearchItem: | 55 entity: JIDType |
54 entity: jid.JID | |
55 name: str = "" | 56 name: str = "" |
57 exact_match: bool = False | |
58 relevance: float | None = None | |
59 | |
60 | |
61 class EntitySearchItem(JidSearchItem): | |
62 type: Literal["entity"] = "entity" | |
56 in_roster: bool = False | 63 in_roster: bool = False |
57 groups: list[str] | None = None | 64 groups: list[str] | None = None |
58 exact_match: bool = False | 65 |
59 relevance: float | None = None | 66 |
60 | 67 class RoomSearchItem(JidSearchItem): |
61 | 68 type: Literal["room"] = "room" |
62 JidSearchCache = OrderedDict[str, list[JidSearchItem]] | 69 local: bool = Field( |
70 description="True if the room comes from a server local component." | |
71 ) | |
72 service_type: str | None = None | |
73 description: str | None = None | |
74 language: str | None = None | |
75 nusers: int | None = Field(default=None, ge=0) | |
76 anonymity_mode: str | None = None | |
77 is_open: bool | None = None | |
78 | |
79 | |
80 SearchItem = Annotated[EntitySearchItem | RoomSearchItem, Field(discriminator="type")] | |
81 | |
82 | |
83 class SearchItems(RootModel): | |
84 root: list[SearchItem] | |
85 | |
86 def __iter__(self) -> Iterator[SearchItem]: # type: ignore | |
87 return iter(self.root) | |
88 | |
89 def __getitem__(self, item) -> str: | |
90 return self.root[item] | |
91 | |
92 def __len__(self) -> int: | |
93 return len(self.root) | |
94 | |
95 def append(self, item: SearchItem) -> None: | |
96 self.root.append(item) | |
97 | |
98 def sort(self, key=None, reverse=False) -> None: | |
99 self.root.sort(key=key, reverse=reverse) # type: ignore | |
100 | |
101 | |
102 class Options(BaseModel): | |
103 entities: bool = Field( | |
104 default=False, description="Search for entities for direct chat" | |
105 ) | |
106 groupchat: bool = Field( | |
107 default=False, description="Search for group chats." | |
108 ) | |
109 allow_external: bool = Field( | |
110 default=False, description="Authorise doing request to external services." | |
111 ) | |
112 | |
113 | |
114 class CachedSearch(NamedTuple): | |
115 search_items: SearchItems | |
116 options: Options | |
117 | |
118 | |
119 JidSearchCache = OrderedDict[str, CachedSearch] | |
63 | 120 |
64 | 121 |
65 class JidSearch: | 122 class JidSearch: |
66 def __init__(self, host) -> None: | 123 def __init__(self, host) -> None: |
67 log.info(f"plugin {PLUGIN_INFO[C.PI_NAME]!r} initialization") | 124 log.info(f"plugin {PLUGIN_INFO[C.PI_NAME]!r} initialization") |
78 def profile_connecting(self, client: SatXMPPEntity) -> None: | 135 def profile_connecting(self, client: SatXMPPEntity) -> None: |
79 client._jid_search_cache = JidSearchCache() | 136 client._jid_search_cache = JidSearchCache() |
80 | 137 |
81 def _search(self, search_term: str, options_s: str, profile: str) -> defer.Deferred: | 138 def _search(self, search_term: str, options_s: str, profile: str) -> defer.Deferred: |
82 client = self.host.get_client(profile) | 139 client = self.host.get_client(profile) |
83 d = defer.ensureDeferred( | 140 options = Options.model_validate_json(options_s) |
84 self.search(client, search_term, data_format.deserialise(options_s)) | 141 d = defer.ensureDeferred(self.search(client, search_term, options)) |
85 ) | 142 d.addCallback(lambda search_items: search_items.model_dump_json()) |
86 d.addCallback( | |
87 lambda search_items: data_format.serialise([asdict(i) for i in search_items]) | |
88 ) | |
89 return d | 143 return d |
90 | 144 |
91 async def search( | 145 async def search( |
92 self, client: SatXMPPEntity, search_term: str, options: Optional[dict] = None | 146 self, |
93 ) -> List[JidSearchItem]: | 147 client: SatXMPPEntity, |
148 search_term: str, | |
149 options: Options | dict | None = None, | |
150 ) -> SearchItems: | |
94 """Searches for entities in various locations. | 151 """Searches for entities in various locations. |
95 | 152 |
96 @param client: The SatXMPPEntity client where the search is to be performed. | 153 @param client: The SatXMPPEntity client where the search is to be performed. |
97 @param search_term: The query to be searched. | 154 @param search_term: The query to be searched. |
98 @param options: Additional search options. | 155 @param options: Additional search options. |
99 @return: A list of matches found. | 156 @return: A list of matches found. |
100 """ | 157 """ |
158 if options is None: | |
159 options = Options() | |
160 elif isinstance(options, dict): | |
161 options = Options(**options) | |
101 search_term = search_term.strip().lower() | 162 search_term = search_term.strip().lower() |
102 sequence_matcher = difflib.SequenceMatcher() | 163 sequence_matcher = difflib.SequenceMatcher() |
103 sequence_matcher.set_seq1(search_term) | 164 sequence_matcher.set_seq1(search_term) |
104 # FIXME: cache can give different results due to the filtering mechanism (if a | 165 # FIXME: cache can give different results due to the filtering mechanism (if a |
105 # cached search term match the beginning of current search term, its results a | 166 # cached search term match the beginning of current search term, its results a |
108 cache: JidSearchCache = client._jid_search_cache | 169 cache: JidSearchCache = client._jid_search_cache |
109 | 170 |
110 # Look for a match in the cache | 171 # Look for a match in the cache |
111 for cache_key in cache: | 172 for cache_key in cache: |
112 if search_term.startswith(cache_key): | 173 if search_term.startswith(cache_key): |
174 cached_data = cache[cache_key] | |
175 if cached_data.options != options: | |
176 log.debug("Ignoring cached data due to incompatible options.") | |
177 continue | |
113 log.debug( | 178 log.debug( |
114 f"Match found in cache for {search_term!r} in [{client.profile}]." | 179 f"Match found in cache for {search_term!r} in [{client.profile}]." |
115 ) | 180 ) |
116 # If an exact match is found, return the results as is | 181 # If an exact match is found, return the results as is |
117 if search_term == cache_key: | 182 if search_term == cache_key: |
118 log.debug("Exact match found in cache, reusing results.") | 183 log.debug("Exact match found in cache, reusing results.") |
119 matches = cache[cache_key] | 184 matches = cached_data.search_items |
120 else: | 185 else: |
121 # If only the beginning matches, filter the cache results | 186 # If only the beginning matches, filter the cache results |
122 log.debug("Prefix match found in cache, filtering results.") | 187 log.debug("Prefix match found in cache, filtering results.") |
123 matches = [] | 188 matches = SearchItems([]) |
124 for jid_search_item in cache[cache_key]: | 189 for jid_search_item in cached_data.search_items: |
125 self._process_matching( | 190 self.process_matching( |
126 search_term, sequence_matcher, matches, jid_search_item | 191 search_term, sequence_matcher, matches, jid_search_item |
127 ) | 192 ) |
128 cache.move_to_end(cache_key) | 193 cache.move_to_end(cache_key) |
129 break | 194 break |
130 else: | 195 else: |
131 # If no match is found in the cache, perform a new search | 196 # If no match is found in the cache, perform a new search |
132 matches = await self._perform_search(client, search_term, sequence_matcher) | 197 matches = await self.perform_search( |
133 cache[search_term] = matches | 198 client, search_term, options, sequence_matcher |
199 ) | |
200 cache[search_term] = CachedSearch(matches, options) | |
134 if len(cache) > MAX_CACHE_SIZE: | 201 if len(cache) > MAX_CACHE_SIZE: |
135 cache.popitem(last=False) | 202 cache.popitem(last=False) |
136 | 203 |
137 # If no exact match is found, but the search term is a valid JID, we add the JID | 204 # If no exact match is found, but the search term is a valid JID, we add the JID |
138 # as a result | 205 # as a result |
142 search_jid = jid.JID(search_term) | 209 search_jid = jid.JID(search_term) |
143 except jid.InvalidFormat: | 210 except jid.InvalidFormat: |
144 pass | 211 pass |
145 else: | 212 else: |
146 matches.append( | 213 matches.append( |
147 JidSearchItem( | 214 EntitySearchItem( |
148 entity=search_jid, | 215 entity=search_jid, |
149 in_roster=False, | 216 in_roster=False, |
150 exact_match=True, | 217 exact_match=True, |
151 relevance=1, | 218 relevance=1, |
152 ) | 219 ) |
153 ) | 220 ) |
154 | 221 |
155 matches.sort( | 222 matches.sort( |
156 key=lambda item: (item.exact_match, item.relevance or 0, item.in_roster), | 223 key=lambda item: ( |
224 item.exact_match, item.relevance or 0, getattr(item, "in_roster", False) | |
225 ), | |
157 reverse=True, | 226 reverse=True, |
158 ) | 227 ) |
159 | 228 |
160 return matches | 229 return matches |
161 | 230 |
162 def _process_matching( | 231 def process_matching( |
163 self, | 232 self, |
164 search_term: str, | 233 search_term: str, |
165 sequence_matcher: difflib.SequenceMatcher, | 234 sequence_matcher: difflib.SequenceMatcher, |
166 matches: List[JidSearchItem], | 235 matches: SearchItems, |
167 item: JidSearchItem, | 236 item: SearchItem, |
168 ) -> None: | 237 ) -> None: |
169 """Process matching of items | 238 """Process the matching of an item against a search term. |
170 | 239 |
171 @param sequence_matcher: The sequence matcher to be used for the matching process. | 240 This method checks if the given item is an exact match or if it has any |
172 @param matches: A list where the match is to be appended. | 241 significant similarity to the search term. If a match is found, the item's |
173 @param item: The item that to be matched. | 242 relevance score is set and the item is added to the matches list. |
174 @return: True if it was an exact match | 243 |
244 @param sequence_matcher: The sequence matcher used for comparing strings. | |
245 @param matches: A list where matched items will be appended. | |
246 @param item: The item to be compared against the search term. | |
175 """ | 247 """ |
176 | 248 |
177 item_name_lower = item.name.lower() | 249 item_name_lower = item.name.strip().lower() |
178 item_entity_lower = item.entity.full().lower() | 250 item_entity_lower = item.entity.userhost().lower() |
179 | 251 |
180 if search_term in (item_name_lower, item_entity_lower): | 252 if search_term in (item_name_lower, item_entity_lower): |
181 item.exact_match = True | 253 item.exact_match = True |
182 item.relevance = 1 | 254 item.relevance = 1 |
183 matches.append(item) | 255 matches.append(item) |
184 return | 256 return |
185 | 257 |
186 item.exact_match = False | 258 item.exact_match = False |
187 | 259 |
188 sequence_matcher.set_seq2(item_name_lower) | 260 # Check if search_term is a substring of item_name_lower or item_entity_lower |
189 name_ratio = sequence_matcher.ratio() | 261 if len(search_term) >= 3: |
190 if name_ratio >= RATIO_CUTOFF: | 262 if item_name_lower and search_term in item_name_lower: |
191 item.relevance = name_ratio | 263 item.relevance = PARTIAL_MATCH_RATIO |
192 matches.append(item) | 264 matches.append(item) |
193 return | 265 return |
266 | |
267 if search_term in item_entity_lower: | |
268 item.relevance = PARTIAL_MATCH_RATIO | |
269 matches.append(item) | |
270 return | |
271 | |
272 if item_name_lower: | |
273 sequence_matcher.set_seq2(item_name_lower) | |
274 name_ratio = sequence_matcher.ratio() | |
275 if name_ratio >= RATIO_CUTOFF: | |
276 item.relevance = name_ratio | |
277 matches.append(item) | |
278 return | |
194 | 279 |
195 sequence_matcher.set_seq2(item_entity_lower) | 280 sequence_matcher.set_seq2(item_entity_lower) |
196 jid_ratio = sequence_matcher.ratio() | 281 jid_ratio = sequence_matcher.ratio() |
197 if jid_ratio >= RATIO_CUTOFF: | 282 if jid_ratio >= RATIO_CUTOFF: |
198 item.relevance = jid_ratio | 283 item.relevance = jid_ratio |
206 if domain_ratio >= RATIO_CUTOFF: | 291 if domain_ratio >= RATIO_CUTOFF: |
207 item.relevance = domain_ratio | 292 item.relevance = domain_ratio |
208 matches.append(item) | 293 matches.append(item) |
209 return | 294 return |
210 | 295 |
211 if item.groups: | 296 if isinstance(item, EntitySearchItem) and item.groups: |
212 group_ratios = [] | 297 group_ratios = [] |
213 for group in item.groups: | 298 for group in item.groups: |
214 sequence_matcher.set_seq2(group.lower()) | 299 sequence_matcher.set_seq2(group.lower()) |
215 group_ratios.append(sequence_matcher.ratio()) | 300 group_ratios.append(sequence_matcher.ratio()) |
216 group_ratio = max(group_ratios) | 301 group_ratio = max(group_ratios) |
225 if domain_ratio >= RATIO_CUTOFF: | 310 if domain_ratio >= RATIO_CUTOFF: |
226 item.relevance = domain_ratio | 311 item.relevance = domain_ratio |
227 matches.append(item) | 312 matches.append(item) |
228 return | 313 return |
229 | 314 |
230 async def _perform_search( | 315 async def perform_search( |
231 self, | 316 self, |
232 client: SatXMPPEntity, | 317 client: SatXMPPEntity, |
233 search_term: str, | 318 search_term: str, |
319 options: Options, | |
234 sequence_matcher: difflib.SequenceMatcher, | 320 sequence_matcher: difflib.SequenceMatcher, |
235 ) -> List[JidSearchItem]: | 321 ) -> SearchItems: |
236 """Performs a new search when no match is found in the cache. | 322 """Performs a new search. |
323 | |
324 Cache is not used here. | |
237 | 325 |
238 @param search_term: The query to be searched. | 326 @param search_term: The query to be searched. |
239 @param sequence_matcher: The SequenceMatcher object to be used for matching. | 327 @param sequence_matcher: The SequenceMatcher object to be used for matching. |
240 @return: A list of matches found. | 328 @return: A list of matches found. |
241 """ | 329 """ |
242 matches = [] | 330 matches = SearchItems([]) |
243 | 331 |
244 try: | 332 if options.entities: |
245 roster = client.roster | 333 assert isinstance(client, SatXMPPClient) |
246 except AttributeError: | 334 try: |
247 # components have no roster | 335 client.roster |
248 roster = [] | 336 except AttributeError: |
249 else: | 337 # components have no roster |
250 roster = client.roster.get_items() | 338 roster = [] |
251 | 339 else: |
252 for roster_item in roster: | 340 roster = client.roster.get_items() |
253 jid_search_item = JidSearchItem( | 341 |
254 entity=roster_item.entity, | 342 for roster_item in roster: |
255 name=roster_item.name, | 343 jid_search_item = EntitySearchItem( |
256 in_roster=True, | 344 entity=roster_item.entity, |
257 groups=list(roster_item.groups), | 345 name=roster_item.name, |
258 ) | 346 in_roster=True, |
259 | 347 groups=list(roster_item.groups), |
260 self._process_matching( | 348 ) |
261 search_term, sequence_matcher, matches, jid_search_item | 349 |
262 ) | 350 self.process_matching( |
351 search_term, sequence_matcher, matches, jid_search_item | |
352 ) | |
353 | |
354 await self.host.trigger.async_point( | |
355 "JID_SEARCH_perform_search", | |
356 client, | |
357 search_term, | |
358 options, | |
359 sequence_matcher, | |
360 matches, | |
361 ) | |
263 | 362 |
264 return matches | 363 return matches |