From 6ee00db383b23176098b2691f306c268fed5e18d Mon Sep 17 00:00:00 2001 From: fred3m Date: Fri, 5 Jul 2024 11:15:27 -0700 Subject: [PATCH] Fix tests for consDB-like schema --- tests/schema.yaml | 62 +++++-------- tests/test_command.py | 78 +++++----------- tests/test_database.py | 51 +++++------ tests/test_query.py | 204 +++++++++++++++++------------------------ tests/utils.py | 119 +++++++++++------------- 5 files changed, 207 insertions(+), 307 deletions(-) diff --git a/tests/schema.yaml b/tests/schema.yaml index 8610e5b..f7cfa19 100644 --- a/tests/schema.yaml +++ b/tests/schema.yaml @@ -5,21 +5,18 @@ description: Small database for testing the package joins: - type: inner matches: - Visit: - - day_obs - - seq_num - - instrument - ExposureInfo: - - day_obs - - seq_num - - instrument + exposure: + - exposure_id + visit1_quicklook: + - visit_id tables: - - name: Visit + - name: exposure index_columns: - - day_obs - - seq_num - - instrument + - exposure_id columns: + - name: exposure_id + datatype: long + description: Unique identifier for the exposure. - name: seq_num datatype: long description: Sequence number @@ -29,9 +26,6 @@ tables: observation date, as this is the night that the observations started, so for observations after midnight obsStart and obsNight will be different days. - - name: instrument - datatype: char - description: Instrument name - name: ra datatype: double unit: degree @@ -40,38 +34,24 @@ tables: datatype: double unit: degree description: Declination of focal plane center - - name: ExposureInfo - index_columns: - - day_obs - - seq_num - - instrument - columns: - - name: seq_num - datatype: long - description: Sequence number - - name: day_obs - datatype: date - description: The night of the observation. This is different than the - observation date, as this is the night that the observations started, - so for observations after midnight obsStart and obsNight will be - different days. - - name: instrument - datatype: char - description: Instrument name - - name: exposure_id - datatype: long - description: Unique identifier of an exposure. - - name: expTime - datatype: double - description: Spatially-averaged duration of exposure, accurate to 10ms. - name: physical_filter datatype: char description: ID of physical filter, the filter associated with a particular instrument. - - name: obsStart + - name: obs_start datatype: datetime description: Start time of the exposure at the fiducial center of the focal plane array, TAI, accurate to 10ms. - - name: obsStartMJD + - name: obs_start_mjd datatype: double description: Start of the exposure in MJD, TAI, accurate to 10ms. + - name: visit1_quicklook + index_columns: + - visit_id + columns: + - name: visit_id + datatype: long + description: Unique identifier for the visit. + - name: exp_time + datatype: double + description: Spatially-averaged duration of exposure, accurate to 10ms. diff --git a/tests/test_command.py b/tests/test_command.py index 7159f8e..58640eb 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -20,43 +20,14 @@ # along with this program. If not, see . import json -import os -import tempfile from typing import cast import astropy.table import lsst.rubintv.analysis.service as lras -import lsst.rubintv.analysis.service.database -import sqlalchemy import utils -import yaml class TestCommand(utils.RasTestCase): - def setUp(self): - path = os.path.dirname(__file__) - yaml_filename = os.path.join(path, "schema.yaml") - - with open(yaml_filename) as file: - schema = yaml.safe_load(file) - db_file = tempfile.NamedTemporaryFile(delete=False) - utils.create_database(schema, db_file.name) - self.db_file = db_file - self.db_filename = db_file.name - - # Load the database connection information - databases = { - "testdb": lsst.rubintv.analysis.service.database.ConsDbSchema( - schema=schema, engine=sqlalchemy.create_engine("sqlite:///" + db_file.name) - ) - } - - self.data_center = lras.data.DataCenter(schemas=databases) - - def tearDown(self) -> None: - self.db_file.close() - os.remove(self.db_file.name) - def execute_command(self, command: dict, response_type: str) -> dict: command_json = json.dumps(command) response = lras.command.execute_command(command_json, self.data_center) @@ -71,12 +42,12 @@ def test_calculate_bounds_command(self): "name": "get bounds", "parameters": { "database": "testdb", - "column": "Visit.dec", + "column": "exposure.dec", }, } print(lras.command.BaseCommand.command_registry) content = self.execute_command(command, "column bounds") - self.assertEqual(content["column"], "Visit.dec") + self.assertEqual(content["column"], "exposure.dec") self.assertListEqual(content["bounds"], [-40, 50]) @@ -87,8 +58,8 @@ def test_load_full_columns(self): "parameters": { "database": "testdb", "columns": [ - "Visit.ra", - "Visit.dec", + "exposure.ra", + "exposure.dec", ], }, } @@ -98,15 +69,14 @@ def test_load_full_columns(self): truth = cast( astropy.table.Table, - utils.get_test_data("Visit")[ - "Visit.ra", - "Visit.dec", - "Visit.day_obs", - "Visit.seq_num", - "Visit.instrument", + utils.get_test_data("exposure")[ + "exposure.ra", + "exposure.dec", + "exposure.day_obs", + "exposure.seq_num", ], ) - valid = (truth["Visit.ra"] != None) & (truth["Visit.dec"] != None) # noqa: E711 + valid = (truth["exposure.ra"] != None) & (truth["exposure.dec"] != None) # noqa: E711 truth = cast(astropy.table.Table, truth[valid]) self.assertDataTableEqual(data, truth) @@ -116,14 +86,14 @@ def test_load_columns_with_query(self): "parameters": { "database": "testdb", "columns": [ - "ExposureInfo.exposure_id", - "Visit.ra", - "Visit.dec", + "visit1_quicklook.visit_id", + "exposure.ra", + "exposure.dec", ], "query": { "name": "EqualityQuery", "content": { - "column": "ExposureInfo.expTime", + "column": "visit1_quicklook.exp_time", "operator": "eq", "value": 30, }, @@ -134,26 +104,24 @@ def test_load_columns_with_query(self): content = self.execute_command(command, "table columns") data = content["data"] - visit_truth = utils.get_test_data("Visit") - exp_truth = utils.get_test_data("ExposureInfo") + visit_truth = utils.get_test_data("exposure") + exp_truth = utils.get_test_data("visit1_quicklook") truth = astropy.table.join( visit_truth, exp_truth, - keys_left=("Visit.seq_num", "Visit.day_obs", "Visit.instrument"), - keys_right=("ExposureInfo.seq_num", "ExposureInfo.day_obs", "ExposureInfo.instrument"), + keys_left=("exposure.exposure_id",), + keys_right=("visit1_quicklook.visit_id",), ) truth = truth[ - "ExposureInfo.exposure_id", - "Visit.ra", - "Visit.dec", - "ExposureInfo.day_obs", - "ExposureInfo.seq_num", - "ExposureInfo.instrument", + "visit1_quicklook.visit_id", + "exposure.ra", + "exposure.dec", + "exposure.day_obs", + "exposure.seq_num", ] # Select rows with expTime = 30 truth = truth[[True, True, False, False, False, True, False, False, False, False]] - print(data.keys()) self.assertDataTableEqual(data, truth) diff --git a/tests/test_database.py b/tests/test_database.py index d1490d1..cf2d2f7 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -30,60 +30,59 @@ def test_get_table_names(self): self.assertTupleEqual( table_names, ( - "Visit", - "ExposureInfo", + "exposure", + "visit1_quicklook", ), ) def test_get_table_schema(self): - schema = lras.database.get_table_schema(self.database.schema, "ExposureInfo") - self.assertEqual(schema["name"], "ExposureInfo") + schema = lras.database.get_table_schema(self.database.schema, "exposure") + self.assertEqual(schema["name"], "exposure") columns = [ + "exposure_id", "seq_num", "day_obs", - "instrument", - "exposure_id", - "expTime", + "ra", + "dec", "physical_filter", - "obsStart", - "obsStartMJD", + "obs_start", + "obs_start_mjd", ] for n, column in enumerate(schema["columns"]): self.assertEqual(column["name"], columns[n]) def test_single_table_query_columns(self): - truth = utils.get_test_data("Visit") - valid = (truth["Visit.ra"] != None) & (truth["Visit.dec"] != None) # noqa: E711 + truth = utils.get_test_data("exposure") + valid = (truth["exposure.ra"] != None) & (truth["exposure.dec"] != None) # noqa: E711 truth = truth[valid] - truth = truth["Visit.ra", "Visit.dec", "Visit.day_obs", "Visit.seq_num", "Visit.instrument"] - data = self.database.query(columns=["Visit.ra", "Visit.dec"]) + truth = truth["exposure.ra", "exposure.dec", "exposure.day_obs", "exposure.seq_num"] + data = self.database.query(columns=["exposure.ra", "exposure.dec"]) self.assertDataTableEqual(data, truth) # type: ignore def test_multiple_table_query_columns(self): - visit_truth = utils.get_test_data("Visit") - exp_truth = utils.get_test_data("ExposureInfo") + visit_truth = utils.get_test_data("exposure") + exp_truth = utils.get_test_data("visit1_quicklook") truth = astropy.table.join( visit_truth, exp_truth, - keys_left=("Visit.seq_num", "Visit.day_obs", "Visit.instrument"), - keys_right=("ExposureInfo.seq_num", "ExposureInfo.day_obs", "ExposureInfo.instrument"), + keys_left=("exposure.exposure_id"), + keys_right=("visit1_quicklook.visit_id"), ) - valid = (truth["Visit.ra"] != None) & (truth["Visit.dec"] != None) # noqa: E711 + valid = (truth["exposure.ra"] != None) & (truth["exposure.dec"] != None) # noqa: E711 truth = truth[valid] truth = truth[ - "Visit.ra", - "Visit.dec", - "ExposureInfo.exposure_id", - "Visit.day_obs", - "Visit.seq_num", - "Visit.instrument", + "exposure.ra", + "exposure.dec", + "visit1_quicklook.visit_id", + "exposure.day_obs", + "exposure.seq_num", ] - data = self.database.query(columns=["Visit.ra", "Visit.dec", "ExposureInfo.exposure_id"]) + data = self.database.query(columns=["exposure.ra", "exposure.dec", "visit1_quicklook.visit_id"]) self.assertDataTableEqual(data, truth) def test_calculate_bounds(self): - result = self.database.calculate_bounds("Visit.dec") + result = self.database.calculate_bounds("exposure.dec") self.assertTupleEqual(result, (-40, 50)) diff --git a/tests/test_query.py b/tests/test_query.py index 7e828d5..b97da65 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -19,40 +19,15 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import os -import tempfile - import astropy.table import lsst.rubintv.analysis.service as lras import sqlalchemy import utils -import yaml class TestQuery(utils.RasTestCase): - def setUp(self): - path = os.path.dirname(__file__) - yaml_filename = os.path.join(path, "schema.yaml") - - with open(yaml_filename) as file: - schema = yaml.safe_load(file) - db_file = tempfile.NamedTemporaryFile(delete=False) - utils.create_database(schema, db_file.name) - self.db_file = db_file - self.db_filename = db_file.name - self.schema = schema - - # Set up the sqlalchemy connection - self.engine = sqlalchemy.create_engine("sqlite:///" + db_file.name) - self.metadata = sqlalchemy.MetaData() - self.database = lras.database.ConsDbSchema(schema=schema, engine=self.engine) - - def tearDown(self) -> None: - self.db_file.close() - os.remove(self.db_file.name) - def test_equality(self): - query_table = self.database.tables["Visit"] + query_table = self.database.tables["exposure"] query_column = query_table.columns.dec value = 0 @@ -66,20 +41,20 @@ def test_equality(self): } for operator, truth in truth_dict.items(): - result = lras.query.EqualityQuery("Visit.dec", operator, value)(self.database) + result = lras.query.EqualityQuery("exposure.dec", operator, value)(self.database) self.assertTrue(result.result.compare(truth)) self.assertSetEqual( result.tables, { - "Visit", + "exposure", }, ) def test_query(self): - dec_column = self.database.tables["Visit"].columns.dec - ra_column = self.database.tables["Visit"].columns.ra + dec_column = self.database.tables["exposure"].columns.dec + ra_column = self.database.tables["exposure"].columns.ra # dec > 0 - query = lras.query.EqualityQuery("Visit.dec", "gt", 0) + query = lras.query.EqualityQuery("exposure.dec", "gt", 0) result = query(self.database) self.assertTrue(result.result.compare(dec_column > 0)) @@ -87,8 +62,8 @@ def test_query(self): query = lras.query.ParentQuery( operator="AND", children=[ - lras.query.EqualityQuery("Visit.dec", "lt", 0), - lras.query.EqualityQuery("Visit.ra", "gt", 60), + lras.query.EqualityQuery("exposure.dec", "lt", 0), + lras.query.EqualityQuery("exposure.ra", "gt", 60), ], ) result = query(self.database) @@ -107,13 +82,13 @@ def test_query(self): self.assertFalse(result.result.compare(truth)) def test_database_query(self): - data = utils.get_test_data("Visit") + data = utils.get_test_data("exposure") # dec > 0 (and is not None) query1 = { "name": "EqualityQuery", "content": { - "column": "Visit.dec", + "column": "exposure.dec", "operator": "gt", "value": 0, }, @@ -122,7 +97,7 @@ def test_database_query(self): query2 = { "name": "EqualityQuery", "content": { - "column": "Visit.ra", + "column": "exposure.ra", "operator": "gt", "value": 60, }, @@ -130,14 +105,13 @@ def test_database_query(self): # Test 1: dec > 0 (and is not None) query = query1 - result = self.database.query(["Visit.ra", "Visit.dec"], query=query) + result = self.database.query(["exposure.ra", "exposure.dec"], query=lras.query.Query.from_dict(query)) truth = data[[False, False, False, False, False, True, False, False, True, True]] truth = truth[ - "Visit.ra", - "Visit.dec", - "Visit.day_obs", - "Visit.seq_num", - "Visit.instrument", + "exposure.ra", + "exposure.dec", + "exposure.day_obs", + "exposure.seq_num", ] # type: ignore self.assertDataTableEqual(result, truth) # type:ignore @@ -149,14 +123,13 @@ def test_database_query(self): "children": [query1, query2], }, } - result = self.database.query(["Visit.ra", "Visit.dec"], query=query) + result = self.database.query(["exposure.ra", "exposure.dec"], query=lras.query.Query.from_dict(query)) truth = data[[False, False, False, False, False, False, False, False, True, True]] truth = truth[ - "Visit.ra", - "Visit.dec", - "Visit.day_obs", - "Visit.seq_num", - "Visit.instrument", + "exposure.ra", + "exposure.dec", + "exposure.day_obs", + "exposure.seq_num", ] # type: ignore self.assertDataTableEqual(result, truth) # type:ignore @@ -178,14 +151,13 @@ def test_database_query(self): }, } - result = self.database.query(["Visit.ra", "Visit.dec"], query=query) + result = self.database.query(["exposure.ra", "exposure.dec"], query=lras.query.Query.from_dict(query)) truth = data[[True, True, False, True, True, False, False, False, True, True]] truth = truth[ - "Visit.ra", - "Visit.dec", - "Visit.day_obs", - "Visit.seq_num", - "Visit.instrument", + "exposure.ra", + "exposure.dec", + "exposure.day_obs", + "exposure.seq_num", ] # type: ignore self.assertDataTableEqual(result, truth) # type:ignore @@ -197,36 +169,34 @@ def test_database_query(self): "children": [query1, query2], }, } - result = self.database.query(["Visit.ra", "Visit.dec"], query=query) + result = self.database.query(["exposure.ra", "exposure.dec"], query=lras.query.Query.from_dict(query)) truth = data[[False, False, False, False, False, True, False, False, False, False]] truth = truth[ - "Visit.ra", - "Visit.dec", - "Visit.day_obs", - "Visit.seq_num", - "Visit.instrument", + "exposure.ra", + "exposure.dec", + "exposure.day_obs", + "exposure.seq_num", ] # type: ignore self.assertDataTableEqual(result, truth) # type:ignore def test_database_string_query(self): - data = utils.get_test_data("ExposureInfo") + data = utils.get_test_data("exposure") # Test equality query = { "name": "EqualityQuery", "content": { - "column": "ExposureInfo.physical_filter", + "column": "exposure.physical_filter", "operator": "eq", "value": "DECam r-band", }, } - result = self.database.query(["ExposureInfo.physical_filter"], query=query) + result = self.database.query(["exposure.physical_filter"], query=lras.query.Query.from_dict(query)) truth = data[[False, False, False, False, False, False, True, False, False, False]] truth = truth[ - "ExposureInfo.physical_filter", - "ExposureInfo.day_obs", - "ExposureInfo.seq_num", - "ExposureInfo.instrument", + "exposure.physical_filter", + "exposure.day_obs", + "exposure.seq_num", ] # type: ignore self.assertDataTableEqual(result, truth) # type:ignore @@ -234,18 +204,17 @@ def test_database_string_query(self): query = { "name": "EqualityQuery", "content": { - "column": "ExposureInfo.physical_filter", + "column": "exposure.physical_filter", "operator": "startswith", "value": "DECam", }, } - result = self.database.query(["ExposureInfo.physical_filter"], query=query) + result = self.database.query(["exposure.physical_filter"], query=lras.query.Query.from_dict(query)) truth = data[[False, False, False, False, False, True, True, True, True, True]] truth = truth[ - "ExposureInfo.physical_filter", - "ExposureInfo.day_obs", - "ExposureInfo.seq_num", - "ExposureInfo.instrument", + "exposure.physical_filter", + "exposure.day_obs", + "exposure.seq_num", ] # type: ignore self.assertDataTableEqual(result, truth) # type:ignore @@ -253,18 +222,17 @@ def test_database_string_query(self): query = { "name": "EqualityQuery", "content": { - "column": "ExposureInfo.physical_filter", + "column": "exposure.physical_filter", "operator": "endswith", "value": "r-band", }, } - result = self.database.query(["ExposureInfo.physical_filter"], query=query) + result = self.database.query(["exposure.physical_filter"], query=lras.query.Query.from_dict(query)) truth = data[[False, True, False, False, False, False, True, False, False, False]] truth = truth[ - "ExposureInfo.physical_filter", - "ExposureInfo.day_obs", - "ExposureInfo.seq_num", - "ExposureInfo.instrument", + "exposure.physical_filter", + "exposure.day_obs", + "exposure.seq_num", ] # type: ignore self.assertDataTableEqual(result, truth) # type:ignore @@ -272,40 +240,38 @@ def test_database_string_query(self): query = { "name": "EqualityQuery", "content": { - "column": "ExposureInfo.physical_filter", + "column": "exposure.physical_filter", "operator": "contains", "value": "T r", }, } - result = self.database.query(["ExposureInfo.physical_filter"], query=query) + result = self.database.query(["exposure.physical_filter"], query=lras.query.Query.from_dict(query)) truth = data[[False, True, False, False, False, False, False, False, False, False]] truth = truth[ - "ExposureInfo.physical_filter", - "ExposureInfo.day_obs", - "ExposureInfo.seq_num", - "ExposureInfo.instrument", + "exposure.physical_filter", + "exposure.day_obs", + "exposure.seq_num", ] # type: ignore self.assertDataTableEqual(result, truth) # type:ignore def test_database_datatime_query(self): - data = utils.get_test_data("ExposureInfo") + data = utils.get_test_data("exposure") # Test < query1 = { "name": "EqualityQuery", "content": { - "column": "ExposureInfo.obsStart", + "column": "exposure.obs_start", "operator": "lt", "value": "2023-05-19 23:23:23", }, } - result = self.database.query(["ExposureInfo.obsStart"], query=query1) + result = self.database.query(["exposure.obs_start"], query=lras.query.Query.from_dict(query1)) truth = data[[True, True, True, False, False, True, True, True, True, True]] truth = truth[ - "ExposureInfo.obsStart", - "ExposureInfo.day_obs", - "ExposureInfo.seq_num", - "ExposureInfo.instrument", + "exposure.obs_start", + "exposure.day_obs", + "exposure.seq_num", ] # type: ignore self.assertDataTableEqual(result, truth) # type:ignore @@ -313,18 +279,17 @@ def test_database_datatime_query(self): query2 = { "name": "EqualityQuery", "content": { - "column": "ExposureInfo.obsStart", + "column": "exposure.obs_start", "operator": "gt", "value": "2023-05-01 23:23:23", }, } - result = self.database.query(["ExposureInfo.obsStart"], query=query2) + result = self.database.query(["exposure.obs_start"], query=lras.query.Query.from_dict(query2)) truth = data[[True, True, True, True, True, False, False, False, False, False]] truth = truth[ - "ExposureInfo.obsStart", - "ExposureInfo.day_obs", - "ExposureInfo.seq_num", - "ExposureInfo.instrument", + "exposure.obs_start", + "exposure.day_obs", + "exposure.seq_num", ] # type: ignore self.assertDataTableEqual(result, truth) # type:ignore @@ -336,31 +301,30 @@ def test_database_datatime_query(self): "children": [query1, query2], }, } - result = self.database.query(["ExposureInfo.obsStart"], query=query3) + result = self.database.query(["exposure.obs_start"], query=lras.query.Query.from_dict(query3)) truth = data[[True, True, True, False, False, False, False, False, False, False]] truth = truth[ - "ExposureInfo.obsStart", - "ExposureInfo.day_obs", - "ExposureInfo.seq_num", - "ExposureInfo.instrument", + "exposure.obs_start", + "exposure.day_obs", + "exposure.seq_num", ] # type: ignore self.assertDataTableEqual(result, truth) # type:ignore def test_multiple_table_query(self): - visit_truth = utils.get_test_data("Visit") - exp_truth = utils.get_test_data("ExposureInfo") + visit_truth = utils.get_test_data("exposure") + exp_truth = utils.get_test_data("visit1_quicklook") truth = astropy.table.join( visit_truth, exp_truth, - keys_left=("Visit.seq_num", "Visit.day_obs", "Visit.instrument"), - keys_right=("ExposureInfo.seq_num", "ExposureInfo.day_obs", "ExposureInfo.instrument"), + keys_left=("exposure.exposure_id",), + keys_right=("visit1_quicklook.visit_id",), ) # dec > 0 (and is not None) query1 = { "name": "EqualityQuery", "content": { - "column": "Visit.dec", + "column": "exposure.dec", "operator": "gt", "value": 0, }, @@ -369,7 +333,7 @@ def test_multiple_table_query(self): query2 = { "name": "EqualityQuery", "content": { - "column": "ExposureInfo.expTime", + "column": "visit1_quicklook.exp_time", "operator": "eq", "value": 30, }, @@ -384,24 +348,24 @@ def test_multiple_table_query(self): } valid = ( - (truth["Visit.dec"] != None) # noqa: E711 - & (truth["Visit.ra"] != None) # noqa: E711 - & (truth["ExposureInfo.exposure_id"] != None) # noqa: E711 + (truth["exposure.dec"] != None) # noqa: E711 + & (truth["exposure.ra"] != None) # noqa: E711 + & (truth["visit1_quicklook.visit_id"] != None) # noqa: E711 ) truth = truth[valid] - valid = (truth["Visit.dec"] > 0) & (truth["ExposureInfo.expTime"] == 30) + valid = (truth["exposure.dec"] > 0) & (truth["visit1_quicklook.exp_time"] == 30) truth = truth[valid] truth = truth[ - "Visit.dec", - "Visit.ra", - "Visit.seq_num", - "Visit.instrument", - "ExposureInfo.exposure_id", - "Visit.day_obs", + "exposure.dec", + "exposure.ra", + "visit1_quicklook.visit_id", + "exposure.day_obs", + "exposure.seq_num", ] result = self.database.query( - columns=["Visit.ra", "Visit.dec", "ExposureInfo.exposure_id"], query=query3 + columns=["exposure.ra", "exposure.dec", "visit1_quicklook.visit_id"], + query=lras.query.Query.from_dict(query3), ) self.assertDataTableEqual(result, truth) diff --git a/tests/utils.py b/tests/utils.py index 22c7603..62cd1ba 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -43,6 +43,12 @@ "datetime": "text", } +# Convert DataID columns +dataid_transform = { + "exposure.day_obs": "day_obs", + "exposure.seq_num": "seq_num", +} + def create_table(cursor: sqlite3.Cursor, tbl_name: str, schema: dict): """Create a table in an sqlite database. @@ -63,42 +69,9 @@ def create_table(cursor: sqlite3.Cursor, tbl_name: str, schema: dict): cursor.execute(command) -def get_visit_data_dict() -> dict: +def get_exposure_data_dict() -> dict: """Get a dictionary containing the visit test data""" - return { - "Visit.seq_num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], - "Visit.day_obs": [ - "2023-05-19", - "2023-05-19", - "2023-05-19", - "2023-05-19", - "2023-05-19", - "2023-02-14", - "2023-02-14", - "2023-02-14", - "2023-02-14", - "2023-02-14", - ], - "Visit.instrument": [ - "LSST", - "LSST", - "LSST", - "LSST", - "LSST", - "DECam", - "DECam", - "DECam", - "DECam", - "DECam", - ], - "Visit.ra": [10, 20, None, 40, 50, 60, 70, None, 90, 100], - "Visit.dec": [-40, -30, None, -10, 0, 10, None, 30, 40, 50], - } - - -def get_exposure_data_dict() -> dict: - """Get a dictionary containing the exposure test data""" obs_start = [ "2023-05-19 20:20:20", "2023-05-19 21:21:21", @@ -115,8 +88,9 @@ def get_exposure_data_dict() -> dict: obs_start_mjd = [Time(time).mjd for time in obs_start] return { - "ExposureInfo.seq_num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], - "ExposureInfo.day_obs": [ + "exposure.exposure_id": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], + "exposure.seq_num": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "exposure.day_obs": [ "2023-05-19", "2023-05-19", "2023-05-19", @@ -128,21 +102,9 @@ def get_exposure_data_dict() -> dict: "2023-02-14", "2023-02-14", ], - "ExposureInfo.instrument": [ - "LSST", - "LSST", - "LSST", - "LSST", - "LSST", - "DECam", - "DECam", - "DECam", - "DECam", - "DECam", - ], - "ExposureInfo.exposure_id": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18], - "ExposureInfo.expTime": [30, 30, 10, 15, 15, 30, 30, 30, 15, 20], - "ExposureInfo.physical_filter": [ + "exposure.ra": [10, 20, None, 40, 50, 60, 70, None, 90, 100], + "exposure.dec": [-40, -30, None, -10, 0, 10, None, 30, 40, 50], + "exposure.physical_filter": [ "LSST g-band", "LSST r-band", "LSST i-band", @@ -154,17 +116,25 @@ def get_exposure_data_dict() -> dict: "DECam z-band", "DECam y-band", ], - "ExposureInfo.obsStart": obs_start, - "ExposureInfo.obsStartMJD": obs_start_mjd, + "exposure.obs_start": obs_start, + "exposure.obs_start_mjd": obs_start_mjd, + } + + +def get_visit_data_dict() -> dict: + """Get a dictionary containing the exposure test data""" + return { + "visit1_quicklook.visit_id": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], + "visit1_quicklook.exp_time": [30, 30, 10, 15, 15, 30, 30, 30, 15, 20], } def get_test_data(table: str) -> ApTable: """Generate data for the test database""" - if table == "Visit": - data_dict = get_visit_data_dict() - else: + if table == "exposure": data_dict = get_exposure_data_dict() + else: + data_dict = get_visit_data_dict() return ApTable(list(data_dict.values()), names=list(data_dict.keys())) @@ -186,14 +156,16 @@ def create_database(schema: dict, db_filename: str): create_table(cursor, table["name"], table["columns"]) - if table["name"] == "Visit": - data = get_visit_data_dict() - elif table["name"] == "ExposureInfo": + if table["name"] == "exposure": data = get_exposure_data_dict() + index_key = "exposure.exposure_id" + elif table["name"] == "visit1_quicklook": + data = get_visit_data_dict() + index_key = "visit1_quicklook.visit_id" else: raise ValueError(f"Unknown table name: {table['name']}") - for n in range(len(data[f"{table['name']}.seq_num"])): + for n in range(len(data[index_key])): row = tuple(data[key][n] for key in data.keys()) value_str = "?, " * (len(row) - 1) + "?" command = f"INSERT INTO {table['name']} VALUES({value_str});" @@ -222,17 +194,30 @@ def setUp(self): with open(yaml_filename) as file: schema = yaml.safe_load(file) + # Remove the name of the schema, since sqlite does not have + # schema names and this will break the code otherwise. + schema["name"] = None + # Create the sqlite test database db_file = tempfile.NamedTemporaryFile(delete=False) create_database(schema, db_file.name) self.db_file = db_file + self.db_filename = db_file.name + self.schema = schema # Set up the sqlalchemy connection engine = sqlalchemy.create_engine("sqlite:///" + db_file.name) + # Load the table joins + joins_path = os.path.abspath( + os.path.expanduser(os.path.expandvars(os.path.join(path, "..", "scripts", "joins.yaml"))) + ) + with open(joins_path) as file: + joins = yaml.safe_load(file)["joins"] + # Create the datacenter - self.database = ConsDbSchema(schema=schema, engine=engine) - self.dataCenter = DataCenter(schemas={"testdb": self.database}) + self.database = ConsDbSchema(schema=schema, engine=engine, join_templates=joins) + self.data_center = DataCenter(schemas={"testdb": self.database}) def tearDown(self) -> None: self.db_file.close() @@ -250,7 +235,11 @@ def assertDataTableEqual(self, result: dict | ApTable, truth: ApTable): # NOQA: """ columns = truth.colnames for column in columns: + result_column = column if column not in result: - msg = f"Column {column} not found in result" - raise TableMismatchError(msg) - np.testing.assert_array_equal(np.array(result[column]), np.array(truth[column])) + if column in dataid_transform: + result_column = dataid_transform[column] + else: + msg = f"Column {column} not found in result" + raise TableMismatchError(msg) + np.testing.assert_array_equal(np.array(result[result_column]), np.array(truth[column]))