view libervia/backend/memory/migration/versions/fe3a02cb4bec_convert_legacypickle_columns_to_json.py @ 4216:1a7a3e4b52a4

core (memory/migration): Update XEP-0384 and `fe3a02cb4bec_convert_legacypickle_columns_to_json.py` migration to properly handle (de)serialisation of `TrustMessageCacheEntry`.
author Goffi <goffi@goffi.org>
date Tue, 05 Mar 2024 17:31:12 +0100
parents 5f2d496c633f
children 79a4870cfbdf
line wrap: on
line source

"""convert LegacyPickle columns to JSON

Revision ID: fe3a02cb4bec
Revises: 610345f77e75
Create Date: 2024-02-22 14:55:59.993983

"""
from alembic import op
import sqlalchemy as sa
import pickle
import json
from libervia.backend.plugins.plugin_xep_0373 import PublicKeyMetadata
from libervia.backend.plugins.plugin_xep_0384 import TrustMessageCacheEntry

# revision identifiers, used by Alembic.
revision = "fe3a02cb4bec"
down_revision = "610345f77e75"
branch_labels = None
depends_on = None


def convert_pickle_to_json(value, table, primary_keys):
    """Convert pickled data to JSON, handling potential errors."""
    if value is None:
        return None
    try:
        # some values are converted to bytes with LegacyPickle
        if isinstance(value, str):
            value = value.encode()
        try:
            deserialized = pickle.loads(value, encoding="utf-8")
        except ModuleNotFoundError:
            deserialized = pickle.loads(
                value.replace(b"sat.plugins", b"libervia.backend.plugins"),
                encoding="utf-8",
            )
        if (
            table == "private_ind_bin"
            and primary_keys[0] == "XEP-0373"
            and not primary_keys[1].startswith("/trust")
            and isinstance(deserialized, set)
            and deserialized
            and isinstance(next(iter(deserialized)), PublicKeyMetadata)
        ):
            # XEP-0373 plugin was pickling an internal class, this can't be converted
            # directly to JSON, so we do a special treatment with the add `to_dict` and
            # `from_dict` methods.
            deserialized = [pkm.to_dict() for pkm in deserialized]

        elif (
            table == "private_ind_bin"
            and primary_keys[0] == "XEP-0384/TM"
            and primary_keys[1] == "cache"
        ):
            # Same issue and solution as for XEP-0373
            try:
                deserialized = [tm.to_dict() for tm in deserialized]
            except Exception as e:
                print(
                    "Warning: Failed to convert Trust Management cache with value "
                    f" {deserialized!r}, using empty array instead: {e}"
                )
                deserialized=[]

        ret = json.dumps(deserialized, ensure_ascii=False, default=str)
        if table == 'history' and ret == "{}":
            # For history, we can remove empty data, but for other tables it may be
            # significant.
            ret = None
        return ret
    except Exception as e:
        print(
            f"Warning: Failed to convert pickle to JSON, using NULL instead. Error: {e}"
        )
        return None


def upgrade():
    print(
        "This migration may take very long, please be patient and don't stop the process."
    )
    connection = op.get_bind()

    tables_and_columns = [
        ("history", "extra", "uid"),
        ("private_gen_bin", "value", "namespace", "key"),
        ("private_ind_bin", "value", "namespace", "key", "profile_id"),
    ]

    for table, column, *primary_keys in tables_and_columns:
        primary_key_clause = " AND ".join(f"{pk} = :{pk}" for pk in primary_keys)
        select_stmt = sa.text(f"SELECT {', '.join(primary_keys)}, {column} FROM {table}")
        update_stmt = sa.text(
            f"UPDATE {table} SET {column} = :{column} WHERE {primary_key_clause}"
        )

        result = connection.execute(select_stmt)
        for row in result:
            value = row[-1]
            if value is None:
                continue
            data = {pk: row[idx] for idx, pk in enumerate(primary_keys)}
            data[column] = convert_pickle_to_json(value, table, row[:-1])
            connection.execute(update_stmt.bindparams(**data))


def convert_json_to_pickle(value, table, primary_keys):
    """Convert JSON data back to pickled data, handling potential errors."""
    if value is None:
        return None
    try:
        deserialized = json.loads(value)
        # Check for the specific table and primary key conditions that require special
        # handling
        if (
            table == "private_ind_bin"
            and primary_keys[0] == "XEP-0373"
            and not primary_keys[1].startswith("/trust")
        ):
            # Convert list of dicts back to set of PublicKeyMetadata objects
            if isinstance(deserialized, list):
                deserialized = {PublicKeyMetadata.from_dict(d) for d in deserialized}
        elif (
            table == "private_ind_bin"
            and primary_keys[0] == "XEP-0384/TM"
            and primary_keys[1] == "cache"
        ):
            # Convert list of dicts back to set of TrustMessageCacheEntry objects
            if isinstance(deserialized, list):
                deserialized = {TrustMessageCacheEntry.from_dict(d) for d in deserialized}
        return pickle.dumps(deserialized, 0)
    except Exception as e:
        print(
            f"Warning: Failed to convert JSON to pickle, using NULL instead. Error: {e}"
        )
        return None


def downgrade():
    print(
        "Reverting JSON columns to LegacyPickle format. This may take a while, please be "
        "patient."
    )
    connection = op.get_bind()

    tables_and_columns = [
        ("history", "extra", "uid"),
        ("private_gen_bin", "value", "namespace", "key"),
        ("private_ind_bin", "value", "namespace", "key", "profile_id"),
    ]

    for table, column, *primary_keys in tables_and_columns:
        primary_key_clause = " AND ".join(f"{pk} = :{pk}" for pk in primary_keys)
        select_stmt = sa.text(f"SELECT {', '.join(primary_keys)}, {column} FROM {table}")
        update_stmt = sa.text(
            f"UPDATE {table} SET {column} = :{column} WHERE {primary_key_clause}"
        )

        result = connection.execute(select_stmt)
        for row in result:
            value = row[-1]
            if value is None:
                continue
            data = {pk: row[idx] for idx, pk in enumerate(primary_keys)}
            data[column] = convert_json_to_pickle(value, table, row[:-1])
            connection.execute(update_stmt.bindparams(**data))