From 6fec4af052207d84cdfa2a62bdf1cf5f36404677 Mon Sep 17 00:00:00 2001 From: ugyballoons Date: Tue, 7 Jan 2025 15:34:13 +0000 Subject: [PATCH] Implement response_type correctly --- python/lsst/rubintv/analysis/service/commands/db.py | 7 +++---- tests/test_command.py | 12 ++++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/python/lsst/rubintv/analysis/service/commands/db.py b/python/lsst/rubintv/analysis/service/commands/db.py index 47afc87..9e9f299 100644 --- a/python/lsst/rubintv/analysis/service/commands/db.py +++ b/python/lsst/rubintv/analysis/service/commands/db.py @@ -232,7 +232,7 @@ class AggregateQueryCommand(BaseCommand): data_ids : list[tuple[int, int]] | None Specific (day_obs, seq_num) pairs to filter rows. response_type : str - The type of response returned, defaulting to "aggregate". + The type of response returned, defaulting to "aggregated". """ database: str @@ -242,7 +242,7 @@ class AggregateQueryCommand(BaseCommand): global_query: dict | None = None day_obs: str | None = None data_ids: list[tuple[int, int]] | None = None - response_type: str = "aggregate" + response_type: str = "aggregated" def build_contents(self, data_center: DataCenter) -> dict: """Query the database to perform the specified aggregate operation on each column.""" @@ -268,10 +268,9 @@ def build_contents(self, data_center: DataCenter) -> dict: # Extract the aggregate result from the query result result[column] = query_result.result.get(self.query_type.lower(), 0) - # Return the results in the expected format return { "schema": self.database, - self.response_type: result, + self.query_type: result, } diff --git a/tests/test_command.py b/tests/test_command.py index bc7e950..bbcbe0f 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -51,8 +51,8 @@ def test_count_rows(self): "query_type": "count", }, } - content = self.execute_command(command, "aggregate") - data = content["aggregate"] + content = self.execute_command(command, "aggregated") + data = content["count"] self.assertEqual(data, {columns[0]: [8], columns[1]: [8]}) def test_sum_rows(self): @@ -68,8 +68,8 @@ def test_sum_rows(self): "query_type": "sum", }, } - content = self.execute_command(command, "aggregate") - data = content["aggregate"] + content = self.execute_command(command, "aggregated") + data = content["sum"] self.assertEqual(data, {columns[0]: [440.0], columns[1]: [50.0]}) def test_max_rows(self): @@ -85,8 +85,8 @@ def test_max_rows(self): "query_type": "max", }, } - content = self.execute_command(command, "aggregate") - data = content["aggregate"] + content = self.execute_command(command, "aggregated") + data = content["max"] self.assertEqual(data, {columns[0]: [100.0], columns[1]: [50.0]})