Skip to content

Commit

Permalink
Make more explicit
Browse files Browse the repository at this point in the history
with some safeguards to be removed
  • Loading branch information
Yngve S. Kristiansen committed Jul 2, 2024
1 parent 0b9a2ae commit 80ef929
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 84 deletions.
1 change: 1 addition & 0 deletions src/ert/enkf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def sample_prior(
)
ensemble.save_parameters(parameter, realization_nr, ds)

ensemble.refresh_statemap()
ensemble.unify_parameters()

logger.debug(f"sample_prior() time_used {(time.perf_counter() - t):.4f}s")
Expand Down
1 change: 1 addition & 0 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def run_ensemble_evaluator(
self._end_queue.get()
return []

run_context.ensemble.refresh_statemap()
run_context.ensemble.unify_parameters()
run_context.ensemble.unify_responses()

Expand Down
129 changes: 50 additions & 79 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,18 +237,41 @@ def create_realization_dir(realization: int) -> Path:

self._realization_dir = create_realization_dir

self._realization_states: Optional[RealizationState] = None
self.try_read_state_map_from_file()
self._response_states_need_update = False
self._parameter_states_need_update = False
self._realization_states = (
RealizationState.from_file(self._path / "state_map.json")
if os.path.exists(self._path / "state_map.json")
else RealizationState()
)

def try_read_state_map_from_file(self):
if self._realization_states is None:
self._realization_states = (
RealizationState.from_file(self._path / "state_map.json")
if os.path.exists(self._path / "state_map.json")
else None
)
self.__response_states_need_update = False # Tmp
self.__parameter_states_need_update = False # Tmp
self._has_invoked_refresh_statemap = False

@property
def _response_states_need_update(self) -> bool:
return self.__response_states_need_update

@_response_states_need_update.setter
def _response_states_need_update(self, val: bool):
if val and self._has_invoked_refresh_statemap:
# Temp, all tests should pass without
# hitting this line
pass # raise AssertionError("Expected this line to never be hit")

self.__response_states_need_update = val

@property
def _parameter_states_need_update(self) -> bool:
return self.__parameter_states_need_update

@_parameter_states_need_update.setter
def _parameter_states_need_update(self, val: bool):
if val and self._has_invoked_refresh_statemap:
# Temp, all tests should pass without
# hitting this line
pass

self.__parameter_states_need_update = val

@classmethod
def create(
Expand Down Expand Up @@ -484,13 +507,15 @@ def _ensure_realization_state_initialized(self) -> None:
self._path / "state_map.json"
)
else:
raise AssertionError("Expected this line to never be hit")
self._response_states_need_update = True
self._parameter_states_need_update = True
self._realization_states = RealizationState()

def refresh_responses_state_if_needed(self) -> None:
self._ensure_realization_state_initialized()
if self._response_states_need_update:
raise AssertionError("Expected this line to never be hit")
self._response_states_need_update = False
self._refresh_all_responses_state_for_all_realizations()
assert self._realization_states is not None
Expand All @@ -500,11 +525,20 @@ def refresh_parameters_state_if_needed(self) -> None:
self._ensure_realization_state_initialized()

if self._parameter_states_need_update:
raise AssertionError("Expected this line to never be hit")
self._parameter_states_need_update = False
self._refresh_all_parameters_state_for_all_realizations()
assert self._realization_states is not None
self._realization_states.to_file(self._path / "state_map.json")

def refresh_statemap(self):
self._refresh_all_responses_state_for_all_realizations()
self._refresh_all_parameters_state_for_all_realizations()
self._parameter_states_need_update = False
self._response_states_need_update = False
self._has_invoked_refresh_statemap = True
self._realization_states.to_file(self._path / "state_map.json")

def _responses_exist_for_realization(
self, realization: int, key: Optional[str] = None
) -> bool:
Expand Down Expand Up @@ -1243,9 +1277,7 @@ def _refresh_all_parameters_state_for_realization(self, realization: int) -> Non
for parameter_key in self.experiment.parameter_configuration:
self._refresh_parameter_state(parameter_key, realization)

def _refresh_parameter_state(
self, parameter_key: str, realization: int, skip_others: bool = False
) -> None:
def _refresh_parameter_state(self, parameter_key: str, realization: int) -> None:
if self._realization_states is None:
if os.path.exists(self._path / "state_map.json"):
with open(self._path / "state_map.json", "r") as f:
Expand All @@ -1256,27 +1288,6 @@ def _refresh_parameter_state(
if self._realization_states.has_entry(realization, parameter_key):
return

realizations_to_refresh = (
range(self.ensemble_size) if not skip_others else [realization]
)

if self.has_combined_parameter_dataset(parameter_key):
ds = xr.open_dataset(self._path / f"{parameter_key}.nc")

for _real in realizations_to_refresh:
_reals_with_parameter = set(ds["realizations"].values)
self._realization_states.add(
_real,
{
(
parameter_key,
parameter_key,
_real in _reals_with_parameter,
)
},
)
return

self._realization_states.add(
realization,
{
Expand All @@ -1290,9 +1301,7 @@ def _refresh_parameter_state(
},
)

def _refresh_response_state(
self, response_key: str, realization: int, skip_others: bool = False
) -> None:
def _refresh_response_state(self, response_key: str, realization: int) -> None:
if self._realization_states is None:
if os.path.exists(self._path / "state_map.json"):
with open(self._path / "state_map.json", "r") as f:
Expand All @@ -1305,51 +1314,13 @@ def _refresh_response_state(

combined_ds_key = self._find_unified_dataset_for_response(response_key)

# ex: combined_ds_key == gen_data, response_key = WOPR_OP1
# ex2: response_key = summary, combined_ds_key = summary
is_grouped_ds = combined_ds_key == response_key

realizations_to_refresh = (
range(self.ensemble_size) if not skip_others else [realization]
)

if self.has_combined_response_dataset(response_key):
ds = xr.open_dataset(self._path / f"{combined_ds_key}.nc")

if is_grouped_ds:
for _real in realizations_to_refresh:
_reals_with_response = set(ds["realization"].values)
self._realization_states.add(
_real,
{
(
combined_ds_key,
combined_ds_key,
_real in _reals_with_response,
)
},
)

return

all_names = set(ds["name"].values)
for _key in all_names:
_ds = ds.sel(name=_key, drop=True)
reals_with_response = set(
_ds.dropna("realization", how="all")["realization"].values
)

for _real in realizations_to_refresh:
self._realization_states.add(
_real, {(combined_ds_key, _key, _real in reals_with_response)}
)

return

# We assume we will never receive "sub-keys" for grouped datasets
if combined_ds_key == "summary" and response_key != combined_ds_key:
raise KeyError("Did not expect sub-key for grouped dataset")

# ex: combined_ds_key == gen_data, response_key = WOPR_OP1
# ex2: response_key = summary, combined_ds_key = summary
is_grouped_ds = combined_ds_key == response_key
has_realization_dir = os.path.exists(self._realization_dir(realization))

if not has_realization_dir:
Expand Down
6 changes: 2 additions & 4 deletions src/ert/storage/local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def refresh(self) -> None:
# into statemap
# Can be removed if we know 100% that
# no storages will ever have datasets in storage, but have no state map file
for ens in self._ensembles.values():
ens.try_read_state_map_from_file()
# for ens in self._ensembles.values():
# ens.try_read_state_map_from_file()

def get_experiment(self, uuid: UUID) -> LocalExperiment:
"""
Expand Down Expand Up @@ -282,8 +282,6 @@ def close(self) -> None:

if self.can_write:
for ens in self._ensembles.values():
ens.refresh_responses_state_if_needed()
ens.refresh_parameters_state_if_needed()
ens.unify_responses()
ens.unify_parameters()

Expand Down
3 changes: 3 additions & 0 deletions tests/integration_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def test_gen_data_obs_data_mismatch(storage, uniform_parameter):
iens,
)

prior.refresh_statemap()
prior.unify_responses()
prior.unify_parameters()

Expand Down Expand Up @@ -317,6 +318,7 @@ def test_gen_data_missing(storage, uniform_parameter, obs):
iens,
)

prior.refresh_statemap()
prior.unify_responses()
prior.unify_parameters()

Expand Down Expand Up @@ -411,6 +413,7 @@ def test_update_subset_parameters(storage, uniform_parameter, obs):
iens,
)

prior.refresh_statemap()
prior.unify_responses()
prior.unify_parameters()

Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/test_storage_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def test_that_storage_always_has_state_map_after_migrations(
_ds_bpr1.coords["realizations"] = [i]
ensemble.save_parameters("BPR", i, _ds_bpr1)

ensemble.refresh_statemap()
ensemble.unify_parameters()
ensemble.unify_responses()

Expand Down
3 changes: 3 additions & 0 deletions tests/unit_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def test_smoother_snapshot_alpha(
),
iens,
)
prior_storage.refresh_statemap()

posterior_storage = storage.create_ensemble(
prior_storage.experiment_id,
Expand Down Expand Up @@ -583,6 +584,7 @@ def g(X):
iens,
)

prior_ensemble.refresh_statemap()
prior_ensemble.unify_parameters()
prior_ensemble.unify_responses()

Expand Down Expand Up @@ -726,6 +728,7 @@ def test_temporary_parameter_storage_with_inactive_fields(
for iens in range(ensemble_size):
prior_ensemble.save_parameters(param_group, iens, fields[iens])

prior_ensemble.refresh_statemap()
prior_ensemble.unify_parameters()

realization_list = list(range(ensemble_size))
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/gui/tools/test_manage_experiments_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_init_prior(qtbot, storage):
ensemble_size=config.model_config.num_realizations,
name="prior",
)
ensemble.refresh_statemap()
notifier.set_current_ensemble(ensemble)
assert (
ensemble.get_ensemble_state()
Expand Down
5 changes: 4 additions & 1 deletion tests/unit_tests/scenarios/test_summary_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

@pytest.fixture
def prior_ensemble(storage, ert_config):
return storage.create_experiment(
prior = storage.create_experiment(
parameters=ert_config.ensemble_config.parameter_configuration,
responses=ert_config.ensemble_config.response_configuration,
observations=ert_config.observations.datasets,
).create_ensemble(ensemble_size=3, name="prior")

return prior


@pytest.fixture
def ert_config(tmpdir):
Expand Down Expand Up @@ -78,6 +80,7 @@ def create_responses(config_file, prior_ensemble, response_times):
facade.load_from_forward_model(
prior_ensemble, [True] * facade.get_ensemble_size(), 0
)
prior_ensemble.refresh_statemap()
prior_ensemble.unify_responses()


Expand Down
2 changes: 2 additions & 0 deletions tests/unit_tests/storage/test_local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def test_realization_state_updates_on_re_save_response(tmp_path):
ds = xr.Dataset({"values": (["report_step", "index"], [[2, 3, 4, 5, 6]])})

ens.save_response("FOPTZ", ds, 1)
ens.refresh_statemap()
response_mask = ens.get_realization_mask_with_responses()
assert all(response_mask == [False] * 200)

Expand All @@ -385,6 +386,7 @@ def test_realization_state_updates_on_re_save_response(tmp_path):
assert rstate.has(1, "gen_data")

ens.save_response("FOPTZZ", ds, 1)
ens.refresh_statemap()
response_mask2 = ens.get_realization_mask_with_responses()
assert all(response_mask2 == [False] + [True] + [False] * 198)
rstate2 = RealizationState.from_file(ens._path / "state_map.json")
Expand Down

0 comments on commit 80ef929

Please sign in to comment.