diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
new file mode 100644
index 0000000..7b345f4
--- /dev/null
+++ b/.github/pull_request_template.md
@@ -0,0 +1,4 @@
+## Checklist
+
+- [ ] ran Jenkins
+- [ ] added a release note for user-visible changes to `doc/changes`
diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml
new file mode 100644
index 0000000..0055df1
--- /dev/null
+++ b/.github/workflows/build.yaml
@@ -0,0 +1,55 @@
+name: build_and_test
+
+on:
+ push:
+ branches:
+ - main
+ tags:
+ - "*"
+ pull_request:
+
+jobs:
+ build_and_test:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ # Need to clone everything for the git tags.
+ fetch-depth: 0
+
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.10"
+ cache: "pip"
+ cache-dependency-path: "setup.cfg"
+
+ - name: Install yaml
+ run: sudo apt-get install libyaml-dev
+
+ - name: Install prereqs for setuptools
+ run: pip install wheel
+
+ # We have two cores so we can speed up the testing with xdist
+ - name: Install xdist, openfiles and flake8 for pytest
+ run: >
+ pip install pytest-xdist pytest-openfiles pytest-flake8
+ pytest-cov "flake8<5"
+
+ - name: Build and install
+ run: pip install -v -e .
+
+ - name: Install documenteer
+ run: pip install 'documenteer[pipelines]<0.7'
+
+ - name: Run tests
+ run: >
+ pytest -r a -v -n 3 --open-files --cov=tests
+ --cov=lsst.rubintv.analysis.service
+ --cov-report=xml --cov-report=term
+ --doctest-modules --doctest-glob="*.rst"
+
+ - name: Upload coverage to codecov
+ uses: codecov/codecov-action@v2
+ with:
+ file: ./coverage.xml
diff --git a/.github/workflows/build_docs.yaml b/.github/workflows/build_docs.yaml
new file mode 100644
index 0000000..75d1dac
--- /dev/null
+++ b/.github/workflows/build_docs.yaml
@@ -0,0 +1,41 @@
+name: docs
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+
+jobs:
+ build_sphinx_docs:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ # Need to clone everything for the git tags.
+ fetch-depth: 0
+
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.10"
+ cache: "pip"
+ cache-dependency-path: "setup.cfg"
+
+ - name: Update pip/wheel infrastructure
+ run: |
+ python -m pip install --upgrade pip
+ pip install wheel
+
+ - name: Build and install
+ run: pip install -v -e .
+
+ - name: Show compiled files
+ run: ls python/lsst/rubintv/analysis/service
+
+ - name: Install documenteer
+ run: pip install 'documenteer[pipelines]<0.7'
+
+ - name: Build documentation
+ working-directory: ./doc
+ run: package-docs build
diff --git a/.github/workflows/formatting.yaml b/.github/workflows/formatting.yaml
new file mode 100644
index 0000000..27f34a6
--- /dev/null
+++ b/.github/workflows/formatting.yaml
@@ -0,0 +1,11 @@
+name: Check Python formatting
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+
+jobs:
+ call-workflow:
+ uses: lsst/rubin_workflows/.github/workflows/formatting.yaml@main
diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml
new file mode 100644
index 0000000..796ef92
--- /dev/null
+++ b/.github/workflows/lint.yaml
@@ -0,0 +1,11 @@
+name: lint
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+
+jobs:
+ call-workflow:
+ uses: lsst/rubin_workflows/.github/workflows/lint.yaml@main
diff --git a/.github/workflows/yamllint.yaml b/.github/workflows/yamllint.yaml
new file mode 100644
index 0000000..76ad875
--- /dev/null
+++ b/.github/workflows/yamllint.yaml
@@ -0,0 +1,11 @@
+name: Lint YAML Files
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+
+jobs:
+ call-workflow:
+ uses: lsst/rubin_workflows/.github/workflows/yamllint.yaml@main
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..a07ff8a
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,32 @@
+repos:
+
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.4.0
+ hooks:
+ - id: check-yaml
+ args:
+ - "--unsafe"
+ - id: end-of-file-fixer
+ - id: trailing-whitespace
+ - repo: https://github.com/psf/black
+ rev: 23.1.0
+ hooks:
+ - id: black
+ # It is recommended to specify the latest version of Python
+
+ # supported by your project here, or alternatively use
+
+ # pre-commit's default_language_version, see
+
+ # https://pre-commit.com/#top_level-default_language_version
+
+ language_version: python3.10
+ - repo: https://github.com/pycqa/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+ name: isort (python)
+ - repo: https://github.com/PyCQA/flake8
+ rev: 6.0.0
+ hooks:
+ - id: flake8
diff --git a/doc/.gitignore b/doc/.gitignore
new file mode 100644
index 0000000..ad2c2bb
--- /dev/null
+++ b/doc/.gitignore
@@ -0,0 +1,10 @@
+# Doxygen products
+html
+xml
+*.tag
+*.inc
+doxygen.conf
+
+# Sphinx products
+_build
+py-api
diff --git a/doc/SConscript b/doc/SConscript
new file mode 100644
index 0000000..61b554a
--- /dev/null
+++ b/doc/SConscript
@@ -0,0 +1,3 @@
+# -*- python -*-
+from lsst.sconsUtils import scripts
+scripts.BasicSConscript.doc()
diff --git a/doc/conf.py b/doc/conf.py
new file mode 100644
index 0000000..0dcae36
--- /dev/null
+++ b/doc/conf.py
@@ -0,0 +1,12 @@
+"""Sphinx configuration file for an LSST stack package.
+This configuration only affects single-package Sphinx documentation builds.
+For more information, see:
+https://developer.lsst.io/stack/building-single-package-docs.html
+"""
+
+from documenteer.conf.pipelinespkg import *
+
+project = "rubintv_analysis_service"
+html_theme_options["logotext"] = project
+html_title = project
+html_short_title = project
diff --git a/doc/doxygen.conf.in b/doc/doxygen.conf.in
new file mode 100644
index 0000000..e69de29
diff --git a/doc/index.rst b/doc/index.rst
new file mode 100644
index 0000000..f7713f5
--- /dev/null
+++ b/doc/index.rst
@@ -0,0 +1,12 @@
+##############################################
+rubintv_analysis_service documentation preview
+##############################################
+
+.. This page is for local development only. It isn't published to pipelines.lsst.io.
+
+.. Link the index pages of package and module documentation directions (listed in manifest.yaml).
+
+.. toctree::
+ :maxdepth: 1
+
+ lsst.rubintv.analysis.service/index
diff --git a/doc/lsst.rubintv.analysis.service/index.rst b/doc/lsst.rubintv.analysis.service/index.rst
new file mode 100644
index 0000000..3707f25
--- /dev/null
+++ b/doc/lsst.rubintv.analysis.service/index.rst
@@ -0,0 +1,40 @@
+.. py:currentmodule:: lsst.rubintv.analysis.service
+
+.. _lsst.rubintv.analysis.service:
+
+#############################
+lsst.rubintv.analysis.service
+#############################
+
+.. Paragraph that describes what this Python module does and links to related modules and frameworks.
+
+.. _lsst.rubintv.analysis.service-using:
+
+Using lsst.rubintv.analysis.service
+=======================
+
+toctree linking to topics related to using the module's APIs.
+
+.. toctree::
+ :maxdepth: 2
+
+.. _lsst.rubintv.analysis.service-contributing:
+
+Contributing
+============
+
+``lsst.rubintv.analysis.service`` is developed at https://github.com/lsst-ts/rubintv_analysis_service.
+
+.. If there are topics related to developing this module (rather than using it), link to this from a toctree placed here.
+
+.. .. toctree::
+.. :maxdepth: 2
+
+.. _lsst.rubintv.analysis.service-pyapi:
+
+Python API reference
+====================
+
+.. automodapi:: lsst.rubintv.analysis.service
+ :no-main-docstr:
+ :no-inheritance-diagram:
diff --git a/doc/manifest.yaml b/doc/manifest.yaml
new file mode 100644
index 0000000..222eb97
--- /dev/null
+++ b/doc/manifest.yaml
@@ -0,0 +1,12 @@
+# Documentation manifest.
+
+# List of names of Python modules in this package.
+# For each module there is a corresponding module doc subdirectory.
+modules:
+ - "lsst.rubintv.analysis.service"
+
+# Name of the static content directories (subdirectories of `_static`).
+# Static content directories are usually named after the package.
+# Most packages do not need a static content directory (leave commented out).
+# statics:
+# - "_static/example_standalone"
diff --git a/mypy.ini b/mypy.ini
new file mode 100644
index 0000000..3d32554
--- /dev/null
+++ b/mypy.ini
@@ -0,0 +1,22 @@
+[mypy]
+warn_unused_configs = True
+warn_redundant_casts = True
+plugins = pydantic.mypy
+
+[mypy-astropy.*]
+ignore_missing_imports = True
+
+[mypy-matplotlib.*]
+ignore_missing_imports = True
+
+[mypy-numpy.*]
+ignore_missing_imports = True
+
+[mypy-scipy.*]
+ignore_missing_imports = True
+
+[mypy-sqlalchemy.*]
+ignore_missing_imports = True
+
+[mypy-yaml.*]
+ignore_missing_imports = True
diff --git a/pyproject.toml b/pyproject.toml
index 3b951cd..98176c8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,12 +30,18 @@ dependencies = [
"scipy",
"matplotlib",
"pydantic",
- "pyyaml"
+ "pyyaml",
+ "sqlalchemy",
+ "astropy",
+ "websocket-client",
+ "lsst-daf-butler",
+ # temporary dependency for testing
+ "tornado",
]
#dynamic = ["version"]
[project.urls]
-"Homepage" = "https://github.com/lsst/rubintv_analysis_service"
+"Homepage" = "https://github.com/lsst-ts/rubintv_analysis_service"
[project.optional-dependencies]
test = [
@@ -65,9 +71,6 @@ line_length = 110
[tool.lsst_versions]
write_to = "python/lsst/rubintv/analysis/service/version.py"
-[tool.pytest.ini_options]
-addopts = "--flake8"
-flake8-ignore = ["W503", "E203"]
# The matplotlib test may not release font files.
open_files_ignore = ["*.ttf"]
diff --git a/python/lsst/__init__.py b/python/lsst/__init__.py
index eb1f6e6..f77af49 100644
--- a/python/lsst/__init__.py
+++ b/python/lsst/__init__.py
@@ -1,3 +1,3 @@
import pkgutil
-__path__ = pkgutil.extend_path(__path__, __name__)
\ No newline at end of file
+__path__ = pkgutil.extend_path(__path__, __name__)
diff --git a/python/lsst/rubintv/__init__.py b/python/lsst/rubintv/__init__.py
index eb1f6e6..f77af49 100644
--- a/python/lsst/rubintv/__init__.py
+++ b/python/lsst/rubintv/__init__.py
@@ -1,3 +1,3 @@
import pkgutil
-__path__ = pkgutil.extend_path(__path__, __name__)
\ No newline at end of file
+__path__ = pkgutil.extend_path(__path__, __name__)
diff --git a/python/lsst/rubintv/analysis/__init__.py b/python/lsst/rubintv/analysis/__init__.py
index eb1f6e6..f77af49 100644
--- a/python/lsst/rubintv/analysis/__init__.py
+++ b/python/lsst/rubintv/analysis/__init__.py
@@ -1,3 +1,3 @@
import pkgutil
-__path__ = pkgutil.extend_path(__path__, __name__)
\ No newline at end of file
+__path__ = pkgutil.extend_path(__path__, __name__)
diff --git a/python/lsst/rubintv/analysis/service/__init__.py b/python/lsst/rubintv/analysis/service/__init__.py
new file mode 100644
index 0000000..ae91fbd
--- /dev/null
+++ b/python/lsst/rubintv/analysis/service/__init__.py
@@ -0,0 +1 @@
+from . import command, database, query, utils
diff --git a/python/lsst/rubintv/analysis/service/butler.py b/python/lsst/rubintv/analysis/service/butler.py
new file mode 100644
index 0000000..35dd967
--- /dev/null
+++ b/python/lsst/rubintv/analysis/service/butler.py
@@ -0,0 +1,29 @@
+# This file is part of lsst_rubintv_analysis_service.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from dataclasses import dataclass
+
+from .command import BaseCommand
+
+
+@dataclass
+class ExampleButlerCommand(BaseCommand):
+ pass
diff --git a/python/lsst/rubintv/analysis/service/client.py b/python/lsst/rubintv/analysis/service/client.py
new file mode 100644
index 0000000..33781b4
--- /dev/null
+++ b/python/lsst/rubintv/analysis/service/client.py
@@ -0,0 +1,89 @@
+# This file is part of lsst_rubintv_analysis_service.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+import logging
+
+import sqlalchemy
+import yaml
+from lsst.daf.butler import Butler
+from websocket import WebSocketApp
+
+from .command import DatabaseConnection, execute_command
+from .utils import Colors, printc
+
+logger = logging.getLogger("lsst.rubintv.analysis.service.client")
+
+
+class Worker:
+ def __init__(self, address: str, port: int, connection_info: dict[str, dict]):
+ self._address = address
+ self._port = port
+ self._connection_info = connection_info
+
+ def on_error(self, ws: WebSocketApp, error: str) -> None:
+ """Error received from the server."""
+ printc(f"Error: {error}", color=Colors.BRIGHT_RED)
+
+ def on_close(self, ws: WebSocketApp, close_status_code: str, close_msg: str) -> None:
+ """Connection closed by the server."""
+ printc("Connection closed", Colors.BRIGHT_YELLOW)
+
+ def run(self) -> None:
+ """Run the worker and connect to the rubinTV server.
+
+ Parameters
+ ----------
+ address :
+ Address of the rubinTV web app.
+ port :
+ Port of the rubinTV web app websockets.
+ connection_info :
+ Connections .
+ """
+ # Load the database connection information
+ databases: dict[str, DatabaseConnection] = {}
+
+ for name, info in self._connection_info["databases"].items():
+ with open(info["schema"], "r") as file:
+ engine = sqlalchemy.create_engine(info["url"])
+ schema = yaml.safe_load(file)
+ databases[name] = DatabaseConnection(schema=schema, engine=engine)
+
+ # Load the Butler (if one is available)
+ butler: Butler | None = None
+ if "butler" in self._connection_info:
+ repo = self._connection_info["butler"].pop("repo")
+ butler = Butler(repo, **self._connection_info["butler"])
+
+ def on_message(ws: WebSocketApp, message: str) -> None:
+ """Message received from the server."""
+ response = execute_command(message, databases, butler)
+ ws.send(response)
+
+ printc(f"Connecting to rubinTV at {self._address}:{self._port}", Colors.BRIGHT_GREEN)
+ # Connect to the WebSocket server
+ ws = WebSocketApp(
+ f"ws://{self._address}:{self._port}/ws/worker",
+ on_message=on_message,
+ on_error=self.on_error,
+ on_close=self.on_close,
+ )
+ ws.run_forever()
+ ws.close()
diff --git a/python/lsst/rubintv/analysis/service/command.py b/python/lsst/rubintv/analysis/service/command.py
new file mode 100644
index 0000000..776ea2b
--- /dev/null
+++ b/python/lsst/rubintv/analysis/service/command.py
@@ -0,0 +1,248 @@
+# This file is part of lsst_rubintv_analysis_service.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import json
+import logging
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+
+import sqlalchemy
+from lsst.daf.butler import Butler
+
+logger = logging.getLogger("lsst.rubintv.analysis.service.command")
+
+
+def construct_error_message(error_name: str, description: str) -> str:
+ """Use a standard format for all error messages.
+
+ Parameters
+ ----------
+ error_name :
+ Name of the error.
+ description :
+ Description of the error.
+
+ Returns
+ -------
+ result :
+ JSON formatted string.
+ """
+ return json.dumps(
+ {
+ "type": "error",
+ "content": {
+ "error": error_name,
+ "description": description,
+ },
+ }
+ )
+
+
+def error_msg(error: Exception) -> str:
+ """Handle errors received while parsing or executing a command.
+
+ Parameters
+ ----------
+ error :
+ The error that was raised while parsing, executing,
+ or responding to a command.
+
+ Returns
+ -------
+ response :
+ The JSON formatted error message sent to the user.
+
+ """
+ if isinstance(error, json.decoder.JSONDecodeError):
+ return construct_error_message("JSON decoder error", error.args[0])
+
+ if isinstance(error, CommandParsingError):
+ return construct_error_message("parsing error", error.args[0])
+
+ if isinstance(error, CommandExecutionError):
+ return construct_error_message("execution error", error.args[0])
+
+ if isinstance(error, CommandResponseError):
+ return construct_error_message("command response error", error.args[0])
+
+ # We should always receive one of the above errors, so the code should
+ # never get to here. But we generate this response just in case something
+ # very unexpected happens, or (more likely) the code is altered in such a
+ # way that this line is it.
+ msg = "An unknown error occurred, you should never reach this message."
+ return construct_error_message(error.__class__.__name__, msg)
+
+
+class CommandParsingError(Exception):
+ """An `~Exception` caused by an error in parsing a command and
+ constructing a response.
+ """
+
+ pass
+
+
+class CommandExecutionError(Exception):
+ """An error occurred while executing a command."""
+
+ pass
+
+
+class CommandResponseError(Exception):
+ """An error occurred while converting a command result to JSON"""
+
+ pass
+
+
+@dataclass
+class DatabaseConnection:
+ """A connection to a database.
+
+ Attributes
+ ----------
+ engine :
+ The engine used to connect to the database.
+ schema :
+ The schema for the database.
+ """
+
+ engine: sqlalchemy.engine.Engine
+ schema: dict
+
+
+@dataclass(kw_only=True)
+class BaseCommand(ABC):
+ """Base class for commands.
+
+ Attributes
+ ----------
+ result :
+ The response generated by the command as a `dict` that can
+ be converted into JSON.
+ response_type :
+ The type of response that this command sends to the user.
+ This should be unique for each command.
+ """
+
+ command_registry = {}
+ result: dict | None = None
+ response_type: str
+
+ @abstractmethod
+ def build_contents(self, databases: dict[str, DatabaseConnection], butler: Butler | None) -> dict:
+ """Build the contents of the command.
+
+ Parameters
+ ----------
+ databases :
+ The database connections.
+ butler :
+ A connected Butler.
+
+ Returns
+ -------
+ contents :
+ The contents of the response to the user.
+ """
+ pass
+
+ def execute(self, databases: dict[str, DatabaseConnection], butler: Butler | None):
+ """Execute the command.
+
+ This method does not return anything, buts sets the `result`,
+ the JSON formatted string that is sent to the user.
+
+ Parameters
+ ----------
+ databases :
+ The database connections.
+ butler :
+ A conencted Butler.
+
+ """
+ self.result = {"type": self.response_type, "content": self.build_contents(databases, butler)}
+
+ def to_json(self):
+ """Convert the `result` into JSON."""
+ if self.result is None:
+ raise CommandExecutionError(f"Null result for command {self.__class__.__name__}")
+ return json.dumps(self.result)
+
+ @classmethod
+ def register(cls, name: str):
+ """Register a command."""
+ BaseCommand.command_registry[name] = cls
+
+
+def execute_command(command_str: str, databases: dict[str, DatabaseConnection], butler: Butler | None) -> str:
+ """Parse a JSON formatted string into a command and execute it.
+
+ Command format:
+ ```
+ {
+ name: command name,
+ content: command content (usually a dict)
+ }
+ ```
+
+ Parameters
+ ----------
+ command_str :
+ The JSON formatted command received from the user.
+ databases :
+ The database connections.
+ butler :
+ A connected Butler.
+ """
+ try:
+ command_dict = json.loads(command_str)
+ if not isinstance(command_dict, dict):
+ raise CommandParsingError(f"Could not generate a valid command from {command_str}")
+ except Exception as err:
+ logging.exception("Error converting command to JSON.")
+ return error_msg(err)
+
+ try:
+ if "name" not in command_dict.keys():
+ raise CommandParsingError("No command 'name' given")
+
+ if command_dict["name"] not in BaseCommand.command_registry.keys():
+ raise CommandParsingError(f"Unrecognized command '{command_dict['name']}'")
+
+ parameters = command_dict.get("parameters", {})
+ command = BaseCommand.command_registry[command_dict["name"]](**parameters)
+
+ except Exception as err:
+ logging.exception("Error parsing command.")
+ return error_msg(CommandParsingError(f"'{err}' error while parsing command"))
+
+ try:
+ command.execute(databases, butler)
+ except Exception as err:
+ logging.exception("Error executing command.")
+ return error_msg(CommandExecutionError(f"{err} error executing command."))
+
+ try:
+ result = command.to_json()
+ except Exception as err:
+ logging.exception("Error converting command response to JSON.")
+ return error_msg(CommandResponseError(f"{err} error converting command response to JSON."))
+
+ return result
diff --git a/python/lsst/rubintv/analysis/service/database.py b/python/lsst/rubintv/analysis/service/database.py
new file mode 100644
index 0000000..eb06f48
--- /dev/null
+++ b/python/lsst/rubintv/analysis/service/database.py
@@ -0,0 +1,271 @@
+# This file is part of lsst_rubintv_analysis_service.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from dataclasses import dataclass
+from typing import Sequence
+
+import sqlalchemy
+from lsst.daf.butler import Butler
+
+from .command import BaseCommand, DatabaseConnection
+from .query import Query
+
+
+class UnrecognizedTableError(Exception):
+ """An error that occurs when a table name does not appear in the schema"""
+
+ pass
+
+
+def get_table_names(schema: dict) -> tuple[str, ...]:
+ """Given a schema, return a list of dataset names
+
+ Parameters
+ ----------
+ schema :
+ The schema for a database.
+
+ Returns
+ -------
+ result :
+ The names of all the tables in the database.
+ """
+ return tuple(tbl["name"] for tbl in schema["tables"])
+
+
+def get_table_schema(schema: dict, table: str) -> dict:
+ """Get the schema for a table from the database schema
+
+ Parameters
+ ----------
+ schema:
+ The schema for a database.
+ table:
+ The name of the table in the database.
+
+ Returns
+ -------
+ result:
+ The schema for the table.
+ """
+ tables = schema["tables"]
+ for _table in tables:
+ if _table["name"] == table:
+ return _table
+ raise UnrecognizedTableError("Could not find the table '{table}' in database")
+
+
+def column_names_to_models(table: sqlalchemy.Table, columns: list[str]) -> list[sqlalchemy.Column]:
+ """Return the sqlalchemy model of a Table column for each column name.
+
+ This method is used to generate a sqlalchemy query based on a `~Query`.
+
+ Parameters
+ ----------
+ table :
+ The name of the table in the database.
+ columns :
+ The names of the columns to generate models for.
+
+ Returns
+ -------
+ A list of sqlalchemy columns.
+ """
+ models = []
+ for column in columns:
+ models.append(getattr(table.columns, column))
+ return models
+
+
+def query_table(
+ table: str,
+ engine: sqlalchemy.engine.Engine,
+ columns: list[str] | None = None,
+ query: dict | None = None,
+) -> Sequence[sqlalchemy.engine.row.Row]:
+ """Query a table and return the results
+
+ Parameters
+ ----------
+ engine :
+ The engine used to connect to the database.
+ table :
+ The table that is being queried.
+ columns :
+ The columns from the table to return.
+ If `columns` is ``None`` then all the columns
+ in the table are returned.
+ query :
+ A query used on the table.
+ If `query` is ``None`` then all the rows
+ in the query are returned.
+
+ Returns
+ -------
+ result :
+ A list of the rows that were returned by the query.
+ """
+ metadata = sqlalchemy.MetaData()
+ _table = sqlalchemy.Table(table, metadata, autoload_with=engine)
+
+ if columns is None:
+ _query = _table.select()
+ else:
+ _query = sqlalchemy.select(*column_names_to_models(_table, columns))
+
+ if query is not None:
+ _query = _query.where(Query.from_dict(query)(_table))
+
+ connection = engine.connect()
+ result = connection.execute(_query)
+ return result.fetchall()
+
+
+def calculate_bounds(table: str, column: str, engine: sqlalchemy.engine.Engine) -> tuple[float, float]:
+ """Calculate the min, max for a column
+
+ Parameters
+ ----------
+ table :
+ The table that is being queried.
+ column :
+ The column to calculate the bounds of.
+ engine :
+ The engine used to connect to the database.
+
+ Returns
+ -------
+ result :
+ The ``(min, max)`` of the chosen column.
+ """
+ metadata = sqlalchemy.MetaData()
+ _table = sqlalchemy.Table(table, metadata, autoload_with=engine)
+ _column = _table.columns[column]
+
+ query = sqlalchemy.select((sqlalchemy.func.min(_column)))
+ connection = engine.connect()
+ result = connection.execute(query)
+ col_min = result.fetchone()
+ if col_min is not None:
+ col_min = col_min[0]
+ else:
+ raise ValueError(f"Could not calculate the min of column {column}")
+
+ query = sqlalchemy.select((sqlalchemy.func.max(_column)))
+ connection = engine.connect()
+ result = connection.execute(query)
+ col_max = result.fetchone()
+ if col_max is not None:
+ col_max = col_max[0]
+ else:
+ raise ValueError(f"Could not calculate the min of column {column}")
+
+ return col_min, col_max
+
+
+@dataclass(kw_only=True)
+class LoadColumnsCommand(BaseCommand):
+ """Load columns from a database table with an optional query.
+
+ Attributes
+ ----------
+ database :
+ The name of the database that the table is in.
+ table :
+ The table that the columns are loaded from.
+ columns :
+ Columns that are to be loaded. If `columns` is ``None``
+ then all the columns in the `table` are loaded.
+ query :
+ Query used to select rows in the table.
+ If `query` is ``None`` then all the rows are loaded.
+ """
+
+ database: str
+ table: str
+ columns: list[str] | None = None
+ query: dict | None = None
+ response_type: str = "table columns"
+
+ def build_contents(self, databases: dict[str, DatabaseConnection], butler: Butler | None) -> dict:
+ # Query the database to return the requested columns
+ database = databases[self.database]
+ index_column = get_table_schema(database.schema, self.table)["index_column"]
+ columns = self.columns
+ if columns is not None and index_column not in columns:
+ columns = [index_column] + columns
+ data = query_table(
+ table=self.table,
+ columns=columns,
+ query=self.query,
+ engine=database.engine,
+ )
+
+ if not data:
+ # There is no column data to return
+ content: dict = {
+ "columns": columns,
+ "data": [],
+ }
+ else:
+ content = {
+ "columns": [column for column in data[0]._fields],
+ "data": [list(row) for row in data],
+ }
+
+ return content
+
+
+@dataclass(kw_only=True)
+class CalculateBoundsCommand(BaseCommand):
+ """Calculate the bounds of a table column.
+
+ Attributes
+ ----------
+ database :
+ The name of the database that the table is in.
+ table :
+ The table that the columns are loaded from.
+ column :
+ The column to calculate the bounds of.
+ """
+
+ database: str
+ table: str
+ column: str
+ response_type: str = "column bounds"
+
+ def build_contents(self, databases: dict[str, DatabaseConnection], butler: Butler | None) -> dict:
+ database = databases[self.database]
+ data = calculate_bounds(
+ table=self.table,
+ column=self.column,
+ engine=database.engine,
+ )
+ return {
+ "column": self.column,
+ "bounds": data,
+ }
+
+
+# Register the commands
+LoadColumnsCommand.register("load columns")
+CalculateBoundsCommand.register("get bounds")
diff --git a/python/lsst/rubintv/analysis/service/query.py b/python/lsst/rubintv/analysis/service/query.py
new file mode 100644
index 0000000..4f6c0bf
--- /dev/null
+++ b/python/lsst/rubintv/analysis/service/query.py
@@ -0,0 +1,152 @@
+# This file is part of lsst_rubintv_analysis_service.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import operator as op
+from abc import ABC, abstractmethod
+from typing import Any
+
+import sqlalchemy
+
+
+class QueryError(Exception):
+ """An error that occurred during a query"""
+
+ pass
+
+
+class Query(ABC):
+ """Base class for constructing queries."""
+
+ @abstractmethod
+ def __call__(self, table: sqlalchemy.Table) -> sqlalchemy.sql.elements.BooleanClauseList:
+ """Run the query on a table.
+
+ Parameters
+ ----------
+ table :
+ The table to run the query on.
+ """
+ pass
+
+ @staticmethod
+ def from_dict(query_dict: dict[str, Any]) -> Query:
+ """Construct a query from a dictionary of parameters.
+
+ Parameters
+ ----------
+ query_dict :
+ Kwargs used to initialize the query.
+ There should only be two keys in this dict,
+ the ``name`` of the query and the ``content`` used
+ to initialize the query.
+ """
+ try:
+ if query_dict["name"] == "EqualityQuery":
+ return EqualityQuery.from_dict(query_dict["content"])
+ elif query_dict["name"] == "ParentQuery":
+ return ParentQuery.from_dict(query_dict["content"])
+ except Exception:
+ raise QueryError("Failed to parse query.")
+
+ raise QueryError("Unrecognized query type")
+
+
+class EqualityQuery(Query):
+ """A query that compares a column to a static value.
+
+ Parameters
+ ----------
+ column :
+ The column used in the query.
+ operator :
+ The operator to use for the query.
+ value :
+ The value that the column is compared to.
+ """
+
+ def __init__(
+ self,
+ column: str,
+ operator: str,
+ value: Any,
+ ):
+ self.operator = operator
+ self.column = column
+ self.value = value
+
+ def __call__(self, table: sqlalchemy.Table) -> sqlalchemy.sql.elements.BooleanClauseList:
+ column = table.columns[self.column]
+
+ if self.operator in ("eq", "ne", "lt", "le", "gt", "ge"):
+ operator = getattr(op, self.operator)
+ return operator(column, self.value)
+
+ if self.operator not in ("startswith", "endswith", "contains"):
+ raise QueryError(f"Unrecognized Equality operator {self.operator}")
+
+ return getattr(column, self.operator)(self.value)
+
+ @staticmethod
+ def from_dict(query_dict: dict[str, Any]) -> EqualityQuery:
+ return EqualityQuery(**query_dict)
+
+
+class ParentQuery(Query):
+ """A query that uses a binary operation to combine other queries.
+
+ Parameters
+ ----------
+ children :
+ The child queries that are combined using the binary operator.
+ operator :
+ The operator that us used to combine the queries.
+ """
+
+ def __init__(self, children: list[Query], operator: str):
+ self._children = children
+ self._operator = operator
+
+ def __call__(self, table: sqlalchemy.Table) -> sqlalchemy.sql.elements.BooleanClauseList:
+ child_results = [child(table) for child in self._children]
+ try:
+ match self._operator:
+ case "AND":
+ return sqlalchemy.and_(*child_results)
+ case "OR":
+ return sqlalchemy.or_(*child_results)
+ case "NOT":
+ return sqlalchemy.not_(*child_results)
+ case "XOR":
+ return sqlalchemy.and_(
+ sqlalchemy.or_(*child_results),
+ sqlalchemy.not_(sqlalchemy.and_(*child_results)),
+ )
+ except Exception:
+ raise QueryError("Error applying a boolean query statement.")
+
+ @staticmethod
+ def from_dict(query_dict: dict[str, Any]) -> ParentQuery:
+ return ParentQuery(
+ children=[Query.from_dict(child) for child in query_dict["children"]],
+ operator=query_dict["operator"],
+ )
diff --git a/python/lsst/rubintv/analysis/service/utils.py b/python/lsst/rubintv/analysis/service/utils.py
new file mode 100644
index 0000000..806215b
--- /dev/null
+++ b/python/lsst/rubintv/analysis/service/utils.py
@@ -0,0 +1,40 @@
+from enum import Enum
+
+
+# ANSI color codes for printing to the terminal
+class Colors(Enum):
+ RESET = 0
+ BLACK = 30
+ RED = 31
+ GREEN = 32
+ YELLOW = 33
+ BLUE = 34
+ MAGENTA = 35
+ CYAN = 36
+ WHITE = 37
+ DEFAULT = 39
+ BRIGHT_BLACK = 90
+ BRIGHT_RED = 91
+ BRIGHT_GREEN = 92
+ BRIGHT_YELLOW = 93
+ BRIGHT_BLUE = 94
+ BRIGHT_MAGENTA = 95
+ BRIGHT_CYAN = 96
+ BRIGHT_WHITE = 97
+
+
+def printc(message: str, color: Colors, end_color: Colors = Colors.RESET):
+ """Print a message to the terminal in color.
+
+ After printing reset the color by default.
+
+ Parameters
+ ----------
+ message :
+ The message to print.
+ color :
+ The color to print the message in.
+ end :
+ The color future messages should be printed in.
+ """
+ print(f"\033[{color.value}m{message}\033[{end_color.value}m")
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..f68ad48
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,12 @@
+numpy>=1.25.2
+scipy
+matplotlib
+pydantic
+pyyaml
+sqlalchemy
+astropy
+websocket-client
+lsst-daf-butler
+
+# the following import is temporary while testing
+tornado
diff --git a/scripts/config.yaml b/scripts/config.yaml
new file mode 100644
index 0000000..4e690ac
--- /dev/null
+++ b/scripts/config.yaml
@@ -0,0 +1,8 @@
+---
+databases:
+ summitcdb:
+ schema: "/Users/fred3m/temp/visitDb/summit.yaml"
+ url: "sqlite:////Users/fred3m/temp/visitDb/summit.db"
+#butler:
+# repo: /repos/main
+# skymap: hsc_rings_v1
diff --git a/scripts/mock_server.py b/scripts/mock_server.py
new file mode 100644
index 0000000..e6fb659
--- /dev/null
+++ b/scripts/mock_server.py
@@ -0,0 +1,255 @@
+# This file is part of lsst_rubintv_analysis_service.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import uuid
+from dataclasses import dataclass
+from enum import Enum
+
+import tornado.httpserver
+import tornado.ioloop
+import tornado.web
+import tornado.websocket
+from lsst.rubintv.analysis.service.utils import Colors, printc
+
+# Default port and address to listen on
+LISTEN_PORT = 2000
+LISTEN_ADDRESS = "localhost"
+
+
+class WorkerPodStatus(Enum):
+ """Status of a worker pod."""
+
+ IDLE = "idle"
+ BUSY = "busy"
+
+
+class WebSocketHandler(tornado.websocket.WebSocketHandler):
+ """
+ Handler that handles WebSocket connections
+ """
+
+ workers: dict[str, WorkerPod] = dict() # Keep track of connected worker pods
+ clients: dict[str, WebSocketHandler] = dict() # Keep track of connected clients
+ queue: list[QueueItem] = list() # Queue of messages to be processed
+
+ @classmethod
+ def urls(cls) -> list[tuple[str, type[tornado.web.RequestHandler], dict[str, str]]]:
+ """url to handle websocket connections.
+
+ Websocket URLs should either be followed by 'worker' for worker pods
+ or client for clients.
+ """
+ return [
+ (r"/ws/([^/]+)", cls, {}), # Route/Handler/kwargs
+ ]
+
+ def open(self, client_type: str) -> None:
+ """
+ Client opens a websocket
+
+ Parameters
+ ----------
+ type :
+ The type of client that is connecting.
+ """
+ self.client_id = str(uuid.uuid4())
+ if client_type == "worker":
+ WebSocketHandler.workers[self.client_id] = WorkerPod(self.client_id, self)
+ printc(
+ f"New worker {self.client_id} connected. Total workers: {len(WebSocketHandler.workers)}",
+ Colors.BLUE,
+ Colors.RED,
+ )
+ if client_type == "client":
+ WebSocketHandler.clients[self.client_id] = self
+ printc(
+ f"New client {self.client_id} connected. Total clients: {len(WebSocketHandler.clients)}",
+ Colors.YELLOW,
+ Colors.RED,
+ )
+
+ def on_message(self, message: str) -> None:
+ """
+ Message received from a client or worker.
+
+ Parameters
+ ----------
+ message :
+ The message received from the client or worker.
+ """
+ if self.client_id in WebSocketHandler.clients:
+ printc(f"Message received from {self.client_id}", Colors.YELLOW, Colors.RED)
+ client = WebSocketHandler.clients[self.client_id]
+
+ # Find an idle worker
+ idle_worker = None
+ for worker in WebSocketHandler.workers.values():
+ if worker.status == WorkerPodStatus.IDLE:
+ idle_worker = worker
+ break
+
+ if idle_worker is None:
+ # No idle worker found, add to queue
+ WebSocketHandler.queue.append(QueueItem(message, client))
+ return
+ idle_worker.process(message, client)
+ return
+
+ if self.client_id in WebSocketHandler.workers:
+ worker = WebSocketHandler.workers[self.client_id]
+ worker.on_finished(message)
+ printc(
+ f"Message received from worker {self.client_id}. New status {worker.status}",
+ Colors.BLUE,
+ Colors.RED,
+ )
+
+ # Check the queue for any outstanding jobs.
+ if len(WebSocketHandler.queue) > 0:
+ queue_item = WebSocketHandler.queue.pop(0)
+ worker.process(queue_item.message, queue_item.client)
+ return
+
+ def on_close(self) -> None:
+ """
+ Client closes the connection
+ """
+ if self.client_id in WebSocketHandler.clients:
+ del WebSocketHandler.clients[self.client_id]
+ printc(
+ f"Client disconnected. Active clients: {len(WebSocketHandler.clients)}",
+ Colors.YELLOW,
+ Colors.RED,
+ )
+ for worker in WebSocketHandler.workers.values():
+ if worker.connected_client == self:
+ worker.on_finished("Client disconnected")
+ break
+ if self.client_id in WebSocketHandler.workers:
+ del WebSocketHandler.workers[self.client_id]
+ printc(
+ f"Worker disconnected. Active workers: {len(WebSocketHandler.workers)}",
+ Colors.BLUE,
+ Colors.RED,
+ )
+
+ def check_origin(self, origin):
+ """
+ Override the origin check if needed
+ """
+ return True
+
+
+class WorkerPod:
+ """State of a worker pod.
+
+ Attributes
+ ----------
+ id :
+ The id of the worker pod.
+ ws :
+ The websocket connection to the worker pod.
+ status :
+ The status of the worker pod.
+ connected_client :
+ The client that is connected to this worker pod.
+ """
+
+ status: WorkerPodStatus
+ connected_client: WebSocketHandler | None
+
+ def __init__(self, wid: str, ws: WebSocketHandler):
+ self.wid = wid
+ self.ws = ws
+ self.status = WorkerPodStatus.IDLE
+ self.connected_client = None
+
+ def process(self, message: str, connected_client: WebSocketHandler):
+ """Process a message from a client.
+
+ Parameters
+ ----------
+ message :
+ The message to process.
+ connected_client :
+ The client that is connected to this worker pod.
+ """
+ self.status = WorkerPodStatus.BUSY
+ self.connected_client = connected_client
+ printc(
+ f"Worker {self.wid} processing message from client {connected_client.client_id}",
+ Colors.BLUE,
+ Colors.RED,
+ )
+ # Send the job to the worker pod
+ self.ws.write_message(message)
+
+ def on_finished(self, message):
+ """Called when the worker pod has finished processing a message."""
+ if (
+ self.connected_client is not None
+ and self.connected_client.ws_connection is not None
+ and message != "Client disconnected"
+ ):
+ # Send the reply to the client that made the request.
+ self.connected_client.write_message(message)
+ else:
+ printc(
+ f"Worker {self.wid} finished processing, but no client was connected.", Colors.RED, Colors.RED
+ )
+ self.status = WorkerPodStatus.IDLE
+ self.connected_client = None
+
+
+@dataclass
+class QueueItem:
+ """An item in the client queue.
+
+ Attributes
+ ----------
+ message :
+ The message to process.
+ client :
+ The client that is making a request.
+ """
+
+ message: str
+ client: WebSocketHandler
+
+
+def main():
+ # Create tornado application and supply URL routes
+ app = tornado.web.Application(WebSocketHandler.urls()) # type: ignore
+
+ # Setup HTTP Server
+ http_server = tornado.httpserver.HTTPServer(app)
+ http_server.listen(LISTEN_PORT, LISTEN_ADDRESS)
+
+ printc(f"Listening on address: {LISTEN_ADDRESS}, {LISTEN_PORT}", Colors.GREEN, Colors.RED)
+
+ # Start IO/Event loop
+ tornado.ioloop.IOLoop.instance().start()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/rubintv_worker.py b/scripts/rubintv_worker.py
new file mode 100644
index 0000000..97152bf
--- /dev/null
+++ b/scripts/rubintv_worker.py
@@ -0,0 +1,55 @@
+# This file is part of lsst_rubintv_analysis_service.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import argparse
+import os
+import pathlib
+
+import yaml
+from lsst.rubintv.analysis.service.client import Worker
+
+default_config = os.path.join(pathlib.Path(__file__).parent.absolute(), "config.yaml")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Initialize a new RubinTV worker.")
+ parser.add_argument(
+ "-a", "--address", default="localhost", type=str, help="Address of the rubinTV web app."
+ )
+ parser.add_argument(
+ "-p", "--port", default=2000, type=int, help="Port of the rubinTV web app websockets."
+ )
+ parser.add_argument(
+ "-c", "--config", default=default_config, type=str, help="Location of the configuration file."
+ )
+ args = parser.parse_args()
+
+ # Load the configuration file
+ with open(args.config, "r") as file:
+ config = yaml.safe_load(file)
+
+ # Run the client and connect to rubinTV via websockets
+ worker = Worker(args.address, args.port, config)
+ worker.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/setup.cfg b/setup.cfg
index bf28c0a..10e0f70 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,7 +1,7 @@
[flake8]
max-line-length = 110
max-doc-length = 79
-ignore = E133, E226, E228, N802, N803, N806, N812, N813, N815, N816, W504
+ignore = W503, E203
exclude =
bin,
doc,
diff --git a/tests/schema.yaml b/tests/schema.yaml
new file mode 100644
index 0000000..835499d
--- /dev/null
+++ b/tests/schema.yaml
@@ -0,0 +1,42 @@
+---
+name: testdb
+"@id": "#test_db"
+description: Small database for testing the package
+tables:
+ - name: ExposureInfo
+ index_column: exposure_id
+ columns:
+ - name: exposure_id
+ datatype: long
+ description: Unique identifier of an exposure.
+ - name: seq_num
+ datatype: long
+ description: Sequence number
+ - name: ra
+ datatype: double
+ unit: degree
+ description: RA of focal plane center.
+ - name: dec
+ datatype: double
+ unit: degree
+ description: Declination of focal plane center
+ - name: expTime
+ datatype: double
+ description: Spatially-averaged duration of exposure, accurate to 10ms.
+ - name: physical_filter
+ datatype: char
+ description: ID of physical filter,
+ the filter associated with a particular instrument.
+ - name: obsNight
+ datatype: date
+ description: The night of the observation. This is different than the
+ observation date, as this is the night that the observations started,
+ so for observations after midnight obsStart and obsNight will be
+ different days.
+ - name: obsStart
+ datatype: datetime
+ description: Start time of the exposure at the fiducial center
+ of the focal plane array, TAI, accurate to 10ms.
+ - name: obsStartMJD
+ datatype: double
+ description: Start of the exposure in MJD, TAI, accurate to 10ms.
diff --git a/tests/test_command.py b/tests/test_command.py
new file mode 100644
index 0000000..bab82c7
--- /dev/null
+++ b/tests/test_command.py
@@ -0,0 +1,210 @@
+# This file is part of lsst_rubintv_analysis_service.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import json
+import os
+import tempfile
+
+import lsst.rubintv.analysis.service as lras
+import sqlalchemy
+import utils
+import yaml
+
+
+class TestCommand(utils.RasTestCase):
+ def setUp(self):
+ path = os.path.dirname(__file__)
+ yaml_filename = os.path.join(path, "schema.yaml")
+
+ with open(yaml_filename) as file:
+ schema = yaml.safe_load(file)
+ db_file = tempfile.NamedTemporaryFile(delete=False)
+ utils.create_database(schema, db_file.name)
+ self.db_file = db_file
+ self.db_filename = db_file.name
+ self.schema = schema
+
+ # Load the database connection information
+ self.databases = {
+ "testdb": lras.command.DatabaseConnection(
+ schema=schema, engine=sqlalchemy.create_engine("sqlite:///" + db_file.name)
+ )
+ }
+
+ # Set up the sqlalchemy connection
+ self.engine = sqlalchemy.create_engine("sqlite:///" + db_file.name)
+
+ def tearDown(self) -> None:
+ self.db_file.close()
+ os.remove(self.db_file.name)
+
+ def execute_command(self, command: dict, response_type: str) -> dict:
+ command_json = json.dumps(command)
+ response = lras.command.execute_command(command_json, self.databases, None)
+ result = json.loads(response)
+ self.assertEqual(result["type"], response_type)
+ return result["content"]
+
+
+class TestCalculateBoundsCommand(TestCommand):
+ def test_calculate_bounds_command(self):
+ command = {
+ "name": "get bounds",
+ "parameters": {
+ "database": "testdb",
+ "table": "ExposureInfo",
+ "column": "dec",
+ },
+ }
+ content = self.execute_command(command, "column bounds")
+ self.assertEqual(content["column"], "dec")
+ self.assertListEqual(content["bounds"], [-40, 50])
+
+
+class TestLoadColumnsCommand(TestCommand):
+ def test_load_full_dataset(self):
+ command = {"name": "load columns", "parameters": {"database": "testdb", "table": "ExposureInfo"}}
+
+ content = self.execute_command(command, "table columns")
+ data = content["data"]
+
+ truth = utils.ap_table_to_list(utils.get_test_data())
+
+ self.assertDataTableEqual(data, truth)
+
+ def test_load_full_columns(self):
+ command = {
+ "name": "load columns",
+ "parameters": {
+ "database": "testdb",
+ "table": "ExposureInfo",
+ "columns": [
+ "ra",
+ "dec",
+ ],
+ },
+ }
+
+ content = self.execute_command(command, "table columns")
+ columns = content["columns"]
+ data = content["data"]
+
+ truth = utils.get_test_data()["exposure_id", "ra", "dec"]
+ truth_data = utils.ap_table_to_list(truth)
+
+ self.assertTupleEqual(tuple(columns), tuple(truth.columns))
+ self.assertDataTableEqual(data, truth_data)
+
+ def test_load_columns_with_query(self):
+ command = {
+ "name": "load columns",
+ "parameters": {
+ "database": "testdb",
+ "table": "ExposureInfo",
+ "columns": [
+ "exposure_id",
+ "ra",
+ "dec",
+ ],
+ "query": {
+ "name": "EqualityQuery",
+ "content": {
+ "column": "expTime",
+ "operator": "eq",
+ "value": 30,
+ },
+ },
+ },
+ }
+
+ content = self.execute_command(command, "table columns")
+ columns = content["columns"]
+ data = content["data"]
+
+ truth = utils.get_test_data()["exposure_id", "ra", "dec"]
+ # Select rows with expTime = 30
+ truth = truth[[True, True, False, False, False, True, True, True, False, False]]
+ truth_data = utils.ap_table_to_list(truth)
+
+ self.assertTupleEqual(tuple(columns), tuple(truth.columns))
+ self.assertDataTableEqual(data, truth_data)
+
+
+class TestCommandErrors(TestCommand):
+ def check_error_response(self, content: dict, error: str, description: str | None = None):
+ self.assertEqual(content["error"], error)
+ if description is not None:
+ self.assertEqual(content["description"], description)
+
+ def test_errors(self):
+ # Command cannot be decoded as JSON dict
+ content = self.execute_command("{'test': [1,2,3,0004,}", "error")
+ self.check_error_response(content, "parsing error")
+
+ # Command does not contain a "name"
+ command = {"content": {}}
+ content = self.execute_command(command, "error")
+ self.check_error_response(
+ content,
+ "parsing error",
+ "'No command 'name' given' error while parsing command",
+ )
+
+ # Command has an invalid name
+ command = {"name": "invalid name"}
+ content = self.execute_command(command, "error")
+ self.check_error_response(
+ content,
+ "parsing error",
+ "'Unrecognized command 'invalid name'' error while parsing command",
+ )
+
+ # Command has no parameters
+ command = {"name": "get bounds"}
+ content = self.execute_command(command, "error")
+ self.check_error_response(
+ content,
+ "parsing error",
+ )
+
+ # Command has invalid parameters
+ command = {
+ "name": "get bounds",
+ "parameters": {
+ "a": 1,
+ },
+ }
+ content = self.execute_command(command, "error")
+ self.check_error_response(
+ content,
+ "parsing error",
+ )
+
+ # Command execution failed (table name does not exist)
+ command = {
+ "name": "get bounds",
+ "parameters": {"database": "testdb", "table": "InvalidTable", "column": "invalid_column"},
+ }
+ content = self.execute_command(command, "error")
+ self.check_error_response(
+ content,
+ "execution error",
+ )
diff --git a/tests/test_database.py b/tests/test_database.py
new file mode 100644
index 0000000..94d6914
--- /dev/null
+++ b/tests/test_database.py
@@ -0,0 +1,103 @@
+# This file is part of lsst_rubintv_analysis_service.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import os
+import tempfile
+from unittest import TestCase
+
+import lsst.rubintv.analysis.service as lras
+import sqlalchemy
+import utils
+import yaml
+
+
+class TestDatabase(TestCase):
+ def setUp(self):
+ path = os.path.dirname(__file__)
+ yaml_filename = os.path.join(path, "schema.yaml")
+
+ with open(yaml_filename) as file:
+ schema = yaml.safe_load(file)
+ db_file = tempfile.NamedTemporaryFile(delete=False)
+ utils.create_database(schema, db_file.name)
+ self.db_file = db_file
+ self.db_filename = db_file.name
+ self.schema = schema
+
+ # Set up the sqlalchemy connection
+ self.engine = sqlalchemy.create_engine("sqlite:///" + db_file.name)
+
+ def tearDown(self) -> None:
+ self.db_file.close()
+ os.remove(self.db_file.name)
+
+ def test_get_table_names(self):
+ table_names = lras.database.get_table_names(self.schema)
+ self.assertTupleEqual(table_names, ("ExposureInfo",))
+
+ def test_get_table_schema(self):
+ schema = lras.database.get_table_schema(self.schema, "ExposureInfo")
+ self.assertEqual(schema["name"], "ExposureInfo")
+
+ columns = [
+ "exposure_id",
+ "seq_num",
+ "ra",
+ "dec",
+ "expTime",
+ "physical_filter",
+ "obsNight",
+ "obsStart",
+ "obsStartMJD",
+ ]
+ for n, column in enumerate(schema["columns"]):
+ self.assertEqual(column["name"], columns[n])
+
+ def test_query_full_table(self):
+ truth_table = utils.get_test_data()
+ truth = utils.ap_table_to_list(truth_table)
+
+ data = lras.database.query_table("ExposureInfo", engine=self.engine)
+ print(data)
+
+ self.assertListEqual(list(data[0]._fields), list(truth_table.columns))
+
+ for n in range(len(truth)):
+ true_row = tuple(truth[n])
+ row = tuple(data[n])
+ self.assertTupleEqual(row, true_row)
+
+ def test_query_columns(self):
+ truth = utils.get_test_data()
+ truth = utils.ap_table_to_list(truth["ra", "dec"])
+
+ data = lras.database.query_table("ExposureInfo", columns=["ra", "dec"], engine=self.engine)
+
+ self.assertListEqual(list(data[0]._fields), ["ra", "dec"])
+
+ for n in range(len(truth)):
+ true_row = tuple(truth[n])
+ row = tuple(data[n])
+ self.assertTupleEqual(row, true_row)
+
+ def test_calculate_bounds(self):
+ result = lras.database.calculate_bounds("ExposureInfo", "dec", self.engine)
+ self.assertTupleEqual(result, (-40, 50))
diff --git a/tests/test_query.py b/tests/test_query.py
new file mode 100644
index 0000000..a125218
--- /dev/null
+++ b/tests/test_query.py
@@ -0,0 +1,280 @@
+# This file is part of lsst_rubintv_analysis_service.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import os
+import tempfile
+
+import lsst.rubintv.analysis.service as lras
+import sqlalchemy
+import utils
+import yaml
+
+
+class TestQuery(utils.RasTestCase):
+ def setUp(self):
+ path = os.path.dirname(__file__)
+ yaml_filename = os.path.join(path, "schema.yaml")
+
+ with open(yaml_filename) as file:
+ schema = yaml.safe_load(file)
+ db_file = tempfile.NamedTemporaryFile(delete=False)
+ utils.create_database(schema, db_file.name)
+ self.db_file = db_file
+ self.db_filename = db_file.name
+ self.schema = schema
+
+ # Set up the sqlalchemy connection
+ self.engine = sqlalchemy.create_engine("sqlite:///" + db_file.name)
+ self.metadata = sqlalchemy.MetaData()
+ self.table = sqlalchemy.Table("ExposureInfo", self.metadata, autoload_with=self.engine)
+
+ def tearDown(self) -> None:
+ self.db_file.close()
+ os.remove(self.db_file.name)
+
+ def test_equality(self):
+ table = self.table
+ column = table.columns.dec
+
+ value = 0
+ truth_dict = {
+ "eq": column == value,
+ "ne": column != value,
+ "lt": column < value,
+ "le": column <= value,
+ "gt": column > value,
+ "ge": column >= value,
+ }
+
+ for operator, truth in truth_dict.items():
+ self.assertTrue(lras.query.EqualityQuery("dec", operator, value)(table).compare(truth))
+
+ def test_query(self):
+ table = self.table
+
+ # dec > 0
+ query = lras.query.EqualityQuery("dec", "gt", 0)
+ result = query(table)
+ self.assertTrue(result.compare(table.columns.dec > 0))
+
+ # dec < 0 and ra > 60
+ query = lras.query.ParentQuery(
+ operator="AND",
+ children=[
+ lras.query.EqualityQuery("dec", "lt", 0),
+ lras.query.EqualityQuery("ra", "gt", 60),
+ ],
+ )
+ result = query(table)
+ truth = sqlalchemy.and_(
+ table.columns.dec < 0,
+ table.columns.ra > 60,
+ )
+ self.assertTrue(result.compare(truth))
+
+ # Check queries that are unequal to verify that they don't work
+ result = query(table)
+ truth = sqlalchemy.and_(
+ table.columns.dec < 0,
+ table.columns.ra > 70,
+ )
+ self.assertFalse(result.compare(truth))
+
+ def test_database_query(self):
+ data = utils.get_test_data()
+
+ # dec > 0 (and is not None)
+ query1 = {
+ "name": "EqualityQuery",
+ "content": {
+ "column": "dec",
+ "operator": "gt",
+ "value": 0,
+ },
+ }
+ # ra > 60 (and is not None)
+ query2 = {
+ "name": "EqualityQuery",
+ "content": {
+ "column": "ra",
+ "operator": "gt",
+ "value": 60,
+ },
+ }
+
+ # Test 1: dec > 0 (and is not None)
+ query = query1
+ result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query)
+ truth = data[[False, False, False, False, False, True, False, True, True, True]]
+ truth = utils.ap_table_to_list(truth)
+ self.assertDataTableEqual(result, truth)
+
+ # Test 2: dec > 0 and ra > 60 (and neither is None)
+ query = {
+ "name": "ParentQuery",
+ "content": {
+ "operator": "AND",
+ "children": [query1, query2],
+ },
+ }
+ result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query)
+ truth = data[[False, False, False, False, False, False, False, False, True, True]]
+ truth = utils.ap_table_to_list(truth)
+ self.assertDataTableEqual(result, truth)
+
+ # Test 3: dec <= 0 or ra > 60 (and neither is None)
+ query = {
+ "name": "ParentQuery",
+ "content": {
+ "operator": "OR",
+ "children": [
+ {
+ "name": "ParentQuery",
+ "content": {
+ "operator": "NOT",
+ "children": [query1],
+ },
+ },
+ query2,
+ ],
+ },
+ }
+
+ result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query)
+ truth = data[[True, True, False, True, True, False, True, False, True, True]]
+ truth = utils.ap_table_to_list(truth)
+ self.assertDataTableEqual(result, truth)
+
+ # Test 4: dec > 0 XOR ra > 60
+ query = {
+ "name": "ParentQuery",
+ "content": {
+ "operator": "XOR",
+ "children": [query1, query2],
+ },
+ }
+ result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query)
+ truth = data[[False, False, False, False, False, True, False, False, False, False]]
+ truth = utils.ap_table_to_list(truth)
+ self.assertDataTableEqual(result, truth)
+
+ def test_database_string_query(self):
+ data = utils.get_test_data()
+
+ # Test equality
+ query = {
+ "name": "EqualityQuery",
+ "content": {
+ "column": "physical_filter",
+ "operator": "eq",
+ "value": "DECam r-band",
+ },
+ }
+ result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query)
+ truth = data[[False, False, False, False, False, False, True, False, False, False]]
+ truth = utils.ap_table_to_list(truth)
+ self.assertDataTableEqual(result, truth)
+
+ # Test "startswith"
+ query = {
+ "name": "EqualityQuery",
+ "content": {
+ "column": "physical_filter",
+ "operator": "startswith",
+ "value": "DECam",
+ },
+ }
+ result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query)
+ truth = data[[False, False, False, False, False, True, True, True, True, True]]
+ truth = utils.ap_table_to_list(truth)
+ self.assertDataTableEqual(result, truth)
+
+ # Test "endswith"
+ query = {
+ "name": "EqualityQuery",
+ "content": {
+ "column": "physical_filter",
+ "operator": "endswith",
+ "value": "r-band",
+ },
+ }
+ result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query)
+ truth = data[[False, True, False, False, False, False, True, False, False, False]]
+ truth = utils.ap_table_to_list(truth)
+ self.assertDataTableEqual(result, truth)
+
+ # Test "like"
+ query = {
+ "name": "EqualityQuery",
+ "content": {
+ "column": "physical_filter",
+ "operator": "contains",
+ "value": "T r",
+ },
+ }
+ result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query)
+ truth = data[[False, True, False, False, False, False, False, False, False, False]]
+ truth = utils.ap_table_to_list(truth)
+ self.assertDataTableEqual(result, truth)
+
+ def test_database_datatime_query(self):
+ data = utils.get_test_data()
+
+ # Test <
+ query1 = {
+ "name": "EqualityQuery",
+ "content": {
+ "column": "obsStart",
+ "operator": "lt",
+ "value": "2023-05-19 23:23:23",
+ },
+ }
+ result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query1)
+ truth = data[[True, True, True, False, False, True, True, True, True, True]]
+ truth = utils.ap_table_to_list(truth)
+ self.assertDataTableEqual(result, truth)
+
+ # Test >
+ query2 = {
+ "name": "EqualityQuery",
+ "content": {
+ "column": "obsStart",
+ "operator": "gt",
+ "value": "2023-05-01 23:23:23",
+ },
+ }
+ result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query2)
+ truth = data[[True, True, True, True, True, False, False, False, False, False]]
+ truth = utils.ap_table_to_list(truth)
+ self.assertDataTableEqual(result, truth)
+
+ # Test in range
+ query3 = {
+ "name": "ParentQuery",
+ "content": {
+ "operator": "AND",
+ "children": [query1, query2],
+ },
+ }
+ result = lras.database.query_table("ExposureInfo", engine=self.engine, query=query3)
+ truth = data[[True, True, True, False, False, False, False, False, False, False]]
+ truth = utils.ap_table_to_list(truth)
+ self.assertDataTableEqual(result, truth)
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000..5b78e5b
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,197 @@
+# This file is part of lsst_rubintv_analysis_service.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import sqlite3
+from unittest import TestCase
+
+from astropy.table import Table as ApTable
+from astropy.time import Time
+
+# Convert visit DB datatypes to sqlite3 datatypes
+datatype_transform = {
+ "int": "integer",
+ "long": "integer",
+ "double": "real",
+ "float": "real",
+ "char": "text",
+ "date": "text",
+ "datetime": "text",
+}
+
+
+def create_table(cursor: sqlite3.Cursor, tbl_name: str, schema: dict):
+ """Create a table in an sqlite database.
+
+ Parameters
+ ----------
+ cursor :
+ The cursor associated with the database connection.
+ tbl_name :
+ The name of the table to create.
+ schema :
+ The schema of the table.
+ """
+ command = f"CREATE TABLE {tbl_name}(\n"
+ for field in schema:
+ command += f' {field["name"]} {datatype_transform[field["datatype"]]},\n'
+ command = command[:-2] + "\n);"
+ cursor.execute(command)
+
+
+def get_test_data_dict() -> dict:
+ """Get a dictionary containing the test data"""
+ obs_start = [
+ "2023-05-19 20:20:20",
+ "2023-05-19 21:21:21",
+ "2023-05-19 22:22:22",
+ "2023-05-19 23:23:23",
+ "2023-05-20 00:00:00",
+ "2023-02-14 22:22:22",
+ "2023-02-14 23:23:23",
+ "2023-02-14 00:00:00",
+ "2023-02-14 01:01:01",
+ "2023-02-14 02:02:02",
+ ]
+
+ obs_start_mjd = [Time(time).mjd for time in obs_start]
+
+ return {
+ "exposure_id": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18],
+ "seq_num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
+ "ra": [10, 20, None, 40, 50, 60, 70, None, 90, 100],
+ "dec": [-40, -30, None, -10, 0, 10, None, 30, 40, 50],
+ "expTime": [30, 30, 10, 15, 15, 30, 30, 30, 15, 20],
+ "physical_filter": [
+ "LSST g-band",
+ "LSST r-band",
+ "LSST i-band",
+ "LSST z-band",
+ "LSST y-band",
+ "DECam g-band",
+ "DECam r-band",
+ "DECam i-band",
+ "DECam z-band",
+ "DECam y-band",
+ ],
+ "obsNight": [
+ "2023-05-19",
+ "2023-05-19",
+ "2023-05-19",
+ "2023-05-19",
+ "2023-05-19",
+ "2023-02-14",
+ "2023-02-14",
+ "2023-02-14",
+ "2023-02-14",
+ "2023-02-14",
+ ],
+ "obsStart": obs_start,
+ "obsStartMJD": obs_start_mjd,
+ }
+
+
+def get_test_data() -> ApTable:
+ """Generate data for the test database"""
+ data_dict = get_test_data_dict()
+
+ table = ApTable(list(data_dict.values()), names=list(data_dict.keys()))
+ return table
+
+
+def ap_table_to_list(data: ApTable) -> list:
+ """Convert an astropy Table into a list of tuples."""
+ rows = []
+ for row in data:
+ rows.append(tuple(row))
+ return rows
+
+
+def create_database(schema: dict, db_filename: str):
+ """Create the test database"""
+ tbl_name = "ExposureInfo"
+ connection = sqlite3.connect(db_filename)
+ cursor = connection.cursor()
+
+ create_table(cursor, tbl_name, schema["tables"][0]["columns"])
+
+ data = get_test_data_dict()
+
+ for n in range(len(data["exposure_id"])):
+ row = tuple(data[key][n] for key in data.keys())
+ value_str = "?, " * (len(row) - 1) + "?"
+ command = f"INSERT INTO {tbl_name} VALUES({value_str});"
+ cursor.execute(command, row)
+ connection.commit()
+ cursor.close()
+
+
+class TableMismatchError(AssertionError):
+ pass
+
+
+class RasTestCase(TestCase):
+ """Base class for tests in this package
+
+ For now this only includes methods to check the
+ database results, but in the future other checks
+ might be put in place.
+ """
+
+ @staticmethod
+ def get_data_table_indices(table: list[tuple]) -> list[int]:
+ """Get the index for each rom in the data table.
+
+ Parameters
+ ----------
+ table :
+ The table containing the data.
+
+ Returns
+ -------
+ result :
+ The index for each row in the table.
+ """
+ # Return the seq_num as an index
+ return [row[1] for row in table]
+
+ def assertDataTableEqual(self, result, truth):
+ """Check if two data tables are equal.
+
+ Parameters
+ ----------
+ result :
+ The result generated by the test that is checked.
+ truth :
+ The expected value of the test.
+ """
+ if len(result) != len(truth):
+ msg = "Data tables have a different number of rows: "
+ msg += f"indices: [{self.get_data_table_indices(result)}], [{self.get_data_table_indices(truth)}]"
+ raise TableMismatchError(msg)
+ try:
+ for n in range(len(truth)):
+ true_row = tuple(truth[n])
+ row = tuple(result[n])
+ self.assertTupleEqual(row, true_row)
+ except AssertionError:
+ msg = "Mismatched tables: "
+ msg += f"indices: [{self.get_data_table_indices(result)}], [{self.get_data_table_indices(truth)}]"
+ raise TableMismatchError(msg)