diff --git a/src/aiidalab_qe/app/configuration/__init__.py b/src/aiidalab_qe/app/configuration/__init__.py
index 456e3f388..70b6084c0 100644
--- a/src/aiidalab_qe/app/configuration/__init__.py
+++ b/src/aiidalab_qe/app/configuration/__init__.py
@@ -56,10 +56,10 @@ def __init__(self, model: ConfigurationModel, **kwargs):
self.settings: dict[str, SettingsPanel] = {}
- workchain_model = WorkChainModel(include=True)
+ workchain_model = WorkChainModel()
self._model.add_model("workchain", workchain_model)
- advanced_model = AdvancedModel(include=True)
+ advanced_model = AdvancedModel()
self._model.add_model("advanced", advanced_model)
self._fetch_plugin_settings()
diff --git a/src/aiidalab_qe/app/configuration/advanced/advanced.py b/src/aiidalab_qe/app/configuration/advanced/advanced.py
index 39503639d..327a72bf4 100644
--- a/src/aiidalab_qe/app/configuration/advanced/advanced.py
+++ b/src/aiidalab_qe/app/configuration/advanced/advanced.py
@@ -341,7 +341,7 @@ def _on_input_structure_change(self, _):
def _on_protocol_change(self, _):
self.refresh(specific="protocol")
- def _on_kpoints_distance_change(self, _=None):
+ def _on_kpoints_distance_change(self, _):
self.refresh(specific="mesh")
def _on_override_change(self, change):
diff --git a/src/aiidalab_qe/app/configuration/advanced/hubbard/model.py b/src/aiidalab_qe/app/configuration/advanced/hubbard/model.py
index 51739314e..de5ffd66c 100644
--- a/src/aiidalab_qe/app/configuration/advanced/hubbard/model.py
+++ b/src/aiidalab_qe/app/configuration/advanced/hubbard/model.py
@@ -49,8 +49,19 @@ def __init__(self, *args, **kwargs):
}
def update(self, specific=""):
+ if self.input_structure is None:
+ self.applicable_kinds = []
+ self.orbital_labels = []
+ self._defaults |= {
+ "parameters": {},
+ "eigenvalues": [],
+ }
+ else:
+ self.orbital_labels = self._define_orbital_labels()
+ self._defaults["parameters"] = self._define_default_parameters()
+ self.applicable_kinds = self._define_applicable_kinds()
+ self._defaults["eigenvalues"] = self._define_default_eigenvalues()
with self.hold_trait_notifications():
- self._update_defaults(specific)
self.parameters = self._get_default_parameters()
self.eigenvalues = self._get_default_eigenvalues()
self.needs_eigenvalues_widget = len(self.applicable_kinds) > 0
@@ -82,20 +93,6 @@ def reset(self):
self.parameters = self._get_default_parameters()
self.eigenvalues = self._get_default_eigenvalues()
- def _update_defaults(self, specific=""):
- if self.input_structure is None:
- self.applicable_kinds = []
- self.orbital_labels = []
- self._defaults |= {
- "parameters": {},
- "eigenvalues": [],
- }
- else:
- self.orbital_labels = self._define_orbital_labels()
- self._defaults["parameters"] = self._define_default_parameters()
- self.applicable_kinds = self._define_applicable_kinds()
- self._defaults["eigenvalues"] = self._define_default_eigenvalues()
-
def _define_orbital_labels(self):
hubbard_manifold_list = [
self._get_manifold(Element(kind.symbol))
diff --git a/src/aiidalab_qe/app/configuration/advanced/magnetization/model.py b/src/aiidalab_qe/app/configuration/advanced/magnetization/model.py
index f031fd3f8..1b29fef4e 100644
--- a/src/aiidalab_qe/app/configuration/advanced/magnetization/model.py
+++ b/src/aiidalab_qe/app/configuration/advanced/magnetization/model.py
@@ -35,8 +35,13 @@ def __init__(self, *args, **kwargs):
}
def update(self, specific=""):
+ if self.spin_type == "none" or self.input_structure is None:
+ self._defaults["moments"] = {}
+ else:
+ self._defaults["moments"] = {
+ symbol: 0.0 for symbol in self.input_structure.get_kind_names()
+ }
with self.hold_trait_notifications():
- self._update_defaults(specific)
self.moments = self._get_default_moments()
def reset(self):
@@ -45,13 +50,5 @@ def reset(self):
self.total = self.traits()["total"].default_value
self.moments = self._get_default_moments()
- def _update_defaults(self, specific=""):
- if self.spin_type == "none" or self.input_structure is None:
- self._defaults["moments"] = {}
- else:
- self._defaults["moments"] = {
- symbol: 0.0 for symbol in self.input_structure.get_kind_names()
- }
-
def _get_default_moments(self):
return deepcopy(self._defaults["moments"])
diff --git a/src/aiidalab_qe/app/configuration/advanced/model.py b/src/aiidalab_qe/app/configuration/advanced/model.py
index 25d4bd23f..246a74f2e 100644
--- a/src/aiidalab_qe/app/configuration/advanced/model.py
+++ b/src/aiidalab_qe/app/configuration/advanced/model.py
@@ -60,8 +60,10 @@ class AdvancedModel(SettingsModel):
kpoints_distance = tl.Float(0.0)
mesh_grid = tl.Unicode("")
- def __init__(self, include=False, *args, **kwargs):
- super().__init__(include, *args, **kwargs)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.include = True
self.dftd3_version = {
"dft-d3": 3,
@@ -84,19 +86,21 @@ def __init__(self, include=False, *args, **kwargs):
def update(self, specific=""):
with self.hold_trait_notifications():
- self._update_defaults(specific)
- self.forc_conv_thr = self._defaults["forc_conv_thr"]
- self.forc_conv_thr_step = self._defaults["forc_conv_thr_step"]
- self.etot_conv_thr = self._defaults["etot_conv_thr"]
- self.etot_conv_thr_step = self._defaults["etot_conv_thr_step"]
- self.scf_conv_thr = self._defaults["scf_conv_thr"]
- self.scf_conv_thr_step = self._defaults["scf_conv_thr_step"]
- self.kpoints_distance = self._defaults["kpoints_distance"]
+ if not specific or specific != "mesh":
+ parameters = PwBaseWorkChain.get_protocol_inputs(self.protocol)
+ self._update_kpoints_distance(parameters)
+ self._update_thresholds(parameters)
+ self._update_kpoints_mesh()
def add_model(self, identifier, model):
self._models[identifier] = model
self._link_model(model)
+ def get_model(self, identifier) -> AdvancedSubModel:
+ if identifier in self._models:
+ return self._models[identifier]
+ raise ValueError(f"Model with identifier '{identifier}' not found.")
+
def get_models(self):
return self._models.items()
@@ -122,7 +126,7 @@ def get_model_state(self):
"kpoints_distance": self.kpoints_distance,
}
- hubbard: HubbardModel = self._get_model("hubbard") # type: ignore
+ hubbard: HubbardModel = self.get_model("hubbard") # type: ignore
if hubbard.is_active:
parameters["hubbard_parameters"] = {"hubbard_u": hubbard.parameters}
if hubbard.has_eigenvalues:
@@ -130,7 +134,7 @@ def get_model_state(self):
"starting_ns_eigenvalue": hubbard.get_active_eigenvalues()
}
- pseudos: PseudosModel = self._get_model("pseudos") # type: ignore
+ pseudos: PseudosModel = self.get_model("pseudos") # type: ignore
parameters["pseudo_family"] = pseudos.family
if pseudos.dictionary:
parameters["pw"]["pseudos"] = pseudos.dictionary
@@ -145,8 +149,8 @@ def get_model_state(self):
self.dftd3_version[self.van_der_waals]
)
- smearing: SmearingModel = self._get_model("smearing") # type: ignore
- magnetization: MagnetizationModel = self._get_model("magnetization") # type: ignore
+ smearing: SmearingModel = self.get_model("smearing") # type: ignore
+ magnetization: MagnetizationModel = self.get_model("magnetization") # type: ignore
if self.spin_type == "collinear":
parameters["initial_magnetic_moments"] = magnetization.moments
if self.electronic_type == "metal":
@@ -181,7 +185,7 @@ def get_model_state(self):
return parameters
def set_model_state(self, parameters):
- pseudos: PseudosModel = self._get_model("pseudos") # type: ignore
+ pseudos: PseudosModel = self.get_model("pseudos") # type: ignore
if "pseudo_family" in parameters:
pseudo_family = PseudoFamily.from_string(parameters["pseudo_family"])
library = pseudo_family.library
@@ -199,7 +203,7 @@ def set_model_state(self, parameters):
if (pw_parameters := parameters.get("pw", {}).get("parameters")) is not None:
self._set_pw_parameters(pw_parameters)
- magnetization: MagnetizationModel = self._get_model("magnetization") # type: ignore
+ magnetization: MagnetizationModel = self.get_model("magnetization") # type: ignore
if magnetic_moments := parameters.get("initial_magnetic_moments"):
if isinstance(magnetic_moments, (int, float)):
magnetic_moments = [magnetic_moments]
@@ -212,7 +216,7 @@ def set_model_state(self, parameters):
)
magnetization.moments = magnetic_moments
- hubbard: HubbardModel = self._get_model("hubbard") # type: ignore
+ hubbard: HubbardModel = self.get_model("hubbard") # type: ignore
if parameters.get("hubbard_parameters"):
hubbard.is_active = True
hubbard.parameters = parameters["hubbard_parameters"]["hubbard_u"]
@@ -256,16 +260,6 @@ def _link_model(self, model):
(model, trait),
)
- def _update_defaults(self, specific=""):
- if not specific or specific != "mesh":
- parameters = PwBaseWorkChain.get_protocol_inputs(self.protocol)
- self._update_kpoints_distance(parameters)
-
- self._update_kpoints_mesh()
-
- if not specific or specific == "protocol":
- self._update_thresholds(parameters)
-
def _update_kpoints_mesh(self, _=None):
if self.input_structure is None:
mesh_grid = ""
@@ -284,18 +278,25 @@ def _update_kpoints_mesh(self, _=None):
def _update_kpoints_distance(self, parameters):
kpoints_distance = parameters["kpoints_distance"] if self.has_pbc else 100.0
self._defaults["kpoints_distance"] = kpoints_distance
+ self.kpoints_distance = self._defaults["kpoints_distance"]
def _update_thresholds(self, parameters):
num_atoms = len(self.input_structure.sites) if self.input_structure else 1
etot_value = num_atoms * parameters["meta_parameters"]["etot_conv_thr_per_atom"]
self._set_value_and_step("etot_conv_thr", etot_value)
+ self.etot_conv_thr = self._defaults["etot_conv_thr"]
+ self.etot_conv_thr_step = self._defaults["etot_conv_thr_step"]
scf_value = num_atoms * parameters["meta_parameters"]["conv_thr_per_atom"]
self._set_value_and_step("scf_conv_thr", scf_value)
+ self.scf_conv_thr = self._defaults["scf_conv_thr"]
+ self.scf_conv_thr_step = self._defaults["scf_conv_thr_step"]
forc_value = parameters["pw"]["parameters"]["CONTROL"]["forc_conv_thr"]
self._set_value_and_step("forc_conv_thr", forc_value)
+ self.forc_conv_thr = self._defaults["forc_conv_thr"]
+ self.forc_conv_thr_step = self._defaults["forc_conv_thr_step"]
def _set_value_and_step(self, attribute, value):
self._defaults[attribute] = value
@@ -324,18 +325,13 @@ def _set_pw_parameters(self, pw_parameters):
system_params.get("vdw_corr", "none"),
)
- smearing: SmearingModel = self._get_model("smearing") # type: ignore
+ smearing: SmearingModel = self.get_model("smearing") # type: ignore
if "degauss" in system_params:
smearing.degauss = system_params["degauss"]
if "smearing" in system_params:
smearing.type = system_params["smearing"]
- magnetization: MagnetizationModel = self._get_model("magnetization") # type: ignore
+ magnetization: MagnetizationModel = self.get_model("magnetization") # type: ignore
if "tot_magnetization" in system_params:
magnetization.type = "tot_magnetization"
-
- def _get_model(self, identifier) -> AdvancedSubModel:
- if identifier in self._models:
- return self._models[identifier]
- raise ValueError(f"Model with identifier '{identifier}' not found.")
diff --git a/src/aiidalab_qe/app/configuration/advanced/pseudos/model.py b/src/aiidalab_qe/app/configuration/advanced/pseudos/model.py
index 03c5068dd..7da4d3fb3 100644
--- a/src/aiidalab_qe/app/configuration/advanced/pseudos/model.py
+++ b/src/aiidalab_qe/app/configuration/advanced/pseudos/model.py
@@ -129,7 +129,16 @@ def __init__(self, *args, **kwargs):
def update(self, specific=""):
with self.hold_trait_notifications():
- self._update_defaults(specific)
+ if self.input_structure is None:
+ self._defaults |= {
+ "dictionary": {},
+ "cutoffs": [[0.0], [0.0]],
+ }
+ else:
+ self.update_default_pseudos()
+ self.update_default_cutoffs()
+ self.update_family_parameters()
+ self.update_family()
def update_default_pseudos(self):
try:
@@ -262,18 +271,6 @@ def reset(self):
self.family_help_message = self.PSEUDO_HELP_WO_SOC
self.status_message = ""
- def _update_defaults(self, specific=""):
- if self.input_structure is None:
- self._defaults |= {
- "dictionary": {},
- "cutoffs": [[0.0], [0.0]],
- }
- else:
- self.update_default_pseudos()
- self.update_default_cutoffs()
- self.update_family_parameters()
- self.update_family()
-
def _get_pseudo_family_from_database(self):
"""Get the pseudo family from the database."""
return (
diff --git a/src/aiidalab_qe/app/configuration/advanced/smearing/model.py b/src/aiidalab_qe/app/configuration/advanced/smearing/model.py
index 5dd94b0cf..6d626f988 100644
--- a/src/aiidalab_qe/app/configuration/advanced/smearing/model.py
+++ b/src/aiidalab_qe/app/configuration/advanced/smearing/model.py
@@ -27,17 +27,6 @@ def __init__(self, *args, **kwargs):
}
def update(self, specific=""):
- with self.hold_trait_notifications():
- self._update_defaults(specific)
- self.type = self._defaults["type"]
- self.degauss = self._defaults["degauss"]
-
- def reset(self):
- with self.hold_trait_notifications():
- self.type = self._defaults["type"]
- self.degauss = self._defaults["degauss"]
-
- def _update_defaults(self, specific=""):
parameters = (
PwBaseWorkChain.get_protocol_inputs(self.protocol)
.get("pw", {})
@@ -48,3 +37,11 @@ def _update_defaults(self, specific=""):
"type": parameters["smearing"],
"degauss": parameters["degauss"],
}
+ with self.hold_trait_notifications():
+ self.type = self._defaults["type"]
+ self.degauss = self._defaults["degauss"]
+
+ def reset(self):
+ with self.hold_trait_notifications():
+ self.type = self._defaults["type"]
+ self.degauss = self._defaults["degauss"]
diff --git a/src/aiidalab_qe/app/configuration/advanced/subsettings.py b/src/aiidalab_qe/app/configuration/advanced/subsettings.py
index f0e4c69e6..b7105fd54 100644
--- a/src/aiidalab_qe/app/configuration/advanced/subsettings.py
+++ b/src/aiidalab_qe/app/configuration/advanced/subsettings.py
@@ -1,3 +1,5 @@
+import os
+
import ipywidgets as ipw
import traitlets as tl
@@ -21,16 +23,6 @@ def reset(self):
"""Resets the model to present defaults."""
raise NotImplementedError
- def _update_defaults(self, specific=""):
- """Updates the model's default values.
-
- Parameters
- ----------
- `specific` : `str`, optional
- If provided, specifies the level of update.
- """
- raise NotImplementedError
-
class AdvancedSubSettings(ipw.VBox):
identifier = "sub"
@@ -74,6 +66,9 @@ def refresh(self, specific=""):
self.updated = False
self._unsubscribe()
self._update(specific)
+ if "PYTEST_CURRENT_TEST" in os.environ:
+ # Skip resetting to avoid having to inject a structure when testing
+ return
if hasattr(self._model, "input_structure") and not self._model.input_structure:
self._reset()
diff --git a/src/aiidalab_qe/app/configuration/basic/model.py b/src/aiidalab_qe/app/configuration/basic/model.py
index 3e733c0f8..a40f4e932 100644
--- a/src/aiidalab_qe/app/configuration/basic/model.py
+++ b/src/aiidalab_qe/app/configuration/basic/model.py
@@ -18,8 +18,10 @@ class WorkChainModel(SettingsModel):
spin_type = tl.Unicode(DEFAULT["workchain"]["spin_type"])
electronic_type = tl.Unicode(DEFAULT["workchain"]["electronic_type"])
- def __init__(self, include=False, *args, **kwargs):
- super().__init__(include, *args, **kwargs)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.include = True
self._defaults = {
"protocol": self.traits()["protocol"].default_value,
diff --git a/src/aiidalab_qe/app/configuration/model.py b/src/aiidalab_qe/app/configuration/model.py
index 270ba96b7..83b0862be 100644
--- a/src/aiidalab_qe/app/configuration/model.py
+++ b/src/aiidalab_qe/app/configuration/model.py
@@ -51,23 +51,53 @@ def __init__(self, *args, **kwargs):
"""
def update(self):
- self._update_defaults()
- self.relax_type_help = self._get_default_relax_type_help()
- self.relax_type_options = self._get_default_relax_type_options()
- self.relax_type = self._get_default_relax_type()
+ if self.has_pbc:
+ relax_type_help = self.relax_type_help_template.format(
+ option_count="three",
+ full_relaxation_option=(
+ """
+
+ (3) Full geometry: perform a full relaxation of the internal atomic
+ coordinates and the cell parameters.
+ """
+ ),
+ )
+ relax_type_options = [
+ ("Structure as is", "none"),
+ ("Atomic positions", "positions"),
+ ("Full geometry", "positions_cell"),
+ ]
+ else:
+ relax_type_help = self.relax_type_help_template.format(
+ option_count="two",
+ full_relaxation_option="",
+ )
+ relax_type_options = [
+ ("Structure as is", "none"),
+ ("Atomic positions", "positions"),
+ ]
+ self._defaults = {
+ "relax_type_help": relax_type_help,
+ "relax_type_options": relax_type_options,
+ "relax_type": relax_type_options[-1][-1],
+ }
+ with self.hold_trait_notifications():
+ self.relax_type_help = self._get_default_relax_type_help()
+ self.relax_type_options = self._get_default_relax_type_options()
+ self.relax_type = self._get_default_relax_type()
def add_model(self, identifier, model):
self._models[identifier] = model
self._link_model(model)
- def get_models(self):
- return self._models.items()
-
def get_model(self, identifier) -> SettingsModel:
if identifier in self._models:
return self._models[identifier]
raise ValueError(f"Model with identifier '{identifier}' not found.")
+ def get_models(self):
+ return self._models.items()
+
def get_model_state(self):
parameters = {
identifier: model.get_model_state()
@@ -86,9 +116,9 @@ def set_model_state(self, parameters):
self.relax_type = workchain_parameters.get("relax_type")
properties = set(workchain_parameters.get("properties", []))
for identifier, model in self._models.items():
+ model.include = identifier in self._default_models | properties
if parameters.get(identifier):
model.set_model_state(parameters[identifier])
- model.include = identifier in self._default_models | properties
def reset(self):
self.confirmed = False
@@ -134,38 +164,6 @@ def _get_properties(self):
properties.append("relax")
return properties
- def _update_defaults(self):
- if self.has_pbc:
- relax_type_help = self.relax_type_help_template.format(
- option_count="three",
- full_relaxation_option=(
- """
-
- (3) Full geometry: perform a full relaxation of the internal atomic
- coordinates and the cell parameters.
- """
- ),
- )
- relax_type_options = [
- ("Structure as is", "none"),
- ("Atomic positions", "positions"),
- ("Full geometry", "positions_cell"),
- ]
- else:
- relax_type_help = self.relax_type_help_template.format(
- option_count="two",
- full_relaxation_option="",
- )
- relax_type_options = [
- ("Structure as is", "none"),
- ("Atomic positions", "positions"),
- ]
- self._defaults = {
- "relax_type_help": relax_type_help,
- "relax_type_options": relax_type_options,
- "relax_type": relax_type_options[-1][-1],
- }
-
def _get_default_relax_type_help(self):
return self._defaults.get("relax_type_help", "")
diff --git a/src/aiidalab_qe/common/panel.py b/src/aiidalab_qe/common/panel.py
index 3bd7f81b6..878a51db6 100644
--- a/src/aiidalab_qe/common/panel.py
+++ b/src/aiidalab_qe/common/panel.py
@@ -75,10 +75,9 @@ class SettingsModel(tl.HasTraits):
_defaults = {}
- def __init__(self, include=False, *args, **kwargs):
+ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.observe(self.unconfirm, tl.All)
- self.include = include
@property
def has_pbc(self):
@@ -112,16 +111,6 @@ def unconfirm(self, change):
if change["name"] != "confirmed":
self.confirmed = False
- def _update_defaults(self, specific=""):
- """Updates the model's default values.
-
- Parameters
- ----------
- `specific` : `str`, optional
- If provided, specifies the level of update.
- """
- raise NotImplementedError
-
class SettingsPanel(Panel):
title = "Settings"
@@ -161,9 +150,8 @@ def refresh(self, specific=""):
"""
self.updated = False
self._unsubscribe()
- if not self._model.include:
- return
- self.update(specific)
+ if self._model.include:
+ self.update(specific)
if "PYTEST_CURRENT_TEST" in os.environ:
# Skip resetting to avoid having to inject a structure when testing
return
diff --git a/src/aiidalab_qe/common/widgets.py b/src/aiidalab_qe/common/widgets.py
index 42ee7cfe0..a27152cf3 100644
--- a/src/aiidalab_qe/common/widgets.py
+++ b/src/aiidalab_qe/common/widgets.py
@@ -957,7 +957,6 @@ def __init__(self, title=None, **kwargs):
def render(self):
if self.rendered:
return
- from aiidalab_qe.common.widgets import LoadingWidget
self.children = [LoadingWidget(f"Loading {self.title} widget")]
diff --git a/src/aiidalab_qe/plugins/pdos/model.py b/src/aiidalab_qe/plugins/pdos/model.py
index deb7e541f..8c664acae 100644
--- a/src/aiidalab_qe/plugins/pdos/model.py
+++ b/src/aiidalab_qe/plugins/pdos/model.py
@@ -22,8 +22,8 @@ class PdosModel(SettingsModel):
use_pdos_degauss = tl.Bool(False)
pdos_degauss = tl.Float(0.005)
- def __init__(self, include=False, *args, **kwargs):
- super().__init__(include, *args, **kwargs)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
self._defaults = {
"kpoints_distance": self.traits()["kpoints_distance"].default_value,
@@ -31,8 +31,11 @@ def __init__(self, include=False, *args, **kwargs):
def update(self, specific=""):
with self.hold_trait_notifications():
- self._update_defaults(specific)
- self.kpoints_distance = self._defaults["kpoints_distance"]
+ if not specific or specific != "mesh":
+ parameters = PdosWorkChain.get_protocol_inputs(self.protocol)
+ self._update_kpoints_distance(parameters)
+
+ self._update_kpoints_mesh()
def get_model_state(self):
return {
@@ -55,13 +58,6 @@ def reset(self):
self.use_pdos_degauss = self.traits()["use_pdos_degauss"].default_value
self.pdos_degauss = self.traits()["pdos_degauss"].default_value
- def _update_defaults(self, specific=""):
- if not specific or specific != "mesh":
- parameters = PdosWorkChain.get_protocol_inputs(self.protocol)
- self._update_kpoints_distance(parameters)
-
- self._update_kpoints_mesh()
-
def _update_kpoints_mesh(self, _=None):
if self.input_structure is None:
mesh_grid = ""
@@ -84,3 +80,4 @@ def _update_kpoints_distance(self, parameters):
kpoints_distance = 100.0
self.use_pdos_degauss = True
self._defaults["kpoints_distance"] = kpoints_distance
+ self.kpoints_distance = self._defaults["kpoints_distance"]
diff --git a/src/aiidalab_qe/plugins/xas/model.py b/src/aiidalab_qe/plugins/xas/model.py
index 67c919a5a..a8b70f929 100644
--- a/src/aiidalab_qe/plugins/xas/model.py
+++ b/src/aiidalab_qe/plugins/xas/model.py
@@ -60,8 +60,9 @@ def __init__(self, *args, **kwargs):
self.installed_pseudos = False
def update(self, specific=""):
- self._update_pseudos()
- self._update_core_hole_treatment_recommendations()
+ with self.hold_trait_notifications():
+ self._update_pseudos()
+ self._update_core_hole_treatment_recommendations()
def get_model_state(self):
pseudo_labels = {}
diff --git a/src/aiidalab_qe/plugins/xps/model.py b/src/aiidalab_qe/plugins/xps/model.py
index 4498bc1c7..4911be55a 100644
--- a/src/aiidalab_qe/plugins/xps/model.py
+++ b/src/aiidalab_qe/plugins/xps/model.py
@@ -34,9 +34,10 @@ class XpsModel(SettingsModel):
)
def update(self, specific=""):
- self._update_correction_energies()
- if not specific or specific == "pseudo_group":
- self._update_pseudos()
+ with self.hold_trait_notifications():
+ self._update_correction_energies()
+ if not specific or specific == "pseudo_group":
+ self._update_pseudos()
def get_supported_core_levels(self):
supported_core_levels = {}
@@ -67,7 +68,6 @@ def set_model_state(self, parameters: dict):
self.traits()["structure_type"].default_value,
)
- # TODO check logic
core_level_list = parameters.get("core_level_list", [])
for orbital in self.core_levels:
if orbital in core_level_list:
diff --git a/tests/configuration/test_advanced.py b/tests/configuration/test_advanced.py
index 4a9a1f43b..79a69342d 100644
--- a/tests/configuration/test_advanced.py
+++ b/tests/configuration/test_advanced.py
@@ -7,27 +7,31 @@ def test_advanced_default():
"""Test default behavior of advanced setting."""
model = AdvancedModel()
_ = AdvancedSettings(model=model)
+ smearing = model.get_model("smearing")
# Test override functionality in advanced settings
model.override = True
model.protocol = "fast"
- model.smearing.type = "methfessel-paxton"
- model.smearing.degauss = 0.03
+ smearing.type = "methfessel-paxton"
+ smearing.degauss = 0.03
model.kpoints_distance = 0.22
# Reset values to default w.r.t protocol
model.override = False
- assert model.smearing.type == "cold"
- assert model.smearing.degauss == 0.01
+ assert smearing.type == "cold"
+ assert smearing.degauss == 0.01
assert model.kpoints_distance == 0.5
def test_advanced_smearing_settings():
"""Test Smearing Settings."""
- from aiidalab_qe.app.configuration.advanced.smearing import SmearingSettings
+ from aiidalab_qe.app.configuration.advanced.smearing import (
+ SmearingModel,
+ SmearingSettings,
+ )
- model = AdvancedModel()
+ model = SmearingModel()
smearing = SmearingSettings(model=model)
smearing.render()
@@ -40,23 +44,22 @@ def test_advanced_smearing_settings():
assert smearing.degauss.disabled is False
assert smearing.smearing.disabled is False
- assert model.smearing.type == "cold"
- assert model.smearing.degauss == 0.01
+ assert model.type == "cold"
+ assert model.degauss == 0.01
# Test protocol-dependent smearing change
model.protocol = "fast"
- assert model.smearing.type == "cold"
- assert model.smearing.degauss == 0.01
+ assert model.type == "cold"
+ assert model.degauss == 0.01
# Check reset
- model.smearing.type = "gaussian"
- model.smearing.degauss = 0.05
- model.smearing.reset()
+ model.type = "gaussian"
+ model.degauss = 0.05
+ model.override = False
- assert model.protocol == "fast" # reset does not apply to protocol
- assert model.smearing.type == "cold"
- assert model.smearing.degauss == 0.01
+ assert model.type == "cold"
+ assert model.degauss == 0.01
def test_advanced_kpoints_settings():
@@ -156,9 +159,12 @@ def test_advanced_kpoints_mesh(generate_structure_data):
@pytest.mark.usefixtures("aiida_profile_clean", "sssp")
def test_advanced_hubbard_settings(generate_structure_data):
"""Test Hubbard widget."""
- from aiidalab_qe.app.configuration.advanced.hubbard import HubbardSettings
+ from aiidalab_qe.app.configuration.advanced.hubbard import (
+ HubbardModel,
+ HubbardSettings,
+ )
- model = AdvancedModel()
+ model = HubbardModel()
hubbard = HubbardSettings(model=model)
hubbard.render()
@@ -166,8 +172,8 @@ def test_advanced_hubbard_settings(generate_structure_data):
model.input_structure = structure
# Activate Hubbard U widget
- model.hubbard.is_active = True
- assert model.hubbard.orbital_labels == ["Co - 3d", "O - 2p", "Li - 2s"]
+ model.is_active = True
+ assert model.orbital_labels == ["Co - 3d", "O - 2p", "Li - 2s"]
# Change the Hubbard U parameters for Co, O, and Li
hubbard_parameters = hubbard.hubbard_widget.children[1:] # type: ignore
@@ -175,7 +181,7 @@ def test_advanced_hubbard_settings(generate_structure_data):
hubbard_parameters[1].value = 2 # O - 2p
hubbard_parameters[2].value = 3 # Li - 2s
- assert model.hubbard.parameters == {
+ assert model.parameters == {
"Co - 3d": 1.0,
"O - 2p": 2.0,
"Li - 2s": 3.0,
@@ -191,9 +197,9 @@ def test_advanced_hubbard_settings(generate_structure_data):
# assert model.hubbard.eigenvalues == [] # TODO should they be?
# Check there is only eigenvalues for Co (Transition metal)
- model.hubbard.has_eigenvalues = True
- assert len(model.hubbard.applicable_kinds) == 1
- assert len(model.hubbard.eigenvalues) == 1
+ model.has_eigenvalues = True
+ assert len(model.applicable_kinds) == 1
+ assert len(model.eigenvalues) == 1
Co_eigenvalues = hubbard.eigenvalues_widget.children[0].children[1] # type: ignore
Co_spin_down_row = Co_eigenvalues.children[1]
@@ -201,7 +207,7 @@ def test_advanced_hubbard_settings(generate_structure_data):
Co_spin_down_row.children[3].value = "1"
Co_spin_down_row.children[5].value = "1"
- assert model.hubbard.get_active_eigenvalues() == [
+ assert model.get_active_eigenvalues() == [
[1, 1, "Co", 1],
[3, 1, "Co", 1],
[5, 1, "Co", 1],
diff --git a/tests/conftest.py b/tests/conftest.py
index 4596af2a5..72fc2d854 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -461,15 +461,16 @@ def _submit_app_generator(
advanced_model.electron_maxstep = electron_maxstep
if isinstance(initial_magnetic_moments, (int, float)):
initial_magnetic_moments = [initial_magnetic_moments]
- advanced_model.magnetization.moments = dict(
+ advanced_model.get_model("magnetization").moments = dict(
zip(
app.configure_model.input_structure.get_kind_names(),
initial_magnetic_moments,
)
)
# mimic the behavior of the smearing widget set up
- advanced_model.smearing.type = smearing
- advanced_model.smearing.degauss = degauss
+ smearing_model = advanced_model.get_model("smearing")
+ smearing_model.type = smearing
+ smearing_model.degauss = degauss
app.configure_step.confirm()
app.submit_model.input_structure = generate_structure_data()
@@ -739,7 +740,9 @@ def _generate_qeapp_workchain(
from_example.children[0].value = from_example.children[0].options[1][1]
else:
structure.store()
- aiida_database = app.structure_step.manager.children[0].children[2] # type: ignore
+ aiida_database_wrapper = app.structure_step.manager.children[0].children[2] # type: ignore
+ aiida_database_wrapper.render()
+ aiida_database = aiida_database_wrapper.children[0] # type: ignore
aiida_database.search()
aiida_database.results.value = structure
@@ -764,20 +767,20 @@ def _generate_qeapp_workchain(
if spin_type == "collinear":
advanced_model.override = True
- magnetization = advanced_model.magnetization
+ magnetization_model = advanced_model.get_model("magnetization")
if electronic_type == "insulator":
- magnetization.total = tot_magnetization
+ magnetization_model.total = tot_magnetization
elif magnetization_type == "starting_magnetization":
if isinstance(initial_magnetic_moments, (int, float)):
initial_magnetic_moments = [initial_magnetic_moments]
- magnetization.moments = dict(
+ magnetization_model.moments = dict(
zip(
structure.get_kind_names(),
initial_magnetic_moments,
)
)
else:
- magnetization.total = tot_magnetization
+ magnetization_model.total = tot_magnetization
app.configure_step.confirm()
diff --git a/tests/test_app.py b/tests/test_app.py
index 996455826..6c1681733 100644
--- a/tests/test_app.py
+++ b/tests/test_app.py
@@ -19,7 +19,8 @@ def test_reload_and_reset(generate_qeapp_workchain):
assert app.configure_model.get_model("workchain").spin_type == "collinear"
assert app.configure_model.get_model("bands").include is True
assert app.configure_model.get_model("pdos").include is False
- assert len(app.configure_model.get_model("advanced").pseudos.dictionary) > 0
+ advanced_model = app.configure_model.get_model("advanced")
+ assert len(advanced_model.get_model("pseudos").dictionary) > 0
assert app.configure_step.state == app.configure_step.State.SUCCESS
diff --git a/tests/test_configure.py b/tests/test_configure.py
index 509a039cd..716709537 100644
--- a/tests/test_configure.py
+++ b/tests/test_configure.py
@@ -71,6 +71,6 @@ def test_reminder_info():
assert bands_info.value == ""
bands_model = model.get_model("bands")
bands_model.include = True
- assert bands_info.value == "Customize bands settings in step 2.2 if needed"
+ assert bands_info.value == "Customize bands settings in Step 2.2 if needed"
bands_model.include = False
assert bands_info.value == ""
diff --git a/tests/test_pseudo.py b/tests/test_pseudo.py
index 3c193c6f1..df3e9ffc4 100644
--- a/tests/test_pseudo.py
+++ b/tests/test_pseudo.py
@@ -142,54 +142,53 @@ def test_download_and_install_pseudo_from_file(tmp_path):
@pytest.mark.usefixtures("aiida_profile_clean", "sssp", "pseudodojo")
def test_pseudos_settings(generate_structure_data, generate_upf_data):
- from aiidalab_qe.app.configuration.advanced import AdvancedModel
- from aiidalab_qe.app.configuration.advanced.pseudos import PseudoSettings
+ from aiidalab_qe.app.configuration.advanced.pseudos import (
+ PseudoSettings,
+ PseudosModel,
+ )
- model = AdvancedModel()
+ model = PseudosModel()
pseudos = PseudoSettings(model=model)
- assert model.pseudos.override is False
+ assert model.override is False
# Test the default family
model.override = True
model.spin_orbit = "wo_soc"
- assert model.pseudos.family == f"SSSP/{SSSP_VERSION}/PBEsol/efficiency"
+ assert model.family == f"SSSP/{SSSP_VERSION}/PBEsol/efficiency"
# Test protocol-dependent family change
model.protocol = "precise"
- assert model.pseudos.family == f"SSSP/{SSSP_VERSION}/PBEsol/precision"
+ assert model.family == f"SSSP/{SSSP_VERSION}/PBEsol/precision"
# Test functional-dependent family change
- model.pseudos.functional = "PBE"
- assert model.pseudos.family == f"SSSP/{SSSP_VERSION}/PBE/precision"
+ model.functional = "PBE"
+ assert model.family == f"SSSP/{SSSP_VERSION}/PBE/precision"
# Test library-dependent family change
- model.pseudos.library = "PseudoDojo stringent"
- assert (
- model.pseudos.family == f"PseudoDojo/{PSEUDODOJO_VERSION}/PBE/SR/stringent/upf"
- )
+ model.library = "PseudoDojo stringent"
+ assert model.family == f"PseudoDojo/{PSEUDODOJO_VERSION}/PBE/SR/stringent/upf"
# Test spin-orbit-dependent family change
model.spin_orbit = "soc"
model.protocol = "moderate"
- assert (
- model.pseudos.family == f"PseudoDojo/{PSEUDODOJO_VERSION}/PBE/FR/standard/upf"
- )
+ assert model.family == f"PseudoDojo/{PSEUDODOJO_VERSION}/PBE/FR/standard/upf"
- # Test structure-dependent family change
- model.reset()
+ # Reset the external dependencies of the model
+ model.spin_orbit = "wo_soc"
+ # Test structure-dependent family change
silicon = generate_structure_data("silicon")
model.input_structure = silicon
- assert "Si" in model.pseudos.dictionary.keys()
- assert model.pseudos.ecutwfc == 30
- assert model.pseudos.ecutrho == 240
+ assert "Si" in model.dictionary.keys()
+ assert model.ecutwfc == 30
+ assert model.ecutrho == 240
# Test that changing the structure triggers a reset
silica = generate_structure_data("silica")
model.input_structure = silica
- assert "Si" in model.pseudos.dictionary.keys()
- assert "O" in model.pseudos.dictionary.keys()
+ assert "Si" in model.dictionary.keys()
+ assert "O" in model.dictionary.keys()
# Test pseudo upload
pseudos.render()
@@ -205,18 +204,18 @@ def test_pseudos_settings(generate_structure_data, generate_upf_data):
},
}
)
- pseudo = model.pseudos.dictionary["O"] # type: ignore
+ pseudo = model.dictionary["O"] # type: ignore
assert orm.load_node(pseudo).filename == "O_new.upf"
# TODO necessary for final test - see comment below
- # cutoffs = [model.pseudos.ecutwfc, model.pseudos.ecutrho]
+ # cutoffs = [model.ecutwfc, model.ecutrho]
- model.pseudos.reset()
- pseudo = model.pseudos.dictionary["O"] # type: ignore
+ model.reset()
+ pseudo = model.dictionary["O"] # type: ignore
assert orm.load_node(pseudo).filename != "O_new.upf"
# TODO what is this about?
- # model.pseudos.set_pseudos(pseudos, cutoffs)
+ # model.set_pseudos(pseudos, cutoffs)
# assert orm.load_node(pseudo).filename == "O_new.upf"