comparison sat/memory/sqla.py @ 3673:bd13391ee29e

core (memory/sqla): fix `fileUpdate`
author Goffi <goffi@goffi.org>
date Wed, 08 Sep 2021 17:58:48 +0200
parents 72b0e4053ab0
children cf930bb282ac
comparison
equal deleted inserted replaced
3672:e4054b648111 3673:bd13391ee29e
17 # along with this program. If not, see <http://www.gnu.org/licenses/>. 17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
18 18
19 import sys 19 import sys
20 import time 20 import time
21 import asyncio 21 import asyncio
22 import copy
22 from datetime import datetime 23 from datetime import datetime
23 from asyncio.subprocess import PIPE 24 from asyncio.subprocess import PIPE
24 from pathlib import Path 25 from pathlib import Path
25 from typing import Union, Dict, List, Tuple, Iterable, Any, Callable, Optional 26 from typing import Union, Dict, List, Tuple, Iterable, Any, Callable, Optional
26 from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine 27 from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine
979 value = (await session.execute( 980 value = (await session.execute(
980 select(orm_col).filter_by(id=file_id) 981 select(orm_col).filter_by(id=file_id)
981 )).scalar_one() 982 )).scalar_one()
982 except NoResultFound: 983 except NoResultFound:
983 raise exceptions.NotFound 984 raise exceptions.NotFound
985 old_value = copy.deepcopy(value)
984 update_cb(value) 986 update_cb(value)
985 stmt = update(orm_col).filter_by(id=file_id) 987 stmt = update(File).filter_by(id=file_id).values({column: value})
986 if not value: 988 if not old_value:
987 # because JsonDefaultDict convert NULL to an empty dict, we have to 989 # because JsonDefaultDict convert NULL to an empty dict, we have to
988 # test both for empty dict and None when we have and empty dict 990 # test both for empty dict and None when we have an empty dict
989 stmt = stmt.where((orm_col == None) | (orm_col == value)) 991 stmt = stmt.where((orm_col == None) | (orm_col == old_value))
990 else: 992 else:
991 stmt = stmt.where(orm_col == value) 993 stmt = stmt.where(orm_col == old_value)
992 result = await session.execute(stmt) 994 result = await session.execute(stmt)
993 await session.commit() 995 await session.commit()
994 996
995 if result.rowcount == 1: 997 if result.rowcount == 1:
996 break 998 break