Skip to content

Commit

Permalink
Rename dump command to dump-targets and add dump-sdss-id command
Browse files Browse the repository at this point in the history
  • Loading branch information
albireox committed Dec 14, 2024
1 parent c941d9c commit 1ad49e4
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 30 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ requires-python = ">=3.12,<3.13"

dependencies = [
"sdsstools>=1.8.2",
"sdss-target-selection",
"sdss-target-selection>=1.3.19",
"adbc-driver-postgresql>=1.2.0",
"polars>=1.17.1",
"httpx>=0.27.0",
Expand All @@ -24,8 +24,8 @@ dependencies = [
"typer>=0.13.0",
]

[tool.uv.sources]
sdss-target-selection = { git = "https://github.com/sdss/target_selection.git", branch = "main" }
# [tool.uv.sources]
# sdss-target-selection = { git = "https://github.com/sdss/target_selection.git", branch = "main" }

[project.urls]
Homepage = "https://github.com/sdss/too"
Expand Down
39 changes: 35 additions & 4 deletions src/too/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def process(
update_sdss_id_tables(database)


@too_cli.command()
def dump(
@too_cli.command(name="dump-targets")
def dump_targets(
file: Annotated[str, typer.Argument(help="The file to dump the ToO targets into.")],
observatory: Annotated[
Observatories,
Expand All @@ -168,10 +168,41 @@ def dump(
):
"""Dumps the current ToO targets into a Parquet file."""

from too import connect_to_database, dump_to_parquet
from too import connect_to_database, dump_targets_to_parquet

database = connect_to_database(dbname, host=host, port=port, user=user)
dump_to_parquet(observatory.value.upper(), file, database=database)
dump_targets_to_parquet(observatory.value.upper(), file, database=database)


@too_cli.command(name="dump-sdss-id")
def dump_sdss_id(
last_updated: Annotated[
str,
typer.Argument(help="The last_updated value to use to search for new targets."),
],
dbname: Annotated[str, dbname] = "sdss5db",
host: Annotated[str, host] = "localhost",
port: Annotated[int | None, port] = None,
user: Annotated[str, user] = "sdss",
):
"""Dumps the ``sdss_id_flat`` and ``sdss_id_stacked`` tables for a given date.
Searches the ``sandbox.sdss_id_stacked`` table for targets matching
``last_updated`` and uses that to determine the ``sdss_id_flat`` associated rows.
The tables are saved to the current directory as CSV files.
"""

from too import connect_to_database, dump_sdss_id_tables

database = connect_to_database(
dbname,
host=host,
port=port,
user=user,
)

print(dump_sdss_id_tables(last_updated, database))


@too_cli.command()
Expand Down
40 changes: 38 additions & 2 deletions src/too/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from sdssdb.connection import PeeweeDatabaseConnection


__all__ = ["dump_to_parquet"]
__all__ = ["dump_targets_to_parquet", "dump_sdss_id_tables"]


def dump_to_parquet(
def dump_targets_to_parquet(
observatory: str,
path: os.PathLike | str,
database: PeeweeDatabaseConnection | None = None,
Expand Down Expand Up @@ -175,3 +175,39 @@ def dump_to_parquet(
bn.write_parquet(path)

return bn


def dump_sdss_id_tables(last_updated: str, database: PeeweeDatabaseConnection):
"""Dumps the SDSS ID tables to CSV files.
Parameters
----------
last_updated
The value of the ``last_updated`` column in ``sandbox.sdss_id_stacked``
to filter the data.
database
The database connection.
"""

sdss_id_flat_path = pathlib.Path("sdss_id_flat.csv")
sdss_id_stacked_path = pathlib.Path("sdss_id_stacked.csv")

if sdss_id_flat_path.exists() or sdss_id_stacked_path.exists():
raise FileExistsError("SDSS ID CSV files already exist.")

sdss_id_stacked_data = polars.read_database(
f"SELECT * FROM sandbox.sdss_id_stacked WHERE last_updated = {last_updated!r}",
database,
)

sdss_id_flat_data = polars.read_database(
f"""SELECT flat.* FROM sandbox.sdss_id_flat flat
JOIN sandbox.sdss_id_stacked stacked ON stacked.sdss_id = flat.sdss_id
WHERE stacked.last_updated = {last_updated!r}
""",
database,
)

sdss_id_stacked_data.write_csv(sdss_id_stacked_path)
sdss_id_flat_data.write_csv(sdss_id_flat_path)
27 changes: 24 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from __future__ import annotations

import datetime
import os
import pathlib

import polars
Expand Down Expand Up @@ -66,7 +68,7 @@ def test_cli_only_load(
assert catalogdb.ToO_Metadata.select().count() == n_loaded


def test_cli_only_process(mock_validation):
def test_cli_only_process(tmp_path_factory: pytest.TempPathFactory, mock_validation):
database = connect_to_database("sdss5db_too_test", user="sdss", host="localhost")

n_sdss_id_flat_pre = int(
Expand All @@ -87,6 +89,25 @@ def test_cli_only_process(mock_validation):

assert n_sdss_id_flat_post > n_sdss_id_flat_pre

# Now that we have run the sdss_id update, test that the dump-sdss-id command works.
# First move to a temporary directory for the sdss_id CSV files
sdss_id_dir = tmp_path_factory.mktemp("sdss_id")
os.chdir(sdss_id_dir)

today = datetime.datetime.now().strftime("%Y-%m-%d")

result_dump_sdss_id = runner.invoke(
too_cli,
["dump-sdss-id", "--dbname", "sdss5db_too_test", today],
)
assert result_dump_sdss_id.exit_code == 0

assert (sdss_id_dir / "sdss_id_flat.csv").exists()
assert (sdss_id_dir / "sdss_id_stacked.csv").exists()

stacked = polars.read_csv(sdss_id_dir / "sdss_id_stacked.csv")
assert len(stacked) > 0


def test_cli_update(
files_path: pathlib.Path,
Expand Down Expand Up @@ -128,14 +149,14 @@ def test_cli_update(
assert n_carton_to_target == n_target


def test_cli_dump(tmp_path: pathlib.Path, mock_validation):
def test_cli_dump_targets(tmp_path: pathlib.Path, mock_validation):
tmp_file = tmp_path / "dump_APO.parquet"

runner = CliRunner()
result = runner.invoke(
too_cli,
[
"dump",
"dump-targets",
"--observatory",
"APO",
"--dbname",
Expand Down
8 changes: 4 additions & 4 deletions tests/test_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest

from too.database import load_too_targets
from too.dump import dump_to_parquet
from too.dump import dump_targets_to_parquet
from too.xmatch import xmatch_too_targets


Expand All @@ -41,7 +41,7 @@ def test_dump(

path = tmp_path / "too_dump.parquet"

df = dump_to_parquet("APO", path, database=database)
df = dump_targets_to_parquet("APO", path, database=database)

assert isinstance(df, polars.DataFrame)
assert path.exists()
Expand All @@ -52,9 +52,9 @@ def test_dump(

def test_dump_invalid_observatory(database: PeeweeDatabaseConnection):
with pytest.raises(ValueError):
dump_to_parquet("ManuaKea", "/a/b.parquet", database=database)
dump_targets_to_parquet("ManuaKea", "/a/b.parquet", database=database)


def test_dump_invalid_database():
with pytest.raises(ValueError):
dump_to_parquet("APO", "/a/b.parquet", database="abc") # type: ignore
dump_targets_to_parquet("APO", "/a/b.parquet", database="abc") # type: ignore
32 changes: 18 additions & 14 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1ad49e4

Please sign in to comment.