diff --git a/python/lsst/rubintv/analysis/service/commands/butler.py b/python/lsst/rubintv/analysis/service/commands/butler.py index 9de7141..65cbb38 100644 --- a/python/lsst/rubintv/analysis/service/commands/butler.py +++ b/python/lsst/rubintv/analysis/service/commands/butler.py @@ -24,13 +24,11 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -from lsst.afw.cameraGeom import FOCAL_PLANE, Camera, Detector -from lsst.obs.lsst import Latiss, LsstCam, LsstComCam - from ..command import BaseCommand if TYPE_CHECKING: from ..data import DataCenter + from lsst.afw.cameraGeom import Camera def get_camera(instrument_name: str) -> Camera: @@ -46,6 +44,9 @@ def get_camera(instrument_name: str) -> Camera: camera : Camera The camera object. """ + # Import afw packages here to prevent tests from failing + from lsst.obs.lsst import Latiss, LsstCam, LsstComCam + instrument_name = instrument_name.lower() match instrument_name: case "lsstcam": @@ -73,6 +74,9 @@ class LoadDetectorInfoCommand(BaseCommand): response_type: str = "detector_info" def build_contents(self, data_center: DataCenter) -> dict: + # Import afw packages here to prevent tests from failing + from lsst.afw.cameraGeom import FOCAL_PLANE, Detector + # Load the detector information from the Butler camera = get_camera(self.instrument) detector_info = {} diff --git a/python/lsst/rubintv/analysis/service/commands/db.py b/python/lsst/rubintv/analysis/service/commands/db.py index 8486fea..e82870e 100644 --- a/python/lsst/rubintv/analysis/service/commands/db.py +++ b/python/lsst/rubintv/analysis/service/commands/db.py @@ -25,9 +25,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -from lsst.afw.cameraGeom import FOCAL_PLANE -from lsst.obs.lsst import Latiss, LsstCam, LsstComCam, LsstComCamSim - from ..command import BaseCommand from ..database import exposure_tables, visit1_tables from ..query import EqualityQuery, ParentQuery, Query @@ -170,6 +167,9 @@ class LoadInstrumentCommand(BaseCommand): response_type: str = "instrument info" def build_contents(self, data_center: DataCenter) -> dict: + from lsst.afw.cameraGeom import FOCAL_PLANE + from lsst.obs.lsst import Latiss, LsstCam, LsstComCam, LsstComCamSim + instrument = self.instrument.lower() match instrument: diff --git a/python/lsst/rubintv/analysis/service/database.py b/python/lsst/rubintv/analysis/service/database.py index 4910e8e..85a8e09 100644 --- a/python/lsst/rubintv/analysis/service/database.py +++ b/python/lsst/rubintv/analysis/service/database.py @@ -22,8 +22,6 @@ from __future__ import annotations import logging -from abc import ABC, abstractmethod -from typing import cast import sqlalchemy @@ -82,78 +80,68 @@ def get_table_schema(schema: dict, table: str) -> dict: raise UnrecognizedTableError("Could not find the table '{table}' in database") -class Join(ABC): - """A join between two tables in a database. +class EnhancedJoinBuilder: + def __init__(self, tables: dict[str, sqlalchemy.Table], joins: list[dict]): + self.tables = tables + self.joins = joins + self.join_graph = self._build_join_graph() + + def _build_join_graph(self) -> dict[str, dict[str, list[str]]]: + graph = {table: {} for table in self.tables} + for join in self.joins: + tables = list(join["matches"].keys()) + t1, t2 = tables[0], tables[1] + join_columns = list(zip(join["matches"][t1], join["matches"][t2])) + graph[t1][t2] = join_columns + graph[t2][t1] = [(col2, col1) for col1, col2 in join_columns] + return graph + + def _find_join_path(self, start: str, end: str) -> list[str]: + queue = [(start, [start])] + visited = set() + + while queue: + (node, path) = queue.pop(0) + if node not in visited: + if node == end: + return path + visited.add(node) + for neighbor in self.join_graph[node]: + if neighbor not in visited: + queue.append((neighbor, path + [neighbor])) + return [] - Attributes - ---------- - join_type : - The type of join. For now only "inner" joins are supported. - """ - - join_type: str - - @abstractmethod - def __call__(self, database: ConsDbSchema): - pass - - -class InnerJoin(Join): - """An inner join between two tables in a database. - - Attributes - ---------- - n_columns : - The number of columns in the join. - matches : - Dictionary with table names as keys and tuples of column names as - values in the order in which they are matched in the join. - """ - - n_columns: int - matches: dict[str, tuple[str, ...]] - - def __init__(self, matches: dict[str, tuple[str, ...]]): - self.join_type = "inner" - if len(matches) != 2: - raise ValueError(f"Inner joins must have exactly two tables: got {len(matches)}") - - n_columns = 0 - for _, fields in matches.items(): - if n_columns == 0: - n_columns = len(fields) - else: - if n_columns != len(fields): - raise ValueError( - "Inner joins must have the same number of fields for each table: " - f"got {n_columns} and {len(fields)}" - ) - self.n_columns = n_columns - self.matches = matches + def build_join(self, table_names: set[str]) -> sqlalchemy.Table | sqlalchemy.Join: + tables = list(table_names) + select_from = self.tables[tables[0]] - def __call__(self, database: ConsDbSchema): - """Create the sqlalchemy join between the two tables. + for i in range(1, len(tables)): + current_table = tables[i] + previous_table = tables[i - 1] + join_path = self._find_join_path(previous_table, current_table) + + if not join_path: + raise ValueError(f"No join path found between {previous_table} and {current_table}") + + for j in range(1, len(join_path)): + t1, t2 = join_path[j - 1], join_path[j] + join_conditions = [] + for col1, col2 in self.join_graph[t1][t2]: + try: + condition = self.tables[t1].columns[col1] == self.tables[t2].columns[col2] + join_conditions.append(condition) + except KeyError as e: + logger.error(f"Column not found: {e}") + logger.error(f"Available columns in {t1}: {self.tables[t1].columns.keys()}") + logger.error(f"Available columns in {t2}: {self.tables[t2].columns.keys()}") + raise + + if not join_conditions: + raise ValueError(f"No valid join conditions found between {t1} and {t2}") + + select_from = sqlalchemy.join(select_from, self.tables[t2], *join_conditions) - Parameters - ---------- - database : - The database connection. - """ - tables = tuple(self.matches.keys()) - table1 = tables[0] - table2 = tables[1] - table_model1 = database.tables[table1] - table_model2 = database.tables[table2] - joins = [] - print("matches is", self.matches) - print("table1 is", table1) - print("table2 is", table2) - for index in range(self.n_columns): - joins.append( - table_model1.columns[self.matches[table1][index]] - == table_model2.columns[self.matches[table2][index]] - ) - return sqlalchemy.and_(*joins) + return select_from class ConsDbSchema: @@ -175,23 +163,13 @@ class ConsDbSchema: schema: dict metadata: sqlalchemy.MetaData tables: dict[str, sqlalchemy.Table] - joins: dict[str, tuple[Join, ...]] + joins: EnhancedJoinBuilder def __init__(self, engine: sqlalchemy.engine.Engine, schema: dict, join_templates: list): self.engine = engine self.schema = schema self.metadata = sqlalchemy.MetaData() - joins = {} - for join in join_templates: - if join["type"] == "inner": - if "inner" not in joins: - joins["inner"] = [] - joins["inner"].append(InnerJoin(join["matches"])) - else: - raise NotImplementedError(f"Join type {join['type']} is not implemented") - self.joins = {key: tuple(value) for key, value in joins.items()} - self.tables = {} for table in schema["tables"]: if ( @@ -209,6 +187,8 @@ def __init__(self, engine: sqlalchemy.engine.Engine, schema: dict, join_template schema=schema["name"], ) + self.joins = EnhancedJoinBuilder(self.tables, join_templates) + def get_table_names(self) -> tuple[str, ...]: """Given a schema, return a list of dataset names @@ -267,41 +247,6 @@ def get_column(self, column: str) -> sqlalchemy.Column: table, column = column.split(".") return self.tables[table].columns[column] - def get_join(self, table1: str, table2: str) -> sqlalchemy.ColumnElement: - """Return the join between two tables. - - Parameters - ---------- - table1 : - The first table in the join. - table2 : - The second table in the join. - - Returns - ------- - result : - The join between the two tables. - """ - joins = cast(tuple[InnerJoin, ...], self.joins["inner"]) - for join in joins: - tables = join.matches.keys() - if table1 in tables and table2 in tables: - return join(self) - - raise ValueError(f"Could not find a join between {table1} and {table2}") - - def build_join(self, table_names: set[str]) -> sqlalchemy.Table | sqlalchemy.Join: - tables = list(table_names) - select_from = self.tables[tables[0]] - print("tables are", tables) - for i in range(1, len(tables)): - current_table = tables[i] - previous_table = tables[i-1] - print("current:", current_table, "previous:", previous_table) - join = self.get_join(previous_table, current_table) - select_from = sqlalchemy.join(select_from, self.tables[current_table], join) - return select_from - def fetch_data(self, query_model: sqlalchemy.Select) -> dict[str, list]: # Temporary, for testing. TODO: remove this code block before merging _log_level = logger.getEffectiveLevel() @@ -351,10 +296,12 @@ def add_data_ids(table_name: str) -> list[sqlalchemy.Column]: table_columns.add(seq_num_column.label("seq_num")) return [day_obs_column, seq_num_column] - if "visit1" in table_names: + if list(table_names)[0] in visit1_tables: data_id_columns = add_data_ids("visit1") - else: + elif list(table_names)[0] in exposure_tables: data_id_columns = add_data_ids("exposure") + else: + raise ValueError(f"Unsupported table name: {list(table_names)[0]}") return table_columns, table_names, data_id_columns @@ -389,13 +336,13 @@ def query( if query is not None: query_result = query(self) query_model = sqlalchemy.and_(query_model, query_result.result) - table_names.add(*query_result.tables) + table_names.update(query_result.tables) if data_ids is not None: data_id_select = sqlalchemy.tuple_(day_obs_column, seq_num_column).in_(data_ids) query_model = sqlalchemy.and_(query_model, data_id_select) # Build the join - select_from = self.build_join(table_names) + select_from = self.joins.build_join(table_names) # Build the query query_model = sqlalchemy.select(*table_columns).select_from(select_from).where(query_model) diff --git a/scripts/joins.yaml b/scripts/joins.yaml index e442a93..acad6e7 100644 --- a/scripts/joins.yaml +++ b/scripts/joins.yaml @@ -8,14 +8,6 @@ joins: ccdexposure: - exposure_id - # exposure and exposure_flexdata - - type: inner - matches: - exposure: - - exposure_id - exposure_flexdata: - - obs_id - # exposure and visit1 - type: inner matches: @@ -24,7 +16,7 @@ joins: visit1: - visit_id - # exposure and ccdvisit1 (through ccdexposure) + # exposure and ccdvisit1 - type: inner matches: exposure: @@ -32,22 +24,6 @@ joins: ccdvisit1: - visit_id - # exposure and ccdvisit1_quicklook (through ccdexposure and ccdvisit1) - - type: inner - matches: - exposure: - - exposure_id - ccdvisit1_quicklook: - - visit_id - - # exposure and visit1_quicklook - - type: inner - matches: - exposure: - - exposure_id - visit1_quicklook: - - visit_id - # ccdexposure and ccdexposure_camera - type: inner matches: @@ -56,29 +32,14 @@ joins: ccdexposure_camera: - ccdexposure_id - # ccdexposure and ccdexposure_flexdata - - type: inner - matches: - ccdexposure: - - ccdexposure_id - ccdexposure_flexdata: - - obs_id - - # ccdexposure and ccdvisit1 - - type: inner - matches: - ccdexposure: - - ccdexposure_id - ccdvisit1: - - ccdvisit_id - # visit1 and ccdvisit1 - type: inner matches: visit1: - visit_id ccdvisit1: - - visit_id + #- visit_id + - exposure_id # visit1 and visit1_quicklook - type: inner @@ -88,22 +49,6 @@ joins: visit1_quicklook: - visit_id - # visit1 and ccdvisit1 - - type: inner - matches: - visit1: - - visit_id - ccdvisit1: - - visit_id - - # visit1 and ccdvisit1_quicklook - - type: inner - matches: - visit1: - - visit_id - ccdvisit1_quicklook: - - visit_id - # ccdvisit1 and ccdvisit1_quicklook - type: inner matches: @@ -111,35 +56,3 @@ joins: - ccdvisit_id ccdvisit1_quicklook: - ccdvisit_id - - # exposure_flexdata and exposure_flexdata_schema - - type: inner - matches: - exposure_flexdata: - - key - exposure_flexdata_schema: - - key - - # ccdexposure_flexdata and ccdexposure_flexdata_schema - - type: inner - matches: - ccdexposure_flexdata: - - key - ccdexposure_flexdata_schema: - - key - - # ccdexposure and visit1 - - type: inner - matches: - ccdexposure: - - exposure_id - visit1: - - visit_id - - # ccdexposure and visit1_quicklook - - type: inner - matches: - ccdexposure: - - exposure_id - visit1_quicklook: - - visit_id diff --git a/tests/joins.yaml b/tests/joins.yaml new file mode 100644 index 0000000..809b7ad --- /dev/null +++ b/tests/joins.yaml @@ -0,0 +1,9 @@ +--- +joins: + # visit1 and visit1_quicklook + - type: inner + matches: + exposure: + - exposure_id + visit1_quicklook: + - visit_id diff --git a/tests/utils.py b/tests/utils.py index 62cd1ba..3cb20f3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -209,9 +209,7 @@ def setUp(self): engine = sqlalchemy.create_engine("sqlite:///" + db_file.name) # Load the table joins - joins_path = os.path.abspath( - os.path.expanduser(os.path.expandvars(os.path.join(path, "..", "scripts", "joins.yaml"))) - ) + joins_path = os.path.join(path, "joins.yaml") with open(joins_path) as file: joins = yaml.safe_load(file)["joins"]