Skip to content

Commit

Permalink
Use new storage in test_config_branch_entry.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Jan 22, 2025
1 parent a9f574a commit 17382e8
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions tests/everest/entry_points/test_config_branch_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from os.path import exists
from pathlib import Path

from seba_sqlite.snapshot import SebaSnapshot

from everest.bin.config_branch_script import config_branch_entry
from everest.config_file_loader import load_yaml
from everest.everest_storage import EverestStorage


def test_config_branch_entry(cached_example):
Expand All @@ -27,19 +26,21 @@ def test_config_branch_entry(cached_example):
assert len(new_controls) == len(old_controls)
assert len(new_controls[0]["variables"]) == len(old_controls[0]["variables"])

opt_controls = {}

snapshot = SebaSnapshot(Path(path) / "everest_output" / "optimization_output")
for opt_data in snapshot._optimization_data():
if opt_data.batch_id == 1:
opt_controls = opt_data.controls
storage = EverestStorage(Path(path) / "everest_output" / "optimization_output")
storage.read_from_output_dir()

new_controls_initial_guesses = {
var["initial_guess"] for var in new_controls[0]["variables"]
}
opt_control_val_for_batch_id = {v for k, v in opt_controls.items()}

assert new_controls_initial_guesses == opt_control_val_for_batch_id
control_names = storage.data.controls["control_name"]
batch_1_info = next(b for b in storage.data.batches if b.batch_id == 1)
realization_control_vals = batch_1_info.realization_controls.select(
*control_names
).to_dicts()[0]
control_values = set(realization_control_vals.values())

assert new_controls_initial_guesses == control_values


def test_config_branch_preserves_config_section_order(cached_example):
Expand All @@ -48,14 +49,6 @@ def test_config_branch_preserves_config_section_order(cached_example):
config_branch_entry(["config_minimal.yml", "new_restart_config.yml", "-b", "1"])

assert exists("new_restart_config.yml")
opt_controls = {}

snapshot = SebaSnapshot(Path(path) / "everest_output" / "optimization_output")
for opt_data in snapshot._optimization_data():
if opt_data.batch_id == 1:
opt_controls = opt_data.controls

opt_control_val_for_batch_id = {v for k, v in opt_controls.items()}

diff_lines = []
with (
Expand All @@ -77,5 +70,16 @@ def test_config_branch_preserves_config_section_order(cached_example):
diff_lines.append(line.replace(" ", "").strip())

assert len(diff_lines) == 4
for control_val in opt_control_val_for_batch_id:
assert "-initial_guess:0.1" in diff_lines

storage = EverestStorage(Path(path) / "everest_output" / "optimization_output")
storage.read_from_output_dir()
control_names = storage.data.controls["control_name"]
batch_1_info = next(b for b in storage.data.batches if b.batch_id == 1)
realization_control_vals = batch_1_info.realization_controls.select(
*control_names
).to_dicts()[0]
control_values = set(realization_control_vals.values())

for control_val in control_values:
assert f"+initial_guess:{control_val}" in diff_lines

0 comments on commit 17382e8

Please sign in to comment.