diff --git a/pyproject.toml b/pyproject.toml index c5b48619..ad43981b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ requires-python = ">=3.10" plotting = ["scipy", "matplotlib"] # REST service support service = ["fastapi>=0.100.0", "uvicorn"] -# For development tests/docs dev = [ # This syntax is supported since pip 21.2 # https://github.com/pypa/pip/issues/10393 @@ -41,11 +40,13 @@ dev = [ "sphinx-copybutton", "sphinx-design", "sphinxcontrib-openapi", + "strawberry-graphql[debug-server]", + "strawberry-graphql[fastapi]", "tox-direct", "types-mock", "httpx", "myst-parser", -] +] # For development tests/docs [project.scripts] scanspec = "scanspec.cli:cli" diff --git a/src/scanspec/schema/__init__.py b/src/scanspec/schema/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/scanspec/schema/resolvers.py b/src/scanspec/schema/resolvers.py new file mode 100644 index 00000000..999cd822 --- /dev/null +++ b/src/scanspec/schema/resolvers.py @@ -0,0 +1,46 @@ +from typing import Any + +import numpy as np +from specs import Line + +from scanspec.core import ( + Frames, + Path, +) + + +def reduce_frames(stack: list[Frames[str]], max_frames: int) -> Path: + """Removes frames from a spec so len(path) < max_frames. + + Args: + stack: A stack of Frames created by a spec + max_frames: The maximum number of frames the user wishes to be returned + """ + # Calculate the total number of frames + num_frames = 1 + for frames in stack: + num_frames *= len(frames) + + # Need each dim to be this much smaller + ratio = 1 / np.power(max_frames / num_frames, 1 / len(stack)) + + sub_frames = [sub_sample(f, ratio) for f in stack] + return Path(sub_frames) + + +def sub_sample(frames: Frames[str], ratio: float) -> Frames: + """Provides a sub-sample Frames object whilst preserving its core structure. + + Args: + frames: the Frames object to be reduced + ratio: the reduction ratio of the dimension + """ + num_indexes = int(len(frames) / ratio) + indexes = np.linspace(0, len(frames) - 1, num_indexes, dtype=np.int32) + return frames.extract(indexes, calculate_gap=False) + + +def validate_spec(spec: Line) -> Any: + """A query used to confirm whether or not the Spec will produce a viable scan.""" + # TODO apischema will do all the validation for us + return spec.serialize() diff --git a/src/scanspec/schema/schema.py b/src/scanspec/schema/schema.py new file mode 100644 index 00000000..4f58b172 --- /dev/null +++ b/src/scanspec/schema/schema.py @@ -0,0 +1,107 @@ +from typing import Any + +import strawberry +from fastapi import FastAPI +from resolvers import reduce_frames, validate_spec +from specs import PointsResponse +from strawberry.fastapi import GraphQLRouter + +from scanspec.core import Path +from scanspec.specs import Line, Spec + +# Here is the manual version of what we are trying to do + +# @strawberry.input +# class LineInput(Line): ... + + +# @strawberry.input +# class ZipInput(Zip): ... + + +# @strawberry.input(one_of=True) +# class SpecInput: +# ... + +# line: LineInput | None = strawberry.UNSET +# zip: ZipInput | None = strawberry.UNSET + + +def generate_input_class() -> type[Any]: + # This will be our input class, we're going to fiddle with it + # throughout this function + class SpecInput: ... + + # We want to go through all the possible scan specs, this isn't + # currently possible but can be implemented. + # Raise an issue for a helper function to get all possible scanspec + # types. + for spec_type in Spec.types: + # We make a strawberry input classs using the scanspec pydantic models + # This isn't possible because scanspec models are actually pydantic + # dataclasses. We should have a word with Tom about it and probably + # raise an issue on strawberry. + @strawberry.experimental.pydantic.input(all_fields=True, model=spec_type) + class InputClass: ... + + # Renaming the class to LineInput, ZipInput etc. so the + # schema looks neater + InputClass.__name__ = spec_type.__name__ + "Input" + + # Add a field to the class called line, zip etc. and make it + # strawberry.UNSET + setattr(SpecInput, spec_type.__name__, strawberry.UNSET) + + # Set the type annotation to line | none, zip | none, etc. + # Strawberry will read this and graphqlify it. + SpecInput.__annotations__[spec_type.__name__] = InputClass | None + + # This is just a programtic equivalent of + # @strawberry.input(one_of=True) + # class SpecInput: + # ... + return strawberry.input(one_of=True)(SpecInput) + + +SpecInput = generate_input_class() + + +@strawberry.type +class Query: + @strawberry.field + def validate(self, spec: SpecInput) -> str: + return validate_spec(spec) + + @strawberry.field + def get_points(self, spec: Line, max_frames: int | None = 10000) -> PointsResponse: + """Calculate the frames present in the scan plus some metadata about the points. + + Args: + spec: The specification of the scan + max_frames: The maximum number of frames the user wishes to receive + """ + + dims = spec.calculate() # Grab dimensions from spec + + path = Path(dims) # Convert to a path + + # TOTAL FRAMES + total_frames = len(path) # Capture the total length of the path + + # MAX FRAMES + # Limit the consumed data by the max_frames argument + if max_frames and (max_frames < len(path)): + # Cap the frames by the max limit + path = reduce_frames(dims, max_frames) + # WARNING: path object is consumed after this statement + chunk = path.consume(max_frames) + + return PointsResponse(chunk, total_frames) + + +schema = strawberry.Schema(Query) + +graphql_app = GraphQLRouter(schema, path="/", graphql_ide="apollo-sandbox") + +app = FastAPI() +app.include_router(graphql_app, prefix="/graphql") diff --git a/src/scanspec/schema/specs.py b/src/scanspec/schema/specs.py new file mode 100644 index 00000000..1e629e06 --- /dev/null +++ b/src/scanspec/schema/specs.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from _collections_abc import Callable, Mapping +from typing import Any + +import numpy as np +import strawberry + +from scanspec.core import ( + Axis, + Frames, + gap_between_frames, +) + + +@strawberry.type +class PointsResponse: + """Information about the points provided by a spec.""" + + total_frames: int + returned_frames: int + + def __init__(self, chunk: Frames[str], total_frames: int): + self.total_frames = total_frames + """The number of frames present across the entire spec""" + self.returned_frames = len(chunk) + """The number of frames returned by the getPoints query + (controlled by the max_points argument)""" + self._chunk = chunk + + +@strawberry.interface +class SpecInterface: + def serialize(self) -> Mapping[str, Any]: + """Serialize the spec to a dictionary.""" + return "serialized" + + +def _dimensions_from_indexes( + func: Callable[[np.ndarray], dict[Axis, np.ndarray]], + axes: list, + num: int, + bounds: bool, +) -> list[Frames[Axis]]: + # Calc num midpoints (fences) from 0.5 .. num - 0.5 + midpoints_calc = func(np.linspace(0.5, num - 0.5, num)) + midpoints = {a: midpoints_calc[a] for a in axes} + if bounds: + # Calc num + 1 bounds (posts) from 0 .. num + bounds_calc = func(np.linspace(0, num, num + 1)) + lower = {a: bounds_calc[a][:-1] for a in axes} + upper = {a: bounds_calc[a][1:] for a in axes} + # Points must have no gap as upper[a][i] == lower[a][i+1] + # because we initialized it to be that way + gap = np.zeros(num, dtype=np.bool_) + dimension = Frames(midpoints, lower, upper, gap) + # But calc the first point as difference between first + # and last + gap[0] = gap_between_frames(dimension, dimension) + else: + # Gap can be calculated in Dimension + dimension = Frames(midpoints) + return [dimension] + + +@strawberry.input +class Line(SpecInterface): + axis: str = strawberry.field(description="An identifier for what to move") + start: float = strawberry.field( + description="Midpoint of the first point of the line" + ) + stop: float = strawberry.field(description="Midpoint of the last point of the line") + num: int = strawberry.field(description="Number of frames to produce") + + def axes(self) -> list: + return [self.axis] + + def _line_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]: + if self.num == 1: + # Only one point, stop-start gives length of one point + step = self.stop - self.start + else: + # Multiple points, stop-start gives length of num-1 points + step = (self.stop - self.start) / (self.num - 1) + # self.start is the first centre point, but we need the lower bound + # of the first point as this is where the index array starts + first = self.start - step / 2 + return {self.axis: indexes * step + first} + + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: + return _dimensions_from_indexes( + self._line_from_indexes, self.axes(), self.num, bounds + )