diff --git a/src/aiidalab_qe/app/configuration/__init__.py b/src/aiidalab_qe/app/configuration/__init__.py index 456e3f388..70b6084c0 100644 --- a/src/aiidalab_qe/app/configuration/__init__.py +++ b/src/aiidalab_qe/app/configuration/__init__.py @@ -56,10 +56,10 @@ def __init__(self, model: ConfigurationModel, **kwargs): self.settings: dict[str, SettingsPanel] = {} - workchain_model = WorkChainModel(include=True) + workchain_model = WorkChainModel() self._model.add_model("workchain", workchain_model) - advanced_model = AdvancedModel(include=True) + advanced_model = AdvancedModel() self._model.add_model("advanced", advanced_model) self._fetch_plugin_settings() diff --git a/src/aiidalab_qe/app/configuration/advanced/advanced.py b/src/aiidalab_qe/app/configuration/advanced/advanced.py index 39503639d..327a72bf4 100644 --- a/src/aiidalab_qe/app/configuration/advanced/advanced.py +++ b/src/aiidalab_qe/app/configuration/advanced/advanced.py @@ -341,7 +341,7 @@ def _on_input_structure_change(self, _): def _on_protocol_change(self, _): self.refresh(specific="protocol") - def _on_kpoints_distance_change(self, _=None): + def _on_kpoints_distance_change(self, _): self.refresh(specific="mesh") def _on_override_change(self, change): diff --git a/src/aiidalab_qe/app/configuration/advanced/hubbard/model.py b/src/aiidalab_qe/app/configuration/advanced/hubbard/model.py index 51739314e..de5ffd66c 100644 --- a/src/aiidalab_qe/app/configuration/advanced/hubbard/model.py +++ b/src/aiidalab_qe/app/configuration/advanced/hubbard/model.py @@ -49,8 +49,19 @@ def __init__(self, *args, **kwargs): } def update(self, specific=""): + if self.input_structure is None: + self.applicable_kinds = [] + self.orbital_labels = [] + self._defaults |= { + "parameters": {}, + "eigenvalues": [], + } + else: + self.orbital_labels = self._define_orbital_labels() + self._defaults["parameters"] = self._define_default_parameters() + self.applicable_kinds = self._define_applicable_kinds() + self._defaults["eigenvalues"] = self._define_default_eigenvalues() with self.hold_trait_notifications(): - self._update_defaults(specific) self.parameters = self._get_default_parameters() self.eigenvalues = self._get_default_eigenvalues() self.needs_eigenvalues_widget = len(self.applicable_kinds) > 0 @@ -82,20 +93,6 @@ def reset(self): self.parameters = self._get_default_parameters() self.eigenvalues = self._get_default_eigenvalues() - def _update_defaults(self, specific=""): - if self.input_structure is None: - self.applicable_kinds = [] - self.orbital_labels = [] - self._defaults |= { - "parameters": {}, - "eigenvalues": [], - } - else: - self.orbital_labels = self._define_orbital_labels() - self._defaults["parameters"] = self._define_default_parameters() - self.applicable_kinds = self._define_applicable_kinds() - self._defaults["eigenvalues"] = self._define_default_eigenvalues() - def _define_orbital_labels(self): hubbard_manifold_list = [ self._get_manifold(Element(kind.symbol)) diff --git a/src/aiidalab_qe/app/configuration/advanced/magnetization/model.py b/src/aiidalab_qe/app/configuration/advanced/magnetization/model.py index f031fd3f8..1b29fef4e 100644 --- a/src/aiidalab_qe/app/configuration/advanced/magnetization/model.py +++ b/src/aiidalab_qe/app/configuration/advanced/magnetization/model.py @@ -35,8 +35,13 @@ def __init__(self, *args, **kwargs): } def update(self, specific=""): + if self.spin_type == "none" or self.input_structure is None: + self._defaults["moments"] = {} + else: + self._defaults["moments"] = { + symbol: 0.0 for symbol in self.input_structure.get_kind_names() + } with self.hold_trait_notifications(): - self._update_defaults(specific) self.moments = self._get_default_moments() def reset(self): @@ -45,13 +50,5 @@ def reset(self): self.total = self.traits()["total"].default_value self.moments = self._get_default_moments() - def _update_defaults(self, specific=""): - if self.spin_type == "none" or self.input_structure is None: - self._defaults["moments"] = {} - else: - self._defaults["moments"] = { - symbol: 0.0 for symbol in self.input_structure.get_kind_names() - } - def _get_default_moments(self): return deepcopy(self._defaults["moments"]) diff --git a/src/aiidalab_qe/app/configuration/advanced/model.py b/src/aiidalab_qe/app/configuration/advanced/model.py index 25d4bd23f..246a74f2e 100644 --- a/src/aiidalab_qe/app/configuration/advanced/model.py +++ b/src/aiidalab_qe/app/configuration/advanced/model.py @@ -60,8 +60,10 @@ class AdvancedModel(SettingsModel): kpoints_distance = tl.Float(0.0) mesh_grid = tl.Unicode("") - def __init__(self, include=False, *args, **kwargs): - super().__init__(include, *args, **kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.include = True self.dftd3_version = { "dft-d3": 3, @@ -84,19 +86,21 @@ def __init__(self, include=False, *args, **kwargs): def update(self, specific=""): with self.hold_trait_notifications(): - self._update_defaults(specific) - self.forc_conv_thr = self._defaults["forc_conv_thr"] - self.forc_conv_thr_step = self._defaults["forc_conv_thr_step"] - self.etot_conv_thr = self._defaults["etot_conv_thr"] - self.etot_conv_thr_step = self._defaults["etot_conv_thr_step"] - self.scf_conv_thr = self._defaults["scf_conv_thr"] - self.scf_conv_thr_step = self._defaults["scf_conv_thr_step"] - self.kpoints_distance = self._defaults["kpoints_distance"] + if not specific or specific != "mesh": + parameters = PwBaseWorkChain.get_protocol_inputs(self.protocol) + self._update_kpoints_distance(parameters) + self._update_thresholds(parameters) + self._update_kpoints_mesh() def add_model(self, identifier, model): self._models[identifier] = model self._link_model(model) + def get_model(self, identifier) -> AdvancedSubModel: + if identifier in self._models: + return self._models[identifier] + raise ValueError(f"Model with identifier '{identifier}' not found.") + def get_models(self): return self._models.items() @@ -122,7 +126,7 @@ def get_model_state(self): "kpoints_distance": self.kpoints_distance, } - hubbard: HubbardModel = self._get_model("hubbard") # type: ignore + hubbard: HubbardModel = self.get_model("hubbard") # type: ignore if hubbard.is_active: parameters["hubbard_parameters"] = {"hubbard_u": hubbard.parameters} if hubbard.has_eigenvalues: @@ -130,7 +134,7 @@ def get_model_state(self): "starting_ns_eigenvalue": hubbard.get_active_eigenvalues() } - pseudos: PseudosModel = self._get_model("pseudos") # type: ignore + pseudos: PseudosModel = self.get_model("pseudos") # type: ignore parameters["pseudo_family"] = pseudos.family if pseudos.dictionary: parameters["pw"]["pseudos"] = pseudos.dictionary @@ -145,8 +149,8 @@ def get_model_state(self): self.dftd3_version[self.van_der_waals] ) - smearing: SmearingModel = self._get_model("smearing") # type: ignore - magnetization: MagnetizationModel = self._get_model("magnetization") # type: ignore + smearing: SmearingModel = self.get_model("smearing") # type: ignore + magnetization: MagnetizationModel = self.get_model("magnetization") # type: ignore if self.spin_type == "collinear": parameters["initial_magnetic_moments"] = magnetization.moments if self.electronic_type == "metal": @@ -181,7 +185,7 @@ def get_model_state(self): return parameters def set_model_state(self, parameters): - pseudos: PseudosModel = self._get_model("pseudos") # type: ignore + pseudos: PseudosModel = self.get_model("pseudos") # type: ignore if "pseudo_family" in parameters: pseudo_family = PseudoFamily.from_string(parameters["pseudo_family"]) library = pseudo_family.library @@ -199,7 +203,7 @@ def set_model_state(self, parameters): if (pw_parameters := parameters.get("pw", {}).get("parameters")) is not None: self._set_pw_parameters(pw_parameters) - magnetization: MagnetizationModel = self._get_model("magnetization") # type: ignore + magnetization: MagnetizationModel = self.get_model("magnetization") # type: ignore if magnetic_moments := parameters.get("initial_magnetic_moments"): if isinstance(magnetic_moments, (int, float)): magnetic_moments = [magnetic_moments] @@ -212,7 +216,7 @@ def set_model_state(self, parameters): ) magnetization.moments = magnetic_moments - hubbard: HubbardModel = self._get_model("hubbard") # type: ignore + hubbard: HubbardModel = self.get_model("hubbard") # type: ignore if parameters.get("hubbard_parameters"): hubbard.is_active = True hubbard.parameters = parameters["hubbard_parameters"]["hubbard_u"] @@ -256,16 +260,6 @@ def _link_model(self, model): (model, trait), ) - def _update_defaults(self, specific=""): - if not specific or specific != "mesh": - parameters = PwBaseWorkChain.get_protocol_inputs(self.protocol) - self._update_kpoints_distance(parameters) - - self._update_kpoints_mesh() - - if not specific or specific == "protocol": - self._update_thresholds(parameters) - def _update_kpoints_mesh(self, _=None): if self.input_structure is None: mesh_grid = "" @@ -284,18 +278,25 @@ def _update_kpoints_mesh(self, _=None): def _update_kpoints_distance(self, parameters): kpoints_distance = parameters["kpoints_distance"] if self.has_pbc else 100.0 self._defaults["kpoints_distance"] = kpoints_distance + self.kpoints_distance = self._defaults["kpoints_distance"] def _update_thresholds(self, parameters): num_atoms = len(self.input_structure.sites) if self.input_structure else 1 etot_value = num_atoms * parameters["meta_parameters"]["etot_conv_thr_per_atom"] self._set_value_and_step("etot_conv_thr", etot_value) + self.etot_conv_thr = self._defaults["etot_conv_thr"] + self.etot_conv_thr_step = self._defaults["etot_conv_thr_step"] scf_value = num_atoms * parameters["meta_parameters"]["conv_thr_per_atom"] self._set_value_and_step("scf_conv_thr", scf_value) + self.scf_conv_thr = self._defaults["scf_conv_thr"] + self.scf_conv_thr_step = self._defaults["scf_conv_thr_step"] forc_value = parameters["pw"]["parameters"]["CONTROL"]["forc_conv_thr"] self._set_value_and_step("forc_conv_thr", forc_value) + self.forc_conv_thr = self._defaults["forc_conv_thr"] + self.forc_conv_thr_step = self._defaults["forc_conv_thr_step"] def _set_value_and_step(self, attribute, value): self._defaults[attribute] = value @@ -324,18 +325,13 @@ def _set_pw_parameters(self, pw_parameters): system_params.get("vdw_corr", "none"), ) - smearing: SmearingModel = self._get_model("smearing") # type: ignore + smearing: SmearingModel = self.get_model("smearing") # type: ignore if "degauss" in system_params: smearing.degauss = system_params["degauss"] if "smearing" in system_params: smearing.type = system_params["smearing"] - magnetization: MagnetizationModel = self._get_model("magnetization") # type: ignore + magnetization: MagnetizationModel = self.get_model("magnetization") # type: ignore if "tot_magnetization" in system_params: magnetization.type = "tot_magnetization" - - def _get_model(self, identifier) -> AdvancedSubModel: - if identifier in self._models: - return self._models[identifier] - raise ValueError(f"Model with identifier '{identifier}' not found.") diff --git a/src/aiidalab_qe/app/configuration/advanced/pseudos/model.py b/src/aiidalab_qe/app/configuration/advanced/pseudos/model.py index 03c5068dd..7da4d3fb3 100644 --- a/src/aiidalab_qe/app/configuration/advanced/pseudos/model.py +++ b/src/aiidalab_qe/app/configuration/advanced/pseudos/model.py @@ -129,7 +129,16 @@ def __init__(self, *args, **kwargs): def update(self, specific=""): with self.hold_trait_notifications(): - self._update_defaults(specific) + if self.input_structure is None: + self._defaults |= { + "dictionary": {}, + "cutoffs": [[0.0], [0.0]], + } + else: + self.update_default_pseudos() + self.update_default_cutoffs() + self.update_family_parameters() + self.update_family() def update_default_pseudos(self): try: @@ -262,18 +271,6 @@ def reset(self): self.family_help_message = self.PSEUDO_HELP_WO_SOC self.status_message = "" - def _update_defaults(self, specific=""): - if self.input_structure is None: - self._defaults |= { - "dictionary": {}, - "cutoffs": [[0.0], [0.0]], - } - else: - self.update_default_pseudos() - self.update_default_cutoffs() - self.update_family_parameters() - self.update_family() - def _get_pseudo_family_from_database(self): """Get the pseudo family from the database.""" return ( diff --git a/src/aiidalab_qe/app/configuration/advanced/smearing/model.py b/src/aiidalab_qe/app/configuration/advanced/smearing/model.py index 5dd94b0cf..6d626f988 100644 --- a/src/aiidalab_qe/app/configuration/advanced/smearing/model.py +++ b/src/aiidalab_qe/app/configuration/advanced/smearing/model.py @@ -27,17 +27,6 @@ def __init__(self, *args, **kwargs): } def update(self, specific=""): - with self.hold_trait_notifications(): - self._update_defaults(specific) - self.type = self._defaults["type"] - self.degauss = self._defaults["degauss"] - - def reset(self): - with self.hold_trait_notifications(): - self.type = self._defaults["type"] - self.degauss = self._defaults["degauss"] - - def _update_defaults(self, specific=""): parameters = ( PwBaseWorkChain.get_protocol_inputs(self.protocol) .get("pw", {}) @@ -48,3 +37,11 @@ def _update_defaults(self, specific=""): "type": parameters["smearing"], "degauss": parameters["degauss"], } + with self.hold_trait_notifications(): + self.type = self._defaults["type"] + self.degauss = self._defaults["degauss"] + + def reset(self): + with self.hold_trait_notifications(): + self.type = self._defaults["type"] + self.degauss = self._defaults["degauss"] diff --git a/src/aiidalab_qe/app/configuration/advanced/subsettings.py b/src/aiidalab_qe/app/configuration/advanced/subsettings.py index f0e4c69e6..b7105fd54 100644 --- a/src/aiidalab_qe/app/configuration/advanced/subsettings.py +++ b/src/aiidalab_qe/app/configuration/advanced/subsettings.py @@ -1,3 +1,5 @@ +import os + import ipywidgets as ipw import traitlets as tl @@ -21,16 +23,6 @@ def reset(self): """Resets the model to present defaults.""" raise NotImplementedError - def _update_defaults(self, specific=""): - """Updates the model's default values. - - Parameters - ---------- - `specific` : `str`, optional - If provided, specifies the level of update. - """ - raise NotImplementedError - class AdvancedSubSettings(ipw.VBox): identifier = "sub" @@ -74,6 +66,9 @@ def refresh(self, specific=""): self.updated = False self._unsubscribe() self._update(specific) + if "PYTEST_CURRENT_TEST" in os.environ: + # Skip resetting to avoid having to inject a structure when testing + return if hasattr(self._model, "input_structure") and not self._model.input_structure: self._reset() diff --git a/src/aiidalab_qe/app/configuration/basic/model.py b/src/aiidalab_qe/app/configuration/basic/model.py index 3e733c0f8..a40f4e932 100644 --- a/src/aiidalab_qe/app/configuration/basic/model.py +++ b/src/aiidalab_qe/app/configuration/basic/model.py @@ -18,8 +18,10 @@ class WorkChainModel(SettingsModel): spin_type = tl.Unicode(DEFAULT["workchain"]["spin_type"]) electronic_type = tl.Unicode(DEFAULT["workchain"]["electronic_type"]) - def __init__(self, include=False, *args, **kwargs): - super().__init__(include, *args, **kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.include = True self._defaults = { "protocol": self.traits()["protocol"].default_value, diff --git a/src/aiidalab_qe/app/configuration/model.py b/src/aiidalab_qe/app/configuration/model.py index 270ba96b7..83b0862be 100644 --- a/src/aiidalab_qe/app/configuration/model.py +++ b/src/aiidalab_qe/app/configuration/model.py @@ -51,23 +51,53 @@ def __init__(self, *args, **kwargs): """ def update(self): - self._update_defaults() - self.relax_type_help = self._get_default_relax_type_help() - self.relax_type_options = self._get_default_relax_type_options() - self.relax_type = self._get_default_relax_type() + if self.has_pbc: + relax_type_help = self.relax_type_help_template.format( + option_count="three", + full_relaxation_option=( + """ +
+ (3) Full geometry: perform a full relaxation of the internal atomic + coordinates and the cell parameters. + """ + ), + ) + relax_type_options = [ + ("Structure as is", "none"), + ("Atomic positions", "positions"), + ("Full geometry", "positions_cell"), + ] + else: + relax_type_help = self.relax_type_help_template.format( + option_count="two", + full_relaxation_option="", + ) + relax_type_options = [ + ("Structure as is", "none"), + ("Atomic positions", "positions"), + ] + self._defaults = { + "relax_type_help": relax_type_help, + "relax_type_options": relax_type_options, + "relax_type": relax_type_options[-1][-1], + } + with self.hold_trait_notifications(): + self.relax_type_help = self._get_default_relax_type_help() + self.relax_type_options = self._get_default_relax_type_options() + self.relax_type = self._get_default_relax_type() def add_model(self, identifier, model): self._models[identifier] = model self._link_model(model) - def get_models(self): - return self._models.items() - def get_model(self, identifier) -> SettingsModel: if identifier in self._models: return self._models[identifier] raise ValueError(f"Model with identifier '{identifier}' not found.") + def get_models(self): + return self._models.items() + def get_model_state(self): parameters = { identifier: model.get_model_state() @@ -86,9 +116,9 @@ def set_model_state(self, parameters): self.relax_type = workchain_parameters.get("relax_type") properties = set(workchain_parameters.get("properties", [])) for identifier, model in self._models.items(): + model.include = identifier in self._default_models | properties if parameters.get(identifier): model.set_model_state(parameters[identifier]) - model.include = identifier in self._default_models | properties def reset(self): self.confirmed = False @@ -134,38 +164,6 @@ def _get_properties(self): properties.append("relax") return properties - def _update_defaults(self): - if self.has_pbc: - relax_type_help = self.relax_type_help_template.format( - option_count="three", - full_relaxation_option=( - """ -
- (3) Full geometry: perform a full relaxation of the internal atomic - coordinates and the cell parameters. - """ - ), - ) - relax_type_options = [ - ("Structure as is", "none"), - ("Atomic positions", "positions"), - ("Full geometry", "positions_cell"), - ] - else: - relax_type_help = self.relax_type_help_template.format( - option_count="two", - full_relaxation_option="", - ) - relax_type_options = [ - ("Structure as is", "none"), - ("Atomic positions", "positions"), - ] - self._defaults = { - "relax_type_help": relax_type_help, - "relax_type_options": relax_type_options, - "relax_type": relax_type_options[-1][-1], - } - def _get_default_relax_type_help(self): return self._defaults.get("relax_type_help", "") diff --git a/src/aiidalab_qe/common/panel.py b/src/aiidalab_qe/common/panel.py index 3bd7f81b6..878a51db6 100644 --- a/src/aiidalab_qe/common/panel.py +++ b/src/aiidalab_qe/common/panel.py @@ -75,10 +75,9 @@ class SettingsModel(tl.HasTraits): _defaults = {} - def __init__(self, include=False, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.observe(self.unconfirm, tl.All) - self.include = include @property def has_pbc(self): @@ -112,16 +111,6 @@ def unconfirm(self, change): if change["name"] != "confirmed": self.confirmed = False - def _update_defaults(self, specific=""): - """Updates the model's default values. - - Parameters - ---------- - `specific` : `str`, optional - If provided, specifies the level of update. - """ - raise NotImplementedError - class SettingsPanel(Panel): title = "Settings" @@ -161,9 +150,8 @@ def refresh(self, specific=""): """ self.updated = False self._unsubscribe() - if not self._model.include: - return - self.update(specific) + if self._model.include: + self.update(specific) if "PYTEST_CURRENT_TEST" in os.environ: # Skip resetting to avoid having to inject a structure when testing return diff --git a/src/aiidalab_qe/common/widgets.py b/src/aiidalab_qe/common/widgets.py index 42ee7cfe0..a27152cf3 100644 --- a/src/aiidalab_qe/common/widgets.py +++ b/src/aiidalab_qe/common/widgets.py @@ -957,7 +957,6 @@ def __init__(self, title=None, **kwargs): def render(self): if self.rendered: return - from aiidalab_qe.common.widgets import LoadingWidget self.children = [LoadingWidget(f"Loading {self.title} widget")] diff --git a/src/aiidalab_qe/plugins/pdos/model.py b/src/aiidalab_qe/plugins/pdos/model.py index deb7e541f..8c664acae 100644 --- a/src/aiidalab_qe/plugins/pdos/model.py +++ b/src/aiidalab_qe/plugins/pdos/model.py @@ -22,8 +22,8 @@ class PdosModel(SettingsModel): use_pdos_degauss = tl.Bool(False) pdos_degauss = tl.Float(0.005) - def __init__(self, include=False, *args, **kwargs): - super().__init__(include, *args, **kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self._defaults = { "kpoints_distance": self.traits()["kpoints_distance"].default_value, @@ -31,8 +31,11 @@ def __init__(self, include=False, *args, **kwargs): def update(self, specific=""): with self.hold_trait_notifications(): - self._update_defaults(specific) - self.kpoints_distance = self._defaults["kpoints_distance"] + if not specific or specific != "mesh": + parameters = PdosWorkChain.get_protocol_inputs(self.protocol) + self._update_kpoints_distance(parameters) + + self._update_kpoints_mesh() def get_model_state(self): return { @@ -55,13 +58,6 @@ def reset(self): self.use_pdos_degauss = self.traits()["use_pdos_degauss"].default_value self.pdos_degauss = self.traits()["pdos_degauss"].default_value - def _update_defaults(self, specific=""): - if not specific or specific != "mesh": - parameters = PdosWorkChain.get_protocol_inputs(self.protocol) - self._update_kpoints_distance(parameters) - - self._update_kpoints_mesh() - def _update_kpoints_mesh(self, _=None): if self.input_structure is None: mesh_grid = "" @@ -84,3 +80,4 @@ def _update_kpoints_distance(self, parameters): kpoints_distance = 100.0 self.use_pdos_degauss = True self._defaults["kpoints_distance"] = kpoints_distance + self.kpoints_distance = self._defaults["kpoints_distance"] diff --git a/src/aiidalab_qe/plugins/xas/model.py b/src/aiidalab_qe/plugins/xas/model.py index 67c919a5a..a8b70f929 100644 --- a/src/aiidalab_qe/plugins/xas/model.py +++ b/src/aiidalab_qe/plugins/xas/model.py @@ -60,8 +60,9 @@ def __init__(self, *args, **kwargs): self.installed_pseudos = False def update(self, specific=""): - self._update_pseudos() - self._update_core_hole_treatment_recommendations() + with self.hold_trait_notifications(): + self._update_pseudos() + self._update_core_hole_treatment_recommendations() def get_model_state(self): pseudo_labels = {} diff --git a/src/aiidalab_qe/plugins/xps/model.py b/src/aiidalab_qe/plugins/xps/model.py index 4498bc1c7..4911be55a 100644 --- a/src/aiidalab_qe/plugins/xps/model.py +++ b/src/aiidalab_qe/plugins/xps/model.py @@ -34,9 +34,10 @@ class XpsModel(SettingsModel): ) def update(self, specific=""): - self._update_correction_energies() - if not specific or specific == "pseudo_group": - self._update_pseudos() + with self.hold_trait_notifications(): + self._update_correction_energies() + if not specific or specific == "pseudo_group": + self._update_pseudos() def get_supported_core_levels(self): supported_core_levels = {} @@ -67,7 +68,6 @@ def set_model_state(self, parameters: dict): self.traits()["structure_type"].default_value, ) - # TODO check logic core_level_list = parameters.get("core_level_list", []) for orbital in self.core_levels: if orbital in core_level_list: diff --git a/tests/configuration/test_advanced.py b/tests/configuration/test_advanced.py index 4a9a1f43b..79a69342d 100644 --- a/tests/configuration/test_advanced.py +++ b/tests/configuration/test_advanced.py @@ -7,27 +7,31 @@ def test_advanced_default(): """Test default behavior of advanced setting.""" model = AdvancedModel() _ = AdvancedSettings(model=model) + smearing = model.get_model("smearing") # Test override functionality in advanced settings model.override = True model.protocol = "fast" - model.smearing.type = "methfessel-paxton" - model.smearing.degauss = 0.03 + smearing.type = "methfessel-paxton" + smearing.degauss = 0.03 model.kpoints_distance = 0.22 # Reset values to default w.r.t protocol model.override = False - assert model.smearing.type == "cold" - assert model.smearing.degauss == 0.01 + assert smearing.type == "cold" + assert smearing.degauss == 0.01 assert model.kpoints_distance == 0.5 def test_advanced_smearing_settings(): """Test Smearing Settings.""" - from aiidalab_qe.app.configuration.advanced.smearing import SmearingSettings + from aiidalab_qe.app.configuration.advanced.smearing import ( + SmearingModel, + SmearingSettings, + ) - model = AdvancedModel() + model = SmearingModel() smearing = SmearingSettings(model=model) smearing.render() @@ -40,23 +44,22 @@ def test_advanced_smearing_settings(): assert smearing.degauss.disabled is False assert smearing.smearing.disabled is False - assert model.smearing.type == "cold" - assert model.smearing.degauss == 0.01 + assert model.type == "cold" + assert model.degauss == 0.01 # Test protocol-dependent smearing change model.protocol = "fast" - assert model.smearing.type == "cold" - assert model.smearing.degauss == 0.01 + assert model.type == "cold" + assert model.degauss == 0.01 # Check reset - model.smearing.type = "gaussian" - model.smearing.degauss = 0.05 - model.smearing.reset() + model.type = "gaussian" + model.degauss = 0.05 + model.override = False - assert model.protocol == "fast" # reset does not apply to protocol - assert model.smearing.type == "cold" - assert model.smearing.degauss == 0.01 + assert model.type == "cold" + assert model.degauss == 0.01 def test_advanced_kpoints_settings(): @@ -156,9 +159,12 @@ def test_advanced_kpoints_mesh(generate_structure_data): @pytest.mark.usefixtures("aiida_profile_clean", "sssp") def test_advanced_hubbard_settings(generate_structure_data): """Test Hubbard widget.""" - from aiidalab_qe.app.configuration.advanced.hubbard import HubbardSettings + from aiidalab_qe.app.configuration.advanced.hubbard import ( + HubbardModel, + HubbardSettings, + ) - model = AdvancedModel() + model = HubbardModel() hubbard = HubbardSettings(model=model) hubbard.render() @@ -166,8 +172,8 @@ def test_advanced_hubbard_settings(generate_structure_data): model.input_structure = structure # Activate Hubbard U widget - model.hubbard.is_active = True - assert model.hubbard.orbital_labels == ["Co - 3d", "O - 2p", "Li - 2s"] + model.is_active = True + assert model.orbital_labels == ["Co - 3d", "O - 2p", "Li - 2s"] # Change the Hubbard U parameters for Co, O, and Li hubbard_parameters = hubbard.hubbard_widget.children[1:] # type: ignore @@ -175,7 +181,7 @@ def test_advanced_hubbard_settings(generate_structure_data): hubbard_parameters[1].value = 2 # O - 2p hubbard_parameters[2].value = 3 # Li - 2s - assert model.hubbard.parameters == { + assert model.parameters == { "Co - 3d": 1.0, "O - 2p": 2.0, "Li - 2s": 3.0, @@ -191,9 +197,9 @@ def test_advanced_hubbard_settings(generate_structure_data): # assert model.hubbard.eigenvalues == [] # TODO should they be? # Check there is only eigenvalues for Co (Transition metal) - model.hubbard.has_eigenvalues = True - assert len(model.hubbard.applicable_kinds) == 1 - assert len(model.hubbard.eigenvalues) == 1 + model.has_eigenvalues = True + assert len(model.applicable_kinds) == 1 + assert len(model.eigenvalues) == 1 Co_eigenvalues = hubbard.eigenvalues_widget.children[0].children[1] # type: ignore Co_spin_down_row = Co_eigenvalues.children[1] @@ -201,7 +207,7 @@ def test_advanced_hubbard_settings(generate_structure_data): Co_spin_down_row.children[3].value = "1" Co_spin_down_row.children[5].value = "1" - assert model.hubbard.get_active_eigenvalues() == [ + assert model.get_active_eigenvalues() == [ [1, 1, "Co", 1], [3, 1, "Co", 1], [5, 1, "Co", 1], diff --git a/tests/conftest.py b/tests/conftest.py index 4596af2a5..72fc2d854 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -461,15 +461,16 @@ def _submit_app_generator( advanced_model.electron_maxstep = electron_maxstep if isinstance(initial_magnetic_moments, (int, float)): initial_magnetic_moments = [initial_magnetic_moments] - advanced_model.magnetization.moments = dict( + advanced_model.get_model("magnetization").moments = dict( zip( app.configure_model.input_structure.get_kind_names(), initial_magnetic_moments, ) ) # mimic the behavior of the smearing widget set up - advanced_model.smearing.type = smearing - advanced_model.smearing.degauss = degauss + smearing_model = advanced_model.get_model("smearing") + smearing_model.type = smearing + smearing_model.degauss = degauss app.configure_step.confirm() app.submit_model.input_structure = generate_structure_data() @@ -739,7 +740,9 @@ def _generate_qeapp_workchain( from_example.children[0].value = from_example.children[0].options[1][1] else: structure.store() - aiida_database = app.structure_step.manager.children[0].children[2] # type: ignore + aiida_database_wrapper = app.structure_step.manager.children[0].children[2] # type: ignore + aiida_database_wrapper.render() + aiida_database = aiida_database_wrapper.children[0] # type: ignore aiida_database.search() aiida_database.results.value = structure @@ -764,20 +767,20 @@ def _generate_qeapp_workchain( if spin_type == "collinear": advanced_model.override = True - magnetization = advanced_model.magnetization + magnetization_model = advanced_model.get_model("magnetization") if electronic_type == "insulator": - magnetization.total = tot_magnetization + magnetization_model.total = tot_magnetization elif magnetization_type == "starting_magnetization": if isinstance(initial_magnetic_moments, (int, float)): initial_magnetic_moments = [initial_magnetic_moments] - magnetization.moments = dict( + magnetization_model.moments = dict( zip( structure.get_kind_names(), initial_magnetic_moments, ) ) else: - magnetization.total = tot_magnetization + magnetization_model.total = tot_magnetization app.configure_step.confirm() diff --git a/tests/test_app.py b/tests/test_app.py index 996455826..6c1681733 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -19,7 +19,8 @@ def test_reload_and_reset(generate_qeapp_workchain): assert app.configure_model.get_model("workchain").spin_type == "collinear" assert app.configure_model.get_model("bands").include is True assert app.configure_model.get_model("pdos").include is False - assert len(app.configure_model.get_model("advanced").pseudos.dictionary) > 0 + advanced_model = app.configure_model.get_model("advanced") + assert len(advanced_model.get_model("pseudos").dictionary) > 0 assert app.configure_step.state == app.configure_step.State.SUCCESS diff --git a/tests/test_configure.py b/tests/test_configure.py index 509a039cd..716709537 100644 --- a/tests/test_configure.py +++ b/tests/test_configure.py @@ -71,6 +71,6 @@ def test_reminder_info(): assert bands_info.value == "" bands_model = model.get_model("bands") bands_model.include = True - assert bands_info.value == "Customize bands settings in step 2.2 if needed" + assert bands_info.value == "Customize bands settings in Step 2.2 if needed" bands_model.include = False assert bands_info.value == "" diff --git a/tests/test_pseudo.py b/tests/test_pseudo.py index 3c193c6f1..df3e9ffc4 100644 --- a/tests/test_pseudo.py +++ b/tests/test_pseudo.py @@ -142,54 +142,53 @@ def test_download_and_install_pseudo_from_file(tmp_path): @pytest.mark.usefixtures("aiida_profile_clean", "sssp", "pseudodojo") def test_pseudos_settings(generate_structure_data, generate_upf_data): - from aiidalab_qe.app.configuration.advanced import AdvancedModel - from aiidalab_qe.app.configuration.advanced.pseudos import PseudoSettings + from aiidalab_qe.app.configuration.advanced.pseudos import ( + PseudoSettings, + PseudosModel, + ) - model = AdvancedModel() + model = PseudosModel() pseudos = PseudoSettings(model=model) - assert model.pseudos.override is False + assert model.override is False # Test the default family model.override = True model.spin_orbit = "wo_soc" - assert model.pseudos.family == f"SSSP/{SSSP_VERSION}/PBEsol/efficiency" + assert model.family == f"SSSP/{SSSP_VERSION}/PBEsol/efficiency" # Test protocol-dependent family change model.protocol = "precise" - assert model.pseudos.family == f"SSSP/{SSSP_VERSION}/PBEsol/precision" + assert model.family == f"SSSP/{SSSP_VERSION}/PBEsol/precision" # Test functional-dependent family change - model.pseudos.functional = "PBE" - assert model.pseudos.family == f"SSSP/{SSSP_VERSION}/PBE/precision" + model.functional = "PBE" + assert model.family == f"SSSP/{SSSP_VERSION}/PBE/precision" # Test library-dependent family change - model.pseudos.library = "PseudoDojo stringent" - assert ( - model.pseudos.family == f"PseudoDojo/{PSEUDODOJO_VERSION}/PBE/SR/stringent/upf" - ) + model.library = "PseudoDojo stringent" + assert model.family == f"PseudoDojo/{PSEUDODOJO_VERSION}/PBE/SR/stringent/upf" # Test spin-orbit-dependent family change model.spin_orbit = "soc" model.protocol = "moderate" - assert ( - model.pseudos.family == f"PseudoDojo/{PSEUDODOJO_VERSION}/PBE/FR/standard/upf" - ) + assert model.family == f"PseudoDojo/{PSEUDODOJO_VERSION}/PBE/FR/standard/upf" - # Test structure-dependent family change - model.reset() + # Reset the external dependencies of the model + model.spin_orbit = "wo_soc" + # Test structure-dependent family change silicon = generate_structure_data("silicon") model.input_structure = silicon - assert "Si" in model.pseudos.dictionary.keys() - assert model.pseudos.ecutwfc == 30 - assert model.pseudos.ecutrho == 240 + assert "Si" in model.dictionary.keys() + assert model.ecutwfc == 30 + assert model.ecutrho == 240 # Test that changing the structure triggers a reset silica = generate_structure_data("silica") model.input_structure = silica - assert "Si" in model.pseudos.dictionary.keys() - assert "O" in model.pseudos.dictionary.keys() + assert "Si" in model.dictionary.keys() + assert "O" in model.dictionary.keys() # Test pseudo upload pseudos.render() @@ -205,18 +204,18 @@ def test_pseudos_settings(generate_structure_data, generate_upf_data): }, } ) - pseudo = model.pseudos.dictionary["O"] # type: ignore + pseudo = model.dictionary["O"] # type: ignore assert orm.load_node(pseudo).filename == "O_new.upf" # TODO necessary for final test - see comment below - # cutoffs = [model.pseudos.ecutwfc, model.pseudos.ecutrho] + # cutoffs = [model.ecutwfc, model.ecutrho] - model.pseudos.reset() - pseudo = model.pseudos.dictionary["O"] # type: ignore + model.reset() + pseudo = model.dictionary["O"] # type: ignore assert orm.load_node(pseudo).filename != "O_new.upf" # TODO what is this about? - # model.pseudos.set_pseudos(pseudos, cutoffs) + # model.set_pseudos(pseudos, cutoffs) # assert orm.load_node(pseudo).filename == "O_new.upf"