diff --git a/skycatalogs/catalog_creator.py b/skycatalogs/catalog_creator.py index 3cf36788..eb1d1dc4 100644 --- a/skycatalogs/catalog_creator.py +++ b/skycatalogs/catalog_creator.py @@ -168,16 +168,54 @@ def _do_galaxy_flux_chunk(send_conn, galaxy_collection, instrument_needed, if 'lsst' in instrument_needed: all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list] all_fluxes_transpose = zip(*all_fluxes) - for i, band in enumerate(LSST_BANDS): - v = all_fluxes_transpose.__next__() - out_dict[f'lsst_flux_{band}'] = v + colnames = [f'lsst_flux_{band}' for band in LSST_BANDS] + flux_dict = dict(zip(colnames, all_fluxes_transpose)) + out_dict.update(flux_dict) if 'roman' in instrument_needed: all_fluxes = [o.get_roman_fluxes(as_dict=False) for o in o_list] all_fluxes_transpose = zip(*all_fluxes) - for i, band in enumerate(ROMAN_BANDS): - v = all_fluxes_transpose.__next__() - out_dict[f'roman_flux_{band}'] = v + colnames = [f'roman_flux_{band}' for band in ROMAN_BANDS] + flux_dict = dict(zip(colnames, all_fluxes_transpose)) + out_dict.update(flux_dict) + + if send_conn: + send_conn.send(out_dict) + else: + return out_dict + + +def _do_star_flux_chunk(send_conn, star_collection, instrument_needed, + l_bnd, u_bnd): + ''' + send_conn output connection, used to send results to + parent process + star_collection ObjectCollection. Information from main skyCatalogs + star file + instrument_needed List of which calculations should be done. Currently + supported instrument names are 'lsst' and 'roman' + l_bnd, u_bnd demarcates slice to process + + returns + dict with keys id, lsst_flux_u, ... lsst_flux_y + ''' + out_dict = {} + + o_list = star_collection[l_bnd: u_bnd] + out_dict['id'] = [o.get_native_attribute('id') for o in o_list] + if 'lsst' in instrument_needed: + all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list] + all_fluxes_transpose = zip(*all_fluxes) + colnames = [f'lsst_flux_{band}' for band in LSST_BANDS] + flux_dict = dict(zip(colnames, all_fluxes_transpose)) + out_dict.update(flux_dict) + + if 'roman' in instrument_needed: + all_fluxes = [o.get_roman_fluxes(as_dict=False) for o in o_list] + all_fluxes_transpose = zip(*all_fluxes) + colnames = [f'roman_flux_{band}' for band in ROMAN_BANDS] + flux_dict = dict(zip(colnames, all_fluxes_transpose)) + out_dict.update(flux_dict) if send_conn: send_conn.send(out_dict) @@ -735,8 +773,6 @@ def _create_galaxy_flux_pixel(self, pixel): self._sed_gen.generate_pixel(pixel) writer = None - global _galaxy_collection - global _instrument_needed _instrument_needed = [] for field in self._gal_flux_needed: if 'lsst' in field and 'lsst' not in _instrument_needed: @@ -1000,36 +1036,79 @@ def _create_pointsource_flux_pixel(self, pixel): self._logger.info(f'Skipping regeneration of {output_path}') return + # NOTE: For now there is only one collection in the object list + # because stars are in a single row group object_list = self._cat.get_object_type_by_hp(pixel, 'star') - last_row_ix = len(object_list) - 1 - writer = None + _star_collection = object_list.get_collections()[0] - # Write out as a single rowgroup as was done for main catalog l_bnd = 0 - u_bnd = last_row_ix + 1 + u_bnd = len(_star_collection) + n_parallel = self._flux_parallel + + if n_parallel == 1: + n_per = u_bnd - l_bnd + else: + n_per = int((u_bnd - l_bnd + n_parallel)/n_parallel) + fields_needed = self._ps_flux_schema.names + instrument_needed = ['lsst'] # for now + + writer = None + rg_written = 0 + + lb = l_bnd + u = min(l_bnd + n_per, u_bnd) + readers = [] + + if n_parallel == 1: + out_dict = _do_star_flux_chunk(None, _star_collection, + instrument_needed, lb, u) + else: + # Expect to be able to do about 1500/minute/process + out_dict = {} + for field in fields_needed: + out_dict[field] = [] + + tm = max(int((n_per*60)/500), 5) # Give ourselves a cushion + self._logger.info(f'Using timeout value {tm} for {n_per} sources') + p_list = [] + for i in range(n_parallel): + conn_rd, conn_wrt = Pipe(duplex=False) + readers.append(conn_rd) + + # For debugging call directly + proc = Process(target=_do_star_flux_chunk, + name=f'proc_{i}', + args=(conn_wrt, _star_collection, + instrument_needed, lb, u)) + proc.start() + p_list.append(proc) + lb = u + u = min(lb + n_per, u_bnd) + + self._logger.debug('Processes started') + for i in range(n_parallel): + ready = readers[i].poll(tm) + if not ready: + self._logger.error(f'Process {i} timed out after {tm} sec') + sys.exit(1) + dat = readers[i].recv() + for field in fields_needed: + out_dict[field] += dat[field] + for p in p_list: + p.join() - o_list = object_list[l_bnd: u_bnd] - self._logger.debug(f'Handling range {l_bnd} up to {u_bnd}') - out_dict = {} - out_dict['id'] = [o.get_native_attribute('id') for o in o_list] - all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list] - all_fluxes_transpose = zip(*all_fluxes) - for i, band in enumerate(LSST_BANDS): - self._logger.debug(f'Band {band} is number {i}') - v = all_fluxes_transpose.__next__() - out_dict[f'lsst_flux_{band}'] = v - if i == 1: - self._logger.debug(f'Len of flux column: {len(v)}') - self._logger.debug(f'Type of flux column: {type(v)}') out_df = pd.DataFrame.from_dict(out_dict) out_table = pa.Table.from_pandas(out_df, schema=self._ps_flux_schema) + if not writer: writer = pq.ParquetWriter(output_path, self._ps_flux_schema) writer.write_table(out_table) + rg_written += 1 + writer.close() - # self._logger.debug(f'#row groups written to flux file: {rg_written}') + self._logger.debug(f'# row groups written to flux file: {rg_written}') if self._provenance == 'yaml': self.write_provenance_file(output_path) diff --git a/skycatalogs/utils/config_utils.py b/skycatalogs/utils/config_utils.py index 00450b89..23a01228 100644 --- a/skycatalogs/utils/config_utils.py +++ b/skycatalogs/utils/config_utils.py @@ -150,7 +150,13 @@ def get_config_value(self, key_path, silent=False): d = d[i] if not isinstance(d, dict): raise ValueError(f'intermediate {d} is not a dict') - return d[path_items[-1]] + + if path_items[-1] in d: + return d[path_items[-1]] + else: + if silent: + return None + raise ValueError(f'Item {i} not found') def add_key(self, k, v): '''