diff --git a/src/dataregistry/db_basic.py b/src/dataregistry/db_basic.py index 833dafb6..b1a4118b 100644 --- a/src/dataregistry/db_basic.py +++ b/src/dataregistry/db_basic.py @@ -182,7 +182,7 @@ def __init__(self, db_connection, get_db_version=True): f"listed tables are {self._metadata.tables}" ) - if prov_name in self._metadata.tables and get_db_version: + if get_db_version: prov_table = self._metadata.tables[prov_name] stmt = select(column("associated_production")).select_from(prov_table) stmt = stmt.order_by(prov_table.c.provenance_id.desc()) @@ -206,6 +206,14 @@ def __init__(self, db_connection, get_db_version=True): self._db_major = None self._db_minor = None self._db_patch = None + self._prod_schema = None + + @property + def is_production_schema(self): + if self._prod_schema == self._schema: + return True + else: + return False @property def db_version_major(self): diff --git a/src/dataregistry/registrar/dataset.py b/src/dataregistry/registrar/dataset.py index 9b5e8a4a..b10f6793 100644 --- a/src/dataregistry/registrar/dataset.py +++ b/src/dataregistry/registrar/dataset.py @@ -7,7 +7,6 @@ from dataregistry.db_basic import add_table_row from dataregistry.exceptions import DataRegistryRootDirBadState -from dataregistry.schema import DEFAULT_SCHEMA_PRODUCTION from sqlalchemy import select, update from functools import wraps @@ -108,7 +107,9 @@ def _validate_register_inputs( raise ValueError("Cannot overwrite production entries") if kwargs_dict["version_suffix"] is not None: raise ValueError("Production entries can't have version suffix") - if self._schema != DEFAULT_SCHEMA_PRODUCTION and not kwargs_dict["test_production"]: + if (not self._metadata_getter.is_production_schema) and ( + not kwargs_dict["test_production"] + ): raise ValueError( "Only the production schema can handle owner_type='production'" ) @@ -117,7 +118,10 @@ def _validate_register_inputs( if kwargs_dict["owner"] != "production": raise ValueError("`owner` for production datasets must be 'production'") else: - if self._schema == DEFAULT_SCHEMA_PRODUCTION or kwargs_dict["test_production"]: + if ( + self._metadata_getter.is_production_schema + or kwargs_dict["test_production"] + ): raise ValueError( "Only owner_type='production' can go in the production schema" )