Skip to content

Commit

Permalink
Merge pull request #85 from LSSTDESC/u/jrbogart/parallel_star
Browse files Browse the repository at this point in the history
allow parallelization for star flux computation
  • Loading branch information
JoanneBogart authored May 16, 2024
2 parents ed55148 + d0f7962 commit e6b60f3
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 27 deletions.
131 changes: 105 additions & 26 deletions skycatalogs/catalog_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,54 @@ def _do_galaxy_flux_chunk(send_conn, galaxy_collection, instrument_needed,
if 'lsst' in instrument_needed:
all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list]
all_fluxes_transpose = zip(*all_fluxes)
for i, band in enumerate(LSST_BANDS):
v = all_fluxes_transpose.__next__()
out_dict[f'lsst_flux_{band}'] = v
colnames = [f'lsst_flux_{band}' for band in LSST_BANDS]
flux_dict = dict(zip(colnames, all_fluxes_transpose))
out_dict.update(flux_dict)

if 'roman' in instrument_needed:
all_fluxes = [o.get_roman_fluxes(as_dict=False) for o in o_list]
all_fluxes_transpose = zip(*all_fluxes)
for i, band in enumerate(ROMAN_BANDS):
v = all_fluxes_transpose.__next__()
out_dict[f'roman_flux_{band}'] = v
colnames = [f'roman_flux_{band}' for band in ROMAN_BANDS]
flux_dict = dict(zip(colnames, all_fluxes_transpose))
out_dict.update(flux_dict)

if send_conn:
send_conn.send(out_dict)
else:
return out_dict


def _do_star_flux_chunk(send_conn, star_collection, instrument_needed,
l_bnd, u_bnd):
'''
send_conn output connection, used to send results to
parent process
star_collection ObjectCollection. Information from main skyCatalogs
star file
instrument_needed List of which calculations should be done. Currently
supported instrument names are 'lsst' and 'roman'
l_bnd, u_bnd demarcates slice to process
returns
dict with keys id, lsst_flux_u, ... lsst_flux_y
'''
out_dict = {}

o_list = star_collection[l_bnd: u_bnd]
out_dict['id'] = [o.get_native_attribute('id') for o in o_list]
if 'lsst' in instrument_needed:
all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list]
all_fluxes_transpose = zip(*all_fluxes)
colnames = [f'lsst_flux_{band}' for band in LSST_BANDS]
flux_dict = dict(zip(colnames, all_fluxes_transpose))
out_dict.update(flux_dict)

if 'roman' in instrument_needed:
all_fluxes = [o.get_roman_fluxes(as_dict=False) for o in o_list]
all_fluxes_transpose = zip(*all_fluxes)
colnames = [f'roman_flux_{band}' for band in ROMAN_BANDS]
flux_dict = dict(zip(colnames, all_fluxes_transpose))
out_dict.update(flux_dict)

if send_conn:
send_conn.send(out_dict)
Expand Down Expand Up @@ -753,8 +791,6 @@ def _create_galaxy_flux_pixel(self, pixel):
self._sed_gen.generate_pixel(pixel)

writer = None
global _galaxy_collection
global _instrument_needed
_instrument_needed = []
for field in self._gal_flux_needed:
if 'lsst' in field and 'lsst' not in _instrument_needed:
Expand Down Expand Up @@ -1018,36 +1054,79 @@ def _create_pointsource_flux_pixel(self, pixel):
self._logger.info(f'Skipping regeneration of {output_path}')
return

# NOTE: For now there is only one collection in the object list
# because stars are in a single row group
object_list = self._cat.get_object_type_by_hp(pixel, 'star')
last_row_ix = len(object_list) - 1
writer = None
_star_collection = object_list.get_collections()[0]

# Write out as a single rowgroup as was done for main catalog
l_bnd = 0
u_bnd = last_row_ix + 1
u_bnd = len(_star_collection)
n_parallel = self._flux_parallel

if n_parallel == 1:
n_per = u_bnd - l_bnd
else:
n_per = int((u_bnd - l_bnd + n_parallel)/n_parallel)
fields_needed = self._ps_flux_schema.names
instrument_needed = ['lsst'] # for now

writer = None
rg_written = 0

lb = l_bnd
u = min(l_bnd + n_per, u_bnd)
readers = []

if n_parallel == 1:
out_dict = _do_star_flux_chunk(None, _star_collection,
instrument_needed, lb, u)
else:
# Expect to be able to do about 1500/minute/process
out_dict = {}
for field in fields_needed:
out_dict[field] = []

tm = max(int((n_per*60)/500), 5) # Give ourselves a cushion
self._logger.info(f'Using timeout value {tm} for {n_per} sources')
p_list = []
for i in range(n_parallel):
conn_rd, conn_wrt = Pipe(duplex=False)
readers.append(conn_rd)

# For debugging call directly
proc = Process(target=_do_star_flux_chunk,
name=f'proc_{i}',
args=(conn_wrt, _star_collection,
instrument_needed, lb, u))
proc.start()
p_list.append(proc)
lb = u
u = min(lb + n_per, u_bnd)

self._logger.debug('Processes started')
for i in range(n_parallel):
ready = readers[i].poll(tm)
if not ready:
self._logger.error(f'Process {i} timed out after {tm} sec')
sys.exit(1)
dat = readers[i].recv()
for field in fields_needed:
out_dict[field] += dat[field]
for p in p_list:
p.join()

o_list = object_list[l_bnd: u_bnd]
self._logger.debug(f'Handling range {l_bnd} up to {u_bnd}')
out_dict = {}
out_dict['id'] = [o.get_native_attribute('id') for o in o_list]
all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list]
all_fluxes_transpose = zip(*all_fluxes)
for i, band in enumerate(LSST_BANDS):
self._logger.debug(f'Band {band} is number {i}')
v = all_fluxes_transpose.__next__()
out_dict[f'lsst_flux_{band}'] = v
if i == 1:
self._logger.debug(f'Len of flux column: {len(v)}')
self._logger.debug(f'Type of flux column: {type(v)}')
out_df = pd.DataFrame.from_dict(out_dict)
out_table = pa.Table.from_pandas(out_df,
schema=self._ps_flux_schema)

if not writer:
writer = pq.ParquetWriter(output_path, self._ps_flux_schema)
writer.write_table(out_table)

rg_written += 1

writer.close()
# self._logger.debug(f'#row groups written to flux file: {rg_written}')
self._logger.debug(f'# row groups written to flux file: {rg_written}')
if self._provenance == 'yaml':
self.write_provenance_file(output_path)

Expand Down
8 changes: 7 additions & 1 deletion skycatalogs/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,13 @@ def get_config_value(self, key_path, silent=False):
d = d[i]
if not isinstance(d, dict):
raise ValueError(f'intermediate {d} is not a dict')
return d[path_items[-1]]

if path_items[-1] in d:
return d[path_items[-1]]
else:
if silent:
return None
raise ValueError(f'Item {i} not found')

def add_key(self, k, v):
'''
Expand Down

0 comments on commit e6b60f3

Please sign in to comment.