From af0c260bb1c32c3b33c50175d790907774561b3e Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Sun, 28 Apr 2024 09:44:04 +0200 Subject: [PATCH] `SqliteDosStorage`: Fix exception when importing archive (#6359) When an archive was imported into an `SqliteDosStorage` backend an exception was raised by sqlalchemy. It was treating the `uuid` column of the models as a UUID type but in reality it was a string. This is because the storage plugin inherits the implementation largely from the `core.psql_dos` plugin, but it converts the models, since the UUID types that are used by the PostgreSQL implementation are not supported by SQLite. The problem was that for archive importing, the `bulk_insert` method was used, which calls the `_get_mapper_from_entity` method to map a given ORM entity to the corresponding database model. But since this method was inherited from `core.psql_dos`, it was returning the incorrect models. The problem is fixed by overriding it in `SqliteDosStorage` and returning the SQLite-adapted models. --- src/aiida/storage/sqlite_dos/backend.py | 19 ++++++++++++++++++- src/aiida/storage/sqlite_zip/models.py | 24 +++++++++++++----------- tests/storage/sqlite_dos/test_backend.py | 13 +++++++++++++ 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/src/aiida/storage/sqlite_dos/backend.py b/src/aiida/storage/sqlite_dos/backend.py index 890e082914..dd2a48f031 100644 --- a/src/aiida/storage/sqlite_dos/backend.py +++ b/src/aiida/storage/sqlite_dos/backend.py @@ -10,7 +10,7 @@ from __future__ import annotations -from functools import cached_property +from functools import cached_property, lru_cache from pathlib import Path from shutil import rmtree from typing import TYPE_CHECKING, Optional @@ -34,6 +34,7 @@ from ..psql_dos.migrator import REPOSITORY_UUID_KEY, PsqlDosMigrator if TYPE_CHECKING: + from aiida.orm.entities import EntityTypes from aiida.repository.backend import DiskObjectStoreRepositoryBackend __all__ = ('SqliteDosStorage',) @@ -208,3 +209,19 @@ def nodes(self): @cached_property def users(self): return orm.SqliteUserCollection(self) + + @staticmethod + @lru_cache(maxsize=18) + def _get_mapper_from_entity(entity_type: 'EntityTypes', with_pk: bool): + """Return the Sqlalchemy mapper and fields corresponding to the given entity. + + :param with_pk: if True, the fields returned will include the primary key + """ + from sqlalchemy import inspect + + from ..sqlite_zip.models import MAP_ENTITY_TYPE_TO_MODEL + + model = MAP_ENTITY_TYPE_TO_MODEL[entity_type] + mapper = inspect(model).mapper # type: ignore[union-attr] + keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key} + return mapper, keys diff --git a/src/aiida/storage/sqlite_zip/models.py b/src/aiida/storage/sqlite_zip/models.py index 8303120f01..6358d60d26 100644 --- a/src/aiida/storage/sqlite_zip/models.py +++ b/src/aiida/storage/sqlite_zip/models.py @@ -158,21 +158,23 @@ def create_orm_cls(klass: base.Base) -> SqliteBase: ), ) +MAP_ENTITY_TYPE_TO_MODEL = { + EntityTypes.USER: DbUser, + EntityTypes.AUTHINFO: DbAuthInfo, + EntityTypes.GROUP: DbGroup, + EntityTypes.NODE: DbNode, + EntityTypes.COMMENT: DbComment, + EntityTypes.COMPUTER: DbComputer, + EntityTypes.LOG: DbLog, + EntityTypes.LINK: DbLink, + EntityTypes.GROUP_NODE: DbGroupNodes, +} + @functools.lru_cache(maxsize=10) def get_model_from_entity(entity_type: EntityTypes) -> Tuple[Any, Set[str]]: """Return the Sqlalchemy model and column names corresponding to the given entity.""" - model = { - EntityTypes.USER: DbUser, - EntityTypes.AUTHINFO: DbAuthInfo, - EntityTypes.GROUP: DbGroup, - EntityTypes.NODE: DbNode, - EntityTypes.COMMENT: DbComment, - EntityTypes.COMPUTER: DbComputer, - EntityTypes.LOG: DbLog, - EntityTypes.LINK: DbLink, - EntityTypes.GROUP_NODE: DbGroupNodes, - }[entity_type] + model = MAP_ENTITY_TYPE_TO_MODEL[entity_type] mapper = sa.inspect(model).mapper column_names = {col.name for col in mapper.c.values()} return model, column_names diff --git a/tests/storage/sqlite_dos/test_backend.py b/tests/storage/sqlite_dos/test_backend.py index cbb778a5a3..43460f67df 100644 --- a/tests/storage/sqlite_dos/test_backend.py +++ b/tests/storage/sqlite_dos/test_backend.py @@ -12,3 +12,16 @@ def test_model(): filepath = pathlib.Path.cwd() / 'archive.aiida' model = SqliteDosStorage.Model(filepath=filepath.name) assert pathlib.Path(model.filepath).is_absolute() + + +def test_archive_import(aiida_config, aiida_profile_factory): + """Test that archives can be imported.""" + from aiida.orm import Node, QueryBuilder + from aiida.tools.archive.imports import import_archive + + from tests.utils.archives import get_archive_file + + with aiida_profile_factory(aiida_config, storage_backend='core.sqlite_dos'): + assert QueryBuilder().append(Node).count() == 0 + import_archive(get_archive_file('calcjob/arithmetic.add.aiida')) + assert QueryBuilder().append(Node).count() > 0