Skip to content

Commit

Permalink
Make more explicit
Browse files Browse the repository at this point in the history
with some safeguards to be removed
  • Loading branch information
Yngve S. Kristiansen committed Jul 2, 2024
1 parent 0b9a2ae commit bb1d26e
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 129 deletions.
1 change: 1 addition & 0 deletions src/ert/enkf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def sample_prior(
)
ensemble.save_parameters(parameter, realization_nr, ds)

ensemble.refresh_statemap()

Check failure on line 156 in src/ert/enkf_main.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Call to untyped function "refresh_statemap" in typed context
ensemble.unify_parameters()

logger.debug(f"sample_prior() time_used {(time.perf_counter() - t):.4f}s")
Expand Down
1 change: 1 addition & 0 deletions src/ert/libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def load_from_forward_model(
run_context.run_args,
run_context.mask,
)
ensemble.refresh_statemap()

Check failure on line 147 in src/ert/libres_facade.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Call to untyped function "refresh_statemap" in typed context
_logger.debug(
f"load_from_forward_model() time_used {(time.perf_counter() - t):.4f}s"
)
Expand Down
1 change: 1 addition & 0 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def run_ensemble_evaluator(
self._end_queue.get()
return []

run_context.ensemble.refresh_statemap()
run_context.ensemble.unify_parameters()
run_context.ensemble.unify_responses()

Expand Down
180 changes: 56 additions & 124 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,18 +237,41 @@ def create_realization_dir(realization: int) -> Path:

self._realization_dir = create_realization_dir

self._realization_states: Optional[RealizationState] = None
self.try_read_state_map_from_file()
self._response_states_need_update = False
self._parameter_states_need_update = False
self._realization_states = (
RealizationState.from_file(self._path / "state_map.json")
if os.path.exists(self._path / "state_map.json")
else RealizationState()
)

def try_read_state_map_from_file(self):
if self._realization_states is None:
self._realization_states = (
RealizationState.from_file(self._path / "state_map.json")
if os.path.exists(self._path / "state_map.json")
else None
)
self.__response_states_need_update = False # Tmp
self.__parameter_states_need_update = False # Tmp
self._has_invoked_refresh_statemap = False

@property
def _response_states_need_update(self) -> bool:
return self.__response_states_need_update

@_response_states_need_update.setter
def _response_states_need_update(self, val: bool):

Check failure on line 255 in src/ert/storage/local_ensemble.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
if val and self._has_invoked_refresh_statemap:
# Temp, all tests should pass without
# hitting this line
pass # raise AssertionError("Expected this line to never be hit")

self.__response_states_need_update = val

@property
def _parameter_states_need_update(self) -> bool:
return self.__parameter_states_need_update

@_parameter_states_need_update.setter
def _parameter_states_need_update(self, val: bool):

Check failure on line 268 in src/ert/storage/local_ensemble.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
if val and self._has_invoked_refresh_statemap:
# Temp, all tests should pass without
# hitting this line
pass

self.__parameter_states_need_update = val

@classmethod
def create(
Expand Down Expand Up @@ -449,7 +472,6 @@ def _parameters_exist_for_realization(self, realization: int) -> bool:

self.refresh_parameters_state_if_needed()

assert self._realization_states is not None
return all(
self._realization_states.has(realization, parameter)
for parameter in self.experiment.parameter_configuration
Expand Down Expand Up @@ -477,34 +499,30 @@ def _load_combined_parameter_dataset(self, key: str) -> xr.Dataset:

return unified_ds

def _ensure_realization_state_initialized(self) -> None:
if self._realization_states is None:
if os.path.exists(self._path / "state_map.json"):
self._realization_states = RealizationState.from_file(
self._path / "state_map.json"
)
else:
self._response_states_need_update = True
self._parameter_states_need_update = True
self._realization_states = RealizationState()

def refresh_responses_state_if_needed(self) -> None:
self._ensure_realization_state_initialized()
if self._response_states_need_update:
raise AssertionError("Expected this line to never be hit")
self._response_states_need_update = False
self._refresh_all_responses_state_for_all_realizations()
assert self._realization_states is not None

self._realization_states.to_file(self._path / "state_map.json")

def refresh_parameters_state_if_needed(self) -> None:
self._ensure_realization_state_initialized()

if self._parameter_states_need_update:
raise AssertionError("Expected this line to never be hit")
self._parameter_states_need_update = False
self._refresh_all_parameters_state_for_all_realizations()
assert self._realization_states is not None
self._realization_states.to_file(self._path / "state_map.json")

def refresh_statemap(self):

Check failure on line 518 in src/ert/storage/local_ensemble.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
self._refresh_all_responses_state_for_all_realizations()
self._refresh_all_parameters_state_for_all_realizations()
self._parameter_states_need_update = False
self._response_states_need_update = False
self._has_invoked_refresh_statemap = True
self._realization_states.to_file(self._path / "state_map.json")

def _responses_exist_for_realization(
self, realization: int, key: Optional[str] = None
) -> bool:
Expand Down Expand Up @@ -664,13 +682,11 @@ def unset_failure(
if filename.exists():
filename.unlink()

if self._realization_states is not None:
for response_key in self.experiment.response_configuration:
self._realization_states.clear_entry(realization, response_key)
for response_key in self.experiment.response_configuration:
self._realization_states.clear_entry(realization, response_key)

if self._realization_states is not None:
for parameter_group_key in self.experiment.parameter_configuration:
self._realization_states.clear_entry(realization, parameter_group_key)
for parameter_group_key in self.experiment.parameter_configuration:
self._realization_states.clear_entry(realization, parameter_group_key)

self._refresh_all_responses_state_for_realization(realization)
self._refresh_all_parameters_state_for_realization(realization)
Expand Down Expand Up @@ -1176,12 +1192,8 @@ def save_parameters(

dataset.to_netcdf(path, engine="scipy")

if self._realization_states is not None:
self._realization_states.clear_entry(realization, group)

self._realization_states.clear_entry(realization, group)
self._parameter_states_need_update = True
if self._realization_states is not None:
self._realization_states.clear_entry(realization, group)

@require_write
def save_response(self, group: str, data: xr.Dataset, realization: int) -> None:
Expand Down Expand Up @@ -1217,9 +1229,7 @@ def save_response(self, group: str, data: xr.Dataset, realization: int) -> None:

data.to_netcdf(output_path / f"{group}.nc", engine="scipy")
self._response_states_need_update = True

if self._realization_states is not None:
self._realization_states.clear_entry(realization, group)
self._realization_states.clear_entry(realization, group)

def _refresh_all_parameters_state_for_all_realizations(self) -> None:
for real in range(self.ensemble_size):
Expand All @@ -1232,7 +1242,6 @@ def _refresh_all_responses_state_for_all_realizations(self) -> None:
for real in range(self.ensemble_size):
self._refresh_all_responses_state_for_realization(realization=real)

assert self._realization_states is not None
self._realization_states.to_file(self._path / "state_map.json")

def _refresh_all_responses_state_for_realization(self, realization: int) -> None:
Expand All @@ -1243,40 +1252,10 @@ def _refresh_all_parameters_state_for_realization(self, realization: int) -> Non
for parameter_key in self.experiment.parameter_configuration:
self._refresh_parameter_state(parameter_key, realization)

def _refresh_parameter_state(
self, parameter_key: str, realization: int, skip_others: bool = False
) -> None:
if self._realization_states is None:
if os.path.exists(self._path / "state_map.json"):
with open(self._path / "state_map.json", "r") as f:
self._realization_states = RealizationState.from_json(json.load(f))
else:
self._realization_states = RealizationState()

def _refresh_parameter_state(self, parameter_key: str, realization: int) -> None:
if self._realization_states.has_entry(realization, parameter_key):
return

realizations_to_refresh = (
range(self.ensemble_size) if not skip_others else [realization]
)

if self.has_combined_parameter_dataset(parameter_key):
ds = xr.open_dataset(self._path / f"{parameter_key}.nc")

for _real in realizations_to_refresh:
_reals_with_parameter = set(ds["realizations"].values)
self._realization_states.add(
_real,
{
(
parameter_key,
parameter_key,
_real in _reals_with_parameter,
)
},
)
return

self._realization_states.add(
realization,
{
Expand All @@ -1290,66 +1269,19 @@ def _refresh_parameter_state(
},
)

def _refresh_response_state(
self, response_key: str, realization: int, skip_others: bool = False
) -> None:
if self._realization_states is None:
if os.path.exists(self._path / "state_map.json"):
with open(self._path / "state_map.json", "r") as f:
self._realization_states = RealizationState.from_json(json.load(f))
else:
self._realization_states = RealizationState()

def _refresh_response_state(self, response_key: str, realization: int) -> None:
if self._realization_states.has_entry(realization, response_key):
return

combined_ds_key = self._find_unified_dataset_for_response(response_key)

# ex: combined_ds_key == gen_data, response_key = WOPR_OP1
# ex2: response_key = summary, combined_ds_key = summary
is_grouped_ds = combined_ds_key == response_key

realizations_to_refresh = (
range(self.ensemble_size) if not skip_others else [realization]
)

if self.has_combined_response_dataset(response_key):
ds = xr.open_dataset(self._path / f"{combined_ds_key}.nc")

if is_grouped_ds:
for _real in realizations_to_refresh:
_reals_with_response = set(ds["realization"].values)
self._realization_states.add(
_real,
{
(
combined_ds_key,
combined_ds_key,
_real in _reals_with_response,
)
},
)

return

all_names = set(ds["name"].values)
for _key in all_names:
_ds = ds.sel(name=_key, drop=True)
reals_with_response = set(
_ds.dropna("realization", how="all")["realization"].values
)

for _real in realizations_to_refresh:
self._realization_states.add(
_real, {(combined_ds_key, _key, _real in reals_with_response)}
)

return

# We assume we will never receive "sub-keys" for grouped datasets
if combined_ds_key == "summary" and response_key != combined_ds_key:
raise KeyError("Did not expect sub-key for grouped dataset")

# ex: combined_ds_key == gen_data, response_key = WOPR_OP1
# ex2: response_key = summary, combined_ds_key = summary
is_grouped_ds = combined_ds_key == response_key
has_realization_dir = os.path.exists(self._realization_dir(realization))

if not has_realization_dir:
Expand Down
6 changes: 2 additions & 4 deletions src/ert/storage/local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def refresh(self) -> None:
# into statemap
# Can be removed if we know 100% that
# no storages will ever have datasets in storage, but have no state map file
for ens in self._ensembles.values():
ens.try_read_state_map_from_file()
# for ens in self._ensembles.values():
# ens.try_read_state_map_from_file()

def get_experiment(self, uuid: UUID) -> LocalExperiment:
"""
Expand Down Expand Up @@ -282,8 +282,6 @@ def close(self) -> None:

if self.can_write:
for ens in self._ensembles.values():
ens.refresh_responses_state_if_needed()
ens.refresh_parameters_state_if_needed()
ens.unify_responses()
ens.unify_parameters()

Expand Down
3 changes: 3 additions & 0 deletions tests/integration_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def test_gen_data_obs_data_mismatch(storage, uniform_parameter):
iens,
)

prior.refresh_statemap()
prior.unify_responses()
prior.unify_parameters()

Expand Down Expand Up @@ -317,6 +318,7 @@ def test_gen_data_missing(storage, uniform_parameter, obs):
iens,
)

prior.refresh_statemap()
prior.unify_responses()
prior.unify_parameters()

Expand Down Expand Up @@ -411,6 +413,7 @@ def test_update_subset_parameters(storage, uniform_parameter, obs):
iens,
)

prior.refresh_statemap()
prior.unify_responses()
prior.unify_parameters()

Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/test_storage_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def test_that_storage_always_has_state_map_after_migrations(
_ds_bpr1.coords["realizations"] = [i]
ensemble.save_parameters("BPR", i, _ds_bpr1)

ensemble.refresh_statemap()
ensemble.unify_parameters()
ensemble.unify_responses()

Expand Down
3 changes: 3 additions & 0 deletions tests/unit_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def test_smoother_snapshot_alpha(
),
iens,
)
prior_storage.refresh_statemap()

posterior_storage = storage.create_ensemble(
prior_storage.experiment_id,
Expand Down Expand Up @@ -583,6 +584,7 @@ def g(X):
iens,
)

prior_ensemble.refresh_statemap()
prior_ensemble.unify_parameters()
prior_ensemble.unify_responses()

Expand Down Expand Up @@ -726,6 +728,7 @@ def test_temporary_parameter_storage_with_inactive_fields(
for iens in range(ensemble_size):
prior_ensemble.save_parameters(param_group, iens, fields[iens])

prior_ensemble.refresh_statemap()
prior_ensemble.unify_parameters()

realization_list = list(range(ensemble_size))
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/gui/tools/test_manage_experiments_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_init_prior(qtbot, storage):
ensemble_size=config.model_config.num_realizations,
name="prior",
)
ensemble.refresh_statemap()
notifier.set_current_ensemble(ensemble)
assert (
ensemble.get_ensemble_state()
Expand Down
Loading

0 comments on commit bb1d26e

Please sign in to comment.