diff --git a/scripts/create_registry_schema.py b/scripts/create_registry_schema.py index a8103d73..cef27f0d 100644 --- a/scripts/create_registry_schema.py +++ b/scripts/create_registry_schema.py @@ -264,42 +264,6 @@ def _BuildTable(schema, table_name, has_production, production): return Model -def _Keyword(schema): - """Stores the list of keywords.""" - - class_name = f"{schema}_keyword" - - # Load columns from `schema.yaml` file - columns = _get_column_definitions(schema, "keyword") - - # Table metadata - meta = { - "__tablename__": "keyword", - "__table_args__": ( - UniqueConstraint("keyword", name="keyword_u_keyword"), - {"schema": schema}, - ), - } - - Model = type(class_name, (Base,), {**columns, **meta}) - return Model - - -def _DatasetKeyword(schema): - """Many-Many link between datasets and keywords.""" - - class_name = f"{schema}_dataset_keyword" - - # Load columns from `schema.yaml` file - columns = _get_column_definitions(schema, "dataset_keyword") - - # Table metadata - meta = {"__tablename__": "dataset_keyword", "__table_args__": {"schema": schema}} - - Model = type(class_name, (Base,), {**columns, **meta}) - return Model - - # ---------------- # Database version # ---------------- @@ -390,34 +354,12 @@ def _DatasetKeyword(schema): > _DB_VERSION_MINOR ): raise RuntimeError("production schema version incompatible") - - if schema: - stmt = f"CREATE SCHEMA IF NOT EXISTS {schema}" - with db_connection.engine.connect() as conn: - conn.execute(text(stmt)) - conn.commit() - - # Grant reg_reader access - try: - with db_connection.engine.connect() as conn: - # Grant reg_reader access. - acct = "reg_reader" - usage_prv = f"GRANT USAGE ON SCHEMA {schema} to {acct}" - select_prv = f"GRANT SELECT ON ALL TABLES IN SCHEMA {schema} to {acct}" - conn.execute(text(usage_prv)) - conn.execute(text(select_prv)) - - if schema == prod_schema: # also grant privileges to reg_writer - acct = "reg_writer" - usage_priv = f"GRANT USAGE ON SCHEMA {schema} to {acct}" - select_priv = ( - f"GRANT SELECT ON ALL TABLES IN SCHEMA {schema} to {acct}" - ) - conn.execute(text(usage_priv)) - conn.execute(text(select_priv)) - conn.commit() - except Exception as e: - print(f"Could not grant access to {acct} on schema {schema}") + + # Create the schema + stmt = f"CREATE SCHEMA IF NOT EXISTS {schema}" + with db_connection.engine.connect() as conn: + conn.execute(text(stmt)) + conn.commit() # Create the tables for table_name in schema_data.keys(): @@ -425,9 +367,8 @@ def _DatasetKeyword(schema): print(f"Built table {table_name} in {schema}") # Generate the database - if schema: - if schema != prod_schema: - Base.metadata.reflect(db_connection.engine, prod_schema) + if schema != prod_schema: + Base.metadata.reflect(db_connection.engine, prod_schema) Base.metadata.create_all(db_connection.engine) # Grant access to other accounts. Can only grant access to objects @@ -440,16 +381,22 @@ def _DatasetKeyword(schema): select_prv = f"GRANT SELECT ON ALL TABLES IN SCHEMA {schema} to {acct}" conn.execute(text(usage_prv)) conn.execute(text(select_prv)) + conn.commit() + except Exception: + print(f"Could not grant access to {acct} on schema {schema}") - if schema == prod_schema: # also grant privileges to reg_writer + if schema == prod_schema: + try: + with db_connection.engine.connect() as conn: + # Grant reg_writer access. acct = "reg_writer" usage_priv = f"GRANT USAGE ON SCHEMA {schema} to {acct}" select_priv = f"GRANT SELECT ON ALL TABLES IN SCHEMA {schema} to {acct}" conn.execute(text(usage_priv)) conn.execute(text(select_priv)) - conn.commit() - except Exception: - print(f"Could not grant access to {acct} on schema {schema}") + conn.commit() + except Exception: + print(f"Could not grant access to {acct} on schema {schema}") # Add initial provenance information prov_id = _insert_provenance( diff --git a/src/dataregistry/registrar/base_table_class.py b/src/dataregistry/registrar/base_table_class.py index 9f06ad77..f47111a5 100644 --- a/src/dataregistry/registrar/base_table_class.py +++ b/src/dataregistry/registrar/base_table_class.py @@ -59,7 +59,7 @@ def __init__(self, db_connection, root_dir, owner, owner_type): self._dialect = db_connection._dialect # Link to Table Metadata. - self._metadata_getter = TableMetadata(db_connection) + self._table_metadata = TableMetadata(db_connection) # Store user id self._uid = os.getenv("USER") @@ -78,7 +78,7 @@ def __init__(self, db_connection, root_dir, owner, owner_type): self.schema_yaml = load_schema() def _get_table_metadata(self, tbl): - return self._metadata_getter.get(tbl) + return self._table_metadata.get(tbl) def delete(self, entry_id): """ diff --git a/src/dataregistry/registrar/dataset.py b/src/dataregistry/registrar/dataset.py index 35e659fa..b90a654e 100644 --- a/src/dataregistry/registrar/dataset.py +++ b/src/dataregistry/registrar/dataset.py @@ -107,7 +107,7 @@ def _validate_register_inputs( raise ValueError("Cannot overwrite production entries") if kwargs_dict["version_suffix"] is not None: raise ValueError("Production entries can't have version suffix") - if (not self._metadata_getter.is_production_schema) and ( + if (not self._table_metadata.is_production_schema) and ( not kwargs_dict["test_production"] ): raise ValueError( @@ -119,7 +119,7 @@ def _validate_register_inputs( raise ValueError("`owner` for production datasets must be 'production'") else: if self._dialect != "sqlite" and not kwargs_dict["test_production"]: - if self._metadata_getter.is_production_schema: + if self._table_metadata.is_production_schema: raise ValueError( "Only owner_type='production' can go in the production schema" )