Skip to content

Commit

Permalink
add a update_inputs to plugin's workchain_and_builder (#656)
Browse files Browse the repository at this point in the history
This PR enhances the plugin API by introducing a new feature to the `workchain_and_builder` configuration: the `update_inputs` method. 

This addition addresses limitations in the previous design, where the main workchain rigidly passed only the `structure` to the plugins' workchains. This approach had notable disadvantages:

- It mandated that all plugins' workchains accept an input named `structure`, leading to errors if this was not adhered to. This is the case in the newly developed plugin: [aiidalab-qe-hp](https://github.com/superstar54/aiidalab-qe-hp).
- It restricted data passed to plugins to be hardcoded, thereby limiting plugin developers' ability to tailor the inputs based on their specific requirements.

The introduction of the `update_inputs` method provides a dynamic mechanism for updating the inputs of plugins' workchains. It accepts two parameters: `inputs` and `ctx`. This method enables developers to customize how the inputs to a plugin's workchain are updated, leveraging the broader context of the Workchain. An immediate benefit of this improvement is observed in the **bands** and **pdos** plugins, which can now access `current_number_of_bands` from the context if available.
  • Loading branch information
superstar54 authored Apr 16, 2024
1 parent dcb951d commit 4476651
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 1 deletion.
11 changes: 11 additions & 0 deletions src/aiidalab_qe/plugins/bands/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,19 @@ def get_builder(codes, structure, parameters, **kwargs):
return bands


def update_inputs(inputs, ctx):
"""Update the inputs using context."""
inputs.structure = ctx.current_structure
inputs.scf.pw.parameters = inputs.scf.pw.parameters.get_dict()
if ctx.current_number_of_bands:
inputs.scf.pw.parameters.setdefault("SYSTEM", {}).setdefault(
"nbnd", ctx.current_number_of_bands
)


workchain_and_builder = {
"workchain": PwBandsWorkChain,
"exclude": ("structure", "relax"),
"get_builder": get_builder,
"update_inputs": update_inputs,
}
11 changes: 11 additions & 0 deletions src/aiidalab_qe/plugins/pdos/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,19 @@ def get_builder(codes, structure, parameters, **kwargs):
return pdos


def update_inputs(inputs, ctx):
"""Update the inputs using context."""
inputs.structure = ctx.current_structure
inputs.nscf.pw.parameters = inputs.nscf.pw.parameters.get_dict()
if ctx.current_number_of_bands:
inputs.nscf.pw.parameters.setdefault("SYSTEM", {}).setdefault(
"nbnd", ctx.current_number_of_bands
)


workchain_and_builder = {
"workchain": PdosWorkChain,
"exclude": ("structure", "relax"),
"get_builder": get_builder,
"update_inputs": update_inputs,
}
6 changes: 6 additions & 0 deletions src/aiidalab_qe/plugins/xas/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,14 @@ def get_builder(codes, structure, parameters, **kwargs):
return builder


def update_inputs(inputs, ctx):
"""Update the inputs using context."""
inputs.structure = ctx.current_structure


workchain_and_builder = {
"workchain": XspectraCrystalWorkChain,
"exclude": ("structure", "relax"),
"get_builder": get_builder,
"update_inputs": update_inputs,
}
6 changes: 6 additions & 0 deletions src/aiidalab_qe/plugins/xps/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,14 @@ def get_builder(codes, structure, parameters, **kwargs):
return builder


def update_inputs(inputs, ctx):
"""Update the inputs using context."""
inputs.structure = ctx.current_structure


workchain_and_builder = {
"workchain": XpsWorkChain,
"exclude": ("structure", "relax"),
"get_builder": get_builder,
"update_inputs": update_inputs,
}
3 changes: 2 additions & 1 deletion src/aiidalab_qe/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def run_plugin(self):
self.exposed_inputs(plugin_workchain, namespace=name)
)
inputs.metadata.call_link_label = name
inputs.structure = self.ctx.current_structure
if entry_point.get("update_inputs"):
entry_point["update_inputs"](inputs, self.ctx)
inputs = prepare_process_inputs(plugin_workchain, inputs)
running = self.submit(plugin_workchain, **inputs)
self.report(f"launching plugin {name} <{running.pk}>")
Expand Down

0 comments on commit 4476651

Please sign in to comment.