Skip to content

Commit

Permalink
Implement scarlet lite deconvolved initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
fred3m committed Feb 6, 2025
1 parent 4381c4b commit 0d76254
Show file tree
Hide file tree
Showing 5 changed files with 1,333 additions and 664 deletions.
38 changes: 35 additions & 3 deletions python/lsst/meas/extensions/scarlet/deconvolveExposureTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import logging

import lsst.afw.image as afwImage
import lsst.afw.table as afwTable
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
import lsst.pipe.base.connectionTypes as cT
Expand Down Expand Up @@ -53,13 +54,25 @@ class DeconvolveExposureConnections(
dimensions=("tract", "patch", "band", "skymap"),
)

catalog = cT.Input(
doc="Catalog of sources detected in the deconvolved image",
name="{inputCoaddName}Coadd_mergeDet",
storageClass="SourceCatalog",
dimensions=("tract", "patch", "skymap"),
)

deconvolved = cT.Output(
doc="Deconvolved exposure",
name="deconvolved_{inputCoaddName}_coadd",
storageClass="ExposureF",
dimensions=("tract", "patch", "band", "skymap"),
)

def __init__(self, *, config=None):
if not config.useFootprints:
# Deconvolution does not use input catalog
self.inputs.remove("catalog")


class DeconvolveExposureConfig(
pipeBase.PipelineTaskConfig,
Expand All @@ -84,6 +97,10 @@ class DeconvolveExposureConfig(
doc="Threshold for background subtraction. "
"Pixels in the fit below this threshold will be set to zero",
)
useFootprints = pexConfig.Field[bool](
default=True,
doc="Use footprints to constrain the deconvolved model",
)


class DeconvolveExposureTask(pipeBase.PipelineTask):
Expand All @@ -102,7 +119,11 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs):
outputs = self.run(**inputs)
butlerQC.put(outputs, outputRefs)

def run(self, coadd: afwImage.Exposure) -> pipeBase.Struct:
def run(
self,
coadd: afwImage.Exposure,
catalog: afwTable.SourceCatalog | None = None,
) -> pipeBase.Struct:
"""Deconvolve an Exposure
Parameters
Expand All @@ -117,10 +138,11 @@ def run(self, coadd: afwImage.Exposure) -> pipeBase.Struct:
"""
# Load the scarlet lite Observation
observation = self._buildObservation(coadd)
self.bbox = coadd.getBBox()

# Deconvolve.
# Store the loss history for debugging purposes.
model, self.loss = self._deconvolve(observation)
model, self.loss = self._deconvolve(observation, catalog)

# Store the model in an Exposure
exposure = self._modelToExposure(model.data[0], coadd)
Expand Down Expand Up @@ -160,7 +182,11 @@ def _buildObservation(self, coadd: afwImage.Exposure) -> scl.Observation:
)
return observation

def _deconvolve(self, observation: scl.Observation) -> tuple[scl.Image, list[float]]:
def _deconvolve(
self,
observation: scl.Observation,
catalog: afwTable.SourceCatalog | None = None,
) -> tuple[scl.Image, list[float]]:
"""Deconvolve the observed image.
Parameters
Expand All @@ -170,12 +196,18 @@ def _deconvolve(self, observation: scl.Observation) -> tuple[scl.Image, list[flo
"""
model = observation.images.copy()
loss = []
if catalog is not None:
width, height = self.bbox.getDimensions()
x0, y0 = self.bbox.getMin()
footprintImage = utils.footprintsToNumpy(catalog, (height, width), (x0, y0))
for n in range(self.config.maxIter):
residual = observation.images - observation.convolve(model)
loss.append(-0.5 * np.sum(residual.data**2))
update = observation.convolve(residual, grad=True)
model += update
model.data[model.data < 0] = 0
if catalog is not None:
model.data[:] *= footprintImage

if n > self.config.minIter and np.abs(loss[-1] - loss[-2]) < self.config.eRel * np.abs(loss[-1]):
break
Expand Down
Loading

0 comments on commit 0d76254

Please sign in to comment.