diff --git a/src/dodal/devices/focusing_mirror.py b/src/dodal/devices/focusing_mirror.py index c9129db828..50dd3c1b79 100644 --- a/src/dodal/devices/focusing_mirror.py +++ b/src/dodal/devices/focusing_mirror.py @@ -1,12 +1,16 @@ -from enum import Enum, IntEnum -from typing import Any - -from ophyd import Component, Device, EpicsSignal -from ophyd.status import Status, StatusBase -from ophyd_async.core import StandardReadable +from enum import Enum + +from ophyd_async.core import ( + AsyncStatus, + Device, + DeviceVector, + StandardReadable, + observe_value, +) from ophyd_async.core.signal import soft_signal_r_and_setter from ophyd_async.epics.motion import Motor from ophyd_async.epics.signal import ( + epics_signal_r, epics_signal_rw, epics_signal_x, ) @@ -32,11 +36,11 @@ class MirrorStripe(str, Enum): PLATINUM = "Platinum" -class MirrorVoltageDemand(IntEnum): - N_A = 0 - OK = 1 - FAIL = 2 - SLEW = 3 +class MirrorVoltageDemand(str, Enum): + N_A = "N/A" + OK = "OK" + FAIL = "FAIL" + SLEW = "SLEW" class MirrorVoltageDevice(Device): @@ -44,14 +48,16 @@ class MirrorVoltageDevice(Device): the demanded voltage setpoint is accepted, without blocking the caller as this process can take significant time. """ - _actual_v: EpicsSignal = Component(EpicsSignal, "R") - _setpoint_v: EpicsSignal = Component(EpicsSignal, "D") - _demand_accepted: EpicsSignal = Component(EpicsSignal, "DSEV") + def __init__(self, name: str = "", prefix: str = ""): + self._actual_v = epics_signal_r(int, prefix + "R") + self._setpoint_v = epics_signal_rw(int, prefix + "D") + self._demand_accepted = epics_signal_r(MirrorVoltageDemand, prefix + "DSEV") + super().__init__(name=name) - def set(self, value, *args, **kwargs) -> StatusBase: + @AsyncStatus.wrap + async def set(self, value, *args, **kwargs): """Combine the following operations into a single set: 1. apply the value to the setpoint PV - 2. Return to the caller with a Status future 3. Wait until demand is accepted 4. When either demand is accepted or DEFAULT_SETTLE_TIME expires, signal the result on the Status """ @@ -59,66 +65,60 @@ def set(self, value, *args, **kwargs) -> StatusBase: setpoint_v = self._setpoint_v demand_accepted = self._demand_accepted - if demand_accepted.get() != MirrorVoltageDemand.OK: + if await demand_accepted.get_value() != MirrorVoltageDemand.OK: raise AssertionError( f"Attempted to set {setpoint_v.name} when demand is not accepted." ) - if setpoint_v.get() == value: + if await setpoint_v.get_value() == value: LOGGER.debug(f"{setpoint_v.name} already at {value} - skipping set") - return Status(success=True, done=True) + return LOGGER.debug(f"setting {setpoint_v.name} to {value}") - demand_accepted_status = Status(self, DEFAULT_SETTLE_TIME_S) - - subscription: dict[str, Any] = {"handle": None} - def demand_check_callback(old_value, value, **kwargs): - LOGGER.debug(f"Got event old={old_value} new={value} for {setpoint_v.name}") - if old_value != MirrorVoltageDemand.OK and value == MirrorVoltageDemand.OK: - LOGGER.debug(f"Demand accepted for {setpoint_v.name}") - subs_handle = subscription.pop("handle", None) - if subs_handle is None: - raise AssertionError("Demand accepted before set attempted") - demand_accepted.unsubscribe(subs_handle) - - demand_accepted_status.set_finished() - # else timeout handled by parent demand_accepted_status + # Register an observer up front to ensure we don't miss events after we + # perform the set + demand_accepted_iterator = observe_value( + demand_accepted, timeout=DEFAULT_SETTLE_TIME_S + ) + # discard the current value (OK) so we can await a subsequent change + await anext(demand_accepted_iterator) + await setpoint_v.set(value) + + # The set should always change to SLEW regardless of whether we are + # already at the set point, then change back to OK/FAIL depending on + # success + accepted_value = await anext(demand_accepted_iterator) + assert accepted_value == MirrorVoltageDemand.SLEW + LOGGER.debug( + f"Demand not accepted for {setpoint_v.name}, waiting for acceptance..." + ) + while MirrorVoltageDemand.SLEW == ( + accepted_value := await anext(demand_accepted_iterator) + ): + pass - subscription["handle"] = demand_accepted.subscribe(demand_check_callback) - setpoint_status = setpoint_v.set(value) - status = setpoint_status & demand_accepted_status - return status + if accepted_value != MirrorVoltageDemand.OK: + raise AssertionError( + f"Voltage slew failed for {setpoint_v.name}, new state={accepted_value}" + ) -class VFMMirrorVoltages(Device): - def __init__(self, *args, daq_configuration_path: str, **kwargs): - super().__init__(*args, **kwargs) +class VFMMirrorVoltages(StandardReadable): + def __init__( + self, name: str, prefix: str, *args, daq_configuration_path: str, **kwargs + ): self.voltage_lookup_table_path = ( daq_configuration_path + "/json/mirrorFocus.json" ) - - _channel14_voltage_device = Component(MirrorVoltageDevice, "BM:V14") - _channel15_voltage_device = Component(MirrorVoltageDevice, "BM:V15") - _channel16_voltage_device = Component(MirrorVoltageDevice, "BM:V16") - _channel17_voltage_device = Component(MirrorVoltageDevice, "BM:V17") - _channel18_voltage_device = Component(MirrorVoltageDevice, "BM:V18") - _channel19_voltage_device = Component(MirrorVoltageDevice, "BM:V19") - _channel20_voltage_device = Component(MirrorVoltageDevice, "BM:V20") - _channel21_voltage_device = Component(MirrorVoltageDevice, "BM:V21") - - @property - def voltage_channels(self) -> list[MirrorVoltageDevice]: - return [ - self._channel14_voltage_device, - self._channel15_voltage_device, - self._channel16_voltage_device, - self._channel17_voltage_device, - self._channel18_voltage_device, - self._channel19_voltage_device, - self._channel20_voltage_device, - self._channel21_voltage_device, - ] + with self.add_children_as_readables(): + self.voltage_channels = DeviceVector( + { + i - 14: MirrorVoltageDevice(prefix=f"{prefix}BM:V{i}") + for i in range(14, 22) + } + ) + super().__init__(*args, name=name, **kwargs) class FocusingMirror(StandardReadable): diff --git a/tests/conftest.py b/tests/conftest.py index 339b7c01b7..92a855bc0c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,12 +6,11 @@ import time from os import environ, getenv from pathlib import Path -from typing import Mapping, cast +from typing import Mapping from unittest.mock import MagicMock, patch import pytest from bluesky.run_engine import RunEngine -from ophyd.sim import make_fake_device from dodal.beamlines import i03 from dodal.common.beamlines import beamline_utils @@ -73,17 +72,11 @@ def pytest_runtest_teardown(): @pytest.fixture -def vfm_mirror_voltages() -> VFMMirrorVoltages: - voltages = cast( - VFMMirrorVoltages, - make_fake_device(VFMMirrorVoltages)( - name="vfm_mirror_voltages", - prefix="BL-I03-MO-PSU-01:", - daq_configuration_path=i03.DAQ_CONFIGURATION_PATH, - ), - ) +def vfm_mirror_voltages(RE: RunEngine) -> VFMMirrorVoltages: + voltages = i03.vfm_mirror_voltages(fake_with_ophyd_sim=True) voltages.voltage_lookup_table_path = "tests/test_data/test_mirror_focus.json" - return voltages + yield voltages + beamline_utils.clear_devices() s03_epics_server_port = getenv("S03_EPICS_CA_SERVER_PORT") diff --git a/tests/devices/unit_tests/test_focusing_mirror.py b/tests/devices/unit_tests/test_focusing_mirror.py index d13ce98184..c10180e7b4 100644 --- a/tests/devices/unit_tests/test_focusing_mirror.py +++ b/tests/devices/unit_tests/test_focusing_mirror.py @@ -1,9 +1,15 @@ -from threading import Timer -from unittest.mock import DEFAULT, MagicMock, patch +import asyncio + +# prevent python 3.10 exception doppelganger stupidity +# see https://docs.python.org/3.10/library/asyncio-exceptions.html +# https://github.com/python/cpython/issues?q=is%3Aissue+timeouterror++alias+ +from asyncio import TimeoutError +from unittest.mock import DEFAULT, patch import pytest -from ophyd.sim import NullStatus -from ophyd.status import Status, StatusBase +from bluesky import FailedStatus, RunEngine +from bluesky import plan_stubs as bps +from ophyd_async.core import get_mock_put, set_mock_value from dodal.devices.focusing_mirror import ( FocusingMirrorWithStripes, @@ -12,121 +18,242 @@ MirrorVoltageDevice, VFMMirrorVoltages, ) +from dodal.log import LOGGER @pytest.fixture def vfm_mirror_voltages_not_ok(vfm_mirror_voltages) -> VFMMirrorVoltages: - vfm_mirror_voltages._channel14_voltage_device._demand_accepted.sim_put( - MirrorVoltageDemand.FAIL + set_mock_value( + vfm_mirror_voltages.voltage_channels[0]._demand_accepted, + MirrorVoltageDemand.FAIL, ) return vfm_mirror_voltages @pytest.fixture def vfm_mirror_voltages_with_set(vfm_mirror_voltages) -> VFMMirrorVoltages: - def not_ok_then_ok(_): - vfm_mirror_voltages._channel14_voltage_device._demand_accepted.sim_put( - MirrorVoltageDemand.SLEW + return vfm_mirror_voltages_with_set_to_value( + vfm_mirror_voltages, MirrorVoltageDemand.OK + ) + + +@pytest.fixture +def vfm_mirror_voltages_with_set_multiple_spins( + vfm_mirror_voltages, +) -> VFMMirrorVoltages: + return vfm_mirror_voltages_with_set_to_value( + vfm_mirror_voltages, MirrorVoltageDemand.OK, 3 + ) + + +@pytest.fixture +def vfm_mirror_voltages_with_set_accepted_fail( + vfm_mirror_voltages, +) -> VFMMirrorVoltages: + return vfm_mirror_voltages_with_set_to_value( + vfm_mirror_voltages, MirrorVoltageDemand.FAIL + ) + + +def vfm_mirror_voltages_with_set_to_value( + vfm_mirror_voltages, new_value: MirrorVoltageDemand, spins: int = 0 +) -> VFMMirrorVoltages: + async def set_demand_accepted_after_delay(): + await asyncio.sleep(0.1) + nonlocal spins + if spins > 0: + set_mock_value( + vfm_mirror_voltages.voltage_channels[0]._demand_accepted, + MirrorVoltageDemand.SLEW, + ) + spins -= 1 + asyncio.create_task(set_demand_accepted_after_delay()) + else: + set_mock_value( + vfm_mirror_voltages.voltage_channels[0]._demand_accepted, + new_value, + ) + LOGGER.debug("DEMAND ACCEPTED OK") + + def not_ok_then_other_value(*args, **kwargs): + set_mock_value( + vfm_mirror_voltages.voltage_channels[0]._demand_accepted, + MirrorVoltageDemand.SLEW, ) - Timer( - 0.1, - lambda: vfm_mirror_voltages._channel14_voltage_device._demand_accepted.sim_put( - MirrorVoltageDemand.OK - ), - ).start() + asyncio.create_task(set_demand_accepted_after_delay()) return DEFAULT - vfm_mirror_voltages._channel14_voltage_device._setpoint_v.set = MagicMock( - side_effect=not_ok_then_ok - ) - vfm_mirror_voltages._channel14_voltage_device._demand_accepted.sim_put( - MirrorVoltageDemand.OK + get_mock_put( + vfm_mirror_voltages.voltage_channels[0]._setpoint_v + ).side_effect = not_ok_then_other_value + set_mock_value( + vfm_mirror_voltages.voltage_channels[0]._demand_accepted, + MirrorVoltageDemand.OK, ) return vfm_mirror_voltages @pytest.fixture def vfm_mirror_voltages_with_set_timing_out(vfm_mirror_voltages) -> VFMMirrorVoltages: - def not_ok(_): - vfm_mirror_voltages._channel14_voltage_device._demand_accepted.sim_put( - MirrorVoltageDemand.SLEW + def not_ok(*args, **kwargs): + set_mock_value( + vfm_mirror_voltages.voltage_channels[0]._demand_accepted, + MirrorVoltageDemand.SLEW, ) return DEFAULT - vfm_mirror_voltages._channel14_voltage_device._setpoint_v.set = MagicMock( - side_effect=not_ok - ) - vfm_mirror_voltages._channel14_voltage_device._demand_accepted.sim_put( - MirrorVoltageDemand.OK + get_mock_put( + vfm_mirror_voltages.voltage_channels[0]._setpoint_v + ).side_effect = not_ok + set_mock_value( + vfm_mirror_voltages.voltage_channels[0]._demand_accepted, + MirrorVoltageDemand.OK, ) return vfm_mirror_voltages def test_mirror_set_voltage_sets_and_waits_happy_path( + RE: RunEngine, vfm_mirror_voltages_with_set: VFMMirrorVoltages, ): - vfm_mirror_voltages_with_set._channel14_voltage_device._setpoint_v.set.return_value = NullStatus() - vfm_mirror_voltages_with_set._channel14_voltage_device._demand_accepted.sim_put( - MirrorVoltageDemand.OK + async def completed(): + pass + + mock_put = get_mock_put( + vfm_mirror_voltages_with_set.voltage_channels[0]._setpoint_v + ) + mock_put.return_value = completed() + set_mock_value( + vfm_mirror_voltages_with_set.voltage_channels[0]._demand_accepted, + MirrorVoltageDemand.OK, ) - status: StatusBase = vfm_mirror_voltages_with_set.voltage_channels[0].set(100) - status.wait() - vfm_mirror_voltages_with_set._channel14_voltage_device._setpoint_v.set.assert_called_with( - 100 + def plan(): + yield from bps.abs_set( + vfm_mirror_voltages_with_set.voltage_channels[0], 100, wait=True + ) + + RE(plan()) + + mock_put.assert_called_with(100, wait=True, timeout=10.0) + + +def test_mirror_set_voltage_sets_and_waits_happy_path_spin_while_waiting_for_slew( + RE: RunEngine, + vfm_mirror_voltages_with_set_multiple_spins: VFMMirrorVoltages, +): + async def completed(): + pass + + mock_put = get_mock_put( + vfm_mirror_voltages_with_set_multiple_spins.voltage_channels[0]._setpoint_v + ) + mock_put.return_value = completed() + set_mock_value( + vfm_mirror_voltages_with_set_multiple_spins.voltage_channels[ + 0 + ]._demand_accepted, + MirrorVoltageDemand.OK, ) - assert status.success + + def plan(): + yield from bps.abs_set( + vfm_mirror_voltages_with_set_multiple_spins.voltage_channels[0], + 100, + wait=True, + ) + + RE(plan()) + + mock_put.assert_called_with(100, wait=True, timeout=10.0) def test_mirror_set_voltage_set_rejected_when_not_ok( + RE: RunEngine, vfm_mirror_voltages_not_ok: VFMMirrorVoltages, ): - with pytest.raises(AssertionError): - vfm_mirror_voltages_not_ok.voltage_channels[0].set(100) + def plan(): + with pytest.raises(FailedStatus) as e: + yield from bps.abs_set( + vfm_mirror_voltages_not_ok.voltage_channels[0], 100, wait=True + ) + + assert isinstance(e.value.args[0].exception(), AssertionError) + + RE(plan()) def test_mirror_set_voltage_sets_and_waits_set_fail( + RE: RunEngine, vfm_mirror_voltages_with_set: VFMMirrorVoltages, ): - vfm_mirror_voltages_with_set._channel14_voltage_device._setpoint_v.set.return_value = Status( - success=False, done=True - ) + def failed(*args, **kwargs): + raise AssertionError("Test Failure") + + get_mock_put( + vfm_mirror_voltages_with_set.voltage_channels[0]._setpoint_v + ).side_effect = failed - status: StatusBase = vfm_mirror_voltages_with_set.voltage_channels[0].set(100) - with pytest.raises(Exception): - status.wait() + def plan(): + with pytest.raises(FailedStatus) as e: + yield from bps.abs_set( + vfm_mirror_voltages_with_set.voltage_channels[0], 100, wait=True + ) - assert not status.success + assert isinstance(e.value.args[0].exception(), AssertionError) + + RE(plan()) + + +def test_mirror_set_voltage_sets_and_waits_demand_accepted_fail( + RE: RunEngine, vfm_mirror_voltages_with_set_accepted_fail +): + def plan(): + with pytest.raises(FailedStatus) as e: + yield from bps.abs_set( + vfm_mirror_voltages_with_set_accepted_fail.voltage_channels[0], + 100, + wait=True, + ) + + assert isinstance(e.value.args[0].exception(), AssertionError) + + RE(plan()) @patch("dodal.devices.focusing_mirror.DEFAULT_SETTLE_TIME_S", 3) def test_mirror_set_voltage_sets_and_waits_settle_timeout_expires( + RE: RunEngine, vfm_mirror_voltages_with_set_timing_out: VFMMirrorVoltages, ): - vfm_mirror_voltages_with_set_timing_out._channel14_voltage_device._setpoint_v.set.return_value = NullStatus() - - status: StatusBase = vfm_mirror_voltages_with_set_timing_out.voltage_channels[ - 0 - ].set(100) - - with pytest.raises(Exception) as excinfo: - status.wait() + def plan(): + with pytest.raises(Exception) as excinfo: + yield from bps.abs_set( + vfm_mirror_voltages_with_set_timing_out.voltage_channels[0], + 100, + wait=True, + ) + assert isinstance(excinfo.value.args[0].exception(), TimeoutError) - # Cannot assert because ophyd discards the original exception - # assert isinstance(excinfo.value, WaitTimeoutError) - assert excinfo.value + RE(plan()) def test_mirror_set_voltage_returns_immediately_if_voltage_already_demanded( + RE: RunEngine, vfm_mirror_voltages_with_set: VFMMirrorVoltages, ): - vfm_mirror_voltages_with_set._channel14_voltage_device._setpoint_v.sim_put(100) + set_mock_value(vfm_mirror_voltages_with_set.voltage_channels[0]._setpoint_v, 100) + + def plan(): + yield from bps.abs_set( + vfm_mirror_voltages_with_set.voltage_channels[0], 100, wait=True + ) - status: StatusBase = vfm_mirror_voltages_with_set.voltage_channels[0].set(100) - status.wait() + RE(plan()) - assert status.success - vfm_mirror_voltages_with_set._channel14_voltage_device._setpoint_v.set.assert_not_called() + get_mock_put( + vfm_mirror_voltages_with_set.voltage_channels[0]._setpoint_v + ).assert_not_called() def test_mirror_populates_voltage_channels(