comparison libervia/backend/plugins/plugin_misc_jid_search.py @ 4358:c8d089b0e478

plugin JID Search: Use Pydantic models + better handling of partial search.
author Goffi <goffi@goffi.org>
date Fri, 11 Apr 2025 18:19:28 +0200
parents 0d7bb4df2343
children
comparison
equal deleted inserted replaced
4357:f43cbceba2a0 4358:c8d089b0e478
15 15
16 # You should have received a copy of the GNU Affero General Public License 16 # You should have received a copy of the GNU Affero General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>. 17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
18 18
19 from collections import OrderedDict 19 from collections import OrderedDict
20 from dataclasses import dataclass, asdict
21 import difflib 20 import difflib
22 from typing import List, Optional 21 from typing import Annotated, Iterator, Literal, NamedTuple
23 22
23 from pydantic import BaseModel, Field, RootModel
24 from twisted.internet import defer 24 from twisted.internet import defer
25 from twisted.words.protocols.jabber import jid 25 from twisted.words.protocols.jabber import jid
26 26
27 from libervia.backend.core.constants import Const as C 27 from libervia.backend.core.constants import Const as C
28 from libervia.backend.core.core_types import SatXMPPEntity 28 from libervia.backend.core.core_types import SatXMPPClient, SatXMPPEntity
29 from libervia.backend.core.i18n import _ 29 from libervia.backend.core.i18n import _
30 from libervia.backend.core.log import getLogger 30 from libervia.backend.core.log import getLogger
31 from libervia.backend.tools.common import data_format 31 from libervia.backend.models.types import JIDType
32 32
33 log = getLogger(__name__) 33 log = getLogger(__name__)
34 34
35 35
36 PLUGIN_INFO = { 36 PLUGIN_INFO = {
44 C.PI_MAIN: "JidSearch", 44 C.PI_MAIN: "JidSearch",
45 C.PI_HANDLER: "no", 45 C.PI_HANDLER: "no",
46 C.PI_DESCRIPTION: _("""Search for XMPP entities"""), 46 C.PI_DESCRIPTION: _("""Search for XMPP entities"""),
47 } 47 }
48 RATIO_CUTOFF = 0.6 48 RATIO_CUTOFF = 0.6
49 # Used when a search term matches a substring of a relevant data.
50 PARTIAL_MATCH_RATIO = 0.8
49 MAX_CACHE_SIZE = 10 51 MAX_CACHE_SIZE = 10
50 52
51 53
52 @dataclass 54 class JidSearchItem(BaseModel):
53 class JidSearchItem: 55 entity: JIDType
54 entity: jid.JID
55 name: str = "" 56 name: str = ""
57 exact_match: bool = False
58 relevance: float | None = None
59
60
61 class EntitySearchItem(JidSearchItem):
62 type: Literal["entity"] = "entity"
56 in_roster: bool = False 63 in_roster: bool = False
57 groups: list[str] | None = None 64 groups: list[str] | None = None
58 exact_match: bool = False 65
59 relevance: float | None = None 66
60 67 class RoomSearchItem(JidSearchItem):
61 68 type: Literal["room"] = "room"
62 JidSearchCache = OrderedDict[str, list[JidSearchItem]] 69 local: bool = Field(
70 description="True if the room comes from a server local component."
71 )
72 service_type: str | None = None
73 description: str | None = None
74 language: str | None = None
75 nusers: int | None = Field(default=None, ge=0)
76 anonymity_mode: str | None = None
77 is_open: bool | None = None
78
79
80 SearchItem = Annotated[EntitySearchItem | RoomSearchItem, Field(discriminator="type")]
81
82
83 class SearchItems(RootModel):
84 root: list[SearchItem]
85
86 def __iter__(self) -> Iterator[SearchItem]: # type: ignore
87 return iter(self.root)
88
89 def __getitem__(self, item) -> str:
90 return self.root[item]
91
92 def __len__(self) -> int:
93 return len(self.root)
94
95 def append(self, item: SearchItem) -> None:
96 self.root.append(item)
97
98 def sort(self, key=None, reverse=False) -> None:
99 self.root.sort(key=key, reverse=reverse) # type: ignore
100
101
102 class Options(BaseModel):
103 entities: bool = Field(
104 default=False, description="Search for entities for direct chat"
105 )
106 groupchat: bool = Field(
107 default=False, description="Search for group chats."
108 )
109 allow_external: bool = Field(
110 default=False, description="Authorise doing request to external services."
111 )
112
113
114 class CachedSearch(NamedTuple):
115 search_items: SearchItems
116 options: Options
117
118
119 JidSearchCache = OrderedDict[str, CachedSearch]
63 120
64 121
65 class JidSearch: 122 class JidSearch:
66 def __init__(self, host) -> None: 123 def __init__(self, host) -> None:
67 log.info(f"plugin {PLUGIN_INFO[C.PI_NAME]!r} initialization") 124 log.info(f"plugin {PLUGIN_INFO[C.PI_NAME]!r} initialization")
78 def profile_connecting(self, client: SatXMPPEntity) -> None: 135 def profile_connecting(self, client: SatXMPPEntity) -> None:
79 client._jid_search_cache = JidSearchCache() 136 client._jid_search_cache = JidSearchCache()
80 137
81 def _search(self, search_term: str, options_s: str, profile: str) -> defer.Deferred: 138 def _search(self, search_term: str, options_s: str, profile: str) -> defer.Deferred:
82 client = self.host.get_client(profile) 139 client = self.host.get_client(profile)
83 d = defer.ensureDeferred( 140 options = Options.model_validate_json(options_s)
84 self.search(client, search_term, data_format.deserialise(options_s)) 141 d = defer.ensureDeferred(self.search(client, search_term, options))
85 ) 142 d.addCallback(lambda search_items: search_items.model_dump_json())
86 d.addCallback(
87 lambda search_items: data_format.serialise([asdict(i) for i in search_items])
88 )
89 return d 143 return d
90 144
91 async def search( 145 async def search(
92 self, client: SatXMPPEntity, search_term: str, options: Optional[dict] = None 146 self,
93 ) -> List[JidSearchItem]: 147 client: SatXMPPEntity,
148 search_term: str,
149 options: Options | dict | None = None,
150 ) -> SearchItems:
94 """Searches for entities in various locations. 151 """Searches for entities in various locations.
95 152
96 @param client: The SatXMPPEntity client where the search is to be performed. 153 @param client: The SatXMPPEntity client where the search is to be performed.
97 @param search_term: The query to be searched. 154 @param search_term: The query to be searched.
98 @param options: Additional search options. 155 @param options: Additional search options.
99 @return: A list of matches found. 156 @return: A list of matches found.
100 """ 157 """
158 if options is None:
159 options = Options()
160 elif isinstance(options, dict):
161 options = Options(**options)
101 search_term = search_term.strip().lower() 162 search_term = search_term.strip().lower()
102 sequence_matcher = difflib.SequenceMatcher() 163 sequence_matcher = difflib.SequenceMatcher()
103 sequence_matcher.set_seq1(search_term) 164 sequence_matcher.set_seq1(search_term)
104 # FIXME: cache can give different results due to the filtering mechanism (if a 165 # FIXME: cache can give different results due to the filtering mechanism (if a
105 # cached search term match the beginning of current search term, its results a 166 # cached search term match the beginning of current search term, its results a
108 cache: JidSearchCache = client._jid_search_cache 169 cache: JidSearchCache = client._jid_search_cache
109 170
110 # Look for a match in the cache 171 # Look for a match in the cache
111 for cache_key in cache: 172 for cache_key in cache:
112 if search_term.startswith(cache_key): 173 if search_term.startswith(cache_key):
174 cached_data = cache[cache_key]
175 if cached_data.options != options:
176 log.debug("Ignoring cached data due to incompatible options.")
177 continue
113 log.debug( 178 log.debug(
114 f"Match found in cache for {search_term!r} in [{client.profile}]." 179 f"Match found in cache for {search_term!r} in [{client.profile}]."
115 ) 180 )
116 # If an exact match is found, return the results as is 181 # If an exact match is found, return the results as is
117 if search_term == cache_key: 182 if search_term == cache_key:
118 log.debug("Exact match found in cache, reusing results.") 183 log.debug("Exact match found in cache, reusing results.")
119 matches = cache[cache_key] 184 matches = cached_data.search_items
120 else: 185 else:
121 # If only the beginning matches, filter the cache results 186 # If only the beginning matches, filter the cache results
122 log.debug("Prefix match found in cache, filtering results.") 187 log.debug("Prefix match found in cache, filtering results.")
123 matches = [] 188 matches = SearchItems([])
124 for jid_search_item in cache[cache_key]: 189 for jid_search_item in cached_data.search_items:
125 self._process_matching( 190 self.process_matching(
126 search_term, sequence_matcher, matches, jid_search_item 191 search_term, sequence_matcher, matches, jid_search_item
127 ) 192 )
128 cache.move_to_end(cache_key) 193 cache.move_to_end(cache_key)
129 break 194 break
130 else: 195 else:
131 # If no match is found in the cache, perform a new search 196 # If no match is found in the cache, perform a new search
132 matches = await self._perform_search(client, search_term, sequence_matcher) 197 matches = await self.perform_search(
133 cache[search_term] = matches 198 client, search_term, options, sequence_matcher
199 )
200 cache[search_term] = CachedSearch(matches, options)
134 if len(cache) > MAX_CACHE_SIZE: 201 if len(cache) > MAX_CACHE_SIZE:
135 cache.popitem(last=False) 202 cache.popitem(last=False)
136 203
137 # If no exact match is found, but the search term is a valid JID, we add the JID 204 # If no exact match is found, but the search term is a valid JID, we add the JID
138 # as a result 205 # as a result
142 search_jid = jid.JID(search_term) 209 search_jid = jid.JID(search_term)
143 except jid.InvalidFormat: 210 except jid.InvalidFormat:
144 pass 211 pass
145 else: 212 else:
146 matches.append( 213 matches.append(
147 JidSearchItem( 214 EntitySearchItem(
148 entity=search_jid, 215 entity=search_jid,
149 in_roster=False, 216 in_roster=False,
150 exact_match=True, 217 exact_match=True,
151 relevance=1, 218 relevance=1,
152 ) 219 )
153 ) 220 )
154 221
155 matches.sort( 222 matches.sort(
156 key=lambda item: (item.exact_match, item.relevance or 0, item.in_roster), 223 key=lambda item: (
224 item.exact_match, item.relevance or 0, getattr(item, "in_roster", False)
225 ),
157 reverse=True, 226 reverse=True,
158 ) 227 )
159 228
160 return matches 229 return matches
161 230
162 def _process_matching( 231 def process_matching(
163 self, 232 self,
164 search_term: str, 233 search_term: str,
165 sequence_matcher: difflib.SequenceMatcher, 234 sequence_matcher: difflib.SequenceMatcher,
166 matches: List[JidSearchItem], 235 matches: SearchItems,
167 item: JidSearchItem, 236 item: SearchItem,
168 ) -> None: 237 ) -> None:
169 """Process matching of items 238 """Process the matching of an item against a search term.
170 239
171 @param sequence_matcher: The sequence matcher to be used for the matching process. 240 This method checks if the given item is an exact match or if it has any
172 @param matches: A list where the match is to be appended. 241 significant similarity to the search term. If a match is found, the item's
173 @param item: The item that to be matched. 242 relevance score is set and the item is added to the matches list.
174 @return: True if it was an exact match 243
244 @param sequence_matcher: The sequence matcher used for comparing strings.
245 @param matches: A list where matched items will be appended.
246 @param item: The item to be compared against the search term.
175 """ 247 """
176 248
177 item_name_lower = item.name.lower() 249 item_name_lower = item.name.strip().lower()
178 item_entity_lower = item.entity.full().lower() 250 item_entity_lower = item.entity.userhost().lower()
179 251
180 if search_term in (item_name_lower, item_entity_lower): 252 if search_term in (item_name_lower, item_entity_lower):
181 item.exact_match = True 253 item.exact_match = True
182 item.relevance = 1 254 item.relevance = 1
183 matches.append(item) 255 matches.append(item)
184 return 256 return
185 257
186 item.exact_match = False 258 item.exact_match = False
187 259
188 sequence_matcher.set_seq2(item_name_lower) 260 # Check if search_term is a substring of item_name_lower or item_entity_lower
189 name_ratio = sequence_matcher.ratio() 261 if len(search_term) >= 3:
190 if name_ratio >= RATIO_CUTOFF: 262 if item_name_lower and search_term in item_name_lower:
191 item.relevance = name_ratio 263 item.relevance = PARTIAL_MATCH_RATIO
192 matches.append(item) 264 matches.append(item)
193 return 265 return
266
267 if search_term in item_entity_lower:
268 item.relevance = PARTIAL_MATCH_RATIO
269 matches.append(item)
270 return
271
272 if item_name_lower:
273 sequence_matcher.set_seq2(item_name_lower)
274 name_ratio = sequence_matcher.ratio()
275 if name_ratio >= RATIO_CUTOFF:
276 item.relevance = name_ratio
277 matches.append(item)
278 return
194 279
195 sequence_matcher.set_seq2(item_entity_lower) 280 sequence_matcher.set_seq2(item_entity_lower)
196 jid_ratio = sequence_matcher.ratio() 281 jid_ratio = sequence_matcher.ratio()
197 if jid_ratio >= RATIO_CUTOFF: 282 if jid_ratio >= RATIO_CUTOFF:
198 item.relevance = jid_ratio 283 item.relevance = jid_ratio
206 if domain_ratio >= RATIO_CUTOFF: 291 if domain_ratio >= RATIO_CUTOFF:
207 item.relevance = domain_ratio 292 item.relevance = domain_ratio
208 matches.append(item) 293 matches.append(item)
209 return 294 return
210 295
211 if item.groups: 296 if isinstance(item, EntitySearchItem) and item.groups:
212 group_ratios = [] 297 group_ratios = []
213 for group in item.groups: 298 for group in item.groups:
214 sequence_matcher.set_seq2(group.lower()) 299 sequence_matcher.set_seq2(group.lower())
215 group_ratios.append(sequence_matcher.ratio()) 300 group_ratios.append(sequence_matcher.ratio())
216 group_ratio = max(group_ratios) 301 group_ratio = max(group_ratios)
225 if domain_ratio >= RATIO_CUTOFF: 310 if domain_ratio >= RATIO_CUTOFF:
226 item.relevance = domain_ratio 311 item.relevance = domain_ratio
227 matches.append(item) 312 matches.append(item)
228 return 313 return
229 314
230 async def _perform_search( 315 async def perform_search(
231 self, 316 self,
232 client: SatXMPPEntity, 317 client: SatXMPPEntity,
233 search_term: str, 318 search_term: str,
319 options: Options,
234 sequence_matcher: difflib.SequenceMatcher, 320 sequence_matcher: difflib.SequenceMatcher,
235 ) -> List[JidSearchItem]: 321 ) -> SearchItems:
236 """Performs a new search when no match is found in the cache. 322 """Performs a new search.
323
324 Cache is not used here.
237 325
238 @param search_term: The query to be searched. 326 @param search_term: The query to be searched.
239 @param sequence_matcher: The SequenceMatcher object to be used for matching. 327 @param sequence_matcher: The SequenceMatcher object to be used for matching.
240 @return: A list of matches found. 328 @return: A list of matches found.
241 """ 329 """
242 matches = [] 330 matches = SearchItems([])
243 331
244 try: 332 if options.entities:
245 roster = client.roster 333 assert isinstance(client, SatXMPPClient)
246 except AttributeError: 334 try:
247 # components have no roster 335 client.roster
248 roster = [] 336 except AttributeError:
249 else: 337 # components have no roster
250 roster = client.roster.get_items() 338 roster = []
251 339 else:
252 for roster_item in roster: 340 roster = client.roster.get_items()
253 jid_search_item = JidSearchItem( 341
254 entity=roster_item.entity, 342 for roster_item in roster:
255 name=roster_item.name, 343 jid_search_item = EntitySearchItem(
256 in_roster=True, 344 entity=roster_item.entity,
257 groups=list(roster_item.groups), 345 name=roster_item.name,
258 ) 346 in_roster=True,
259 347 groups=list(roster_item.groups),
260 self._process_matching( 348 )
261 search_term, sequence_matcher, matches, jid_search_item 349
262 ) 350 self.process_matching(
351 search_term, sequence_matcher, matches, jid_search_item
352 )
353
354 await self.host.trigger.async_point(
355 "JID_SEARCH_perform_search",
356 client,
357 search_term,
358 options,
359 sequence_matcher,
360 matches,
361 )
263 362
264 return matches 363 return matches