From a5a92779511acd02a98299f0b689ec87d00b3576 Mon Sep 17 00:00:00 2001 From: fred3m Date: Mon, 8 Jul 2024 08:50:55 -0700 Subject: [PATCH] Refactor to work with improved rubintv_visualization client --- .github/workflows/build_docs.yaml | 4 +- doc/lsst.rubintv.analysis.service/design.rst | 51 ++ doc/lsst.rubintv.analysis.service/index.rst | 5 +- .../lsst/rubintv/analysis/service/__init__.py | 23 +- .../lsst/rubintv/analysis/service/butler.py | 2 + .../lsst/rubintv/analysis/service/command.py | 95 ++- .../analysis/service/commands/__init__.py | 24 + .../analysis/service/commands/butler.py | 126 ++++ .../rubintv/analysis/service/commands/db.py | 218 +++++++ .../analysis/service/commands/image.py | 52 ++ python/lsst/rubintv/analysis/service/data.py | 158 +++++ .../lsst/rubintv/analysis/service/database.py | 600 ++++++++++++------ python/lsst/rubintv/analysis/service/efd.py | 26 + python/lsst/rubintv/analysis/service/query.py | 72 ++- python/lsst/rubintv/analysis/service/utils.py | 92 ++- .../lsst/rubintv/analysis/service/viewer.py | 435 +++++++++++++ .../analysis/service/{client.py => worker.py} | 59 +- scripts/config.yaml | 20 +- scripts/joins.yaml | 58 ++ scripts/mock_server.py | 81 +-- scripts/rubintv_worker.py | 118 +++- tests/joins.yaml | 21 + tests/schema.yaml | 68 +- tests/test_command.py | 111 ++-- tests/test_database.py | 104 ++- tests/test_query.py | 279 +++++--- tests/utils.py | 187 ++++-- ups/rubintv_analysis_service.table | 2 + 28 files changed, 2437 insertions(+), 654 deletions(-) create mode 100644 doc/lsst.rubintv.analysis.service/design.rst create mode 100644 python/lsst/rubintv/analysis/service/commands/__init__.py create mode 100644 python/lsst/rubintv/analysis/service/commands/butler.py create mode 100644 python/lsst/rubintv/analysis/service/commands/db.py create mode 100644 python/lsst/rubintv/analysis/service/commands/image.py create mode 100644 python/lsst/rubintv/analysis/service/data.py create mode 100644 python/lsst/rubintv/analysis/service/efd.py create mode 100644 python/lsst/rubintv/analysis/service/viewer.py rename python/lsst/rubintv/analysis/service/{client.py => worker.py} (63%) create mode 100644 scripts/joins.yaml create mode 100644 tests/joins.yaml diff --git a/.github/workflows/build_docs.yaml b/.github/workflows/build_docs.yaml index 75d1dac..e2c05f7 100644 --- a/.github/workflows/build_docs.yaml +++ b/.github/workflows/build_docs.yaml @@ -34,7 +34,9 @@ jobs: run: ls python/lsst/rubintv/analysis/service - name: Install documenteer - run: pip install 'documenteer[pipelines]<0.7' + run: | + pip install 'sphinx<7' + pip install 'documenteer[pipelines]==0.8.2' - name: Build documentation working-directory: ./doc diff --git a/doc/lsst.rubintv.analysis.service/design.rst b/doc/lsst.rubintv.analysis.service/design.rst new file mode 100644 index 0000000..fe60375 --- /dev/null +++ b/doc/lsst.rubintv.analysis.service/design.rst @@ -0,0 +1,51 @@ +.. _rubintv_analysis_service-design: + +===================================== +Design of rubintv_analysis_service +===================================== + +.. contents:: Table of Contents + :depth: 2 + +Overview +======== + +The ``rubintv_analysis_service`` is a backend Python service designed to support the Derived Data Visualization (DDV) tool within the Rubin Observatory's software ecosystem. It provides a set of libraries and scripts that facilitate the analysis and visualization of astronomical data. + +Architecture +============ + +The service is structured around a series of commands and tasks, each responsible for a specific aspect of data processing and visualization. Key components include: + +- **Worker Script**: A script that initializes and runs the service, handling configuration and database connections. + + - [`rubintv_worker.py`](rubintv_analysis_service/scripts/rubintv_worker.py) + +The script is designed to be run on a worker POD that is part of a Kubernetes cluster. It is responsible for initializing the service, loading configuration, and connecting to the Butler and consDB. It listens for incoming commands from the web application, executes them, and returns the results. + +There is also a [`mock server`](rubintv_analysis_service/scripts/mock_server.py) that can be used for testing the service before being built on either the USDF or summit. + +- **Commands**: Modular operations that perform specific tasks, such as loading columns, detector images, and detector information. These are implemented in various Python modules within the ``commands`` directory, for example the[`db.py`](rubintv_analysis_service/python/lsst/rubintv/analysis/service/commands/db.py) module contains commands for loading information from the consolidated database (consDB), while the [`image.py`](rubintv_analysis_service/python/lsst/rubintv/analysis/service/commands/image.py) module contains commands for loading detector images (not yet implemented), and [`butler.py`](rubintv_analysis_service/python/lsst/rubintv/analysis/service/commands/butler.py) contains commands for loading data from a Butler repository. + +All commands derive from the `BaseCommand` class, which provides a common interface for command execution. All inherited classes are required to have parameters as keyword arguments, and implement the `BaseCommand.build_contents` method. This is done to separate the different steps in processing a command: +1. Reading the JSON command and converting it into a python dictionary. +2. Parsing the command and converting it from JSON into a `BaseCommand` instance. +3. Executing the command. +4. Packaging the results of the command into a JSON response and sending it to the rubintv web application. + +The `BaseCommand.build_contents` method is called during execution, and must return the result as a `dict` that will be converted into JSON and returned to the user. + +Configuration +============= + +Configuration for the service is managed through the following YAML files, allowing for flexible deployment and customization of the service's behavior: + +- **config.yaml**: Main configuration file specifying service parameters. +- **joins.yaml**: Configuration for database joins. + +Configuration options can be overwritten using commad line arguments, which are parsed using the `argparse` module. + +Dart/Flutter Frontend +===================== + +The frontend of the DDV tool is implemented using the Dart programming language and the Flutter framework. It provides a web-based interface for users to interact with the service, submit commands, and visualize the results, and is located at https://github.com/lsst-ts/rubintv_visualization, which is built on top of [`rubin_chart`](https://github.com/lsst-sitcom/rubin_chart), an open source plotting library in flutter also written by the project. diff --git a/doc/lsst.rubintv.analysis.service/index.rst b/doc/lsst.rubintv.analysis.service/index.rst index 3707f25..6f85025 100644 --- a/doc/lsst.rubintv.analysis.service/index.rst +++ b/doc/lsst.rubintv.analysis.service/index.rst @@ -6,7 +6,8 @@ lsst.rubintv.analysis.service ############################# -.. Paragraph that describes what this Python module does and links to related modules and frameworks. +This is the backend python service to run the Derived Data Visualization (DDV) tool. + .. _lsst.rubintv.analysis.service-using: @@ -18,6 +19,8 @@ toctree linking to topics related to using the module's APIs. .. toctree:: :maxdepth: 2 + design + .. _lsst.rubintv.analysis.service-contributing: Contributing diff --git a/python/lsst/rubintv/analysis/service/__init__.py b/python/lsst/rubintv/analysis/service/__init__.py index ae91fbd..cc66cb1 100644 --- a/python/lsst/rubintv/analysis/service/__init__.py +++ b/python/lsst/rubintv/analysis/service/__init__.py @@ -1 +1,22 @@ -from . import command, database, query, utils +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from . import butler, command, commands, data, database, efd, query, utils, viewer, worker diff --git a/python/lsst/rubintv/analysis/service/butler.py b/python/lsst/rubintv/analysis/service/butler.py index 35dd967..7677338 100644 --- a/python/lsst/rubintv/analysis/service/butler.py +++ b/python/lsst/rubintv/analysis/service/butler.py @@ -26,4 +26,6 @@ @dataclass class ExampleButlerCommand(BaseCommand): + """Placeholder for butler commands""" + pass diff --git a/python/lsst/rubintv/analysis/service/command.py b/python/lsst/rubintv/analysis/service/command.py index 776ea2b..4b8a469 100644 --- a/python/lsst/rubintv/analysis/service/command.py +++ b/python/lsst/rubintv/analysis/service/command.py @@ -19,18 +19,23 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from __future__ import annotations + import json import logging +import traceback from abc import ABC, abstractmethod from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .data import DataCenter -import sqlalchemy -from lsst.daf.butler import Butler logger = logging.getLogger("lsst.rubintv.analysis.service.command") -def construct_error_message(error_name: str, description: str) -> str: +def construct_error_message(error_name: str, description: str, traceback: str) -> str: """Use a standard format for all error messages. Parameters @@ -51,12 +56,13 @@ def construct_error_message(error_name: str, description: str) -> str: "content": { "error": error_name, "description": description, + "traceback": traceback, }, } ) -def error_msg(error: Exception) -> str: +def error_msg(error: Exception, traceback: str) -> str: """Handle errors received while parsing or executing a command. Parameters @@ -72,23 +78,23 @@ def error_msg(error: Exception) -> str: """ if isinstance(error, json.decoder.JSONDecodeError): - return construct_error_message("JSON decoder error", error.args[0]) + return construct_error_message("JSON decoder error", error.args[0], traceback) if isinstance(error, CommandParsingError): - return construct_error_message("parsing error", error.args[0]) + return construct_error_message("parsing error", error.args[0], traceback) if isinstance(error, CommandExecutionError): - return construct_error_message("execution error", error.args[0]) + return construct_error_message("execution error", error.args[0], traceback) if isinstance(error, CommandResponseError): - return construct_error_message("command response error", error.args[0]) + return construct_error_message("command response error", error.args[0], traceback) # We should always receive one of the above errors, so the code should # never get to here. But we generate this response just in case something # very unexpected happens, or (more likely) the code is altered in such a # way that this line is it. msg = "An unknown error occurred, you should never reach this message." - return construct_error_message(error.__class__.__name__, msg) + return construct_error_message(error.__class__.__name__, msg, traceback) class CommandParsingError(Exception): @@ -111,22 +117,6 @@ class CommandResponseError(Exception): pass -@dataclass -class DatabaseConnection: - """A connection to a database. - - Attributes - ---------- - engine : - The engine used to connect to the database. - schema : - The schema for the database. - """ - - engine: sqlalchemy.engine.Engine - schema: dict - - @dataclass(kw_only=True) class BaseCommand(ABC): """Base class for commands. @@ -146,15 +136,13 @@ class BaseCommand(ABC): response_type: str @abstractmethod - def build_contents(self, databases: dict[str, DatabaseConnection], butler: Butler | None) -> dict: + def build_contents(self, data_center: DataCenter) -> dict: """Build the contents of the command. Parameters ---------- - databases : - The database connections. - butler : - A connected Butler. + data_center : + Connections to databases, the Butler, and the EFD. Returns ------- @@ -163,7 +151,7 @@ def build_contents(self, databases: dict[str, DatabaseConnection], butler: Butle """ pass - def execute(self, databases: dict[str, DatabaseConnection], butler: Butler | None): + def execute(self, data_center: DataCenter): """Execute the command. This method does not return anything, buts sets the `result`, @@ -171,18 +159,18 @@ def execute(self, databases: dict[str, DatabaseConnection], butler: Butler | Non Parameters ---------- - databases : - The database connections. - butler : - A conencted Butler. + data_center : + Connections to databases, the Butler, and the EFD. """ - self.result = {"type": self.response_type, "content": self.build_contents(databases, butler)} + self.result = {"type": self.response_type, "content": self.build_contents(data_center)} - def to_json(self): + def to_json(self, request_id: str | None = None): """Convert the `result` into JSON.""" if self.result is None: raise CommandExecutionError(f"Null result for command {self.__class__.__name__}") + if request_id is not None: + self.result["requestId"] = request_id return json.dumps(self.result) @classmethod @@ -191,7 +179,7 @@ def register(cls, name: str): BaseCommand.command_registry[name] = cls -def execute_command(command_str: str, databases: dict[str, DatabaseConnection], butler: Butler | None) -> str: +def execute_command(command_str: str, data_center: DataCenter) -> str: """Parse a JSON formatted string into a command and execute it. Command format: @@ -206,10 +194,8 @@ def execute_command(command_str: str, databases: dict[str, DatabaseConnection], ---------- command_str : The JSON formatted command received from the user. - databases : - The database connections. - butler : - A connected Butler. + data_center : + Connections to databases, the Butler, and the EFD. """ try: command_dict = json.loads(command_str) @@ -217,7 +203,8 @@ def execute_command(command_str: str, databases: dict[str, DatabaseConnection], raise CommandParsingError(f"Could not generate a valid command from {command_str}") except Exception as err: logging.exception("Error converting command to JSON.") - return error_msg(err) + traceback_string = traceback.format_exc() + return error_msg(err, traceback_string) try: if "name" not in command_dict.keys(): @@ -230,19 +217,27 @@ def execute_command(command_str: str, databases: dict[str, DatabaseConnection], command = BaseCommand.command_registry[command_dict["name"]](**parameters) except Exception as err: - logging.exception("Error parsing command.") - return error_msg(CommandParsingError(f"'{err}' error while parsing command")) + logging.exception(f"Error parsing command {command_dict}") + traceback_string = traceback.format_exc() + return error_msg(CommandParsingError(f"'{err}' error while parsing command"), traceback_string) try: - command.execute(databases, butler) + command.execute(data_center) except Exception as err: - logging.exception("Error executing command.") - return error_msg(CommandExecutionError(f"{err} error executing command.")) + logging.exception(f"Error executing command {command_dict}") + traceback_string = traceback.format_exc() + return error_msg(CommandExecutionError(f"{err} error executing command."), traceback_string) try: - result = command.to_json() + if "requestId" in command_dict: + result = command.to_json(command_dict["requestId"]) + else: + result = command.to_json() except Exception as err: logging.exception("Error converting command response to JSON.") - return error_msg(CommandResponseError(f"{err} error converting command response to JSON.")) + traceback_string = traceback.format_exc() + return error_msg( + CommandResponseError(f"{err} error converting command response to JSON."), traceback_string + ) return result diff --git a/python/lsst/rubintv/analysis/service/commands/__init__.py b/python/lsst/rubintv/analysis/service/commands/__init__.py new file mode 100644 index 0000000..268e95b --- /dev/null +++ b/python/lsst/rubintv/analysis/service/commands/__init__.py @@ -0,0 +1,24 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from .butler import * +from .db import * +from .image import * diff --git a/python/lsst/rubintv/analysis/service/commands/butler.py b/python/lsst/rubintv/analysis/service/commands/butler.py new file mode 100644 index 0000000..3437eb1 --- /dev/null +++ b/python/lsst/rubintv/analysis/service/commands/butler.py @@ -0,0 +1,126 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from ..command import BaseCommand + +if TYPE_CHECKING: + from lsst.afw.cameraGeom import Camera + + from ..data import DataCenter + + +def get_camera(instrument_name: str) -> Camera: + """Load a camera based on the instrument name + + Parameters + ---------- + instrument_name : str + The name of the instrument. + + Returns + ------- + camera : Camera + The camera object. + """ + # Import afw packages here to prevent tests from failing + from lsst.obs.lsst import Latiss, LsstCam, LsstComCam + + instrument_name = instrument_name.lower() + match instrument_name: + case "lsstcam": + camera = LsstCam.getCamera() + case "lsstcomcam": + camera = LsstComCam.getCamera() + case "latiss": + camera = Latiss.getCamera() + case _: + raise ValueError(f"Unsupported instrument: {instrument_name}") + return camera + + +@dataclass(kw_only=True) +class LoadDetectorInfoCommand(BaseCommand): + """Load the detector information from the Butler. + + Attributes + ---------- + instrument : str + The instrument name. + """ + + instrument: str + response_type: str = "detector_info" + + def build_contents(self, data_center: DataCenter) -> dict: + # Import afw packages here to prevent tests from failing + from lsst.afw.cameraGeom import FOCAL_PLANE, Detector + + # Load the detector information from the Butler + camera = get_camera(self.instrument) + detector_info = {} + for detector in camera: + if isinstance(detector, Detector): + detector_info[detector.getId()] = { + "corners": detector.getCorners(FOCAL_PLANE), + "id": detector.getId(), + "name": detector.getName(), + } + return detector_info + + +@dataclass(kw_only=True) +class LoadImageCommand(BaseCommand): + """Load an image from the Butler. + + Attributes + ---------- + collection : str + The name of the collection to load the image from. + image_name : str + The name of the image to load. + data_id : dict + The data ID of the image. Depending on the type of image this could + include things like "band" or "visit" or "detector". + """ + + repo: str + image_name: str + collection: dict + data_id: dict + response_type: str = "image" + + def build_contents(self, data_center: DataCenter) -> dict: + # Load the image from the Butler + assert data_center.butlers is not None + image = data_center.butlers[self.repo].get( + self.image_name, collections=[self.collection], **self.data_id + ) + if hasattr(image, "image"): + # Extract the Image from an Exposure or MaskedImage. + image = image.image + return { + "image": image.array, + } diff --git a/python/lsst/rubintv/analysis/service/commands/db.py b/python/lsst/rubintv/analysis/service/commands/db.py new file mode 100644 index 0000000..e82870e --- /dev/null +++ b/python/lsst/rubintv/analysis/service/commands/db.py @@ -0,0 +1,218 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from ..command import BaseCommand +from ..database import exposure_tables, visit1_tables +from ..query import EqualityQuery, ParentQuery, Query + +if TYPE_CHECKING: + from ..data import DataCenter + + +logger = logging.getLogger("lsst.rubintv.analysis.service.commands.db") + + +@dataclass(kw_only=True) +class LoadColumnsCommand(BaseCommand): + """Load columns from a database table with an optional query. + + Attributes + ---------- + database : + The name of the database that the table is in. + columns : + Columns that are to be loaded. + This should be a string with the format `table.columnName`. + If there is only a single entry and it does not contain a `.`, + then the table name is used and all of the columns matching the + `query` are loaded. + query : + Query used to select rows in the table for a specific Widget. + If `query` is ``None`` then all the rows are loaded. + global_query : + Query used to select rows for all Widgets in the remote workspace. + If `global_query` is ``None`` then no global query is used. + day_obs : + The day_obs to filter the data on. + If day_obs is None then no filter on day_obs is used unless otherwise + specified in `query` or `global_query`. + data_ids : + The data IDs to filter the data on. + If data_ids is specified then only rows with the specified + day_obs and seq_num are selected. + """ + + database: str + columns: list[str] + query: dict | None = None + global_query: dict | None = None + day_obs: str | None = None + data_ids: list[tuple[int, int]] | None = None + response_type: str = "table columns" + + def build_contents(self, data_center: DataCenter) -> dict: + # Query the database to return the requested columns + database = data_center.schemas[self.database] + + query: Query | None = None + if self.query is not None: + query = Query.from_dict(self.query) + if self.global_query is not None: + global_query = Query.from_dict(self.global_query) + if query is None: + query = global_query + else: + query = ParentQuery( + children=[query, global_query], + operator="AND", + ) + if self.day_obs is not None: + table_name = self.columns[0].split(".")[0] + if table_name in exposure_tables: + column = "exposure.day_obs" + elif table_name in visit1_tables: + column = "visit1.day_obs" + else: + raise ValueError(f"Unsupported table name: {table_name}") + day_obs_query = EqualityQuery( + column=column, + value=int(self.day_obs.replace("-", "")), + operator="eq", + ) + if query is None: + query = day_obs_query + else: + query = ParentQuery( + children=[query, day_obs_query], + operator="AND", + ) + + data = database.query(self.columns, query, self.data_ids) + + if not data: + # There is no data to return + data = [] + content = { + "schema": self.database, + "columns": self.columns, + "data": data, + } + + return content + + +@dataclass(kw_only=True) +class CalculateBoundsCommand(BaseCommand): + """Calculate the bounds of a table column. + + Attributes + ---------- + database : + The name of the database that the table is in. + column : + The column to calculate the bounds of in the format "table.column". + """ + + database: str + column: str + response_type: str = "column bounds" + + def build_contents(self, data_center: DataCenter) -> dict: + # Query the database to return the requested columns + database = data_center.schemas[self.database] + data = database.calculate_bounds( + column=self.column, + ) + return { + "column": self.column, + "bounds": data, + } + + +@dataclass(kw_only=True) +class LoadInstrumentCommand(BaseCommand): + """Load the instruments for a database. + + Attributes + ---------- + instrument : + The name of the instrument (camera) to load. + """ + + instrument: str + response_type: str = "instrument info" + + def build_contents(self, data_center: DataCenter) -> dict: + from lsst.afw.cameraGeom import FOCAL_PLANE + from lsst.obs.lsst import Latiss, LsstCam, LsstComCam, LsstComCamSim + + instrument = self.instrument.lower() + + match instrument: + case "lsstcam": + camera = LsstCam.getCamera() + case "lsstcomcam": + camera = LsstComCam.getCamera() + case "latiss": + camera = Latiss.getCamera() + case "lsstcomcamsim": + camera = LsstComCamSim.getCamera() + case _: + raise ValueError(f"Unsupported instrument: {instrument}") + + detectors = [] + for detector in camera: + corners = [(c.getX(), c.getY()) for c in detector.getCorners(FOCAL_PLANE)] + detectors.append( + { + "id": detector.getId(), + "name": detector.getName(), + "corners": corners, + } + ) + + result = { + "instrument": self.instrument, + "detectors": detectors, + } + + # Load the data base to access the schema + schema_name = f"cdb_{instrument}" + try: + database = data_center.schemas[schema_name] + result["schema"] = database.schema + except KeyError: + logger.warning(f"No database connection available for {schema_name}") + logger.warning(f"Available databases: {data_center.schemas.keys()}") + + return result + + +# Register the commands +LoadColumnsCommand.register("load columns") +CalculateBoundsCommand.register("get bounds") +LoadInstrumentCommand.register("load instrument") diff --git a/python/lsst/rubintv/analysis/service/commands/image.py b/python/lsst/rubintv/analysis/service/commands/image.py new file mode 100644 index 0000000..98f9a2e --- /dev/null +++ b/python/lsst/rubintv/analysis/service/commands/image.py @@ -0,0 +1,52 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import logging +from dataclasses import dataclass + +from lsst.rubintv.analysis.service.data import DataCenter + +from ..command import BaseCommand + +logger = logging.getLogger("lsst.rubintv.analysis.service.commands.image") + + +@dataclass(kw_only=True) +class LoadDetectorImageCommand(BaseCommand): + """Load an image from a data center. + + This command is not yet implemented, but will use the + `viewer.py` module, adapted from `https://github.com/fred3m/toyz` + to load image tiles and send them to the client to display + detector images. + """ + + database: str + detector: int + visit_id: int + + def build_contents(self, data_center: DataCenter) -> dict: + # butler = data_center.butler + # assert butler is not None + # image = butler.get(, **data_id) + return {} diff --git a/python/lsst/rubintv/analysis/service/data.py b/python/lsst/rubintv/analysis/service/data.py new file mode 100644 index 0000000..196d569 --- /dev/null +++ b/python/lsst/rubintv/analysis/service/data.py @@ -0,0 +1,158 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from lsst.daf.butler import Butler + + from .database import ConsDbSchema + from .efd import EfdClient + + +class DataId: + """A unique identifier for a dataset.""" + + def __init__(self, **parameters): + self.parameters = parameters + for name, parameter in parameters.items(): + setattr(self, name, parameter) + + def __hash__(self): + return hash(tuple(sorted(self.parameters.items()))) + + def __eq__(self, other: DataId): + for parameter in self.parameters: + if parameter not in other.parameters or self.parameters[parameter] != other.parameters[parameter]: + return False + return True + + +@dataclass(kw_only=True) +class SelectionId: + """A unique identifier for a data entry.""" + + pass + + +@dataclass(kw_only=True) +class DatabaseSelectionId(SelectionId): + """A unique identifier for a database row. + + Attributes + ---------- + data_id : + The data ID of the row. + columns : + Columns in the selected row. + """ + + data_id: DataId + columns: tuple[str] + + def __hash__(self): + return hash((self.data_id, sorted(self.columns))) + + def __eq__(self, other: DatabaseSelectionId): + if self.data_id != other.data_id or len(self.columns) != len(other.columns): + return False + for column, other_column in zip(sorted(self.columns), sorted(other.columns)): + if column != other_column: + return False + return True + + +@dataclass(kw_only=True) +class ButlerSelectionId(SelectionId): + """A unique identifier for a Butler dataset.""" + + pass + + +@dataclass(kw_only=True) +class EfdSelectionId(SelectionId): + """A unique identifier for an EFD dataset entry.""" + + pass + + +@dataclass +class DataMatch(ABC): + """A match between two datasets. + + Attributes + ---------- + data_id1 : + The data ID of the first dataset to match. + data_id2 : + The data ID of the second dataset to match. + """ + + data_id1: DataId + data_id2: DataId + + @abstractmethod + def match_forward(self, indices: list[SelectionId]): + """Match the first dataset to the second.""" + pass + + @abstractmethod + def match_backward(self, indices: list[SelectionId]): + """Match the second dataset to the first.""" + pass + + +class DataCenter: + """A class that manages access to data. + + This includes functions to match entries between datasets, + for example the exposure ID from the visit database can + be matched to exposures in the Butler. + + Attributes + ---------- + matches : + A dictionary of matches between datasets. + databases : + A dictionary of database connections. + butlers : + Butler repos that the data center has access to. + efd_client : + An EFD client instance. + """ + + schemas: dict[str, ConsDbSchema] + butlers: dict[str, Butler] | None = None + efd_client: EfdClient | None = None + + def __init__( + self, + schemas: dict[str, ConsDbSchema], + butlers: dict[str, Butler] | None = None, + efd_client: EfdClient | None = None, + ): + self.schemas = schemas + self.butlers = butlers + self.efdClient = efd_client diff --git a/python/lsst/rubintv/analysis/service/database.py b/python/lsst/rubintv/analysis/service/database.py index eb06f48..fff04e2 100644 --- a/python/lsst/rubintv/analysis/service/database.py +++ b/python/lsst/rubintv/analysis/service/database.py @@ -19,36 +19,48 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from dataclasses import dataclass -from typing import Sequence +from __future__ import annotations + +import logging import sqlalchemy -from lsst.daf.butler import Butler -from .command import BaseCommand, DatabaseConnection +from .data import DatabaseSelectionId, DataId from .query import Query +logger = logging.getLogger("lsst.rubintv.analysis.service.database") -class UnrecognizedTableError(Exception): - """An error that occurs when a table name does not appear in the schema""" - pass +# Exposure tables currently in the schema +exposure_tables = [ + "exposure", + "ccdexposure", + "ccdexposure_camera", +] +# Tables in the schema for single visit exposures +visit1_tables = [ + "visit1", + "visit1_quicklook", + "ccdvisit1", + "ccdvisit1_quicklook", +] -def get_table_names(schema: dict) -> tuple[str, ...]: - """Given a schema, return a list of dataset names +# Flex tables in the schema. +# These are currently not implement and would take some thought implmenting +# correctly, so we ignore them for now. +flex_tables = [ + "exposure_flexdata", + "exposure_flexdata_schema", + "ccdexposure_flexdata", + "ccdexposure_flexdata_schema", +] - Parameters - ---------- - schema : - The schema for a database. - Returns - ------- - result : - The names of all the tables in the database. - """ - return tuple(tbl["name"] for tbl in schema["tables"]) +class UnrecognizedTableError(Exception): + """An error that occurs when a table name does not appear in the schema""" + + pass def get_table_schema(schema: dict, table: str) -> dict: @@ -73,199 +85,387 @@ def get_table_schema(schema: dict, table: str) -> dict: raise UnrecognizedTableError("Could not find the table '{table}' in database") -def column_names_to_models(table: sqlalchemy.Table, columns: list[str]) -> list[sqlalchemy.Column]: - """Return the sqlalchemy model of a Table column for each column name. +class JoinError(Exception): + """An error that occurs when a join cannot be made between two tables""" - This method is used to generate a sqlalchemy query based on a `~Query`. + pass - Parameters - ---------- - table : - The name of the table in the database. - columns : - The names of the columns to generate models for. - - Returns - ------- - A list of sqlalchemy columns. - """ - models = [] - for column in columns: - models.append(getattr(table.columns, column)) - return models +class JoinBuilder: + """Builds joins between tables in sqlalchemy. -def query_table( - table: str, - engine: sqlalchemy.engine.Engine, - columns: list[str] | None = None, - query: dict | None = None, -) -> Sequence[sqlalchemy.engine.row.Row]: - """Query a table and return the results + Using a dictionary of joins, usually from the joins.yaml file, + this class builds a graph of joins between tables so that given a + list of tables it can create a join that connects all the tables. - Parameters + Attributes ---------- - engine : - The engine used to connect to the database. - table : - The table that is being queried. - columns : - The columns from the table to return. - If `columns` is ``None`` then all the columns - in the table are returned. - query : - A query used on the table. - If `query` is ``None`` then all the rows - in the query are returned. - - Returns - ------- - result : - A list of the rows that were returned by the query. + tables : + A dictionary of tables in the schema. + joins : + A list of inner joins between tables. Each item in the list should + have a ``matches`` key with another dictionary as values. + The values will have the names of the tables being joined as keys + and a list of columns to join on as values. """ - metadata = sqlalchemy.MetaData() - _table = sqlalchemy.Table(table, metadata, autoload_with=engine) - - if columns is None: - _query = _table.select() - else: - _query = sqlalchemy.select(*column_names_to_models(_table, columns)) - - if query is not None: - _query = _query.where(Query.from_dict(query)(_table)) - - connection = engine.connect() - result = connection.execute(_query) - return result.fetchall() + def __init__(self, tables: dict[str, sqlalchemy.Table], joins: list[dict]): + self.tables = tables + self.joins = joins + self.join_graph = self._build_join_graph() + + def _build_join_graph(self) -> dict[str, dict[str, list[str]]]: + """Create the graph of joins from the list of joins.""" + graph = {table: {} for table in self.tables} + for join in self.joins: + tables = list(join["matches"].keys()) + t1, t2 = tables[0], tables[1] + join_columns = list(zip(join["matches"][t1], join["matches"][t2])) + graph[t1][t2] = join_columns + graph[t2][t1] = [(col2, col1) for col1, col2 in join_columns] + return graph + + def _find_join_path(self, start: str, end: str) -> list[str]: + """Find a path between two tables in the join graph. + + In some cases, such as between vist1 and ccdvisit1_quicklook, + this might require intermediary joins. + + Parameters + ---------- + start : + The name of the table to start the join from. + end : + The name of the table to join to. + + Returns + ------- + result : + A list of tables that can be joined to get from the + first table to the last table. + """ + queue = [(start, [start])] + visited = set() + + while queue: + (node, path) = queue.pop(0) + if node not in visited: + if node == end: + return path + visited.add(node) + for neighbor in self.join_graph[node]: + if neighbor not in visited: + queue.append((neighbor, path + [neighbor])) + raise JoinError(f"No path found between {start} and {end}") + + def build_join(self, table_names: set[str]) -> sqlalchemy.Table | sqlalchemy.Join: + """Build a join between all of the tables in a SQL statement. + + Parameters + ---------- + table_names : + A set of table names to join. + + Returns + ------- + result : + The join between all of the tables. + """ + tables = list(table_names) + select_from = self.tables[tables[0]] + # Use the first table as the starting point + joined_tables = set([tables[0]]) + logger.info(f"Starting join with table: {tables[0]}") + logger.info(f"all tables: {tables}") + + for i in range(1, len(tables)): + # Move to the next table + current_table = tables[i] + if current_table in joined_tables: + logger.info(f"Skipping {current_table} as it's already joined") + continue + + # find the join path from the first table to the current table + join_path = self._find_join_path(tables[0], current_table) + logger.info(f"Join path from {tables[0]} to {current_table}: {join_path}") + + for j in range(1, len(join_path)): + # Join all of the tables in the join_path + t1, t2 = join_path[j - 1], join_path[j] + if t2 in joined_tables: + logger.info(f"Skipping {t2} as it's already joined") + continue + + logger.info(f"Joining {t1} to {t2}") + join_conditions = [] + for col1, col2 in self.join_graph[t1][t2]: + logger.info(f"Attempting to join {t1}.{col1} = {t2}.{col2}") + try: + condition = self.tables[t1].columns[col1] == self.tables[t2].columns[col2] + join_conditions.append(condition) + except KeyError as e: + logger.error(f"Column not found: {e}") + logger.error(f"Available columns in {t1}: {list(self.tables[t1].columns.keys())}") + logger.error(f"Available columns in {t2}: {list(self.tables[t2].columns.keys())}") + raise + + if not join_conditions: + raise ValueError(f"No valid join conditions found between {t1} and {t2}") + + # Implement the join in sqlalchemy + select_from = sqlalchemy.join(select_from, self.tables[t2], *join_conditions) + joined_tables.add(t2) + + return select_from + + +class ConsDbSchema: + """A schema (instrument) in the consolidated database. -def calculate_bounds(table: str, column: str, engine: sqlalchemy.engine.Engine) -> tuple[float, float]: - """Calculate the min, max for a column - - Parameters + Attributes ---------- - table : - The table that is being queried. - column : - The column to calculate the bounds of. engine : The engine used to connect to the database. - - Returns - ------- - result : - The ``(min, max)`` of the chosen column. - """ - metadata = sqlalchemy.MetaData() - _table = sqlalchemy.Table(table, metadata, autoload_with=engine) - _column = _table.columns[column] - - query = sqlalchemy.select((sqlalchemy.func.min(_column))) - connection = engine.connect() - result = connection.execute(query) - col_min = result.fetchone() - if col_min is not None: - col_min = col_min[0] - else: - raise ValueError(f"Could not calculate the min of column {column}") - - query = sqlalchemy.select((sqlalchemy.func.max(_column))) - connection = engine.connect() - result = connection.execute(query) - col_max = result.fetchone() - if col_max is not None: - col_max = col_max[0] - else: - raise ValueError(f"Could not calculate the min of column {column}") - - return col_min, col_max - - -@dataclass(kw_only=True) -class LoadColumnsCommand(BaseCommand): - """Load columns from a database table with an optional query. - - Attributes - ---------- - database : - The name of the database that the table is in. - table : - The table that the columns are loaded from. - columns : - Columns that are to be loaded. If `columns` is ``None`` - then all the columns in the `table` are loaded. - query : - Query used to select rows in the table. - If `query` is ``None`` then all the rows are loaded. + schema : + The schema yaml converted into a dict for the instrument. + metadata : + The metadata for the database. + joins : + A JoinBuilder object that builds joins between tables. """ - database: str - table: str - columns: list[str] | None = None - query: dict | None = None - response_type: str = "table columns" - - def build_contents(self, databases: dict[str, DatabaseConnection], butler: Butler | None) -> dict: - # Query the database to return the requested columns - database = databases[self.database] - index_column = get_table_schema(database.schema, self.table)["index_column"] - columns = self.columns - if columns is not None and index_column not in columns: - columns = [index_column] + columns - data = query_table( - table=self.table, - columns=columns, - query=self.query, - engine=database.engine, - ) - - if not data: - # There is no column data to return - content: dict = { - "columns": columns, - "data": [], - } + engine: sqlalchemy.engine.Engine + schema: dict + metadata: sqlalchemy.MetaData + tables: dict[str, sqlalchemy.Table] + joins: JoinBuilder + + def __init__(self, engine: sqlalchemy.engine.Engine, schema: dict, join_templates: list): + self.engine = engine + self.schema = schema + self.metadata = sqlalchemy.MetaData() + + self.tables = {} + for table in schema["tables"]: + if ( + table["name"] not in exposure_tables + and table["name"] not in visit1_tables + and table["name"] not in flex_tables + ): + # A new table was added to the schema and cannot be parsed + logger.warn(f"Table {table['name']} has not been implemented in the RubinTV analysis service") + else: + self.tables[table["name"]] = sqlalchemy.Table( + table["name"], + self.metadata, + autoload_with=self.engine, + schema=schema["name"], + ) + + self.joins = JoinBuilder(self.tables, join_templates) + + def get_table_names(self) -> tuple[str, ...]: + """Given a schema, return a list of dataset names + + Returns + ------- + result : + The names of all the tables in the database. + """ + return tuple(tbl["name"] for tbl in self.schema["tables"]) + + def get_data_id(self, table: str) -> DataId: + """Return the data id for a table. + + Parameters + ---------- + table : + The name of the table in the database. + + Returns + ------- + result : + The data id for the table. + """ + return DataId(database=self.schema["name"], table=table) + + def get_selection_id(self, table: str) -> DatabaseSelectionId: + """Return the selection indices for a table. + + Parameters + ---------- + table : + The name of the table in the database. + + Returns + ------- + result : + The selection indices for the table. + """ + _table = self.schema["tables"][table] + index_columns = _table["index_columns"] + return DatabaseSelectionId(data_id=self.get_data_id(table), columns=index_columns) + + def get_column(self, column: str) -> sqlalchemy.Column: + """Return the column model for a column. + + Parameters + ---------- + column : + The name of the column in the database. + + Returns + ------- + result : + The column model for the column. + """ + table, column = column.split(".") + return self.tables[table].columns[column] + + def fetch_data(self, query_model: sqlalchemy.Select) -> dict[str, list]: + """Load data from the database. + + Parameters + ---------- + query_model : + The query to run on the database. + """ + logger.info(f"Query: {query_model}") + connection = self.engine.connect() + result = connection.execute(query_model) + data = result.fetchall() + connection.close() + + # Convert the unnamed row data into columns + return {str(col): [row[i] for row in data] for i, col in enumerate(result.keys())} + + def get_column_models( + self, columns: list[str] + ) -> tuple[set[sqlalchemy.Column], set[str], list[sqlalchemy.Column]]: + """Return the sqlalchemy models for a list of columns. + + Parameters + ---------- + columns : + The names of the columns in the database. + + Returns + ------- + result : + The column models for the columns. + """ + table_columns = set() + table_names: set[str] = set() + # get the sql alchemy model for each column + for column in columns: + table_name, column_name = column.split(".") + table_names.add(table_name) + column_obj = self.get_column(column) + # Label each column as 'table_name.column_name' + table_columns.add(column_obj.label(f"{table_name}.{column_name}")) + + # Add the data Ids (seq_num and day_obs) to the query. + def add_data_ids(table_name: str) -> list[sqlalchemy.Column]: + day_obs_column = self.get_column(f"{table_name}.day_obs") + seq_num_column = self.get_column(f"{table_name}.seq_num") + # Strip off the table name to make the data IDs uniform + table_columns.add(day_obs_column.label("day_obs")) + table_columns.add(seq_num_column.label("seq_num")) + return [day_obs_column, seq_num_column] + + if list(table_names)[0] in visit1_tables: + data_id_columns = add_data_ids("visit1") + table_names.add("visit1") + elif list(table_names)[0] in exposure_tables: + data_id_columns = add_data_ids("exposure") + table_names.add("exposure") else: - content = { - "columns": [column for column in data[0]._fields], - "data": [list(row) for row in data], - } - - return content - - -@dataclass(kw_only=True) -class CalculateBoundsCommand(BaseCommand): - """Calculate the bounds of a table column. - - Attributes - ---------- - database : - The name of the database that the table is in. - table : - The table that the columns are loaded from. - column : - The column to calculate the bounds of. - """ - - database: str - table: str - column: str - response_type: str = "column bounds" - - def build_contents(self, databases: dict[str, DatabaseConnection], butler: Butler | None) -> dict: - database = databases[self.database] - data = calculate_bounds( - table=self.table, - column=self.column, - engine=database.engine, - ) - return { - "column": self.column, - "bounds": data, - } - - -# Register the commands -LoadColumnsCommand.register("load columns") -CalculateBoundsCommand.register("get bounds") + raise ValueError(f"Unsupported table name: {list(table_names)[0]}") + + return table_columns, table_names, data_id_columns + + def query( + self, + columns: list[str], + query: Query | None = None, + data_ids: list[tuple[int, int]] | None = None, + ) -> dict[str, list]: + """Query a table and return the results + + Parameters + ---------- + columns : + The ``table.column`` names of the columns to load. + query : + A query used on the table. + If `query` is ``None`` then all the rows + in the query are returned. + data_ids : + The data IDs to query, in the format ``(day_obs, seq_num)``. + + Returns + ------- + result : + A dictionary of columns as keys and lists of values as values. + """ + # Get the models for the columns + table_columns, table_names, data_id_columns = self.get_column_models(columns) + day_obs_column, seq_num_column = data_id_columns + + logger.info(f"table names: {table_names}") + + # generate the query + query_model = sqlalchemy.and_(*[col.isnot(None) for col in table_columns]) + if query is not None: + query_result = query(self) + query_model = sqlalchemy.and_(query_model, query_result.result) + table_names.update(query_result.tables) + if data_ids is not None: + data_id_select = sqlalchemy.tuple_(day_obs_column, seq_num_column).in_(data_ids) + query_model = sqlalchemy.and_(query_model, data_id_select) + + # Build the join + select_from = self.joins.build_join(table_names) + + # Build the query + query_model = sqlalchemy.select(*table_columns).select_from(select_from).where(query_model) + + # Fetch the data + result = self.fetch_data(query_model) + + return result + + def calculate_bounds(self, column: str) -> tuple[float, float]: + """Calculate the min, max for a column + + Parameters + ---------- + column : + The column to calculate the bounds of in the format "table.column". + + Returns + ------- + result : + The ``(min, max)`` of the chosen column. + """ + table, column = column.split(".") + _table = sqlalchemy.Table(table, self.metadata, autoload_with=self.engine) + _column = _table.columns[column] + + with self.engine.connect() as connection: + query = sqlalchemy.select((sqlalchemy.func.min(_column))) + result = connection.execute(query) + col_min = result.fetchone() + if col_min is not None: + col_min = col_min[0] + else: + raise ValueError(f"Could not calculate the min of column {column}") + + query = sqlalchemy.select((sqlalchemy.func.max(_column))) + result = connection.execute(query) + col_max = result.fetchone() + if col_max is not None: + col_max = col_max[0] + else: + raise ValueError(f"Could not calculate the max of column {column}") + return col_min, col_max diff --git a/python/lsst/rubintv/analysis/service/efd.py b/python/lsst/rubintv/analysis/service/efd.py new file mode 100644 index 0000000..67517b5 --- /dev/null +++ b/python/lsst/rubintv/analysis/service/efd.py @@ -0,0 +1,26 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +class EfdClient: + """Placeholder for the EFD client.""" + + pass diff --git a/python/lsst/rubintv/analysis/service/query.py b/python/lsst/rubintv/analysis/service/query.py index 4f6c0bf..08a6b58 100644 --- a/python/lsst/rubintv/analysis/service/query.py +++ b/python/lsst/rubintv/analysis/service/query.py @@ -23,10 +23,14 @@ import operator as op from abc import ABC, abstractmethod -from typing import Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any import sqlalchemy +if TYPE_CHECKING: + from .database import ConsDbSchema + class QueryError(Exception): """An error that occurred during a query""" @@ -34,17 +38,39 @@ class QueryError(Exception): pass +@dataclass +class QueryResult: + """The result of a query. + + Attributes + ---------- + result : + The result of the query as an sqlalchemy expression. + tables : + All of the tables that were used in the query. + """ + + result: sqlalchemy.ColumnElement + tables: set[str] + + class Query(ABC): """Base class for constructing queries.""" @abstractmethod - def __call__(self, table: sqlalchemy.Table) -> sqlalchemy.sql.elements.BooleanClauseList: + def __call__(self, database: ConsDbSchema) -> QueryResult: """Run the query on a table. Parameters ---------- - table : - The table to run the query on. + database : + The connection to the database that is being queried. + + Returns + ------- + QueryResult : + The result of the query, including the tables that are + needed for the query. """ pass @@ -66,7 +92,7 @@ def from_dict(query_dict: dict[str, Any]) -> Query: elif query_dict["name"] == "ParentQuery": return ParentQuery.from_dict(query_dict["content"]) except Exception: - raise QueryError("Failed to parse query.") + raise QueryError(f"Failed to parse query: {query_dict}") raise QueryError("Unrecognized query type") @@ -94,17 +120,19 @@ def __init__( self.column = column self.value = value - def __call__(self, table: sqlalchemy.Table) -> sqlalchemy.sql.elements.BooleanClauseList: - column = table.columns[self.column] - + def __call__(self, database: ConsDbSchema) -> QueryResult: + table_name, _ = self.column.split(".") + column = database.get_column(self.column) + result = None if self.operator in ("eq", "ne", "lt", "le", "gt", "ge"): operator = getattr(op, self.operator) - return operator(column, self.value) - - if self.operator not in ("startswith", "endswith", "contains"): + result = operator(column, self.value) + elif self.operator in ("startswith", "endswith", "contains"): + result = getattr(column, self.operator)(self.value) + else: raise QueryError(f"Unrecognized Equality operator {self.operator}") - return getattr(column, self.operator)(self.value) + return QueryResult(result, set((table_name,))) @staticmethod def from_dict(query_dict: dict[str, Any]) -> EqualityQuery: @@ -126,24 +154,32 @@ def __init__(self, children: list[Query], operator: str): self._children = children self._operator = operator - def __call__(self, table: sqlalchemy.Table) -> sqlalchemy.sql.elements.BooleanClauseList: - child_results = [child(table) for child in self._children] + def __call__(self, database: ConsDbSchema) -> QueryResult: + child_results = [] + tables = set() + for child in self._children: + result = child(database) + child_results.append(result.result) + tables.update(result.tables) + try: match self._operator: case "AND": - return sqlalchemy.and_(*child_results) + result = sqlalchemy.and_(*child_results) case "OR": - return sqlalchemy.or_(*child_results) + result = sqlalchemy.or_(*child_results) case "NOT": - return sqlalchemy.not_(*child_results) + result = sqlalchemy.not_(*child_results) case "XOR": - return sqlalchemy.and_( + result = sqlalchemy.and_( sqlalchemy.or_(*child_results), sqlalchemy.not_(sqlalchemy.and_(*child_results)), ) except Exception: raise QueryError("Error applying a boolean query statement.") + return QueryResult(result, tables) # type: ignore + @staticmethod def from_dict(query_dict: dict[str, Any]) -> ParentQuery: return ParentQuery( diff --git a/python/lsst/rubintv/analysis/service/utils.py b/python/lsst/rubintv/analysis/service/utils.py index 806215b..29f9b37 100644 --- a/python/lsst/rubintv/analysis/service/utils.py +++ b/python/lsst/rubintv/analysis/service/utils.py @@ -1,5 +1,60 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import logging from enum import Enum +__all__ = ["ServerFormatter"] + +# Define custom loggers for the web app +WORKER_LEVEL = 21 +CLIENT_LEVEL = 22 +CONNECTION_LEVEL = 23 +logging.addLevelName(WORKER_LEVEL, "WORKER") +logging.addLevelName(CLIENT_LEVEL, "CLIENT") +logging.addLevelName(CONNECTION_LEVEL, "CONNECTION") + + +def worker(self, message, *args, **kws): + """Special log level for workers""" + if self.isEnabledFor(WORKER_LEVEL): + self._log(WORKER_LEVEL, message, args, **kws) + + +def client(self, message, *args, **kws): + """Special log level for clients""" + if self.isEnabledFor(CLIENT_LEVEL): + self._log(CLIENT_LEVEL, message, args, **kws) + + +def connection(self, message, *args, **kws): + """Special log level for connections""" + if self.isEnabledFor(CONNECTION_LEVEL): + self._log(CONNECTION_LEVEL, message, args, **kws) + + +logging.Logger.worker = worker +logging.Logger.client = client +logging.Logger.connection = connection + # ANSI color codes for printing to the terminal class Colors(Enum): @@ -22,19 +77,30 @@ class Colors(Enum): BRIGHT_CYAN = 96 BRIGHT_WHITE = 97 + @property + def ansi_code(self): + return f"\x1b[{self.value};20m" + + +def color_to_ansi(color: Colors) -> str: + return f"\x1b[{color.value};20m" + -def printc(message: str, color: Colors, end_color: Colors = Colors.RESET): - """Print a message to the terminal in color. +class ServerFormatter(logging.Formatter): + format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - After printing reset the color by default. + FORMATS = { + logging.DEBUG: Colors.BRIGHT_BLACK.ansi_code + format + Colors.RESET.ansi_code, + logging.INFO: Colors.WHITE.ansi_code + format + Colors.RESET.ansi_code, + WORKER_LEVEL: Colors.BLUE.ansi_code + format + Colors.BRIGHT_RED.ansi_code, + CLIENT_LEVEL: Colors.YELLOW.ansi_code + format + Colors.BRIGHT_RED.ansi_code, + CONNECTION_LEVEL: Colors.GREEN.ansi_code + format + Colors.BRIGHT_RED.ansi_code, + logging.WARNING: Colors.YELLOW.ansi_code + format + Colors.RESET.ansi_code, + logging.ERROR: Colors.RED.ansi_code + format + Colors.RESET.ansi_code, + logging.CRITICAL: Colors.BRIGHT_RED.ansi_code + format + Colors.RESET.ansi_code, + } - Parameters - ---------- - message : - The message to print. - color : - The color to print the message in. - end : - The color future messages should be printed in. - """ - print(f"\033[{color.value}m{message}\033[{end_color.value}m") + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno, self.format) + formatter = logging.Formatter(log_fmt) + return formatter.format(record) diff --git a/python/lsst/rubintv/analysis/service/viewer.py b/python/lsst/rubintv/analysis/service/viewer.py new file mode 100644 index 0000000..d0011b1 --- /dev/null +++ b/python/lsst/rubintv/analysis/service/viewer.py @@ -0,0 +1,435 @@ +# This file is part of lsst_rubintv_analysis_service. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +# This module was adapted from https://github.com/fred3m/toyz and has not yet +# been tested. + +from __future__ import annotations + +import datetime +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import numpy as np +import scipy.ndimage +from matplotlib import cm as cmap +from matplotlib.colors import Normalize +from PIL import Image + +if TYPE_CHECKING: + from lsst.geom import Box2I + +# It may be desirabe in the future to allow users to choose +# what type of image they want to send to the client. +# For now the default is sent to png. +img_formats = { + "png": "PNG", + "bmp": "BMP", + "eps": "EPS", + "gif": "GIF", + "im": "IM", + "jpg": "JPEG", + "j2k": "JPEG 2000", + "msp": "MSP", + "pcx": "PCX", + "pbm": "PBM", + "pgm": "PGM", + "ppm": "PPM", + "spi": "SPIDER", + "tiff": "TIFF", + "webp": "WEBP", + "xbm": "XBM", +} + + +@dataclass(kw_only=True) +class ColorMap: + name: str = "Spectral" + color_scale: str = "linear" + invert_color: bool = False + px_min: float | None = None + px_max: float | None = None + + @property + def set_bounds(self) -> bool: + return self.px_min is not None and self.px_max is not None + + def copy_with( + self, + name: str | None = None, + color_scale: str | None = None, + px_min: float | None = None, + px_max: float | None = None, + ) -> ColorMap: + return ColorMap( + name=name if name is not None else self.name, + color_scale=color_scale if color_scale is not None else self.color_scale, + invert_color=self.invert_color, + px_min=px_min if px_min is not None else self.px_min, + px_max=px_max if px_max is not None else self.px_max, + ) + + def to_json(self): + return { + "name": self.name, + "color_scale": self.color_scale, + "invert_color": self.invert_color, + "px_min": self.px_min, + "px_max": self.px_max, + } + + +@dataclass(kw_only=True) +class FileInfo: + data_id: dict[str, Any] + tile_width: int = 400 + tile_height: int = 200 + image_type: str = "image" + resampling: str = "NEAREST" + invert_x: bool = False + invert_y: bool = False + tile_format: str = "png" + data: np.ndarray + bbox: Box2I + colormap: ColorMap + + def to_json(self): + return { + "data_id": self.data_id, + "tile_width": self.tile_width, + "tile_height": self.tile_height, + "img_type": self.image_type, + "resampling": self.resampling, + "invert_x": self.invert_x, + "invert_y": self.invert_y, + "tile_format": self.tile_format, + "bbox": [self.bbox.getMinX(), self.bbox.getMinY(), self.bbox.getMaxX(), self.bbox.getMaxY()], + "colormap": self.colormap.to_json(), + } + + +@dataclass(kw_only=True) +class ImageFile: + file_info: FileInfo + data: np.ndarray + created: datetime.datetime + modified: datetime.datetime + + +# Cached images loaded on the image worker +images_loaded: dict[str, ImageFile] = {} + + +class ImageViewer: + x_center: int + y_center: int + width: int + height: int + scale: float + left: int = 0 + bottom: int = 0 + right: int = 0 + top: int = 0 + + def __init__(self, x_center: int, y_center: int, width: int, height: int, scale: float): + self.x_center = x_center + self.y_center = y_center + self.width = width + self.height = height + self.scale = scale + self.left = x_center - int(width / 2) + self.right = x_center + int(width / 2) + self.bottom = y_center - int(height / 2) + self.top = y_center + int(height / 2) + + @staticmethod + def best_fit(data_width: int, data_height: int, viewer_width: int, viewer_height: int): + x_scale = viewer_width / data_width * 0.97 + y_scale = viewer_height / data_height * 0.97 + scale = min(y_scale, x_scale) + x_center = int(np.floor(data_width / 2 * scale)) + y_center = int(np.floor(data_height / 2 * scale)) + return ImageViewer(x_center, y_center, viewer_width, viewer_height, scale) + + def to_json(self): + return { + "x_center": self.x_center, + "y_center": self.y_center, + "width": self.width, + "height": self.height, + "scale": self.scale, + } + + +class ImageInfo: + viewer: ImageViewer + width: int + height: int + scale: float + scaled_width: int + scaled_height: int + columns: int + rows: int + invert_x: bool = False + invert_y: bool = False + tiles: dict = {} + colormap: ColorMap | None = None + + def __init__(self, file_info: FileInfo, viewer: ImageViewer): + self.viewer = viewer + self.width = file_info.bbox.getWidth() + self.height = file_info.bbox.getHeight() + self.scaled_width = int(np.ceil(self.width * self.viewer.scale)) + self.scaled_height = int(np.ceil(self.height * self.viewer.scale)) + self.columns = int(np.ceil(self.scaled_width / file_info.tile_width)) + self.rows = int(np.ceil(self.scaled_height / file_info.tile_height)) + + +@dataclass(kw_only=True) +class TileInfo: + idx: str + left: int + right: int + top: int + bottom: int + y0_idx: int + yf_idx: int + x0_idx: int + xf_idx: int + loaded: bool + row: int + col: int + x: int + y: int + width: int + height: int + + def to_json(self): + return { + "idx": self.idx, + "left": self.left, + "right": self.right, + "top": self.top, + "bottom": self.bottom, + "y0_idx": self.y0_idx, + "yf_idx": self.yf_idx, + "x0_idx": self.x0_idx, + "xf_idx": self.xf_idx, + "loaded": self.loaded, + "row": self.row, + "col": self.col, + "x": self.x, + "y": self.y, + "width": self.width, + "height": self.height, + } + + def to_basic_tile_info(self): + return BasicTileInfo( + y0_idx=self.y0_idx, + yf_idx=self.yf_idx, + x0_idx=self.x0_idx, + xf_idx=self.xf_idx, + width=self.width, + height=self.height, + ) + + +@dataclass(kw_only=True) +class BasicTileInfo: + y0_idx: int + yf_idx: int + x0_idx: int + xf_idx: int + width: int + height: int + + +def get_all_tile_info(file_info: FileInfo, img_info: ImageInfo): + """Get info for all tiles available in the viewer. + + If the tile has not been loaded yet, it is added to the new_tiles array. + """ + all_tiles = [] + new_tiles = {} + if img_info.invert_x: + xmin = img_info.width * img_info.scale - img_info.viewer.right + xmax = img_info.width * img_info.scale - img_info.viewer.left + else: + xmin = img_info.viewer.left + xmax = img_info.viewer.right + if img_info.invert_y: + ymin = img_info.height * img_info.scale - img_info.viewer.bottom + ymax = img_info.height * img_info.scale - img_info.viewer.top + else: + ymin = img_info.viewer.top + ymax = img_info.viewer.bottom + min_col = int(max(1, np.floor(xmin / file_info.tile_width))) - 1 + max_col = int(min(img_info.columns, np.ceil(xmax / file_info.tile_width))) + min_row = int(max(1, np.floor(ymin / file_info.tile_height))) - 1 + max_row = int(min(img_info.rows, np.ceil(ymax / file_info.tile_height))) + + block_width = int(np.ceil(file_info.tile_width / img_info.scale)) + block_height = int(np.ceil(file_info.tile_height / img_info.scale)) + + for row in range(min_row, max_row): + y0 = row * file_info.tile_height + yf = (row + 1) * file_info.tile_height + y0_idx = int(y0 / img_info.scale) + yf_idx = min(y0_idx + block_height, img_info.height) + for col in range(min_col, max_col): + all_tiles.append(str(col) + "," + str(row)) + tile_idx = str(col) + "," + str(row) + if ( + tile_idx not in img_info.tiles + or "loaded" not in img_info.tiles[tile_idx] + or not img_info.tiles[tile_idx]["loaded"] + ): + x0 = col * file_info.tile_width + xf = (col + 1) * file_info.tile_width + x0_idx = int(x0 / img_info.scale) + xf_idx = min(x0_idx + block_width, img_info.width) + tile_width = int((xf_idx - x0_idx) * img_info.scale) + tile_height = int((yf_idx - y0_idx) * img_info.scale) + tile = TileInfo( + idx=tile_idx, + left=x0, + right=xf, + top=y0, + bottom=yf, + y0_idx=y0_idx, + yf_idx=yf_idx, + x0_idx=x0_idx, + xf_idx=xf_idx, + loaded=False, + row=row, + col=col, + x=col * file_info.tile_width, + y=row * file_info.tile_height, + width=tile_width, + height=tile_height, + ) + if img_info.invert_y: + tile.top = yf + tile.bottom = y0 + if img_info.invert_x: + tile.left = xf + tile.right = x0 + new_tiles[tile_idx] = tile + print("viewer:", img_info.viewer) + print("new tiles", new_tiles.keys()) + return all_tiles, new_tiles + + +def scale_data(img_info: ImageInfo, tile_info: BasicTileInfo, data: np.ndarray): + if img_info.scale == 1: + data = data[tile_info.y0_idx : tile_info.yf_idx, tile_info.x0_idx : tile_info.xf_idx] + else: + data = data[tile_info.y0_idx : tile_info.yf_idx, tile_info.x0_idx : tile_info.xf_idx] + data = scipy.ndimage.zoom(data, img_info.scale, order=0) + return data + + +def create_tile(file_info: FileInfo, img_info: ImageInfo, tile_info: BasicTileInfo) -> Image.Image | None: + if file_info.resampling == "NEAREST": + data = scale_data(img_info, tile_info, file_info.data) + else: + data = file_info.data[tile_info.y0_idx : tile_info.yf_idx, tile_info.x0_idx : tile_info.xf_idx] + # FITS images have a flipped y-axis from what browsers + # and other image formats expect. + if img_info.invert_y: + data = np.flipud(data) + if img_info.invert_x: + data = np.fliplr(data) + + assert img_info.colormap is not None + + norm = Normalize(img_info.colormap.px_min, img_info.colormap.px_max, True) + colormap_name = img_info.colormap.name + if img_info.colormap.invert_color: + colormap_name = colormap_name + "_r" + colormap = getattr(cmap, colormap_name) + cm = cmap.ScalarMappable(norm, colormap) + img = np.uint8(cm.to_rgba(data) * 255) + img = Image.fromarray(img) + if file_info.resampling != "NEAREST": + img = img.resize((tile_info.width, tile_info.height), getattr(Image, file_info.resampling)) + + width, height = img.size + if width > 0 and height > 0: + return img + + return None + + +def get_img_data( + file_info: FileInfo, img_info: ImageInfo, width: int, height: int, x: int, y: int, rescale: bool = False +): + """ + Get data from an image or FITS file + """ + assert file_info.data is not None + data = file_info.data + + if rescale: + width = int(width / 2 / img_info.viewer.scale) + height = int(height / 2 / img_info.viewer.scale) + else: + width = int(width / 2) + height = int(height / 2) + x0 = max(0, x - width) + y0 = max(0, y - height) + xf = min(data.shape[1], x + width) + yf = min(data.shape[0], y + height) + if rescale: + tile_info = BasicTileInfo( + y0_idx=y0, + yf_idx=yf, + x0_idx=x0, + xf_idx=xf, + width=width, + height=height, + ) + data = scale_data(img_info, tile_info, data) + else: + data = data[y0:yf, x0:xf] + response = { + "id": "data", + "min": float(data.min()), + "max": float(data.max()), + "mean": float(data.mean()), + "median": float(np.median(data)), + "std_dev": float(np.std(data)), + "data": data.tolist(), + } + + return response + + +def get_point_data(file_info: FileInfo, img_info: ImageInfo, x: int, y: int) -> dict[str, Any]: + assert file_info.data is not None + data = file_info.data + + if x < data.shape[1] and y < data.shape[0] and x >= 0 and y >= 0: + response = {"id": "datapoint", "px_value": float(data[y, x])} + else: + response = {"id": "datapoint", "px_value": 0} + return response diff --git a/python/lsst/rubintv/analysis/service/client.py b/python/lsst/rubintv/analysis/service/worker.py similarity index 63% rename from python/lsst/rubintv/analysis/service/client.py rename to python/lsst/rubintv/analysis/service/worker.py index 33781b4..8cfdf67 100644 --- a/python/lsst/rubintv/analysis/service/client.py +++ b/python/lsst/rubintv/analysis/service/worker.py @@ -18,32 +18,54 @@ # # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from __future__ import annotations + import logging +from typing import TYPE_CHECKING -import sqlalchemy -import yaml -from lsst.daf.butler import Butler from websocket import WebSocketApp -from .command import DatabaseConnection, execute_command -from .utils import Colors, printc +from .command import execute_command + +if TYPE_CHECKING: + from .command import DataCenter logger = logging.getLogger("lsst.rubintv.analysis.service.client") class Worker: - def __init__(self, address: str, port: int, connection_info: dict[str, dict]): + """A worker that connects to the rubinTV server and executes commands. + + Attributes + ---------- + _address : + Address of the rubinTV web app websockets. + _port : + Port of the rubinTV web app websockets. + _dataCenter : + Data center for the worker. + """ + + _address: str + _port: int + _data_center: DataCenter + + def __init__(self, address: str, port: int, data_center: DataCenter): self._address = address self._port = port - self._connection_info = connection_info + self._data_center = data_center + + @property + def data_center(self) -> DataCenter: + return self._data_center def on_error(self, ws: WebSocketApp, error: str) -> None: """Error received from the server.""" - printc(f"Error: {error}", color=Colors.BRIGHT_RED) + logger.error(f"Error: {error}") def on_close(self, ws: WebSocketApp, close_status_code: str, close_msg: str) -> None: """Connection closed by the server.""" - printc("Connection closed", Colors.BRIGHT_YELLOW) + logger.connection("Connection closed") def run(self) -> None: """Run the worker and connect to the rubinTV server. @@ -57,27 +79,14 @@ def run(self) -> None: connection_info : Connections . """ - # Load the database connection information - databases: dict[str, DatabaseConnection] = {} - - for name, info in self._connection_info["databases"].items(): - with open(info["schema"], "r") as file: - engine = sqlalchemy.create_engine(info["url"]) - schema = yaml.safe_load(file) - databases[name] = DatabaseConnection(schema=schema, engine=engine) - - # Load the Butler (if one is available) - butler: Butler | None = None - if "butler" in self._connection_info: - repo = self._connection_info["butler"].pop("repo") - butler = Butler(repo, **self._connection_info["butler"]) def on_message(ws: WebSocketApp, message: str) -> None: """Message received from the server.""" - response = execute_command(message, databases, butler) + response = execute_command(message, self.data_center) ws.send(response) - printc(f"Connecting to rubinTV at {self._address}:{self._port}", Colors.BRIGHT_GREEN) + logger.connection(f"Connecting to rubinTV at {self._address}:{self._port}") + # Connect to the WebSocket server ws = WebSocketApp( f"ws://{self._address}:{self._port}/ws/worker", diff --git a/scripts/config.yaml b/scripts/config.yaml index 4e690ac..690199c 100644 --- a/scripts/config.yaml +++ b/scripts/config.yaml @@ -1,8 +1,14 @@ --- -databases: - summitcdb: - schema: "/Users/fred3m/temp/visitDb/summit.yaml" - url: "sqlite:////Users/fred3m/temp/visitDb/summit.db" -#butler: -# repo: /repos/main -# skymap: hsc_rings_v1 +locations: + summit: "postgresdb01.cp.lsst.org" + usdf: "usdf-summitdb.slac.stanford.edu" +schemas: + cdb_latiss: cdb_latiss.yaml + cdb_lsstcomcam: cdb_lsstcomcam.yaml + cdb_lsstcomcamsim: cdb_lsstcomcamsim.yaml +# cdb_lsstcam: cdb_lsstcam.yaml +repos: + - /repo/main + - embargo_or4 +#efd: +# url: connection info here diff --git a/scripts/joins.yaml b/scripts/joins.yaml new file mode 100644 index 0000000..acad6e7 --- /dev/null +++ b/scripts/joins.yaml @@ -0,0 +1,58 @@ +--- +joins: + # exposure and ccdexposure + - type: inner + matches: + exposure: + - exposure_id + ccdexposure: + - exposure_id + + # exposure and visit1 + - type: inner + matches: + exposure: + - exposure_id + visit1: + - visit_id + + # exposure and ccdvisit1 + - type: inner + matches: + exposure: + - exposure_id + ccdvisit1: + - visit_id + + # ccdexposure and ccdexposure_camera + - type: inner + matches: + ccdexposure: + - ccdexposure_id + ccdexposure_camera: + - ccdexposure_id + + # visit1 and ccdvisit1 + - type: inner + matches: + visit1: + - visit_id + ccdvisit1: + #- visit_id + - exposure_id + + # visit1 and visit1_quicklook + - type: inner + matches: + visit1: + - visit_id + visit1_quicklook: + - visit_id + + # ccdvisit1 and ccdvisit1_quicklook + - type: inner + matches: + ccdvisit1: + - ccdvisit_id + ccdvisit1_quicklook: + - ccdvisit_id diff --git a/scripts/mock_server.py b/scripts/mock_server.py index e6fb659..108d62e 100644 --- a/scripts/mock_server.py +++ b/scripts/mock_server.py @@ -21,6 +21,8 @@ from __future__ import annotations +import argparse +import logging import uuid from dataclasses import dataclass from enum import Enum @@ -29,11 +31,9 @@ import tornado.ioloop import tornado.web import tornado.websocket -from lsst.rubintv.analysis.service.utils import Colors, printc +from lsst.rubintv.analysis.service.utils import ServerFormatter -# Default port and address to listen on -LISTEN_PORT = 2000 -LISTEN_ADDRESS = "localhost" +logger = logging.getLogger("lsst.rubintv.analysis.service.server") class WorkerPodStatus(Enum): @@ -75,17 +75,13 @@ def open(self, client_type: str) -> None: self.client_id = str(uuid.uuid4()) if client_type == "worker": WebSocketHandler.workers[self.client_id] = WorkerPod(self.client_id, self) - printc( + logger.worker( f"New worker {self.client_id} connected. Total workers: {len(WebSocketHandler.workers)}", - Colors.BLUE, - Colors.RED, ) if client_type == "client": WebSocketHandler.clients[self.client_id] = self - printc( + logger.client( f"New client {self.client_id} connected. Total clients: {len(WebSocketHandler.clients)}", - Colors.YELLOW, - Colors.RED, ) def on_message(self, message: str) -> None: @@ -98,7 +94,7 @@ def on_message(self, message: str) -> None: The message received from the client or worker. """ if self.client_id in WebSocketHandler.clients: - printc(f"Message received from {self.client_id}", Colors.YELLOW, Colors.RED) + logger.client(f"Message received from {self.client_id}") client = WebSocketHandler.clients[self.client_id] # Find an idle worker @@ -118,11 +114,7 @@ def on_message(self, message: str) -> None: if self.client_id in WebSocketHandler.workers: worker = WebSocketHandler.workers[self.client_id] worker.on_finished(message) - printc( - f"Message received from worker {self.client_id}. New status {worker.status}", - Colors.BLUE, - Colors.RED, - ) + logger.worker(f"Message received from worker {self.client_id}. New status {worker.status}") # Check the queue for any outstanding jobs. if len(WebSocketHandler.queue) > 0: @@ -136,22 +128,14 @@ def on_close(self) -> None: """ if self.client_id in WebSocketHandler.clients: del WebSocketHandler.clients[self.client_id] - printc( - f"Client disconnected. Active clients: {len(WebSocketHandler.clients)}", - Colors.YELLOW, - Colors.RED, - ) + logger.client(f"Client disconnected. Active clients: {len(WebSocketHandler.clients)}") for worker in WebSocketHandler.workers.values(): if worker.connected_client == self: worker.on_finished("Client disconnected") break if self.client_id in WebSocketHandler.workers: del WebSocketHandler.workers[self.client_id] - printc( - f"Worker disconnected. Active workers: {len(WebSocketHandler.workers)}", - Colors.BLUE, - Colors.RED, - ) + logger.worker(f"Worker disconnected. Active workers: {len(WebSocketHandler.workers)}") def check_origin(self, origin): """ @@ -196,11 +180,7 @@ def process(self, message: str, connected_client: WebSocketHandler): """ self.status = WorkerPodStatus.BUSY self.connected_client = connected_client - printc( - f"Worker {self.wid} processing message from client {connected_client.client_id}", - Colors.BLUE, - Colors.RED, - ) + logger.worker(f"Worker {self.wid} processing message from client {connected_client.client_id}") # Send the job to the worker pod self.ws.write_message(message) @@ -214,9 +194,7 @@ def on_finished(self, message): # Send the reply to the client that made the request. self.connected_client.write_message(message) else: - printc( - f"Worker {self.wid} finished processing, but no client was connected.", Colors.RED, Colors.RED - ) + logger.error(f"Worker {self.wid} finished processing, but no client was connected.") self.status = WorkerPodStatus.IDLE self.connected_client = None @@ -238,14 +216,45 @@ class QueueItem: def main(): + parser = argparse.ArgumentParser(description="Initialize a new RubinTV worker.") + parser.add_argument( + "-a", "--address", default="localhost", type=str, help="Address of the rubinTV web app." + ) + parser.add_argument( + "-p", "--port", default=8080, type=int, help="Port of the rubinTV web app websockets." + ) + parser.add_argument( + "--log", + default="INFO", + help="Set the logging level of web app (DEBUG, INFO, WARNING, ERROR, CRITICAL).", + ) + args = parser.parse_args() + + # Configure logging + log_level = getattr(logging, args.log.upper(), None) + if not isinstance(log_level, int): + raise ValueError(f"Invalid log level: {args.log}") + + # Use custom formatting for the server logs + handler = logging.StreamHandler() + handler.setFormatter(ServerFormatter()) + for logger_name in [ + "lsst.rubintv.analysis.service.worker", + "lsst.rubintv.analysis.service.client", + "lsst.rubintv.analysis.service.server", + ]: + logger = logging.getLogger(logger_name) + logger.setLevel(log_level) + logger.addHandler(handler) + # Create tornado application and supply URL routes app = tornado.web.Application(WebSocketHandler.urls()) # type: ignore # Setup HTTP Server http_server = tornado.httpserver.HTTPServer(app) - http_server.listen(LISTEN_PORT, LISTEN_ADDRESS) + http_server.listen(args.port, args.address) - printc(f"Listening on address: {LISTEN_ADDRESS}, {LISTEN_PORT}", Colors.GREEN, Colors.RED) + logger.connection(f"Listening on address: {args.address}, {args.port}") # Start IO/Event loop tornado.ioloop.IOLoop.instance().start() diff --git a/scripts/rubintv_worker.py b/scripts/rubintv_worker.py index 97152bf..1d63bd7 100644 --- a/scripts/rubintv_worker.py +++ b/scripts/rubintv_worker.py @@ -20,13 +20,29 @@ # along with this program. If not, see . import argparse +import logging import os import pathlib +import sqlalchemy import yaml -from lsst.rubintv.analysis.service.client import Worker +from lsst.daf.butler import Butler +from lsst.rubintv.analysis.service.data import DataCenter, DataMatch +from lsst.rubintv.analysis.service.database import ConsDbSchema +from lsst.rubintv.analysis.service.efd import EfdClient +from lsst.rubintv.analysis.service.utils import ServerFormatter +from lsst.rubintv.analysis.service.worker import Worker default_config = os.path.join(pathlib.Path(__file__).parent.absolute(), "config.yaml") +default_joins = os.path.join(pathlib.Path(__file__).parent.absolute(), "joins.yaml") +logger = logging.getLogger("lsst.rubintv.analysis.server.worker") +sdm_schemas_path = os.path.join(os.path.expandvars("$SDM_SCHEMAS_DIR"), "yml") +credentials_path = os.path.join(os.path.expanduser("~"), ".lsst", "postgres-credentials.txt") + + +class UniversalToVisit(DataMatch): + def get_join(self): + return def main(): @@ -35,19 +51,113 @@ def main(): "-a", "--address", default="localhost", type=str, help="Address of the rubinTV web app." ) parser.add_argument( - "-p", "--port", default=2000, type=int, help="Port of the rubinTV web app websockets." + "-p", "--port", default=8080, type=int, help="Port of the rubinTV web app websockets." ) parser.add_argument( "-c", "--config", default=default_config, type=str, help="Location of the configuration file." ) + parser.add_argument("-j", "--joins", default=default_joins, type=str, help="Location of the joins file.") + parser.add_argument( + "-l", + "--location", + default="usdf", + type=str, + help="Location of the worker (either 'summit' or 'usdf')", + ) + parser.add_argument( + "--log", + default="INFO", + help="Set the logging level of the worker pod modules (DEBUG, INFO, WARNING, ERROR, CRITICAL).", + ) + parser.add_argument( + "--log-all", + default="WARNING", + help="Set the logging level of the remainder of packages (DEBUG, INFO, WARNING, ERROR, CRITICAL).", + ) args = parser.parse_args() - # Load the configuration file + # Configure logging for all modules + log_level = getattr(logging, args.log_all.upper(), None) + if not isinstance(log_level, int): + raise ValueError(f"Invalid log level: {args.log}") + logging.basicConfig(level=log_level) + + # Configure logging for the worker pod modules + worker_log_level = getattr(logging, args.log.upper(), None) + if not isinstance(worker_log_level, int): + raise ValueError(f"Invalid log level: {args.log}") + + # Use custom formatting for the server logs + handler = logging.StreamHandler() + handler.setFormatter(ServerFormatter()) + for logger_name in [ + "lsst.rubintv.analysis.service.worker", + "lsst.rubintv.analysis.service.client", + "lsst.rubintv.analysis.service.server", + ]: + logger = logging.getLogger(logger_name) + logger.setLevel(worker_log_level) + logger.addHandler(handler) + + # Load the configuration and join files + logger.info("Loading config") with open(args.config, "r") as file: config = yaml.safe_load(file) + with open(args.joins, "r") as file: + joins = yaml.safe_load(file)["joins"] + + # Set the database URL based on the location + logger.info("Connecting to the database") + server = "" + if args.location.lower() == "summit": + server = config["locations"]["summit"] + elif args.location.lower() == "usdf": + server = config["locations"]["usdf"] + else: + raise ValueError(f"Invalid location: {args.location}, must be either 'summit' or 'usdf'") + + with open(credentials_path, "r") as file: + credentials = file.readlines() + for credential in credentials: + _server, _, database, user, password = credential.split(":") + if _server == server: + password = password.strip() + break + else: + raise ValueError(f"Could not find credentials for {server}") + database_url = f"postgresql://{user}:{password}@{server}/{database}" + engine = sqlalchemy.create_engine(database_url) + + # Initialize the data center that provides access to various data sources + schemas: dict[str, ConsDbSchema] = {} + + for name, filename in config["schemas"].items(): + full_path = os.path.join(sdm_schemas_path, filename) + with open(full_path, "r") as file: + schema = yaml.safe_load(file) + schemas[name] = ConsDbSchema(schema=schema, engine=engine, join_templates=joins) + + # Load the Butler (if one is available) + butlers: dict[str, Butler] | None = None + if "butlers" in config: + logger.info("Connecting to Butlers") + for repo in config["butlers"]: + butlers[repo] = Butler(repo) # type: ignore + + # Load the EFD client (if one is available) + efd_client: EfdClient | None = None + if "efd" in config: + logger.info("Connecting to EFD") + raise NotImplementedError("EFD client not yet implemented") + + # Create the DataCenter that keeps track of all data sources. + # This will have to be updated every time we want to + # change/add a new data source. + data_center = DataCenter(schemas=schemas, butlers=butlers, efd_client=efd_client) # Run the client and connect to rubinTV via websockets - worker = Worker(args.address, args.port, config) + logger.info("Initializing worker") + worker = Worker(args.address, args.port, data_center) worker.run() diff --git a/tests/joins.yaml b/tests/joins.yaml new file mode 100644 index 0000000..bb035f9 --- /dev/null +++ b/tests/joins.yaml @@ -0,0 +1,21 @@ +--- +joins: + # visit1 and visit1_quicklook + - type: inner + matches: + exposure: + - exposure_id + visit1_quicklook: + - visit_id + - type: inner + matches: + exposure: + - exposure_id + visit1: + - visit_id + - type: inner + matches: + visit1: + - visit_id + visit1_quicklook: + - visit_id diff --git a/tests/schema.yaml b/tests/schema.yaml index 835499d..8ff7caf 100644 --- a/tests/schema.yaml +++ b/tests/schema.yaml @@ -2,16 +2,30 @@ name: testdb "@id": "#test_db" description: Small database for testing the package +joins: + - type: inner + matches: + exposure: + - exposure_id + visit1_quicklook: + - visit_id tables: - - name: ExposureInfo - index_column: exposure_id + - name: exposure + index_columns: + - exposure_id columns: - name: exposure_id datatype: long - description: Unique identifier of an exposure. + description: Unique identifier for the exposure. - name: seq_num datatype: long description: Sequence number + - name: day_obs + datatype: date + description: The night of the observation. This is different than the + observation date, as this is the night that the observations started, + so for observations after midnight obsStart and obsNight will be + different days. - name: ra datatype: double unit: degree @@ -20,23 +34,59 @@ tables: datatype: double unit: degree description: Declination of focal plane center - - name: expTime - datatype: double - description: Spatially-averaged duration of exposure, accurate to 10ms. - name: physical_filter datatype: char description: ID of physical filter, the filter associated with a particular instrument. - - name: obsNight + - name: obs_start + datatype: datetime + description: Start time of the exposure at the fiducial center + of the focal plane array, TAI, accurate to 10ms. + - name: obs_start_mjd + datatype: double + description: Start of the exposure in MJD, TAI, accurate to 10ms. + - name: visit1 + index_columns: + - visit_id + columns: + - name: visit_id + datatype: long + description: Unique identifier for the exposure. + - name: seq_num + datatype: long + description: Sequence number + - name: day_obs datatype: date description: The night of the observation. This is different than the observation date, as this is the night that the observations started, so for observations after midnight obsStart and obsNight will be different days. - - name: obsStart + - name: ra + datatype: double + unit: degree + description: RA of focal plane center. + - name: dec + datatype: double + unit: degree + description: Declination of focal plane center + - name: physical_filter + datatype: char + description: ID of physical filter, + the filter associated with a particular instrument. + - name: obs_start datatype: datetime description: Start time of the exposure at the fiducial center of the focal plane array, TAI, accurate to 10ms. - - name: obsStartMJD + - name: obs_start_mjd datatype: double description: Start of the exposure in MJD, TAI, accurate to 10ms. + - name: visit1_quicklook + index_columns: + - visit_id + columns: + - name: visit_id + datatype: long + description: Unique identifier for the visit. + - name: exp_time + datatype: double + description: Spatially-averaged duration of exposure, accurate to 10ms. diff --git a/tests/test_command.py b/tests/test_command.py index bab82c7..58640eb 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -20,45 +20,17 @@ # along with this program. If not, see . import json -import os -import tempfile +from typing import cast +import astropy.table import lsst.rubintv.analysis.service as lras -import sqlalchemy import utils -import yaml class TestCommand(utils.RasTestCase): - def setUp(self): - path = os.path.dirname(__file__) - yaml_filename = os.path.join(path, "schema.yaml") - - with open(yaml_filename) as file: - schema = yaml.safe_load(file) - db_file = tempfile.NamedTemporaryFile(delete=False) - utils.create_database(schema, db_file.name) - self.db_file = db_file - self.db_filename = db_file.name - self.schema = schema - - # Load the database connection information - self.databases = { - "testdb": lras.command.DatabaseConnection( - schema=schema, engine=sqlalchemy.create_engine("sqlite:///" + db_file.name) - ) - } - - # Set up the sqlalchemy connection - self.engine = sqlalchemy.create_engine("sqlite:///" + db_file.name) - - def tearDown(self) -> None: - self.db_file.close() - os.remove(self.db_file.name) - def execute_command(self, command: dict, response_type: str) -> dict: command_json = json.dumps(command) - response = lras.command.execute_command(command_json, self.databases, None) + response = lras.command.execute_command(command_json, self.data_center) result = json.loads(response) self.assertEqual(result["type"], response_type) return result["content"] @@ -70,64 +42,58 @@ def test_calculate_bounds_command(self): "name": "get bounds", "parameters": { "database": "testdb", - "table": "ExposureInfo", - "column": "dec", + "column": "exposure.dec", }, } + print(lras.command.BaseCommand.command_registry) content = self.execute_command(command, "column bounds") - self.assertEqual(content["column"], "dec") + self.assertEqual(content["column"], "exposure.dec") self.assertListEqual(content["bounds"], [-40, 50]) class TestLoadColumnsCommand(TestCommand): - def test_load_full_dataset(self): - command = {"name": "load columns", "parameters": {"database": "testdb", "table": "ExposureInfo"}} - - content = self.execute_command(command, "table columns") - data = content["data"] - - truth = utils.ap_table_to_list(utils.get_test_data()) - - self.assertDataTableEqual(data, truth) - def test_load_full_columns(self): command = { "name": "load columns", "parameters": { "database": "testdb", - "table": "ExposureInfo", "columns": [ - "ra", - "dec", + "exposure.ra", + "exposure.dec", ], }, } content = self.execute_command(command, "table columns") - columns = content["columns"] data = content["data"] - truth = utils.get_test_data()["exposure_id", "ra", "dec"] - truth_data = utils.ap_table_to_list(truth) - - self.assertTupleEqual(tuple(columns), tuple(truth.columns)) - self.assertDataTableEqual(data, truth_data) + truth = cast( + astropy.table.Table, + utils.get_test_data("exposure")[ + "exposure.ra", + "exposure.dec", + "exposure.day_obs", + "exposure.seq_num", + ], + ) + valid = (truth["exposure.ra"] != None) & (truth["exposure.dec"] != None) # noqa: E711 + truth = cast(astropy.table.Table, truth[valid]) + self.assertDataTableEqual(data, truth) def test_load_columns_with_query(self): command = { "name": "load columns", "parameters": { "database": "testdb", - "table": "ExposureInfo", "columns": [ - "exposure_id", - "ra", - "dec", + "visit1_quicklook.visit_id", + "exposure.ra", + "exposure.dec", ], "query": { "name": "EqualityQuery", "content": { - "column": "expTime", + "column": "visit1_quicklook.exp_time", "operator": "eq", "value": 30, }, @@ -136,16 +102,27 @@ def test_load_columns_with_query(self): } content = self.execute_command(command, "table columns") - columns = content["columns"] data = content["data"] - truth = utils.get_test_data()["exposure_id", "ra", "dec"] - # Select rows with expTime = 30 - truth = truth[[True, True, False, False, False, True, True, True, False, False]] - truth_data = utils.ap_table_to_list(truth) + visit_truth = utils.get_test_data("exposure") + exp_truth = utils.get_test_data("visit1_quicklook") + truth = astropy.table.join( + visit_truth, + exp_truth, + keys_left=("exposure.exposure_id",), + keys_right=("visit1_quicklook.visit_id",), + ) + truth = truth[ + "visit1_quicklook.visit_id", + "exposure.ra", + "exposure.dec", + "exposure.day_obs", + "exposure.seq_num", + ] - self.assertTupleEqual(tuple(columns), tuple(truth.columns)) - self.assertDataTableEqual(data, truth_data) + # Select rows with expTime = 30 + truth = truth[[True, True, False, False, False, True, False, False, False, False]] + self.assertDataTableEqual(data, truth) class TestCommandErrors(TestCommand): @@ -156,7 +133,7 @@ def check_error_response(self, content: dict, error: str, description: str | Non def test_errors(self): # Command cannot be decoded as JSON dict - content = self.execute_command("{'test': [1,2,3,0004,}", "error") + content = self.execute_command("{'test': [1,2,3,0004,}", "error") # type: ignore self.check_error_response(content, "parsing error") # Command does not contain a "name" @@ -201,7 +178,7 @@ def test_errors(self): # Command execution failed (table name does not exist) command = { "name": "get bounds", - "parameters": {"database": "testdb", "table": "InvalidTable", "column": "invalid_column"}, + "parameters": {"database": "testdb", "column": "InvalidTable.invalid_column"}, } content = self.execute_command(command, "error") self.check_error_response( diff --git a/tests/test_database.py b/tests/test_database.py index 94d6914..ddc6a38 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -19,85 +19,71 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import os -import tempfile -from unittest import TestCase - +import astropy.table import lsst.rubintv.analysis.service as lras -import sqlalchemy import utils -import yaml - - -class TestDatabase(TestCase): - def setUp(self): - path = os.path.dirname(__file__) - yaml_filename = os.path.join(path, "schema.yaml") - with open(yaml_filename) as file: - schema = yaml.safe_load(file) - db_file = tempfile.NamedTemporaryFile(delete=False) - utils.create_database(schema, db_file.name) - self.db_file = db_file - self.db_filename = db_file.name - self.schema = schema - - # Set up the sqlalchemy connection - self.engine = sqlalchemy.create_engine("sqlite:///" + db_file.name) - - def tearDown(self) -> None: - self.db_file.close() - os.remove(self.db_file.name) +class TestDatabase(utils.RasTestCase): def test_get_table_names(self): - table_names = lras.database.get_table_names(self.schema) - self.assertTupleEqual(table_names, ("ExposureInfo",)) + table_names = self.database.get_table_names() + self.assertTupleEqual( + table_names, + ( + "exposure", + "visit1", + "visit1_quicklook", + ), + ) def test_get_table_schema(self): - schema = lras.database.get_table_schema(self.schema, "ExposureInfo") - self.assertEqual(schema["name"], "ExposureInfo") + schema = lras.database.get_table_schema(self.database.schema, "exposure") + self.assertEqual(schema["name"], "exposure") columns = [ "exposure_id", "seq_num", + "day_obs", "ra", "dec", - "expTime", "physical_filter", - "obsNight", - "obsStart", - "obsStartMJD", + "obs_start", + "obs_start_mjd", ] for n, column in enumerate(schema["columns"]): self.assertEqual(column["name"], columns[n]) - def test_query_full_table(self): - truth_table = utils.get_test_data() - truth = utils.ap_table_to_list(truth_table) - - data = lras.database.query_table("ExposureInfo", engine=self.engine) - print(data) - - self.assertListEqual(list(data[0]._fields), list(truth_table.columns)) - - for n in range(len(truth)): - true_row = tuple(truth[n]) - row = tuple(data[n]) - self.assertTupleEqual(row, true_row) - - def test_query_columns(self): - truth = utils.get_test_data() - truth = utils.ap_table_to_list(truth["ra", "dec"]) - - data = lras.database.query_table("ExposureInfo", columns=["ra", "dec"], engine=self.engine) + def test_single_table_query_columns(self): + truth = utils.get_test_data("exposure") + valid = (truth["exposure.ra"] != None) & (truth["exposure.dec"] != None) # noqa: E711 + truth = truth[valid] + truth = truth["exposure.ra", "exposure.dec", "exposure.day_obs", "exposure.seq_num"] + data = self.database.query(columns=["exposure.ra", "exposure.dec"]) + self.assertDataTableEqual(data, truth) # type: ignore + + def test_multiple_table_query_columns(self): + visit_truth = utils.get_test_data("exposure") + exp_truth = utils.get_test_data("visit1_quicklook") + truth = astropy.table.join( + visit_truth, + exp_truth, + keys_left=("exposure.exposure_id"), + keys_right=("visit1_quicklook.visit_id"), + ) + valid = (truth["exposure.ra"] != None) & (truth["exposure.dec"] != None) # noqa: E711 + truth = truth[valid] + truth = truth[ + "exposure.ra", + "exposure.dec", + "visit1_quicklook.visit_id", + "exposure.day_obs", + "exposure.seq_num", + ] - self.assertListEqual(list(data[0]._fields), ["ra", "dec"]) + data = self.database.query(columns=["exposure.ra", "exposure.dec", "visit1_quicklook.visit_id"]) - for n in range(len(truth)): - true_row = tuple(truth[n]) - row = tuple(data[n]) - self.assertTupleEqual(row, true_row) + self.assertDataTableEqual(data, truth) def test_calculate_bounds(self): - result = lras.database.calculate_bounds("ExposureInfo", "dec", self.engine) + result = self.database.calculate_bounds("exposure.dec") self.assertTupleEqual(result, (-40, 50)) diff --git a/tests/test_query.py b/tests/test_query.py index a125218..b97da65 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -19,93 +19,76 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import os -import tempfile - +import astropy.table import lsst.rubintv.analysis.service as lras import sqlalchemy import utils -import yaml class TestQuery(utils.RasTestCase): - def setUp(self): - path = os.path.dirname(__file__) - yaml_filename = os.path.join(path, "schema.yaml") - - with open(yaml_filename) as file: - schema = yaml.safe_load(file) - db_file = tempfile.NamedTemporaryFile(delete=False) - utils.create_database(schema, db_file.name) - self.db_file = db_file - self.db_filename = db_file.name - self.schema = schema - - # Set up the sqlalchemy connection - self.engine = sqlalchemy.create_engine("sqlite:///" + db_file.name) - self.metadata = sqlalchemy.MetaData() - self.table = sqlalchemy.Table("ExposureInfo", self.metadata, autoload_with=self.engine) - - def tearDown(self) -> None: - self.db_file.close() - os.remove(self.db_file.name) - def test_equality(self): - table = self.table - column = table.columns.dec + query_table = self.database.tables["exposure"] + query_column = query_table.columns.dec value = 0 truth_dict = { - "eq": column == value, - "ne": column != value, - "lt": column < value, - "le": column <= value, - "gt": column > value, - "ge": column >= value, + "eq": query_column == value, + "ne": query_column != value, + "lt": query_column < value, + "le": query_column <= value, + "gt": query_column > value, + "ge": query_column >= value, } for operator, truth in truth_dict.items(): - self.assertTrue(lras.query.EqualityQuery("dec", operator, value)(table).compare(truth)) + result = lras.query.EqualityQuery("exposure.dec", operator, value)(self.database) + self.assertTrue(result.result.compare(truth)) + self.assertSetEqual( + result.tables, + { + "exposure", + }, + ) def test_query(self): - table = self.table - + dec_column = self.database.tables["exposure"].columns.dec + ra_column = self.database.tables["exposure"].columns.ra # dec > 0 - query = lras.query.EqualityQuery("dec", "gt", 0) - result = query(table) - self.assertTrue(result.compare(table.columns.dec > 0)) + query = lras.query.EqualityQuery("exposure.dec", "gt", 0) + result = query(self.database) + self.assertTrue(result.result.compare(dec_column > 0)) # dec < 0 and ra > 60 query = lras.query.ParentQuery( operator="AND", children=[ - lras.query.EqualityQuery("dec", "lt", 0), - lras.query.EqualityQuery("ra", "gt", 60), + lras.query.EqualityQuery("exposure.dec", "lt", 0), + lras.query.EqualityQuery("exposure.ra", "gt", 60), ], ) - result = query(table) + result = query(self.database) truth = sqlalchemy.and_( - table.columns.dec < 0, - table.columns.ra > 60, + dec_column < 0, + ra_column > 60, ) - self.assertTrue(result.compare(truth)) + self.assertTrue(result.result.compare(truth)) # Check queries that are unequal to verify that they don't work - result = query(table) + result = query(self.database) truth = sqlalchemy.and_( - table.columns.dec < 0, - table.columns.ra > 70, + dec_column < 0, + ra_column > 70, ) - self.assertFalse(result.compare(truth)) + self.assertFalse(result.result.compare(truth)) def test_database_query(self): - data = utils.get_test_data() + data = utils.get_test_data("exposure") # dec > 0 (and is not None) query1 = { "name": "EqualityQuery", "content": { - "column": "dec", + "column": "exposure.dec", "operator": "gt", "value": 0, }, @@ -114,7 +97,7 @@ def test_database_query(self): query2 = { "name": "EqualityQuery", "content": { - "column": "ra", + "column": "exposure.ra", "operator": "gt", "value": 60, }, @@ -122,10 +105,15 @@ def test_database_query(self): # Test 1: dec > 0 (and is not None) query = query1 - result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) - truth = data[[False, False, False, False, False, True, False, True, True, True]] - truth = utils.ap_table_to_list(truth) - self.assertDataTableEqual(result, truth) + result = self.database.query(["exposure.ra", "exposure.dec"], query=lras.query.Query.from_dict(query)) + truth = data[[False, False, False, False, False, True, False, False, True, True]] + truth = truth[ + "exposure.ra", + "exposure.dec", + "exposure.day_obs", + "exposure.seq_num", + ] # type: ignore + self.assertDataTableEqual(result, truth) # type:ignore # Test 2: dec > 0 and ra > 60 (and neither is None) query = { @@ -135,10 +123,15 @@ def test_database_query(self): "children": [query1, query2], }, } - result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + result = self.database.query(["exposure.ra", "exposure.dec"], query=lras.query.Query.from_dict(query)) truth = data[[False, False, False, False, False, False, False, False, True, True]] - truth = utils.ap_table_to_list(truth) - self.assertDataTableEqual(result, truth) + truth = truth[ + "exposure.ra", + "exposure.dec", + "exposure.day_obs", + "exposure.seq_num", + ] # type: ignore + self.assertDataTableEqual(result, truth) # type:ignore # Test 3: dec <= 0 or ra > 60 (and neither is None) query = { @@ -158,10 +151,15 @@ def test_database_query(self): }, } - result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) - truth = data[[True, True, False, True, True, False, True, False, True, True]] - truth = utils.ap_table_to_list(truth) - self.assertDataTableEqual(result, truth) + result = self.database.query(["exposure.ra", "exposure.dec"], query=lras.query.Query.from_dict(query)) + truth = data[[True, True, False, True, True, False, False, False, True, True]] + truth = truth[ + "exposure.ra", + "exposure.dec", + "exposure.day_obs", + "exposure.seq_num", + ] # type: ignore + self.assertDataTableEqual(result, truth) # type:ignore # Test 4: dec > 0 XOR ra > 60 query = { @@ -171,100 +169,129 @@ def test_database_query(self): "children": [query1, query2], }, } - result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + result = self.database.query(["exposure.ra", "exposure.dec"], query=lras.query.Query.from_dict(query)) truth = data[[False, False, False, False, False, True, False, False, False, False]] - truth = utils.ap_table_to_list(truth) - self.assertDataTableEqual(result, truth) + truth = truth[ + "exposure.ra", + "exposure.dec", + "exposure.day_obs", + "exposure.seq_num", + ] # type: ignore + self.assertDataTableEqual(result, truth) # type:ignore def test_database_string_query(self): - data = utils.get_test_data() + data = utils.get_test_data("exposure") # Test equality query = { "name": "EqualityQuery", "content": { - "column": "physical_filter", + "column": "exposure.physical_filter", "operator": "eq", "value": "DECam r-band", }, } - result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + result = self.database.query(["exposure.physical_filter"], query=lras.query.Query.from_dict(query)) truth = data[[False, False, False, False, False, False, True, False, False, False]] - truth = utils.ap_table_to_list(truth) - self.assertDataTableEqual(result, truth) + truth = truth[ + "exposure.physical_filter", + "exposure.day_obs", + "exposure.seq_num", + ] # type: ignore + self.assertDataTableEqual(result, truth) # type:ignore # Test "startswith" query = { "name": "EqualityQuery", "content": { - "column": "physical_filter", + "column": "exposure.physical_filter", "operator": "startswith", "value": "DECam", }, } - result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + result = self.database.query(["exposure.physical_filter"], query=lras.query.Query.from_dict(query)) truth = data[[False, False, False, False, False, True, True, True, True, True]] - truth = utils.ap_table_to_list(truth) - self.assertDataTableEqual(result, truth) + truth = truth[ + "exposure.physical_filter", + "exposure.day_obs", + "exposure.seq_num", + ] # type: ignore + self.assertDataTableEqual(result, truth) # type:ignore # Test "endswith" query = { "name": "EqualityQuery", "content": { - "column": "physical_filter", + "column": "exposure.physical_filter", "operator": "endswith", "value": "r-band", }, } - result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + result = self.database.query(["exposure.physical_filter"], query=lras.query.Query.from_dict(query)) truth = data[[False, True, False, False, False, False, True, False, False, False]] - truth = utils.ap_table_to_list(truth) - self.assertDataTableEqual(result, truth) + truth = truth[ + "exposure.physical_filter", + "exposure.day_obs", + "exposure.seq_num", + ] # type: ignore + self.assertDataTableEqual(result, truth) # type:ignore # Test "like" query = { "name": "EqualityQuery", "content": { - "column": "physical_filter", + "column": "exposure.physical_filter", "operator": "contains", "value": "T r", }, } - result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query) + result = self.database.query(["exposure.physical_filter"], query=lras.query.Query.from_dict(query)) truth = data[[False, True, False, False, False, False, False, False, False, False]] - truth = utils.ap_table_to_list(truth) - self.assertDataTableEqual(result, truth) + truth = truth[ + "exposure.physical_filter", + "exposure.day_obs", + "exposure.seq_num", + ] # type: ignore + self.assertDataTableEqual(result, truth) # type:ignore def test_database_datatime_query(self): - data = utils.get_test_data() + data = utils.get_test_data("exposure") # Test < query1 = { "name": "EqualityQuery", "content": { - "column": "obsStart", + "column": "exposure.obs_start", "operator": "lt", "value": "2023-05-19 23:23:23", }, } - result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query1) + result = self.database.query(["exposure.obs_start"], query=lras.query.Query.from_dict(query1)) truth = data[[True, True, True, False, False, True, True, True, True, True]] - truth = utils.ap_table_to_list(truth) - self.assertDataTableEqual(result, truth) + truth = truth[ + "exposure.obs_start", + "exposure.day_obs", + "exposure.seq_num", + ] # type: ignore + self.assertDataTableEqual(result, truth) # type:ignore # Test > query2 = { "name": "EqualityQuery", "content": { - "column": "obsStart", + "column": "exposure.obs_start", "operator": "gt", "value": "2023-05-01 23:23:23", }, } - result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query2) + result = self.database.query(["exposure.obs_start"], query=lras.query.Query.from_dict(query2)) truth = data[[True, True, True, True, True, False, False, False, False, False]] - truth = utils.ap_table_to_list(truth) - self.assertDataTableEqual(result, truth) + truth = truth[ + "exposure.obs_start", + "exposure.day_obs", + "exposure.seq_num", + ] # type: ignore + self.assertDataTableEqual(result, truth) # type:ignore # Test in range query3 = { @@ -274,7 +301,71 @@ def test_database_datatime_query(self): "children": [query1, query2], }, } - result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query3) + result = self.database.query(["exposure.obs_start"], query=lras.query.Query.from_dict(query3)) truth = data[[True, True, True, False, False, False, False, False, False, False]] - truth = utils.ap_table_to_list(truth) + truth = truth[ + "exposure.obs_start", + "exposure.day_obs", + "exposure.seq_num", + ] # type: ignore + self.assertDataTableEqual(result, truth) # type:ignore + + def test_multiple_table_query(self): + visit_truth = utils.get_test_data("exposure") + exp_truth = utils.get_test_data("visit1_quicklook") + truth = astropy.table.join( + visit_truth, + exp_truth, + keys_left=("exposure.exposure_id",), + keys_right=("visit1_quicklook.visit_id",), + ) + + # dec > 0 (and is not None) + query1 = { + "name": "EqualityQuery", + "content": { + "column": "exposure.dec", + "operator": "gt", + "value": 0, + }, + } + # exposure time == 30 (and is not None) + query2 = { + "name": "EqualityQuery", + "content": { + "column": "visit1_quicklook.exp_time", + "operator": "eq", + "value": 30, + }, + } + # Intersection of the two queries + query3 = { + "name": "ParentQuery", + "content": { + "operator": "AND", + "children": [query1, query2], + }, + } + + valid = ( + (truth["exposure.dec"] != None) # noqa: E711 + & (truth["exposure.ra"] != None) # noqa: E711 + & (truth["visit1_quicklook.visit_id"] != None) # noqa: E711 + ) + truth = truth[valid] + valid = (truth["exposure.dec"] > 0) & (truth["visit1_quicklook.exp_time"] == 30) + truth = truth[valid] + truth = truth[ + "exposure.dec", + "exposure.ra", + "visit1_quicklook.visit_id", + "exposure.day_obs", + "exposure.seq_num", + ] + + result = self.database.query( + columns=["exposure.ra", "exposure.dec", "visit1_quicklook.visit_id"], + query=lras.query.Query.from_dict(query3), + ) + self.assertDataTableEqual(result, truth) diff --git a/tests/utils.py b/tests/utils.py index 5b78e5b..828a6e0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,11 +19,18 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +import os import sqlite3 +import tempfile from unittest import TestCase +import numpy as np +import sqlalchemy +import yaml from astropy.table import Table as ApTable from astropy.time import Time +from lsst.rubintv.analysis.service.data import DataCenter +from lsst.rubintv.analysis.service.database import ConsDbSchema # Convert visit DB datatypes to sqlite3 datatypes datatype_transform = { @@ -36,6 +43,12 @@ "datetime": "text", } +# Convert DataID columns +dataid_transform = { + "exposure.day_obs": "day_obs", + "exposure.seq_num": "seq_num", +} + def create_table(cursor: sqlite3.Cursor, tbl_name: str, schema: dict): """Create a table in an sqlite database. @@ -56,8 +69,9 @@ def create_table(cursor: sqlite3.Cursor, tbl_name: str, schema: dict): cursor.execute(command) -def get_test_data_dict() -> dict: - """Get a dictionary containing the test data""" +def get_exposure_data_dict(table_name: str, id_name: str) -> dict: + """Get a dictionary containing the visit test data""" + obs_start = [ "2023-05-19 20:20:20", "2023-05-19 21:21:21", @@ -74,24 +88,9 @@ def get_test_data_dict() -> dict: obs_start_mjd = [Time(time).mjd for time in obs_start] return { - "exposure_id": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18], - "seq_num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], - "ra": [10, 20, None, 40, 50, 60, 70, None, 90, 100], - "dec": [-40, -30, None, -10, 0, 10, None, 30, 40, 50], - "expTime": [30, 30, 10, 15, 15, 30, 30, 30, 15, 20], - "physical_filter": [ - "LSST g-band", - "LSST r-band", - "LSST i-band", - "LSST z-band", - "LSST y-band", - "DECam g-band", - "DECam r-band", - "DECam i-band", - "DECam z-band", - "DECam y-band", - ], - "obsNight": [ + f"{table_name}.{id_name}": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], + f"{table_name}.seq_num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + f"{table_name}.day_obs": [ "2023-05-19", "2023-05-19", "2023-05-19", @@ -103,44 +102,79 @@ def get_test_data_dict() -> dict: "2023-02-14", "2023-02-14", ], - "obsStart": obs_start, - "obsStartMJD": obs_start_mjd, + f"{table_name}.ra": [10, 20, None, 40, 50, 60, 70, None, 90, 100], + f"{table_name}.dec": [-40, -30, None, -10, 0, 10, None, 30, 40, 50], + f"{table_name}.physical_filter": [ + "LSST g-band", + "LSST r-band", + "LSST i-band", + "LSST z-band", + "LSST y-band", + "DECam g-band", + "DECam r-band", + "DECam i-band", + "DECam z-band", + "DECam y-band", + ], + f"{table_name}.obs_start": obs_start, + f"{table_name}.obs_start_mjd": obs_start_mjd, } -def get_test_data() -> ApTable: +def get_visit_data_dict() -> dict: + """Get a dictionary containing the exposure test data""" + return { + "visit1_quicklook.visit_id": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], + "visit1_quicklook.exp_time": [30, 30, 10, 15, 15, 30, 30, 30, 15, 20], + } + + +def get_test_data(table: str) -> ApTable: """Generate data for the test database""" - data_dict = get_test_data_dict() + if table == "exposure": + data_dict = get_exposure_data_dict("exposure", "exposure_id") + else: + data_dict = get_visit_data_dict() - table = ApTable(list(data_dict.values()), names=list(data_dict.keys())) - return table + return ApTable(list(data_dict.values()), names=list(data_dict.keys())) def ap_table_to_list(data: ApTable) -> list: """Convert an astropy Table into a list of tuples.""" rows = [] for row in data: - rows.append(tuple(row)) + rows.append(tuple(row)) # type: ignore return rows def create_database(schema: dict, db_filename: str): """Create the test database""" - tbl_name = "ExposureInfo" - connection = sqlite3.connect(db_filename) - cursor = connection.cursor() - - create_table(cursor, tbl_name, schema["tables"][0]["columns"]) - data = get_test_data_dict() - - for n in range(len(data["exposure_id"])): - row = tuple(data[key][n] for key in data.keys()) - value_str = "?, " * (len(row) - 1) + "?" - command = f"INSERT INTO {tbl_name} VALUES({value_str});" - cursor.execute(command, row) - connection.commit() - cursor.close() + for table in schema["tables"]: + connection = sqlite3.connect(db_filename) + cursor = connection.cursor() + + create_table(cursor, table["name"], table["columns"]) + + if table["name"] == "exposure": + data = get_exposure_data_dict("exposure", "exposure_id") + index_key = "exposure.exposure_id" + elif table["name"] == "visit1": + data = get_exposure_data_dict("visit1", "visit_id") + index_key = "visit1.visit_id" + elif table["name"] == "visit1_quicklook": + data = get_visit_data_dict() + index_key = "visit1_quicklook.visit_id" + else: + raise ValueError(f"Unknown table name: {table['name']}") + + for n in range(len(data[index_key])): + row = tuple(data[key][n] for key in data.keys()) + value_str = "?, " * (len(row) - 1) + "?" + command = f"INSERT INTO {table['name']} VALUES({value_str});" + cursor.execute(command, row) + connection.commit() + cursor.close() class TableMismatchError(AssertionError): @@ -155,24 +189,42 @@ class RasTestCase(TestCase): might be put in place. """ - @staticmethod - def get_data_table_indices(table: list[tuple]) -> list[int]: - """Get the index for each rom in the data table. + def setUp(self): + # Load the testdb schema + path = os.path.dirname(__file__) + yaml_filename = os.path.join(path, "schema.yaml") - Parameters - ---------- - table : - The table containing the data. + with open(yaml_filename) as file: + schema = yaml.safe_load(file) - Returns - ------- - result : - The index for each row in the table. - """ - # Return the seq_num as an index - return [row[1] for row in table] + # Remove the name of the schema, since sqlite does not have + # schema names and this will break the code otherwise. + schema["name"] = None + + # Create the sqlite test database + db_file = tempfile.NamedTemporaryFile(delete=False) + create_database(schema, db_file.name) + self.db_file = db_file + self.db_filename = db_file.name + self.schema = schema + + # Set up the sqlalchemy connection + engine = sqlalchemy.create_engine("sqlite:///" + db_file.name) + + # Load the table joins + joins_path = os.path.join(path, "joins.yaml") + with open(joins_path) as file: + joins = yaml.safe_load(file)["joins"] + + # Create the datacenter + self.database = ConsDbSchema(schema=schema, engine=engine, join_templates=joins) + self.data_center = DataCenter(schemas={"testdb": self.database}) + + def tearDown(self) -> None: + self.db_file.close() + os.remove(self.db_file.name) - def assertDataTableEqual(self, result, truth): + def assertDataTableEqual(self, result: dict | ApTable, truth: ApTable): # NOQA: N802 """Check if two data tables are equal. Parameters @@ -182,16 +234,13 @@ def assertDataTableEqual(self, result, truth): truth : The expected value of the test. """ - if len(result) != len(truth): - msg = "Data tables have a different number of rows: " - msg += f"indices: [{self.get_data_table_indices(result)}], [{self.get_data_table_indices(truth)}]" - raise TableMismatchError(msg) - try: - for n in range(len(truth)): - true_row = tuple(truth[n]) - row = tuple(result[n]) - self.assertTupleEqual(row, true_row) - except AssertionError: - msg = "Mismatched tables: " - msg += f"indices: [{self.get_data_table_indices(result)}], [{self.get_data_table_indices(truth)}]" - raise TableMismatchError(msg) + columns = truth.colnames + for column in columns: + result_column = column + if column not in result: + if column in dataid_transform: + result_column = dataid_transform[column] + else: + msg = f"Column {column} not found in result" + raise TableMismatchError(msg) + np.testing.assert_array_equal(np.array(result[result_column]), np.array(truth[column])) diff --git a/ups/rubintv_analysis_service.table b/ups/rubintv_analysis_service.table index e1e210e..10589fd 100644 --- a/ups/rubintv_analysis_service.table +++ b/ups/rubintv_analysis_service.table @@ -3,6 +3,8 @@ # - Common third-party packages can be assumed to be recursively included by # the "base" package. setupRequired(lsst_distrib) +setupRequired(obs_lsst) +setupRequired(afw) # The following is boilerplate for all packages. # See https://dmtn-001.lsst.io for details on LSST_LIBRARY_PATH.