From 1b0f30d01ae930e755f2515b1e2edc4bbcd92a0d Mon Sep 17 00:00:00 2001 From: Steven Henke Date: Wed, 23 Oct 2024 13:41:17 -0500 Subject: [PATCH] ptycho+xrf GUI complete --- ptychodus/api/fluorescence.py | 6 + ptychodus/controller/probe/core.py | 2 +- ptychodus/controller/probe/fluorescence.py | 142 ++++++--- ptychodus/model/analysis/__init__.py | 2 - ptychodus/model/analysis/core.py | 42 +-- ptychodus/model/analysis/fluorescence.py | 327 --------------------- ptychodus/model/analysis/settings.py | 31 -- ptychodus/model/core.py | 18 +- ptychodus/model/fluorescence/__init__.py | 10 + ptychodus/model/fluorescence/core.py | 223 ++++++++++++++ ptychodus/model/fluorescence/settings.py | 33 +++ ptychodus/model/fluorescence/two_step.py | 110 +++++++ ptychodus/model/fluorescence/vspi.py | 180 ++++++++++++ ptychodus/view/probe.py | 49 ++- 14 files changed, 718 insertions(+), 457 deletions(-) delete mode 100644 ptychodus/model/analysis/fluorescence.py create mode 100644 ptychodus/model/fluorescence/__init__.py create mode 100644 ptychodus/model/fluorescence/core.py create mode 100644 ptychodus/model/fluorescence/settings.py create mode 100644 ptychodus/model/fluorescence/two_step.py create mode 100644 ptychodus/model/fluorescence/vspi.py diff --git a/ptychodus/api/fluorescence.py b/ptychodus/api/fluorescence.py index 4a27f737..b1352e4c 100644 --- a/ptychodus/api/fluorescence.py +++ b/ptychodus/api/fluorescence.py @@ -24,6 +24,12 @@ class FluorescenceDataset: # scan_indexes: IntegerArray +class FluorescenceEnhancingAlgorithm(ABC): + @abstractmethod + def enhance(self, dataset: FluorescenceDataset, product: Product) -> FluorescenceDataset: + pass + + class FluorescenceFileReader(ABC): @abstractmethod def read(self, filePath: Path) -> FluorescenceDataset: diff --git a/ptychodus/controller/probe/core.py b/ptychodus/controller/probe/core.py index b16b76fb..509b7ef9 100644 --- a/ptychodus/controller/probe/core.py +++ b/ptychodus/controller/probe/core.py @@ -8,10 +8,10 @@ from ...model.analysis import ( ExposureAnalyzer, - FluorescenceEnhancer, ProbePropagator, STXMSimulator, ) +from ...model.fluorescence import FluorescenceEnhancer from ...model.product import ProbeAPI, ProbeRepository from ...model.product.probe import ProbeRepositoryItem from ...model.visualization import VisualizationEngine diff --git a/ptychodus/controller/probe/fluorescence.py b/ptychodus/controller/probe/fluorescence.py index 236f1e7b..a2a49c1b 100644 --- a/ptychodus/controller/probe/fluorescence.py +++ b/ptychodus/controller/probe/fluorescence.py @@ -1,13 +1,23 @@ -from typing import Any +from decimal import Decimal +from typing import Any, Final import logging from PyQt5.QtCore import Qt, QAbstractListModel, QModelIndex, QObject, QStringListModel +from PyQt5.QtWidgets import QWidget from ptychodus.api.observer import Observable, Observer -from ...model.analysis import FluorescenceEnhancer +from ...model.fluorescence import ( + FluorescenceEnhancer, + TwoStepFluorescenceEnhancingAlgorithm, + VSPIFluorescenceEnhancingAlgorithm, +) from ...model.visualization import VisualizationEngine -from ...view.probe import FluorescenceDialog +from ...view.probe import ( + FluorescenceDialog, + FluorescenceTwoStepParametersView, + FluorescenceVSPIParametersView, +) from ...view.widgets import ExceptionDialog from ..data import FileDialogFactory from ..visualization import ( @@ -33,6 +43,71 @@ def rowCount(self, parent: QModelIndex = QModelIndex()) -> int: return self._enhancer.getNumberOfChannels() +class FluorescenceTwoStepViewController(Observer): + def __init__(self, algorithm: TwoStepFluorescenceEnhancingAlgorithm) -> None: + super().__init__() + self._algorithm = algorithm + self._view = FluorescenceTwoStepParametersView() + + self._upscalingModel = QStringListModel() + self._upscalingModel.setStringList(self._algorithm.getUpscalingStrategyList()) + self._view.upscalingStrategyComboBox.setModel(self._upscalingModel) + self._view.upscalingStrategyComboBox.textActivated.connect(algorithm.setUpscalingStrategy) + + self._deconvolutionModel = QStringListModel() + self._deconvolutionModel.setStringList(self._algorithm.getDeconvolutionStrategyList()) + self._view.deconvolutionStrategyComboBox.setModel(self._deconvolutionModel) + self._view.deconvolutionStrategyComboBox.textActivated.connect( + algorithm.setDeconvolutionStrategy + ) + + self._syncModelToView() + algorithm.addObserver(self) + + def getWidget(self) -> QWidget: + return self._view + + def _syncModelToView(self) -> None: + self._view.upscalingStrategyComboBox.setCurrentText(self._algorithm.getUpscalingStrategy()) + self._view.deconvolutionStrategyComboBox.setCurrentText( + self._algorithm.getDeconvolutionStrategy() + ) + + def update(self, observable: Observable) -> None: + if observable is self._algorithm: + self._syncModelToView() + + +class FluorescenceVSPIViewController(Observer): + MAX_INT: Final[int] = 0x7FFFFFFF + + def __init__(self, algorithm: VSPIFluorescenceEnhancingAlgorithm) -> None: + super().__init__() + self._algorithm = algorithm + self._view = FluorescenceVSPIParametersView() + + self._view.dampingFactorLineEdit.valueChanged.connect(self._syncDampingFactorToModel) + self._view.maxIterationsSpinBox.setRange(1, self.MAX_INT) + self._view.maxIterationsSpinBox.valueChanged.connect(algorithm.setMaxIterations) + + algorithm.addObserver(self) + self._syncModelToView() + + def getWidget(self) -> QWidget: + return self._view + + def _syncDampingFactorToModel(self, value: Decimal) -> None: + self._algorithm.setDampingFactor(float(value)) + + def _syncModelToView(self) -> None: + self._view.dampingFactorLineEdit.setValue(Decimal(repr(self._algorithm.getDampingFactor()))) + self._view.maxIterationsSpinBox.setValue(self._algorithm.getMaxIterations()) + + def update(self, observable: Observable) -> None: + if observable is self._algorithm: + self._syncModelToView() + + class FluorescenceViewController(Observer): def __init__( self, @@ -46,43 +121,42 @@ def __init__( self._fileDialogFactory = fileDialogFactory self._dialog = FluorescenceDialog() self._enhancementModel = QStringListModel() - self._enhancementModel.setStringList(self._enhancer.getEnhancementStrategyList()) - # FIXME add vspiDampingFactor - # FIXME add vspiMaxIterations - self._upscalingModel = QStringListModel() - self._upscalingModel.setStringList(self._enhancer.getUpscalingStrategyList()) - self._deconvolutionModel = QStringListModel() - self._deconvolutionModel.setStringList(self._enhancer.getDeconvolutionStrategyList()) + self._enhancementModel.setStringList(self._enhancer.getAlgorithmList()) self._channelListModel = FluorescenceChannelListModel(enhancer) self._dialog.fluorescenceParametersView.openButton.clicked.connect( self._openMeasuredDataset ) - self._dialog.fluorescenceParametersView.enhancementStrategyComboBox.setModel( - self._enhancementModel + twoStepViewController = FluorescenceTwoStepViewController( + enhancer.twoStepEnhancingAlgorithm + ) + self._dialog.fluorescenceParametersView.algorithmComboBox.addItem( + TwoStepFluorescenceEnhancingAlgorithm.DISPLAY_NAME, + self._dialog.fluorescenceParametersView.algorithmComboBox.count(), ) - self._dialog.fluorescenceParametersView.enhancementStrategyComboBox.textActivated.connect( - enhancer.setEnhancementStrategy + self._dialog.fluorescenceParametersView.stackedWidget.addWidget( + twoStepViewController.getWidget() ) - self._dialog.fluorescenceParametersView.upscalingStrategyComboBox.setModel( - self._upscalingModel + vspiViewController = FluorescenceVSPIViewController(enhancer.vspiEnhancingAlgorithm) + self._dialog.fluorescenceParametersView.algorithmComboBox.addItem( + VSPIFluorescenceEnhancingAlgorithm.DISPLAY_NAME, + self._dialog.fluorescenceParametersView.algorithmComboBox.count(), ) - self._dialog.fluorescenceParametersView.upscalingStrategyComboBox.textActivated.connect( - enhancer.setUpscalingStrategy + self._dialog.fluorescenceParametersView.stackedWidget.addWidget( + vspiViewController.getWidget() ) - self._dialog.fluorescenceParametersView.deconvolutionStrategyComboBox.setModel( - self._deconvolutionModel + self._dialog.fluorescenceParametersView.algorithmComboBox.textActivated.connect( + enhancer.setAlgorithm ) - self._dialog.fluorescenceParametersView.deconvolutionStrategyComboBox.textActivated.connect( - enhancer.setDeconvolutionStrategy + self._dialog.fluorescenceParametersView.algorithmComboBox.currentIndexChanged.connect( + self._dialog.fluorescenceParametersView.stackedWidget.setCurrentIndex ) - - self._dialog.fluorescenceChannelListView.setModel(self._channelListModel) - self._dialog.fluorescenceChannelListView.selectionModel().currentChanged.connect( - self._updateView + self._dialog.fluorescenceParametersView.algorithmComboBox.setModel(self._enhancementModel) + self._dialog.fluorescenceParametersView.algorithmComboBox.textActivated.connect( + enhancer.setAlgorithm ) self._dialog.fluorescenceParametersView.enhanceButton.clicked.connect( @@ -92,6 +166,11 @@ def __init__( self._saveEnhancedDataset ) + self._dialog.fluorescenceChannelListView.setModel(self._channelListModel) + self._dialog.fluorescenceChannelListView.selectionModel().currentChanged.connect( + self._updateView + ) + self._measuredWidgetController = VisualizationWidgetController( engine, self._dialog.measuredWidget, @@ -162,16 +241,9 @@ def _saveEnhancedDataset(self) -> None: ExceptionDialog.showException(title, err) def _syncModelToView(self) -> None: - self._dialog.fluorescenceParametersView.enhancementStrategyComboBox.setCurrentText( - self._enhancer.getEnhancementStrategy() - ) - self._dialog.fluorescenceParametersView.upscalingStrategyComboBox.setCurrentText( - self._enhancer.getUpscalingStrategy() + self._dialog.fluorescenceParametersView.algorithmComboBox.setCurrentText( + self._enhancer.getAlgorithm() ) - self._dialog.fluorescenceParametersView.deconvolutionStrategyComboBox.setCurrentText( - self._enhancer.getDeconvolutionStrategy() - ) - self._channelListModel.beginResetModel() self._channelListModel.endResetModel() diff --git a/ptychodus/model/analysis/__init__.py b/ptychodus/model/analysis/__init__.py index aeae296f..250c79ca 100644 --- a/ptychodus/model/analysis/__init__.py +++ b/ptychodus/model/analysis/__init__.py @@ -1,6 +1,5 @@ from .core import AnalysisCore from .exposure import ExposureAnalyzer, ExposureMap -from .fluorescence import FluorescenceEnhancer from .frc import FourierRingCorrelator from .objectInterpolator import ObjectLinearInterpolator from .objectStitcher import ObjectStitcher @@ -12,7 +11,6 @@ 'AnalysisCore', 'ExposureAnalyzer', 'ExposureMap', - 'FluorescenceEnhancer', 'FourierRingCorrelator', 'ObjectLinearInterpolator', 'ObjectStitcher', diff --git a/ptychodus/model/analysis/core.py b/ptychodus/model/analysis/core.py index 7225d657..baf8d098 100644 --- a/ptychodus/model/analysis/core.py +++ b/ptychodus/model/analysis/core.py @@ -1,23 +1,14 @@ -from pathlib import Path import logging -from ptychodus.api.fluorescence import ( - DeconvolutionStrategy, - FluorescenceFileReader, - FluorescenceFileWriter, - UpscalingStrategy, -) -from ptychodus.api.plugins import PluginChooser from ptychodus.api.settings import SettingsRegistry from ..product import ObjectRepository, ProductRepository from ..reconstructor import DiffractionPatternPositionMatcher from ..visualization import VisualizationEngine from .exposure import ExposureAnalyzer -from .fluorescence import FluorescenceEnhancer from .frc import FourierRingCorrelator from .propagator import ProbePropagator -from .settings import FluorescenceSettings, ProbePropagationSettings +from .settings import ProbePropagationSettings from .stxm import STXMSimulator from .xmcd import XMCDAnalyzer @@ -31,10 +22,6 @@ def __init__( dataMatcher: DiffractionPatternPositionMatcher, productRepository: ProductRepository, objectRepository: ObjectRepository, - upscalingStrategyChooser: PluginChooser[UpscalingStrategy], - deconvolutionStrategyChooser: PluginChooser[DeconvolutionStrategy], - fluorescenceFileReaderChooser: PluginChooser[FluorescenceFileReader], - fluorescenceFileWriterChooser: PluginChooser[FluorescenceFileWriter], ) -> None: self.stxmSimulator = STXMSimulator(dataMatcher) self.stxmVisualizationEngine = VisualizationEngine(isComplex=False) @@ -46,32 +33,5 @@ def __init__( self.exposureVisualizationEngine = VisualizationEngine(isComplex=False) self.fourierRingCorrelator = FourierRingCorrelator(objectRepository) - self._fluorescenceSettings = FluorescenceSettings(settingsRegistry) - self.fluorescenceEnhancer = FluorescenceEnhancer( - self._fluorescenceSettings, - productRepository, - upscalingStrategyChooser, - deconvolutionStrategyChooser, - fluorescenceFileReaderChooser, - fluorescenceFileWriterChooser, - settingsRegistry, - ) - self.fluorescenceVisualizationEngine = VisualizationEngine(isComplex=False) self.xmcdAnalyzer = XMCDAnalyzer(objectRepository) self.xmcdVisualizationEngine = VisualizationEngine(isComplex=False) - - def enhanceFluorescence( - self, productIndex: int, inputFilePath: Path, outputFilePath: Path - ) -> int: - fileType = 'XRF-Maps' - - try: - self.fluorescenceEnhancer.setProduct(productIndex) - self.fluorescenceEnhancer.openMeasuredDataset(inputFilePath, fileType) - self.fluorescenceEnhancer.enhanceFluorescence() - self.fluorescenceEnhancer.saveEnhancedDataset(outputFilePath, fileType) - except Exception as exc: - logger.exception(exc) - return -1 - - return 0 diff --git a/ptychodus/model/analysis/fluorescence.py b/ptychodus/model/analysis/fluorescence.py deleted file mode 100644 index 29812299..00000000 --- a/ptychodus/model/analysis/fluorescence.py +++ /dev/null @@ -1,327 +0,0 @@ -from __future__ import annotations -from collections.abc import Sequence -from pathlib import Path -from typing import Final -import logging -import time - -from scipy.sparse.linalg import lsmr, LinearOperator -import numpy - -from ptychodus.api.fluorescence import ( - DeconvolutionStrategy, - ElementMap, - FluorescenceDataset, - FluorescenceFileReader, - FluorescenceFileWriter, - UpscalingStrategy, -) -from ptychodus.api.geometry import PixelGeometry -from ptychodus.api.object import ObjectPoint -from ptychodus.api.observer import Observable, Observer -from ptychodus.api.plugins import PluginChooser -from ptychodus.api.product import Product -from ptychodus.api.typing import RealArrayType - -from ..product import ProductRepository -from .settings import FluorescenceSettings - -logger = logging.getLogger(__name__) - - -class ArrayPatchInterpolator: - def __init__(self, array: RealArrayType, point: ObjectPoint, shape: tuple[int, ...]) -> None: - # top left corner of object support - xmin = point.positionXInPixels - shape[-1] / 2 - ymin = point.positionYInPixels - shape[-2] / 2 - - # whole components (pixel indexes) - xmin_wh = int(xmin) - ymin_wh = int(ymin) - - # fractional (subpixel) components - xmin_fr = xmin - xmin_wh - ymin_fr = ymin - ymin_wh - - # bottom right corner of object patch support - xmax_wh = xmin_wh + shape[-1] + 1 - ymax_wh = ymin_wh + shape[-2] + 1 - - # reused quantities - xmin_fr_c = 1.0 - xmin_fr - ymin_fr_c = 1.0 - ymin_fr - - # barycentric interpolant weights - self._weight00 = ymin_fr_c * xmin_fr_c - self._weight01 = ymin_fr_c * xmin_fr - self._weight10 = ymin_fr * xmin_fr_c - self._weight11 = ymin_fr * xmin_fr - - # extract patch support region from full object - self._support = array[ymin_wh:ymax_wh, xmin_wh:xmax_wh] - - def get_patch(self) -> RealArrayType: - """interpolate array support to extract patch""" - patch = self._weight00 * self._support[:-1, :-1] - patch += self._weight01 * self._support[:-1, 1:] - patch += self._weight10 * self._support[1:, :-1] - patch += self._weight11 * self._support[1:, 1:] - return patch - - def accumulate_patch(self, patch: RealArrayType) -> None: - """add patch update to array support""" - self._support[:-1, :-1] += self._weight00 * patch - self._support[:-1, 1:] += self._weight01 * patch - self._support[1:, :-1] += self._weight10 * patch - self._support[1:, 1:] += self._weight11 * patch - - -class VSPILinearOperator(LinearOperator): - def __init__(self, product: Product) -> None: - """ - M: number of XRF positions - N: number of ptychography object pixels - P: number of XRF channels - - A[M,N] * X[N,P] = B[M,P] - """ - object_geometry = product.object_.getGeometry() - M = len(product.scan) - N = object_geometry.heightInPixels * object_geometry.widthInPixels - super().__init__(float, (M, N)) - self._product = product - - def _get_psf(self) -> RealArrayType: - intensity = self._product.probe.getIntensity() - return intensity / intensity.sum() - - def _matvec(self, X: RealArrayType) -> RealArrayType: - object_geometry = self._product.object_.getGeometry() - object_array = X.reshape((object_geometry.heightInPixels, object_geometry.widthInPixels)) - psf = self._get_psf() - AX = numpy.zeros(len(self._product.scan)) - - for index, scan_point in enumerate(self._product.scan): - object_point = object_geometry.mapScanPointToObjectPoint(scan_point) - interpolator = ArrayPatchInterpolator(object_array, object_point, psf.shape) - AX[index] = numpy.sum(psf * interpolator.get_patch()) - - return AX - - def _rmatvec(self, X: RealArrayType) -> RealArrayType: - object_geometry = self._product.object_.getGeometry() - object_array = numpy.zeros((object_geometry.heightInPixels, object_geometry.widthInPixels)) - psf = self._get_psf() - - for index, scan_point in enumerate(self._product.scan): - object_point = object_geometry.mapScanPointToObjectPoint(scan_point) - interpolator = ArrayPatchInterpolator(object_array, object_point, psf.shape) - interpolator.accumulate_patch(X[index] * psf) - - HX = object_array.flatten() - - return HX - - -class FluorescenceEnhancer(Observable, Observer): - VSPI: Final[str] = 'Virtual Single Pixel Imaging' - TWO_STEP: Final[str] = 'Upscale and Deconvolve' - - def __init__( - self, - settings: FluorescenceSettings, - productRepository: ProductRepository, - upscalingStrategyChooser: PluginChooser[UpscalingStrategy], - deconvolutionStrategyChooser: PluginChooser[DeconvolutionStrategy], - fileReaderChooser: PluginChooser[FluorescenceFileReader], - fileWriterChooser: PluginChooser[FluorescenceFileWriter], - reinitObservable: Observable, - ) -> None: - super().__init__() - self._settings = settings - self._productRepository = productRepository - self._upscalingStrategyChooser = upscalingStrategyChooser - self._deconvolutionStrategyChooser = deconvolutionStrategyChooser - self._fileReaderChooser = fileReaderChooser - self._fileWriterChooser = fileWriterChooser - self._reinitObservable = reinitObservable - - self._productIndex = -1 - self._measured: FluorescenceDataset | None = None - self._enhanced: FluorescenceDataset | None = None - - upscalingStrategyChooser.addObserver(self) - upscalingStrategyChooser.setCurrentPluginByName(settings.upscalingStrategy.getValue()) - deconvolutionStrategyChooser.addObserver(self) - deconvolutionStrategyChooser.setCurrentPluginByName( - settings.deconvolutionStrategy.getValue() - ) - fileReaderChooser.setCurrentPluginByName(settings.fileType.getValue()) - fileWriterChooser.setCurrentPluginByName(settings.fileType.getValue()) - reinitObservable.addObserver(self) - - def setProduct(self, productIndex: int) -> None: - if self._productIndex != productIndex: - self._productIndex = productIndex - self._enhanced = None - self.notifyObservers() - - def getProductName(self) -> str: - return self._productRepository[self._productIndex].getName() - - def getOpenFileFilterList(self) -> Sequence[str]: - return self._fileReaderChooser.getDisplayNameList() - - def getOpenFileFilter(self) -> str: - return self._fileReaderChooser.currentPlugin.displayName - - def openMeasuredDataset(self, filePath: Path, fileFilter: str) -> None: - if filePath.is_file(): - self._fileReaderChooser.setCurrentPluginByName(fileFilter) - fileType = self._fileReaderChooser.currentPlugin.simpleName - logger.debug(f'Reading "{filePath}" as "{fileType}"') - fileReader = self._fileReaderChooser.currentPlugin.strategy - - try: - measured = fileReader.read(filePath) - except Exception as exc: - raise RuntimeError(f'Failed to read "{filePath}"') from exc - else: - self._measured = measured - self._enhanced = None - - self._settings.filePath.setValue(filePath) - self._settings.fileType.setValue(fileType) - - self.notifyObservers() - else: - logger.warning(f'Refusing to load dataset from invalid file path "{filePath}"') - - def getNumberOfChannels(self) -> int: - return 0 if self._measured is None else len(self._measured.element_maps) - - def getMeasuredElementMap(self, channelIndex: int) -> ElementMap: - if self._measured is None: - raise ValueError('Fluorescence dataset not loaded!') - - return self._measured.element_maps[channelIndex] - - def getEnhancementStrategyList(self) -> Sequence[str]: - return [self.VSPI, self.TWO_STEP] - - def getEnhancementStrategy(self) -> str: - return self.VSPI if self._settings.useVSPI.getValue() else self.TWO_STEP - - def setEnhancementStrategy(self, name: str) -> None: - self._settings.useVSPI.setValue(name.casefold() == self.VSPI.casefold()) - - def getUpscalingStrategyList(self) -> Sequence[str]: - return self._upscalingStrategyChooser.getDisplayNameList() - - def getUpscalingStrategy(self) -> str: - return self._upscalingStrategyChooser.currentPlugin.displayName - - def setUpscalingStrategy(self, name: str) -> None: - self._upscalingStrategyChooser.setCurrentPluginByName(name) - - def getDeconvolutionStrategyList(self) -> Sequence[str]: - return self._deconvolutionStrategyChooser.getDisplayNameList() - - def getDeconvolutionStrategy(self) -> str: - return self._deconvolutionStrategyChooser.currentPlugin.displayName - - def setDeconvolutionStrategy(self, name: str) -> None: - self._deconvolutionStrategyChooser.setCurrentPluginByName(name) - - def enhanceFluorescence(self) -> None: - if self._measured is None: - raise ValueError('Fluorescence dataset not loaded!') - - product = self._productRepository[self._productIndex].getProduct() - object_geometry = product.object_.getGeometry() - e_cps_shape = object_geometry.heightInPixels, object_geometry.widthInPixels - element_maps: list[ElementMap] = list() - - if self._settings.useVSPI.getValue(): - A = VSPILinearOperator(product) - - for emap in self._measured.element_maps: - logger.info(f'Enhancing "{emap.name}"...') - tic = time.perf_counter() - m_cps = emap.counts_per_second - result = lsmr( - A, - m_cps.flatten(), - damp=self._settings.vspiDampingFactor.getValue(), - maxiter=self._settings.vspiMaximumIterations.getValue(), - show=True, - ) - logger.debug(result) - e_cps = result[0].reshape(e_cps_shape) - emap_enhanced = ElementMap(emap.name, e_cps) - toc = time.perf_counter() - logger.info(f'Enhanced "{emap.name}" in {toc - tic:.4f} seconds.') - - element_maps.append(emap_enhanced) - else: - upscaler = self._upscalingStrategyChooser.currentPlugin.strategy - deconvolver = self._deconvolutionStrategyChooser.currentPlugin.strategy - - for emap in self._measured.element_maps: - logger.info(f'Enhancing "{emap.name}"...') - tic = time.perf_counter() - emap_upscaled = upscaler(emap, product) - emap_enhanced = deconvolver(emap_upscaled, product) - toc = time.perf_counter() - logger.info(f'Enhanced "{emap.name}" in {toc - tic:.4f} seconds.') - - element_maps.append(emap_enhanced) - - self._enhanced = FluorescenceDataset( - element_maps=element_maps, - counts_per_second_path=self._measured.counts_per_second_path, - channel_names_path=self._measured.channel_names_path, - ) - self.notifyObservers() - - def getPixelGeometry(self) -> PixelGeometry: - return self._productRepository[self._productIndex].getGeometry().getPixelGeometry() - - def getEnhancedElementMap(self, channelIndex: int) -> ElementMap: - if self._enhanced is None: - raise ValueError('Fluorescence dataset not enhanced!') - - return self._enhanced.element_maps[channelIndex] - - def getSaveFileFilterList(self) -> Sequence[str]: - return self._fileWriterChooser.getDisplayNameList() - - def getSaveFileFilter(self) -> str: - return self._fileWriterChooser.currentPlugin.displayName - - def saveEnhancedDataset(self, filePath: Path, fileFilter: str) -> None: - if self._enhanced is None: - raise ValueError('Fluorescence dataset not enhanced!') - - self._fileWriterChooser.setCurrentPluginByName(fileFilter) - fileType = self._fileWriterChooser.currentPlugin.simpleName - logger.debug(f'Writing "{filePath}" as "{fileType}"') - writer = self._fileWriterChooser.currentPlugin.strategy - writer.write(filePath, self._enhanced) - - def _openFluorescenceFileFromSettings(self) -> None: - self.openMeasuredDataset( - self._settings.filePath.getValue(), self._settings.fileType.getValue() - ) - - def update(self, observable: Observable) -> None: - if observable is self._reinitObservable: - self._openFluorescenceFileFromSettings() - elif observable is self._upscalingStrategyChooser: - strategy = self._upscalingStrategyChooser.currentPlugin.simpleName - self._settings.upscalingStrategy.setValue(strategy) - self.notifyObservers() - elif observable is self._deconvolutionStrategyChooser: - strategy = self._deconvolutionStrategyChooser.currentPlugin.simpleName - self._settings.deconvolutionStrategy.setValue(strategy) - self.notifyObservers() diff --git a/ptychodus/model/analysis/settings.py b/ptychodus/model/analysis/settings.py index 2edcac95..ea0fcd44 100644 --- a/ptychodus/model/analysis/settings.py +++ b/ptychodus/model/analysis/settings.py @@ -1,5 +1,3 @@ -from pathlib import Path - from ptychodus.api.observer import Observable, Observer from ptychodus.api.settings import SettingsRegistry @@ -21,32 +19,3 @@ def __init__(self, registry: SettingsRegistry) -> None: def update(self, observable: Observable) -> None: if observable is self._settingsGroup: self.notifyObservers() - - -class FluorescenceSettings(Observable, Observer): - def __init__(self, registry: SettingsRegistry) -> None: - super().__init__() - self._settingsGroup = registry.createGroup('Fluorescence') - self._settingsGroup.addObserver(self) - - self.filePath = self._settingsGroup.createPathParameter( - 'FilePath', Path('/path/to/dataset.h5') - ) - self.fileType = self._settingsGroup.createStringParameter('FileType', 'XRF-Maps') - self.useVSPI = self._settingsGroup.createBooleanParameter('UseVSPI', True) - self.vspiDampingFactor = self._settingsGroup.createRealParameter( - 'VSPIDampingFactor', 0.0, minimum=0.0 - ) - self.vspiMaximumIterations = self._settingsGroup.createIntegerParameter( - 'VSPIMaximumIterations', 100, minimum=1 - ) - self.upscalingStrategy = self._settingsGroup.createStringParameter( - 'UpscalingStrategy', 'Linear' - ) - self.deconvolutionStrategy = self._settingsGroup.createStringParameter( - 'DeconvolutionStrategy', 'Richardson-Lucy' - ) - - def update(self, observable: Observable) -> None: - if observable is self._settingsGroup: - self.notifyObservers() diff --git a/ptychodus/model/core.py b/ptychodus/model/core.py index 2faa7467..e8502692 100644 --- a/ptychodus/model/core.py +++ b/ptychodus/model/core.py @@ -24,7 +24,6 @@ from .analysis import ( AnalysisCore, ExposureAnalyzer, - FluorescenceEnhancer, FourierRingCorrelator, ProbePropagator, STXMSimulator, @@ -35,6 +34,7 @@ AutomationPresenter, AutomationProcessingPresenter, ) +from .fluorescence import FluorescenceCore, FluorescenceEnhancer from .memory import MemoryPresenter from .patterns import ( Detector, @@ -142,16 +142,20 @@ def __init__( self.ptychonnReconstructorLibrary, ], ) - self._analysisCore = AnalysisCore( + self._fluorescenceCore = FluorescenceCore( self.settingsRegistry, - self._reconstructorCore.dataMatcher, self._productCore.productRepository, - self._productCore.objectRepository, self._pluginRegistry.upscalingStrategies, self._pluginRegistry.deconvolutionStrategies, self._pluginRegistry.fluorescenceFileReaders, self._pluginRegistry.fluorescenceFileWriters, ) + self._analysisCore = AnalysisCore( + self.settingsRegistry, + self._reconstructorCore.dataMatcher, + self._productCore.productRepository, + self._productCore.objectRepository, + ) self._workflowCore = WorkflowCore( self.settingsRegistry, self._patternsCore.patternsAPI, @@ -307,7 +311,7 @@ def batchModeExecute( ) if fluorescenceInputFilePath is not None and fluorescenceOutputFilePath is not None: - self._analysisCore.enhanceFluorescence( + self._fluorescenceCore.enhanceFluorescence( outputProductIndex, fluorescenceInputFilePath, fluorescenceOutputFilePath, @@ -361,11 +365,11 @@ def fourierRingCorrelator(self) -> FourierRingCorrelator: @property def fluorescenceEnhancer(self) -> FluorescenceEnhancer: - return self._analysisCore.fluorescenceEnhancer + return self._fluorescenceCore.enhancer @property def fluorescenceVisualizationEngine(self) -> VisualizationEngine: - return self._analysisCore.fluorescenceVisualizationEngine + return self._fluorescenceCore.visualizationEngine @property def xmcdAnalyzer(self) -> XMCDAnalyzer: diff --git a/ptychodus/model/fluorescence/__init__.py b/ptychodus/model/fluorescence/__init__.py new file mode 100644 index 00000000..dee43986 --- /dev/null +++ b/ptychodus/model/fluorescence/__init__.py @@ -0,0 +1,10 @@ +from .core import FluorescenceCore, FluorescenceEnhancer +from .two_step import TwoStepFluorescenceEnhancingAlgorithm +from .vspi import VSPIFluorescenceEnhancingAlgorithm + +__all__ = [ + 'FluorescenceCore', + 'FluorescenceEnhancer', + 'TwoStepFluorescenceEnhancingAlgorithm', + 'VSPIFluorescenceEnhancingAlgorithm', +] diff --git a/ptychodus/model/fluorescence/core.py b/ptychodus/model/fluorescence/core.py new file mode 100644 index 00000000..02c396b6 --- /dev/null +++ b/ptychodus/model/fluorescence/core.py @@ -0,0 +1,223 @@ +from __future__ import annotations +from collections.abc import Sequence +from pathlib import Path +import logging + + +from ptychodus.api.fluorescence import ( + DeconvolutionStrategy, + ElementMap, + FluorescenceDataset, + FluorescenceEnhancingAlgorithm, + FluorescenceFileReader, + FluorescenceFileWriter, + UpscalingStrategy, +) +from ptychodus.api.geometry import PixelGeometry +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.plugins import PluginChooser +from ptychodus.api.settings import SettingsRegistry + +from ..product import ProductRepository, ProductRepositoryItem +from ..visualization import VisualizationEngine +from .settings import FluorescenceSettings +from .two_step import TwoStepFluorescenceEnhancingAlgorithm +from .vspi import VSPIFluorescenceEnhancingAlgorithm + +logger = logging.getLogger(__name__) + + +class FluorescenceEnhancer(Observable, Observer): + def __init__( + self, + settings: FluorescenceSettings, + productRepository: ProductRepository, + twoStepEnhancingAlgorithm: TwoStepFluorescenceEnhancingAlgorithm, + vspiEnhancingAlgorithm: VSPIFluorescenceEnhancingAlgorithm, + fileReaderChooser: PluginChooser[FluorescenceFileReader], + fileWriterChooser: PluginChooser[FluorescenceFileWriter], + reinitObservable: Observable, + ) -> None: + super().__init__() + self._settings = settings + self._productRepository = productRepository + self.twoStepEnhancingAlgorithm = twoStepEnhancingAlgorithm + self.vspiEnhancingAlgorithm = vspiEnhancingAlgorithm + self._fileReaderChooser = fileReaderChooser + self._fileWriterChooser = fileWriterChooser + self._reinitObservable = reinitObservable + + self._algorithmChooser = PluginChooser[FluorescenceEnhancingAlgorithm]() + self._algorithmChooser.registerPlugin( + twoStepEnhancingAlgorithm, + simpleName=TwoStepFluorescenceEnhancingAlgorithm.SIMPLE_NAME, + displayName=TwoStepFluorescenceEnhancingAlgorithm.DISPLAY_NAME, + ) + self._algorithmChooser.registerPlugin( + vspiEnhancingAlgorithm, + simpleName=VSPIFluorescenceEnhancingAlgorithm.SIMPLE_NAME, + displayName=VSPIFluorescenceEnhancingAlgorithm.DISPLAY_NAME, + ) + self._syncAlgorithmFromSettings() + self._algorithmChooser.addObserver(self) + + self._productIndex = -1 + self._measured: FluorescenceDataset | None = None + self._enhanced: FluorescenceDataset | None = None + + fileReaderChooser.setCurrentPluginByName(settings.fileType.getValue()) + fileWriterChooser.setCurrentPluginByName(settings.fileType.getValue()) + reinitObservable.addObserver(self) + + @property + def _product(self) -> ProductRepositoryItem: + return self._productRepository[self._productIndex] + + def setProduct(self, productIndex: int) -> None: + if self._productIndex != productIndex: + self._productIndex = productIndex + self._enhanced = None + self.notifyObservers() + + def getProductName(self) -> str: + return self._product.getName() + + def getPixelGeometry(self) -> PixelGeometry: + return self._product.getGeometry().getPixelGeometry() + + def getOpenFileFilterList(self) -> Sequence[str]: + return self._fileReaderChooser.getDisplayNameList() + + def getOpenFileFilter(self) -> str: + return self._fileReaderChooser.currentPlugin.displayName + + def openMeasuredDataset(self, filePath: Path, fileFilter: str) -> None: + if filePath.is_file(): + self._fileReaderChooser.setCurrentPluginByName(fileFilter) + fileType = self._fileReaderChooser.currentPlugin.simpleName + logger.debug(f'Reading "{filePath}" as "{fileType}"') + fileReader = self._fileReaderChooser.currentPlugin.strategy + + try: + measured = fileReader.read(filePath) + except Exception as exc: + raise RuntimeError(f'Failed to read "{filePath}"') from exc + else: + self._measured = measured + self._enhanced = None + + self._settings.filePath.setValue(filePath) + self._settings.fileType.setValue(fileType) + + self.notifyObservers() + else: + logger.warning(f'Refusing to load dataset from invalid file path "{filePath}"') + + def getNumberOfChannels(self) -> int: + return 0 if self._measured is None else len(self._measured.element_maps) + + def getMeasuredElementMap(self, channelIndex: int) -> ElementMap: + if self._measured is None: + raise ValueError('Fluorescence dataset not loaded!') + + return self._measured.element_maps[channelIndex] + + def getAlgorithmList(self) -> Sequence[str]: + return self._algorithmChooser.getDisplayNameList() + + def getAlgorithm(self) -> str: + return self._algorithmChooser.currentPlugin.displayName + + def setAlgorithm(self, name: str) -> None: + self._algorithmChooser.setCurrentPluginByName(name) + self._settings.algorithm.setValue(self._algorithmChooser.currentPlugin.simpleName) + + def _syncAlgorithmFromSettings(self) -> None: + self.setAlgorithm(self._settings.algorithm.getValue()) + + def enhanceFluorescence(self) -> None: + if self._measured is None: + raise ValueError('Fluorescence dataset not loaded!') + else: + algorithm = self._algorithmChooser.currentPlugin.strategy + product = self._product.getProduct() + self._enhanced = algorithm.enhance(self._measured, product) + self.notifyObservers() + + def getEnhancedElementMap(self, channelIndex: int) -> ElementMap: + if self._enhanced is None: + return self.getMeasuredElementMap(channelIndex) + + return self._enhanced.element_maps[channelIndex] + + def getSaveFileFilterList(self) -> Sequence[str]: + return self._fileWriterChooser.getDisplayNameList() + + def getSaveFileFilter(self) -> str: + return self._fileWriterChooser.currentPlugin.displayName + + def saveEnhancedDataset(self, filePath: Path, fileFilter: str) -> None: + if self._enhanced is None: + raise ValueError('Fluorescence dataset not enhanced!') + + self._fileWriterChooser.setCurrentPluginByName(fileFilter) + fileType = self._fileWriterChooser.currentPlugin.simpleName + logger.debug(f'Writing "{filePath}" as "{fileType}"') + writer = self._fileWriterChooser.currentPlugin.strategy + writer.write(filePath, self._enhanced) + + def _openFluorescenceFileFromSettings(self) -> None: + self.openMeasuredDataset( + self._settings.filePath.getValue(), self._settings.fileType.getValue() + ) + + def update(self, observable: Observable) -> None: + if observable is self._algorithmChooser: + self.notifyObservers() + elif observable is self._reinitObservable: + self._syncAlgorithmFromSettings() + self._openFluorescenceFileFromSettings() + + +class FluorescenceCore: + def __init__( + self, + settingsRegistry: SettingsRegistry, + productRepository: ProductRepository, + upscalingStrategyChooser: PluginChooser[UpscalingStrategy], + deconvolutionStrategyChooser: PluginChooser[DeconvolutionStrategy], + fileReaderChooser: PluginChooser[FluorescenceFileReader], + fileWriterChooser: PluginChooser[FluorescenceFileWriter], + ) -> None: + self._settings = FluorescenceSettings(settingsRegistry) + self._twoStepEnhancingAlgorithm = TwoStepFluorescenceEnhancingAlgorithm( + self._settings, upscalingStrategyChooser, deconvolutionStrategyChooser, settingsRegistry + ) + self._vspiEnhancingAlgorithm = VSPIFluorescenceEnhancingAlgorithm(self._settings) + + self.enhancer = FluorescenceEnhancer( + self._settings, + productRepository, + self._twoStepEnhancingAlgorithm, + self._vspiEnhancingAlgorithm, + fileReaderChooser, + fileWriterChooser, + settingsRegistry, + ) + self.visualizationEngine = VisualizationEngine(isComplex=False) + + def enhanceFluorescence( + self, productIndex: int, inputFilePath: Path, outputFilePath: Path + ) -> int: + fileType = 'XRF-Maps' + + try: + self.enhancer.setProduct(productIndex) + self.enhancer.openMeasuredDataset(inputFilePath, fileType) + self.enhancer.enhanceFluorescence() + self.enhancer.saveEnhancedDataset(outputFilePath, fileType) + except Exception as exc: + logger.exception(exc) + return -1 + + return 0 diff --git a/ptychodus/model/fluorescence/settings.py b/ptychodus/model/fluorescence/settings.py new file mode 100644 index 00000000..71540f6b --- /dev/null +++ b/ptychodus/model/fluorescence/settings.py @@ -0,0 +1,33 @@ +from pathlib import Path + +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.settings import SettingsRegistry + + +class FluorescenceSettings(Observable, Observer): + def __init__(self, registry: SettingsRegistry) -> None: + super().__init__() + self._settingsGroup = registry.createGroup('Fluorescence') + self._settingsGroup.addObserver(self) + + self.filePath = self._settingsGroup.createPathParameter( + 'FilePath', Path('/path/to/dataset.h5') + ) + self.fileType = self._settingsGroup.createStringParameter('FileType', 'XRF-Maps') + self.algorithm = self._settingsGroup.createStringParameter('Algorithm', 'VSPI') + self.vspiDampingFactor = self._settingsGroup.createRealParameter( + 'VSPIDampingFactor', 0.0, minimum=0.0 + ) + self.vspiMaxIterations = self._settingsGroup.createIntegerParameter( + 'VSPIMaxIterations', 100, minimum=1 + ) + self.upscalingStrategy = self._settingsGroup.createStringParameter( + 'UpscalingStrategy', 'Linear' + ) + self.deconvolutionStrategy = self._settingsGroup.createStringParameter( + 'DeconvolutionStrategy', 'Richardson-Lucy' + ) + + def update(self, observable: Observable) -> None: + if observable is self._settingsGroup: + self.notifyObservers() diff --git a/ptychodus/model/fluorescence/two_step.py b/ptychodus/model/fluorescence/two_step.py new file mode 100644 index 00000000..28c08e98 --- /dev/null +++ b/ptychodus/model/fluorescence/two_step.py @@ -0,0 +1,110 @@ +from __future__ import annotations +from collections.abc import Sequence +from typing import Final +import logging +import time + +from ptychodus.api.fluorescence import ( + DeconvolutionStrategy, + ElementMap, + FluorescenceDataset, + FluorescenceEnhancingAlgorithm, + UpscalingStrategy, +) +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.plugins import PluginChooser +from ptychodus.api.product import Product + +from .settings import FluorescenceSettings + +logger = logging.getLogger(__name__) + +__all__ = [ + 'TwoStepFluorescenceEnhancingAlgorithm', +] + + +class TwoStepFluorescenceEnhancingAlgorithm(FluorescenceEnhancingAlgorithm, Observable, Observer): + SIMPLE_NAME: Final[str] = 'TwoStep' + DISPLAY_NAME: Final[str] = 'Upscale and Deconvolve' + + def __init__( + self, + settings: FluorescenceSettings, + upscalingStrategyChooser: PluginChooser[UpscalingStrategy], + deconvolutionStrategyChooser: PluginChooser[DeconvolutionStrategy], + reinitObservable: Observable, + ) -> None: + super().__init__() + self._settings = settings + self._upscalingStrategyChooser = upscalingStrategyChooser + self._deconvolutionStrategyChooser = deconvolutionStrategyChooser + self._reinitObservable = reinitObservable + + self._syncUpscalingStrategyFromSettings() + upscalingStrategyChooser.addObserver(self) + + self._syncDeconvolutionStrategyFromSettings() + deconvolutionStrategyChooser.addObserver(self) + + reinitObservable.addObserver(self) + + def enhance(self, dataset: FluorescenceDataset, product: Product) -> FluorescenceDataset: + upscaler = self._upscalingStrategyChooser.currentPlugin.strategy + deconvolver = self._deconvolutionStrategyChooser.currentPlugin.strategy + element_maps: list[ElementMap] = list() + + for emap in dataset.element_maps: + logger.info(f'Enhancing "{emap.name}"...') + tic = time.perf_counter() + emap_upscaled = upscaler(emap, product) + emap_enhanced = deconvolver(emap_upscaled, product) + toc = time.perf_counter() + logger.info(f'Enhanced "{emap.name}" in {toc - tic:.4f} seconds.') + + element_maps.append(emap_enhanced) + + return FluorescenceDataset( + element_maps=element_maps, + counts_per_second_path=dataset.counts_per_second_path, + channel_names_path=dataset.channel_names_path, + ) + + def getUpscalingStrategyList(self) -> Sequence[str]: + return self._upscalingStrategyChooser.getDisplayNameList() + + def getUpscalingStrategy(self) -> str: + return self._upscalingStrategyChooser.currentPlugin.displayName + + def setUpscalingStrategy(self, name: str) -> None: + self._upscalingStrategyChooser.setCurrentPluginByName(name) + self._settings.upscalingStrategy.setValue( + self._upscalingStrategyChooser.currentPlugin.simpleName + ) + + def _syncUpscalingStrategyFromSettings(self) -> None: + self.setUpscalingStrategy(self._settings.upscalingStrategy.getValue()) + + def getDeconvolutionStrategyList(self) -> Sequence[str]: + return self._deconvolutionStrategyChooser.getDisplayNameList() + + def getDeconvolutionStrategy(self) -> str: + return self._deconvolutionStrategyChooser.currentPlugin.displayName + + def setDeconvolutionStrategy(self, name: str) -> None: + self._deconvolutionStrategyChooser.setCurrentPluginByName(name) + self._settings.deconvolutionStrategy.setValue( + self._deconvolutionStrategyChooser.currentPlugin.simpleName + ) + + def _syncDeconvolutionStrategyFromSettings(self) -> None: + self.setDeconvolutionStrategy(self._settings.deconvolutionStrategy.getValue()) + + def update(self, observable: Observable) -> None: + if observable is self._reinitObservable: + self._syncUpscalingStrategyFromSettings() + self._syncDeconvolutionStrategyFromSettings() + elif observable is self._upscalingStrategyChooser: + self.notifyObservers() + elif observable is self._deconvolutionStrategyChooser: + self.notifyObservers() diff --git a/ptychodus/model/fluorescence/vspi.py b/ptychodus/model/fluorescence/vspi.py new file mode 100644 index 00000000..15247c66 --- /dev/null +++ b/ptychodus/model/fluorescence/vspi.py @@ -0,0 +1,180 @@ +from __future__ import annotations +from typing import Final +import logging +import time + +from scipy.sparse.linalg import lsmr, LinearOperator +import numpy + +from ptychodus.api.fluorescence import ( + ElementMap, + FluorescenceDataset, + FluorescenceEnhancingAlgorithm, +) +from ptychodus.api.object import ObjectPoint +from ptychodus.api.observer import Observable, Observer +from ptychodus.api.product import Product +from ptychodus.api.typing import RealArrayType + +from .settings import FluorescenceSettings + +logger = logging.getLogger(__name__) + +__all__ = [ + 'VSPIFluorescenceEnhancingAlgorithm', +] + + +class ArrayPatchInterpolator: + def __init__(self, array: RealArrayType, point: ObjectPoint, shape: tuple[int, ...]) -> None: + # top left corner of object support + xmin = point.positionXInPixels - shape[-1] / 2 + ymin = point.positionYInPixels - shape[-2] / 2 + + # whole components (pixel indexes) + xmin_wh = int(xmin) + ymin_wh = int(ymin) + + # fractional (subpixel) components + xmin_fr = xmin - xmin_wh + ymin_fr = ymin - ymin_wh + + # bottom right corner of object patch support + xmax_wh = xmin_wh + shape[-1] + 1 + ymax_wh = ymin_wh + shape[-2] + 1 + + # reused quantities + xmin_fr_c = 1.0 - xmin_fr + ymin_fr_c = 1.0 - ymin_fr + + # barycentric interpolant weights + self._weight00 = ymin_fr_c * xmin_fr_c + self._weight01 = ymin_fr_c * xmin_fr + self._weight10 = ymin_fr * xmin_fr_c + self._weight11 = ymin_fr * xmin_fr + + # extract patch support region from full object + self._support = array[ymin_wh:ymax_wh, xmin_wh:xmax_wh] + + def get_patch(self) -> RealArrayType: + """interpolate array support to extract patch""" + patch = self._weight00 * self._support[:-1, :-1] + patch += self._weight01 * self._support[:-1, 1:] + patch += self._weight10 * self._support[1:, :-1] + patch += self._weight11 * self._support[1:, 1:] + return patch + + def accumulate_patch(self, patch: RealArrayType) -> None: + """add patch update to array support""" + self._support[:-1, :-1] += self._weight00 * patch + self._support[:-1, 1:] += self._weight01 * patch + self._support[1:, :-1] += self._weight10 * patch + self._support[1:, 1:] += self._weight11 * patch + + +class VSPILinearOperator(LinearOperator): + def __init__(self, product: Product) -> None: + """ + M: number of XRF positions + N: number of ptychography object pixels + P: number of XRF channels + + A[M,N] * X[N,P] = B[M,P] + """ + object_geometry = product.object_.getGeometry() + M = len(product.scan) + N = object_geometry.heightInPixels * object_geometry.widthInPixels + super().__init__(float, (M, N)) + self._product = product + + def _get_psf(self) -> RealArrayType: + intensity = self._product.probe.getIntensity() + return intensity / intensity.sum() + + def _matvec(self, X: RealArrayType) -> RealArrayType: + object_geometry = self._product.object_.getGeometry() + object_array = X.reshape((object_geometry.heightInPixels, object_geometry.widthInPixels)) + psf = self._get_psf() + AX = numpy.zeros(len(self._product.scan)) + + for index, scan_point in enumerate(self._product.scan): + object_point = object_geometry.mapScanPointToObjectPoint(scan_point) + interpolator = ArrayPatchInterpolator(object_array, object_point, psf.shape) + AX[index] = numpy.sum(psf * interpolator.get_patch()) + + return AX + + def _rmatvec(self, X: RealArrayType) -> RealArrayType: + object_geometry = self._product.object_.getGeometry() + object_array = numpy.zeros((object_geometry.heightInPixels, object_geometry.widthInPixels)) + psf = self._get_psf() + + for index, scan_point in enumerate(self._product.scan): + object_point = object_geometry.mapScanPointToObjectPoint(scan_point) + interpolator = ArrayPatchInterpolator(object_array, object_point, psf.shape) + interpolator.accumulate_patch(X[index] * psf) + + HX = object_array.flatten() + + return HX + + +class VSPIFluorescenceEnhancingAlgorithm(FluorescenceEnhancingAlgorithm, Observable, Observer): + SIMPLE_NAME: Final[str] = 'VSPI' + DISPLAY_NAME: Final[str] = 'Virtual Single Pixel Imaging' + + def __init__(self, settings: FluorescenceSettings) -> None: + super().__init__() + self._settings = settings + + settings.vspiDampingFactor.addObserver(self) + settings.vspiMaxIterations.addObserver(self) + + def enhance(self, dataset: FluorescenceDataset, product: Product) -> FluorescenceDataset: + object_geometry = product.object_.getGeometry() + e_cps_shape = object_geometry.heightInPixels, object_geometry.widthInPixels + element_maps: list[ElementMap] = list() + A = VSPILinearOperator(product) + + for emap in dataset.element_maps: + logger.info(f'Enhancing "{emap.name}"...') + tic = time.perf_counter() + m_cps = emap.counts_per_second + result = lsmr( + A, + m_cps.flatten(), + damp=self._settings.vspiDampingFactor.getValue(), + maxiter=self._settings.vspiMaxIterations.getValue(), + show=True, + ) + logger.debug(result) + e_cps = result[0].reshape(e_cps_shape) + emap_enhanced = ElementMap(emap.name, e_cps) + toc = time.perf_counter() + logger.info(f'Enhanced "{emap.name}" in {toc - tic:.4f} seconds.') + + element_maps.append(emap_enhanced) + + return FluorescenceDataset( + element_maps=element_maps, + counts_per_second_path=dataset.counts_per_second_path, + channel_names_path=dataset.channel_names_path, + ) + + def getDampingFactor(self) -> float: + return self._settings.vspiDampingFactor.getValue() + + def setDampingFactor(self, factor: float) -> None: + self._settings.vspiDampingFactor.setValue(factor) + + def getMaxIterations(self) -> int: + return self._settings.vspiMaxIterations.getValue() + + def setMaxIterations(self, number: int) -> None: + self._settings.vspiMaxIterations.setValue(number) + + def update(self, observable: Observable) -> None: + if observable is self._settings.vspiDampingFactor: + self.notifyObservers() + elif observable is self._settings.vspiMaxIterations: + self.notifyObservers() diff --git a/ptychodus/view/probe.py b/ptychodus/view/probe.py index 4739359a..ff7cc83d 100644 --- a/ptychodus/view/probe.py +++ b/ptychodus/view/probe.py @@ -13,6 +13,7 @@ QRadioButton, QSlider, QSpinBox, + QStackedWidget, QStatusBar, QVBoxLayout, QWidget, @@ -171,27 +172,49 @@ def __init__(self, parent: QWidget | None = None) -> None: self.setLayout(layout) -class FluorescenceParametersView(QGroupBox): +class FluorescenceVSPIParametersView(QWidget): def __init__(self, parent: QWidget | None = None) -> None: - super().__init__('Parameters', parent) - self.openButton = QPushButton('Open') - self.enhancementStrategyComboBox = QComboBox() - self.vspiDampingFactorLineEdit = DecimalLineEdit.createInstance() - self.vspiMaxIterationsSpinBox = QSpinBox() + super().__init__(parent) + self.dampingFactorLineEdit = DecimalLineEdit.createInstance() + self.maxIterationsSpinBox = QSpinBox() + + layout = QFormLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addRow('Damping Factor:', self.dampingFactorLineEdit) + layout.addRow('Max Iterations:', self.maxIterationsSpinBox) + self.setLayout(layout) + + +class FluorescenceTwoStepParametersView(QWidget): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__(parent) self.upscalingStrategyComboBox = QComboBox() self.deconvolutionStrategyComboBox = QComboBox() - self.enhanceButton = QPushButton('Enhance') - self.saveButton = QPushButton('Save') layout = QFormLayout() - layout.addRow('Measured Dataset:', self.openButton) - layout.addRow('Enhancement Strategy:', self.enhancementStrategyComboBox) - layout.addRow('VSPI Damping Factor:', self.vspiDampingFactorLineEdit) - layout.addRow('VSPI Max Iterations:', self.vspiMaxIterationsSpinBox) + layout.setContentsMargins(0, 0, 0, 0) layout.addRow('Upscaling Strategy:', self.upscalingStrategyComboBox) layout.addRow('Deconvolution Strategy:', self.deconvolutionStrategyComboBox) + self.setLayout(layout) + + +class FluorescenceParametersView(QGroupBox): + def __init__(self, parent: QWidget | None = None) -> None: + super().__init__('Enhancement Strategy', parent) + self.openButton = QPushButton('Open Measured Dataset') + self.algorithmComboBox = QComboBox() + self.stackedWidget = QStackedWidget() + self.enhanceButton = QPushButton('Enhance') + self.saveButton = QPushButton('Save Enhanced Dataset') + + self.stackedWidget.layout().setContentsMargins(0, 0, 0, 0) + + layout = QFormLayout() + layout.addRow(self.openButton) + layout.addRow('Algorithm:', self.algorithmComboBox) + layout.addRow(self.stackedWidget) layout.addRow(self.enhanceButton) - layout.addRow('Enhanced Dataset:', self.saveButton) + layout.addRow(self.saveButton) self.setLayout(layout)