Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plugin API: add a update_inputs to plugin's workchain_and_builder #656

Merged
merged 8 commits into from
Apr 16, 2024
11 changes: 11 additions & 0 deletions src/aiidalab_qe/plugins/bands/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,19 @@ def get_builder(codes, structure, parameters, **kwargs):
return bands


def update_inputs(inputs, ctx):
"""Update the inputs using context."""
inputs.structure = ctx.current_structure
inputs.scf.pw.parameters = inputs.scf.pw.parameters.get_dict()
if ctx.current_number_of_bands:
inputs.scf.pw.parameters.setdefault("SYSTEM", {}).setdefault(
"nbnd", ctx.current_number_of_bands
)


workchain_and_builder = {
"workchain": PwBandsWorkChain,
"exclude": ("structure", "relax"),
"get_builder": get_builder,
"update_inputs": update_inputs,
}
11 changes: 11 additions & 0 deletions src/aiidalab_qe/plugins/pdos/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,19 @@ def get_builder(codes, structure, parameters, **kwargs):
return pdos


def update_inputs(inputs, ctx):
"""Update the inputs using context."""
inputs.structure = ctx.current_structure
inputs.nscf.pw.parameters = inputs.nscf.pw.parameters.get_dict()
if ctx.current_number_of_bands:
inputs.nscf.pw.parameters.setdefault("SYSTEM", {}).setdefault(
"nbnd", ctx.current_number_of_bands
)


workchain_and_builder = {
"workchain": PdosWorkChain,
"exclude": ("structure", "relax"),
"get_builder": get_builder,
"update_inputs": update_inputs,
}
6 changes: 6 additions & 0 deletions src/aiidalab_qe/plugins/xas/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,14 @@ def get_builder(codes, structure, parameters, **kwargs):
return builder


def update_inputs(inputs, ctx):
"""Update the inputs using context."""
inputs.structure = ctx.current_structure


workchain_and_builder = {
"workchain": XspectraCrystalWorkChain,
"exclude": ("structure", "relax"),
"get_builder": get_builder,
"update_inputs": update_inputs,
}
6 changes: 6 additions & 0 deletions src/aiidalab_qe/plugins/xps/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,14 @@ def get_builder(codes, structure, parameters, **kwargs):
return builder


def update_inputs(inputs, ctx):
"""Update the inputs using context."""
inputs.structure = ctx.current_structure


workchain_and_builder = {
"workchain": XpsWorkChain,
"exclude": ("structure", "relax"),
"get_builder": get_builder,
"update_inputs": update_inputs,
}
3 changes: 2 additions & 1 deletion src/aiidalab_qe/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def run_plugin(self):
self.exposed_inputs(plugin_workchain, namespace=name)
)
inputs.metadata.call_link_label = name
inputs.structure = self.ctx.current_structure
if entry_point.get("update_inputs"):
entry_point["update_inputs"](inputs, self.ctx)
inputs = prepare_process_inputs(plugin_workchain, inputs)
running = self.submit(plugin_workchain, **inputs)
self.report(f"launching plugin {name} <{running.pk}>")
Expand Down
Loading