diff --git a/aiida/storage/psql_dos/backend.py b/aiida/storage/psql_dos/backend.py index 315af918d5..e346ac1132 100644 --- a/aiida/storage/psql_dos/backend.py +++ b/aiida/storage/psql_dos/backend.py @@ -172,24 +172,21 @@ def _clear(self) -> None: with self.migrator_context(self._profile) as migrator: - # First clear the contents of the database - with self.transaction() as session: - - # Close the session otherwise the ``delete_tables`` call will hang as there will be an open connection - # to the PostgreSQL server and it will block the deletion and the command will hang. - self.get_session().close() - exclude_tables = [migrator.alembic_version_tbl_name, 'db_dbsetting'] - migrator.delete_all_tables(exclude_tables=exclude_tables) + # Close the session otherwise the ``delete_tables`` call will hang as there will be an open connection + # to the PostgreSQL server and it will block the deletion and the command will hang. + self.get_session().close() + exclude_tables = [migrator.alembic_version_tbl_name, 'db_dbsetting'] + migrator.delete_all_tables(exclude_tables=exclude_tables) - # Clear out all references to database model instances which are now invalid. - session.expunge_all() + # Clear out all references to database model instances which are now invalid. + self.get_session().expunge_all() # Now reset and reinitialise the repository migrator.reset_repository() migrator.initialise_repository() repository_uuid = migrator.get_repository_uuid() - with self.transaction(): + with self.transaction() as session: session.execute( DbSetting.__table__.update().where(DbSetting.key == REPOSITORY_UUID_KEY ).values(val=repository_uuid) @@ -243,13 +240,15 @@ def transaction(self) -> Iterator[Session]: """ session = self.get_session() if session.in_transaction(): - with session.begin_nested(): + with session.begin_nested() as savepoint: yield session + savepoint.commit() session.commit() else: with session.begin(): - with session.begin_nested(): + with session.begin_nested() as savepoint: yield session + savepoint.commit() @property def in_transaction(self) -> bool: diff --git a/aiida/storage/psql_dos/orm/querybuilder/main.py b/aiida/storage/psql_dos/orm/querybuilder/main.py index a9585a3387..23aa3e591b 100644 --- a/aiida/storage/psql_dos/orm/querybuilder/main.py +++ b/aiida/storage/psql_dos/orm/querybuilder/main.py @@ -171,9 +171,8 @@ def iterall(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[Li # on the session when a yielded row is mutated. This would reset the cursor invalidating it and causing an # exception to be raised in the next batch of rows in the iteration. # See https://github.com/python/mypy/issues/10109 for the reason of the type warning. - in_nested_transaction = session.in_nested_transaction() - - with nullcontext() if in_nested_transaction else session.begin_nested(): # type: ignore[attr-defined] + with nullcontext() if session.in_nested_transaction() else self._backend.transaction( + ): # type: ignore[attr-defined] for resultrow in session.execute(stmt): yield [self.to_backend(rowitem) for rowitem in resultrow] @@ -188,9 +187,8 @@ def iterdict(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[D # on the session when a yielded row is mutated. This would reset the cursor invalidating it and causing an # exception to be raised in the next batch of rows in the iteration. # See https://github.com/python/mypy/issues/10109 for the reason of the type warning. - in_nested_transaction = session.in_nested_transaction() - - with nullcontext() if in_nested_transaction else session.begin_nested(): # type: ignore[attr-defined] + with nullcontext() if session.in_nested_transaction() else self._backend.transaction( + ): # type: ignore[attr-defined] for row in self.get_session().execute(stmt): # build the yield result yield_result: Dict[str, Dict[str, Any]] = {} diff --git a/aiida/storage/sqlite_zip/backend.py b/aiida/storage/sqlite_zip/backend.py index afc24807f9..dd58451bf1 100644 --- a/aiida/storage/sqlite_zip/backend.py +++ b/aiida/storage/sqlite_zip/backend.py @@ -253,8 +253,19 @@ def users(self): def _clear(self) -> None: raise ReadOnlyError() + @contextmanager def transaction(self): - raise ReadOnlyError() + session = self.get_session() + if session.in_transaction(): + with session.begin_nested() as savepoint: + yield session + savepoint.commit() + session.commit() + else: + with session.begin(): + with session.begin_nested() as savepoint: + yield session + savepoint.commit() @property def in_transaction(self) -> bool: diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index 341e08836c..4a7ce094c8 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -1567,6 +1567,38 @@ def test_iterall_with_store_group(self): for pk, pk_clone in zip(pks, [e[1] for e in sorted(pks_clone)]): assert orm.load_node(pk) == orm.load_node(pk_clone) + @pytest.mark.usefixtures('aiida_profile_clean') + def test_iterall_persistence(self, manager): + """Test that mutations made during ``QueryBuilder.iterall`` context are automatically committed and persisted. + + This is a regression test for https://github.com/aiidateam/aiida-core/issues/6133 . + """ + count = 10 + + # Create number of nodes with specific extra + for _ in range(count): + node = orm.Data().store() + node.base.extras.set('testing', True) + + query = orm.QueryBuilder().append(orm.Data, filters={'extras': {'has_key': 'testing'}}) + assert query.count() == count + + # Unload and reload the storage, which will reset the session and check that the nodes with extras still exist + manager.reset_profile_storage() + manager.get_profile_storage() + assert query.count() == count + + # Delete the extras and check that the query now matches 0 + for [node] in orm.QueryBuilder().append(orm.Data).iterall(batch_size=2): + node.base.extras.delete('testing') + + assert query.count() == 0 + + # Finally, reset the storage again and verify the changes have been persisted + manager.reset_profile_storage() + manager.get_profile_storage() + assert query.count() == 0 + class TestManager: