Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Storage: Add extend_json Tag for Enhanced JSON Field Handling in bulk_update Operations #6659

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/aiida/orm/implementation/storage_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,14 @@ def bulk_insert(self, entity_type: 'EntityTypes', rows: List[dict], allow_defaul
"""

@abc.abstractmethod
def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict]) -> None:
def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict], extend_json: bool = False) -> None:
"""Update a list of entities in the database, directly with a backend transaction.

:param entity_type: The type of the entity
:param data: A list of dictionaries, containing fields of the backend model to update,
and the `id` field (a.k.a primary key)
:param extend_json: A boolean flag indicating if updates on JSON fields are treated as an extension,
instead of overwriting the entire JSON object

:raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table
"""
Expand Down
49 changes: 42 additions & 7 deletions src/aiida/storage/psql_dos/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@

import functools
import gc
import json
import pathlib
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Set, Union
from typing import TYPE_CHECKING, Any, Iterator, Optional, Sequence, Set, Union

from disk_objectstore import Container, backup_utils
from pydantic import BaseModel, Field
from sqlalchemy import column, insert, update
from sqlalchemy import case, cast, column, func, insert, update
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Session, scoped_session, sessionmaker

from aiida.common import exceptions
Expand Down Expand Up @@ -314,7 +317,7 @@
keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key}
return mapper, keys

def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]:
def bulk_insert(self, entity_type: EntityTypes, rows: list[dict], allow_defaults: bool = False) -> list[int]:
mapper, keys = self._get_mapper_from_entity(entity_type, False)
if not rows:
return []
Expand All @@ -337,18 +340,50 @@
result = session.execute(insert(mapper).returning(mapper, column('id')), rows).fetchall()
return [row.id for row in result]

def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None:
def bulk_update(self, entity_type: EntityTypes, rows: list[dict], extend_json: bool = False) -> None:
mapper, keys = self._get_mapper_from_entity(entity_type, True)
if not rows:
return None

session = self.get_session()

def to_json(x: dict[str, Any]):
return cast(x, JSONB)

Check warning on line 351 in src/aiida/storage/psql_dos/backend.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/storage/psql_dos/backend.py#L351

Added line #L351 was not covered by tests

if session.bind is not None and session.bind.dialect.name == 'sqlite':
# TODO: A dirty workaround:
# SQLite DOS now doesn't have a dedicated background, and SQLite don't have JSONB type,
# so the casting need to be implement specifically.
def to_json(x: dict[str, Any]):
return json.dumps(x)

cases = defaultdict(list)
id_list = []
for row in rows:
if 'id' not in row:
raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}")
if not keys.issuperset(row):
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
if 'id' in row:
when = mapper.c.id == row['id']
id_list.append(row['id'])
else:
raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}")

for key, value in row.items():
if key == 'id':
continue

update_value = value
if key in ['extras', 'attributes']:
update_value = to_json(update_value)
if extend_json:
update_value = func.json_patch(mapper.c[key], update_value)
cases[key].append((when, update_value))

session = self.get_session()
with nullcontext() if self.in_transaction else self.transaction():
session.execute(update(mapper), rows)
values = {k: case(*v, else_=mapper.c[k]) for k, v in cases.items()}
stmt = update(mapper).where(mapper.c.id.in_(id_list)).values(**values)
session.execute(stmt)

def delete(self, delete_database_user: bool = False) -> None:
"""Delete the storage and all the data.
Expand Down
50 changes: 49 additions & 1 deletion src/aiida/storage/psql_dos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import json
from typing import TypedDict

from sqlalchemy import event


class PsqlConfig(TypedDict, total=False):
"""Configuration to connect to a PostgreSQL database."""
Expand All @@ -25,6 +27,50 @@
"""keyword argument that will be passed on to the SQLAlchemy engine."""


# Adapted from https://stackoverflow.com/a/79133234/9184828
JSONB_PATCH_FUNCTION = """
CREATE OR REPLACE FUNCTION json_patch (
target jsonb, -- target JSON value
patch jsonb -- patch JSON value
)
RETURNS jsonb
LANGUAGE plpgsql
IMMUTABLE AS $$
BEGIN
-- If the patch is not a JSON object, return the patch as the result (base case)
IF patch isnull or jsonb_typeof(patch) != 'object' THEN
RETURN patch;
END IF;

-- If the target is not an object, set it to an empty object
IF target isnull or jsonb_typeof(target) != 'object' THEN
target := '{}';
END IF;

RETURN coalesce(
jsonb_object_agg(
coalesce(targetKey, patchKey), -- there will be either one or both keys equal
CASE
WHEN patchKey isnull THEN targetValue -- key missing in patch - retain target value
ELSE json_patch(targetValue, patchValue)
END
),
'{}'::jsonb -- if SELECT will return no keys (empty table),
-- then jsonb_object_agg will return NULL, need to return {} in that case
)
FROM jsonb_each(target) temp1(targetKey, targetValue)
FULL JOIN jsonb_each(patch) temp2(patchKey, patchValue)
ON targetKey = patchKey
WHERE jsonb_typeof(patchValue) != 'null' OR patchValue isnull; -- remove keys which are set to null in patch object
END;
$$;
""".strip()


def register_jsonb_patch_function(conn, *args, **kwargs):
conn.execute(JSONB_PATCH_FUNCTION)

Check warning on line 71 in src/aiida/storage/psql_dos/utils.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/storage/psql_dos/utils.py#L71

Added line #L71 was not covered by tests


def create_sqlalchemy_engine(config: PsqlConfig):
"""Create SQLAlchemy engine (to be used for QueryBuilder queries)

Expand All @@ -50,12 +96,14 @@
port=config['database_port'],
name=config['database_name'],
)
return create_engine(
engine = create_engine(

Check warning on line 99 in src/aiida/storage/psql_dos/utils.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/storage/psql_dos/utils.py#L99

Added line #L99 was not covered by tests
engine_url,
json_serializer=json.dumps,
json_deserializer=json.loads,
**config.get('engine_kwargs', {}),
)
event.listen(engine, 'connect', register_jsonb_patch_function)
return engine

Check warning on line 106 in src/aiida/storage/psql_dos/utils.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/storage/psql_dos/utils.py#L105-L106

Added lines #L105 - L106 were not covered by tests


def create_scoped_session_factory(engine, **kwargs):
Expand Down
31 changes: 26 additions & 5 deletions src/aiida/storage/sqlite_temp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@

import functools
import hashlib
import json
import os
import shutil
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from pathlib import Path
from tempfile import mkdtemp
from typing import Any, BinaryIO, Iterator, Sequence

from pydantic import BaseModel, Field
from sqlalchemy import column, insert, update
from sqlalchemy import column, func, insert, update
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import case

from aiida.common.exceptions import ClosedStorage, IntegrityError
from aiida.manage.configuration import Profile
Expand Down Expand Up @@ -268,18 +271,36 @@
result = session.execute(insert(mapper).returning(mapper, column('id')), rows).fetchall()
return [row.id for row in result]

def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None:
def bulk_update(self, entity_type: EntityTypes, rows: list[dict], extend_json: bool = False) -> None:
mapper, keys = self._get_mapper_from_entity(entity_type, True)
if not rows:
return None

cases = defaultdict(list)
id_list = []
for row in rows:
if 'id' not in row:
raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}")
if not keys.issuperset(row):
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
if 'id' in row:
when = mapper.c.id == row['id']
id_list.append(row['id'])
else:
raise IntegrityError(f"neither 'id' nor 'uuid' field given for {entity_type}: {set(row)}")

Check warning on line 288 in src/aiida/storage/sqlite_temp/backend.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/storage/sqlite_temp/backend.py#L288

Added line #L288 was not covered by tests

for key, value in row.items():
if key == 'id':
continue

update_value = value
if extend_json and key in ['extra', 'attributes']:
update_value = func.json_patch(mapper.c[key], json.dumps(value))
cases[key].append((when, update_value))

session = self.get_session()
with nullcontext() if self.in_transaction else self.transaction():
session.execute(update(mapper), rows)
values = {k: case(*v, else_=mapper.c[key]) for k, v in cases.items()}
stmt = update(mapper).where(mapper.c.id.in_(id_list)).values(**values)
session.execute(stmt)

def delete(self) -> None:
"""Delete the storage and all the data."""
Expand Down
2 changes: 1 addition & 1 deletion src/aiida/storage/sqlite_zip/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def in_transaction(self) -> bool:
def bulk_insert(self, entity_type: EntityTypes, rows: list[dict], allow_defaults: bool = False) -> list[int]:
raise ReadOnlyError()

def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None:
def bulk_update(self, entity_type: EntityTypes, rows: list[dict], extend_json: bool = False) -> None:
raise ReadOnlyError()

def delete(self) -> None:
Expand Down
70 changes: 70 additions & 0 deletions tests/orm/implementation/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from aiida.common import exceptions
from aiida.common.links import LinkType
from aiida.orm.entities import EntityTypes
from aiida.storage.sqlite_temp.backend import SqliteTempBackend


class TestBackend:
Expand Down Expand Up @@ -123,6 +124,75 @@ def test_bulk_update(self):
assert users[1].email == 'other1'
assert users[2].email == f'{prefix}-2'

@pytest.mark.parametrize('use_sqlite_temp_backend', (True, False))
def test_bulk_update_extend_json(self, use_sqlite_temp_backend):
backend = self.backend
if use_sqlite_temp_backend:
profile = SqliteTempBackend.create_profile(debug=False)
backend = SqliteTempBackend(profile)

prefix = uuid.uuid4().hex
nodes = [
orm.Dict(
{
'key-string': f'{prefix}-{index}',
'key-integer': index,
'key-null': None,
'key-object': {'k1': 'v1', 'k2': 2},
'key-array': [11, 45, 14],
},
backend=backend,
).store()
for index in range(5)
]
backend.bulk_update(
EntityTypes.NODE,
[
{
'id': nodes[0].pk,
'attributes': {
'key-new': 'foobar',
},
},
{
'id': nodes[1].pk,
'attributes': {
'key-string': ['change type'],
'key-array': [1919, 810],
},
},
{
'id': nodes[2].pk,
'attributes': {
'key-integer': -1,
'key-object': {'k2': 114514},
},
},
],
extend_json=True,
)

# new attribute is added
assert nodes[0].get('key-new') == 'foobar'
# old attributes are kept
assert nodes[0].get('key-string') == f'{prefix}-0'
assert nodes[0].get('key-null') is None
assert nodes[0].get('key-integer') == 0
assert nodes[0].get('key-object') == {'k1': 'v1', 'k2': 2}
assert len(nodes[0].get('key-array')) == 3
assert all(x == y for x, y in zip(nodes[0].get('key-array'), [11, 45, 14]))
# change type
assert isinstance(nodes[1].get('key-string'), list)
assert len(nodes[1].get('key-string')) == 1
assert nodes[1].get('key-string')[0] == 'change type'
# overwrite array
assert len(nodes[1].get('key-array')) == 2
assert all(x == y for x, y in zip(nodes[1].get('key-array'), [1919, 810]))
# overwrite integer
assert nodes[2].get('key-integer') == -1
# merge object
assert nodes[2].get('key-object') == {'k1': 'v1', 'k2': 114514}

def test_bulk_update_in_transaction(self):
"""Test that bulk update in a cancelled transaction is not committed."""
prefix = uuid.uuid4().hex
Expand Down
Loading