# HG changeset patch # User Goffi # Date 1708691464 -3600 # Node ID 5f2d496c633f58ff8ecb2fb802f7ebf79be03de1 # Parent be89ab1cbca4053f6391c86527b6e00c43bc8704 core: get rid of `pickle`: Use of `pickle` to serialise data was a technical legacy that was causing trouble to store in database, to update (if a class was serialised, a change could break update), and to security (pickle can lead to code execution). This patch remove all use of Pickle in favour in JSON, notably: - for caching data, a Pydantic model is now used instead - for SQLAlchemy model, the LegacyPickle is replaced by JSON serialisation - in XEP-0373 a class `PublicKeyMetadata` was serialised. New method `from_dict` and `to_dict` method have been implemented to do serialisation. - new methods to (de)serialise data can now be specified with Identity data types. It is notably used to (de)serialise `path` of avatars. A migration script has been created to convert data (for upgrade or downgrade), with special care for XEP-0373 case. Depending of size of database, this migration script can be long to run. rel 443 diff -r be89ab1cbca4 -r 5f2d496c633f libervia/backend/memory/cache.py --- a/libervia/backend/memory/cache.py Fri Feb 16 18:46:06 2024 +0100 +++ b/libervia/backend/memory/cache.py Fri Feb 23 13:31:04 2024 +0100 @@ -1,8 +1,7 @@ #!/usr/bin/env python3 - -# SAT: a jabber client -# Copyright (C) 2009-2021 Jérôme Poisson (goffi@goffi.org) +# Libervia: an XMPP client +# Copyright (C) 2009-2024 Jérôme Poisson (goffi@goffi.org) # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -20,9 +19,10 @@ from io import BufferedIOBase import mimetypes from pathlib import Path -import pickle as pickle import time -from typing import Any, Dict, Optional +from typing import Any + +from pydantic import BaseModel, ValidationError from libervia.backend.core import exceptions from libervia.backend.core.constants import Const as C @@ -33,11 +33,24 @@ log = getLogger(__name__) +CACHE_METADATA_EXT = ".cache.json" DEFAULT_EXT = ".raw" -class Cache(object): - """generic file caching""" +class CacheMetadata(BaseModel): + source: str + uid: str + filename: str + creation: int + eol: int + max_age: int = C.DEFAULT_MAX_AGE + original_filename: str | None = None + mime_type: str | None = None + last_access: int | None = None + + +class Cache: + """Generic file caching.""" def __init__(self, host, profile): """ @@ -56,60 +69,52 @@ self.purge() def purge(self): - # remove expired files from cache + # Remove expired, unreadable, and unrelated files from cache # TODO: this should not be called only on startup, but at regular interval # (e.g. once a day) - purged = set() - # we sort files to have metadata files first - for cache_file in sorted(self.cache_dir.iterdir()): - if cache_file in purged: - continue + to_delete = set() + seen = set() + now = time.time() + for cache_data_file in self.cache_dir.glob(f"*{CACHE_METADATA_EXT}"): try: - with cache_file.open('rb') as f: - cache_data = pickle.load(f) - except IOError: + with cache_data_file.open("r") as f: + cache_data = CacheMetadata.model_validate_json(f.read()) + except (IOError, ValidationError): log.warning( - _("Can't read metadata file at {path}") - .format(path=cache_file)) - continue - except (pickle.UnpicklingError, EOFError): - log.debug(f"File at {cache_file} is not a metadata file") + _("Can't read metadata file at {path}, deleting it.").format( + path=cache_data_file + ) + ) + to_delete.add(cache_data_file) continue - try: - eol = cache_data['eol'] - filename = cache_data['filename'] - except KeyError: - log.warning( - _("Invalid cache metadata at {path}") - .format(path=cache_file)) - continue + else: + cached_file = self.get_path(cache_data.filename) + if not cached_file.exists(): + log.warning( + f"Cache file {cache_data_file!r} references a non-existent file " + f"and will be deleted: {cache_data_file!r}." + ) + to_delete.add(cache_data_file) + elif cache_data.eol < now: + log.debug( + "Purging expired cache file {cache_data_file!r} (expired for " + "{time}s)".format(time=int(time.time() - cache_data.eol)) + ) + to_delete.add(cache_data_file) + seen.add(cached_file) + seen.add(cache_data_file) - filepath = self.getPath(filename) + for file in to_delete: + log.debug(f"Deleting cache file: {file}") + file.unlink() - if not filepath.exists(): - log.warning(_( - "cache {cache_file!r} references an inexisting file: {filepath!r}" - ).format(cache_file=str(cache_file), filepath=str(filepath))) - log.debug("purging cache with missing file") - cache_file.unlink() - elif eol < time.time(): - log.debug( - "purging expired cache {filepath!r} (expired for {time}s)" - .format(filepath=str(filepath), time=int(time.time() - eol)) - ) - cache_file.unlink() - try: - filepath.unlink() - except FileNotFoundError: - log.warning( - _("following file is missing while purging cache: {path}") - .format(path=filepath) - ) - purged.add(cache_file) - purged.add(filepath) + for file in self.cache_dir.iterdir(): + if file not in seen: + log.debug(f"Deleting irrelevant file in cache dir: {file}") + file.unlink() - def getPath(self, filename: str) -> Path: - """return cached file URL + def get_path(self, filename: str) -> Path: + """Return cached file URL. @param filename: cached file name (cache data or actual file) @return: path to the cached file @@ -121,62 +126,58 @@ raise exceptions.DataError("Invalid char found") return self.cache_dir / filename - def get_metadata(self, uid: str, update_eol: bool = True) -> Optional[Dict[str, Any]]: - """Retrieve metadata for cached data + def get_metadata(self, uid: str, update_eol: bool = True) -> dict[str, Any] | None: + """Retrieve metadata for cached data. - @param uid(unicode): unique identifier of file - @param update_eol(bool): True if eol must extended + @param uid: unique identifier of cache metadata. + @param update_eol: True if eol must extended if True, max_age will be added to eol (only if it is not already expired) - @return (dict, None): metadata with following keys: - see [cache_data] for data details, an additional "path" key is the full path to - cached file. - None if file is not in cache (or cache is invalid) + @return: metadata, see [cache_data] for data details, an additional "path" key is + the full path to cached file. + None if file is not in cache (or cache is invalid). """ - uid = uid.strip() if not uid: raise exceptions.InternalError("uid must not be empty") - cache_url = self.getPath(uid) + cache_url = self.get_path(f"{uid}{CACHE_METADATA_EXT}") if not cache_url.exists(): return None try: - with cache_url.open("rb") as f: - cache_data = pickle.load(f) + with cache_url.open("r") as f: + cache_data = CacheMetadata.model_validate_json(f.read()) except (IOError, EOFError) as e: - log.warning(f"can't read cache at {cache_url}: {e}") + log.warning(f"Can't read cache at {cache_url}: {e}") return None - except pickle.UnpicklingError: - log.warning(f"invalid cache found at {cache_url}") + except ValidationError: + log.warning(f"Invalid cache found at {cache_url}") + return None + except UnicodeDecodeError as e: + log.warning(f"Invalid encoding, this is not a cache metadata file.") return None - try: - eol = cache_data["eol"] - except KeyError: - log.warning("no End Of Life found for cached file {}".format(uid)) - eol = 0 - if eol < time.time(): + if cache_data.eol < time.time(): log.debug( - "removing expired cache (expired for {}s)".format(time.time() - eol) + "removing expired cache (expired for {}s)".format( + time.time() - cache_data.eol + ) ) return None if update_eol: - try: - max_age = cache_data["max_age"] - except KeyError: - log.warning(f"no max_age found for cache at {cache_url}, using default") - max_age = cache_data["max_age"] = C.DEFAULT_MAX_AGE now = int(time.time()) - cache_data["last_access"] = now - cache_data["eol"] = now + max_age - with cache_url.open("wb") as f: - pickle.dump(cache_data, f, protocol=2) + cache_data.last_access = now + cache_data.eol = now + cache_data.max_age + with cache_url.open("w") as f: + f.write(cache_data.model_dump_json(exclude_none=True)) - cache_data["path"] = self.getPath(cache_data["filename"]) - return cache_data + # FIXME: we convert to dict to be compatible with former method (pre Pydantic). + # All call to get_metadata should use directly the Pydantic model in the future. + cache_data_dict = cache_data.model_dump() + cache_data_dict["path"] = self.get_path(cache_data.filename) + return cache_data_dict - def get_file_path(self, uid: str) -> Path: + def get_file_path(self, uid: str) -> Path | None: """Retrieve absolute path to file @param uid(unicode): unique identifier of file @@ -187,7 +188,7 @@ if metadata is not None: return metadata["path"] - def remove_from_cache(self, uid, metadata=None): + def remove_from_cache(self, uid: str, metadata=None) -> None: """Remove data from cache @param uid(unicode): unique identifier cache file @@ -198,32 +199,33 @@ return try: - filename = cache_data['filename'] + filename = cache_data["filename"] except KeyError: - log.warning(_("missing filename for cache {uid!r}") .format(uid=uid)) + log.warning(_("missing filename for cache {uid!r}").format(uid=uid)) else: - filepath = self.getPath(filename) + filepath = self.get_path(filename) try: filepath.unlink() except FileNotFoundError: log.warning( - _("missing file referenced in cache {uid!r}: {filename}") - .format(uid=uid, filename=filename) + _("missing file referenced in cache {uid!r}: {filename}").format( + uid=uid, filename=filename + ) ) - cache_file = self.getPath(uid) + cache_file = self.get_path(f"{uid}{CACHE_METADATA_EXT}") cache_file.unlink() - log.debug(f"cache with uid {uid!r} has been removed") + log.debug(f"Cache with uid {uid!r} has been removed.") def cache_data( self, source: str, uid: str, - mime_type: Optional[str] = None, - max_age: Optional[int] = None, - original_filename: Optional[str] = None + mime_type: str | None = None, + max_age: int = C.DEFAULT_MAX_AGE, + original_filename: str | None = None, ) -> BufferedIOBase: - """create cache metadata and file object to use for actual data + """Create cache metadata and file object to use for actual data. @param source: source of the cache (should be plugin's import_name) @param uid: an identifier of the file which must be unique @@ -235,47 +237,42 @@ None to use default value 0 to ignore cache (file will be re-downloaded on each access) @param original_filename: if not None, will be used to retrieve file extension and - guess - mime type, and stored in "original_filename" + guess mime type, and stored in "original_filename" @return: file object opened in write mode you have to close it yourself (hint: use ``with`` statement) """ - if max_age is None: - max_age = C.DEFAULT_MAX_AGE - cache_data = { - "source": source, - # we also store max_age for updating eol - "max_age": max_age, - } - cache_url = self.getPath(uid) - if original_filename is not None: - cache_data["original_filename"] = original_filename - if mime_type is None: - # we have original_filename but not MIME type, we try to guess the later - mime_type = mimetypes.guess_type(original_filename, strict=False)[0] + if original_filename is not None and mime_type is None: + # we have original_filename but not MIME type, we try to guess the later + mime_type = mimetypes.guess_type(original_filename, strict=False)[0] + if mime_type: ext = mimetypes.guess_extension(mime_type, strict=False) if ext is None: - log.warning( - "can't find extension for MIME type {}".format(mime_type) - ) + log.warning("can't find extension for MIME type {}".format(mime_type)) ext = DEFAULT_EXT elif ext == ".jpe": ext = ".jpg" else: ext = DEFAULT_EXT mime_type = None + filename = uid + ext now = int(time.time()) - cache_data.update({ - "filename": filename, - "creation": now, - "eol": now + max_age, - "mime_type": mime_type, - }) - file_path = self.getPath(filename) + metadata = CacheMetadata( + source=source, + uid=uid, + mime_type=mime_type, + max_age=max_age, + original_filename=original_filename, + filename=filename, + creation=now, + eol=now + max_age, + ) - with open(cache_url, "wb") as f: - pickle.dump(cache_data, f, protocol=2) + cache_metadata_file = self.get_path(f"{uid}{CACHE_METADATA_EXT}") + file_path = self.get_path(filename) - return file_path.open("wb") + with open(cache_metadata_file, "w") as f: + f.write(metadata.model_dump_json(exclude_none=True)) + + return open(file_path, "wb") diff -r be89ab1cbca4 -r 5f2d496c633f libervia/backend/memory/migration/versions/fe3a02cb4bec_convert_legacypickle_columns_to_json.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/libervia/backend/memory/migration/versions/fe3a02cb4bec_convert_legacypickle_columns_to_json.py Fri Feb 23 13:31:04 2024 +0100 @@ -0,0 +1,142 @@ +"""convert LegacyPickle columns to JSON + +Revision ID: fe3a02cb4bec +Revises: 610345f77e75 +Create Date: 2024-02-22 14:55:59.993983 + +""" +from alembic import op +import sqlalchemy as sa +import pickle +import json +from libervia.backend.plugins.plugin_xep_0373 import PublicKeyMetadata + +# revision identifiers, used by Alembic. +revision = "fe3a02cb4bec" +down_revision = "610345f77e75" +branch_labels = None +depends_on = None + + +def convert_pickle_to_json(value, table, primary_keys): + """Convert pickled data to JSON, handling potential errors.""" + if value is None: + return None + try: + # some values are converted to bytes with LegacyPickle + if isinstance(value, str): + value = value.encode() + try: + deserialized = pickle.loads(value, encoding="utf-8") + except ModuleNotFoundError: + deserialized = pickle.loads( + value.replace(b"sat.plugins", b"libervia.backend.plugins"), + encoding="utf-8", + ) + if ( + table == "private_ind_bin" + and primary_keys[0] == "XEP-0373" + and not primary_keys[1].startswith("/trust") + and isinstance(deserialized, set) + and deserialized + and isinstance(next(iter(deserialized)), PublicKeyMetadata) + ): + # XEP-0373 plugin was pickling an internal class, this can't be converted + # directly to JSON, so we do a special treatment with the add `to_dict` and + # `from_dict` methods. + deserialized = [pkm.to_dict() for pkm in deserialized] + + ret = json.dumps(deserialized, ensure_ascii=False, default=str) + if table == 'history' and ret == "{}": + # For history, we can remove empty data, but for other tables it may be + # significant. + ret = None + return ret + except Exception as e: + print( + f"Warning: Failed to convert pickle to JSON, using NULL instead. Error: {e}" + ) + return None + + +def upgrade(): + print( + "This migration may take very long, please be patient and don't stop the process." + ) + connection = op.get_bind() + + tables_and_columns = [ + ("history", "extra", "uid"), + ("private_gen_bin", "value", "namespace", "key"), + ("private_ind_bin", "value", "namespace", "key", "profile_id"), + ] + + for table, column, *primary_keys in tables_and_columns: + primary_key_clause = " AND ".join(f"{pk} = :{pk}" for pk in primary_keys) + select_stmt = sa.text(f"SELECT {', '.join(primary_keys)}, {column} FROM {table}") + update_stmt = sa.text( + f"UPDATE {table} SET {column} = :{column} WHERE {primary_key_clause}" + ) + + result = connection.execute(select_stmt) + for row in result: + value = row[-1] + if value is None: + continue + data = {pk: row[idx] for idx, pk in enumerate(primary_keys)} + data[column] = convert_pickle_to_json(value, table, row[:-1]) + connection.execute(update_stmt.bindparams(**data)) + + +def convert_json_to_pickle(value, table, primary_keys): + """Convert JSON data back to pickled data, handling potential errors.""" + if value is None: + return None + try: + deserialized = json.loads(value) + # Check for the specific table and primary key conditions that require special + # handling + if ( + table == "private_ind_bin" + and primary_keys[0] == "XEP-0373" + and not primary_keys[1].startswith("/trust") + ): + # Convert list of dicts back to set of PublicKeyMetadata objects + if isinstance(deserialized, list): + deserialized = {PublicKeyMetadata.from_dict(d) for d in deserialized} + return pickle.dumps(deserialized, 0) + except Exception as e: + print( + f"Warning: Failed to convert JSON to pickle, using NULL instead. Error: {e}" + ) + return None + + +def downgrade(): + print( + "Reverting JSON columns to LegacyPickle format. This may take a while, please be " + "patient." + ) + connection = op.get_bind() + + tables_and_columns = [ + ("history", "extra", "uid"), + ("private_gen_bin", "value", "namespace", "key"), + ("private_ind_bin", "value", "namespace", "key", "profile_id"), + ] + + for table, column, *primary_keys in tables_and_columns: + primary_key_clause = " AND ".join(f"{pk} = :{pk}" for pk in primary_keys) + select_stmt = sa.text(f"SELECT {', '.join(primary_keys)}, {column} FROM {table}") + update_stmt = sa.text( + f"UPDATE {table} SET {column} = :{column} WHERE {primary_key_clause}" + ) + + result = connection.execute(select_stmt) + for row in result: + value = row[-1] + if value is None: + continue + data = {pk: row[idx] for idx, pk in enumerate(primary_keys)} + data[column] = convert_json_to_pickle(value, table, row[:-1]) + connection.execute(update_stmt.bindparams(**data)) diff -r be89ab1cbca4 -r 5f2d496c633f libervia/backend/memory/sqla.py --- a/libervia/backend/memory/sqla.py Fri Feb 16 18:46:06 2024 +0100 +++ b/libervia/backend/memory/sqla.py Fri Feb 23 13:31:04 2024 +0100 @@ -20,6 +20,7 @@ from asyncio.subprocess import PIPE import copy from datetime import datetime +import json from pathlib import Path import sys import time @@ -214,6 +215,7 @@ engine = create_async_engine( db_config["url"], future=True, + json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False) ) new_base = not db_config["path"].exists() diff -r be89ab1cbca4 -r 5f2d496c633f libervia/backend/memory/sqla_mapping.py --- a/libervia/backend/memory/sqla_mapping.py Fri Feb 16 18:46:06 2024 +0100 +++ b/libervia/backend/memory/sqla_mapping.py Fri Feb 23 13:31:04 2024 +0100 @@ -19,7 +19,6 @@ from datetime import datetime import enum import json -import pickle import time from typing import Any, Dict @@ -132,42 +131,6 @@ URGENT = 40 -class LegacyPickle(TypeDecorator): - """Handle troubles with data pickled by former version of SàT - - This type is temporary until we do migration to a proper data type - """ - - # Blob is used on SQLite but gives errors when used here, while Text works fine - impl = Text - cache_ok = True - - def process_bind_param(self, value, dialect): - if value is None: - return None - return pickle.dumps(value, 0) - - def process_result_value(self, value, dialect): - if value is None: - return None - # value types are inconsistent (probably a consequence of Python 2/3 port - # and/or SQLite dynamic typing) - try: - value = value.encode() - except AttributeError: - pass - # "utf-8" encoding is needed to handle Python 2 pickled data - try: - return pickle.loads(value, encoding="utf-8") - except ModuleNotFoundError: - # FIXME: workaround due to package renaming, need to move all pickle code to - # JSON - return pickle.loads( - value.replace(b"sat.plugins", b"libervia.backend.plugins"), - encoding="utf-8", - ) - - class Json(TypeDecorator): """Handle JSON field in DB independant way""" @@ -178,7 +141,7 @@ def process_bind_param(self, value, dialect): if value is None: return None - return json.dumps(value) + return json.dumps(value, ensure_ascii=False) def process_result_value(self, value, dialect): if value is None: @@ -296,7 +259,7 @@ ), nullable=False, ) - extra = Column(LegacyPickle) + extra = Column(JSON) profile = relationship("Profile") messages = relationship( @@ -573,7 +536,7 @@ namespace = Column(Text, primary_key=True) key = Column(Text, primary_key=True) - value = Column(LegacyPickle) + value = Column(JSON) class PrivateIndBin(Base): @@ -582,7 +545,7 @@ namespace = Column(Text, primary_key=True) key = Column(Text, primary_key=True) profile_id = Column(ForeignKey("profiles.id", ondelete="CASCADE"), primary_key=True) - value = Column(LegacyPickle) + value = Column(JSON) profile = relationship("Profile", back_populates="private_bin_data") diff -r be89ab1cbca4 -r 5f2d496c633f libervia/backend/plugins/plugin_comp_ap_gateway/http_server.py --- a/libervia/backend/plugins/plugin_comp_ap_gateway/http_server.py Fri Feb 16 18:46:06 2024 +0100 +++ b/libervia/backend/plugins/plugin_comp_ap_gateway/http_server.py Fri Feb 23 13:31:04 2024 +0100 @@ -1017,7 +1017,7 @@ if len(extra_args) != 1: raise exceptions.DataError("avatar argument expected in URL") avatar_filename = extra_args[0] - avatar_path = self.apg.host.common_cache.getPath(avatar_filename) + avatar_path = self.apg.host.common_cache.get_path(avatar_filename) return static.File(str(avatar_path)).render(request) elif request_type == "item": ret_data = await self.apg.ap_get_local_object(ap_url) diff -r be89ab1cbca4 -r 5f2d496c633f libervia/backend/plugins/plugin_misc_identity.py --- a/libervia/backend/plugins/plugin_misc_identity.py Fri Feb 16 18:46:06 2024 +0100 +++ b/libervia/backend/plugins/plugin_misc_identity.py Fri Feb 23 13:31:04 2024 +0100 @@ -85,6 +85,8 @@ # we store the metadata in database, to restore it on next connection # (it is stored only for roster entities) "store": True, + "store_serialisation": self._avatar_ser, + "store_deserialisation": self._avatar_deser }, "nicknames": { "type": list, @@ -167,10 +169,16 @@ for key, value in stored_data.items(): entity_s, name = key.split('\n') - if name not in self.metadata.keys(): + try: + metadata = self.metadata[name] + except KeyError: log.debug(f"removing {key} from storage: not an allowed metadata name") to_delete.append(key) continue + if value is not None: + deser_method = metadata.get("store_deserialisation") + if deser_method is not None: + value = deser_method(value) entity = jid.JID(entity_s) if name == 'avatar': @@ -365,6 +373,10 @@ client, entity, metadata_name, data) if metadata.get('store', False): + if data is not None: + ser_method = metadata.get("store_serialisation") + if ser_method is not None: + data = ser_method(data) key = f"{entity}\n{metadata_name}" await client._identity_storage.aset(key, data) @@ -488,6 +500,10 @@ if metadata.get('store', False): key = f"{entity}\n{metadata_name}" + if data is not None: + ser_method = metadata.get("store_serialisation") + if ser_method is not None: + data = ser_method(data) await client._identity_storage.aset(key, data) def default_update_is_new_data(self, client, entity, cached_data, new_data): @@ -633,6 +649,18 @@ raise ValueError(f"missing avatar data keys: {mandatory_keys - data.keys()}") return data + def _avatar_ser(self, data: dict) -> dict: + if data.get("path"): + # Path instance can't be stored + data = data.copy() + data["path"] = str(data["path"]) + return data + + def _avatar_deser(self, data: dict) -> dict: + if data.get("path"): + data["path"] = Path(data["path"]) + return data + async def nicknames_get_post_treatment(self, client, entity, plugin_nicknames): """Prepend nicknames from core locations + set default nickname diff -r be89ab1cbca4 -r 5f2d496c633f libervia/backend/plugins/plugin_xep_0048.py --- a/libervia/backend/plugins/plugin_xep_0048.py Fri Feb 16 18:46:06 2024 +0100 +++ b/libervia/backend/plugins/plugin_xep_0048.py Fri Feb 23 13:31:04 2024 +0100 @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from typing import cast from libervia.backend.core.i18n import _, D_ from libervia.backend.core import exceptions from libervia.backend.core.constants import Const as C @@ -106,9 +107,12 @@ NS_BOOKMARKS, client.profile ) await local.load() + local = cast(dict[str, dict|None]|None, local) if not local: - local[XEP_0048.MUC_TYPE] = dict() - local[XEP_0048.URL_TYPE] = dict() + local = { + XEP_0048.MUC_TYPE: {}, + XEP_0048.URL_TYPE: {} + } private = await self._get_server_bookmarks("private", client.profile) pubsub = client.bookmarks_pubsub = None diff -r be89ab1cbca4 -r 5f2d496c633f libervia/backend/plugins/plugin_xep_0373.py --- a/libervia/backend/plugins/plugin_xep_0373.py Fri Feb 16 18:46:06 2024 +0100 +++ b/libervia/backend/plugins/plugin_xep_0373.py Fri Feb 23 13:31:04 2024 +0100 @@ -20,12 +20,13 @@ import base64 from datetime import datetime, timezone import enum +import json import secrets import string from typing import Any, Dict, Iterable, List, Literal, Optional, Set, Tuple, cast from xml.sax.saxutils import quoteattr -from typing_extensions import Final, NamedTuple, Never, assert_never +from typing import Final, NamedTuple, Never, assert_never from wokkel import muc, pubsub from wokkel.disco import DiscoFeature, DiscoInfo import xmlschema @@ -812,10 +813,21 @@ """ Metadata about a published public key. """ - fingerprint: str timestamp: datetime + def to_dict(self) -> dict: + # Convert the instance to a dictionary and handle datetime serialization + data = self._asdict() + data['timestamp'] = self.timestamp.isoformat() + return data + + @staticmethod + def from_dict(data: dict) -> 'PublicKeyMetadata': + # Load a serialised dictionary + data['timestamp'] = datetime.fromisoformat(data['timestamp']) + return PublicKeyMetadata(**data) + @enum.unique class TrustLevel(enum.Enum): @@ -1102,10 +1114,10 @@ storage_key = STR_KEY_PUBLIC_KEYS_METADATA.format(sender.userhost()) - local_public_keys_metadata = cast( - Set[PublicKeyMetadata], - await self.__storage[profile].get(storage_key, set()) - ) + local_public_keys_metadata = { + PublicKeyMetadata.from_dict(pkm) + for pkm in await self.__storage[profile].get(storage_key, []) + } unchanged_keys = new_public_keys_metadata & local_public_keys_metadata changed_or_new_keys = new_public_keys_metadata - unchanged_keys @@ -1149,7 +1161,10 @@ await self.publish_public_keys_list(client, new_public_keys_metadata) - await self.__storage[profile].force(storage_key, new_public_keys_metadata) + await self.__storage[profile].force( + storage_key, + [pkm.to_dict() for pkm in new_public_keys_metadata] + ) def list_public_keys(self, client: SatXMPPClient, jid: jid.JID) -> Set[GPGPublicKey]: """List GPG public keys available for a JID. @@ -1191,10 +1206,10 @@ storage_key = STR_KEY_PUBLIC_KEYS_METADATA.format(client.jid.userhost()) - public_keys_list = cast( - Set[PublicKeyMetadata], - await self.__storage[client.profile].get(storage_key, set()) - ) + public_keys_list = { + PublicKeyMetadata.from_dict(pkm) + for pkm in await self.__storage[client.profile].get(storage_key, []) + } public_keys_list.add(PublicKeyMetadata( fingerprint=secret_key.public_key.fingerprint, @@ -1508,10 +1523,10 @@ storage_key = STR_KEY_PUBLIC_KEYS_METADATA.format(entity_jid.userhost()) - public_keys_metadata = cast( - Set[PublicKeyMetadata], - await self.__storage[client.profile].get(storage_key, set()) - ) + public_keys_metadata = { + PublicKeyMetadata.from_dict(pkm) + for pkm in await self.__storage[client.profile].get(storage_key, []) + } if not public_keys_metadata: public_keys_metadata = await self.download_public_keys_list( client, entity_jid @@ -1522,7 +1537,8 @@ ) else: await self.__storage[client.profile].aset( - storage_key, public_keys_metadata + storage_key, + [pkm.to_dict() for pkm in public_keys_metadata] )