diff --git a/.github/workflows/code.yml b/.github/workflows/code.yml index 9a8954befe..88b4e565cb 100644 --- a/.github/workflows/code.yml +++ b/.github/workflows/code.yml @@ -112,96 +112,6 @@ jobs: # If more than one module in src/ replace with module name to test run: python -m dodal --version - container: - needs: [lint, dist, test] - runs-on: ubuntu-latest - - permissions: - contents: read - packages: write - - env: - TEST_TAG: "testing" - - steps: - - name: Checkout - uses: actions/checkout@v4 - - # image names must be all lower case - - name: Generate image repo name - run: echo IMAGE_REPOSITORY=ghcr.io/$(tr '[:upper:]' '[:lower:]' <<< "${{ github.repository }}") >> $GITHUB_ENV - - - name: Set lockfile location in environment - run: | - echo "DIST_LOCKFILE_PATH=lockfiles-${{ env.CONTAINER_PYTHON }}-dist-${{ github.sha }}" >> $GITHUB_ENV - - - name: Download wheel and lockfiles - uses: actions/download-artifact@v4.1.2 - with: - path: artifacts/ - pattern: "*dist*" - - - name: Log in to GitHub Docker Registry - if: github.event_name != 'pull_request' - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Set up Docker Buildx - id: buildx - uses: docker/setup-buildx-action@v3 - - - name: Build and export to Docker local cache - uses: docker/build-push-action@v5 - with: - # Note build-args, context, file, and target must all match between this - # step and the later build-push-action, otherwise the second build-push-action - # will attempt to build the image again - build-args: | - PIP_OPTIONS=-r ${{ env.DIST_LOCKFILE_PATH }}/requirements.txt ${{ env.DIST_WHEEL_PATH }}/*.whl - context: artifacts/ - file: ./Dockerfile - target: runtime - load: true - tags: ${{ env.TEST_TAG }} - # If you have a long docker build (2+ minutes), uncomment the - # following to turn on caching. For short build times this - # makes it a little slower - #cache-from: type=gha - #cache-to: type=gha,mode=max - - - name: Create tags for publishing image - id: meta - uses: docker/metadata-action@v5 - with: - images: ${{ env.IMAGE_REPOSITORY }} - tags: | - type=ref,event=tag - type=raw,value=latest, enable=${{ github.ref_type == 'tag' }} - # type=edge,branch=main - # Add line above to generate image for every commit to given branch, - # and uncomment the end of if clause in next step - - - name: Push cached image to container registry - if: github.ref_type == 'tag' # || github.ref_name == 'main' - uses: docker/build-push-action@v5 - # This does not build the image again, it will find the image in the - # Docker cache and publish it - with: - # Note build-args, context, file, and target must all match between this - # step and the previous build-push-action, otherwise this step will - # attempt to build the image again - build-args: | - PIP_OPTIONS=-r ${{ env.DIST_LOCKFILE_PATH }}/requirements.txt ${{ env.DIST_WHEEL_PATH }}/*.whl - context: artifacts/ - file: ./Dockerfile - target: runtime - push: true - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} - release: # upload to PyPI and make a release on every tag needs: [lint, dist, test] diff --git a/Dockerfile b/Dockerfile index a932e69851..a7cf36f3bb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,15 +23,3 @@ WORKDIR /context # install python package into /venv RUN pip install ${PIP_OPTIONS} - -FROM python:3.11-slim as runtime - -# Add apt-get system dependecies for runtime here if needed - -# copy the virtual environment from the build stage and put it in PATH -COPY --from=build /venv/ /venv/ -ENV PATH=/venv/bin:$PATH - -# change this entrypoint if it is not the same as the repo -ENTRYPOINT ["dodal"] -CMD ["--version"] diff --git a/docs/user/how-to/run-container.rst b/docs/user/how-to/run-container.rst deleted file mode 100644 index 64baca9b3e..0000000000 --- a/docs/user/how-to/run-container.rst +++ /dev/null @@ -1,15 +0,0 @@ -Run in a container -================== - -Pre-built containers with dodal and its dependencies already -installed are available on `Github Container Registry -`_. - -Starting the container ----------------------- - -To pull the container from github container registry and run:: - - $ docker run ghcr.io/DiamondLightSource/dodal:main --version - -To get a released version, use a numbered release instead of ``main``. diff --git a/docs/user/index.rst b/docs/user/index.rst index c3ba4b7c0c..538d0fce92 100644 --- a/docs/user/index.rst +++ b/docs/user/index.rst @@ -26,7 +26,6 @@ side-bar. :caption: How-to Guides :maxdepth: 1 - how-to/run-container how-to/create-beamline.rst +++ diff --git a/pyproject.toml b/pyproject.toml index 1027e09cd7..1deadf9edb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers = [ description = "Ophyd devices and other utils that could be used across DLS beamlines" dependencies = [ "ophyd", - "ophyd_async@git+https://github.com/bluesky/ophyd-async@ec5729640041ee5b77b4614158793af3a34cf9d8", #Use a specific branch from ophyd async until https://github.com/bluesky/ophyd-async/pull/101 is merged + "ophyd-async@git+https://github.com/bluesky/ophyd-async", "bluesky", "pyepics", "dataclasses-json", @@ -22,10 +22,11 @@ dependencies = [ "zocalo", "requests", "graypy", - "pydantic<2.0", - "opencv-python-headless", # For pin-tip detection. - "aioca", # Required for CA support with ophyd-async. - "p4p", # Required for PVA support with ophyd-async. + "pydantic", + "opencv-python-headless", # For pin-tip detection. + "aioca", # Required for CA support with ophyd-async. + "p4p", # Required for PVA support with ophyd-async. + "numpy", ] dynamic = ["version"] @@ -78,6 +79,7 @@ ignore_missing_imports = true # Ignore missing stubs in imported modules [tool.pytest.ini_options] # Run pytest with all our checkers, and don't spam us with massive tracebacks on error +asyncio_mode = "auto" markers = [ "s03: marks tests as requiring the s03 simulator running (deselect with '-m \"not s03\"')", ] @@ -88,6 +90,11 @@ addopts = """ # Doctest python code in docs, python code in src docstrings, test functions in tests testpaths = "docs src tests" +[tool.coverage.report] +exclude_also = [ + '^"""', # Ignore the start/end of a file-level triple quoted docstring +] + [tool.coverage.run] data_file = "/tmp/dodal.coverage" diff --git a/src/dodal/beamlines/beamline_utils.py b/src/dodal/beamlines/beamline_utils.py index 21a7fd42df..1647eba4ef 100644 --- a/src/dodal/beamlines/beamline_utils.py +++ b/src/dodal/beamlines/beamline_utils.py @@ -52,8 +52,12 @@ def _wait_for_connection( device.wait_for_connection(timeout=timeout) elif isinstance(device, OphydV2Device): call_in_bluesky_event_loop( - v2_device_wait_for_connection(coros=device.connect(sim=sim)), - timeout=timeout, + v2_device_wait_for_connection( + coros=device.connect( + sim=sim, + timeout=timeout, + ) + ), ) else: raise TypeError( @@ -98,9 +102,11 @@ def device_instantiation( if already_existing_device is None: device_instance = device_factory( name=name, - prefix=f"{(BeamlinePrefix(BL).beamline_prefix)}{prefix}" - if bl_prefix - else prefix, + prefix=( + f"{(BeamlinePrefix(BL).beamline_prefix)}{prefix}" + if bl_prefix + else prefix + ), **kwargs, ) ACTIVE_DEVICES[name] = device_instance diff --git a/src/dodal/beamlines/i03.py b/src/dodal/beamlines/i03.py index aaa335f794..373a2c7c03 100644 --- a/src/dodal/beamlines/i03.py +++ b/src/dodal/beamlines/i03.py @@ -187,7 +187,7 @@ def fast_grid_scan( return device_instantiation( device_factory=FastGridScan, name="fast_grid_scan", - prefix="-MO-SGON-01:FGS:", + prefix="-MO-SGON-01:", wait=wait_for_connection, fake=fake_with_ophyd_sim, ) diff --git a/src/dodal/beamlines/i20_1.py b/src/dodal/beamlines/i20_1.py index e1259f7d3d..fe5f1be071 100644 --- a/src/dodal/beamlines/i20_1.py +++ b/src/dodal/beamlines/i20_1.py @@ -1,33 +1,24 @@ from dodal.beamlines.beamline_utils import device_instantiation from dodal.beamlines.beamline_utils import set_beamline as set_utils_beamline -from dodal.devices.motors import EpicsMotor -from dodal.devices.i20_1.turbo_slit import TurboSlit +from dodal.devices.turbo_slit import TurboSlit from dodal.log import set_beamline as set_log_beamline -from dodal.utils import get_beamline_name, get_hostname, skip_device +from dodal.utils import get_beamline_name BL = get_beamline_name("i20_1") set_log_beamline(BL) set_utils_beamline(BL) -def _is_i20_1_machine(): - """ - Devices using PVA can only connect from i20_1 machines, due to the absence of - PVA gateways at present. - """ - hostname = get_hostname() - return hostname.startswith("i20_1") - - -@skip_device(lambda: not _is_i20_1_machine()) -def turbo_slit_motor( +def turbo_slit( wait_for_connection: bool = True, fake_with_ophyd_sim: bool = False ) -> TurboSlit: - """Get the i20-1 motor""" + """ + turboslit for selecting energy from the polychromator + """ return device_instantiation( TurboSlit, - prefix="-OP-PCHRO-01", + prefix="-OP-PCHRO-01:TS:", name="turbo_slit", wait=wait_for_connection, fake=fake_with_ophyd_sim, diff --git a/src/dodal/common/__init__.py b/src/dodal/common/__init__.py new file mode 100644 index 0000000000..4a53b30530 --- /dev/null +++ b/src/dodal/common/__init__.py @@ -0,0 +1,12 @@ +from .coordination import group_uuid, inject +from .maths import in_micros, step_to_num +from .types import MsgGenerator, PlanGenerator + +__all__ = [ + "group_uuid", + "inject", + "in_micros", + "MsgGenerator", + "PlanGenerator", + "step_to_num", +] diff --git a/src/dodal/common/coordination.py b/src/dodal/common/coordination.py new file mode 100644 index 0000000000..3c061490bf --- /dev/null +++ b/src/dodal/common/coordination.py @@ -0,0 +1,38 @@ +import uuid + +from dodal.common.types import Group + + +def group_uuid(name: str) -> Group: + """ + Returns a unique but human-readable string, to assist debugging orchestrated groups. + + Args: + name (str): A human readable name + + Returns: + readable_uid (Group): name appended with a unique string + """ + return f"{name}-{str(uuid.uuid4())[:6]}" + + +def inject(name: str): # type: ignore + """ + Function to mark a default argument of a plan method as a reference to a device + that is stored in the Blueapi context, as devices are constructed on startup of the + service, and are not available to be used when writing plans. + Bypasses mypy linting, returning x as Any and therefore valid as a default + argument. + e.g. For a 1-dimensional scan, that is usually performed on a consistent Movable + axis with name "stage_x" + def scan(x: Movable = inject("stage_x"), start: float = 0.0 ...) + + Args: + name (str): Name of a device to be fetched from the Blueapi context + + Returns: + Any: name but without typing checking, valid as any default type + + """ + + return name diff --git a/src/dodal/common/maths.py b/src/dodal/common/maths.py new file mode 100644 index 0000000000..a691279f24 --- /dev/null +++ b/src/dodal/common/maths.py @@ -0,0 +1,52 @@ +from typing import Tuple + +import numpy as np + + +def step_to_num(start: float, stop: float, step: float) -> Tuple[float, float, int]: + """ + Standard handling for converting from start, stop, step to start, stop, num + Forces step to be same direction as length + Includes a final point if it is within 1% of the final step, prevents floating + point arithmatic errors from giving inconsistent shaped scans between steps of an + outer axis. + + Args: + start (float): + Start of length, will be returned unchanged + stop (float): + End of length, if length/step does not divide cleanly will be returned + extended up to 1% of step, or else truncated. + step (float): + Length of a step along the line formed from start to stop. + If stop < start, will be coerced to be backwards. + + Returns: + start, adjusted_stop, num = Tuple[float, float, int] + start will be returned unchanged + adjusted_stop = start + (num - 1) * step + num is the maximal number of steps that could fit into the length. + + """ + # Make step be the right direction + step = abs(step) if stop >= start else -abs(step) + # If stop is within 1% of a step then include it + steps = int((stop - start) / step + 0.01) + return start, start + steps * step, steps + 1 # include 1st point + + +def in_micros(t: float) -> int: + """ + Converts between a positive number of seconds and an equivalent + number of microseconds. + + Args: + t (float): A time in seconds + Raises: + ValueError: if t < 0 + Returns: + t (int): A time in microseconds, rounded up to the nearest whole microsecond, + """ + if t < 0: + raise ValueError(f"Expected a positive time in seconds, got {t!r}") + return int(np.ceil(t * 1e6)) diff --git a/src/dodal/common/types.py b/src/dodal/common/types.py new file mode 100644 index 0000000000..1eeedcfa59 --- /dev/null +++ b/src/dodal/common/types.py @@ -0,0 +1,14 @@ +from typing import ( + Any, + Callable, + Generator, +) + +from bluesky.utils import Msg + +# String identifier used by 'wait' or stubs that await +Group = str +# A true 'plan', usually the output of a generator function +MsgGenerator = Generator[Msg, Any, None] +# A function that generates a plan +PlanGenerator = Callable[..., MsgGenerator] diff --git a/src/dodal/devices/aperture.py b/src/dodal/devices/aperture.py index d78c983e51..fa2b0d1d8d 100644 --- a/src/dodal/devices/aperture.py +++ b/src/dodal/devices/aperture.py @@ -1,11 +1,12 @@ -from ophyd import Component, Device, EpicsMotor, EpicsSignalRO +from ophyd import Component, Device, EpicsSignalRO - -class EpicsMotorWithMRES(EpicsMotor): - motor_resolution: Component[EpicsSignalRO] = Component(EpicsSignalRO, ".MRES") +from dodal.devices.util.motor_utils import ExtendedEpicsMotor class Aperture(Device): - x = Component(EpicsMotor, "X") - y = Component(EpicsMotor, "Y") - z = Component(EpicsMotorWithMRES, "Z") + x = Component(ExtendedEpicsMotor, "X") + y = Component(ExtendedEpicsMotor, "Y") + z = Component(ExtendedEpicsMotor, "Z") + small = Component(EpicsSignalRO, "Y:SMALL_CALC") + medium = Component(EpicsSignalRO, "Y:MEDIUM_CALC") + large = Component(EpicsSignalRO, "Y:LARGE_CALC") diff --git a/src/dodal/devices/aperturescatterguard.py b/src/dodal/devices/aperturescatterguard.py index 911818735d..66cf56e459 100644 --- a/src/dodal/devices/aperturescatterguard.py +++ b/src/dodal/devices/aperturescatterguard.py @@ -1,8 +1,13 @@ +import operator +from collections import namedtuple from dataclasses import dataclass -from typing import Optional, Tuple +from functools import reduce +from typing import List, Optional, Sequence from ophyd import Component as Cpt -from ophyd.status import AndStatus, Status +from ophyd import SignalRO +from ophyd.epics_motor import EpicsMotor +from ophyd.status import AndStatus, Status, StatusBase from dodal.devices.aperture import Aperture from dodal.devices.logging_ophyd_device import InfoLoggingDevice @@ -14,85 +19,134 @@ class InvalidApertureMove(Exception): pass +ApertureFiveDimensionalLocation = namedtuple( + "ApertureFiveDimensionalLocation", + [ + "aperture_x", + "aperture_y", + "aperture_z", + "scatterguard_x", + "scatterguard_y", + ], +) + + +@dataclass +class SingleAperturePosition: + name: str + GDA_name: str + radius_microns: Optional[float] + location: ApertureFiveDimensionalLocation + + +def position_from_params( + name: str, GDA_name: str, radius_microns: Optional[float], params: dict +) -> SingleAperturePosition: + return SingleAperturePosition( + name, + GDA_name, + radius_microns, + ApertureFiveDimensionalLocation( + params[f"miniap_x_{GDA_name}"], + params[f"miniap_y_{GDA_name}"], + params[f"miniap_z_{GDA_name}"], + params[f"sg_x_{GDA_name}"], + params[f"sg_y_{GDA_name}"], + ), + ) + + @dataclass class AperturePositions: - """Holds tuples (miniap_x, miniap_y, miniap_z, scatterguard_x, scatterguard_y) - representing the motor positions needed to select a particular aperture size. - """ + """Holds the motor positions needed to select a particular aperture size.""" - LARGE: Tuple[float, float, float, float, float] - MEDIUM: Tuple[float, float, float, float, float] - SMALL: Tuple[float, float, float, float, float] - ROBOT_LOAD: Tuple[float, float, float, float, float] + LARGE: SingleAperturePosition + MEDIUM: SingleAperturePosition + SMALL: SingleAperturePosition + ROBOT_LOAD: SingleAperturePosition @classmethod def from_gda_beamline_params(cls, params): return cls( - LARGE=( - params["miniap_x_LARGE_APERTURE"], - params["miniap_y_LARGE_APERTURE"], - params["miniap_z_LARGE_APERTURE"], - params["sg_x_LARGE_APERTURE"], - params["sg_y_LARGE_APERTURE"], - ), - MEDIUM=( - params["miniap_x_MEDIUM_APERTURE"], - params["miniap_y_MEDIUM_APERTURE"], - params["miniap_z_MEDIUM_APERTURE"], - params["sg_x_MEDIUM_APERTURE"], - params["sg_y_MEDIUM_APERTURE"], - ), - SMALL=( - params["miniap_x_SMALL_APERTURE"], - params["miniap_y_SMALL_APERTURE"], - params["miniap_z_SMALL_APERTURE"], - params["sg_x_SMALL_APERTURE"], - params["sg_y_SMALL_APERTURE"], - ), - ROBOT_LOAD=( - params["miniap_x_ROBOT_LOAD"], - params["miniap_y_ROBOT_LOAD"], - params["miniap_z_ROBOT_LOAD"], - params["sg_x_ROBOT_LOAD"], - params["sg_y_ROBOT_LOAD"], - ), + LARGE=position_from_params("Large", "LARGE_APERTURE", 100, params), + MEDIUM=position_from_params("Medium", "MEDIUM_APERTURE", 50, params), + SMALL=position_from_params("Small", "SMALL_APERTURE", 20, params), + ROBOT_LOAD=position_from_params("Robot load", "ROBOT_LOAD", None, params), ) - def position_valid(self, pos: Tuple[float, float, float, float, float]): - """ - Check if argument 'pos' is a valid position in this AperturePositions object. - """ - if pos not in [self.LARGE, self.MEDIUM, self.SMALL, self.ROBOT_LOAD]: - return False - return True + def as_list(self) -> List[SingleAperturePosition]: + return [ + self.LARGE, + self.MEDIUM, + self.SMALL, + self.ROBOT_LOAD, + ] class ApertureScatterguard(InfoLoggingDevice): aperture = Cpt(Aperture, "-MO-MAPT-01:") scatterguard = Cpt(Scatterguard, "-MO-SCAT-01:") aperture_positions: Optional[AperturePositions] = None - APERTURE_Z_TOLERANCE = 3 # Number of MRES steps + TOLERANCE_STEPS = 3 # Number of MRES steps + + class SelectedAperture(SignalRO): + def get(self): + assert isinstance(self.parent, ApertureScatterguard) + return self.parent._get_current_aperture_position() + + selected_aperture = Cpt(SelectedAperture) def load_aperture_positions(self, positions: AperturePositions): LOGGER.info(f"{self.name} loaded in {positions}") self.aperture_positions = positions - def set(self, pos: Tuple[float, float, float, float, float]) -> AndStatus: - try: - assert isinstance(self.aperture_positions, AperturePositions) - assert self.aperture_positions.position_valid(pos) - except AssertionError as e: - raise InvalidApertureMove(repr(e)) - return self._safe_move_within_datacollection_range(*pos) + def set(self, pos: SingleAperturePosition) -> StatusBase: + assert isinstance(self.aperture_positions, AperturePositions) + if pos not in self.aperture_positions.as_list(): + raise InvalidApertureMove(f"Unknown aperture: {pos}") + + return self._safe_move_within_datacollection_range(pos.location) + + def _get_motor_list(self): + return [ + self.aperture.x, + self.aperture.y, + self.aperture.z, + self.scatterguard.x, + self.scatterguard.y, + ] + + def _set_raw_unsafe(self, positions: ApertureFiveDimensionalLocation) -> AndStatus: + motors: Sequence[EpicsMotor] = self._get_motor_list() + return reduce( + operator.and_, [motor.set(pos) for motor, pos in zip(motors, positions)] + ) + + def _get_current_aperture_position(self) -> SingleAperturePosition: + """ + Returns the current aperture position using readback values + for SMALL, MEDIUM, LARGE. ROBOT_LOAD position defined when + mini aperture y <= ROBOT_LOAD.location.aperture_y + tolerance. + If no position is found then raises InvalidApertureMove. + """ + assert isinstance(self.aperture_positions, AperturePositions) + current_ap_y = float(self.aperture.y.user_readback.get()) + robot_load_ap_y = self.aperture_positions.ROBOT_LOAD.location.aperture_y + tolerance = self.TOLERANCE_STEPS * self.aperture.y.motor_resolution.get() + if int(self.aperture.large.get()) == 1: + return self.aperture_positions.LARGE + elif int(self.aperture.medium.get()) == 1: + return self.aperture_positions.MEDIUM + elif int(self.aperture.small.get()) == 1: + return self.aperture_positions.SMALL + elif current_ap_y <= robot_load_ap_y + tolerance: + return self.aperture_positions.ROBOT_LOAD + + raise InvalidApertureMove("Current aperture/scatterguard state unrecognised") def _safe_move_within_datacollection_range( - self, - aperture_x: float, - aperture_y: float, - aperture_z: float, - scatterguard_x: float, - scatterguard_y: float, - ) -> Status: + self, pos: ApertureFiveDimensionalLocation + ) -> StatusBase: """ Move the aperture and scatterguard combo safely to a new position. See https://github.com/DiamondLightSource/hyperion/wiki/Aperture-Scatterguard-Collisions @@ -101,6 +155,10 @@ def _safe_move_within_datacollection_range( # EpicsMotor does not have deadband/MRES field, so the way to check if we are # in a datacollection position is to see if we are "ready" (DMOV) and the target # position is correct + + # unpacking the position + aperture_x, aperture_y, aperture_z, scatterguard_x, scatterguard_y = pos + ap_z_in_position = self.aperture.z.motor_done_move.get() if not ap_z_in_position: status: Status = Status(obj=self) @@ -111,9 +169,11 @@ def _safe_move_within_datacollection_range( ) ) return status + current_ap_z = self.aperture.z.user_setpoint.get() - tolerance = self.APERTURE_Z_TOLERANCE * self.aperture.z.motor_resolution.get() - if abs(current_ap_z - aperture_z) > tolerance: + tolerance = self.TOLERANCE_STEPS * self.aperture.z.motor_resolution.get() + diff_on_z = abs(current_ap_z - aperture_z) + if diff_on_z > tolerance: raise InvalidApertureMove( "ApertureScatterguard safe move is not yet defined for positions " "outside of LARGE, MEDIUM, SMALL, ROBOT_LOAD. " @@ -126,24 +186,22 @@ def _safe_move_within_datacollection_range( scatterguard_x ) & self.scatterguard.y.set(scatterguard_y) sg_status.wait() - final_status = ( + return ( sg_status & self.aperture.x.set(aperture_x) & self.aperture.y.set(aperture_y) & self.aperture.z.set(aperture_z) ) - return final_status - else: - ap_status: AndStatus = ( - self.aperture.x.set(aperture_x) - & self.aperture.y.set(aperture_y) - & self.aperture.z.set(aperture_z) - ) - ap_status.wait() - final_status = ( - ap_status - & self.scatterguard.x.set(scatterguard_x) - & self.scatterguard.y.set(scatterguard_y) - ) - return final_status + ap_status: AndStatus = ( + self.aperture.x.set(aperture_x) + & self.aperture.y.set(aperture_y) + & self.aperture.z.set(aperture_z) + ) + ap_status.wait() + final_status: AndStatus = ( + ap_status + & self.scatterguard.x.set(scatterguard_x) + & self.scatterguard.y.set(scatterguard_y) + ) + return final_status diff --git a/src/dodal/devices/areadetector/plugins/MXSC.py b/src/dodal/devices/areadetector/plugins/MXSC.py deleted file mode 100644 index 82177896d0..0000000000 --- a/src/dodal/devices/areadetector/plugins/MXSC.py +++ /dev/null @@ -1,130 +0,0 @@ -from typing import List, Tuple - -import numpy as np -from ophyd import Component, Device, EpicsSignal, EpicsSignalRO, Kind, Signal -from ophyd.status import StableSubscriptionStatus, Status, StatusBase - -from dodal.log import LOGGER - -Pixel = Tuple[int, int] - - -def statistics_of_positions( - positions: List[Pixel], -) -> Tuple[Pixel, Tuple[float, float]]: - """Get the median and standard deviation from a list of readings. - - Note that x/y are treated separately so the median position is not guaranteed to be - a position that was actually read. - - Args: - positions (List[Pixel]): A list of tip positions. - - Returns: - Tuple[Pixel, Tuple[float, float]]: The median tip position and the standard - deviation in x/y - """ - x_coords, y_coords = np.array(positions).T - - median = (int(np.median(x_coords)), int(np.median(y_coords))) - std = (np.std(x_coords, dtype=float), np.std(y_coords, dtype=float)) - - return median, std - - -class PinTipDetect(Device): - """This will read the pin tip location from the MXSC plugin. - - If the plugin finds no tip it will return {INVALID_POSITION}. However, it will also - occassionally give incorrect data. Therefore, it is recommended that you trigger - this device, which will set {triggered_tip} to a median of the valid points taken - for {settle_time_s} seconds. - - If no valid points are found within {validity_timeout} seconds a {triggered_tip} - will be set to {INVALID_POSITION}. - """ - - INVALID_POSITION = (-1, -1) - tip_x = Component(EpicsSignalRO, "TipX") - tip_y = Component(EpicsSignalRO, "TipY") - - triggered_tip = Component(Signal, kind=Kind.hinted, value=INVALID_POSITION) - validity_timeout = Component(Signal, value=5) - settle_time_s = Component(Signal, value=0.5) - - def log_tips_and_statistics(self, _): - median, standard_deviation = statistics_of_positions(self.tip_positions) - LOGGER.info( - f"Found tips {self.tip_positions} with median {median} and standard deviation {standard_deviation}" - ) - - def update_tip_if_valid(self, value: int, **_): - current_value = (value, int(self.tip_y.get())) - if current_value != self.INVALID_POSITION: - self.tip_positions.append(current_value) - - ( - median_tip_location, - __, - ) = statistics_of_positions(self.tip_positions) - - self.triggered_tip.put(median_tip_location) - return True - return False - - def trigger(self) -> StatusBase: - self.tip_positions: List[Pixel] = [] - - subscription_status = StableSubscriptionStatus( - self.tip_x, - self.update_tip_if_valid, - stability_time=self.settle_time_s.get(), - run=True, - ) - - def set_to_default_and_finish(timeout_status: Status): - try: - if not timeout_status.success: - self.triggered_tip.set(self.INVALID_POSITION) - subscription_status.set_finished() - except Exception as e: - subscription_status.set_exception(e) - - # We use a separate status for measuring the timeout as we don't want an error - # on the returned status - self._timeout_status = Status(self, timeout=self.validity_timeout.get()) - self._timeout_status.add_callback(set_to_default_and_finish) - subscription_status.add_callback(lambda _: self._timeout_status.set_finished()) - subscription_status.add_callback(self.log_tips_and_statistics) - - return subscription_status - - -class MXSC(Device): - """ - Device for edge detection plugin. - """ - - input_plugin = Component(EpicsSignal, "NDArrayPort") - enable_callbacks = Component(EpicsSignal, "EnableCallbacks") - min_callback_time = Component(EpicsSignal, "MinCallbackTime") - blocking_callbacks = Component(EpicsSignal, "BlockingCallbacks") - read_file = Component(EpicsSignal, "ReadFile") - filename = Component(EpicsSignal, "Filename", string=True) - preprocess_operation = Component(EpicsSignal, "Preprocess") - preprocess_ksize = Component(EpicsSignal, "PpParam1") - canny_upper_threshold = Component(EpicsSignal, "CannyUpper") - canny_lower_threshold = Component(EpicsSignal, "CannyLower") - close_ksize = Component(EpicsSignal, "CloseKsize") - scan_direction = Component(EpicsSignal, "ScanDirection") - min_tip_height = Component(EpicsSignal, "MinTipHeight") - - top = Component(EpicsSignal, "Top") - bottom = Component(EpicsSignal, "Bottom") - output_array = Component(EpicsSignal, "OutputArray") - draw_tip = Component(EpicsSignal, "DrawTip") - draw_edges = Component(EpicsSignal, "DrawEdges") - waveform_size_x = Component(EpicsSignal, "ArraySize1_RBV") - waveform_size_y = Component(EpicsSignal, "ArraySize2_RBV") - - pin_tip = Component(PinTipDetect, "") diff --git a/src/dodal/devices/detector/detector.py b/src/dodal/devices/detector/detector.py index 5d970a6823..be407b909e 100644 --- a/src/dodal/devices/detector/detector.py +++ b/src/dodal/devices/detector/detector.py @@ -1,7 +1,7 @@ from enum import Enum, auto from typing import Any, Optional, Tuple -from pydantic import BaseModel, validator +from pydantic import BaseModel, root_validator, validator from dodal.devices.detector.det_dim_constants import ( EIGER2_X_16M_SIZE, @@ -28,7 +28,7 @@ class DetectorParams(BaseModel): """Holds parameters for the detector. Provides access to a list of Dectris detector sizes and a converter for distance to beam centre.""" - expected_energy_ev: Optional[float] + expected_energy_ev: Optional[float] = None exposure_time: float directory: str prefix: str @@ -41,8 +41,8 @@ class DetectorParams(BaseModel): det_dist_to_beam_converter_path: str trigger_mode: TriggerMode = TriggerMode.SET_FRAMES detector_size_constants: DetectorSizeConstants = EIGER2_X_16M_SIZE - beam_xy_converter: DetectorDistanceToBeamXYConverter = None - run_number: Optional[int] = None + beam_xy_converter: DetectorDistanceToBeamXYConverter + run_number: int class Config: arbitrary_types_allowed = True @@ -51,6 +51,15 @@ class Config: DetectorSizeConstants: lambda d: d.det_type_string, } + @root_validator(pre=True, skip_on_failure=True) # type: ignore # should be replaced with model_validator once move to pydantic 2 is complete + def create_beamxy_and_runnumber(cls, values: dict[str, Any]) -> dict[str, Any]: + values["beam_xy_converter"] = DetectorDistanceToBeamXYConverter( + values["det_dist_to_beam_converter_path"] + ) + if values.get("run_number") is None: + values["run_number"] = get_run_number(values["directory"]) + return values + @validator("detector_size_constants", pre=True) def _parse_detector_size_constants( cls, det_type: str, values: dict[str, Any] @@ -63,28 +72,6 @@ def _parse_directory(cls, directory: str, values: dict[str, Any]) -> str: directory += "/" return directory - @validator("beam_xy_converter", always=True) - def _parse_beam_xy_converter( - cls, - beam_xy_converter: DetectorDistanceToBeamXYConverter, - values: dict[str, Any], - ) -> DetectorDistanceToBeamXYConverter: - return DetectorDistanceToBeamXYConverter( - values["det_dist_to_beam_converter_path"] - ) - - @validator("run_number", always=True) - def _set_run_number(cls, run_number: int, values: dict[str, Any]): - if run_number is None: - return get_run_number(values["directory"]) - else: - return run_number - - def __post_init__(self): - self.beam_xy_converter = DetectorDistanceToBeamXYConverter( - self.det_dist_to_beam_converter_path - ) - def get_beam_position_mm(self, detector_distance: float) -> Tuple[float, float]: x_beam_mm = self.beam_xy_converter.get_beam_xy_from_det_dist( detector_distance, Axis.X_AXIS diff --git a/src/dodal/devices/fast_grid_scan.py b/src/dodal/devices/fast_grid_scan.py index bde9120f12..7d58f1768f 100644 --- a/src/dodal/devices/fast_grid_scan.py +++ b/src/dodal/devices/fast_grid_scan.py @@ -14,13 +14,13 @@ Signal, ) from ophyd.status import DeviceStatus, StatusBase -from pydantic import BaseModel, validator +from pydantic import validator from pydantic.dataclasses import dataclass from dodal.devices.motors import XYZLimitBundle from dodal.devices.status import await_value from dodal.log import LOGGER -from dodal.parameters.experiment_parameter_base import AbstractExperimentParameterBase +from dodal.parameters.experiment_parameter_base import AbstractExperimentWithBeamParams @dataclass @@ -44,7 +44,7 @@ def is_within(self, steps): return 0 <= steps <= self.full_steps -class GridScanParamsCommon(BaseModel, AbstractExperimentParameterBase): +class GridScanParamsCommon(AbstractExperimentWithBeamParams): """ Common holder class for the parameters of a grid scan in a similar layout to EPICS. The parameters and functions of this class are common @@ -237,36 +237,43 @@ def clean_up(self): self.device.status.clear_sub(self._running_changed) +class MotionProgram(Device): + running = Component(EpicsSignalRO, "PROGBITS") + program_number = Component(EpicsSignalRO, "CS1:PROG_NUM") + + class FastGridScan(Device): - x_steps = Component(EpicsSignalWithRBV, "X_NUM_STEPS") - y_steps = Component(EpicsSignalWithRBV, "Y_NUM_STEPS") - z_steps = Component(EpicsSignalWithRBV, "Z_NUM_STEPS") + x_steps = Component(EpicsSignalWithRBV, "FGS:X_NUM_STEPS") + y_steps = Component(EpicsSignalWithRBV, "FGS:Y_NUM_STEPS") + z_steps = Component(EpicsSignalWithRBV, "FGS:Z_NUM_STEPS") - x_step_size = Component(EpicsSignalWithRBV, "X_STEP_SIZE") - y_step_size = Component(EpicsSignalWithRBV, "Y_STEP_SIZE") - z_step_size = Component(EpicsSignalWithRBV, "Z_STEP_SIZE") + x_step_size = Component(EpicsSignalWithRBV, "FGS:X_STEP_SIZE") + y_step_size = Component(EpicsSignalWithRBV, "FGS:Y_STEP_SIZE") + z_step_size = Component(EpicsSignalWithRBV, "FGS:Z_STEP_SIZE") - dwell_time_ms = Component(EpicsSignalWithRBV, "DWELL_TIME") + dwell_time_ms = Component(EpicsSignalWithRBV, "FGS:DWELL_TIME") - x_start = Component(EpicsSignalWithRBV, "X_START") - y1_start = Component(EpicsSignalWithRBV, "Y_START") - y2_start = Component(EpicsSignalWithRBV, "Y2_START") - z1_start = Component(EpicsSignalWithRBV, "Z_START") - z2_start = Component(EpicsSignalWithRBV, "Z2_START") + x_start = Component(EpicsSignalWithRBV, "FGS:X_START") + y1_start = Component(EpicsSignalWithRBV, "FGS:Y_START") + y2_start = Component(EpicsSignalWithRBV, "FGS:Y2_START") + z1_start = Component(EpicsSignalWithRBV, "FGS:Z_START") + z2_start = Component(EpicsSignalWithRBV, "FGS:Z2_START") position_counter = Component( - EpicsSignal, "POS_COUNTER", write_pv="POS_COUNTER_WRITE" + EpicsSignal, "FGS:POS_COUNTER", write_pv="FGS:POS_COUNTER_WRITE" ) - x_counter = Component(EpicsSignalRO, "X_COUNTER") - y_counter = Component(EpicsSignalRO, "Y_COUNTER") - scan_invalid = Component(EpicsSignalRO, "SCAN_INVALID") + x_counter = Component(EpicsSignalRO, "FGS:X_COUNTER") + y_counter = Component(EpicsSignalRO, "FGS:Y_COUNTER") + scan_invalid = Component(EpicsSignalRO, "FGS:SCAN_INVALID") - run_cmd = Component(EpicsSignal, "RUN.PROC") - stop_cmd = Component(EpicsSignal, "STOP.PROC") - status = Component(EpicsSignalRO, "SCAN_STATUS") + run_cmd = Component(EpicsSignal, "FGS:RUN.PROC") + stop_cmd = Component(EpicsSignal, "FGS:STOP.PROC") + status = Component(EpicsSignalRO, "FGS:SCAN_STATUS") expected_images = Component(Signal) + motion_program = Component(MotionProgram, "") + # Kickoff timeout in seconds KICKOFF_TIMEOUT: float = 5.0 @@ -291,11 +298,15 @@ def is_invalid(self) -> bool: return bool(self.scan_invalid.get()) def kickoff(self) -> StatusBase: - # Check running already here? st = DeviceStatus(device=self, timeout=self.KICKOFF_TIMEOUT) def scan(): try: + curr_prog = self.motion_program.program_number.get() + running = self.motion_program.running.get() + if running: + LOGGER.info(f"Motion program {curr_prog} still running, waiting...") + await_value(self.motion_program.running, 0).wait() LOGGER.debug("Running scan") self.run_cmd.put(1) LOGGER.info("Waiting for FGS to start") diff --git a/src/dodal/devices/motors.py b/src/dodal/devices/motors.py index da98a88dd6..3c522243ae 100644 --- a/src/dodal/devices/motors.py +++ b/src/dodal/devices/motors.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import List, Tuple, Union import numpy as np from ophyd import Component, Device, EpicsMotor @@ -24,8 +25,8 @@ def is_within(self, position: float) -> bool: :param position: The position to check :return: True if position is within the limits """ - low = self.motor.low_limit_travel.get() - high = self.motor.high_limit_travel.get() + low = float(self.motor.low_limit_travel.get()) + high = float(self.motor.high_limit_travel.get()) return low <= position <= high @@ -39,7 +40,9 @@ class XYZLimitBundle: y: MotorLimitHelper z: MotorLimitHelper - def position_valid(self, position: np.ndarray): + def position_valid( + self, position: Union[np.ndarray, List[float], Tuple[float, float, float]] + ): if len(position) != 3: raise ValueError( f"Position valid expects a 3-vector, got {position} instead" diff --git a/src/dodal/devices/oav/grid_overlay.py b/src/dodal/devices/oav/grid_overlay.py index a0908e054a..57e74908fc 100644 --- a/src/dodal/devices/oav/grid_overlay.py +++ b/src/dodal/devices/oav/grid_overlay.py @@ -15,7 +15,7 @@ class Orientation(Enum): def _add_parallel_lines_to_image( - image: Image, + image: Image.Image, start_x: int, start_y: int, line_length: int, @@ -44,13 +44,15 @@ def _add_parallel_lines_to_image( parallel lines to draw.""" lines = [ ( - (start_x, start_y + i * spacing), - (start_x + line_length, start_y + i * spacing), - ) - if orientation == Orientation.horizontal - else ( - (start_x + i * spacing, start_y), - (start_x + i * spacing, start_y + line_length), + ( + (start_x, start_y + i * spacing), + (start_x + line_length, start_y + i * spacing), + ) + if orientation == Orientation.horizontal + else ( + (start_x + i * spacing, start_y), + (start_x + i * spacing, start_y + line_length), + ) ) for i in range(num_lines) ] diff --git a/src/dodal/devices/oav/oav_calculations.py b/src/dodal/devices/oav/oav_calculations.py index cc3ca79b75..6796e3a411 100644 --- a/src/dodal/devices/oav/oav_calculations.py +++ b/src/dodal/devices/oav/oav_calculations.py @@ -1,228 +1,5 @@ -from typing import Tuple - import numpy as np -from dodal.devices.oav.oav_errors import ( - OAVError_MissingRotations, - OAVError_NoRotationsPassValidityTest, -) -from dodal.log import LOGGER - - -def smooth(array): - """ - Remove noise from waveform using a convolution. - - Args: - array (np.ndarray): waveform to be smoothed. - Returns: - array_smooth (np.ndarray): array with noise removed. - """ - - # the smoothing window is set to 50 on i03 - smoothing_window = 50 - box = np.ones(smoothing_window) / smoothing_window - array_smooth = np.convolve(array, box, mode="same") - return array_smooth - - -def find_midpoint(top, bottom): - """ - Finds the midpoint from MXSC edge PVs. The midpoint is considered the centre of the first - bulge in the waveforms. This will correspond to the pin where the sample is located. - - Args: - top (np.ndarray): The waveform corresponding to the top of the pin. - bottom (np.ndarray): The waveform corresponding to the bottom of the pin. - Returns: - i_pixel (int): The i position of the located centre (in pixels). - j_pixel (int): The j position of the located centre (in pixels). - width (int): The width of the pin at the midpoint (in pixels). - """ - - # Widths between top and bottom. - widths = bottom - top - - # The line going down the middle of the waveform. - middle_line = (bottom + top) * 0.5 - - smoothed_width = smooth(widths) - first_derivative = np.gradient(smoothed_width) - - # The derivative introduces more noise, so another application of smooth is neccessary. - # The gradient is reversed prior since a new index has been introduced in smoothing, that is - # negated by smoothing in the reversed array. - reversed_derivative = first_derivative[::-1] - reversed_grad = smooth(reversed_derivative) - grad = reversed_grad[::-1] - - # np.sign gives us the positions where the gradient is positive and negative. - # Taking the diff of th/at gives us an array with all 0's apart from the places - # sign of the gradient went from -1 -> 1 or 1 -> -1. - # Indices are -1 for decreasing width, +1 for increasing width. - increasing_or_decreasing = np.sign(grad) - - # Taking the difference will give us an array with -2/2 for the places the gradient where the gradient - # went from negative->positive/postitive->negative, 0 where it didn't change, and -1 where the gradient goes from 0->1 - # at the pin tip. - gradient_changed = np.diff(increasing_or_decreasing) - - # np.where will give all non-zero indices: the indices where the gradient changed. - # We take the 0th element as the x pos since it's the first place where the gradient changed, indicating a bulge. - stationary_points = np.where(gradient_changed)[0] - - # We'll have one stationary point before the midpoint. - i_pixel = stationary_points[1] - - j_pixel = middle_line[int(i_pixel)] - width = widths[int(i_pixel)] - return (i_pixel, j_pixel, width) - - -def get_rotation_increment(rotations: int, omega: int, high_limit: int) -> float: - """ - By default we'll rotate clockwise (viewing the goniometer from the front), but if we - can't rotate 180 degrees clockwise without exceeding the high_limit threshold then - the goniometer rotates in the anticlockwise direction. - - Args: - rotations (int): The number of rotations we want to add up to 180/-180 - omega (int): The current omega angle of the smargon. - high_limit (int): The maximum allowed angle we want the smargon omega to have. - Returns: - The inrement we should rotate omega by. - """ - - # Number of degrees to rotate to. - increment = 180.0 / rotations - - # If the rotation threshhold would be exceeded flip the rotation direction. - if omega + 180 > high_limit: - increment = -increment - - return increment - - -def filter_rotation_data( - i_positions: np.ndarray, - j_positions: np.ndarray, - widths: np.ndarray, - omega_angles: np.ndarray, - acceptable_i_difference=100, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """ - Filters out outlier positions - those for which the i value of the midpoint unreasonably differs from the median of the i values at other rotations. - - Args: - i_positions (numpy.ndarray): Array where the n-th element corresponds to the i value (in pixels) of the midpoint at rotation n. - j_positions (numpy.ndarray): Array where the n-th element corresponds to the j value (in pixels) of the midpoint at rotation n. - widths (numpy.ndarray): Array where the n-th element corresponds to the pin width (in pixels) of the midpoint at rotation n. - acceptable_i_difference: the acceptable difference between the average value of i and - any individual value of i. We don't want to use exceptional positions for calculation. - Returns: - i_positions_filtered: the i_positions with outliers filtered out - j_positions_filtered: the j_positions with outliers filtered out - widths_filtered: the widths with outliers filtered out - omega_angles_filtered: the omega_angles with outliers filtered out - """ - # Find the average of the non zero elements of the array. - i_median = np.median(i_positions) - - # Filter out outliers. - outlier_i_positions = np.where( - abs(i_positions - i_median) > acceptable_i_difference - )[0] - i_positions_filtered = np.delete(i_positions, outlier_i_positions) - j_positions_filtered = np.delete(j_positions, outlier_i_positions) - widths_filtered = np.delete(widths, outlier_i_positions) - omega_angles_filtered = np.delete(omega_angles, outlier_i_positions) - - if not widths_filtered.size: - raise OAVError_NoRotationsPassValidityTest( - "No rotations pass the validity test." - ) - - return ( - i_positions_filtered, - j_positions_filtered, - widths_filtered, - omega_angles_filtered, - ) - - -def check_i_within_bounds( - max_tip_distance_pixels: int, tip_i: int, i_pixels: int -) -> int: - """ - Checks if i_pixels exceeds max tip distance (in pixels), if so returns max_tip_distance, else i_pixels. - This is necessary as some users send in wierd loops for which the differential method isn't functional. - OAV centring only needs to get in the right ballpark so Xray centring can do its thing. - """ - - tip_distance_pixels = i_pixels - tip_i - if tip_distance_pixels > max_tip_distance_pixels: - LOGGER.warning( - f"x_pixels={i_pixels} exceeds maximum tip distance {max_tip_distance_pixels}, using setting x_pixels within the max tip distance" - ) - i_pixels = max_tip_distance_pixels + tip_i - return i_pixels - - -def extract_pixel_centre_values_from_rotation_data( - i_positions: np.ndarray, - j_positions: np.ndarray, - widths: np.ndarray, - omega_angles: np.ndarray, -) -> Tuple[int, int, int, float, float]: - """ - Takes the obtained midpoints x_positions, y_positions, the pin widths, omega_angles from the rotations - and returns i, j, k, the angle the pin is widest, and the angle orthogonal to it. - - Args: - i_positions (numpy.ndarray): Array where the n-th element corresponds to the x value (in pixels) of the midpoint at rotation n. - j_positions (numpy.ndarray): Array where the n-th element corresponds to the y value (in pixels) of the midpoint at rotation n. - widths (numpy.ndarray): Array where the n-th element corresponds to the pin width (in pixels) of the midpoint at rotation n. - omega_angles (numpy.ndarray): Array where the n-th element corresponds to the omega angle at rotation n. - Returns: - i_pixels (int): The i value (x in pixels) of the midpoint when omega is equal to widest_omega_angle - j_pixels (int): The j value (y in pixels) of the midpoint when omega is equal to widest_omega_angle - k_pixels (int): The k value - the distance in pixels between the the midpoint and the top/bottom of the pin, - when omega is equal to `widest_omega_angle_orthogonal` - widest_omega_angle (float): The value of omega where the pin is widest in the image. - widest_omega_angle_orthogonal (float): The value of omega orthogonal to the angle where the pin is widest in the image. - """ - - ( - i_positions, - j_positions, - widths, - omega_angles, - ) = filter_rotation_data(i_positions, j_positions, widths, omega_angles) - - ( - index_of_largest_width, - index_orthogonal_to_largest_width, - ) = find_widest_point_and_orthogonal_point(widths, omega_angles) - - i_pixels = int(i_positions[index_of_largest_width]) - j_pixels = int(j_positions[index_of_largest_width]) - widest_omega_angle = float(omega_angles[index_of_largest_width]) - - widest_omega_angle_orthogonal = float( - omega_angles[index_orthogonal_to_largest_width] - ) - - # Store the y value which will be the magnitude in the z axis on 90 degree rotation - k_pixels = int(j_positions[index_orthogonal_to_largest_width]) - - return ( - i_pixels, - j_pixels, - k_pixels, - widest_omega_angle, - widest_omega_angle_orthogonal, - ) - def camera_coordinates_to_xyz( horizontal: float, @@ -260,80 +37,3 @@ def camera_coordinates_to_xyz( z = vertical * sine return np.array([x, y, z], dtype=np.float64) - - -def keep_inside_bounds(value: float, lower_bound: float, upper_bound: float) -> float: - """ - If value is above an upper bound then the upper bound is returned. - If value is below a lower bound then the lower bound is returned. - If value is within bounds then the value is returned. - - Args: - value (float): The value being checked against bounds. - lower_bound (float): The lower bound. - lower_bound (float): The upper bound. - """ - if value < lower_bound: - return lower_bound - if value > upper_bound: - return upper_bound - return value - - -def find_widest_point_and_orthogonal_point( - widths: np.ndarray, - omega_angles: np.ndarray, -) -> Tuple[int, int]: - """ - Find the index of the rotation where the pin was widest in the camera, and the indices of rotations orthogonal to it. - - Args: Lists of values taken, the i-th value of the list is the i-th point sampled: - widths (numpy.ndarray): Array where the i-th element corresponds to the pin width (in pixels) of the midpoint at rotation i. - omega_angles (numpy.ndarray): Array where the i-th element corresponds to the omega angle at rotation i. - Returns: The index of the sample which had the widest pin as an int, and the index orthogonal to that - """ - - # Find omega for face-on position: where bulge was widest. - index_of_largest_width = widths.argmax() - widest_omega_angle = omega_angles[index_of_largest_width] - - # Find the best angles orthogonal to the best_omega_angle. - index_orthogonal_to_largest_width = get_orthogonal_index( - omega_angles, widest_omega_angle - ) - - return int(index_of_largest_width), index_orthogonal_to_largest_width - - -def get_orthogonal_index( - angle_array: np.ndarray, angle: float, error_bounds: float = 5 -) -> int: - """ - Takes a numpy array of angles that encompasses 180 deg, and an angle from within - that 180 deg and returns the index of the element most orthogonal to that angle. - - Args: - angle_array (np.ndarray): Numpy array of angles. - angle (float): The angle we want to be orthogonal to - error_bounds (float): The absolute error allowed on the angle - - Returns: - The index of the orthogonal angle - """ - smallest_angle = angle_array.min() - - # Normalise values to be positive - normalised_array = angle_array - smallest_angle - normalised_angle = angle - smallest_angle - - orthogonal_angle = (normalised_angle + 90) % 180 - - angle_distance_to_orthogonal: np.ndarray = abs(normalised_array - orthogonal_angle) - index_of_orthogonal = int(angle_distance_to_orthogonal.argmin()) - - if not (abs((angle_distance_to_orthogonal[index_of_orthogonal])) <= error_bounds): - raise OAVError_MissingRotations( - f"Orthogonal angle found {angle_array[index_of_orthogonal]} not sufficiently orthogonal to angle {angle}" - ) - - return index_of_orthogonal diff --git a/src/dodal/devices/oav/oav_detector.py b/src/dodal/devices/oav/oav_detector.py index d4f9f67814..d95840a43b 100644 --- a/src/dodal/devices/oav/oav_detector.py +++ b/src/dodal/devices/oav/oav_detector.py @@ -17,7 +17,6 @@ StatusBase, ) -from dodal.devices.areadetector.plugins.MXSC import MXSC from dodal.devices.oav.grid_overlay import SnapshotWithGrid from dodal.devices.oav.oav_errors import ( OAVError_BeamPositionNotFound, @@ -60,11 +59,9 @@ class ZoomController(Device): def set_flatfield_on_zoom_level_one(self, value): flat_applied = self.parent.proc.port_name.get() no_flat_applied = self.parent.cam.port_name.get() - - input_plugin = flat_applied if value == "1.0x" else no_flat_applied - - flat_field_status = self.parent.mxsc.input_plugin.set(input_plugin) - return flat_field_status & self.parent.snapshot.input_plugin.set(input_plugin) + return self.parent.snapshot.input_plugin.set( + flat_applied if value == "1.0x" else no_flat_applied + ) @property def allowed_zoom_levels(self): @@ -108,7 +105,7 @@ def update_on_zoom(self, value, xsize, ysize, *args, **kwargs): zoom, xsize, ysize ) - def load_microns_per_pixel(self, zoom: float, xsize: int, ysize: int): + def load_microns_per_pixel(self, zoom: float, xsize: int, ysize: int) -> None: """ Loads the microns per x pixel and y pixel for a given zoom level. These are currently generated by GDA, though hyperion could generate them in future. @@ -198,7 +195,6 @@ class OAV(AreaDetector): tiff = ADC(OverlayPlugin, "-DI-OAV-01:TIFF:") hdf5 = ADC(HDF5Plugin, "-DI-OAV-01:HDF5:") snapshot = Component(SnapshotWithGrid, "-DI-OAV-01:MJPG:") - mxsc = ADC(MXSC, "-DI-OAV-01:MXSC:") zoom_controller = Component(ZoomController, "-EA-OAV-01:FZOOM:") def __init__(self, *args, params: OAVConfigParams, **kwargs): diff --git a/src/dodal/devices/oav/oav_errors.py b/src/dodal/devices/oav/oav_errors.py index 90febccb75..82834196df 100644 --- a/src/dodal/devices/oav/oav_errors.py +++ b/src/dodal/devices/oav/oav_errors.py @@ -1,6 +1,7 @@ """ Module for containing errors in operation of the OAV. """ + from dodal.log import LOGGER diff --git a/src/dodal/devices/oav/pin_image_recognition/__init__.py b/src/dodal/devices/oav/pin_image_recognition/__init__.py index 1e785400be..191b133792 100644 --- a/src/dodal/devices/oav/pin_image_recognition/__init__.py +++ b/src/dodal/devices/oav/pin_image_recognition/__init__.py @@ -4,12 +4,18 @@ import numpy as np from numpy.typing import NDArray -from ophyd_async.core import AsyncStatus, StandardReadable, observe_value +from ophyd_async.core import ( + DEFAULT_TIMEOUT, + AsyncStatus, + StandardReadable, + observe_value, +) from ophyd_async.epics.signal import epics_signal_r from dodal.devices.oav.pin_image_recognition.utils import ( ARRAY_PROCESSING_FUNCTIONS_MAP, MxSampleDetect, + SampleLocation, ScanDirections, identity, ) @@ -46,6 +52,12 @@ def __init__(self, prefix: str, name: str = ""): self._name = name self.triggered_tip = create_soft_signal_r(Tip, "triggered_tip", self.name) + self.triggered_top_edge = create_soft_signal_r( + NDArray[np.uint32], "triggered_top_edge", self.name + ) + self.triggered_bottom_edge = create_soft_signal_r( + NDArray[np.uint32], "triggered_bottom_edge", self.name + ) self.array_data = epics_signal_r(NDArray[np.uint8], f"pva://{prefix}PVA:ARRAY") # Soft parameters for pin-tip detection. @@ -73,23 +85,29 @@ def __init__(self, prefix: str, name: str = ""): ) self.set_readable_signals( - read=[self.triggered_tip], + read=[ + self.triggered_tip, + self.triggered_top_edge, + self.triggered_bottom_edge, + ], ) super().__init__(name=name) - async def _set_triggered_tip(self, value): - if value == self.INVALID_POSITION: + async def _set_triggered_values(self, results: SampleLocation): + tip = (results.tip_x, results.tip_y) + if tip == self.INVALID_POSITION: raise InvalidPinException else: - await self.triggered_tip._backend.put(value) + await self.triggered_tip._backend.put(tip) + await self.triggered_top_edge._backend.put(results.edge_top) + await self.triggered_bottom_edge._backend.put(results.edge_bottom) - async def _get_tip_position(self, array_data: NDArray[np.uint8]) -> Tip: + async def _get_tip_and_edge_data( + self, array_data: NDArray[np.uint8] + ) -> SampleLocation: """ - Gets the location of the pin tip. - - Returns tuple of: - (tip_x, tip_y) + Gets the location of the pin tip and the top and bottom edges. """ preprocess_key = await self.preprocess_operation.get_value() preprocess_iter = await self.preprocess_iterations.get_value() @@ -127,11 +145,10 @@ async def _get_tip_position(self, array_data: NDArray[np.uint8]) -> Tip: (end_time - start_time) * 1000.0 ) ) + return location - return (location.tip_x, location.tip_y) - - async def connect(self, sim: bool = False): - await super().connect(sim) + async def connect(self, sim: bool = False, timeout: float = DEFAULT_TIMEOUT): + await super().connect(sim, timeout) # Set defaults for soft parameters await self.validity_timeout.set(5.0) @@ -156,7 +173,8 @@ async def _set_triggered_tip(): """ async for value in observe_value(self.array_data): try: - await self._set_triggered_tip(await self._get_tip_position(value)) + location = await self._get_tip_and_edge_data(value) + await self._set_triggered_values(location) except Exception as e: LOGGER.warn( f"Failed to detect pin-tip location, will retry with next image: {e}" @@ -173,3 +191,5 @@ async def _set_triggered_tip(): f"No tip found in {await self.validity_timeout.get_value()} seconds." ) await self.triggered_tip._backend.put(self.INVALID_POSITION) + await self.triggered_bottom_edge._backend.put(np.array([])) + await self.triggered_top_edge._backend.put(np.array([])) diff --git a/src/dodal/devices/oav/pin_image_recognition/manual_test.py b/src/dodal/devices/oav/pin_image_recognition/manual_test.py index b70b0a10da..d6f5e495fc 100644 --- a/src/dodal/devices/oav/pin_image_recognition/manual_test.py +++ b/src/dodal/devices/oav/pin_image_recognition/manual_test.py @@ -5,6 +5,7 @@ It is otherwise unused. """ + import asyncio from dodal.devices.oav.pin_image_recognition import PinTipDetection diff --git a/src/dodal/devices/oav/pin_image_recognition/utils.py b/src/dodal/devices/oav/pin_image_recognition/utils.py index 7f42f136c6..ff249a1f6f 100644 --- a/src/dodal/devices/oav/pin_image_recognition/utils.py +++ b/src/dodal/devices/oav/pin_image_recognition/utils.py @@ -97,8 +97,8 @@ class SampleLocation: Holder type for results from sample detection. """ - tip_y: Optional[int] tip_x: Optional[int] + tip_y: Optional[int] edge_top: np.ndarray edge_bottom: np.ndarray @@ -209,7 +209,7 @@ def _locate_sample(self, edge_arr: np.ndarray) -> SampleLocation: "pin-tip detection: No non-narrow edges found - cannot locate pin tip" ) return SampleLocation( - tip_y=None, tip_x=None, edge_bottom=bottom, edge_top=top + tip_x=None, tip_y=None, edge_bottom=bottom, edge_top=top ) # Choose our starting point - i.e. first column with non-narrow width for positive scan, last one for negative scan. @@ -248,5 +248,5 @@ def _locate_sample(self, edge_arr: np.ndarray) -> SampleLocation: ) ) return SampleLocation( - tip_y=tip_y, tip_x=tip_x, edge_bottom=bottom, edge_top=top + tip_x=tip_x, tip_y=tip_y, edge_bottom=bottom, edge_top=top ) diff --git a/src/dodal/devices/robot.py b/src/dodal/devices/robot.py index a30c448895..e775e8702f 100644 --- a/src/dodal/devices/robot.py +++ b/src/dodal/devices/robot.py @@ -1,9 +1,15 @@ +import asyncio from collections import OrderedDict +from dataclasses import dataclass +from enum import Enum from typing import Dict, Sequence -from bluesky.protocols import Descriptor, Reading -from ophyd_async.core import StandardReadable -from ophyd_async.epics.signal import epics_signal_r +from bluesky.protocols import Descriptor, Movable, Reading +from ophyd_async.core import AsyncStatus, StandardReadable, wait_for_value +from ophyd_async.epics.signal import epics_signal_r, epics_signal_x + +from dodal.devices.util.epics_util import epics_signal_rw_rbv +from dodal.log import LOGGER class SingleIndexWaveformReadable(StandardReadable): @@ -44,14 +50,57 @@ async def describe(self) -> dict[str, Descriptor]: return desc -class BartRobot(StandardReadable): +@dataclass +class SampleLocation: + puck: int + pin: int + + +class PinMounted(str, Enum): + NO_PIN_MOUNTED = "No Pin Mounted" + PIN_MOUNTED = "Pin Mounted" + + +class BartRobot(StandardReadable, Movable): """The sample changing robot.""" + LOAD_TIMEOUT = 60 + def __init__( self, name: str, prefix: str, ) -> None: self.barcode = SingleIndexWaveformReadable(prefix + "BARCODE") - self.gonio_pin_sensor = epics_signal_r(bool, prefix + "PIN_MOUNTED") + self.gonio_pin_sensor = epics_signal_r(PinMounted, prefix + "PIN_MOUNTED") + self.next_pin = epics_signal_rw_rbv(float, prefix + "NEXT_PIN") + self.next_puck = epics_signal_rw_rbv(float, prefix + "NEXT_PUCK") + self.load = epics_signal_x(prefix + "LOAD.PROC") + self.program_running = epics_signal_r(bool, prefix + "PROGRAM_RUNNING") + self.program_name = epics_signal_r(str, prefix + "PROGRAM_NAME") super().__init__(name=name) + + async def _load_pin_and_puck(self, sample_location: SampleLocation): + LOGGER.info(f"Loading pin {sample_location}") + if await self.program_running.get_value(): + LOGGER.info( + f"Waiting on robot to finish {await self.program_name.get_value()}" + ) + await wait_for_value(self.program_running, False, None) + await asyncio.gather( + self.next_puck.set(sample_location.puck), + self.next_pin.set(sample_location.pin), + ) + await self.load.trigger() + if await self.gonio_pin_sensor.get_value() == PinMounted.PIN_MOUNTED: + LOGGER.info("Waiting on old pin unloaded") + await wait_for_value(self.gonio_pin_sensor, PinMounted.NO_PIN_MOUNTED, None) + LOGGER.info("Waiting on new pin loaded") + await wait_for_value(self.gonio_pin_sensor, PinMounted.PIN_MOUNTED, None) + + def set(self, sample_location: SampleLocation) -> AsyncStatus: + return AsyncStatus( + asyncio.wait_for( + self._load_pin_and_puck(sample_location), timeout=self.LOAD_TIMEOUT + ) + ) diff --git a/src/dodal/devices/scatterguard.py b/src/dodal/devices/scatterguard.py index 6c9374169f..b29b148bd7 100644 --- a/src/dodal/devices/scatterguard.py +++ b/src/dodal/devices/scatterguard.py @@ -1,7 +1,9 @@ from ophyd import Component as Cpt -from ophyd import Device, EpicsMotor +from ophyd import Device + +from dodal.devices.util.motor_utils import ExtendedEpicsMotor class Scatterguard(Device): - x = Cpt(EpicsMotor, "X") - y = Cpt(EpicsMotor, "Y") + x = Cpt(ExtendedEpicsMotor, "X") + y = Cpt(ExtendedEpicsMotor, "Y") diff --git a/src/dodal/devices/smargon.py b/src/dodal/devices/smargon.py index 9a27bde268..c4ea5c421e 100644 --- a/src/dodal/devices/smargon.py +++ b/src/dodal/devices/smargon.py @@ -1,13 +1,14 @@ from enum import Enum from ophyd import Component as Cpt -from ophyd import Device, EpicsMotor, EpicsSignal, EpicsSignalRO +from ophyd import Device, EpicsMotor, EpicsSignal from ophyd.epics_motor import MotorBundle from ophyd.status import StatusBase from dodal.devices.motors import MotorLimitHelper, XYZLimitBundle from dodal.devices.status import await_approx_value from dodal.devices.util.epics_util import SetWhenEnabled +from dodal.devices.util.motor_utils import ExtendedEpicsMotor class StubPosition(Enum): @@ -48,13 +49,12 @@ class Smargon(MotorBundle): Robot loading can nudge these and lead to errors. """ - x = Cpt(EpicsMotor, "X") - x_speed_limit_mm_per_s = Cpt(EpicsSignalRO, "X.VMAX") + x = Cpt(ExtendedEpicsMotor, "X") y = Cpt(EpicsMotor, "Y") z = Cpt(EpicsMotor, "Z") chi = Cpt(EpicsMotor, "CHI") phi = Cpt(EpicsMotor, "PHI") - omega = Cpt(EpicsMotor, "OMEGA") + omega = Cpt(ExtendedEpicsMotor, "OMEGA") real_x1 = Cpt(EpicsMotor, "MOTOR_3") real_x2 = Cpt(EpicsMotor, "MOTOR_4") diff --git a/src/dodal/devices/synchrotron.py b/src/dodal/devices/synchrotron.py index cc75bcef65..ae9eabe8ca 100644 --- a/src/dodal/devices/synchrotron.py +++ b/src/dodal/devices/synchrotron.py @@ -1,9 +1,25 @@ from enum import Enum -from ophyd import Component, Device, EpicsSignal +from ophyd_async.core import StandardReadable +from ophyd_async.epics.signal import epics_signal_r -class SynchrotronMode(Enum): +class Prefix(str, Enum): + STATUS = "CS-CS-MSTAT-01:" + TOP_UP = "SR-CS-FILL-01:" + SIGNAL = "SR-DI-DCCT-01:" + + +class Suffix(str, Enum): + SIGNAL = "SIGNAL" + MODE = "MODE" + USER_COUNTDOWN = "USERCOUNTDN" + BEAM_ENERGY = "BEAMENERGY" + COUNTDOWN = "COUNTDOWN" + END_COUNTDOWN = "ENDCOUNTDN" + + +class SynchrotronMode(str, Enum): SHUTDOWN = "Shutdown" INJECTION = "Injection" NOBEAM = "No Beam" @@ -14,19 +30,41 @@ class SynchrotronMode(Enum): UNKNOWN = "Unknown" -class SynchrotoronMachineStatus(Device): - synchrotron_mode = Component(EpicsSignal, "MODE", string=True) - user_countdown = Component(EpicsSignal, "USERCOUNTDN") - beam_energy = Component(EpicsSignal, "BEAMENERGY") - - -class SynchrotronTopUp(Device): - start_countdown = Component(EpicsSignal, "COUNTDOWN") - end_countdown = Component(EpicsSignal, "ENDCOUNTDN") - - -class Synchrotron(Device): - machine_status = Component(SynchrotoronMachineStatus, "CS-CS-MSTAT-01:") - top_up = Component(SynchrotronTopUp, "SR-CS-FILL-01:") +class Synchrotron(StandardReadable): + def __init__( + self, + prefix: str = "", + name: str = "synchrotron", + *, + signal_prefix=Prefix.SIGNAL, + status_prefix=Prefix.STATUS, + topup_prefix=Prefix.TOP_UP, + ): + self.ring_current = epics_signal_r(float, signal_prefix + Suffix.SIGNAL) + self.synchrotron_mode = epics_signal_r( + SynchrotronMode, status_prefix + Suffix.MODE + ) + self.machine_user_countdown = epics_signal_r( + float, status_prefix + Suffix.USER_COUNTDOWN + ) + self.beam_energy = epics_signal_r(float, status_prefix + Suffix.BEAM_ENERGY) + self.topup_start_countdown = epics_signal_r( + float, topup_prefix + Suffix.COUNTDOWN + ) + self.top_up_end_countdown = epics_signal_r( + float, topup_prefix + Suffix.END_COUNTDOWN + ) - ring_current = Component(EpicsSignal, "SR-DI-DCCT-01:SIGNAL") + self.set_readable_signals( + read=[ + self.ring_current, + self.machine_user_countdown, + self.topup_start_countdown, + self.top_up_end_countdown, + ], + config=[ + self.beam_energy, + self.synchrotron_mode, + ], + ) + super().__init__(name=name) diff --git a/src/dodal/devices/turbo_slit.py b/src/dodal/devices/turbo_slit.py new file mode 100644 index 0000000000..d9db5491c2 --- /dev/null +++ b/src/dodal/devices/turbo_slit.py @@ -0,0 +1,20 @@ +from ophyd_async.core import Device +from ophyd_async.epics.motion.motor import Motor + + +class TurboSlit(Device): + """ + This collection of motors coordinates time resolved XAS experiments. + It selects a beam out of the polychromatic fan. + These slits can be scanned continously or in step mode. + The relationship between the three motors is as follows: + - gap provides energy resolution + - xfine selects the energy + - arc - ??? + """ + + def __init__(self, prefix: str, name: str): + self.gap = Motor(prefix=prefix + "GAP") + self.arc = Motor(prefix=prefix + "ARC") + self.xfine = Motor(prefix=prefix + "XFINE") + super().__init__(name=name) diff --git a/src/dodal/devices/util/adjuster_plans.py b/src/dodal/devices/util/adjuster_plans.py index 98ddd9c91a..9c7cc69826 100644 --- a/src/dodal/devices/util/adjuster_plans.py +++ b/src/dodal/devices/util/adjuster_plans.py @@ -2,6 +2,7 @@ All the methods in this module return a bluesky plan generator that adjusts a value according to some criteria either via feedback, preset positions, lookup tables etc. """ + from typing import Callable, Generator from bluesky import plan_stubs as bps diff --git a/src/dodal/devices/util/epics_util.py b/src/dodal/devices/util/epics_util.py index 80525a8d2e..0f71e62025 100644 --- a/src/dodal/devices/util/epics_util.py +++ b/src/dodal/devices/util/epics_util.py @@ -3,6 +3,7 @@ from ophyd import Component, Device, EpicsSignal from ophyd.status import Status, StatusBase +from ophyd_async.epics.signal import epics_signal_rw from dodal.devices.status import await_value from dodal.log import LOGGER @@ -125,3 +126,9 @@ def set(self, proc: int) -> Status: lambda: self.proc.set(proc), ] ) + + +def epics_signal_rw_rbv( + T, write_pv: str +): # Remove when https://github.com/bluesky/ophyd-async/issues/139 is done + return epics_signal_rw(T, write_pv + "_RBV", write_pv) diff --git a/src/dodal/devices/util/lookup_tables.py b/src/dodal/devices/util/lookup_tables.py index 24fff56109..0e1b74475d 100644 --- a/src/dodal/devices/util/lookup_tables.py +++ b/src/dodal/devices/util/lookup_tables.py @@ -2,6 +2,7 @@ All the public methods in this module return a lookup table of some kind that converts the source value s to a target value t for different values of s. """ + from collections.abc import Sequence from typing import Callable diff --git a/src/dodal/devices/util/motor_utils.py b/src/dodal/devices/util/motor_utils.py new file mode 100644 index 0000000000..07638ba3f6 --- /dev/null +++ b/src/dodal/devices/util/motor_utils.py @@ -0,0 +1,6 @@ +from ophyd import Component, EpicsMotor, EpicsSignalRO + + +class ExtendedEpicsMotor(EpicsMotor): + motor_resolution: Component[EpicsSignalRO] = Component(EpicsSignalRO, ".MRES") + max_velocity: Component[EpicsSignalRO] = Component(EpicsSignalRO, ".VMAX") diff --git a/src/dodal/devices/zebra.py b/src/dodal/devices/zebra.py index 3a1066041d..12c582d0c1 100644 --- a/src/dodal/devices/zebra.py +++ b/src/dodal/devices/zebra.py @@ -1,24 +1,18 @@ from __future__ import annotations -from enum import Enum, IntEnum +import asyncio +from enum import Enum from functools import partialmethod from typing import List -from ophyd import Component, Device, EpicsSignal, StatusBase - -from dodal.devices.status import await_value -from dodal.devices.util.epics_util import epics_signal_put_wait - -PC_ARM_SOURCE_SOFT = "Soft" -PC_ARM_SOURCE_EXT = "External" - -PC_GATE_SOURCE_POSITION = 0 -PC_GATE_SOURCE_TIME = 1 -PC_GATE_SOURCE_EXTERNAL = 2 - -PC_PULSE_SOURCE_POSITION = 0 -PC_PULSE_SOURCE_TIME = 1 -PC_PULSE_SOURCE_EXTERNAL = 2 +from ophyd_async.core import ( + AsyncStatus, + DeviceVector, + SignalRW, + StandardReadable, + observe_value, +) +from ophyd_async.epics.signal import epics_signal_rw # Sources DISCONNECT = 0 @@ -45,99 +39,127 @@ TTL_PANDA = 4 -class I03Axes(Enum): - SMARGON_X1 = "Enc1" - SMARGON_Y = "Enc2" - SMARGON_Z = "Enc3" - OMEGA = "Enc4" +class ArmSource(str, Enum): + SOFT = "Soft" + EXTERNAL = "External" + + +class TrigSource(str, Enum): + POSITION = "Position" + TIME = "Time" + EXTERNAL = "External" + + +class EncEnum(str, Enum): + Enc1 = "Enc1" + Enc2 = "Enc2" + Enc3 = "Enc3" + Enc4 = "Enc4" + Enc1_4Av = "Enc1-4Av" -class I24Axes(Enum): - VGON_Z = "Enc1" - OMEGA = "Enc2" - VGON_X = "Enc3" - VGON_YH = "Enc4" +class I03Axes: + SMARGON_X1 = EncEnum.Enc1 + SMARGON_Y = EncEnum.Enc2 + SMARGON_Z = EncEnum.Enc3 + OMEGA = EncEnum.Enc4 -class RotationDirection(IntEnum): - POSITIVE = 1 - NEGATIVE = -1 +class I24Axes: + VGON_Z = EncEnum.Enc1 + OMEGA = EncEnum.Enc2 + VGON_X = EncEnum.Enc3 + VGON_YH = EncEnum.Enc4 -class ArmDemand(IntEnum): +class RotationDirection(str, Enum): + POSITIVE = "Positive" + NEGATIVE = "Negative" + + +class ArmDemand(Enum): ARM = 1 DISARM = 0 -class FastShutterAction(IntEnum): - OPEN = 1 - CLOSE = 0 +class SoftInState(str, Enum): + YES = "Yes" + NO = "No" -class ArmingDevice(Device): +class ArmingDevice(StandardReadable): """A useful device that can abstract some of the logic of arming. Allows a user to just call arm.set(ArmDemand.ARM)""" TIMEOUT = 3 - arm_set = Component(EpicsSignal, "PC_ARM") - disarm_set = Component(EpicsSignal, "PC_DISARM") - armed = Component(EpicsSignal, "PC_ARM_OUT") + def __init__(self, prefix: str, name: str = "") -> None: + self.arm_set = epics_signal_rw(float, prefix + "PC_ARM") + self.disarm_set = epics_signal_rw(float, prefix + "PC_DISARM") + self.armed = epics_signal_rw(float, prefix + "PC_ARM_OUT") + super().__init__(name) - def set(self, demand: ArmDemand) -> StatusBase: - status = await_value(self.armed, demand.value, timeout=self.TIMEOUT) + async def _set_armed(self, demand: ArmDemand): signal_to_set = self.arm_set if demand == ArmDemand.ARM else self.disarm_set - status &= signal_to_set.set(1) - return status + await signal_to_set.set(1) + async for reading in observe_value(self.armed): + if reading == demand.value: + return + def set(self, demand: ArmDemand) -> AsyncStatus: + return AsyncStatus( + asyncio.wait_for(self._set_armed(demand), timeout=self.TIMEOUT) + ) -class PositionCompare(Device): - num_gates = epics_signal_put_wait("PC_GATE_NGATE") - gate_trigger = epics_signal_put_wait("PC_ENC") - gate_source = epics_signal_put_wait("PC_GATE_SEL") - gate_input = epics_signal_put_wait("PC_GATE_INP") - gate_width = epics_signal_put_wait("PC_GATE_WID") - gate_start = epics_signal_put_wait("PC_GATE_START") - gate_step = epics_signal_put_wait("PC_GATE_STEP") - pulse_source = epics_signal_put_wait("PC_PULSE_SEL") - pulse_input = epics_signal_put_wait("PC_PULSE_INP") - pulse_start = epics_signal_put_wait("PC_PULSE_START") - pulse_width = epics_signal_put_wait("PC_PULSE_WID") - pulse_step = epics_signal_put_wait("PC_PULSE_STEP") - pulse_max = epics_signal_put_wait("PC_PULSE_MAX") +class PositionCompare(StandardReadable): + def __init__(self, prefix: str, name: str = "") -> None: + self.num_gates = epics_signal_rw(float, prefix + "PC_GATE_NGATE") + self.gate_trigger = epics_signal_rw(EncEnum, prefix + "PC_ENC") + self.gate_source = epics_signal_rw(TrigSource, prefix + "PC_GATE_SEL") + self.gate_input = epics_signal_rw(float, prefix + "PC_GATE_INP") + self.gate_width = epics_signal_rw(float, prefix + "PC_GATE_WID") + self.gate_start = epics_signal_rw(float, prefix + "PC_GATE_START") + self.gate_step = epics_signal_rw(float, prefix + "PC_GATE_STEP") - dir = Component(EpicsSignal, "PC_DIR") - arm_source = epics_signal_put_wait("PC_ARM_SEL") - reset = Component(EpicsSignal, "SYS_RESET.PROC") + self.pulse_source = epics_signal_rw(TrigSource, prefix + "PC_PULSE_SEL") + self.pulse_input = epics_signal_rw(float, prefix + "PC_PULSE_INP") + self.pulse_start = epics_signal_rw(float, prefix + "PC_PULSE_START") + self.pulse_width = epics_signal_rw(float, prefix + "PC_PULSE_WID") + self.pulse_step = epics_signal_rw(float, prefix + "PC_PULSE_STEP") + self.pulse_max = epics_signal_rw(float, prefix + "PC_PULSE_MAX") - arm = Component(ArmingDevice, "") + self.dir = epics_signal_rw(RotationDirection, prefix + "PC_DIR") + self.arm_source = epics_signal_rw(ArmSource, prefix + "PC_ARM_SEL") + self.reset = epics_signal_rw(int, prefix + "SYS_RESET.PROC") - def is_armed(self) -> bool: - return self.arm.armed.get() == 1 + self.arm = ArmingDevice(prefix) + super().__init__(name) + async def is_armed(self) -> bool: + arm_state = await self.arm.armed.get_value() + return arm_state == 1 -class PulseOutput(Device): - input = epics_signal_put_wait("_INP") - delay = epics_signal_put_wait("_DLY") - width = epics_signal_put_wait("_WID") +class PulseOutput(StandardReadable): + """Zebra pulse output panel.""" -class ZebraOutputPanel(Device): - pulse_1 = Component(PulseOutput, "PULSE1") - pulse_2 = Component(PulseOutput, "PULSE2") + def __init__(self, prefix: str, name: str = "") -> None: + self.input = epics_signal_rw(float, prefix + "_INP") + self.delay = epics_signal_rw(float, prefix + "_DLY") + self.width = epics_signal_rw(float, prefix + "_WID") + super().__init__(name) - out_1 = epics_signal_put_wait("OUT1_TTL") - out_2 = epics_signal_put_wait("OUT2_TTL") - out_3 = epics_signal_put_wait("OUT3_TTL") - out_4 = epics_signal_put_wait("OUT4_TTL") - @property - def out_pvs(self) -> List[EpicsSignal]: - """A list of all the output TTL PVs. Note that as the PVs are 1 indexed - `out_pvs[0]` is `None`. - """ - return [None, self.out_1, self.out_2, self.out_3, self.out_4] +class ZebraOutputPanel(StandardReadable): + def __init__(self, prefix: str, name: str = "") -> None: + self.pulse_1 = PulseOutput(prefix + "PULSE1") + self.pulse_2 = PulseOutput(prefix + "PULSE2") + + self.out_pvs: DeviceVector[SignalRW] = DeviceVector( + {i: epics_signal_rw(float, prefix + f"OUT{i}_TTL") for i in range(1, 5)} + ) + super().__init__(name) def boolean_array_to_integer(values: List[bool]) -> int: @@ -153,17 +175,14 @@ def boolean_array_to_integer(values: List[bool]) -> int: return sum(v << i for i, v in enumerate(values)) -class GateControl(Device): - enable = epics_signal_put_wait("_ENA", 30.0) - source_1 = epics_signal_put_wait("_INP1", 30.0) - source_2 = epics_signal_put_wait("_INP2", 30.0) - source_3 = epics_signal_put_wait("_INP3", 30.0) - source_4 = epics_signal_put_wait("_INP4", 30.0) - invert = epics_signal_put_wait("_INV", 30.0) - - @property - def sources(self): - return [self.source_1, self.source_2, self.source_3, self.source_4] +class GateControl(StandardReadable): + def __init__(self, prefix: str, name: str = "") -> None: + self.enable = epics_signal_rw(int, prefix + "_ENA") + self.sources = DeviceVector( + {i: epics_signal_rw(float, prefix + f"_INP{i}") for i in range(1, 5)} + ) + self.invert = epics_signal_rw(int, prefix + "_INV") + super().__init__(name) class GateType(Enum): @@ -171,36 +190,25 @@ class GateType(Enum): OR = "OR" -class LogicGateConfigurer(Device): +class LogicGateConfigurer(StandardReadable): DEFAULT_SOURCE_IF_GATE_NOT_USED = 0 - and_gate_1 = Component(GateControl, "AND1") - and_gate_2 = Component(GateControl, "AND2") - and_gate_3 = Component(GateControl, "AND3") - and_gate_4 = Component(GateControl, "AND4") + def __init__(self, prefix: str, name: str = "") -> None: + self.and_gates: DeviceVector[GateControl] = DeviceVector( + {i: GateControl(prefix + f"AND{i}") for i in range(1, 5)} + ) - or_gate_1 = Component(GateControl, "OR1") - or_gate_2 = Component(GateControl, "OR2") - or_gate_3 = Component(GateControl, "OR3") - or_gate_4 = Component(GateControl, "OR4") + self.or_gates: DeviceVector[GateControl] = DeviceVector( + {i: GateControl(prefix + f"OR{i}") for i in range(1, 5)} + ) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) self.all_gates = { - GateType.AND: [ - self.and_gate_1, - self.and_gate_2, - self.and_gate_3, - self.and_gate_4, - ], - GateType.OR: [ - self.or_gate_1, - self.or_gate_2, - self.or_gate_3, - self.or_gate_4, - ], + GateType.AND: list(self.and_gates.values()), + GateType.OR: list(self.or_gates.values()), } + super().__init__(name) + def apply_logic_gate_config( self, type: GateType, gate_number: int, config: LogicGateConfiguration ): @@ -213,17 +221,17 @@ def apply_logic_gate_config( """ gate: GateControl = self.all_gates[type][gate_number - 1] - gate.enable.put(boolean_array_to_integer([True] * len(config.sources))) + gate.enable.set(boolean_array_to_integer([True] * len(config.sources))) # Input Source - for source_number, source_pv in enumerate(gate.sources): + for source_number, source_pv in gate.sources.items(): try: - source_pv.put(config.sources[source_number]) + source_pv.set(config.sources[source_number - 1]) except IndexError: - source_pv.put(self.DEFAULT_SOURCE_IF_GATE_NOT_USED) + source_pv.set(self.DEFAULT_SOURCE_IF_GATE_NOT_USED) # Invert - gate.invert.put(boolean_array_to_integer(config.invert)) + gate.invert.set(boolean_array_to_integer(config.invert)) apply_and_gate_config = partialmethod(apply_logic_gate_config, GateType.AND) apply_or_gate_config = partialmethod(apply_logic_gate_config, GateType.OR) @@ -265,15 +273,21 @@ def __str__(self) -> str: return ", ".join(input_strings) -class SoftInputs(Device): - soft_in_1 = Component(EpicsSignal, "SOFT_IN:B0") - soft_in_2 = Component(EpicsSignal, "SOFT_IN:B1") - soft_in_3 = Component(EpicsSignal, "SOFT_IN:B2") - soft_in_4 = Component(EpicsSignal, "SOFT_IN:B3") +class SoftInputs(StandardReadable): + def __init__(self, prefix: str, name: str = "") -> None: + self.soft_in_1 = epics_signal_rw(SoftInState, prefix + "SOFT_IN:B0") + self.soft_in_2 = epics_signal_rw(SoftInState, prefix + "SOFT_IN:B1") + self.soft_in_3 = epics_signal_rw(SoftInState, prefix + "SOFT_IN:B2") + self.soft_in_4 = epics_signal_rw(SoftInState, prefix + "SOFT_IN:B3") + super().__init__(name) + +class Zebra(StandardReadable): + """The Zebra device.""" -class Zebra(Device): - pc = Component(PositionCompare, "") - output = Component(ZebraOutputPanel, "") - inputs = Component(SoftInputs, "") - logic_gates = Component(LogicGateConfigurer, "") + def __init__(self, name: str, prefix: str) -> None: + self.pc = PositionCompare(prefix, name) + self.output = ZebraOutputPanel(prefix, name) + self.inputs = SoftInputs(prefix, name) + self.logic_gates = LogicGateConfigurer(prefix, name) + super().__init__(name=name) diff --git a/src/dodal/devices/zocalo/__init__.py b/src/dodal/devices/zocalo/__init__.py index edd2579d2e..3644917422 100644 --- a/src/dodal/devices/zocalo/__init__.py +++ b/src/dodal/devices/zocalo/__init__.py @@ -1,6 +1,4 @@ -from dodal.devices.zocalo.zocalo_interaction import ( - ZocaloTrigger, -) +from dodal.devices.zocalo.zocalo_interaction import ZocaloStartInfo, ZocaloTrigger from dodal.devices.zocalo.zocalo_results import ( NoResultsFromZocalo, NoZocaloSubscription, @@ -17,6 +15,7 @@ "ZOCALO_READING_PLAN_NAME", "NoResultsFromZocalo", "NoZocaloSubscription", + "ZocaloStartInfo", ] ZOCALO_READING_PLAN_NAME = "zocalo reading" diff --git a/src/dodal/devices/zocalo/zocalo_interaction.py b/src/dodal/devices/zocalo/zocalo_interaction.py index 217da22d41..ead267a58b 100644 --- a/src/dodal/devices/zocalo/zocalo_interaction.py +++ b/src/dodal/devices/zocalo/zocalo_interaction.py @@ -1,5 +1,8 @@ +import dataclasses import getpass import socket +from dataclasses import dataclass +from typing import Optional import zocalo.configuration from workflows.transport import lookup @@ -16,6 +19,25 @@ def _get_zocalo_connection(environment): return transport +@dataclass +class ZocaloStartInfo: + """ + ispyb_dcid (int): The ID of the data collection in ISPyB + filename (str): The name of the file that the detector will store into dev/shm + start_frame_index (int): The index of the first image of this collection within the file + written by the detector + number_of_frames (int): The number of frames in this collection + message_index (int): Which trigger this is in the detector collection e.g. 0 for the + first collection after a single arm, 1 for the next... + """ + + ispyb_dcid: int + filename: Optional[str] + start_frame_index: int + number_of_frames: int + message_index: int + + class ZocaloTrigger: """This class just sends 'run_start' and 'run_end' messages to zocalo, it is intended to be used in bluesky callback classes. To get results from zocalo back @@ -42,16 +64,20 @@ def _send_to_zocalo(self, parameters: dict): finally: transport.disconnect() - def run_start(self, data_collection_id: int): + def run_start( + self, + start_data: ZocaloStartInfo, + ): """Tells the data analysis pipeline we have started a run. Assumes that appropriate data has already been put into ISPyB Args: - data_collection_id (int): The ID of the data collection representing the - gridscan in ISPyB + start_data (ZocaloStartInfo): Data about the collection to send to zocalo """ - LOGGER.info(f"Starting Zocalo job with ispyb id {data_collection_id}") - self._send_to_zocalo({"event": "start", "ispyb_dcid": data_collection_id}) + LOGGER.info(f"Starting Zocalo job {start_data}") + data = dataclasses.asdict(start_data) + data["event"] = "start" + self._send_to_zocalo(data) def run_end(self, data_collection_id: int): """Tells the data analysis pipeline we have finished a run. diff --git a/src/dodal/log.py b/src/dodal/log.py index c59c735738..ad7083181c 100644 --- a/src/dodal/log.py +++ b/src/dodal/log.py @@ -18,8 +18,9 @@ DEFAULT_FORMATTER = logging.Formatter( "[%(asctime)s] %(name)s %(module)s %(levelname)s: %(message)s" ) -ERROR_LOG_BUFFER_LINES = 200000 +ERROR_LOG_BUFFER_LINES = 20000 INFO_LOG_DAYS = 30 +DEBUG_LOG_FILES_TO_KEEP = 7 class CircularMemoryHandler(logging.Handler): @@ -131,10 +132,14 @@ def set_up_DEBUG_memory_handler( print(f"Logging to {path/filename}") debug_path = path / "debug" debug_path.mkdir(parents=True, exist_ok=True) - file_handler = TimedRotatingFileHandler(filename=debug_path / filename, when="H") + file_handler = TimedRotatingFileHandler( + filename=debug_path / filename, when="H", backupCount=DEBUG_LOG_FILES_TO_KEEP + ) file_handler.setLevel(logging.DEBUG) memory_handler = CircularMemoryHandler( - capacity=capacity, flushLevel=logging.ERROR, target=file_handler + capacity=capacity, + flushLevel=logging.ERROR, + target=file_handler, ) memory_handler.setLevel(logging.DEBUG) memory_handler.addFilter(beamline_filter) diff --git a/src/dodal/parameters/experiment_parameter_base.py b/src/dodal/parameters/experiment_parameter_base.py index ccfd758e4a..604453fd63 100644 --- a/src/dodal/parameters/experiment_parameter_base.py +++ b/src/dodal/parameters/experiment_parameter_base.py @@ -1,7 +1,15 @@ from abc import ABC, abstractmethod +from pydantic import BaseModel + + +class AbstractExperimentParameterBase(BaseModel, ABC): + pass + + +class AbstractExperimentWithBeamParams(AbstractExperimentParameterBase): + transmission_fraction: float -class AbstractExperimentParameterBase(ABC): @abstractmethod def get_num_images(self) -> int: pass diff --git a/src/dodal/plans/check_topup.py b/src/dodal/plans/check_topup.py new file mode 100644 index 0000000000..78110d70c5 --- /dev/null +++ b/src/dodal/plans/check_topup.py @@ -0,0 +1,82 @@ +import bluesky.plan_stubs as bps + +from dodal.devices.synchrotron import Synchrotron, SynchrotronMode +from dodal.log import LOGGER + +ALLOWED_MODES = [SynchrotronMode.USER, SynchrotronMode.SPECIAL] +DECAY_MODE_COUNTDOWN = -1 # Value of the start_countdown PV when in decay mode +COUNTDOWN_DURING_TOPUP = 0 + + +def _in_decay_mode(time_to_topup): + if time_to_topup == DECAY_MODE_COUNTDOWN: + LOGGER.info("Machine in decay mode, gating disabled") + return True + return False + + +def _gating_permitted(machine_mode: SynchrotronMode): + if machine_mode in ALLOWED_MODES: + LOGGER.info("Machine in allowed mode, gating top up enabled.") + return True + LOGGER.info("Machine not in allowed mode, gating disabled") + return False + + +def _delay_to_avoid_topup(total_run_time, time_to_topup): + if total_run_time > time_to_topup: + LOGGER.info( + """ + Total run time for this collection exceeds time to next top up. + Collection delayed until top up done. + """ + ) + return True + LOGGER.info( + """ + Total run time less than time to next topup. Proceeding with collection. + """ + ) + return False + + +def wait_for_topup_complete(synchrotron: Synchrotron): + LOGGER.info("Waiting for topup to complete") + start = yield from bps.rd(synchrotron.topup_start_countdown) + while start == COUNTDOWN_DURING_TOPUP: + yield from bps.sleep(0.1) + start = yield from bps.rd(synchrotron.topup_start_countdown) + + +def check_topup_and_wait_if_necessary( + synchrotron: Synchrotron, + total_exposure_time: float, + ops_time: float, # Account for xray centering, rotation speed, etc +): # See https://github.com/DiamondLightSource/hyperion/issues/932 + """A small plan to check if topup gating is permitted and sleep until the topup\ + is over if it starts before the end of collection. + + Args: + synchrotron (Synchrotron): Synchrotron device. + total_exposure_time (float): Expected total exposure time for \ + collection, in seconds. + ops_time (float): Additional time to account for various operations,\ + eg. x-ray centering, in seconds. Defaults to 30.0. + """ + machine_mode = yield from bps.rd(synchrotron.synchrotron_mode) + assert isinstance(machine_mode, SynchrotronMode) + time_to_topup = yield from bps.rd(synchrotron.topup_start_countdown) + if _in_decay_mode(time_to_topup) or not _gating_permitted(machine_mode): + yield from bps.null() + return + tot_run_time = total_exposure_time + ops_time + end_topup = yield from bps.rd(synchrotron.top_up_end_countdown) + time_to_wait = ( + end_topup if _delay_to_avoid_topup(tot_run_time, time_to_topup) else 0.0 + ) + + yield from bps.sleep(time_to_wait) + + check_start = yield from bps.rd(synchrotron.topup_start_countdown) + if check_start == COUNTDOWN_DURING_TOPUP: + yield from wait_for_topup_complete(synchrotron) diff --git a/src/dodal/utils.py b/src/dodal/utils.py index 2becfb1755..991909ddaf 100644 --- a/src/dodal/utils.py +++ b/src/dodal/utils.py @@ -184,9 +184,7 @@ def collect_factories( def _is_device_skipped(func: AnyDeviceFactory) -> bool: - if not hasattr(func, "__skip__"): - return False - return func.__skip__ # type: ignore + return getattr(func, "__skip__", False) def is_v1_device_factory(func: Callable) -> bool: @@ -239,7 +237,7 @@ def get_beamline_based_on_environment_variable() -> ModuleType: if ( len(beamline) == 0 or beamline[0] not in string.ascii_letters - or not all(c in valid_characters for c in beamline) + or any(c not in valid_characters for c in beamline) ): raise ValueError( "Invalid BEAMLINE variable - module name is not a permissible python module name, got '{}'".format( @@ -269,10 +267,7 @@ def _find_next_run_number_from_files(file_names: List[str]) -> int: dodal.log.LOGGER.warning( f"Identified nexus file {file_name} with unexpected format" ) - if len(valid_numbers) != 0: - return max(valid_numbers) + 1 - else: - return 1 + return max(valid_numbers) + 1 if valid_numbers else 1 def get_run_number(directory: str) -> int: diff --git a/tests/beamlines/unit_tests/test_beamline_utils.py b/tests/beamlines/unit_tests/test_beamline_utils.py index 20f182723d..39653bf741 100644 --- a/tests/beamlines/unit_tests/test_beamline_utils.py +++ b/tests/beamlines/unit_tests/test_beamline_utils.py @@ -6,6 +6,8 @@ from ophyd.device import Device as OphydV1Device from ophyd.sim import FakeEpicsSignal from ophyd_async.core import Device as OphydV2Device +from ophyd_async.core import StandardReadable +from ophyd_async.core.sim_signal_backend import SimSignalBackend from dodal.beamlines import beamline_utils, i03 from dodal.devices.aperturescatterguard import ApertureScatterguard @@ -49,12 +51,21 @@ def test_instantiating_different_device_with_same_name(): assert dev2 in beamline_utils.ACTIVE_DEVICES.values() -def test_instantiate_function_fake_makes_fake(): +def test_instantiate_v1_function_fake_makes_fake(): + smargon: Smargon = beamline_utils.device_instantiation( + i03.Smargon, "smargon", "", True, True, None + ) + assert isinstance(smargon, Device) + assert isinstance(smargon.disabled, FakeEpicsSignal) + + +def test_instantiate_v2_function_fake_makes_fake(): + RE() fake_zeb: Zebra = beamline_utils.device_instantiation( i03.Zebra, "zebra", "", True, True, None ) - assert isinstance(fake_zeb, Device) - assert isinstance(fake_zeb.pc.arm_source, FakeEpicsSignal) + assert isinstance(fake_zeb, StandardReadable) + assert isinstance(fake_zeb.pc.arm.armed._backend, SimSignalBackend) def test_clear_devices(RE): @@ -100,10 +111,14 @@ def test_wait_for_v2_device_connection_passes_through_timeout( ): RE() device = OphydV2Device() + device.connect = MagicMock() beamline_utils._wait_for_connection(device, **kwargs) - call_in_bluesky_el.assert_called_once_with(ANY, timeout=expected_timeout) + device.connect.assert_called_once_with( + sim=ANY, + timeout=expected_timeout, + ) def test_default_directory_provider_is_singleton(): diff --git a/tests/common/test_coordination.py b/tests/common/test_coordination.py new file mode 100644 index 0000000000..b4db6e1204 --- /dev/null +++ b/tests/common/test_coordination.py @@ -0,0 +1,12 @@ +import uuid + +import pytest + +from dodal.common.coordination import group_uuid + + +@pytest.mark.parametrize("group", ["foo", "bar", "baz", str(uuid.uuid4())]) +def test_group_uid(group: str): + gid = group_uuid(group) + assert gid.startswith(f"{group}-") + assert not gid.endswith(f"{group}-") diff --git a/tests/common/test_maths.py b/tests/common/test_maths.py new file mode 100644 index 0000000000..7cf682d1c1 --- /dev/null +++ b/tests/common/test_maths.py @@ -0,0 +1,65 @@ +from typing import Optional + +import pytest + +from dodal.common import in_micros, step_to_num + + +@pytest.mark.parametrize( + "s,us", + [ + (4.000_001, 4_000_001), + (4.999_999, 4_999_999), + (4, 4_000_000), + (4.000_000_1, 4_000_001), + (4.999_999_9, 5_000_000), + (0.1, 100_000), + (0.000_000_1, 1), + (0, 0), + ], +) +def test_in_micros(s: float, us: int): + assert in_micros(s) == us + + +@pytest.mark.parametrize( + "s", [-4.000_001, -4.999_999, -4, -4.000_000_5, -4.999_999_9, -4.05] +) +def test_in_micros_negative(s: float): + with pytest.raises(ValueError): + in_micros(s) + + +@pytest.mark.parametrize( + "start,stop,step,expected_num,truncated_stop", + [ + (0, 0, 1, 1, None), # start=stop, 1 point at start + (0, 0.5, 1, 1, 0), # step>length, 1 point at start + (0, 1, 1, 2, None), # stop=start+step, point at start & stop + (0, 0.99, 1, 2, 1), # stop >= start + 0.99*step, included + (0, 0.98, 1, 1, 0), # stop < start + 0.99*step, not included + (0, 1.01, 1, 2, 1), # stop >= start + 0.99*step, included + (0, 1.75, 0.25, 8, 1.75), + (0, 0, -1, 1, None), # start=stop, 1 point at start + (0, 0.5, -1, 1, 0), # abs(step)>length, 1 point at start + (0, -1, 1, 2, None), # stop=start+-abs(step), point at start & stop + (0, -0.99, 1, 2, -1), # stop >= start + 0.99*-abs(step), included + (0, -0.98, 1, 1, 0), # stop < start + 0.99*-abs(step), not included + (0, -1.01, 1, 2, -1), # stop >= start + 0.99*-abs(step), included + (0, -1.75, 0.25, 8, -1.75), + (1, 10, -0.901, 10, 9.109), # length overrules step for direction + (10, 1, -0.901, 10, 1.891), + ], +) +def test_step_to_num( + start: float, + stop: float, + step: float, + expected_num: int, + truncated_stop: Optional[float], +): + truncated_stop = stop if truncated_stop is None else truncated_stop + actual_start, actual_stop, num = step_to_num(start, stop, step) + assert actual_start == start + assert actual_stop == truncated_stop + assert num == expected_num diff --git a/tests/conftest.py b/tests/conftest.py index 8d0236987a..54ccd42836 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,8 +42,9 @@ def module_and_devices_for_beamline(request): bl_mod = importlib.import_module("dodal.beamlines." + beamline) importlib.reload(bl_mod) mock_beamline_module_filepaths(beamline, bl_mod) - yield bl_mod, make_all_devices( - bl_mod, include_skipped=True, fake_with_ophyd_sim=True + yield ( + bl_mod, + make_all_devices(bl_mod, include_skipped=True, fake_with_ophyd_sim=True), ) beamline_utils.clear_devices() del bl_mod diff --git a/tests/devices/system_tests/test_gridscan_system.py b/tests/devices/system_tests/test_gridscan_system.py index bf7f00a91b..4a15093ffd 100644 --- a/tests/devices/system_tests/test_gridscan_system.py +++ b/tests/devices/system_tests/test_gridscan_system.py @@ -32,7 +32,12 @@ def test_when_program_data_set_and_staged_then_expected_images_correct( fast_grid_scan: FastGridScan, ): RE = RunEngine() - RE(set_fast_grid_scan_params(fast_grid_scan, GridScanParams(x_steps=2, y_steps=2))) + RE( + set_fast_grid_scan_params( + fast_grid_scan, + GridScanParams(transmission_fraction=0.01, x_steps=2, y_steps=2), + ) + ) assert fast_grid_scan.expected_images.get() == 2 * 2 fast_grid_scan.stage() assert fast_grid_scan.position_counter.get() == 0 @@ -44,7 +49,8 @@ def test_given_valid_params_when_kickoff_then_completion_status_increases_and_fi ): def set_and_wait_plan(fast_grid_scan: FastGridScan): yield from set_fast_grid_scan_params( - fast_grid_scan, GridScanParams(x_steps=3, y_steps=3) + fast_grid_scan, + GridScanParams(transmission_fraction=0.01, x_steps=3, y_steps=3), ) yield from wait_for_fgs_valid(fast_grid_scan) diff --git a/tests/devices/system_tests/test_synchrotron_system.py b/tests/devices/system_tests/test_synchrotron_system.py index 655094b1d4..c39edf869d 100644 --- a/tests/devices/system_tests/test_synchrotron_system.py +++ b/tests/devices/system_tests/test_synchrotron_system.py @@ -10,5 +10,5 @@ def synchrotron(): @pytest.mark.s03 -def test_synchrotron_connects(synchrotron: Synchrotron): - synchrotron.wait_for_connection() +async def test_synchrotron_connects(synchrotron: Synchrotron): + await synchrotron.connect() diff --git a/tests/devices/system_tests/test_zocalo_results.py b/tests/devices/system_tests/test_zocalo_results.py index 18849f0c46..4bbd0f994f 100644 --- a/tests/devices/system_tests/test_zocalo_results.py +++ b/tests/devices/system_tests/test_zocalo_results.py @@ -4,7 +4,6 @@ import bluesky.plan_stubs as bps import psutil import pytest -import pytest_asyncio from bluesky.preprocessors import stage_decorator from bluesky.run_engine import RunEngine from bluesky.utils import FailedStatus @@ -27,7 +26,7 @@ } -@pytest_asyncio.fixture +@pytest.fixture async def zocalo_device(): zd = ZocaloResults() await zd.connect() @@ -35,11 +34,10 @@ async def zocalo_device(): @pytest.mark.s03 -@pytest.mark.asyncio async def test_read_results_from_fake_zocalo(zocalo_device: ZocaloResults): zocalo_device._subscribe_to_results() zc = ZocaloTrigger("dev_artemis") - zc.run_start(0) + zc.run_start(0, 0, 100) zc.run_end(0) zocalo_device.timeout_s = 5 @@ -56,7 +54,6 @@ def plan(): @pytest.mark.s03 -@pytest.mark.asyncio async def test_stage_unstage_controls_read_results_from_fake_zocalo( zocalo_device: ZocaloResults, ): @@ -66,7 +63,7 @@ async def test_stage_unstage_controls_read_results_from_fake_zocalo( def plan(): yield from bps.open_run() - zc.run_start(0) + zc.run_start(0, 0, 100) zc.run_end(0) yield from bps.sleep(0.15) yield from bps.trigger_and_read([zocalo_device]) @@ -104,7 +101,6 @@ def plan_with_stage(): @pytest.mark.s03 -@pytest.mark.asyncio async def test_stale_connections_closed_after_unstage( zocalo_device: ZocaloResults, ): diff --git a/tests/devices/unit_tests/conftest.py b/tests/devices/unit_tests/conftest.py new file mode 100644 index 0000000000..f27893fbed --- /dev/null +++ b/tests/devices/unit_tests/conftest.py @@ -0,0 +1,24 @@ +from functools import partial +from typing import Union +from unittest.mock import MagicMock, patch + +from ophyd.epics_motor import EpicsMotor +from ophyd.status import Status + +from dodal.devices.util.motor_utils import ExtendedEpicsMotor + + +def mock_set(motor: EpicsMotor, val): + motor.user_setpoint.sim_put(val) # type: ignore + motor.user_readback.sim_put(val) # type: ignore + return Status(done=True, success=True) + + +def patch_motor(motor: Union[EpicsMotor, ExtendedEpicsMotor], initial_position=0): + motor.user_setpoint.sim_put(initial_position) # type: ignore + motor.user_readback.sim_put(initial_position) # type: ignore + motor.motor_done_move.sim_put(1) # type: ignore + motor.user_setpoint._use_limits = False + if isinstance(motor, ExtendedEpicsMotor): + motor.motor_resolution.sim_put(0.001) # type: ignore + return patch.object(motor, "set", MagicMock(side_effect=partial(mock_set, motor))) diff --git a/tests/devices/unit_tests/oav/image_recognition/test_pin_tip_detect.py b/tests/devices/unit_tests/oav/image_recognition/test_pin_tip_detect.py index b3ee5eb4b4..432d3b7ae2 100644 --- a/tests/devices/unit_tests/oav/image_recognition/test_pin_tip_detect.py +++ b/tests/devices/unit_tests/oav/image_recognition/test_pin_tip_detect.py @@ -2,7 +2,6 @@ from unittest.mock import MagicMock, patch import numpy as np -import pytest from ophyd_async.core import set_sim_value from dodal.devices.oav.pin_image_recognition import MxSampleDetect, PinTipDetection @@ -11,9 +10,10 @@ EVENT_LOOP = asyncio.new_event_loop() -pytest_plugins = ("pytest_asyncio",) DEVICE_NAME = "pin_tip_detection" TRIGGERED_TIP_READING = DEVICE_NAME + "-triggered_tip" +TRIGGERED_TOP_EDGE_READING = DEVICE_NAME + "-triggered_top_edge" +TRIGGERED_BOTTOM_EDGE_READING = DEVICE_NAME + "-triggered_bottom_edge" async def _get_pin_tip_detection_device() -> PinTipDetection: @@ -22,13 +22,11 @@ async def _get_pin_tip_detection_device() -> PinTipDetection: return device -@pytest.mark.asyncio async def test_pin_tip_detect_can_be_connected_in_sim_mode(): device = await _get_pin_tip_detection_device() await device.connect(sim=True) -@pytest.mark.asyncio async def test_soft_parameter_defaults_are_correct(): device = await _get_pin_tip_detection_device() @@ -44,7 +42,6 @@ async def test_soft_parameter_defaults_are_correct(): assert await device.preprocess_ksize.get_value() == 5 -@pytest.mark.asyncio async def test_numeric_soft_parameters_can_be_changed(): device = await _get_pin_tip_detection_device() @@ -71,7 +68,6 @@ async def test_numeric_soft_parameters_can_be_changed(): assert await device.preprocess_iterations.get_value() == 4 -@pytest.mark.asyncio async def test_invalid_processing_func_uses_identity_function(): device = await _get_pin_tip_detection_device() test_sample_location = SampleLocation(100, 200, np.array([]), np.array([])) @@ -82,7 +78,7 @@ async def test_invalid_processing_func_uses_identity_function(): patch.object(MxSampleDetect, "__init__", return_value=None) as mock_init, patch.object(MxSampleDetect, "processArray", return_value=test_sample_location), ): - await device._get_tip_position(np.array([])) + await device._get_tip_and_edge_data(np.array([])) mock_init.assert_called_once() @@ -93,11 +89,12 @@ async def test_invalid_processing_func_uses_identity_function(): assert arg == captured_func(arg) -@pytest.mark.asyncio async def test_given_valid_data_reading_then_used_to_find_location(): device = await _get_pin_tip_detection_device() image_array = np.array([1, 2, 3]) - test_sample_location = SampleLocation(100, 200, np.array([]), np.array([])) + test_sample_location = SampleLocation( + 100, 200, np.array([1, 2, 3]), np.array([4, 5, 6]) + ) set_sim_value(device.array_data, image_array) with ( @@ -111,11 +108,16 @@ async def test_given_valid_data_reading_then_used_to_find_location(): process_call = mock_process_array.call_args[0][0] assert np.array_equal(process_call, image_array) - assert location[TRIGGERED_TIP_READING]["value"] == (200, 100) + assert location[TRIGGERED_TIP_READING]["value"] == (100, 200) + assert np.all( + location[TRIGGERED_TOP_EDGE_READING]["value"] == np.array([1, 2, 3]) + ) + assert np.all( + location[TRIGGERED_BOTTOM_EDGE_READING]["value"] == np.array([4, 5, 6]) + ) assert location[TRIGGERED_TIP_READING]["timestamp"] > 0 -@pytest.mark.asyncio async def test_given_find_tip_fails_when_triggered_then_tip_invalid(): device = await _get_pin_tip_detection_device() await device.validity_timeout.set(0.1) @@ -128,9 +130,10 @@ async def test_given_find_tip_fails_when_triggered_then_tip_invalid(): await device.trigger() reading = await device.read() assert reading[TRIGGERED_TIP_READING]["value"] == device.INVALID_POSITION + assert len(reading[TRIGGERED_TOP_EDGE_READING]["value"]) == 0 + assert len(reading[TRIGGERED_BOTTOM_EDGE_READING]["value"]) == 0 -@pytest.mark.asyncio @patch("dodal.devices.oav.pin_image_recognition.observe_value") async def test_given_find_tip_fails_twice_when_triggered_then_tip_invalid_and_tried_twice( mock_image_read, @@ -156,7 +159,6 @@ async def get_array_data(_): assert mock_process_array.call_count > 1 -@pytest.mark.asyncio @patch("dodal.devices.oav.pin_image_recognition.LOGGER.warn") @patch("dodal.devices.oav.pin_image_recognition.observe_value") async def test_given_tip_invalid_then_loop_keeps_retrying_until_valid( @@ -172,14 +174,25 @@ async def get_array_data(_): device = await _get_pin_tip_detection_device() class FakeLocation: - def __init__(self, tip_x, tip_y): + def __init__(self, tip_x, tip_y, edge_top, edge_bottom): self.tip_x = tip_x self.tip_y = tip_y + self.edge_top = edge_top + self.edge_bottom = edge_bottom + + fake_top_edge = np.array([1, 2, 3]) + fake_bottom_edge = np.array([4, 5, 6]) - with patch.object(MxSampleDetect, "__init__", return_value=None), patch.object( - MxSampleDetect, - "processArray", - side_effect=[FakeLocation(None, None), FakeLocation(1, 1)], + with ( + patch.object(MxSampleDetect, "__init__", return_value=None), + patch.object( + MxSampleDetect, + "processArray", + side_effect=[ + FakeLocation(None, None, fake_top_edge, fake_bottom_edge), + FakeLocation(1, 1, fake_top_edge, fake_bottom_edge), + ], + ), ): await device.trigger() mock_logger.assert_called_once() diff --git a/tests/devices/unit_tests/oav/test_oav.py b/tests/devices/unit_tests/oav/test_oav.py index 0fe1148996..58daa13adf 100644 --- a/tests/devices/unit_tests/oav/test_oav.py +++ b/tests/devices/unit_tests/oav/test_oav.py @@ -2,7 +2,7 @@ import pytest from ophyd.sim import instantiate_fake_device -from ophyd.status import Status +from ophyd.status import AndStatus, Status from dodal.devices.oav.oav_detector import OAV, OAVConfigParams from dodal.devices.oav.oav_errors import ( @@ -46,25 +46,18 @@ def oav() -> OAV: def test_when_zoom_level_changed_then_oav_rewired(zoom, expected_plugin, oav: OAV): oav.zoom_controller.set(zoom).wait() - assert oav.mxsc.input_plugin.get() == expected_plugin assert oav.snapshot.input_plugin.get() == expected_plugin def test_when_zoom_level_changed_then_status_waits_for_all_plugins_to_be_updated( oav: OAV, ): - mxsc_status = Status(obj="msxc - test_when_zoom_level...") - oav.mxsc.input_plugin.set = MagicMock(return_value=mxsc_status) - mjpg_status = Status("mjpg - test_when_zoom_level...") oav.snapshot.input_plugin.set = MagicMock(return_value=mjpg_status) - full_status = oav.zoom_controller.set("1.0x") - - assert mxsc_status in full_status + assert isinstance(full_status := oav.zoom_controller.set("1.0x"), AndStatus) assert mjpg_status in full_status - mxsc_status.set_finished() mjpg_status.set_finished() full_status.wait() diff --git a/tests/devices/unit_tests/test_aperture_scatterguard.py b/tests/devices/unit_tests/test_aperture_scatterguard.py index 84a71b0c31..b119571062 100644 --- a/tests/devices/unit_tests/test_aperture_scatterguard.py +++ b/tests/devices/unit_tests/test_aperture_scatterguard.py @@ -1,65 +1,48 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, call import pytest from ophyd.sim import make_fake_device -from ophyd.status import Status, StatusBase +from ophyd.status import StatusBase from dodal.devices.aperturescatterguard import ( + ApertureFiveDimensionalLocation, AperturePositions, ApertureScatterguard, InvalidApertureMove, + SingleAperturePosition, ) +from .conftest import patch_motor + @pytest.fixture -def fake_aperture_scatterguard(): +def ap_sg(aperture_positions: AperturePositions): FakeApertureScatterguard = make_fake_device(ApertureScatterguard) ap_sg: ApertureScatterguard = FakeApertureScatterguard(name="test_ap_sg") - yield ap_sg + ap_sg.load_aperture_positions(aperture_positions) + with ( + patch_motor(ap_sg.aperture.x), + patch_motor(ap_sg.aperture.y), + patch_motor(ap_sg.aperture.z), + patch_motor(ap_sg.scatterguard.x), + patch_motor(ap_sg.scatterguard.y), + ): + yield ap_sg @pytest.fixture def aperture_in_medium_pos( - fake_aperture_scatterguard: ApertureScatterguard, + ap_sg: ApertureScatterguard, aperture_positions: AperturePositions, ): - fake_aperture_scatterguard.load_aperture_positions(aperture_positions) - fake_aperture_scatterguard.aperture.x.user_setpoint.sim_put( # type: ignore - aperture_positions.MEDIUM[0] - ) - fake_aperture_scatterguard.aperture.y.user_setpoint.sim_put( # type: ignore - aperture_positions.MEDIUM[1] - ) - fake_aperture_scatterguard.aperture.z.user_setpoint.sim_put( # type: ignore - aperture_positions.MEDIUM[2] - ) - fake_aperture_scatterguard.aperture.x.user_readback.sim_put( # type: ignore - aperture_positions.MEDIUM[1] - ) - fake_aperture_scatterguard.aperture.y.user_readback.sim_put( # type: ignore - aperture_positions.MEDIUM[1] - ) - fake_aperture_scatterguard.aperture.z.user_readback.sim_put( # type: ignore - aperture_positions.MEDIUM[1] - ) - fake_aperture_scatterguard.scatterguard.x.user_setpoint.sim_put( # type: ignore - aperture_positions.MEDIUM[3] - ) - fake_aperture_scatterguard.scatterguard.y.user_setpoint.sim_put( # type: ignore - aperture_positions.MEDIUM[4] - ) - fake_aperture_scatterguard.scatterguard.x.user_readback.sim_put( # type: ignore - aperture_positions.MEDIUM[3] - ) - fake_aperture_scatterguard.scatterguard.y.user_readback.sim_put( # type: ignore - aperture_positions.MEDIUM[4] - ) - fake_aperture_scatterguard.aperture.x.motor_done_move.sim_put(1) # type: ignore - fake_aperture_scatterguard.aperture.y.motor_done_move.sim_put(1) # type: ignore - fake_aperture_scatterguard.aperture.z.motor_done_move.sim_put(1) # type: ignore - fake_aperture_scatterguard.scatterguard.x.motor_done_move.sim_put(1) # type: ignore - fake_aperture_scatterguard.scatterguard.y.motor_done_move.sim_put(1) # type: ignore - return fake_aperture_scatterguard + medium = aperture_positions.MEDIUM.location + ap_sg.aperture.x.set(medium.aperture_x) + ap_sg.aperture.y.set(medium.aperture_y) + ap_sg.aperture.z.set(medium.aperture_z) + ap_sg.scatterguard.x.set(medium.scatterguard_x) + ap_sg.scatterguard.y.set(medium.scatterguard_y) + ap_sg.aperture.medium.sim_put(1) # type: ignore + yield ap_sg @pytest.fixture @@ -91,138 +74,207 @@ def aperture_positions(): return aperture_positions -def test_aperture_scatterguard_rejects_unknown_position( - aperture_positions, aperture_in_medium_pos -): - for i in range(0, len(aperture_positions.MEDIUM)): - temp_pos = list(aperture_positions.MEDIUM) - temp_pos[i] += 0.01 - position_to_reject = tuple(temp_pos) +def test_aperture_scatterguard_rejects_unknown_position(aperture_in_medium_pos): + position_to_reject = ApertureFiveDimensionalLocation(0, 0, 0, 0, 0) - with pytest.raises(InvalidApertureMove): - aperture_in_medium_pos.set(position_to_reject) + with pytest.raises(InvalidApertureMove): + aperture_in_medium_pos.set( + SingleAperturePosition("test", "GDA_NAME", 10, position_to_reject) + ) def test_aperture_scatterguard_select_bottom_moves_sg_down_then_assembly_up( - aperture_positions, aperture_in_medium_pos + aperture_positions: AperturePositions, + aperture_in_medium_pos: ApertureScatterguard, ): aperture_scatterguard = aperture_in_medium_pos call_logger = install_logger_for_aperture_and_scatterguard(aperture_scatterguard) aperture_scatterguard.set(aperture_positions.SMALL) - actual_calls = call_logger.mock_calls - expected_calls = [ - ("_mock_sg_x", (5.3375,)), - ("_mock_sg_y", (-3.55,)), - lambda call: call[0].endswith("__and__().wait"), - ("_mock_ap_x", (2.43,)), - ("_mock_ap_y", (48.974,)), - ("_mock_ap_z", (15.8,)), - ] + call_logger.assert_has_calls( + [ + call._mock_sg_x(5.3375), + call._mock_sg_y(-3.55), + call._mock_ap_x(2.43), + call._mock_ap_y(48.974), + call._mock_ap_z(15.8), + ] + ) + - compare_actual_and_expected_calls(actual_calls, expected_calls) +def test_aperture_unsafe_move( + aperture_positions: AperturePositions, + aperture_in_medium_pos: ApertureScatterguard, +): + (a, b, c, d, e) = (0.2, 3.4, 5.6, 7.8, 9.0) + aperture_scatterguard = aperture_in_medium_pos + call_logger = install_logger_for_aperture_and_scatterguard(aperture_scatterguard) + aperture_scatterguard._set_raw_unsafe((a, b, c, d, e)) # type: ignore + + call_logger.assert_has_calls( + [ + call._mock_ap_x(a), + call._mock_ap_y(b), + call._mock_ap_z(c), + call._mock_sg_x(d), + call._mock_sg_y(e), + ] + ) def test_aperture_scatterguard_select_top_moves_assembly_down_then_sg_up( - aperture_positions, aperture_in_medium_pos + aperture_positions: AperturePositions, aperture_in_medium_pos: ApertureScatterguard ): aperture_scatterguard = aperture_in_medium_pos call_logger = install_logger_for_aperture_and_scatterguard(aperture_scatterguard) aperture_scatterguard.set(aperture_positions.LARGE) - actual_calls = call_logger.mock_calls - expected_calls = [ - ("_mock_ap_x", (2.389,)), - ("_mock_ap_y", (40.986,)), - ("_mock_ap_z", (15.8,)), - lambda call: call[0].endswith("__and__().wait"), - ("_mock_sg_x", (5.25,)), - ("_mock_sg_y", (4.43,)), - ] - - compare_actual_and_expected_calls(actual_calls, expected_calls) + call_logger.assert_has_calls( + [ + call._mock_ap_x(2.389), + call._mock_ap_y(40.986), + call._mock_ap_z(15.8), + call._mock_sg_x(5.25), + call._mock_sg_y(4.43), + ] + ) def test_aperture_scatterguard_throws_error_if_outside_tolerance( - fake_aperture_scatterguard: ApertureScatterguard, + ap_sg: ApertureScatterguard, ): - fake_aperture_scatterguard.aperture.z.motor_resolution.sim_put(0.001) # type: ignore - fake_aperture_scatterguard.aperture.z.user_setpoint.sim_put(1) # type: ignore - fake_aperture_scatterguard.aperture.z.motor_done_move.sim_put(1) # type: ignore + ap_sg.aperture.z.motor_resolution.sim_put(0.001) # type: ignore + ap_sg.aperture.z.user_setpoint.sim_put(1) # type: ignore + ap_sg.aperture.z.motor_done_move.sim_put(1) # type: ignore with pytest.raises(InvalidApertureMove): - fake_aperture_scatterguard._safe_move_within_datacollection_range( - 0, 0, 1.1, 0, 0 - ) + pos = ApertureFiveDimensionalLocation(0, 0, 1.1, 0, 0) + ap_sg._safe_move_within_datacollection_range(pos) def test_aperture_scatterguard_returns_status_if_within_tolerance( - fake_aperture_scatterguard: ApertureScatterguard, + ap_sg: ApertureScatterguard, ): - fake_aperture_scatterguard.aperture.z.motor_resolution.sim_put(0.001) # type: ignore - fake_aperture_scatterguard.aperture.z.user_setpoint.sim_put(1) # type: ignore - fake_aperture_scatterguard.aperture.z.motor_done_move.sim_put(1) # type: ignore + ap_sg.aperture.z.motor_resolution.sim_put(0.001) # type: ignore + ap_sg.aperture.z.user_setpoint.sim_put(1) # type: ignore + ap_sg.aperture.z.motor_done_move.sim_put(1) # type: ignore - mock_set = MagicMock(return_value=Status(done=True, success=True)) + pos = ApertureFiveDimensionalLocation(0, 0, 1, 0, 0) + status = ap_sg._safe_move_within_datacollection_range(pos) + assert isinstance(status, StatusBase) - fake_aperture_scatterguard.aperture.x.set = mock_set - fake_aperture_scatterguard.aperture.y.set = mock_set - fake_aperture_scatterguard.aperture.z.set = mock_set - fake_aperture_scatterguard.scatterguard.x.set = mock_set - fake_aperture_scatterguard.scatterguard.y.set = mock_set +def set_underlying_motors( + ap_sg: ApertureScatterguard, position: ApertureFiveDimensionalLocation +): + ap_sg.aperture.x.set(position.aperture_x) + ap_sg.aperture.y.set(position.aperture_y) + ap_sg.aperture.z.set(position.aperture_z) + ap_sg.scatterguard.x.set(position.scatterguard_x) + ap_sg.scatterguard.y.set(position.scatterguard_y) + - status = fake_aperture_scatterguard._safe_move_within_datacollection_range( - 0, 0, 1, 0, 0 +def test_aperture_positions_large( + ap_sg: ApertureScatterguard, aperture_positions: AperturePositions +): + ap_sg.aperture.large.sim_put(1) # type: ignore + assert ap_sg._get_current_aperture_position() == aperture_positions.LARGE + + +def test_aperture_positions_medium( + ap_sg: ApertureScatterguard, aperture_positions: AperturePositions +): + ap_sg.aperture.medium.sim_put(1) # type: ignore + assert ap_sg._get_current_aperture_position() == aperture_positions.MEDIUM + + +def test_aperture_positions_small( + ap_sg: ApertureScatterguard, aperture_positions: AperturePositions +): + ap_sg.aperture.small.sim_put(1) # type: ignore + assert ap_sg._get_current_aperture_position() == aperture_positions.SMALL + + +def test_aperture_positions_robot_load( + ap_sg: ApertureScatterguard, aperture_positions: AperturePositions +): + ap_sg.aperture.large.sim_put(0) # type: ignore + ap_sg.aperture.medium.sim_put(0) # type: ignore + ap_sg.aperture.small.sim_put(0) # type: ignore + ap_sg.aperture.y.set(aperture_positions.ROBOT_LOAD.location.aperture_y) # type: ignore + assert ap_sg._get_current_aperture_position() == aperture_positions.ROBOT_LOAD + + +def test_aperture_positions_robot_load_within_tolerance( + ap_sg: ApertureScatterguard, aperture_positions: AperturePositions +): + robot_load_ap_y = aperture_positions.ROBOT_LOAD.location.aperture_y + tolerance = ap_sg.TOLERANCE_STEPS * ap_sg.aperture.y.motor_resolution.get() + ap_sg.aperture.large.sim_put(0) # type: ignore + ap_sg.aperture.medium.sim_put(0) # type: ignore + ap_sg.aperture.small.sim_put(0) # type: ignore + ap_sg.aperture.y.set(robot_load_ap_y + tolerance) # type: ignore + assert ap_sg._get_current_aperture_position() == aperture_positions.ROBOT_LOAD + + +def test_aperture_positions_robot_load_outside_tolerance( + ap_sg: ApertureScatterguard, aperture_positions: AperturePositions +): + robot_load_ap_y = aperture_positions.ROBOT_LOAD.location.aperture_y + tolerance = (ap_sg.TOLERANCE_STEPS + 1) * ap_sg.aperture.y.motor_resolution.get() + ap_sg.aperture.large.sim_put(0) # type: ignore + ap_sg.aperture.medium.sim_put(0) # type: ignore + ap_sg.aperture.small.sim_put(0) # type: ignore + ap_sg.aperture.y.set(robot_load_ap_y + tolerance) # type: ignore + with pytest.raises(InvalidApertureMove): + ap_sg._get_current_aperture_position() + + +def test_aperture_positions_robot_load_unsafe( + ap_sg: ApertureScatterguard, aperture_positions: AperturePositions +): + ap_sg.aperture.large.sim_put(0) # type: ignore + ap_sg.aperture.medium.sim_put(0) # type: ignore + ap_sg.aperture.small.sim_put(0) # type: ignore + ap_sg.aperture.y.set(50.0) # type: ignore + with pytest.raises(InvalidApertureMove): + ap_sg._get_current_aperture_position() + + +def test_given_aperture_not_set_through_device_but_motors_in_position_when_device_read_then_position_returned( + aperture_in_medium_pos: ApertureScatterguard, aperture_positions: AperturePositions +): + selected_aperture = aperture_in_medium_pos.read() + assert ( + selected_aperture["test_ap_sg_selected_aperture"]["value"] + == aperture_positions.MEDIUM + ) + + +def test_when_aperture_set_and_device_read_then_position_returned( + aperture_in_medium_pos: ApertureScatterguard, aperture_positions: AperturePositions +): + set_status = aperture_in_medium_pos.set(aperture_positions.MEDIUM) + set_status.wait() + selected_aperture = aperture_in_medium_pos.read() + assert ( + selected_aperture["test_ap_sg_selected_aperture"]["value"] + == aperture_positions.MEDIUM ) - assert isinstance(status, StatusBase) def install_logger_for_aperture_and_scatterguard(aperture_scatterguard): parent_mock = MagicMock() - mock_ap_x = MagicMock(aperture_scatterguard.aperture.x.set) - mock_ap_y = MagicMock(aperture_scatterguard.aperture.y.set) - mock_ap_z = MagicMock(aperture_scatterguard.aperture.z.set) - mock_sg_x = MagicMock(aperture_scatterguard.scatterguard.x.set) - mock_sg_y = MagicMock(aperture_scatterguard.scatterguard.y.set) - aperture_scatterguard.aperture.x.set = mock_ap_x - aperture_scatterguard.aperture.y.set = mock_ap_y - aperture_scatterguard.aperture.z.set = mock_ap_z - aperture_scatterguard.scatterguard.x.set = mock_sg_x - aperture_scatterguard.scatterguard.y.set = mock_sg_y + mock_ap_x = aperture_scatterguard.aperture.x.set + mock_ap_y = aperture_scatterguard.aperture.y.set + mock_ap_z = aperture_scatterguard.aperture.z.set + mock_sg_x = aperture_scatterguard.scatterguard.x.set + mock_sg_y = aperture_scatterguard.scatterguard.y.set parent_mock.attach_mock(mock_ap_x, "_mock_ap_x") parent_mock.attach_mock(mock_ap_y, "_mock_ap_y") parent_mock.attach_mock(mock_ap_z, "_mock_ap_z") parent_mock.attach_mock(mock_sg_x, "_mock_sg_x") parent_mock.attach_mock(mock_sg_y, "_mock_sg_y") return parent_mock - - -def compare_actual_and_expected_calls(actual_calls, expected_calls): - # ideally, we could use MagicMock.assert_has_calls but a) it doesn't work properly and b) doesn't do what I need - i_actual = 0 - for i_expected in range(0, len(expected_calls)): - try: - expected = expected_calls[i_expected] - if isinstance(expected, tuple): - # simple comparison - i_actual = actual_calls.index(expected_calls[i_expected], i_actual) - else: - # expected is a predicate to be satisfied - i_matches = [ - i - for i in range(i_actual, len(actual_calls)) - if expected(actual_calls[i]) - ] - if i_matches: - i_actual = i_matches[0] - else: - raise ValueError("Couldn't find call matching predicate") - except ValueError: - assert ( - False - ), f"Couldn't find call #{i_expected}: {expected_calls[i_expected]}" - - i_actual += 1 diff --git a/tests/devices/unit_tests/test_bart_robot.py b/tests/devices/unit_tests/test_bart_robot.py index 30ee449f9c..8f28b43649 100644 --- a/tests/devices/unit_tests/test_bart_robot.py +++ b/tests/devices/unit_tests/test_bart_robot.py @@ -1,24 +1,89 @@ +from asyncio import TimeoutError, sleep +from unittest.mock import AsyncMock, MagicMock, patch + import pytest from ophyd_async.core import set_sim_value -from dodal.devices.robot import BartRobot +from dodal.devices.robot import BartRobot, PinMounted, SampleLocation async def _get_bart_robot() -> BartRobot: device = BartRobot("robot", "-MO-ROBOT-01:") + device.LOAD_TIMEOUT = 0.01 # type: ignore await device.connect(sim=True) return device -@pytest.mark.asyncio async def test_bart_robot_can_be_connected_in_sim_mode(): device = await _get_bart_robot() await device.connect(sim=True) -@pytest.mark.asyncio async def test_when_barcode_updates_then_new_barcode_read(): device = await _get_bart_robot() expected_barcode = "expected" set_sim_value(device.barcode.bare_signal, [expected_barcode, "other_barcode"]) assert (await device.barcode.read())["robot-barcode"]["value"] == expected_barcode + + +@patch("dodal.devices.robot.LOGGER") +async def test_given_program_running_when_load_pin_then_logs_the_program_name_and_times_out( + patch_logger: MagicMock, +): + device = await _get_bart_robot() + program_name = "BAD_PROGRAM" + set_sim_value(device.program_running, True) + set_sim_value(device.program_name, program_name) + with pytest.raises(TimeoutError): + await device.set(SampleLocation(0, 0)) + last_log = patch_logger.mock_calls[1].args[0] + assert program_name in last_log + + +@patch("dodal.devices.robot.LOGGER") +async def test_given_program_not_running_but_pin_not_unmounting_when_load_pin_then_timeout( + patch_logger: MagicMock, +): + device = await _get_bart_robot() + set_sim_value(device.program_running, False) + set_sim_value(device.gonio_pin_sensor, PinMounted.PIN_MOUNTED) + device.load = AsyncMock(side_effect=device.load) + with pytest.raises(TimeoutError): + await device.set(SampleLocation(15, 10)) + device.load.trigger.assert_called_once() # type:ignore + last_log = patch_logger.mock_calls[1].args[0] + assert "Waiting on old pin unloaded" in last_log + + +@patch("dodal.devices.robot.LOGGER") +async def test_given_program_not_running_and_pin_unmounting_but_new_pin_not_mounting_when_load_pin_then_timeout( + patch_logger: MagicMock, +): + device = await _get_bart_robot() + set_sim_value(device.program_running, False) + set_sim_value(device.gonio_pin_sensor, PinMounted.NO_PIN_MOUNTED) + device.load = AsyncMock(side_effect=device.load) + with pytest.raises(TimeoutError): + await device.set(SampleLocation(15, 10)) + device.load.trigger.assert_called_once() # type:ignore + last_log = patch_logger.mock_calls[1].args[0] + assert "Waiting on new pin loaded" in last_log + + +async def test_given_program_not_running_and_pin_unmounts_then_mounts_when_load_pin_then_pin_loaded(): + device = await _get_bart_robot() + device.LOAD_TIMEOUT = 0.03 # type: ignore + set_sim_value(device.program_running, False) + set_sim_value(device.gonio_pin_sensor, PinMounted.PIN_MOUNTED) + + device.load = AsyncMock(side_effect=device.load) + status = device.set(SampleLocation(15, 10)) + await sleep(0.01) + set_sim_value(device.gonio_pin_sensor, PinMounted.NO_PIN_MOUNTED) + await sleep(0.005) + set_sim_value(device.gonio_pin_sensor, PinMounted.PIN_MOUNTED) + await status + assert status.success + assert (await device.next_puck.get_value()) == 15 + assert (await device.next_pin.get_value()) == 10 + device.load.trigger.assert_called_once() # type:ignore diff --git a/tests/devices/unit_tests/test_gridscan.py b/tests/devices/unit_tests/test_gridscan.py index 2a5ee6de74..f3ad6a5cb2 100644 --- a/tests/devices/unit_tests/test_gridscan.py +++ b/tests/devices/unit_tests/test_gridscan.py @@ -9,6 +9,7 @@ from mockito.matchers import ANY, ARGS, KWARGS from ophyd.sim import make_fake_device from ophyd.status import DeviceStatus, Status +from ophyd.utils.errors import StatusTimeoutError from dodal.devices.fast_grid_scan import ( FastGridScan, @@ -43,7 +44,7 @@ def test_given_settings_valid_when_kickoff_then_run_started( mock_run_set_status = mock() when(fast_grid_scan.run_cmd).put(ANY).thenReturn(mock_run_set_status) - fast_grid_scan.status.subscribe = lambda func, **_: func(1) + fast_grid_scan.status.subscribe = lambda func, **_: func(1) # type: ignore status = fast_grid_scan.kickoff() @@ -53,13 +54,41 @@ def test_given_settings_valid_when_kickoff_then_run_started( verify(fast_grid_scan.run_cmd).put(1) +def test_waits_for_running_motion( + fast_grid_scan: FastGridScan, +): + when(fast_grid_scan.motion_program.running).get().thenReturn(1) + + fast_grid_scan.KICKOFF_TIMEOUT = 0.01 + + with pytest.raises(StatusTimeoutError): + status = fast_grid_scan.kickoff() + status.wait() + + fast_grid_scan.KICKOFF_TIMEOUT = 1 + + mock_run_set_status = mock() + when(fast_grid_scan.run_cmd).put(ANY).thenReturn(mock_run_set_status) + fast_grid_scan.status.subscribe = lambda func, **_: func(1) # type: ignore + + when(fast_grid_scan.motion_program.running).get().thenReturn(0) + status = fast_grid_scan.kickoff() + status.wait() + verify(fast_grid_scan.run_cmd).put(1) + + def run_test_on_complete_watcher( fast_grid_scan: FastGridScan, num_pos_1d, put_value, expected_frac ): RE = RunEngine() RE( set_fast_grid_scan_params( - fast_grid_scan, GridScanParams(x_steps=num_pos_1d, y_steps=num_pos_1d) + fast_grid_scan, + GridScanParams( + x_steps=num_pos_1d, + y_steps=num_pos_1d, + transmission_fraction=0.01, + ), ) ) @@ -122,7 +151,10 @@ def test_running_finished_with_all_images_done_then_complete_status_finishes_not RE = RunEngine() RE( set_fast_grid_scan_params( - fast_grid_scan, GridScanParams(x_steps=num_pos_1d, y_steps=num_pos_1d) + fast_grid_scan, + GridScanParams( + transmission_fraction=0.01, x_steps=num_pos_1d, y_steps=num_pos_1d + ), ) ) @@ -188,7 +220,9 @@ def test_within_limits_check(position, expected_in_limit): ) def test_scan_within_limits_1d(start, steps, size, expected_in_limits): motor_bundle = create_motor_bundle_with_limits(0.0, 10.0) - grid_params = GridScanParams(x_start=start, x_steps=steps, x_step_size=size) + grid_params = GridScanParams( + transmission_fraction=0.01, x_start=start, x_steps=steps, x_step_size=size + ) assert grid_params.is_valid(motor_bundle.get_xyz_limits()) == expected_in_limits @@ -205,6 +239,7 @@ def test_scan_within_limits_2d( ): motor_bundle = create_motor_bundle_with_limits(0.0, 10.0) grid_params = GridScanParams( + transmission_fraction=0.01, x_start=x_start, x_steps=x_steps, x_step_size=x_size, @@ -261,6 +296,7 @@ def test_scan_within_limits_3d( ): motor_bundle = create_motor_bundle_with_limits(0.0, 10.0) grid_params = GridScanParams( + transmission_fraction=0.01, x_start=x_start, x_steps=x_steps, x_step_size=x_size, @@ -279,6 +315,7 @@ def test_scan_within_limits_3d( @pytest.fixture def grid_scan_params(): yield GridScanParams( + transmission_fraction=0.01, x_steps=10, y_steps=15, z_steps=20, @@ -380,8 +417,14 @@ def test_given_x_y_z_steps_when_full_number_calculated_then_answer_is_as_expecte ) def test_non_test_integer_dwell_time(test_dwell_times, expected_dwell_time_is_integer): if expected_dwell_time_is_integer: - params = GridScanParams(dwell_time_ms=test_dwell_times) + params = GridScanParams( + dwell_time_ms=test_dwell_times, + transmission_fraction=0.01, + ) assert params.dwell_time_ms == test_dwell_times else: with pytest.raises(ValueError): - GridScanParams(dwell_time_ms=test_dwell_times) + GridScanParams( + dwell_time_ms=test_dwell_times, + transmission_fraction=0.01, + ) diff --git a/tests/devices/unit_tests/test_oav.py b/tests/devices/unit_tests/test_oav.py index f32c1d2a45..0e85a70fcd 100644 --- a/tests/devices/unit_tests/test_oav.py +++ b/tests/devices/unit_tests/test_oav.py @@ -67,17 +67,16 @@ def test_snapshot_trigger_saves_to_correct_file( mock_open: MagicMock, mock_get, fake_oav ): image = PIL.Image.open("test") - mock_save = MagicMock() - image.save = mock_save mock_open.return_value.__enter__.return_value = image - st = fake_oav.snapshot.trigger() - st.wait() - expected_calls_to_save = [ - call(f"test directory/test filename{addition}.png") - for addition in ["", "_outer_overlay", "_grid_overlay"] - ] - calls_to_save = mock_save.mock_calls - assert calls_to_save == expected_calls_to_save + with patch.object(image, "save") as mock_save: + st = fake_oav.snapshot.trigger() + st.wait() + expected_calls_to_save = [ + call(f"test directory/test filename{addition}.png") + for addition in ["", "_outer_overlay", "_grid_overlay"] + ] + calls_to_save = mock_save.mock_calls + assert calls_to_save == expected_calls_to_save @patch("requests.get") @@ -120,14 +119,12 @@ def test_bottom_right_from_top_left(): def test_when_zoom_1_then_flat_field_applied(fake_oav: OAV): RE = RunEngine() RE(bps.abs_set(fake_oav.zoom_controller, "1.0x")) - assert fake_oav.mxsc.input_plugin.get() == "PROC" assert fake_oav.snapshot.input_plugin.get() == "PROC" def test_when_zoom_not_1_then_flat_field_removed(fake_oav: OAV): RE = RunEngine() RE(bps.abs_set(fake_oav.zoom_controller, "10.0x")) - assert fake_oav.mxsc.input_plugin.get() == "CAM" assert fake_oav.snapshot.input_plugin.get() == "CAM" @@ -136,11 +133,9 @@ def test_when_zoom_is_externally_changed_to_1_then_flat_field_not_changed( ): """This test is required to ensure that Hyperion doesn't cause unexpected behaviour e.g. change the flatfield when the zoom level is changed through the synoptic""" - fake_oav.mxsc.input_plugin.sim_put("CAM") # type: ignore fake_oav.snapshot.input_plugin.sim_put("CAM") # type: ignore fake_oav.zoom_controller.level.sim_put("1.0x") # type: ignore - assert fake_oav.mxsc.input_plugin.get() == "CAM" assert fake_oav.snapshot.input_plugin.get() == "CAM" diff --git a/tests/devices/unit_tests/test_oav_centring.py b/tests/devices/unit_tests/test_oav_centring.py deleted file mode 100644 index 6b2d477d3f..0000000000 --- a/tests/devices/unit_tests/test_oav_centring.py +++ /dev/null @@ -1,269 +0,0 @@ -import bluesky.plan_stubs as bps -import bluesky.preprocessors as bpp -import numpy as np -import pytest -from bluesky.run_engine import RunEngine -from ophyd.sim import make_fake_device - -from dodal.devices.backlight import Backlight -from dodal.devices.oav.oav_calculations import ( - camera_coordinates_to_xyz, - check_i_within_bounds, - extract_pixel_centre_values_from_rotation_data, - filter_rotation_data, - find_midpoint, - find_widest_point_and_orthogonal_point, - get_orthogonal_index, - get_rotation_increment, - keep_inside_bounds, -) -from dodal.devices.oav.oav_detector import OAV, OAVConfigParams -from dodal.devices.oav.oav_errors import ( - OAVError_MissingRotations, - OAVError_NoRotationsPassValidityTest, -) -from dodal.devices.oav.oav_parameters import OAVParameters -from dodal.devices.smargon import Smargon - -OAV_CENTRING_JSON = "tests/devices/unit_tests/test_OAVCentring.json" -DISPLAY_CONFIGURATION = "tests/devices/unit_tests/test_display.configuration" -ZOOM_LEVELS_XML = "tests/devices/unit_tests/test_jCameraManZoomLevels.xml" - - -def do_nothing(*args, **kwargs): - pass - - -@pytest.fixture -def mock_oav(): - oav_params = OAVConfigParams(ZOOM_LEVELS_XML, DISPLAY_CONFIGURATION) - oav: OAV = make_fake_device(OAV)( - name="oav", prefix="a fake beamline", params=oav_params - ) - oav.snapshot.x_size.sim_put("1024") # type: ignore - oav.snapshot.y_size.sim_put("768") # type: ignore - oav.wait_for_connection() - return oav - - -@pytest.fixture -def mock_parameters(): - return OAVParameters( - "loopCentring", - OAV_CENTRING_JSON, - ) - - -@pytest.fixture -def mock_smargon(): - smargon: Smargon = make_fake_device(Smargon)(name="smargon") - smargon.wait_for_connection = do_nothing - return smargon - - -@pytest.fixture -def mock_backlight(): - backlight: Backlight = make_fake_device(Backlight)(name="backlight") - backlight.wait_for_connection = do_nothing - return backlight - - -def test_can_make_fake_testing_devices_and_use_run_engine( - mock_oav: OAV, - mock_parameters: OAVParameters, - mock_smargon: Smargon, - mock_backlight: Backlight, -): - @bpp.run_decorator() - def fake_run( - mock_oav: OAV, - mock_parameters: OAVParameters, - mock_smargon: Smargon, - mock_backlight: Backlight, - ): - yield from bps.abs_set(mock_oav.cam.acquire_period, 5) - mock_parameters.acquire_period = 10 - # can't change the smargon motors because of limit issues with FakeEpicsDevice - # yield from bps.mv(mock_smargon.omega, 1) - yield from bps.mv(mock_backlight.pos, 1) - - RE = RunEngine() - RE(fake_run(mock_oav, mock_parameters, mock_smargon, mock_backlight)) - - -def test_find_midpoint_symmetric_pin(): - x = np.arange(-15, 10, 25 / 1024) - x2 = x**2 - top = -1 * x2 + 100 - bottom = x2 - 100 - top += 500 - bottom += 500 - - # set the waveforms to 0 before the edge is found - top[np.where(top < bottom)[0]] = 0 - bottom[np.where(bottom > top)[0]] = 0 - - (x_pos, y_pos, width) = find_midpoint(top, bottom) - assert x_pos == 614 - assert y_pos == 500 - - -def test_find_midpoint_non_symmetric_pin(): - x = np.arange(-4, 2.35, 6.35 / 1024) - x2 = x**2 - x4 = x2**2 - top = -1 * x2 + 6 - bottom = x4 - 5 * x2 - 3 - - top += 400 - bottom += 400 - - # set the waveforms to 0 before the edge is found - top[np.where(top < bottom)[0]] = 0 - bottom[np.where(bottom > top)[0]] = 0 - - (x_pos, y_pos, width) = find_midpoint(top, bottom) - assert x_pos == 419 - assert np.floor(y_pos) == 397 - # x = 205/1024*4.7 - 2.35 ≈ -1.41 which is the first stationary point of the width on - # our midpoint line - - -def test_get_rotation_increment_threshold_within_180(): - increment = get_rotation_increment(6, 0, 180) - assert increment == 180 / 6 - - -def test_get_rotation_increment_threshold_exceeded(): - increment = get_rotation_increment(6, 30, 180) - assert increment == -180 / 6 - - -@pytest.mark.parametrize( - "value,lower_bound,upper_bound,expected_value", - [(0.5, -10, 10, 0.5), (-100, -10, 10, -10), (10000, -213, 50, 50)], -) -def test_keep_inside_bounds(value, lower_bound, upper_bound, expected_value): - assert keep_inside_bounds(value, lower_bound, upper_bound) == expected_value - - -def test_filter_rotation_data(): - x_positions = np.array([400, 450, 7, 500]) - y_positions = np.array([400, 450, 7, 500]) - widths = np.array([400, 450, 7, 500]) - omegas = np.array([400, 450, 7, 500]) - - ( - filtered_x, - filtered_y, - filtered_widths, - filtered_omegas, - ) = filter_rotation_data(x_positions, y_positions, widths, omegas) - - assert filtered_x[2] == 500 - assert filtered_omegas[2] == 500 - - -def test_filter_rotation_data_throws_error_when_all_fail(): - x_positions = np.array([1020, 20]) - y_positions = np.array([10, 450]) - widths = np.array([400, 450]) - omegas = np.array([400, 450]) - with pytest.raises(OAVError_NoRotationsPassValidityTest): - ( - filtered_x, - filtered_y, - filtered_widths, - filtered_omegas, - ) = filter_rotation_data(x_positions, y_positions, widths, omegas) - - -@pytest.mark.parametrize( - "max_tip_distance, tip_x, x, expected_return", - [ - (180, 400, 600, 580), - (180, 400, 450, 450), - ], -) -def test_keep_x_within_bounds(max_tip_distance, tip_x, x, expected_return): - assert check_i_within_bounds(max_tip_distance, tip_x, x) == expected_return - - -@pytest.mark.parametrize( - "h,v,omega,expected_values", - [ - (0.0, 0.0, 0.0, np.array([0.0, 0.0, 0.0])), - (10, -5, 90, np.array([-10, 3.062e-16, -5])), - (100, -50, 40, np.array([-100, 38.302, -32.139])), - (10, 100, -4, np.array([-10, -99.756, -6.976])), - ], -) -def test_distance_from_beam_centre_to_motor_coords_returns_the_same_values_as_GDA( - h, v, omega, expected_values, mock_oav: OAV, mock_parameters: OAVParameters -): - mock_parameters.zoom = "5.0" - mock_oav.zoom_controller.level.sim_put(mock_parameters.zoom) # type: ignore - results = camera_coordinates_to_xyz( - h, - v, - omega, - mock_oav.parameters.micronsPerXPixel, - mock_oav.parameters.micronsPerYPixel, - ) - expected_values = expected_values * 1e-3 - expected_values[0] *= mock_oav.parameters.micronsPerXPixel - expected_values[1] *= mock_oav.parameters.micronsPerYPixel - expected_values[2] *= mock_oav.parameters.micronsPerYPixel - expected_values = np.around(expected_values, decimals=3) - - assert np.array_equal(np.around(results, decimals=3), expected_values) - - -def test_find_widest_point_and_orthogonal_point(): - widths = np.array([400, 450, 7, 500, 600, 400]) - omegas = np.array([0, 30, 60, 90, 120, 180]) - assert find_widest_point_and_orthogonal_point(widths, omegas) == (4, 1) - - -def test_find_widest_point_and_orthogonal_point_no_orthogonal_angles(): - widths = np.array([400, 7, 500, 600, 400]) - omegas = np.array([0, 60, 90, 120, 180]) - with pytest.raises(OAVError_MissingRotations): - find_widest_point_and_orthogonal_point(widths, omegas) - - -def test_extract_pixel_centre_values_from_rotation_data(): - x_positions = np.array([400, 450, 7, 500, 475, 412]) - y_positions = np.array([500, 512, 518, 498, 486, 530]) - widths = np.array([400, 450, 7, 500, 600, 400]) - omegas = np.array([0, 30, 60, 90, 120, 180]) - assert extract_pixel_centre_values_from_rotation_data( - x_positions, y_positions, widths, omegas - ) == (475, 486, 512, 120, 30) - - -@pytest.mark.parametrize( - "angle_array,angle,expected_index", - [ - (np.array([0, 30, 60, 75, 110, 140, 160, 179]), 50, 5), - (np.array([0, 15, 10, 65, 89, 135, 174]), 0, 4), - (np.array([-40, -80, -52, 10, -3, -5, 60]), 85, 5), - (np.array([-150, -120, -90, -60, -30, 0]), 30, 3), - ( - np.array( - [6.0013e01, 3.0010e01, 7.0000e-03, -3.0002e01, -6.0009e01, -9.0016e01] - ), - -90.016, - 2, - ), - ], -) -def test_get_closest_orthogonal_index(angle_array, angle, expected_index): - assert get_orthogonal_index(angle_array, angle) == expected_index - - -def test_get_closest_orthogonal_index_not_orthogonal_enough(): - with pytest.raises(OAVError_MissingRotations): - get_orthogonal_index( - np.array([0, 30, 60, 90, 160, 180, 210, 240, 250, 255]), 50 - ) diff --git a/tests/devices/unit_tests/test_panda_gridscan.py b/tests/devices/unit_tests/test_panda_gridscan.py index c3db08245a..6338fd420a 100644 --- a/tests/devices/unit_tests/test_panda_gridscan.py +++ b/tests/devices/unit_tests/test_panda_gridscan.py @@ -67,7 +67,12 @@ def test_running_finished_with_all_images_done_then_complete_status_finishes_not RE = RunEngine() RE( set_fast_grid_scan_params( - fast_grid_scan, PandAGridScanParams(x_steps=num_pos_1d, y_steps=num_pos_1d) + fast_grid_scan, + PandAGridScanParams( + x_steps=num_pos_1d, + y_steps=num_pos_1d, + transmission_fraction=0.01, + ), ) ) diff --git a/tests/devices/unit_tests/test_pin_tip_detect.py b/tests/devices/unit_tests/test_pin_tip_detect.py deleted file mode 100644 index 951c305f19..0000000000 --- a/tests/devices/unit_tests/test_pin_tip_detect.py +++ /dev/null @@ -1,133 +0,0 @@ -from typing import Generator, List, Tuple - -import bluesky.plan_stubs as bps -import pytest -from bluesky.run_engine import RunEngine -from ophyd.sim import make_fake_device - -from dodal.devices.areadetector.plugins.MXSC import ( - PinTipDetect, - statistics_of_positions, -) - - -@pytest.fixture -def fake_pin_tip_detect() -> Generator[PinTipDetect, None, None]: - FakePinTipDetect = make_fake_device(PinTipDetect) - pin_tip_detect: PinTipDetect = FakePinTipDetect(name="pin_tip") - pin_tip_detect.settle_time_s.set(0).wait() - yield pin_tip_detect - - -def trigger_and_read( - fake_pin_tip_detect, values_to_set_during_trigger: List[Tuple] = None -): - yield from bps.trigger(fake_pin_tip_detect) - if values_to_set_during_trigger: - for position in values_to_set_during_trigger: - fake_pin_tip_detect.tip_y.sim_put(position[1]) # type: ignore - fake_pin_tip_detect.tip_x.sim_put(position[0]) # type: ignore - yield from bps.wait() - return (yield from bps.rd(fake_pin_tip_detect)) - - -def test_given_pin_tip_stays_invalid_when_triggered_then_return_( - fake_pin_tip_detect: PinTipDetect, -): - def set_small_timeout_then_trigger_and_read(): - yield from bps.abs_set(fake_pin_tip_detect.validity_timeout, 0.01) - return (yield from trigger_and_read(fake_pin_tip_detect)) - - RE = RunEngine(call_returns_result=True) - result = RE(set_small_timeout_then_trigger_and_read()) - - assert result.plan_result == fake_pin_tip_detect.INVALID_POSITION - - -def test_given_pin_tip_invalid_when_triggered_and_set_then_rd_returns_value( - fake_pin_tip_detect: PinTipDetect, -): - RE = RunEngine(call_returns_result=True) - result = RE(trigger_and_read(fake_pin_tip_detect, [(200, 100)])) - - assert result.plan_result == (200, 100) - - -def test_given_pin_tip_found_before_timeout_then_timeout_status_cleaned_up_and_tip_value_remains( - fake_pin_tip_detect: PinTipDetect, -): - RE = RunEngine(call_returns_result=True) - RE(trigger_and_read(fake_pin_tip_detect, [(100, 200)])) - # A success should clear up the timeout status but it may clear it up slightly later - # so we need the small timeout to avoid the race condition - fake_pin_tip_detect._timeout_status.wait(0.01) - assert fake_pin_tip_detect.triggered_tip.get() == (100, 200) - - -def test_median_of_positions_calculated_correctly(): - test = [(1, 2), (1, 5), (3, 3)] - - actual_med, _ = statistics_of_positions(test) - - assert actual_med == (1, 3) - - -def test_standard_dev_of_positions_calculated_correctly(): - test = [(1, 2), (1, 3)] - - _, actual_std = statistics_of_positions(test) - - assert actual_std == (0, 0.5) - - -def test_given_multiple_tips_found_then_running_median_calculated( - fake_pin_tip_detect: PinTipDetect, -): - fake_pin_tip_detect.settle_time_s.set(0.1).wait() - - RE = RunEngine(call_returns_result=True) - RE(trigger_and_read(fake_pin_tip_detect, [(100, 200), (50, 60), (400, 800)])) - - assert fake_pin_tip_detect.triggered_tip.get() == (100, 200) - - -def trigger_and_read_twice( - fake_pin_tip_detect: PinTipDetect, first_values: List[Tuple], second_value: Tuple -): - yield from trigger_and_read(fake_pin_tip_detect, first_values) - fake_pin_tip_detect.tip_y.sim_put(second_value[1]) # type: ignore - fake_pin_tip_detect.tip_x.sim_put(second_value[0]) # type: ignore - return (yield from trigger_and_read(fake_pin_tip_detect, [])) - - -def test_given_median_previously_calculated_when_triggered_again_then_only_calculated_on_new_values( - fake_pin_tip_detect: PinTipDetect, -): - fake_pin_tip_detect.settle_time_s.set(0.1).wait() - - RE = RunEngine(call_returns_result=True) - - def my_plan(): - tip_pos = yield from trigger_and_read_twice( - fake_pin_tip_detect, [(10, 20), (1, 3), (4, 8)], (100, 200) - ) - assert tip_pos == (100, 200) - - RE(my_plan()) - - -def test_given_previous_tip_found_when_this_tip_not_found_then_returns_invalid( - fake_pin_tip_detect: PinTipDetect, -): - fake_pin_tip_detect.settle_time_s.set(0.1).wait() - fake_pin_tip_detect.validity_timeout.set(0.5).wait() - - RE = RunEngine(call_returns_result=True) - - def my_plan(): - tip_pos = yield from trigger_and_read_twice( - fake_pin_tip_detect, [(10, 20), (1, 3), (4, 8)], (-1, -1) - ) - assert tip_pos == (-1, -1) - - RE(my_plan()) diff --git a/tests/devices/unit_tests/test_synchrotron.py b/tests/devices/unit_tests/test_synchrotron.py new file mode 100644 index 0000000000..b935df7e48 --- /dev/null +++ b/tests/devices/unit_tests/test_synchrotron.py @@ -0,0 +1,228 @@ +import json +from typing import Any, Awaitable, Callable, Dict, List + +import bluesky.plan_stubs as bps +import pytest +from bluesky.run_engine import RunEngine +from ophyd_async.core import DeviceCollector, StandardReadable, set_sim_value + +from dodal.devices.synchrotron import ( + Prefix, + Suffix, + Synchrotron, + SynchrotronMode, +) + +RING_CURRENT = 0.556677 +USER_COUNTDOWN = 55.0 +START_COUNTDOWN = 66.0 +END_COUNTDOWN = 77.0 +BEAM_ENERGY = 3.0158 +MODE = SynchrotronMode.INJECTION +NUMBER = "number" +STRING = "string" +EMPTY_LIST: List = [] + +READINGS = [RING_CURRENT, USER_COUNTDOWN, START_COUNTDOWN, END_COUNTDOWN] +CONFIGS = [BEAM_ENERGY, MODE.value] + +READING_FIELDS = ["value", "alarm_severity"] +DESCRIPTION_FIELDS = ["source", "dtype", "shape"] +READING_ADDRESSES = [ + f"sim://{Prefix.SIGNAL + Suffix.SIGNAL}", + f"sim://{Prefix.STATUS + Suffix.USER_COUNTDOWN}", + f"sim://{Prefix.TOP_UP + Suffix.COUNTDOWN}", + f"sim://{Prefix.TOP_UP + Suffix.END_COUNTDOWN}", +] + +CONFIG_ADDRESSES = [ + f"sim://{Prefix.STATUS + Suffix.BEAM_ENERGY}", + f"sim://{Prefix.STATUS + Suffix.MODE}", +] + +READ_SIGNALS = [ + "synchrotron-ring_current", + "synchrotron-machine_user_countdown", + "synchrotron-topup_start_countdown", + "synchrotron-top_up_end_countdown", +] + +CONFIG_SIGNALS = [ + "synchrotron-beam_energy", + "synchrotron-synchrotron_mode", +] + +EXPECTED_READ_RESULT = f"""{{ + "{READ_SIGNALS[0]}": {{ + "{READING_FIELDS[0]}": {READINGS[0]}, + "{READING_FIELDS[1]}": 0 + }}, + "{READ_SIGNALS[1]}": {{ + "{READING_FIELDS[0]}": {READINGS[1]}, + "{READING_FIELDS[1]}": 0 + }}, + "{READ_SIGNALS[2]}": {{ + "{READING_FIELDS[0]}": {READINGS[2]}, + "{READING_FIELDS[1]}": 0 + }}, + "{READ_SIGNALS[3]}": {{ + "{READING_FIELDS[0]}": {READINGS[3]}, + "{READING_FIELDS[1]}": 0 + }} +}}""" + +EXPECTED_READ_CONFIG_RESULT = f"""{{ + "{CONFIG_SIGNALS[0]}":{{ + "{READING_FIELDS[0]}": {CONFIGS[0]}, + "{READING_FIELDS[1]}": 0 + }}, + "{CONFIG_SIGNALS[1]}":{{ + "{READING_FIELDS[0]}": "{CONFIGS[1]}", + "{READING_FIELDS[1]}": 0 + }} +}}""" + +EXPECTED_DESCRIBE_RESULT = f"""{{ + "{READ_SIGNALS[0]}":{{ + "{DESCRIPTION_FIELDS[0]}": "{READING_ADDRESSES[0]}", + "{DESCRIPTION_FIELDS[1]}": "{NUMBER}", + "{DESCRIPTION_FIELDS[2]}": {EMPTY_LIST} + }}, + "{READ_SIGNALS[1]}":{{ + "{DESCRIPTION_FIELDS[0]}": "{READING_ADDRESSES[1]}", + "{DESCRIPTION_FIELDS[1]}": "{NUMBER}", + "{DESCRIPTION_FIELDS[2]}": {EMPTY_LIST} + }}, + "{READ_SIGNALS[2]}":{{ + "{DESCRIPTION_FIELDS[0]}": "{READING_ADDRESSES[2]}", + "{DESCRIPTION_FIELDS[1]}": "{NUMBER}", + "{DESCRIPTION_FIELDS[2]}": {EMPTY_LIST} + }}, + "{READ_SIGNALS[3]}":{{ + "{DESCRIPTION_FIELDS[0]}": "{READING_ADDRESSES[3]}", + "{DESCRIPTION_FIELDS[1]}": "{NUMBER}", + "{DESCRIPTION_FIELDS[2]}": {EMPTY_LIST} + }} +}}""" + +EXPECTED_DESCRIBE_CONFIG_RESULT = f"""{{ + "{CONFIG_SIGNALS[0]}":{{ + "{DESCRIPTION_FIELDS[0]}": "{CONFIG_ADDRESSES[0]}", + "{DESCRIPTION_FIELDS[1]}": "{NUMBER}", + "{DESCRIPTION_FIELDS[2]}": {EMPTY_LIST} + }}, + "{CONFIG_SIGNALS[1]}":{{ + "{DESCRIPTION_FIELDS[0]}": "{CONFIG_ADDRESSES[1]}", + "{DESCRIPTION_FIELDS[1]}": "{STRING}", + "{DESCRIPTION_FIELDS[2]}": {EMPTY_LIST} + }} +}}""" + + +@pytest.fixture +async def sim_synchrotron() -> Synchrotron: + async with DeviceCollector(sim=True): + sim_synchrotron = Synchrotron() + set_sim_value(sim_synchrotron.ring_current, RING_CURRENT) + set_sim_value(sim_synchrotron.machine_user_countdown, USER_COUNTDOWN) + set_sim_value(sim_synchrotron.topup_start_countdown, START_COUNTDOWN) + set_sim_value(sim_synchrotron.top_up_end_countdown, END_COUNTDOWN) + set_sim_value(sim_synchrotron.beam_energy, BEAM_ENERGY) + set_sim_value(sim_synchrotron.synchrotron_mode, MODE) + return sim_synchrotron + + +async def test_synchrotron_read(sim_synchrotron: Synchrotron): + await verify( + sim_synchrotron.read, + READ_SIGNALS, + READING_FIELDS, + EXPECTED_READ_RESULT, + ) + + +async def test_synchrotron_read_configuration(sim_synchrotron: Synchrotron): + await verify( + sim_synchrotron.read_configuration, + CONFIG_SIGNALS, + READING_FIELDS, + EXPECTED_READ_CONFIG_RESULT, + ) + + +async def test_synchrotron_describe(sim_synchrotron: Synchrotron): + await verify( + sim_synchrotron.describe, + READ_SIGNALS, + DESCRIPTION_FIELDS, + EXPECTED_DESCRIBE_RESULT, + ) + + +async def test_synchrotron_describe_configuration(sim_synchrotron: Synchrotron): + await verify( + sim_synchrotron.describe_configuration, + CONFIG_SIGNALS, + DESCRIPTION_FIELDS, + EXPECTED_DESCRIBE_CONFIG_RESULT, + ) + + +async def test_synchrotron_count(RE: RunEngine, sim_synchrotron: Synchrotron): + docs = [] + RE(count_sim(sim_synchrotron), lambda x, y: docs.append(y)) + + assert len(docs) == 4 + assert sim_synchrotron.name in docs[1]["configuration"] + cfg_data_keys = docs[1]["configuration"][sim_synchrotron.name]["data_keys"] + for sig, addr in zip(CONFIG_SIGNALS, CONFIG_ADDRESSES): + assert sig in cfg_data_keys + dtype = NUMBER if sig == CONFIG_SIGNALS[0] else STRING + assert cfg_data_keys[sig][DESCRIPTION_FIELDS[0]] == addr + assert cfg_data_keys[sig][DESCRIPTION_FIELDS[1]] == dtype + assert cfg_data_keys[sig][DESCRIPTION_FIELDS[2]] == EMPTY_LIST + cfg_data = docs[1]["configuration"][sim_synchrotron.name]["data"] + for sig, value in zip(CONFIG_SIGNALS, CONFIGS): + assert cfg_data[sig] == value + data_keys = docs[1]["data_keys"] + for sig, addr in zip(READ_SIGNALS, READING_ADDRESSES): + assert sig in data_keys + assert data_keys[sig][DESCRIPTION_FIELDS[0]] == addr + assert data_keys[sig][DESCRIPTION_FIELDS[1]] == NUMBER + assert data_keys[sig][DESCRIPTION_FIELDS[2]] == EMPTY_LIST + + data = docs[2]["data"] + assert len(data) == len(READ_SIGNALS) + for sig, value in zip(READ_SIGNALS, READINGS): + assert sig in data + assert data[sig] == value + + +async def verify( + func: Callable[[], Awaitable[Dict[str, Any]]], + signals: List[str], + fields: List[str], + expectation: str, +): + expected = json.loads(expectation) + result = await func() + + for signal in signals: + for field in fields: + assert result[signal][field] == expected[signal][field] + + +def count_sim(det: StandardReadable, times: int = 1): + """Test plan to do equivalent of bp.count for a sim detector (no file writing).""" + + yield from bps.stage_all(det) + yield from bps.open_run() + yield from bps.declare_stream(det, name="primary", collect=False) + for _ in range(times): + yield from bps.wait(group="wait_for_trigger") + yield from bps.create() + yield from bps.read(det) + yield from bps.save() + + yield from bps.close_run() + yield from bps.unstage_all(det) diff --git a/tests/devices/unit_tests/test_zebra.py b/tests/devices/unit_tests/test_zebra.py index 3285e18899..73a3d8686e 100644 --- a/tests/devices/unit_tests/test_zebra.py +++ b/tests/devices/unit_tests/test_zebra.py @@ -1,15 +1,64 @@ +from unittest.mock import AsyncMock + import pytest +from bluesky.run_engine import RunEngine from mockito import mock, verify -from ophyd.sim import make_fake_device from dodal.devices.zebra import ( + ArmDemand, + ArmingDevice, + ArmSource, GateType, + I03Axes, LogicGateConfiguration, LogicGateConfigurer, + PositionCompare, + TrigSource, boolean_array_to_integer, ) +async def test_arming_device(): + RunEngine() + arming_device = ArmingDevice("", name="fake arming device") + await arming_device.connect(sim=True) + status = arming_device.set(ArmDemand.DISARM) + await status + assert status.success + assert await arming_device.disarm_set.get_value() == 1 + + +async def test_position_compare_sets_signals(): + RunEngine() + fake_pc = PositionCompare("", name="fake position compare") + await fake_pc.connect(sim=True) + + async def mock_arm(demand): + fake_pc.arm.disarm_set._backend._set_value(not demand) # type: ignore + fake_pc.arm.arm_set._backend._set_value(demand) # type: ignore + await fake_pc.arm.armed.set(demand) + + fake_pc.arm.arm_set.set = AsyncMock(side_effect=mock_arm) + fake_pc.arm.disarm_set.set = AsyncMock(side_effect=mock_arm) + + fake_pc.gate_source.set(TrigSource.EXTERNAL) + fake_pc.gate_trigger.set(I03Axes.OMEGA) + fake_pc.num_gates.set(10) + + assert await fake_pc.gate_source.get_value() == "External" + assert await fake_pc.gate_trigger.get_value() == "Enc4" + assert await fake_pc.num_gates.get_value() == 10 + + fake_pc.arm_source.set(ArmSource.SOFT) + status = fake_pc.arm.set(ArmDemand.ARM) + await status + + assert await fake_pc.arm_source.get_value() == "Soft" + assert await fake_pc.arm.arm_set.get_value() == 1 + assert await fake_pc.arm.disarm_set.get_value() == 0 + assert await fake_pc.is_armed() + + @pytest.mark.parametrize( "boolean_array,expected_integer", [ @@ -45,14 +94,20 @@ def test_logic_gate_configuration_62_and_34_inv_and_15_inv(): assert str(config) == "INP1=62, INP2=!34, INP3=!15" -def run_configurer_test(gate_type: GateType, gate_num, config, expected_pv_values): - FakeLogicConfigurer = make_fake_device(LogicGateConfigurer) - configurer = FakeLogicConfigurer(name="test fake logicconfigurer") +async def run_configurer_test( + gate_type: GateType, + gate_num, + config, + expected_pv_values, +): + RunEngine() + configurer = LogicGateConfigurer(prefix="", name="test fake logicconfigurer") + await configurer.connect(sim=True) mock_gate_control = mock() mock_pvs = [mock() for i in range(6)] mock_gate_control.enable = mock_pvs[0] - mock_gate_control.sources = mock_pvs[1:5] + mock_gate_control.sources = {i: mock_pvs[i] for i in range(1, 5)} mock_gate_control.invert = mock_pvs[5] configurer.all_gates[gate_type][gate_num - 1] = mock_gate_control @@ -62,21 +117,21 @@ def run_configurer_test(gate_type: GateType, gate_num, config, expected_pv_value configurer.apply_or_gate_config(gate_num, config) for pv, value in zip(mock_pvs, expected_pv_values): - verify(pv).put(value) + verify(pv).set(value) -def test_apply_and_logic_gate_configuration_32_and_51_inv_and_1(): +async def test_apply_and_logic_gate_configuration_32_and_51_inv_and_1(): config = LogicGateConfiguration(32).add_input(51, True).add_input(1) expected_pv_values = [7, 32, 51, 1, 0, 2] - run_configurer_test(GateType.AND, 1, config, expected_pv_values) + await run_configurer_test(GateType.AND, 1, config, expected_pv_values) -def test_apply_or_logic_gate_configuration_19_and_36_inv_and_60_inv(): +async def test_apply_or_logic_gate_configuration_19_and_36_inv_and_60_inv(): config = LogicGateConfiguration(19).add_input(36, True).add_input(60, True) expected_pv_values = [7, 19, 36, 60, 0, 6] - run_configurer_test(GateType.OR, 2, config, expected_pv_values) + await run_configurer_test(GateType.OR, 2, config, expected_pv_values) @pytest.mark.parametrize( diff --git a/tests/devices/unit_tests/test_zocalo_interaction.py b/tests/devices/unit_tests/test_zocalo_interaction.py index ea43abfafe..688a3d751d 100644 --- a/tests/devices/unit_tests/test_zocalo_interaction.py +++ b/tests/devices/unit_tests/test_zocalo_interaction.py @@ -9,11 +9,20 @@ from dodal.devices.zocalo import ( ZocaloTrigger, ) +from dodal.devices.zocalo.zocalo_interaction import ZocaloStartInfo SIM_ZOCALO_ENV = "dev_artemis" EXPECTED_DCID = 100 -EXPECTED_RUN_START_MESSAGE = {"event": "start", "ispyb_dcid": EXPECTED_DCID} +EXPECTED_FILENAME = "test/file" +EXPECTED_RUN_START_MESSAGE = { + "ispyb_dcid": EXPECTED_DCID, + "filename": EXPECTED_FILENAME, + "start_frame_index": 0, + "number_of_frames": 100, + "message_index": 0, + "event": "start", +} EXPECTED_RUN_END_MESSAGE = { "event": "end", "ispyb_dcid": EXPECTED_DCID, @@ -65,27 +74,40 @@ def with_exception(function_to_run, mock_transport): @mark.parametrize( - "function_to_test,function_wrapper,expected_message", + "function_wrapper,expected_message", [ - (zc.run_start, normally, EXPECTED_RUN_START_MESSAGE), + (normally, EXPECTED_RUN_START_MESSAGE), ( - zc.run_start, with_exception, EXPECTED_RUN_START_MESSAGE, ), - (zc.run_end, normally, EXPECTED_RUN_END_MESSAGE), - (zc.run_end, with_exception, EXPECTED_RUN_END_MESSAGE), ], ) -def test__run_start_and_end( - function_to_test: Callable, function_wrapper: Callable, expected_message: Dict -): +def test_run_start(function_wrapper: Callable, expected_message: Dict): + """ + Args: + function_wrapper (Callable): A wrapper used to test for expected exceptions + expected_message (Dict): The expected dictionary sent to zocalo + """ + data = ZocaloStartInfo(EXPECTED_DCID, EXPECTED_FILENAME, 0, 100, 0) + function_to_run = partial(zc.run_start, data) + function_to_run = partial(function_wrapper, function_to_run) + _test_zocalo(function_to_run, expected_message) + + +@mark.parametrize( + "function_wrapper,expected_message", + [ + (normally, EXPECTED_RUN_END_MESSAGE), + (with_exception, EXPECTED_RUN_END_MESSAGE), + ], +) +def test__run_start_and_end(function_wrapper: Callable, expected_message: Dict): """ Args: - function_to_test (Callable): The function to test e.g. start/stop zocalo - function_wrapper (Callable): A wrapper around the function, used to test for expected exceptions + function_wrapper (Callable): A wrapper used to test for expected exceptions expected_message (Dict): The expected dictionary sent to zocalo """ - function_to_run = partial(function_to_test, EXPECTED_DCID) + function_to_run = partial(zc.run_end, EXPECTED_DCID) function_to_run = partial(function_wrapper, function_to_run) _test_zocalo(function_to_run, expected_message) diff --git a/tests/devices/unit_tests/test_zocalo_results.py b/tests/devices/unit_tests/test_zocalo_results.py index 281ab449f2..346c941b86 100644 --- a/tests/devices/unit_tests/test_zocalo_results.py +++ b/tests/devices/unit_tests/test_zocalo_results.py @@ -4,7 +4,6 @@ import bluesky.plan_stubs as bps import numpy as np import pytest -import pytest_asyncio from bluesky.run_engine import RunEngine from bluesky.utils import FailedStatus from ophyd_async.core.async_status import AsyncStatus @@ -81,7 +80,7 @@ @patch("dodal.devices.zocalo_results._get_zocalo_connection") -@pytest_asyncio.fixture +@pytest.fixture async def mocked_zocalo_device(RE): async def device(results, run_setup=False): zd = ZocaloResults(zocalo_environment="test_env") @@ -106,7 +105,6 @@ def plan(): return device -@pytest.mark.asyncio async def test_put_result_read_results( mocked_zocalo_device, RE, @@ -122,7 +120,6 @@ async def test_put_result_read_results( assert np.all(bboxes[0] == [2, 2, 1]) -@pytest.mark.asyncio async def test_rd_top_results( mocked_zocalo_device, RE, @@ -141,7 +138,6 @@ def test_plan(): RE(test_plan()) -@pytest.mark.asyncio async def test_trigger_and_wait_puts_results( mocked_zocalo_device, RE, @@ -159,7 +155,6 @@ def plan(): zocalo_device._put_results.assert_called() -@pytest.mark.asyncio async def test_extraction_plan(mocked_zocalo_device, RE) -> None: zocalo_device: ZocaloResults = await mocked_zocalo_device( TEST_RESULTS, run_setup=False @@ -176,7 +171,6 @@ def plan(): RE(plan()) -@pytest.mark.asyncio @patch( "dodal.devices.zocalo.zocalo_results.workflows.recipe.wrap_subscribe", autospec=True ) @@ -202,7 +196,6 @@ async def test_subscribe_only_on_called_stage( mock_wrap_subscribe.assert_called_once() -@pytest.mark.asyncio @patch("dodal.devices.zocalo.zocalo_results._get_zocalo_connection", autospec=True) async def test_when_exception_caused_by_zocalo_message_then_exception_propagated( mock_connection, diff --git a/tests/plans/test_topup_plan.py b/tests/plans/test_topup_plan.py new file mode 100644 index 0000000000..4cb079af1c --- /dev/null +++ b/tests/plans/test_topup_plan.py @@ -0,0 +1,94 @@ +from unittest.mock import patch + +import bluesky.plan_stubs as bps +import pytest +from bluesky.run_engine import RunEngine +from ophyd_async.core import set_sim_value + +from dodal.beamlines import i03 +from dodal.devices.synchrotron import Synchrotron, SynchrotronMode +from dodal.plans.check_topup import ( + check_topup_and_wait_if_necessary, + wait_for_topup_complete, +) + + +@pytest.fixture +def synchrotron() -> Synchrotron: + return i03.synchrotron(fake_with_ophyd_sim=True) + + +@patch("dodal.plans.check_topup.wait_for_topup_complete") +@patch("dodal.plans.check_topup.bps.sleep") +def test_when_topup_before_end_of_collection_wait( + fake_sleep, fake_wait, synchrotron: Synchrotron +): + set_sim_value(synchrotron.synchrotron_mode, SynchrotronMode.USER) + set_sim_value(synchrotron.topup_start_countdown, 20.0) + set_sim_value(synchrotron.top_up_end_countdown, 60.0) + + RE = RunEngine() + RE( + check_topup_and_wait_if_necessary( + synchrotron=synchrotron, + total_exposure_time=40.0, + ops_time=30.0, + ) + ) + fake_sleep.assert_called_once_with(60.0) + + +@patch("dodal.plans.check_topup.bps.rd") +@patch("dodal.plans.check_topup.bps.sleep") +def test_wait_for_topup_complete(fake_sleep, fake_rd, synchrotron): + def fake_generator(value): + yield from bps.null() + return value + + fake_rd.side_effect = [ + fake_generator(0.0), + fake_generator(0.0), + fake_generator(0.0), + fake_generator(10.0), + ] + + RE = RunEngine() + RE(wait_for_topup_complete(synchrotron)) + + assert fake_sleep.call_count == 3 + fake_sleep.assert_called_with(0.1) + + +@patch("dodal.plans.check_topup.bps.sleep") +@patch("dodal.plans.check_topup.bps.null") +def test_no_waiting_if_decay_mode(fake_null, fake_sleep, synchrotron: Synchrotron): + set_sim_value(synchrotron.topup_start_countdown, -1) + + RE = RunEngine() + RE( + check_topup_and_wait_if_necessary( + synchrotron=synchrotron, + total_exposure_time=10.0, + ops_time=1.0, + ) + ) + fake_null.assert_called_once() + assert fake_sleep.call_count == 0 + + +@patch("dodal.plans.check_topup.bps.null") +def test_no_waiting_when_mode_does_not_allow_gating( + fake_null, synchrotron: Synchrotron +): + set_sim_value(synchrotron.topup_start_countdown, 1.0) + set_sim_value(synchrotron.synchrotron_mode, SynchrotronMode.SHUTDOWN) + + RE = RunEngine() + RE( + check_topup_and_wait_if_necessary( + synchrotron=synchrotron, + total_exposure_time=10.0, + ops_time=1.0, + ) + ) + fake_null.assert_called_once() diff --git a/tests/unit_tests/test_log.py b/tests/unit_tests/test_log.py index 410618f605..d0ed3205f6 100644 --- a/tests/unit_tests/test_log.py +++ b/tests/unit_tests/test_log.py @@ -91,7 +91,7 @@ def test_no_env_variable_sets_correct_file_handler( expected_calls = [ call(filename=PosixPath("tmp/dev/dodal.log"), when="MIDNIGHT", backupCount=30), - call(PosixPath("tmp/dev/debug/dodal.log"), when="H"), + call(PosixPath("tmp/dev/debug/dodal.log"), when="H", backupCount=7), ] mock_file_handler.assert_has_calls(expected_calls, any_order=True)