Skip to content

Commit

Permalink
Tidy reflect function
Browse files Browse the repository at this point in the history
  • Loading branch information
stuartmcalpine committed Dec 19, 2024
1 parent 6710663 commit e37a038
Showing 1 changed file with 40 additions and 20 deletions.
60 changes: 40 additions & 20 deletions src/dataregistry/db_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,43 @@ def _reflect(self):
are extracted and stored in the `self.metadata` dict.
"""

def _get_db_info(prov_table, get_associated_production=False):
"""
Get provenance information (version and associated production
schema) from provenance table.
Parameters
----------
prov_table : SqlAlchemy metadata
get_associated_production : bool, optional
Returns
-------
schema_version : str
associated_production schema : str
If get_associated_production=True
"""

# Columns to query
cols = ["db_version_major", "db_version_minor", "db_version_patch"]
if get_associated_production:
cols.append("associated_production")

# Execute query
stmt = select(*[column(c) for c in cols]).select_from(prov_table)
stmt = stmt.order_by(prov_table.c.provenance_id.desc())
with self.engine.connect() as conn:
results = conn.execute(stmt)
r = results.fetchone()
if r is None:
raise DataRegistryException(
"During reflection no provenance information was found")

if get_associated_production:
return f"{r[0]}.{r[1]}.{r[2]}", r[3]
else:
return f"{r[0]}.{r[1]}.{r[2]}"

# Reflect the working schema to find database tables
metadata = MetaData(schema=self.schema)
metadata.reflect(self.engine, self.schema)
Expand All @@ -232,32 +269,15 @@ def _reflect(self):
return

# From the procenance table get the associated production schema
cols = ["db_version_major", "db_version_minor", "db_version_patch", "associated_production"]
prov_table = metadata.tables[prov_name]
stmt = select(*[column(c) for c in cols]).select_from(prov_table)
stmt = stmt.order_by(prov_table.c.provenance_id.desc())
with self.engine.connect() as conn:
results = conn.execute(stmt)
r = results.fetchone()
if r is None:
raise DataRegistryException(
"During reflection no provenance information was found")
self._prod_schema = r[3]
self.metadata["schema_version"] = f"{r[0]}.{r[1]}.{r[2]}"
self.metadata["schema_version"], self._prod_schema = _get_db_info(prov_table, get_associated_production=True)

# Add production schema tables to metadata
if self.dialect != "sqlite":
metadata.reflect(self.engine, self._prod_schema)
cols.remove("associated_production")
prov_name = ".".join([self._prod_schema, "provenance"])
stmt = select(*[column(c) for c in cols]).select_from(prov_table)
stmt = stmt.order_by(prov_table.c.provenance_id.desc())
with self.engine.connect() as conn:
results = conn.execute(stmt)
r = results.fetchone()
if r is None:
raise DataRegistryException("Cannot find production provenance table")
self.metadata["prod_schema_version"] = f"{r[0]}.{r[1]}.{r[2]}"
prov_table = metadata.tables[prov_name]
self.metadata["prod_schema_version"] = _get_db_info(prov_table)
else:
self.metadata["prod_schema_version"] = None

Expand Down

0 comments on commit e37a038

Please sign in to comment.