diff --git a/py/fiberassign/assign.py b/py/fiberassign/assign.py index a3126806..34c46c74 100644 --- a/py/fiberassign/assign.py +++ b/py/fiberassign/assign.py @@ -20,6 +20,7 @@ import multiprocessing as mp from multiprocessing.sharedctypes import RawArray +from multiprocessing.pool import ThreadPool from functools import partial @@ -573,7 +574,7 @@ def write_assignment_fits_tile(asgn, tagalong, fulltarget, overwrite, params): def write_assignment_fits(tiles, tagalong, asgn, out_dir=".", out_prefix="fba-", split_dir=False, all_targets=False, gfa_targets=None, overwrite=False, stucksky=None, - tile_xy_cs5=None): + tile_xy_cs5=None, numproc=0): """Write out assignment results in FITS format. For each tile, all available targets (not only the assigned targets) and @@ -593,6 +594,7 @@ def write_assignment_fits(tiles, tagalong, asgn, out_dir=".", out_prefix="fba-", properties of assigned targets. gfa_targets (list of numpy arrays): Include these as GFA_TARGETS HDUs overwrite (bool): overwrite pre-existing output files + numproc (int): if >0, runs with numproc parallel jobs (with pool.starmap()); default: 0 Returns: None @@ -610,9 +612,13 @@ def write_assignment_fits(tiles, tagalong, asgn, out_dir=".", out_prefix="fba-", tiletime = tiles.obstime tileha = tiles.obshourang + write_tile = partial(write_assignment_fits_tile, asgn, tagalong, all_targets, overwrite) + if numproc > 0: + all_params = [] + for i, tid in enumerate(tileids): tra = tilera[tileorder[tid]] tdec = tiledec[tileorder[tid]] @@ -635,7 +641,16 @@ def write_assignment_fits(tiles, tagalong, asgn, out_dir=".", out_prefix="fba-", txy = tile_xy_cs5.get(tid, None) params = (tid, tra, tdec, ttheta, ttime, tha, outfile, gfa, stuck, txy) - write_tile(params) + + if numproc == 0: + write_tile(params) + else: + all_params.append((asgn, tagalong, all_targets, overwrite, params)) + + if numproc > 0: + pool = ThreadPool(numproc) + with pool: + _ = pool.starmap(write_assignment_fits_tile, all_params) tm.stop() tm.report("Write output files") diff --git a/py/fiberassign/scripts/assign.py b/py/fiberassign/scripts/assign.py index 5cef659c..695b03c1 100644 --- a/py/fiberassign/scripts/assign.py +++ b/py/fiberassign/scripts/assign.py @@ -197,6 +197,14 @@ def parse_assign(optlist=None): choices=["ls", "gaia"], help="Source for the look-up table for sky positions for stuck fibers:" " 'ls': uses $SKYBRICKS_DIR; 'gaia': uses $SKYHEALPIXS_DIR (default=ls)") + parser.add_argument("--write_fits_numproc", required=False, default=0, + type=int, + help="if >0, then runs the write_assignment_fits() in parallel with numproc jobs (default=0)") + parser.add_argument("--fast_match", required=False, default=False, + type=bool, + help="use a fast method to match TARGETID in the TargetTagalong Class;" + "assumes there are no duplicates in TARGETID in the input files (default=False)") + args = None if optlist is None: @@ -327,7 +335,7 @@ def run_assign_init(args, plate_radec=True): # Create empty target list tgs = Targets() # Create structure for carrying along auxiliary target data not needed by C++. - tagalong = create_tagalong(plate_radec=plate_radec) + tagalong = create_tagalong(plate_radec=plate_radec, fast_match=args.fast_match) # Append each input target file. These target files must all be of the # same survey type, and will set the Targets object to be of that survey. @@ -444,7 +452,8 @@ def run_assign_full(args, plate_radec=True): out_prefix=args.prefix, split_dir=args.split, all_targets=args.write_all_targets, gfa_targets=gfa_targets, overwrite=args.overwrite, - stucksky=stucksky, tile_xy_cs5=tile_xy_cs5) + stucksky=stucksky, tile_xy_cs5=tile_xy_cs5, + numproc=args.write_fits_numproc) gt.stop("run_assign_full write output") @@ -539,7 +548,8 @@ def run_assign_bytile(args): out_prefix=args.prefix, split_dir=args.split, all_targets=args.write_all_targets, gfa_targets=gfa_targets, overwrite=args.overwrite, - stucksky=stucksky, tile_xy_cs5=tile_xy_cs5) + stucksky=stucksky, tile_xy_cs5=tile_xy_cs5, + numproc=args.write_fits_numproc) gt.stop("run_assign_bytile write output") diff --git a/py/fiberassign/targets.py b/py/fiberassign/targets.py index 6cb05e1e..b33be88c 100644 --- a/py/fiberassign/targets.py +++ b/py/fiberassign/targets.py @@ -31,6 +31,7 @@ from desitarget.sv3.sv3_targetmask import desi_mask as sv3_mask from desitarget.targets import main_cmx_or_sv +from desitarget.geomask import match from .utils import Logger, Timer from .hardware import radec2xy, cs52xy @@ -54,7 +55,7 @@ class TargetTagalong(object): to propagate to the output fiberassign files, and that are not needed by the C++ layer. ''' - def __init__(self, columns, outnames={}, aliases={}): + def __init__(self, columns, outnames={}, aliases={}, fast_match=False): ''' Create a new tag-along object. @@ -64,10 +65,14 @@ def __init__(self, columns, outnames={}, aliases={}): the column will be given in the output file; None to omit from the output file. *aliases*: dict, string to string: for get_for_ids(), column aliases. + *fast_match*: bool (default to False): use a fast method to match TARGETIDs + assumes there are no duplicates in TARGETIDs + [added in Feb. 2024] ''' self.columns = columns self.outnames = outnames self.aliases = aliases + self.fast_match = fast_match # Internally, we store one tuple for each targeting file read # (to avoid manipulating/reformatting the arrays too much), # where each tuple starts with the TARGETID of the targets, followed @@ -129,16 +134,20 @@ def set_data(self, targetids, tabledata): outarr[:] = defval outarrs.append(outarr) # Build output targetid-to-index map - outmap = dict([(tid,i) for i,tid in enumerate(targetids)]) + if not self.fast_match: + outmap = dict([(tid,i) for i,tid in enumerate(targetids)]) # Go through my many data arrays for thedata in self.data: # TARGETIDs are the first element in the tuple tids = thedata[0] # Search for output array indices for these targetids - outinds = np.array([outmap.get(tid, -1) for tid in tids]) - # Keep only the indices of targetids that were found - ininds = np.flatnonzero(outinds >= 0) - outinds = outinds[ininds] + if self.fast_match: + outinds, ininds = match(targetids, tids) + else: + outinds = np.array([outmap.get(tid, -1) for tid in tids]) + # Keep only the indices of targetids that were found + ininds = np.flatnonzero(outinds >= 0) + outinds = outinds[ininds] for outarr,inarr in zip(outarrs, thedata[1:]): if outarr is None: continue @@ -160,19 +169,23 @@ def get_for_ids(self, targetids, names): outarrs.append(np.zeros(len(targetids), dtype)) colinds.append(ic+1) # Build output targetid-to-index map - outmap = dict([(tid,i) for i,tid in enumerate(targetids)]) + if not self.fast_match: + outmap = dict([(tid,i) for i,tid in enumerate(targetids)]) # Go through my many data arrays for thedata in self.data: tids = thedata[0] # Search for output array indices for these targetids - outinds = np.array([outmap.get(tid, -1) for tid in tids]) - ininds = np.flatnonzero(outinds >= 0) - outinds = outinds[ininds] + if self.fast_match: + outinds, ininds = match(targetids, tids) + else: + outinds = np.array([outmap.get(tid, -1) for tid in tids]) + ininds = np.flatnonzero(outinds >= 0) + outinds = outinds[ininds] for outarr,ic in zip(outarrs, colinds): outarr[outinds] = thedata[ic][ininds] return outarrs -def create_tagalong(plate_radec=True): +def create_tagalong(plate_radec=True, fast_match=False): cols = [ 'TARGET_RA', 'TARGET_DEC', @@ -200,7 +213,7 @@ def create_tagalong(plate_radec=True): # (OBSCOND doesn't appear in all the fiberassign output HDUs, # so we handle it specially) - return TargetTagalong(cols, outnames={'OBSCOND':None}, aliases=aliases) + return TargetTagalong(cols, outnames={'OBSCOND':None}, aliases=aliases, fast_match=fast_match) def str_to_target_type(input): if input == "science":