Skip to content

Commit

Permalink
NEW: Use matplotlib fill_between to show object function uncertainty …
Browse files Browse the repository at this point in the history
…for multi-batch reconstructions (#70)

* NEW: Add UncertainPlot for lines with confidence

* NEW: Draw confidence band if plot is uncertain

* REF: Modify TikeReconstructor to return PlotUncertain2D
  • Loading branch information
carterbox authored Jan 26, 2024
1 parent d8c2c32 commit 082417d
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 20 deletions.
26 changes: 26 additions & 0 deletions ptychodus/api/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ class PlotSeries:
label: str
values: Sequence[float]

@dataclass(frozen=True)
class PlotUncertainSeries:
label: str
lo: Sequence[float]
values: Sequence[float]
hi: Sequence[float]

@dataclass(frozen=True)
class PlotAxis:
Expand All @@ -29,6 +35,16 @@ def createNull(cls) -> PlotAxis:
return cls('', [])


@dataclass(frozen=True)
class PlotUncertainAxis:
label: str
series: Sequence[PlotUncertainSeries]

@classmethod
def createNull(cls) -> PlotUncertainAxis:
return cls('', [])


@dataclass(frozen=True)
class Plot2D:
axisX: PlotAxis
Expand All @@ -39,6 +55,16 @@ def createNull(cls) -> Plot2D:
return cls(PlotAxis.createNull(), PlotAxis.createNull())


@dataclass(frozen=True)
class PlotUncertain2D:
axisX: PlotAxis
axisY: PlotUncertainAxis

@classmethod
def createNull(cls) -> PlotUncertain2D:
return cls(PlotAxis.createNull(), PlotUncertainAxis.createNull())


@dataclass(frozen=True)
class LineCut:
distanceInMeters: Sequence[float]
Expand Down
4 changes: 2 additions & 2 deletions ptychodus/api/reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .data import DiffractionPatternArrayType
from .image import ImageExtent
from .object import ObjectArrayType, ObjectInterpolator
from .plot import Plot2D
from .plot import Plot2D, PlotUncertain2D
from .probe import ProbeArrayType
from .scan import Scan

Expand Down Expand Up @@ -51,7 +51,7 @@ class ReconstructOutput:
probeArray: ProbeArrayType | None
objectArray: ObjectArrayType | None
objective: Sequence[Sequence[float]]
plot2D: Plot2D
plot2D: Plot2D | PlotUncertain2D
result: int

@classmethod
Expand Down
15 changes: 8 additions & 7 deletions ptychodus/controller/reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from collections.abc import Iterable
import logging
import itertools

from PyQt5.QtCore import Qt, QStringListModel
from PyQt5.QtGui import QPixmap
Expand Down Expand Up @@ -208,14 +209,14 @@ def _redrawPlot(self) -> None:
ax.set_ylabel(axisY.label)
ax.grid(True)

if len(axisX.series) == len(axisY.series):
for sx, sy in zip(axisX.series, axisY.series):
ax.plot(sx.values, sy.values, '.-', label=sy.label, linewidth=1.5)
elif len(axisX.series) == 1:
sx = axisX.series[0]

for sy in axisY.series:
if (
(len(axisX.series) == len(axisY.series)) or
(len(axisX.series) == 1 and len(axisY.series) >= 1)
):
for sx, sy in zip(itertools.cycle(axisX.series), axisY.series):
ax.plot(sx.values, sy.values, '.-', label=sy.label, linewidth=1.5)
if hasattr(sy, 'hi') and hasattr(sy, 'lo'):
ax.fill_between(sx.values, sy.lo, sy.hi, alpha=0.2)
else:
logger.error('Failed to broadcast plot series!')

Expand Down
4 changes: 2 additions & 2 deletions ptychodus/model/reconstructor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging

from ...api.observer import Observable, Observer
from ...api.plot import Plot2D, PlotAxis, PlotSeries
from ...api.plot import Plot2D, PlotAxis, PlotSeries, PlotUncertain2D
from ...api.reconstructor import ReconstructorLibrary
from ...api.settings import SettingsRegistry
from ..data import ActiveDiffractionDataset
Expand Down Expand Up @@ -66,7 +66,7 @@ def reconstructSplit(self) -> None:
)
self.notifyObservers()

def getPlot(self) -> Plot2D:
def getPlot(self) -> Plot2D | PlotUncertain2D:
return self._plot2D

@property
Expand Down
21 changes: 12 additions & 9 deletions ptychodus/model/tike/reconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ...api.object import ObjectArrayType
from ...api.object import ObjectPoint
from ...api.plot import Plot2D, PlotAxis, PlotSeries
from ...api.plot import PlotUncertain2D, PlotUncertainAxis, PlotUncertainSeries, PlotAxis, PlotSeries
from ...api.probe import ProbeArrayType
from ...api.reconstructor import Reconstructor, ReconstructInput, ReconstructOutput
from ...api.scan import Scan, ScanPoint, TabularScan
Expand Down Expand Up @@ -107,8 +107,8 @@ def getNumGpus(self) -> Union[int, tuple[int, ...]]:

return 1

def _plotCosts(self, costs: Sequence[Sequence[float]]) -> Plot2D:
plot = Plot2D.createNull()
def _plotCosts(self, costs: Sequence[Sequence[float]]) -> PlotUncertain2D:
plot = PlotUncertain2D.createNull()
numIterations = len(costs)

if numIterations > 0:
Expand All @@ -117,22 +117,25 @@ def _plotCosts(self, costs: Sequence[Sequence[float]]) -> Plot2D:
midCost: list[float] = list()
maxCost: list[float] = list()

seriesYList: list[PlotSeries] = list()
seriesYList: list[PlotUncertainSeries] = list()

for values in costs:
minCost.append(min(values))
midCost.append(float(numpy.median(values)))
maxCost.append(max(values))

seriesYList = [
PlotSeries(label='Minimum', values=minCost),
PlotSeries(label='Median', values=midCost),
PlotSeries(label='Maximum', values=maxCost),
PlotUncertainSeries(
label='Median',
lo=minCost,
values=midCost,
hi=maxCost,
),
]

plot = Plot2D(
plot = PlotUncertain2D(
axisX=PlotAxis(label='Iteration', series=[seriesX]),
axisY=PlotAxis(label='Cost', series=seriesYList),
axisY=PlotUncertainAxis(label='Cost', series=seriesYList),
)

return plot
Expand Down

0 comments on commit 082417d

Please sign in to comment.