Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow parallelization for star flux computation #85

Merged
merged 4 commits into from
May 16, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 100 additions & 18 deletions skycatalogs/catalog_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,41 @@ def _do_galaxy_flux_chunk(send_conn, galaxy_collection, instrument_needed,
return out_dict


def _do_star_flux_chunk(send_conn, star_collection, instrument_needed,
l_bnd, u_bnd):
Comment on lines +188 to +189
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be a bit cleaner to simplify this interface by doing the slicing in the calling code. So the new interface would effectively become:

def _do_star_flux_chunk(send_conn, o_list, instrument_needed):

and in the calling code, it would be used like this:

_do_star_flux_chunk(send_conn, star_collection[l_bnd: u_bnd], instrument_needed)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, see comment above about global declaration.

'''
end_conn output connection
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end_conn -> send_conn. Can you be more explicit about what sort of "connection" this is?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in next commit.

star_collection information from main file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually an ObjectCollection, isn't it? It would be helpful to note that here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in next commit

instrument_needed List of which calculations should be done
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the only relevant entries are 'lsst' and 'roman'. It would be good to note that here. Should there be a check somewhere that this list contains at least one of those two values?

Copy link
Collaborator Author

@JoanneBogart JoanneBogart May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add something to the parameter description in the docstring.
There could be a check that instrument_needed has at least one valid value, but there is nothing like that level of checking generally on internal routines. The caller always includes 'lsst'.

l_bnd, u_bnd demarcates slice to process

returns
dict with keys id, lsst_flux_u, ... lsst_flux_y
'''
out_dict = {}

o_list = star_collection[l_bnd: u_bnd]
out_dict['id'] = [o.get_native_attribute('id') for o in o_list]
if 'lsst' in instrument_needed:
all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list]
all_fluxes_transpose = zip(*all_fluxes)
for i, band in enumerate(LSST_BANDS):
v = all_fluxes_transpose.__next__()
out_dict[f'lsst_flux_{band}'] = v
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The explicit use of .__next__() seems a bit unwieldy. I think the following would be more conventional:

colnames = [f'lsst_flux_{band}' for band in LSST_BANDS]
flux_dict = dict(zip(colnames, all_fluxes_transpose))
out_dict.update(flux_dict)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed in next commit


if 'roman' in instrument_needed:
all_fluxes = [o.get_roman_fluxes(as_dict=False) for o in o_list]
all_fluxes_transpose = zip(*all_fluxes)
for i, band in enumerate(ROMAN_BANDS):
v = all_fluxes_transpose.__next__()
out_dict[f'roman_flux_{band}'] = v

if send_conn:
send_conn.send(out_dict)
else:
return out_dict


