-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tools to create mock target lists
- Loading branch information
Showing
3 changed files
with
393 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# @Author: José Sánchez-Gallego ([email protected]) | ||
# @Date: 2024-02-08 | ||
# @Filename: datamodel.py | ||
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause) | ||
|
||
from __future__ import annotations | ||
|
||
import polars | ||
|
||
|
||
too_dtypes = { | ||
"too_id": polars.UInt64, | ||
"fiber_type": polars.String, | ||
"catalogid": polars.UInt64, | ||
"sdss_id": polars.UInt64, | ||
"gaia_dr3_source_id": polars.UInt64, | ||
"twomass_pts_key": polars.Int32, | ||
"sky_brightness_mode": polars.String, | ||
"ra": polars.Float64, | ||
"dec": polars.Float64, | ||
"pmra": polars.Float32, | ||
"pmdec": polars.Float32, | ||
"epoch": polars.Float32, | ||
"parallax": polars.Float32, | ||
"lambda_eff": polars.Float32, | ||
"u_mag": polars.Float32, | ||
"g_mag": polars.Float32, | ||
"r_mag": polars.Float32, | ||
"i_mag": polars.Float32, | ||
"z_mag": polars.Float32, | ||
"optical_prov": polars.String, | ||
"gaia_bp_mag": polars.Float32, | ||
"gaia_rp_mag": polars.Float32, | ||
"gaia_g_mag": polars.Float32, | ||
"h_mag": polars.Float32, | ||
"delta_ra": polars.Float32, | ||
"delta_dec": polars.Float32, | ||
"inertial": polars.Boolean, | ||
"n_exposures": polars.Int16, | ||
"priority": polars.Int16, | ||
"active": polars.Boolean, | ||
"expiration_date": polars.Int32, | ||
"observed": polars.Boolean, | ||
} | ||
|
||
too_fixed_columns = ["catalogid", "sdss_id", "gaia_dr3_source_id", "twomass_pts_key"] | ||
mag_columns = [ | ||
"u_mag", | ||
"g_mag", | ||
"r_mag", | ||
"i_mag", | ||
"z_mag", | ||
"gaia_bp_mag", | ||
"gaia_rp_mag", | ||
"gaia_g_mag", | ||
"h_mag", | ||
] | ||
fiber_type_values = ["APOGEE", "BOSS", ""] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# @Author: José Sánchez-Gallego ([email protected]) | ||
# @Date: 2024-02-08 | ||
# @Filename: mock.py | ||
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause) | ||
|
||
from __future__ import annotations | ||
|
||
import pathlib | ||
|
||
from typing import Literal, overload | ||
|
||
import numpy | ||
import polars | ||
|
||
from too import console, log | ||
from too.datamodel import too_dtypes | ||
from too.tools import download_file | ||
|
||
|
||
@overload | ||
def get_sample_data( | ||
table: Literal["gaia_dr3", "twomass", "photoobj"], | ||
lazy: Literal[True], | ||
) -> polars.LazyFrame: ... | ||
|
||
|
||
@overload | ||
def get_sample_data( | ||
table: Literal["gaia_dr3", "twomass", "photoobj"], | ||
lazy: Literal[False], | ||
) -> polars.DataFrame: ... | ||
|
||
|
||
def get_sample_data( | ||
table: Literal["gaia_dr3", "twomass", "photoobj"], | ||
lazy: bool = True, | ||
) -> polars.DataFrame | polars.LazyFrame: | ||
"""Downloads and caches the table sample data. | ||
Parameters | ||
---------- | ||
table | ||
The table to download. Can be ``'gaia_dr3'`` or ``'twomass'``. If the table | ||
file is not found in the cache, it will be downloaded from the SAS server. | ||
lazy | ||
Whether to return a lazy dataframe. | ||
Returns | ||
------- | ||
A polars dataframe with the sample data. | ||
""" | ||
|
||
BASE_URL = "https://data.sdss5.org/resources/target/mocks/samples/" | ||
CACHE_PATH = pathlib.Path("~/.cache/sdss/too/samples").expanduser() | ||
|
||
if table == "gaia_dr3": | ||
filename = "gaia_dr3_1M_sample.parquet" | ||
elif table == "twomass": | ||
filename = "twomass_psc_1M_sample.parquet" | ||
elif table == "photoobj": | ||
filename = "sdss_dr13_photoobj_1M_sample.parquet" | ||
else: | ||
raise ValueError(f"Invalid table {table!r}") | ||
|
||
if not (CACHE_PATH / filename).exists(): | ||
log.debug(f"File {filename!r} not found in cache. Downloading from {BASE_URL}.") | ||
download_file( | ||
f"{BASE_URL}/{filename}", | ||
CACHE_PATH, | ||
transient_progress=True, | ||
console=console, | ||
) | ||
|
||
if lazy: | ||
return polars.scan_parquet(CACHE_PATH / filename) | ||
return polars.read_parquet(CACHE_PATH / filename) | ||
|
||
|
||
def get_sample_targets( | ||
table: Literal["gaia_dr3", "twomass", "photoobj"], | ||
n_targets: int, | ||
): | ||
"""Returns a sample of table targets.""" | ||
|
||
data = get_sample_data(table, lazy=False) | ||
sample = data.sample(n_targets) | ||
|
||
if table == "gaia_dr3": | ||
col_mapping = { | ||
"source_id": "gaia_dr3_source_id", | ||
"ra": "ra", | ||
"dec": "dec", | ||
"pmra": "pmra", | ||
"pmdec": "pmdec", | ||
"parallax": "parallax", | ||
"phot_g_mean_mag": "gaia_g_mag", | ||
"phot_bp_mean_mag": "gaia_bp_mag", | ||
"phot_rp_mean_mag": "gaia_rp_mag", | ||
} | ||
elif table == "twomass": | ||
col_mapping = { | ||
"pts_key": "twomass_pts_key", | ||
"ra": "ra", | ||
"decl": "dec", | ||
"h_m": "h_mag", | ||
} | ||
elif table == "photoobj": | ||
col_mapping = { | ||
"ra": "ra", | ||
"dec": "dec", | ||
"psfmag_u": "u_mag", | ||
"psfmag_g": "g_mag", | ||
"psfmag_r": "r_mag", | ||
"psfmag_i": "i_mag", | ||
"psfmag_z": "z_mag", | ||
} | ||
|
||
sample = sample.select(*list(col_mapping), "catalogid", "sdss_id") | ||
sample = sample.rename(col_mapping) | ||
sample = sample.with_columns(polars.col(["ra", "dec"]).cast(polars.Float64)) | ||
|
||
return sample | ||
|
||
|
||
def create_mock_too_catalogue( | ||
n_targets: int = 1000000, | ||
fraction_unknown: float = 0.6, | ||
fraction_unknown_sdss: float = 0.2, | ||
fraction_unknown_gaia: float = 0.3, | ||
fraction_known_gaia: float = 0.8, | ||
fraction_known_sdss: float = 0.1, | ||
fraction_known_twomass: float = 0.1, | ||
catalogid_likelihood: float = 0.2, | ||
): | ||
"""Creates a mock ToO catalogue. | ||
Parameters | ||
---------- | ||
n_targets | ||
The number of targets in the mocked catalogue. | ||
fraction_unknown | ||
Fraction of the targets that will not have an associated target in | ||
``catalogdb``. | ||
fraction_unknown_sdss : float | ||
Fraction of the unknown targets that will actually be drawn from | ||
``sdss_dr13_photoobj``. | ||
fraction_unknown_gaia : float | ||
Fraction of the unknown targets that will actually be drawn from | ||
``gaia_dr3_source``. | ||
fraction_known_gaia | ||
Fraction of the targets that will have a known Gaia source. Along with | ||
``fraction_known_sdss`` and ``fraction_known_twomass``, must add up to 1. | ||
The total number of targets with ``gaia_dr3_source_id`` will be | ||
``n_targets * fraction_known_gaia * (1 - fraction_unknown)``. | ||
fraction known_sdss | ||
Fraction of the targets that will have a known SDSS source. | ||
fraction_known_twomass | ||
Fraction of the targets that will have a known 2MASS source. | ||
catalogid_likelihood | ||
The likelihood (0-1) that a target will have a known ``catalogid`` or | ||
``sdss_id``. | ||
Returns | ||
------- | ||
A polars dataframe with the mock catalogue. | ||
""" | ||
|
||
n_known = n_targets * (1 - fraction_unknown) | ||
|
||
gaia_known_targets = get_sample_targets( | ||
"gaia_dr3", | ||
int(n_known * fraction_known_gaia), | ||
) | ||
|
||
twomass_known_targets = get_sample_targets( | ||
"twomass", | ||
int(n_known * fraction_known_twomass), | ||
) | ||
|
||
sdss_known_targets = get_sample_targets( | ||
"photoobj", | ||
int(n_known * fraction_known_sdss), | ||
) | ||
sdss_known_targets = sdss_known_targets.with_columns(catalogid=None, sdss_id=None) | ||
|
||
n_unknown = n_targets * fraction_unknown | ||
|
||
gaia_unknown_targets = get_sample_targets( | ||
"gaia_dr3", | ||
int(n_unknown * fraction_unknown_gaia), | ||
) | ||
gaia_unknown_targets = gaia_unknown_targets.with_columns( | ||
catalogid=None, | ||
sdss_id=None, | ||
gaia_bp_mag=None, | ||
gaia_rp_mag=None, | ||
) | ||
|
||
sdss_unknown_targets = get_sample_targets( | ||
"photoobj", | ||
int(n_unknown * fraction_unknown_sdss), | ||
) | ||
sdss_unknown_targets = sdss_unknown_targets.with_columns( | ||
polars.col(r"^[ugiz]_mag$").replace({}, default=None), | ||
catalogid=None, | ||
sdss_id=None, | ||
) | ||
|
||
df = polars.DataFrame(schema=too_dtypes) | ||
df = polars.concat( | ||
[ | ||
df, | ||
gaia_known_targets, | ||
twomass_known_targets, | ||
sdss_known_targets, | ||
gaia_unknown_targets, | ||
sdss_unknown_targets, | ||
], | ||
how="diagonal", | ||
) | ||
|
||
n_random = int(n_targets - df.height) | ||
|
||
# Give half of them Gaia mags and half SDSS. | ||
random_gaia = [None] * n_random | ||
random_sdss = [None] * n_random | ||
random_gaia[: n_random // 2] = numpy.random.uniform(10, 20, size=n_random // 2) | ||
random_sdss[n_random // 2 :] = numpy.random.uniform(10, 20, size=n_random // 2) | ||
|
||
random_targets = polars.DataFrame( | ||
{ | ||
"ra": numpy.random.uniform(0, 360, n_random), | ||
"dec": numpy.random.uniform(-90, 90, n_random), | ||
"gaia_g_mag": polars.Series(random_gaia, dtype=polars.Float32), | ||
"r_mag": polars.Series(random_sdss, dtype=polars.Float32), | ||
}, | ||
) | ||
|
||
df = polars.concat([df, random_targets], how="diagonal") | ||
|
||
df = df.with_columns( | ||
keep_cid=polars.Series(numpy.random.rand(df.height) < catalogid_likelihood) | ||
) | ||
df = df.with_columns( | ||
catalogid=polars.when(polars.col.keep_cid).then(polars.col.catalogid), | ||
sdss_id=polars.when(polars.col.keep_cid).then(polars.col.sdss_id), | ||
) | ||
df.drop_in_place("keep_cid") | ||
df = df.sample(df.height) # Shuffle | ||
|
||
df = df.with_columns(polars.int_range(1, df.height + 1).alias("too_id")) | ||
|
||
# Fill out some columns | ||
fiber_type = numpy.array(["APOGEE"] * df.height) | ||
boss_mask = numpy.random.rand(df.height) < 0.5 | ||
fiber_type[numpy.where(boss_mask)[0]] = "BOSS" | ||
|
||
df = df.with_columns( | ||
fiber_type=polars.Series(fiber_type, dtype=polars.String), | ||
observed=False, | ||
active=True, | ||
priority=polars.lit(5, dtype=polars.Int16), | ||
) | ||
|
||
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# @Author: José Sánchez-Gallego ([email protected]) | ||
# @Date: 2024-02-10 | ||
# @Filename: tools.py | ||
# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause) | ||
|
||
from __future__ import annotations | ||
|
||
import pathlib | ||
import shutil | ||
import tempfile | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
import httpx | ||
import rich.progress | ||
|
||
|
||
if TYPE_CHECKING: | ||
import os | ||
|
||
import rich.console | ||
|
||
|
||
__all__ = ["download_file"] | ||
|
||
|
||
def download_file( | ||
url: str, | ||
path: os.PathLike | str, | ||
transient_progress: bool = False, | ||
console: rich.console.Console | None = None, | ||
): | ||
"""Downloads a file from a URL to a local path.""" | ||
|
||
path = pathlib.Path(path).expanduser() | ||
path.mkdir(parents=True, exist_ok=True) | ||
|
||
with tempfile.NamedTemporaryFile() as download_file: | ||
with httpx.stream("GET", url) as response: | ||
total = int(response.headers["Content-Length"]) | ||
|
||
with rich.progress.Progress( | ||
"[progress.percentage]{task.percentage:>3.0f}%", | ||
rich.progress.BarColumn(bar_width=None), | ||
rich.progress.DownloadColumn(), | ||
rich.progress.TransferSpeedColumn(), | ||
transient=transient_progress, | ||
console=console, | ||
) as progress: | ||
download_task = progress.add_task("Download", total=total) | ||
for chunk in response.iter_bytes(): | ||
download_file.write(chunk) | ||
progress.update( | ||
download_task, | ||
completed=response.num_bytes_downloaded, | ||
) | ||
|
||
download_file.flush() | ||
shutil.move(download_file.name, path / pathlib.Path(url).name) |