Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inject for default plan arguments again #773

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 2 additions & 20 deletions src/blueapi/worker/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class Task(BlueapiBaseModel):

def prepare_params(self, ctx: BlueskyContext) -> Mapping[str, Any]:
model = _lookup_params(ctx, self)
return _model_to_kwargs(model)
# Re-create dict manually to avoid nesting in model_dump output
return {field: getattr(model, field) for field in model.__pydantic_fields__}

def do_task(self, ctx: BlueskyContext) -> None:
LOGGER.info(f"Asked to run plan {self.name} with {self.params}")
Expand All @@ -49,22 +50,3 @@ def _lookup_params(ctx: BlueskyContext, task: Task) -> BaseModel:
model = plan.model
adapter = TypeAdapter(model)
return adapter.validate_python(task.params)


def _model_to_kwargs(model: BaseModel) -> Mapping[str, Any]:
"""
Converts an instance of BaseModel back to a dictionary that
can be passed as **kwargs.
Used instead of BaseModel.model_dump() because we don't want
the dumping to be nested and because it fires UserWarnings
about data types it is unfamiliar with
(such as ophyd devices).

Args:
model: Pydantic model to convert to kwargs

Returns:
Mapping[str, Any]: Dictionary that can be passed as **kwargs
"""

return {name: getattr(model, name) for name in model.model_fields_set}
23 changes: 23 additions & 0 deletions tests/unit_tests/worker/test_task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,26 @@ def test_begin_task_span_ok(
task_id = worker.submit_task(_SIMPLE_TASK)
with asserting_span_exporter(exporter, "begin_task", "task_id"):
worker.begin_task(task_id)


def test_injected_devices_are_found(
fake_device: FakeDevice,
context: BlueskyContext,
):
def injected_device_plan(dev: FakeDevice = fake_device.name) -> MsgGenerator: # type: ignore
yield from ()

context.register_plan(injected_device_plan)
params = Task(name="injected_device_plan").prepare_params(context)
assert params["dev"] == fake_device


def test_missing_injected_devices_fail_early(
context: BlueskyContext,
):
def missing_injection(dev: FakeDevice = "does_not_exist") -> MsgGenerator: # type: ignore
yield from ()

context.register_plan(missing_injection)
with pytest.raises(ValueError):
Task(name="missing_injection").prepare_params(context)
Loading