Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial stab at DerivedSignal #525 #661

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from ._derived_signal import (
DerivedBackend,
DerivedSignalBackend,
Transform,
TransformArgument,
)
from ._detector import (
DetectorController,
DetectorTrigger,
Expand Down Expand Up @@ -99,6 +105,10 @@
)

__all__ = [
"DerivedBackend",
"DerivedSignalBackend",
"Transform",
"TransformArgument",
"DetectorController",
"DetectorTrigger",
"DetectorWriter",
Expand Down
112 changes: 112 additions & 0 deletions src/ophyd_async/core/_derived_signal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import asyncio
from abc import abstractmethod
from typing import Generic, TypedDict, TypeVar, get_args

from ._device import Device
from ._protocol import AsyncMovable
from ._signal import SignalR, SignalRW
from ._signal_backend import SignalBackend, SignalDatatypeT


class TransformArgument(TypedDict, Generic[SignalDatatypeT]):
pass


T = TypeVar("T", bound=TransformArgument)


async def _get_dataclass_from_signals(cls: type[T], device: Device) -> T:
coros = {}
for name in cls.__annotations__:
signal = getattr(device, name)
assert isinstance(
signal, SignalR
), f"{device.name}.{name} is {signal}, not a Signal"
coros[name] = signal.get_value()
results = await asyncio.gather(*coros.values())
kwargs = dict(zip(coros, results, strict=True))
return cls(**kwargs)


RawT = TypeVar("RawT", bound=TransformArgument)
DerivedT = TypeVar("DerivedT", bound=TransformArgument)
ParametersT = TypeVar("ParametersT")


class TransformMeta(type):
def __init__(cls, *_):
if "__orig_bases__" not in cls.__dict__:
raise TypeError(
"Transform classes must be defined with Raw, "
"Derived, and Parameter `TransformArgument`s."
)
orig_base = cls.__orig_bases__[0] # type: ignore
cls.raw_cls, cls.derived_cls, cls.parameters_cls = get_args(orig_base)


class Transform(Generic[RawT, DerivedT, ParametersT], metaclass=TransformMeta):
raw_cls: type[RawT]
derived_cls: type[DerivedT]
parameters_cls: type[ParametersT]

@classmethod
@abstractmethod
def forward(cls, raw: RawT, parameters: ParametersT) -> DerivedT:
pass

@classmethod
@abstractmethod
def inverse(cls, derived: DerivedT, parameters: ParametersT) -> RawT:
pass


class DerivedBackend(Generic[RawT, DerivedT, ParametersT]):
def __init__(
self,
device: Device,
transform: Transform[RawT, DerivedT, ParametersT],
):
self._device = device
self._transform = transform

async def get_parameters(self) -> ParametersT:
return await _get_dataclass_from_signals(
self._transform.parameters_cls, self._device
)

async def get_raw_values(self) -> RawT:
return await _get_dataclass_from_signals(self._transform.raw_cls, self._device)

async def get_derived_values(self) -> DerivedT:
raw, parameters = await asyncio.gather(
self.get_raw_values(), self.get_parameters()
)
return self._transform.forward(raw, parameters)

async def set_derived_values(self, derived: DerivedT):
assert isinstance(self._device, AsyncMovable)
await self._device.set(derived)

async def calculate_raw_values(self, derived: DerivedT) -> RawT:
return self._transform.inverse(derived, await self.get_parameters())

def derived_signal(self, variable: str) -> SignalRW:
return SignalRW(DerivedSignalBackend(self, variable))


class DerivedSignalBackend(SignalBackend[float]):
def __init__(self, backend: DerivedBackend, transform_name: str):
self._backend = backend
self._transform_name = transform_name
super().__init__(float)

async def get_value(self) -> float:
derived = await self._backend.get_derived_values()
return getattr(derived, self._transform_name)

async def put(self, value: float | None, wait: bool):
derived = await self._backend.get_derived_values()
# TODO: we should be calling locate on these as we want to move relative to the
# setpoint, not readback
setattr(derived, self._transform_name, value)
await self._backend.set_derived_values(derived)
8 changes: 8 additions & 0 deletions src/ophyd_async/core/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ def unstage(self) -> AsyncStatus:
"""


@runtime_checkable
class AsyncMovable(Protocol):
@abstractmethod
def set(self, value) -> AsyncStatus:
"""Return a ``Status`` that is marked done when the device is done moving."""
...


C = TypeVar("C", contravariant=True)


Expand Down
109 changes: 109 additions & 0 deletions tests/core/test_derived_signal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import asyncio
from typing import TypeVar

import numpy as np
import pytest

from ophyd_async.core import (
Array1D,
AsyncStatus,
DerivedBackend,
Device,
Transform,
TransformArgument,
soft_signal_rw,
)


async def test_transform_argument_cls_inference():
class Raw(TransformArgument[float]): ...

class Derived(TransformArgument[float]): ...

with pytest.raises(
TypeError,
match=(
"Too few arguments for "
"<class 'ophyd_async.core._derived_signal.Transform'>; "
"actual 2, expected at least 3"
),
):

class SomeTransform1(Transform[Raw, Derived]): ... # type: ignore

with pytest.raises(
TypeError,
match=(
"Transform classes must be defined with Raw, Derived, "
"and Parameter `TransformArgument`s."
),
):

class SomeTransform2(Transform): ...

class Parameters(TransformArgument[float]): ...

class SomeTransform(Transform[Raw, Derived, Parameters]): ...

assert SomeTransform.raw_cls is Raw
assert SomeTransform.derived_cls is Derived
assert SomeTransform.parameters_cls is Parameters


F = TypeVar("F", float, Array1D[np.float64])


class SlitsRaw(TransformArgument[F]):
top: F
bottom: F


class SlitsDerived(TransformArgument[F]):
gap: F
centre: F


class SlitsParameters(TransformArgument):
gap_offset: float


class SlitsTransform(Transform[SlitsRaw[F], SlitsDerived[F], SlitsParameters]):
@classmethod
def forward(cls, raw: SlitsRaw[F], parameters: SlitsParameters) -> SlitsDerived[F]:
return SlitsDerived(
gap=raw["top"] - raw["bottom"] + parameters["gap_offset"],
centre=(raw["top"] + raw["bottom"]) / 2,
)

@classmethod
def inverse(
cls, derived: SlitsDerived[F], parameters: SlitsParameters
) -> SlitsRaw[F]:
half_gap = (derived["gap"] - parameters["gap_offset"]) / 2
return SlitsRaw(
top=derived["centre"] + half_gap,
bottom=derived["centre"] - half_gap,
)


class Slits(Device):
def __init__(self, name=""):
self._backend = DerivedBackend(self, SlitsTransform())
# Raw signals
self.top = soft_signal_rw(float)
self.bottom = soft_signal_rw(float)
# Parameter
self.gap_offset = soft_signal_rw(float)
# Derived signals
self.gap = self._backend.derived_signal("gap")
self.centre = self._backend.derived_signal("centre")
super().__init__(name=name)

@AsyncStatus.wrap
async def set(self, derived: SlitsDerived[float]) -> None:
raw: SlitsRaw[float] = await self._backend.calculate_raw_values(derived)
await asyncio.gather(self.top.set(raw["top"]), self.bottom.set(raw["bottom"]))


async def test_derived_signals():
Slits()
Loading