From 80ef9294150e600f2ea134d99628745f077b7537 Mon Sep 17 00:00:00 2001 From: "Yngve S. Kristiansen" Date: Tue, 2 Jul 2024 14:12:18 +0200 Subject: [PATCH] Make more explicit with some safeguards to be removed --- src/ert/enkf_main.py | 1 + src/ert/run_models/base_run_model.py | 1 + src/ert/storage/local_ensemble.py | 129 +++++++----------- src/ert/storage/local_storage.py | 6 +- .../analysis/test_es_update.py | 3 + .../test_storage_migration.py | 1 + tests/unit_tests/analysis/test_es_update.py | 3 + .../gui/tools/test_manage_experiments_tool.py | 1 + .../scenarios/test_summary_response.py | 5 +- .../unit_tests/storage/test_local_storage.py | 2 + 10 files changed, 68 insertions(+), 84 deletions(-) diff --git a/src/ert/enkf_main.py b/src/ert/enkf_main.py index 5301ae5d157..af5a1c7b5f3 100644 --- a/src/ert/enkf_main.py +++ b/src/ert/enkf_main.py @@ -153,6 +153,7 @@ def sample_prior( ) ensemble.save_parameters(parameter, realization_nr, ds) + ensemble.refresh_statemap() ensemble.unify_parameters() logger.debug(f"sample_prior() time_used {(time.perf_counter() - t):.4f}s") diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index ffdf30cc5d3..583f26fcb81 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -539,6 +539,7 @@ def run_ensemble_evaluator( self._end_queue.get() return [] + run_context.ensemble.refresh_statemap() run_context.ensemble.unify_parameters() run_context.ensemble.unify_responses() diff --git a/src/ert/storage/local_ensemble.py b/src/ert/storage/local_ensemble.py index c12594abf38..a539c66f5f9 100644 --- a/src/ert/storage/local_ensemble.py +++ b/src/ert/storage/local_ensemble.py @@ -237,18 +237,41 @@ def create_realization_dir(realization: int) -> Path: self._realization_dir = create_realization_dir - self._realization_states: Optional[RealizationState] = None - self.try_read_state_map_from_file() - self._response_states_need_update = False - self._parameter_states_need_update = False + self._realization_states = ( + RealizationState.from_file(self._path / "state_map.json") + if os.path.exists(self._path / "state_map.json") + else RealizationState() + ) - def try_read_state_map_from_file(self): - if self._realization_states is None: - self._realization_states = ( - RealizationState.from_file(self._path / "state_map.json") - if os.path.exists(self._path / "state_map.json") - else None - ) + self.__response_states_need_update = False # Tmp + self.__parameter_states_need_update = False # Tmp + self._has_invoked_refresh_statemap = False + + @property + def _response_states_need_update(self) -> bool: + return self.__response_states_need_update + + @_response_states_need_update.setter + def _response_states_need_update(self, val: bool): + if val and self._has_invoked_refresh_statemap: + # Temp, all tests should pass without + # hitting this line + pass # raise AssertionError("Expected this line to never be hit") + + self.__response_states_need_update = val + + @property + def _parameter_states_need_update(self) -> bool: + return self.__parameter_states_need_update + + @_parameter_states_need_update.setter + def _parameter_states_need_update(self, val: bool): + if val and self._has_invoked_refresh_statemap: + # Temp, all tests should pass without + # hitting this line + pass + + self.__parameter_states_need_update = val @classmethod def create( @@ -484,6 +507,7 @@ def _ensure_realization_state_initialized(self) -> None: self._path / "state_map.json" ) else: + raise AssertionError("Expected this line to never be hit") self._response_states_need_update = True self._parameter_states_need_update = True self._realization_states = RealizationState() @@ -491,6 +515,7 @@ def _ensure_realization_state_initialized(self) -> None: def refresh_responses_state_if_needed(self) -> None: self._ensure_realization_state_initialized() if self._response_states_need_update: + raise AssertionError("Expected this line to never be hit") self._response_states_need_update = False self._refresh_all_responses_state_for_all_realizations() assert self._realization_states is not None @@ -500,11 +525,20 @@ def refresh_parameters_state_if_needed(self) -> None: self._ensure_realization_state_initialized() if self._parameter_states_need_update: + raise AssertionError("Expected this line to never be hit") self._parameter_states_need_update = False self._refresh_all_parameters_state_for_all_realizations() assert self._realization_states is not None self._realization_states.to_file(self._path / "state_map.json") + def refresh_statemap(self): + self._refresh_all_responses_state_for_all_realizations() + self._refresh_all_parameters_state_for_all_realizations() + self._parameter_states_need_update = False + self._response_states_need_update = False + self._has_invoked_refresh_statemap = True + self._realization_states.to_file(self._path / "state_map.json") + def _responses_exist_for_realization( self, realization: int, key: Optional[str] = None ) -> bool: @@ -1243,9 +1277,7 @@ def _refresh_all_parameters_state_for_realization(self, realization: int) -> Non for parameter_key in self.experiment.parameter_configuration: self._refresh_parameter_state(parameter_key, realization) - def _refresh_parameter_state( - self, parameter_key: str, realization: int, skip_others: bool = False - ) -> None: + def _refresh_parameter_state(self, parameter_key: str, realization: int) -> None: if self._realization_states is None: if os.path.exists(self._path / "state_map.json"): with open(self._path / "state_map.json", "r") as f: @@ -1256,27 +1288,6 @@ def _refresh_parameter_state( if self._realization_states.has_entry(realization, parameter_key): return - realizations_to_refresh = ( - range(self.ensemble_size) if not skip_others else [realization] - ) - - if self.has_combined_parameter_dataset(parameter_key): - ds = xr.open_dataset(self._path / f"{parameter_key}.nc") - - for _real in realizations_to_refresh: - _reals_with_parameter = set(ds["realizations"].values) - self._realization_states.add( - _real, - { - ( - parameter_key, - parameter_key, - _real in _reals_with_parameter, - ) - }, - ) - return - self._realization_states.add( realization, { @@ -1290,9 +1301,7 @@ def _refresh_parameter_state( }, ) - def _refresh_response_state( - self, response_key: str, realization: int, skip_others: bool = False - ) -> None: + def _refresh_response_state(self, response_key: str, realization: int) -> None: if self._realization_states is None: if os.path.exists(self._path / "state_map.json"): with open(self._path / "state_map.json", "r") as f: @@ -1305,51 +1314,13 @@ def _refresh_response_state( combined_ds_key = self._find_unified_dataset_for_response(response_key) - # ex: combined_ds_key == gen_data, response_key = WOPR_OP1 - # ex2: response_key = summary, combined_ds_key = summary - is_grouped_ds = combined_ds_key == response_key - - realizations_to_refresh = ( - range(self.ensemble_size) if not skip_others else [realization] - ) - - if self.has_combined_response_dataset(response_key): - ds = xr.open_dataset(self._path / f"{combined_ds_key}.nc") - - if is_grouped_ds: - for _real in realizations_to_refresh: - _reals_with_response = set(ds["realization"].values) - self._realization_states.add( - _real, - { - ( - combined_ds_key, - combined_ds_key, - _real in _reals_with_response, - ) - }, - ) - - return - - all_names = set(ds["name"].values) - for _key in all_names: - _ds = ds.sel(name=_key, drop=True) - reals_with_response = set( - _ds.dropna("realization", how="all")["realization"].values - ) - - for _real in realizations_to_refresh: - self._realization_states.add( - _real, {(combined_ds_key, _key, _real in reals_with_response)} - ) - - return - # We assume we will never receive "sub-keys" for grouped datasets if combined_ds_key == "summary" and response_key != combined_ds_key: raise KeyError("Did not expect sub-key for grouped dataset") + # ex: combined_ds_key == gen_data, response_key = WOPR_OP1 + # ex2: response_key = summary, combined_ds_key = summary + is_grouped_ds = combined_ds_key == response_key has_realization_dir = os.path.exists(self._realization_dir(realization)) if not has_realization_dir: diff --git a/src/ert/storage/local_storage.py b/src/ert/storage/local_storage.py index 55ace297c44..91efec43a17 100644 --- a/src/ert/storage/local_storage.py +++ b/src/ert/storage/local_storage.py @@ -134,8 +134,8 @@ def refresh(self) -> None: # into statemap # Can be removed if we know 100% that # no storages will ever have datasets in storage, but have no state map file - for ens in self._ensembles.values(): - ens.try_read_state_map_from_file() + # for ens in self._ensembles.values(): + # ens.try_read_state_map_from_file() def get_experiment(self, uuid: UUID) -> LocalExperiment: """ @@ -282,8 +282,6 @@ def close(self) -> None: if self.can_write: for ens in self._ensembles.values(): - ens.refresh_responses_state_if_needed() - ens.refresh_parameters_state_if_needed() ens.unify_responses() ens.unify_parameters() diff --git a/tests/integration_tests/analysis/test_es_update.py b/tests/integration_tests/analysis/test_es_update.py index eb811d68b0a..3863c91bb1d 100644 --- a/tests/integration_tests/analysis/test_es_update.py +++ b/tests/integration_tests/analysis/test_es_update.py @@ -253,6 +253,7 @@ def test_gen_data_obs_data_mismatch(storage, uniform_parameter): iens, ) + prior.refresh_statemap() prior.unify_responses() prior.unify_parameters() @@ -317,6 +318,7 @@ def test_gen_data_missing(storage, uniform_parameter, obs): iens, ) + prior.refresh_statemap() prior.unify_responses() prior.unify_parameters() @@ -411,6 +413,7 @@ def test_update_subset_parameters(storage, uniform_parameter, obs): iens, ) + prior.refresh_statemap() prior.unify_responses() prior.unify_parameters() diff --git a/tests/integration_tests/test_storage_migration.py b/tests/integration_tests/test_storage_migration.py index 2ebc1e7b1b5..f242834fef4 100644 --- a/tests/integration_tests/test_storage_migration.py +++ b/tests/integration_tests/test_storage_migration.py @@ -391,6 +391,7 @@ def test_that_storage_always_has_state_map_after_migrations( _ds_bpr1.coords["realizations"] = [i] ensemble.save_parameters("BPR", i, _ds_bpr1) + ensemble.refresh_statemap() ensemble.unify_parameters() ensemble.unify_responses() diff --git a/tests/unit_tests/analysis/test_es_update.py b/tests/unit_tests/analysis/test_es_update.py index 3db37490643..d20cb90f868 100644 --- a/tests/unit_tests/analysis/test_es_update.py +++ b/tests/unit_tests/analysis/test_es_update.py @@ -418,6 +418,7 @@ def test_smoother_snapshot_alpha( ), iens, ) + prior_storage.refresh_statemap() posterior_storage = storage.create_ensemble( prior_storage.experiment_id, @@ -583,6 +584,7 @@ def g(X): iens, ) + prior_ensemble.refresh_statemap() prior_ensemble.unify_parameters() prior_ensemble.unify_responses() @@ -726,6 +728,7 @@ def test_temporary_parameter_storage_with_inactive_fields( for iens in range(ensemble_size): prior_ensemble.save_parameters(param_group, iens, fields[iens]) + prior_ensemble.refresh_statemap() prior_ensemble.unify_parameters() realization_list = list(range(ensemble_size)) diff --git a/tests/unit_tests/gui/tools/test_manage_experiments_tool.py b/tests/unit_tests/gui/tools/test_manage_experiments_tool.py index 716cfebd0db..71521e7bfb6 100644 --- a/tests/unit_tests/gui/tools/test_manage_experiments_tool.py +++ b/tests/unit_tests/gui/tools/test_manage_experiments_tool.py @@ -31,6 +31,7 @@ def test_init_prior(qtbot, storage): ensemble_size=config.model_config.num_realizations, name="prior", ) + ensemble.refresh_statemap() notifier.set_current_ensemble(ensemble) assert ( ensemble.get_ensemble_state() diff --git a/tests/unit_tests/scenarios/test_summary_response.py b/tests/unit_tests/scenarios/test_summary_response.py index 8657783dc54..bcce27a1c33 100644 --- a/tests/unit_tests/scenarios/test_summary_response.py +++ b/tests/unit_tests/scenarios/test_summary_response.py @@ -17,12 +17,14 @@ @pytest.fixture def prior_ensemble(storage, ert_config): - return storage.create_experiment( + prior = storage.create_experiment( parameters=ert_config.ensemble_config.parameter_configuration, responses=ert_config.ensemble_config.response_configuration, observations=ert_config.observations.datasets, ).create_ensemble(ensemble_size=3, name="prior") + return prior + @pytest.fixture def ert_config(tmpdir): @@ -78,6 +80,7 @@ def create_responses(config_file, prior_ensemble, response_times): facade.load_from_forward_model( prior_ensemble, [True] * facade.get_ensemble_size(), 0 ) + prior_ensemble.refresh_statemap() prior_ensemble.unify_responses() diff --git a/tests/unit_tests/storage/test_local_storage.py b/tests/unit_tests/storage/test_local_storage.py index 92e3fa1b70a..f8956648d85 100644 --- a/tests/unit_tests/storage/test_local_storage.py +++ b/tests/unit_tests/storage/test_local_storage.py @@ -376,6 +376,7 @@ def test_realization_state_updates_on_re_save_response(tmp_path): ds = xr.Dataset({"values": (["report_step", "index"], [[2, 3, 4, 5, 6]])}) ens.save_response("FOPTZ", ds, 1) + ens.refresh_statemap() response_mask = ens.get_realization_mask_with_responses() assert all(response_mask == [False] * 200) @@ -385,6 +386,7 @@ def test_realization_state_updates_on_re_save_response(tmp_path): assert rstate.has(1, "gen_data") ens.save_response("FOPTZZ", ds, 1) + ens.refresh_statemap() response_mask2 = ens.get_realization_mask_with_responses() assert all(response_mask2 == [False] + [True] + [False] * 198) rstate2 = RealizationState.from_file(ens._path / "state_map.json")