diff --git a/ptychodus/api/plot.py b/ptychodus/api/plot.py index e53301dd..6ad09d25 100644 --- a/ptychodus/api/plot.py +++ b/ptychodus/api/plot.py @@ -18,6 +18,12 @@ class PlotSeries: label: str values: Sequence[float] +@dataclass(frozen=True) +class PlotUncertainSeries: + label: str + lo: Sequence[float] + values: Sequence[float] + hi: Sequence[float] @dataclass(frozen=True) class PlotAxis: @@ -29,6 +35,16 @@ def createNull(cls) -> PlotAxis: return cls('', []) +@dataclass(frozen=True) +class PlotUncertainAxis: + label: str + series: Sequence[PlotUncertainSeries] + + @classmethod + def createNull(cls) -> PlotUncertainAxis: + return cls('', []) + + @dataclass(frozen=True) class Plot2D: axisX: PlotAxis @@ -39,6 +55,16 @@ def createNull(cls) -> Plot2D: return cls(PlotAxis.createNull(), PlotAxis.createNull()) +@dataclass(frozen=True) +class PlotUncertain2D: + axisX: PlotAxis + axisY: PlotUncertainAxis + + @classmethod + def createNull(cls) -> PlotUncertain2D: + return cls(PlotAxis.createNull(), PlotUncertainAxis.createNull()) + + @dataclass(frozen=True) class LineCut: distanceInMeters: Sequence[float] diff --git a/ptychodus/api/reconstructor.py b/ptychodus/api/reconstructor.py index 95fd56d7..defddd6b 100644 --- a/ptychodus/api/reconstructor.py +++ b/ptychodus/api/reconstructor.py @@ -7,7 +7,7 @@ from .data import DiffractionPatternArrayType from .image import ImageExtent from .object import ObjectArrayType, ObjectInterpolator -from .plot import Plot2D +from .plot import Plot2D, PlotUncertain2D from .probe import ProbeArrayType from .scan import Scan @@ -51,7 +51,7 @@ class ReconstructOutput: probeArray: ProbeArrayType | None objectArray: ObjectArrayType | None objective: Sequence[Sequence[float]] - plot2D: Plot2D + plot2D: Plot2D | PlotUncertain2D result: int @classmethod diff --git a/ptychodus/controller/reconstructor.py b/ptychodus/controller/reconstructor.py index a2216634..4351069f 100644 --- a/ptychodus/controller/reconstructor.py +++ b/ptychodus/controller/reconstructor.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable import logging +import itertools from PyQt5.QtCore import Qt, QStringListModel from PyQt5.QtGui import QPixmap @@ -208,14 +209,14 @@ def _redrawPlot(self) -> None: ax.set_ylabel(axisY.label) ax.grid(True) - if len(axisX.series) == len(axisY.series): - for sx, sy in zip(axisX.series, axisY.series): - ax.plot(sx.values, sy.values, '.-', label=sy.label, linewidth=1.5) - elif len(axisX.series) == 1: - sx = axisX.series[0] - - for sy in axisY.series: + if ( + (len(axisX.series) == len(axisY.series)) or + (len(axisX.series) == 1 and len(axisY.series) >= 1) + ): + for sx, sy in zip(itertools.cycle(axisX.series), axisY.series): ax.plot(sx.values, sy.values, '.-', label=sy.label, linewidth=1.5) + if hasattr(sy, 'hi') and hasattr(sy, 'lo'): + ax.fill_between(sx.values, sy.lo, sy.hi, alpha=0.2) else: logger.error('Failed to broadcast plot series!') diff --git a/ptychodus/model/reconstructor/core.py b/ptychodus/model/reconstructor/core.py index f94ef300..2f9560ff 100644 --- a/ptychodus/model/reconstructor/core.py +++ b/ptychodus/model/reconstructor/core.py @@ -4,7 +4,7 @@ import logging from ...api.observer import Observable, Observer -from ...api.plot import Plot2D, PlotAxis, PlotSeries +from ...api.plot import Plot2D, PlotAxis, PlotSeries, PlotUncertain2D from ...api.reconstructor import ReconstructorLibrary from ...api.settings import SettingsRegistry from ..data import ActiveDiffractionDataset @@ -66,7 +66,7 @@ def reconstructSplit(self) -> None: ) self.notifyObservers() - def getPlot(self) -> Plot2D: + def getPlot(self) -> Plot2D | PlotUncertain2D: return self._plot2D @property diff --git a/ptychodus/model/tike/reconstructor.py b/ptychodus/model/tike/reconstructor.py index e0f57a73..1a9f2d51 100644 --- a/ptychodus/model/tike/reconstructor.py +++ b/ptychodus/model/tike/reconstructor.py @@ -11,7 +11,7 @@ from ...api.object import ObjectArrayType from ...api.object import ObjectPoint -from ...api.plot import Plot2D, PlotAxis, PlotSeries +from ...api.plot import PlotUncertain2D, PlotUncertainAxis, PlotUncertainSeries, PlotAxis, PlotSeries from ...api.probe import ProbeArrayType from ...api.reconstructor import Reconstructor, ReconstructInput, ReconstructOutput from ...api.scan import Scan, ScanPoint, TabularScan @@ -107,8 +107,8 @@ def getNumGpus(self) -> Union[int, tuple[int, ...]]: return 1 - def _plotCosts(self, costs: Sequence[Sequence[float]]) -> Plot2D: - plot = Plot2D.createNull() + def _plotCosts(self, costs: Sequence[Sequence[float]]) -> PlotUncertain2D: + plot = PlotUncertain2D.createNull() numIterations = len(costs) if numIterations > 0: @@ -117,7 +117,7 @@ def _plotCosts(self, costs: Sequence[Sequence[float]]) -> Plot2D: midCost: list[float] = list() maxCost: list[float] = list() - seriesYList: list[PlotSeries] = list() + seriesYList: list[PlotUncertainSeries] = list() for values in costs: minCost.append(min(values)) @@ -125,14 +125,17 @@ def _plotCosts(self, costs: Sequence[Sequence[float]]) -> Plot2D: maxCost.append(max(values)) seriesYList = [ - PlotSeries(label='Minimum', values=minCost), - PlotSeries(label='Median', values=midCost), - PlotSeries(label='Maximum', values=maxCost), + PlotUncertainSeries( + label='Median', + lo=minCost, + values=midCost, + hi=maxCost, + ), ] - plot = Plot2D( + plot = PlotUncertain2D( axisX=PlotAxis(label='Iteration', series=[seriesX]), - axisY=PlotAxis(label='Cost', series=seriesYList), + axisY=PlotUncertainAxis(label='Cost', series=seriesYList), ) return plot