Skip to content

Commit

Permalink
SqliteDosStorage: Fix exception when importing archive (#6359)
Browse files Browse the repository at this point in the history
When an archive was imported into an `SqliteDosStorage` backend an
exception was raised by sqlalchemy. It was treating the `uuid` column of
the models as a UUID type but in reality it was a string. This is
because the storage plugin inherits the implementation largely from the
`core.psql_dos` plugin, but it converts the models, since the UUID types
that are used by the PostgreSQL implementation are not supported by
SQLite.

The problem was that for archive importing, the `bulk_insert` method was
used, which calls the `_get_mapper_from_entity` method to map a given
ORM entity to the corresponding database model. But since this method
was inherited from `core.psql_dos`, it was returning the incorrect
models. The problem is fixed by overriding it in `SqliteDosStorage` and
returning the SQLite-adapted models.
  • Loading branch information
sphuber authored Apr 28, 2024
1 parent ffc6e4f commit af0c260
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 12 deletions.
19 changes: 18 additions & 1 deletion src/aiida/storage/sqlite_dos/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from __future__ import annotations

from functools import cached_property
from functools import cached_property, lru_cache
from pathlib import Path
from shutil import rmtree
from typing import TYPE_CHECKING, Optional
Expand All @@ -34,6 +34,7 @@
from ..psql_dos.migrator import REPOSITORY_UUID_KEY, PsqlDosMigrator

if TYPE_CHECKING:
from aiida.orm.entities import EntityTypes
from aiida.repository.backend import DiskObjectStoreRepositoryBackend

__all__ = ('SqliteDosStorage',)
Expand Down Expand Up @@ -208,3 +209,19 @@ def nodes(self):
@cached_property
def users(self):
return orm.SqliteUserCollection(self)

@staticmethod
@lru_cache(maxsize=18)
def _get_mapper_from_entity(entity_type: 'EntityTypes', with_pk: bool):
"""Return the Sqlalchemy mapper and fields corresponding to the given entity.
:param with_pk: if True, the fields returned will include the primary key
"""
from sqlalchemy import inspect

from ..sqlite_zip.models import MAP_ENTITY_TYPE_TO_MODEL

model = MAP_ENTITY_TYPE_TO_MODEL[entity_type]
mapper = inspect(model).mapper # type: ignore[union-attr]
keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key}
return mapper, keys
24 changes: 13 additions & 11 deletions src/aiida/storage/sqlite_zip/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,21 +158,23 @@ def create_orm_cls(klass: base.Base) -> SqliteBase:
),
)

MAP_ENTITY_TYPE_TO_MODEL = {
EntityTypes.USER: DbUser,
EntityTypes.AUTHINFO: DbAuthInfo,
EntityTypes.GROUP: DbGroup,
EntityTypes.NODE: DbNode,
EntityTypes.COMMENT: DbComment,
EntityTypes.COMPUTER: DbComputer,
EntityTypes.LOG: DbLog,
EntityTypes.LINK: DbLink,
EntityTypes.GROUP_NODE: DbGroupNodes,
}


@functools.lru_cache(maxsize=10)
def get_model_from_entity(entity_type: EntityTypes) -> Tuple[Any, Set[str]]:
"""Return the Sqlalchemy model and column names corresponding to the given entity."""
model = {
EntityTypes.USER: DbUser,
EntityTypes.AUTHINFO: DbAuthInfo,
EntityTypes.GROUP: DbGroup,
EntityTypes.NODE: DbNode,
EntityTypes.COMMENT: DbComment,
EntityTypes.COMPUTER: DbComputer,
EntityTypes.LOG: DbLog,
EntityTypes.LINK: DbLink,
EntityTypes.GROUP_NODE: DbGroupNodes,
}[entity_type]
model = MAP_ENTITY_TYPE_TO_MODEL[entity_type]
mapper = sa.inspect(model).mapper
column_names = {col.name for col in mapper.c.values()}
return model, column_names
13 changes: 13 additions & 0 deletions tests/storage/sqlite_dos/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,16 @@ def test_model():
filepath = pathlib.Path.cwd() / 'archive.aiida'
model = SqliteDosStorage.Model(filepath=filepath.name)
assert pathlib.Path(model.filepath).is_absolute()


def test_archive_import(aiida_config, aiida_profile_factory):
"""Test that archives can be imported."""
from aiida.orm import Node, QueryBuilder
from aiida.tools.archive.imports import import_archive

from tests.utils.archives import get_archive_file

with aiida_profile_factory(aiida_config, storage_backend='core.sqlite_dos'):
assert QueryBuilder().append(Node).count() == 0
import_archive(get_archive_file('calcjob/arithmetic.add.aiida'))
assert QueryBuilder().append(Node).count() > 0

0 comments on commit af0c260

Please sign in to comment.