diff --git a/cfg/object_types.yaml b/cfg/object_types.yaml index d6d58edf..56f30384 100644 --- a/cfg/object_types.yaml +++ b/cfg/object_types.yaml @@ -66,7 +66,7 @@ object_types : internal_extinction: CCM MW_extinction: F19 spatial_model: knots - + star: file_template: 'pointsource_(?P\d+).parquet' flux_file_template: 'pointsource_flux_(?P\d+).parquet' @@ -77,15 +77,6 @@ object_types : sed_file_root_env_var: SIMS_SED_LIBRARY_DIR MW_extinction: F19 internal_extinction: None - sncosmo: - file_template: 'pointsource_(?P\d+).parquet' - data_file_type: parquet - area_partition: - { type: healpix, ordering: ring, nside: 32} - - sed_model: sncosmo - MW_extinction: F19 - internal_extinction: None gaia_star: data_file_type: butler_refcat butler_parameters: diff --git a/etc/conda_requirements.txt b/etc/conda_requirements.txt index 72ebe0b5..50ff97e7 100644 --- a/etc/conda_requirements.txt +++ b/etc/conda_requirements.txt @@ -2,4 +2,4 @@ stackvana>=0.2023.32 gitpython -sncosmo +# sncosmo diff --git a/pyproject.toml b/pyproject.toml index ea881def..c332194b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ dependencies = [ 'astropy', 'pyarrow', 'pandas', - 'sncosmo', ] requires-python = ">=3.7" # For setuptools >= 61.0 support diff --git a/skycatalogs/catalog_creator.py b/skycatalogs/catalog_creator.py index fe653318..7e82cfda 100644 --- a/skycatalogs/catalog_creator.py +++ b/skycatalogs/catalog_creator.py @@ -20,7 +20,7 @@ from .utils.parquet_schema_utils import make_galaxy_schema from .utils.parquet_schema_utils import make_galaxy_flux_schema from .utils.parquet_schema_utils import make_star_flux_schema -from .utils.parquet_schema_utils import make_pointsource_schema +from .utils.parquet_schema_utils import make_star_schema from .utils.creator_utils import make_MW_extinction_av, make_MW_extinction_rv from .objects.base_object import LSST_BANDS from .objects.base_object import ROMAN_BANDS @@ -227,7 +227,7 @@ def _do_star_flux_chunk(send_conn, star_collection, instrument_needed, class CatalogCreator: def __init__(self, parts, area_partition=None, skycatalog_root=None, catalog_dir='.', galaxy_truth=None, - star_truth=None, sn_truth=None, + star_truth=None, config_path=None, catalog_name='skyCatalog', output_type='parquet', mag_cut=None, sed_subdir='galaxyTopHatSED', knots_mag_cut=27.0, @@ -295,12 +295,14 @@ def __init__(self, parts, area_partition=None, skycatalog_root=None, _cosmo_cat = 'cosmodc2_v1.1.4_image_addon_knots' _diffsky_cat = 'roman_rubin_2023_v1.1.2_elais' _star_db = '/global/cfs/cdirs/lsst/groups/SSim/DC2/dc2_stellar_healpixel.db' - _sn_db = '/global/cfs/cdirs/lsst/groups/SSim/DC2/cosmoDC2_v1.1.4/sne_cosmoDC2_v1.1.4_MS_DDF_healpix.db' - # _star_parquet = '/global/cfs/cdirs/descssim/postDC2/UW_star_catalog' _star_parquet = '/sdf/data/rubin/shared/ops-rehearsal-3/imSim_catalogs/UW_stars' self._galaxy_stride = galaxy_stride + + # Temporary. Should add a separate star_stride argument or change name + # e.g. galaxy_stride --> catalog_stride + self._star_stride = galaxy_stride if pkg_root: self._pkg_root = pkg_root else: @@ -322,10 +324,6 @@ def __init__(self, parts, area_partition=None, skycatalog_root=None, else: self._galaxy_truth = _diffsky_cat - self._sn_truth = sn_truth - if self._sn_truth is None: - self._sn_truth = _sn_db - self._sn_object_type = sn_object_type self._star_truth = star_truth @@ -398,7 +396,7 @@ def _make_tophat_columns(self, dat, names, cmp): dat[cmp + '_magnorm'] = [self._obs_sed_factory.magnorm(s, z) for (s, z) in zip(sed_vals, dat['redshiftHubble'])] for k in names: - del(dat[k]) + del dat[k] return dat def create(self, catalog_type): @@ -792,6 +790,7 @@ def _create_galaxy_flux_pixel(self, pixel): writer = None _instrument_needed = [] + rg_written = 0 for field in self._gal_flux_needed: if 'lsst' in field and 'lsst' not in _instrument_needed: _instrument_needed.append('lsst') @@ -805,7 +804,6 @@ def _create_galaxy_flux_pixel(self, pixel): _ = object_coll.get_native_attribute(att) l_bnd = 0 u_bnd = len(object_coll) - rg_written = 0 self._logger.debug(f'Handling range {l_bnd} up to {u_bnd}') @@ -886,24 +884,23 @@ def create_pointsource_catalog(self): ------- None """ - arrow_schema = make_pointsource_schema() + arrow_schema = make_star_schema() # Need a way to indicate which object types to include; deal with that # later. For now, default is stars + sn for p in self._parts: self._logger.debug(f'Point sources. Starting on pixel {p}') self.create_pointsource_pixel(p, arrow_schema, - star_cat=self._star_truth, - sn_cat=self._sn_truth) + star_cat=self._star_truth) self._logger.debug(f'Completed pixel {p}') - def create_pointsource_pixel(self, pixel, arrow_schema, star_cat=None, - sn_cat=None): - if not star_cat and not sn_cat: - self._logger.info('No point source inputs specified') + def create_pointsource_pixel(self, pixel, arrow_schema, star_cat=None): + if not star_cat: + self._logger.info('No star input specified') return output_filename = f'pointsource_{pixel}.parquet' output_path = os.path.join(self._output_dir, output_filename) + stride = self._star_stride if os.path.exists(output_path): if not self._skip_done: @@ -914,77 +911,53 @@ def create_pointsource_pixel(self, pixel, arrow_schema, star_cat=None, writer = pq.ParquetWriter(output_path, arrow_schema) - if star_cat: - # Get data for this pixel - if self._star_input_fmt == 'sqlite': - cols = ','.join(['format("%s",simobjid) as id', 'ra', - 'decl as dec', 'magNorm as magnorm', 'mura', - 'mudecl as mudec', - 'radialVelocity as radial_velocity', - 'parallax', - 'sedFilename as sed_filepath', 'ebv']) - q = f'select {cols} from stars where hpid={pixel} ' - with sqlite3.connect(star_cat) as conn: - star_df = pd.read_sql_query(q, conn) - elif self._star_input_fmt == 'parquet': - star_df = _star_parquet_reader(self._star_truth, pixel, - arrow_schema) - nobj = len(star_df['id']) - self._logger.debug(f'Found {nobj} stars') - star_df['sed_filepath'] = get_star_sed_path(star_df['sed_filepath']) - star_df['object_type'] = np.full((nobj,), 'star') - star_df['host_galaxy_id'] = np.zeros((nobj,), np.int64()) - - star_df['MW_rv'] = np.full((nobj,), _MW_rv_constant, np.float32()) - - # NOTE MW_av calculation for stars does not use SFD dust map - star_df['MW_av'] = star_df['ebv'] * _MW_rv_constant - - star_df['variability_model'] = np.full((nobj,), '') - star_df['salt2_params'] = np.full((nobj,), None) - out_table = pa.Table.from_pandas(star_df, schema=arrow_schema) + # Get data for this pixel + if self._star_input_fmt == 'sqlite': + cols = ','.join(['format("%s",simobjid) as id', 'ra', + 'decl as dec', 'magNorm as magnorm', 'mura', + 'mudecl as mudec', + 'radialVelocity as radial_velocity', + 'parallax', + 'sedFilename as sed_filepath', 'ebv']) + q = f'select {cols} from stars where hpid={pixel} ' + with sqlite3.connect(star_cat) as conn: + star_df = pd.read_sql_query(q, conn) + elif self._star_input_fmt == 'parquet': + star_df = _star_parquet_reader(self._star_truth, pixel, + arrow_schema) + nobj = len(star_df['id']) + self._logger.debug(f'Found {nobj} stars') + if nobj == 0: + return + star_df['sed_filepath'] = get_star_sed_path(star_df['sed_filepath']) + star_df['object_type'] = np.full((nobj,), 'star') + star_df['host_galaxy_id'] = np.zeros((nobj,), np.int64()) + + star_df['MW_rv'] = np.full((nobj,), _MW_rv_constant, np.float32()) + + # NOTE MW_av calculation for stars does not use SFD dust map + star_df['MW_av'] = star_df['ebv'] * _MW_rv_constant + + star_df['variability_model'] = np.full((nobj,), '') + star_df['salt2_params'] = np.full((nobj,), None) + + last_row_ix = nobj - 1 + u_bnd = min(stride, nobj) + l_bnd = 0 + rg_written = 0 + + while u_bnd > l_bnd: + out_dict = {k: star_df[k][l_bnd: u_bnd] for k in star_df.columns} + out_df = pd.DataFrame.from_dict(out_dict) + + out_table = pa.Table.from_pandas(out_df, schema=arrow_schema) self._logger.debug('Created arrow table from star dataframe') + # write a row broup writer.write_table(out_table) - - if sn_cat: - # Get data for this pixel - cols = ','.join(['snid_in as id', 'snra_in as ra', - 'sndec_in as dec', 'galaxy_id as host_galaxy_id']) - - params = ','.join(['z_in as z', 't0_in as t0, x0_in as x0', - 'x1_in as x1', 'c_in as c']) - - q1 = f'select {cols} from sne_params where hpid={pixel} ' - q2 = f'select {params} from sne_params where hpid={pixel} ' - with sqlite3.connect(sn_cat) as conn: - sn_df = pd.read_sql_query(q1, conn) - params_df = pd.read_sql_query(q2, conn) - - nobj = len(sn_df['ra']) - if nobj > 0: - sn_df['object_type'] = np.full((nobj,), self._sn_object_type) - - sn_df['MW_rv'] = make_MW_extinction_rv(sn_df['ra'], - sn_df['dec']) - sn_df['MW_av'] = make_MW_extinction_av(sn_df['ra'], - sn_df['dec']) - - # Add fillers for columns not relevant for sn - sn_df['sed_filepath'] = np.full((nobj), '') - sn_df['magnorm'] = np.full((nobj,), None) - sn_df['mura'] = np.full((nobj,), None) - sn_df['mudec'] = np.full((nobj,), None) - sn_df['radial_velocity'] = np.full((nobj,), None) - sn_df['parallax'] = np.full((nobj,), None) - sn_df['variability_model'] = np.full((nobj,), 'salt2_extended') - - # Form array of struct from params_df - sn_df['salt2_params'] = params_df.to_records(index=False) - out_table = pa.Table.from_pandas(sn_df, schema=arrow_schema) - self._logger.debug('Created arrow table from sn dataframe') - - writer.write_table(out_table) + rg_written += 1 + l_bnd = u_bnd + u_bnd = min(l_bnd + stride, last_row_ix + 1) writer.close() if self._provenance == 'yaml': @@ -1053,77 +1026,77 @@ def _create_pointsource_flux_pixel(self, pixel): else: self._logger.info(f'Skipping regeneration of {output_path}') return - - # NOTE: For now there is only one collection in the object list - # because stars are in a single row group - object_list = self._cat.get_object_type_by_hp(pixel, 'star') - _star_collection = object_list.get_collections()[0] - - l_bnd = 0 - u_bnd = len(_star_collection) n_parallel = self._flux_parallel - if n_parallel == 1: - n_per = u_bnd - l_bnd - else: - n_per = int((u_bnd - l_bnd + n_parallel)/n_parallel) - fields_needed = self._ps_flux_schema.names - instrument_needed = ['lsst'] # for now - + object_list = self._cat.get_object_type_by_hp(pixel, 'star') writer = None + instrument_needed = ['lsst'] # for now rg_written = 0 + fields_needed = self._ps_flux_schema.names - lb = l_bnd - u = min(l_bnd + n_per, u_bnd) - readers = [] + for i in range(object_list.collection_count): + _star_collection = object_list.get_collections()[i] + + l_bnd = 0 + u_bnd = len(_star_collection) + out_dict = {} - if n_parallel == 1: - out_dict = _do_star_flux_chunk(None, _star_collection, - instrument_needed, lb, u) - else: - # Expect to be able to do about 1500/minute/process out_dict = {} for field in fields_needed: out_dict[field] = [] - tm = max(int((n_per*60)/500), 5) # Give ourselves a cushion - self._logger.info(f'Using timeout value {tm} for {n_per} sources') - p_list = [] - for i in range(n_parallel): - conn_rd, conn_wrt = Pipe(duplex=False) - readers.append(conn_rd) - - # For debugging call directly - proc = Process(target=_do_star_flux_chunk, - name=f'proc_{i}', - args=(conn_wrt, _star_collection, - instrument_needed, lb, u)) - proc.start() - p_list.append(proc) - lb = u - u = min(lb + n_per, u_bnd) - - self._logger.debug('Processes started') - for i in range(n_parallel): - ready = readers[i].poll(tm) - if not ready: - self._logger.error(f'Process {i} timed out after {tm} sec') - sys.exit(1) - dat = readers[i].recv() - for field in fields_needed: - out_dict[field] += dat[field] - for p in p_list: - p.join() - - out_df = pd.DataFrame.from_dict(out_dict) - out_table = pa.Table.from_pandas(out_df, - schema=self._ps_flux_schema) - - if not writer: - writer = pq.ParquetWriter(output_path, self._ps_flux_schema) - writer.write_table(out_table) - - rg_written += 1 + if n_parallel == 1: + n_per = u_bnd - l_bnd + else: + n_per = int((u_bnd - l_bnd + n_parallel)/n_parallel) + + lb = l_bnd + u = min(l_bnd + n_per, u_bnd) + readers = [] + + if n_parallel == 1: + out_dict = _do_star_flux_chunk(None, _star_collection, + instrument_needed, lb, u) + else: + # Expect to be able to do about 1500/minute/process + + tm = max(int((n_per*60)/500), 5) # Give ourselves a cushion + self._logger.info(f'Using timeout value {tm} for {n_per} sources') + p_list = [] + for i in range(n_parallel): + conn_rd, conn_wrt = Pipe(duplex=False) + readers.append(conn_rd) + + # For debugging call directly + proc = Process(target=_do_star_flux_chunk, + name=f'proc_{i}', + args=(conn_wrt, _star_collection, + instrument_needed, lb, u)) + proc.start() + p_list.append(proc) + lb = u + u = min(lb + n_per, u_bnd) + + self._logger.debug('Processes started') + for i in range(n_parallel): + ready = readers[i].poll(tm) + if not ready: + self._logger.error(f'Process {i} timed out after {tm} sec') + sys.exit(1) + dat = readers[i].recv() + for field in fields_needed: + out_dict[field] += dat[field] + for p in p_list: + p.join() + + out_df = pd.DataFrame.from_dict(out_dict) + out_table = pa.Table.from_pandas(out_df, + schema=self._ps_flux_schema) + + if not writer: + writer = pq.ParquetWriter(output_path, self._ps_flux_schema) + writer.write_table(out_table) + rg_written += 1 writer.close() self._logger.debug(f'# row groups written to flux file: {rg_written}') @@ -1174,8 +1147,6 @@ def write_config(self, overwrite=False, path_only=False): config.add_key('knots_magnitude_cut', self._knots_mag_cut) inputs = {'galaxy_truth': self._galaxy_truth} - if self._sn_truth: - inputs['sn_truth'] = self._sn_truth if self._star_truth: inputs['star_truth'] = self._star_truth if self._sso_truth: diff --git a/skycatalogs/data/ci_sample/skyCatalog.yaml b/skycatalogs/data/ci_sample/skyCatalog.yaml index 79a81e66..877a3b32 100644 --- a/skycatalogs/data/ci_sample/skyCatalog.yaml +++ b/skycatalogs/data/ci_sample/skyCatalog.yaml @@ -129,7 +129,6 @@ object_types: provenance: inputs: galaxy_truth: cosmodc2_v1.1.4_image_addon_knots - sn_truth: /global/cfs/cdirs/lsst/groups/SSim/DC2/cosmoDC2_v1.1.4/sne_cosmoDC2_v1.1.4_MS_DDF_healpix.db star_truth: /global/cfs/cdirs/lsst/groups/SSim/DC2/dc2_stellar_healpixel.db skyCatalogs_repo: git_branch: master diff --git a/skycatalogs/data/ci_sample/skyCatalog_top.yaml b/skycatalogs/data/ci_sample/skyCatalog_top.yaml index 69de030e..b2c77c6f 100644 --- a/skycatalogs/data/ci_sample/skyCatalog_top.yaml +++ b/skycatalogs/data/ci_sample/skyCatalog_top.yaml @@ -94,7 +94,6 @@ object_types: provenance: inputs: galaxy_truth: cosmodc2_v1.1.4_image_addon_knots - sn_truth: /global/cfs/cdirs/lsst/groups/SSim/DC2/cosmoDC2_v1.1.4/sne_cosmoDC2_v1.1.4_MS_DDF_healpix.db star_truth: /global/cfs/cdirs/lsst/groups/SSim/DC2/dc2_stellar_healpixel.db skyCatalogs_repo: git_branch: master diff --git a/skycatalogs/objects/sncosmo_object.py b/skycatalogs/objects/sncosmo_object.py deleted file mode 100644 index 66c7ac94..00000000 --- a/skycatalogs/objects/sncosmo_object.py +++ /dev/null @@ -1,57 +0,0 @@ -import galsim -from skycatalogs.utils.sn_tools import SncosmoModel -from .base_object import BaseObject,ObjectCollection - - -__all__ = ['SncosmoObject'] - - -class SncosmoObject(BaseObject): - _type_name = 'sncosmo' - - def _get_sed(self, mjd=None): - params = self.get_native_attribute('salt2_params') - sn = SncosmoModel(params=params) - - if mjd < sn.mintime() or mjd > sn.maxtime(): - return None - return sn.get_sed(mjd) - - def get_gsobject_components(self, gsparams=None, rng=None): - if gsparams is not None: - gsparams = galsim.GSParams(**gsparams) - return {'this_object': galsim.DeltaFunction(gsparams=gsparams)} - - def get_observer_sed_component(self, component, mjd=None): - if mjd is None: - mjd = self._belongs_to._mjd - if mjd is None: - txt = 'SncosmoObject._get_sed: no mjd specified for this call\n' - txt += 'nor when generating object list' - raise ValueError(txt) - sed = self._get_sed(mjd=mjd) - if sed is not None: - sed = self._apply_component_extinction(sed) - return sed - - def get_LSST_flux(self, band, sed=None, cache=False, mjd=None): - # There is usually no reason to cache flux for SNe, in fact it could - # cause problems. If flux has been cached and then this routine - # is called again with a different value of mjd, it would - # return the wrong answer. - return super().get_LSST_flux(band, sed=sed, cache=cache, mjd=mjd) - - -class SncosmoCollection(ObjectCollection): - ''' - This class only exists in order to issue a warning if mjd is None - ''' - def __init__(self, ra, dec, id, object_type, partition_id, sky_catalog, - region=None, mjd=None, mask=None, readers=None, row_group=0): - # Normally mjd should be specified - if mjd is None: - sky_catalog._logger.warning('Creating SncosmoCollection with no mjd value.') - sky_catalog._logger.warning('Transient collections normally have non-None mjd') - super().__init__(ra, dec, id, object_type, partition_id, - sky_catalog, region=region, mjd=mjd, mask=mask, - readers=readers, row_group=row_group) diff --git a/skycatalogs/skyCatalogs.py b/skycatalogs/skyCatalogs.py index 37b51af3..5a80c20e 100644 --- a/skycatalogs/skyCatalogs.py +++ b/skycatalogs/skyCatalogs.py @@ -12,7 +12,6 @@ from skycatalogs.objects.gaia_object import GaiaObject, GaiaCollection from skycatalogs.objects.sso_object import SsoObject, SsoCollection from skycatalogs.objects.sso_object import EXPOSURE_DEFAULT -# from skycatalogs.objects.sso_object import find_sso_files from skycatalogs.readers import ParquetReader from skycatalogs.utils.sed_tools import TophatSedFactory, DiffskySedFactory from skycatalogs.utils.sed_tools import SsoSedFactory @@ -20,7 +19,6 @@ from skycatalogs.utils.config_utils import Config from skycatalogs.utils.shapes import Box, Disk, PolygonalRegion from skycatalogs.utils.shapes import compute_region_mask -from skycatalogs.objects.sncosmo_object import SncosmoObject, SncosmoCollection from skycatalogs.objects.star_object import StarObject from skycatalogs.objects.galaxy_object import GalaxyObject from skycatalogs.objects.diffsky_object import DiffskyObject @@ -347,11 +345,6 @@ def __init__(self, config, mp=False, skycatalog_root=None, verbose=False, object_class=GaiaObject, collection_class=GaiaCollection, custom_load=True) - if 'sncosmo' in config['object_types']: - self.cat_cxt.register_source_type( - 'sncosmo', - object_class=SncosmoObject, - collection_class=SncosmoCollection) if 'star' in config['object_types']: self.cat_cxt.register_source_type('star', object_class=StarObject) @@ -660,7 +653,7 @@ def get_object_type_by_hp(self, hp, object_type, region=None, mjd=None, elif object_type in ['snana']: columns = ['id', 'ra', 'dec', 'start_mjd', 'end_mjd'] id_name = 'id' - elif object_type in ['star', 'sncosmo']: + elif object_type in ['star']: columns = ['object_type', 'id', 'ra', 'dec'] id_name = 'id' elif object_type in ['sso']: @@ -810,4 +803,4 @@ def open_catalog(config_file, mp=False, skycatalog_root=None, verbose=False): config_dict = open_config_file(config_file) return SkyCatalog(config_dict, skycatalog_root=skycatalog_root, mp=mp, - verbose=verbose) \ No newline at end of file + verbose=verbose) diff --git a/skycatalogs/utils/__init__.py b/skycatalogs/utils/__init__.py index a284dbdd..99ce4c5b 100644 --- a/skycatalogs/utils/__init__.py +++ b/skycatalogs/utils/__init__.py @@ -4,6 +4,5 @@ from .exceptions import * from .parquet_schema_utils import * from .sed_tools import * -from .sn_tools import * from .shapes import * from .creator_utils import * diff --git a/skycatalogs/utils/parquet_schema_utils.py b/skycatalogs/utils/parquet_schema_utils.py index 178e4483..54b45bd2 100644 --- a/skycatalogs/utils/parquet_schema_utils.py +++ b/skycatalogs/utils/parquet_schema_utils.py @@ -2,7 +2,7 @@ import logging __all__ = ['make_galaxy_schema', 'make_galaxy_flux_schema', - 'make_pointsource_schema', 'make_star_flux_schema'] + 'make_star_schema', 'make_star_flux_schema'] def _add_roman_fluxes(fields): @@ -140,19 +140,10 @@ def make_star_flux_schema(logname, include_roman_flux=False): return pa.schema(fields) -def make_pointsource_schema(): +def make_star_schema(): ''' - Ultimately should handle stars both static and variable, SN, and AGN - For now add everything needed for SN and put in some additional - star fields, but not structs for star variability models + Just for "regular" stars. ''' - - salt2_fields = [ - pa.field('z', pa.float64(), True), - pa.field('t0', pa.float64(), True), - pa.field('x0', pa.float64(), True), - pa.field('x1', pa.float64(), True), - pa.field('c', pa.float64(), True)] fields = [pa.field('object_type', pa.string(), False), pa.field('id', pa.string(), False), pa.field('ra', pa.float64(), False), @@ -167,28 +158,5 @@ def make_pointsource_schema(): pa.field('radial_velocity', pa.float64(), True), pa.field('parallax', pa.float64(), True), pa.field('variability_model', pa.string(), True), - pa.field('salt2_params', pa.struct(salt2_fields), True) ] return pa.schema(fields) - - -def make_pointsource_flux_schema(logname, include_roman_flux=False): - ''' - Will make a separate parquet file with lsst flux for each band - and id for joining with the main star file. - For static sources mjd field could be -1. Or the field could be - made nullable. - ''' - logger = logging.getLogger(logname) - logger.debug('Creating pointsource flux schema') - fields = [pa.field('id', pa.string()), - pa.field('lsst_flux_u', pa.float32(), True), - pa.field('lsst_flux_g', pa.float32(), True), - pa.field('lsst_flux_r', pa.float32(), True), - pa.field('lsst_flux_i', pa.float32(), True), - pa.field('lsst_flux_z', pa.float32(), True), - pa.field('lsst_flux_y', pa.float32(), True), - pa.field('mjd', pa.float64(), True)] - if include_roman_flux: - fields = _add_roman_fluxes(fields) - return pa.schema(fields) diff --git a/skycatalogs/utils/sn_tools.py b/skycatalogs/utils/sn_tools.py deleted file mode 100644 index 842fb3ce..00000000 --- a/skycatalogs/utils/sn_tools.py +++ /dev/null @@ -1,43 +0,0 @@ -import numpy as np -from astropy import units as u -import sncosmo -import galsim - -__all__ = ['SncosmoModel'] - - -class SncosmoModel(sncosmo.Model): - def __init__(self, source='salt2-extended', params=None): - ''' - params - dict of params suitable for the model - - See also https://sncosmo.readthedocs.io/en/stable/index.html - - ''' - # The following explicitly turns off host and Milky Way extinction. - dust = sncosmo.F99Dust() - super().__init__(source=source, - effects=[dust, dust], - effect_names=['host', 'mw'], - effect_frames=['rest', 'obs']) - self.set(mwebv=0., hostebv=0.) - self.redshift = 0 - if params: - self.set(**params) - self.redshift = params['z'] - - def get_sed(self, mjd, npts=1000): - """ - Return the SED in the observer frame at the requested time. - """ - wl = np.linspace(self.minwave(), self.maxwave(), npts) - - # prepend 0 bins - n_bins = int(self.minwave()) - 1 - pre_wl = [float(i) for i in range(n_bins)] - pre_val = [0.0 for i in range(n_bins)] - flambda = self.flux(mjd, wl) - wl = np.insert(wl, 0, pre_wl) - flambda = np.insert(flambda, 0, pre_val) - lut = galsim.LookupTable(wl, flambda, interpolant='linear') - return galsim.SED(lut, wave_type=u.Angstrom, flux_type='flambda') diff --git a/tests/test_gaia_direct.py b/tests/test_gaia_direct.py index 056b8afa..9e4f6f78 100644 --- a/tests/test_gaia_direct.py +++ b/tests/test_gaia_direct.py @@ -76,9 +76,9 @@ def test_proper_motion(self): gaia_id = obj.id.split('_')[-1] df = self.df.query(f"id=={gaia_id}") row = df.iloc[0] - self.assertAlmostEqual(np.degrees(row.coord_ra), obj.ra, places=15) + self.assertAlmostEqual(np.degrees(row.coord_ra), obj.ra, places=14) self.assertAlmostEqual(np.degrees(row.coord_dec), obj.dec, - places=15) + places=14) self.assertEqual(mjd0, object_list.mjd) diff --git a/tests/test_gaia_objects.py b/tests/test_gaia_objects.py index 813812bb..920f693a 100644 --- a/tests/test_gaia_objects.py +++ b/tests/test_gaia_objects.py @@ -65,9 +65,9 @@ def test_proper_motion(self): gaia_id = obj.id.split('_')[-1] df = self.df.query(f"id=={gaia_id}") row = df.iloc[0] - self.assertAlmostEqual(np.degrees(row.coord_ra), obj.ra, places=15) + self.assertAlmostEqual(np.degrees(row.coord_ra), obj.ra, places=14) self.assertAlmostEqual(np.degrees(row.coord_dec), obj.dec, - places=15) + places=14) self.assertEqual(mjd0, object_list.mjd) diff --git a/tests/test_pointsource.py b/tests/test_pointsource.py index ed7e8605..0f03f17a 100644 --- a/tests/test_pointsource.py +++ b/tests/test_pointsource.py @@ -1,110 +1,9 @@ import os from pathlib import Path -import numpy as np -import matplotlib.pyplot as plt -import sncosmo - - -from skycatalogs.skyCatalogs import SkyCatalog, open_catalog -from skycatalogs.objects.base_object import BaseObject, load_lsst_bandpasses -from skycatalogs.utils.sn_tools import SncosmoModel +from skycatalogs.skyCatalogs import open_catalog PIXEL = 9556 -def explore_lc(obj): - if obj.object_type != 'sn': - print('No light curve for object of type ', obj.object_type) - return - - for native in ['id', 'ra', 'dec', 'salt2_params']: - print(native,'=',obj.get_native_attribute(native)) - - # get fluxes for some days around t0 - params = obj.get_native_attribute('salt2_params') - t0 = params['t0'] - - - fluxes_mjd = dict() - dt_rng = '-5_30_3' - #for dt in np.arange(0, 40, 2): - fluxes_mjd[-50] = obj.get_LSST_fluxes(cache=False, mjd=(t0 - 50)) - for dt in np.arange(-5, 30, 3): - fluxes_mjd[dt] = obj.get_LSST_fluxes(cache=False, mjd=(t0 + dt)) - - # Also add far-out points which could be outside time interval sn - # is active - - fluxes_mjd[50] = obj.get_LSST_fluxes(cache=False, mjd=(t0 + 50)) - - plt.figure() - by_band = dict() - for b in 'ugrizy': - print('fluxes for band ', b) - band_fluxes = [] - for dt in fluxes_mjd.keys(): - print(f'dt: {dt} flux: {fluxes_mjd[dt][b]}') - band_fluxes.append(fluxes_mjd[dt][b]) - # Also plot flux vs dt - plt.plot(list(fluxes_mjd.keys()), band_fluxes, label=f'{b}') - plt.legend(fontsize='x-small', ncol=2) - plt.title(f'{obj.id}') - plt.xlabel('dt (days)') - plt.ylabel('flux') - # _nm suffix probably means "no magnorm" - plt.savefig(f'{obj.id}_dt_fluxes_nm.png') - - plt.close() - - # also plot SEDs - sn_obj = SncosmoModel(params=params) - - - plt.figure() - for dt in np.arange(-5, 30, 3): - sed = sn_obj.get_sed(t0 + dt) - plt.plot(sed.wave_list, sed(sed.wave_list), label=f'{dt}') - plt.yscale('log') - ##plt.ylim(3e-8, 2e-3) - plt.legend(fontsize='x-small', ncol=2) - plt.xlabel('wavelength (nm)') - plt.ylabel('photons/nm/cm^2/s') - plt.savefig(f'{obj.id}_seds.png') - plt.close() - - -def make_sncosmo_lc(obj): - params = obj.get_native_attribute('salt2_params') - - # Set up the sncosmo source - src = sncosmo.Model(source='salt2-extended') - src.set(**params) - - bandpasses = load_lsst_bandpasses() - sncosmo_bandpasses = [] - for nm,val in bandpasses.items(): - snc_bp = sncosmo.Bandpass(val.wave_list, [val(wv) for wv in val.wave_list], name=nm, wave_unit='nm') - sncosmo_bandpasses.append(snc_bp) - - dt_start = -5 - dt_end = 30 - dt_incr = 3 - t0 = params['t0'] - - dt_rng = '-5_30_3' - plt.figure() - #for b_name, b in bandpasses.items(): - for b in sncosmo_bandpasses: - fluxes = [] - rel_times = [] - for dt in np.arange(dt_start, dt_end, dt_incr): - fluxes.append(src.bandflux(b, t0 + dt)) - rel_times.append(dt) - plt.plot(rel_times, fluxes, label=f'{b.name}') - plt.legend(fontsize='x-small', ncol=2) - plt.xlabel('dt (days)') - plt.ylabel('flux') - plt.savefig(f'sn_cosmo_{obj.id}_dt{dt_rng}_fluxes.png') - def explore(cat, obj_type, ix_list=[0]): obj_list = cat.get_object_type_by_hp(PIXEL, obj_type) @@ -126,49 +25,16 @@ def explore(cat, obj_type, ix_list=[0]): ' belongs_index=', obj0._belongs_index) all_cols = obj0.native_columns - if obj0.object_type == 'sn': - extras = {'lsst_flux_u', 'lsst_flux_g', 'lsst_flux_r', - 'lsst_flux_i', 'lsst_flux_z', 'lsst_flux_y', 'mjd'} - all_cols.difference_update(extras) - for native in all_cols: # obj0.native_columns: - print(native,'=',obj0.get_native_attribute(native)) + for native in all_cols: + print(native, '=', obj0.get_native_attribute(native)) icoll += 1 - if obj0.object_type == 'sn': - for ix in ix_list: - print('sn object index is ', ix) - explore_lc(c[ix]) - ##make_sncosmo_lc(c[ix]) - # # get fluxes for some days around t0 - # params = obj0.get_native_attribute('salt2_params') - # t0 = params['t0'] - # fluxes = dict() - # for dt in np.arange(-5, 30, 3): - # fluxes[dt] = obj0.get_LSST_fluxes(cache=False, mjd=(t0 + dt)) - - # for b in 'ugrizy': - # print('fluxes for band ', b) - # for dt in fluxes.keys(): - # print(f'dt: {dt} flux: {fluxes[dt][b]}') skycatalog_root = os.path.join(Path(__file__).resolve().parents[1], 'skycatalogs', 'data') config_path = os.path.join(skycatalog_root, 'ci_sample', 'skyCatalog.yaml') -#skycatalog_root = os.getenv('CFS_SKY_ROOT') -#config_path = os.path.join(skycatalog_root, 'point_test', 'skyCatalog.yaml') - cat = open_catalog(config_path, skycatalog_root=skycatalog_root) -#print('Explore star collection') -#explore(cat, {'star'}) - -print('explore sn collection') -###explore(cat, {'sn'}, ix_list = [0, 3, 100]) -## have tried ix 3, 7, 100. 105, 10 -## For ix 100 there are no visible seds. For 105 almost none -##explore(cat, {'sn'}, ix_list = [3,7,10,100, 105]) -##explore(cat, {'sn'}, ix_list = [202]) -explore(cat, 'star', ix_list = [10]) -#print('explore both sn and star') -#explore(cat, {'sn', 'star'}) +print('Explore star collection') +explore(cat, 'star', ix_list=[10])