From 4476651aa23fe547546d5b51dd6779fb1e2056c1 Mon Sep 17 00:00:00 2001 From: Xing Wang Date: Tue, 16 Apr 2024 16:20:23 +0200 Subject: [PATCH] add a `update_inputs` to plugin's `workchain_and_builder` (#656) This PR enhances the plugin API by introducing a new feature to the `workchain_and_builder` configuration: the `update_inputs` method. This addition addresses limitations in the previous design, where the main workchain rigidly passed only the `structure` to the plugins' workchains. This approach had notable disadvantages: - It mandated that all plugins' workchains accept an input named `structure`, leading to errors if this was not adhered to. This is the case in the newly developed plugin: [aiidalab-qe-hp](https://github.com/superstar54/aiidalab-qe-hp). - It restricted data passed to plugins to be hardcoded, thereby limiting plugin developers' ability to tailor the inputs based on their specific requirements. The introduction of the `update_inputs` method provides a dynamic mechanism for updating the inputs of plugins' workchains. It accepts two parameters: `inputs` and `ctx`. This method enables developers to customize how the inputs to a plugin's workchain are updated, leveraging the broader context of the Workchain. An immediate benefit of this improvement is observed in the **bands** and **pdos** plugins, which can now access `current_number_of_bands` from the context if available. --- src/aiidalab_qe/plugins/bands/workchain.py | 11 +++++++++++ src/aiidalab_qe/plugins/pdos/workchain.py | 11 +++++++++++ src/aiidalab_qe/plugins/xas/workchain.py | 6 ++++++ src/aiidalab_qe/plugins/xps/workchain.py | 6 ++++++ src/aiidalab_qe/workflows/__init__.py | 3 ++- 5 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/aiidalab_qe/plugins/bands/workchain.py b/src/aiidalab_qe/plugins/bands/workchain.py index 9f7bd7d0e..56594b2b7 100644 --- a/src/aiidalab_qe/plugins/bands/workchain.py +++ b/src/aiidalab_qe/plugins/bands/workchain.py @@ -225,8 +225,19 @@ def get_builder(codes, structure, parameters, **kwargs): return bands +def update_inputs(inputs, ctx): + """Update the inputs using context.""" + inputs.structure = ctx.current_structure + inputs.scf.pw.parameters = inputs.scf.pw.parameters.get_dict() + if ctx.current_number_of_bands: + inputs.scf.pw.parameters.setdefault("SYSTEM", {}).setdefault( + "nbnd", ctx.current_number_of_bands + ) + + workchain_and_builder = { "workchain": PwBandsWorkChain, "exclude": ("structure", "relax"), "get_builder": get_builder, + "update_inputs": update_inputs, } diff --git a/src/aiidalab_qe/plugins/pdos/workchain.py b/src/aiidalab_qe/plugins/pdos/workchain.py index 4dd5fd665..2d5fd3e40 100644 --- a/src/aiidalab_qe/plugins/pdos/workchain.py +++ b/src/aiidalab_qe/plugins/pdos/workchain.py @@ -78,8 +78,19 @@ def get_builder(codes, structure, parameters, **kwargs): return pdos +def update_inputs(inputs, ctx): + """Update the inputs using context.""" + inputs.structure = ctx.current_structure + inputs.nscf.pw.parameters = inputs.nscf.pw.parameters.get_dict() + if ctx.current_number_of_bands: + inputs.nscf.pw.parameters.setdefault("SYSTEM", {}).setdefault( + "nbnd", ctx.current_number_of_bands + ) + + workchain_and_builder = { "workchain": PdosWorkChain, "exclude": ("structure", "relax"), "get_builder": get_builder, + "update_inputs": update_inputs, } diff --git a/src/aiidalab_qe/plugins/xas/workchain.py b/src/aiidalab_qe/plugins/xas/workchain.py index 0c08ce431..47a189cf5 100644 --- a/src/aiidalab_qe/plugins/xas/workchain.py +++ b/src/aiidalab_qe/plugins/xas/workchain.py @@ -104,8 +104,14 @@ def get_builder(codes, structure, parameters, **kwargs): return builder +def update_inputs(inputs, ctx): + """Update the inputs using context.""" + inputs.structure = ctx.current_structure + + workchain_and_builder = { "workchain": XspectraCrystalWorkChain, "exclude": ("structure", "relax"), "get_builder": get_builder, + "update_inputs": update_inputs, } diff --git a/src/aiidalab_qe/plugins/xps/workchain.py b/src/aiidalab_qe/plugins/xps/workchain.py index 238215d62..270400fc7 100644 --- a/src/aiidalab_qe/plugins/xps/workchain.py +++ b/src/aiidalab_qe/plugins/xps/workchain.py @@ -98,8 +98,14 @@ def get_builder(codes, structure, parameters, **kwargs): return builder +def update_inputs(inputs, ctx): + """Update the inputs using context.""" + inputs.structure = ctx.current_structure + + workchain_and_builder = { "workchain": XpsWorkChain, "exclude": ("structure", "relax"), "get_builder": get_builder, + "update_inputs": update_inputs, } diff --git a/src/aiidalab_qe/workflows/__init__.py b/src/aiidalab_qe/workflows/__init__.py index 093484d0e..cb5c5698e 100644 --- a/src/aiidalab_qe/workflows/__init__.py +++ b/src/aiidalab_qe/workflows/__init__.py @@ -238,7 +238,8 @@ def run_plugin(self): self.exposed_inputs(plugin_workchain, namespace=name) ) inputs.metadata.call_link_label = name - inputs.structure = self.ctx.current_structure + if entry_point.get("update_inputs"): + entry_point["update_inputs"](inputs, self.ctx) inputs = prepare_process_inputs(plugin_workchain, inputs) running = self.submit(plugin_workchain, **inputs) self.report(f"launching plugin {name} <{running.pk}>")