Skip to content

Commit

Permalink
Final changes before rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
fred3m committed Feb 6, 2025
1 parent 8c9e9e2 commit 673c704
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 30 deletions.
2 changes: 2 additions & 0 deletions python/lsst/meas/extensions/scarlet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def updateCatalogFootprints(
"""
# All of the blends should have the same PSF,
# so we extract it from the first blend data.
if len(modelData.blends) == 0:
raise ValueError("Scarlet mode data is empty")
refBlend = next(iter(modelData.blends.values()))
bands = refBlend.bands
bandIndex = bands.index(band)
Expand Down
92 changes: 62 additions & 30 deletions python/lsst/meas/extensions/scarlet/scarletDeblendTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _getDeconvolvedFootprints(
# catalog will produce unexpected results if a deconvolved footprint
# is in more than one footprint from the source catalog or has
# flux outside of its parent footprint.
sourceImage = utils.footprintsToNumpy(sources, detect.shape, (ymin, xmin))
sourceImage = utils.footprintsToNumpy(sources, detect.shape, (xmin, ymin))
detectionArray = detectionArray * sourceImage

footprints = get_footprints(
Expand Down Expand Up @@ -784,6 +784,13 @@ def _addSchemaKeys(self, schema: afwTable.Schema):
size=25,
doc="Name of error if the blend failed",
)
self.incompleteDataKey = schema.addField(
"deblend_incompleteData",
type="Flag",
doc="True when a blend has at least one band "
"that could not generate a PSF and was "
"not included in the model.",
)
# Deblended source fields
self.peakCenter = afwTable.Point2IKey.addFields(
schema,
Expand Down Expand Up @@ -812,7 +819,7 @@ def _addSchemaKeys(self, schema: afwTable.Schema):
doc="The type of model used, for example "
"MultiExtendedSource, SingleExtendedSource, PointSource",
)
self.deblendDepth = schema.addField(
self.deblendDepthKey = schema.addField(
"deblend_depth",
type=np.int32,
doc="The depth of deblending in the hierarchy."
Expand Down Expand Up @@ -993,10 +1000,10 @@ def deblend(
)
# Keep all of the isolated parents and the first
# `ciNumParentsToDeblend` children
children = nPeaks == 1
parents = nPeaks == 1
children = np.zeros((len(catalog),), dtype=bool)
children[childrenInRange[: self.config.ciNumParentsToDeblend]] = True
catalog = catalog[children | children]
catalog = catalog[parents | children]
# We need to update the IdFactory, otherwise the the source ids
# will not be sequential
idFactory = catalog.getIdFactory()
Expand Down Expand Up @@ -1026,18 +1033,18 @@ def deblend(

# Subdivide the psf blended parents into deconvolved parents
# using the deconovlved footprints.
self.mExposure = mExposure
nPsfBlendedParents = len(catalog)
psfParents, parentHierarchy = self._buildParentHierarchy(catalog, context)
psfParents, parentHierarchy, skippedParents = self._buildParentHierarchy(catalog, context)

self.log.info(
"Subdivided the top level parents to create "
f"{np.sum((catalog[self.deblendDepth] == 1) & (catalog[self.nPeaksKey] > 1))} "
f"{np.sum((catalog[self.deblendDepthKey] == 1) & (catalog[self.nPeaksKey] > 1))} "
"deconvolved parents."
)

# Attach full image objects to the task to simplify the API
# and use for debugging.
self.mExposure = mExposure
self.catalog = catalog
self.context = context
self.parentHierarchy = parentHierarchy
Expand All @@ -1056,11 +1063,7 @@ def deblend(

psfParent = catalog.find(psfParentId)

# Since we use the first peak for the parent object, we should
# propagate its flags to the parent source.
psfParent.assign(psfParent.getFootprint().peaks[0], self.peakSchemaMapper)

if psfParentId in psfParents:
if psfParentId in psfParents and psfParentId not in skippedParents:
self.log.trace(f"Deblending parent {psfParent.getId()} directly")
# The deconvolved footprint had all of the same peaks,
# so there is no deconvolved parent record.
Expand Down Expand Up @@ -1144,8 +1147,8 @@ def deblend(
mask, mask.getPlaneBitMask(self.config.notDeblendedMask)
)

nDeconvolvedParents = np.sum((catalog[self.deblendDepth] == 1) & (catalog[self.nPeaksKey] > 1))
nDeblendedSources = np.sum((catalog[self.deblendDepth] > 0) & (catalog[self.nPeaksKey] == 1))
nDeconvolvedParents = np.sum((catalog[self.deblendDepthKey] == 1) & (catalog[self.nPeaksKey] > 1))
nDeblendedSources = np.sum((catalog[self.deblendDepthKey] > 0) & (catalog[self.nPeaksKey] == 1))
self.log.info(
f"Deblender results: {nPsfBlendedParents} parent sources were "
f"split into {nDeconvolvedParents} deconvovled parents,"
Expand Down Expand Up @@ -1202,9 +1205,8 @@ def _deblendParent(
blendRecord.assign(peaks[0], self.peakSchemaMapper)

# Skip the source if it meets the skipping criteria
if (skipArgs := self._checkSkipped(blendRecord, self.mExposure)) is not None:
self._skipParent(blendRecord, *skipArgs)
return None
if self._checkSkipped(blendRecord, self.mExposure) is not None:
raise RuntimeError("Skipped parents should be handled when building the parent hierachy")

self.log.trace(f"Parent {blendRecord.getId()}: deblending {len(peaks)} peaks")
# Run the deblender
Expand Down Expand Up @@ -1408,7 +1410,7 @@ def _skipParent(
blendRecord.set(self.deblendSkippedKey, True)
blendRecord.set(skipKey, True)

def _checkSkipped(self, parent: afwTable.SourceRecord, mExposure: afwImage.MultibandExposure):
def _checkSkipped(self, parent: afwTable.SourceRecord, mExposure: afwImage.MultibandExposure) -> bool:
"""Update a parent record that is not being deblended.
This is a fairly trivial function but is implemented to ensure
Expand All @@ -1432,7 +1434,10 @@ def _checkSkipped(self, parent: afwTable.SourceRecord, mExposure: afwImage.Multi
skipKey = None
skipMessage = None
footprint = parent.getFootprint()
if len(footprint.peaks) < 2 and not self.config.processSingles:
if (len(footprint.peaks) < 2
and not self.config.processSingles
and parent.get(self.deblendDepthKey) == 0
):

Check failure on line 1440 in python/lsst/meas/extensions/scarlet/scarletDeblendTask.py

View workflow job for this annotation

GitHub Actions / call-workflow / lint

E124

closing bracket does not match visual indentation
# Skip isolated sources unless processSingles is turned on.
# Note: this does not flag isolated sources as skipped or
# set the NOT_DEBLENDED mask in the exposure,
Expand Down Expand Up @@ -1555,6 +1560,7 @@ def _buildParentHierarchy(
psfParents = {}
deconvolvedParents = {}
hierarchy = {}
skipped = []

def _addChildren(parent: afwTable.SourceRecord):
"""Add a child record for every peak in the parent footprint"""
Expand All @@ -1571,9 +1577,18 @@ def _addChildren(parent: afwTable.SourceRecord):

for n in range(nParents):
psfParent = catalog[n]
if isPseudoSource(psfParent, self.config.pseudoColumns):
continue
parentFoot = psfParent.getFootprint()
# Since we use the first peak for the parent object, we should
# propagate its flags to the parent source.
# For example, this propagates `merge_peak_sky` to the parent
psfParent.assign(psfParent.getFootprint().peaks[0], self.peakSchemaMapper)

if (skipArgs := self._checkSkipped(psfParent, self.mExposure)) is not None:
self._skipParent(psfParent, *skipArgs)
# Do not add sources for skipped parents
skipped.append(psfParent.getId())

continue
parentId = psfParent.getId()
deconvolvedFootprints = self._getIntersectingFootprints(
parentFoot,
Expand All @@ -1600,9 +1615,13 @@ def _addChildren(parent: afwTable.SourceRecord):
for parentId, sources in deconvolvedParents.items():
hierarchy[parentId] = {}
for sourceRecord in sources:
if (skipArgs := self._checkSkipped(sourceRecord, self.mExposure)) is not None:
self._skipParent(sourceRecord, *skipArgs)
skipped.append(sourceRecord.getId())
continue
hierarchy[parentId][sourceRecord.getId()] = _addChildren(sourceRecord)

return psfParents, hierarchy
return psfParents, hierarchy, skipped

def _getIntersectingFootprints(
self,
Expand Down Expand Up @@ -1682,7 +1701,9 @@ def _addDeconvolvedParents(
deconvolvedParents.append(deconvolvedParent)
deconvolvedParent.setParent(parent.getId())
deconvolvedParent.setFootprint(footprint)
deconvolvedParent.set(self.deblendDepth, parent[self.deblendDepth] + 1)
deconvolvedParent.set(self.deblendDepthKey, parent[self.deblendDepthKey] + 1)
self._propagateToChild(parent, deconvolvedParent)

return deconvolvedParents

def _addDeblendedSource(
Expand Down Expand Up @@ -1713,8 +1734,6 @@ def _addDeblendedSource(
The new child source record.
"""
src = catalog.addNew()
for key in self.toCopyFromParent:
src.set(key, parent.get(key))
# The peak catalog is the same for all bands,
# so we just use the first peak catalog
src.assign(peak, self.peakSchemaMapper)
Expand All @@ -1735,7 +1754,7 @@ def _addDeblendedSource(
src.set(self.peakIdKey, peak["id"])

# Set the deblend depth
src.set(self.deblendDepth, parent[self.deblendDepth] + 1)
src.set(self.deblendDepthKey, parent[self.deblendDepthKey] + 1)

return src

Expand All @@ -1756,7 +1775,7 @@ def _updateDeblendedSource(
----------
parent :
The parent of the new child record.
peak : `lsst.afw.table.PeakRecord`
peak :
The peak record for the peak from the parent peak catalog.
catalog :
The merged `SourceCatalog` that contains parent footprints
Expand All @@ -1782,7 +1801,20 @@ def _updateDeblendedSource(
src.set(self.scarletChi2Key, np.sum(chi2[:, scarletSource.bbox].data/area))

# Propagate columns from the parent to the child
for parentColumn, childColumn in self.config.columnInheritance.items():
src.set(childColumn, parent.get(parentColumn))

self._propagateToChild(parent, src)
return src

def _propagateToChild(self, parent: afwTable.SourceRecord, child: afwTable.SourceRecord):
"""Propagate columns from the parent to the child.
Parameters
----------
parent :
The parent source record.
child :
The child source record.
"""
for key in self.toCopyFromParent:
child.set(key, parent.get(key))
for parentColumn, childColumn in self.config.columnInheritance.items():
child.set(childColumn, parent.get(parentColumn))

0 comments on commit 673c704

Please sign in to comment.