diff --git a/tests/core/test_signal.py b/tests/core/test_signal.py index 242d630b72..1502d58a4d 100644 --- a/tests/core/test_signal.py +++ b/tests/core/test_signal.py @@ -6,7 +6,18 @@ import numpy import pytest +from bluesky import Msg from bluesky.protocols import Reading +from bluesky.run_engine import call_in_bluesky_event_loop +from bluesky.suspenders import ( + SuspendBoolHigh, + SuspendBoolLow, + SuspendCeil, + SuspendFloor, + SuspendInBand, + SuspendOutBand, + SuspendWhenOutsideBand, +) from ophyd_async.core import ( DEFAULT_TIMEOUT, @@ -457,3 +468,58 @@ def some_function(self): assert isinstance((await signal.get_value()), SomeClass) await signal.set(1) assert (await signal.get_value()) == 1 + + +@pytest.mark.parametrize( + "klass,sc_args,start_val,fail_val,resume_val,wait_time", + [ + (SuspendBoolHigh, (), 0, 1, 0, 0.2), + (SuspendBoolLow, (), 1, 0, 1, 0.2), + (SuspendFloor, (0.5,), 1, 0, 1, 0.2), + (SuspendCeil, (0.5,), 0, 1, 0, 0.2), + (SuspendWhenOutsideBand, (0.5, 1.5), 1, 0, 1, 0.2), + ((SuspendInBand, True), (0.5, 1.5), 1, 0, 1, 0.2), # renamed to WhenOutsideBand + ((SuspendOutBand, True), (0.5, 1.5), 0, 1, 0, 0.2), + ], +) # deprecated +async def test_bluesky_suspenders( + klass, sc_args, start_val, fail_val, resume_val, wait_time, RE +): + sleep_time = 0.2 + fail_time = 0.1 + resume_time = 0.5 + signal = epics_signal_rw(int, "mock_signal") + await signal.connect(mock=True) + try: + klass, deprecated = klass + except TypeError: + deprecated = False + if deprecated: + with pytest.warns(UserWarning): + suspender = klass(signal, *sc_args, sleep=wait_time) + else: + suspender = klass(signal, *sc_args, sleep=wait_time) + + RE.install_suspender(suspender) + + await signal.set(start_val) + + async def _set_after_time(): + await asyncio.sleep(fail_time) + await signal.set(fail_val) + await asyncio.sleep(resume_time - fail_time) + await signal.set(resume_val) + + start = time.time() + + # loop = RE.loop + + call_in_bluesky_event_loop(_set_after_time()) + # task = RE.loop.create_task(_set_after_time()) + + RE([Msg("checkpoint"), Msg("sleep", None, sleep_time)]) + + stop = time.time() + delta = stop - start + # await task + assert delta >= resume_time + sleep_time + wait_time