Skip to content

Commit

Permalink
Address reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
stuartmcalpine committed Sep 17, 2024
1 parent 031dd48 commit 13e4c41
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 75 deletions.
89 changes: 18 additions & 71 deletions scripts/create_registry_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,42 +264,6 @@ def _BuildTable(schema, table_name, has_production, production):
return Model


def _Keyword(schema):
"""Stores the list of keywords."""

class_name = f"{schema}_keyword"

# Load columns from `schema.yaml` file
columns = _get_column_definitions(schema, "keyword")

# Table metadata
meta = {
"__tablename__": "keyword",
"__table_args__": (
UniqueConstraint("keyword", name="keyword_u_keyword"),
{"schema": schema},
),
}

Model = type(class_name, (Base,), {**columns, **meta})
return Model


def _DatasetKeyword(schema):
"""Many-Many link between datasets and keywords."""

class_name = f"{schema}_dataset_keyword"

# Load columns from `schema.yaml` file
columns = _get_column_definitions(schema, "dataset_keyword")

# Table metadata
meta = {"__tablename__": "dataset_keyword", "__table_args__": {"schema": schema}}

Model = type(class_name, (Base,), {**columns, **meta})
return Model


# ----------------
# Database version
# ----------------
Expand Down Expand Up @@ -390,44 +354,21 @@ def _DatasetKeyword(schema):
> _DB_VERSION_MINOR
):
raise RuntimeError("production schema version incompatible")

if schema:
stmt = f"CREATE SCHEMA IF NOT EXISTS {schema}"
with db_connection.engine.connect() as conn:
conn.execute(text(stmt))
conn.commit()

# Grant reg_reader access
try:
with db_connection.engine.connect() as conn:
# Grant reg_reader access.
acct = "reg_reader"
usage_prv = f"GRANT USAGE ON SCHEMA {schema} to {acct}"
select_prv = f"GRANT SELECT ON ALL TABLES IN SCHEMA {schema} to {acct}"
conn.execute(text(usage_prv))
conn.execute(text(select_prv))

if schema == prod_schema: # also grant privileges to reg_writer
acct = "reg_writer"
usage_priv = f"GRANT USAGE ON SCHEMA {schema} to {acct}"
select_priv = (
f"GRANT SELECT ON ALL TABLES IN SCHEMA {schema} to {acct}"
)
conn.execute(text(usage_priv))
conn.execute(text(select_priv))
conn.commit()
except Exception as e:
print(f"Could not grant access to {acct} on schema {schema}")

# Create the schema
stmt = f"CREATE SCHEMA IF NOT EXISTS {schema}"
with db_connection.engine.connect() as conn:
conn.execute(text(stmt))
conn.commit()

# Create the tables
for table_name in schema_data.keys():
_BuildTable(schema, table_name, db_connection.dialect != "sqlite", prod_schema)
print(f"Built table {table_name} in {schema}")

# Generate the database
if schema:
if schema != prod_schema:
Base.metadata.reflect(db_connection.engine, prod_schema)
if schema != prod_schema:
Base.metadata.reflect(db_connection.engine, prod_schema)
Base.metadata.create_all(db_connection.engine)

# Grant access to other accounts. Can only grant access to objects
Expand All @@ -440,16 +381,22 @@ def _DatasetKeyword(schema):
select_prv = f"GRANT SELECT ON ALL TABLES IN SCHEMA {schema} to {acct}"
conn.execute(text(usage_prv))
conn.execute(text(select_prv))
conn.commit()
except Exception:
print(f"Could not grant access to {acct} on schema {schema}")

if schema == prod_schema: # also grant privileges to reg_writer
if schema == prod_schema:
try:
with db_connection.engine.connect() as conn:
# Grant reg_writer access.
acct = "reg_writer"
usage_priv = f"GRANT USAGE ON SCHEMA {schema} to {acct}"
select_priv = f"GRANT SELECT ON ALL TABLES IN SCHEMA {schema} to {acct}"
conn.execute(text(usage_priv))
conn.execute(text(select_priv))
conn.commit()
except Exception:
print(f"Could not grant access to {acct} on schema {schema}")
conn.commit()
except Exception:
print(f"Could not grant access to {acct} on schema {schema}")

# Add initial provenance information
prov_id = _insert_provenance(
Expand Down
4 changes: 2 additions & 2 deletions src/dataregistry/registrar/base_table_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, db_connection, root_dir, owner, owner_type):
self._dialect = db_connection._dialect

# Link to Table Metadata.
self._metadata_getter = TableMetadata(db_connection)
self._table_metadata = TableMetadata(db_connection)

# Store user id
self._uid = os.getenv("USER")
Expand All @@ -78,7 +78,7 @@ def __init__(self, db_connection, root_dir, owner, owner_type):
self.schema_yaml = load_schema()

def _get_table_metadata(self, tbl):
return self._metadata_getter.get(tbl)
return self._table_metadata.get(tbl)

def delete(self, entry_id):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/dataregistry/registrar/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _validate_register_inputs(
raise ValueError("Cannot overwrite production entries")
if kwargs_dict["version_suffix"] is not None:
raise ValueError("Production entries can't have version suffix")
if (not self._metadata_getter.is_production_schema) and (
if (not self._table_metadata.is_production_schema) and (
not kwargs_dict["test_production"]
):
raise ValueError(
Expand All @@ -119,7 +119,7 @@ def _validate_register_inputs(
raise ValueError("`owner` for production datasets must be 'production'")
else:
if self._dialect != "sqlite" and not kwargs_dict["test_production"]:
if self._metadata_getter.is_production_schema:
if self._table_metadata.is_production_schema:
raise ValueError(
"Only owner_type='production' can go in the production schema"
)
Expand Down

0 comments on commit 13e4c41

Please sign in to comment.