class CatalogCreator:
def __init__(self, parts, area_partition=None, skycatalog_root=None,
catalog_dir='.', galaxy_truth=None,
Expand Down Expand Up @@ -990,6 +1025,9 @@ def _create_pointsource_flux_pixel(self, pixel):
# For schema use self._ps_flux_schema
# output_template should be derived from value for flux_file_template
# in main catalog config. Cheat for now

global _star_collection
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment here why this needs to be global? Given that it's being passed as an argument to _do_star_flux_chunk, it wouldn't seem to need to be.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not positive it does need to be global. The star code is following the same procedure as the galaxy code multiprocessing code, which was written a while ago. My intent then was to make the rather large _star_collection available to subprocesses without pickling. I don't know whether that is in fact the case.
If it's not I would do better to just pass in the slice of _star_collection that each subprocess needs, but that is not something I want to take on for this PR.

Copy link
Collaborator

@jchiang87 jchiang87 May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like you would need to omit _star_collection from the argument list and declare it global inside of _do_star_flux_chunk to avoid having it pickled. Given that it's being passed as an argument, I think it must be being serialized, and the global declaration here isn't doing anything. If so, then there's no reason to pass l_bnd and u_bnd and the slicing can be done in the calling code.

that is not something I want to take on for this PR.

The code may work fine as-is, but these apparent inconsistencies seem like good reasons to try to understand what's going on in case there is something untoward actually happening. So if you don't want to change it, I'd suggest at least some test code to ensure it is behaving as expected.

Copy link
Collaborator Author

@JoanneBogart JoanneBogart May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree in principle; packaging this up in test code will take some thought. When I first implemented the identical scheme for galaxies I proceeded pretty carefully by

  • examining whatever I could in the debugger
  • comparing output from creating a flux file with number-of-subprocesses = 1 (in which case the code doesn't use subprocesses at all and just calls the _do_X_flux_chunk routine directly) and number-of-subprocesses set to something realistic, like 20. The outputs were identical.

I did the same thing when I implemented parallel processing for stars. I haven't done the comparison test for this last commit, but the code changes have nothing to do with the parallel processing part of the code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but statements like this

My intent then was to make the rather large _star_collection available to subprocesses without pickling. I don't know whether that is in fact the case.

make it rather unclear to me, without actually having run the code (which apparently isn't that easy to do), that the unmodified implementation is correct. Regardless of whether consistent outputs are obtained with 1 or more than 1 subprocess, it would be nice to have the global vs not global aspects of the code make sense. Right now, they don't appear to me to be consistent. Given that it looks like global _star_collection isn't doing what's intended, it would be remiss of me not to point that out and suggest a fix.


output_filename = f'pointsource_flux_{pixel}.parquet'
output_path = os.path.join(self._output_dir, output_filename)

Expand All @@ -1000,36 +1038,80 @@ def _create_pointsource_flux_pixel(self, pixel):
self._logger.info(f'Skipping regeneration of {output_path}')
return

# NOTE: For now there is only one collection in the object list
# because stars are in a single row group
object_list = self._cat.get_object_type_by_hp(pixel, 'star')
last_row_ix = len(object_list) - 1
writer = None
obj_coll = object_list.get_collections()[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't this simply be

_star_collection = object_list.get_collections()[0]

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No reason. Changed in next commit.

_star_collection = obj_coll

# Write out as a single rowgroup as was done for main catalog
l_bnd = 0
u_bnd = last_row_ix + 1
u_bnd = len(_star_collection)
n_parallel = self._flux_parallel

if n_parallel == 1:
n_per = u_bnd - l_bnd
else:
n_per = int((u_bnd - l_bnd + n_parallel)/n_parallel)
fields_needed = self._ps_flux_schema.names
instrument_needed = ['lsst'] # for now
Copy link
Collaborator

@jchiang87 jchiang87 May 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like instrument_needed should be set as a instance-level attribute and set in the .__init__(...), instead of hard-wired here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should ultimately comes from the way the outer script was called but currently the assumption is that lsst is always included and including roman is an option. I could see making that assumption explicit in the outer script and flowing only from there, but I think not in this PR.


writer = None
rg_written = 0

lb = l_bnd
u = min(l_bnd + n_per, u_bnd)
readers = []

if n_parallel == 1:
out_dict = _do_star_flux_chunk(None, _star_collection,
instrument_needed, lb, u)
else:
# Expect to be able to do about 1500/minute/process
out_dict = {}
for field in fields_needed:
out_dict[field] = []

tm = max(int((n_per*60)/500), 5) # Give ourselves a cushion
self._logger.info(f'Using timeout value {tm} for {n_per} sources')
p_list = []
for i in range(n_parallel):
conn_rd, conn_wrt = Pipe(duplex=False)
readers.append(conn_rd)

# For debugging call directly
proc = Process(target=_do_star_flux_chunk,
name=f'proc_{i}',
args=(conn_wrt, _star_collection,
instrument_needed, lb, u))
proc.start()
p_list.append(proc)
lb = u
u = min(lb + n_per, u_bnd)

self._logger.debug('Processes started')
for i in range(n_parallel):
ready = readers[i].poll(tm)
if not ready:
self._logger.error(f'Process {i} timed out after {tm} sec')
sys.exit(1)
dat = readers[i].recv()
for field in fields_needed:
out_dict[field] += dat[field]
for p in p_list:
p.join()

o_list = object_list[l_bnd: u_bnd]
self._logger.debug(f'Handling range {l_bnd} up to {u_bnd}')
out_dict = {}
out_dict['id'] = [o.get_native_attribute('id') for o in o_list]
all_fluxes = [o.get_LSST_fluxes(as_dict=False) for o in o_list]
all_fluxes_transpose = zip(*all_fluxes)
for i, band in enumerate(LSST_BANDS):
self._logger.debug(f'Band {band} is number {i}')
v = all_fluxes_transpose.__next__()
out_dict[f'lsst_flux_{band}'] = v
if i == 1:
self._logger.debug(f'Len of flux column: {len(v)}')
self._logger.debug(f'Type of flux column: {type(v)}')
out_df = pd.DataFrame.from_dict(out_dict)
out_table = pa.Table.from_pandas(out_df,
schema=self._ps_flux_schema)

if not writer:
writer = pq.ParquetWriter(output_path, self._ps_flux_schema)
writer.write_table(out_table)

rg_written += 1

writer.close()
# self._logger.debug(f'#row groups written to flux file: {rg_written}')
self._logger.debug(f'# row groups written to flux file: {rg_written}')
if self._provenance == 'yaml':
self.write_provenance_file(output_path)

Expand Down
Loading