Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only initialize runner once in NeymanConstructor #103

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion alea/examples/configs/unbinned_wimp_statistical_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ parameter_definition:
nominal_value: 50
ptype: shape
fittable: false
description: WIMP mass in GeV/c^2
blueice_anchors:
- 50
description: WIMP mass in GeV/c^2

livetime_sr0:
nominal_value: 0.2
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
parameter_definition:
wimp_mass:
nominal_value: 50
ptype: shape
fittable: false
blueice_anchors:
- 50
description: WIMP mass in GeV/c^2
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modification is to be compatible with the requirement of ptype.


livetime:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
parameter_definition:
wimp_mass:
nominal_value: 50
ptype: shape
fittable: false
blueice_anchors:
- 50
description: WIMP mass in GeV/c^2

livetime_sr2:
nominal_value: 0.5
ptype: livetime
fittable: false
description: Livetime of SR2 in years

livetime_sr3:
nominal_value: 1.0
ptype: livetime
fittable: false
description: Livetime of SR3 in years

Expand Down
21 changes: 15 additions & 6 deletions alea/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,16 @@ def __init__(
self._confidence_interval_kind = confidence_interval_kind
self.confidence_interval_threshold = confidence_interval_threshold
self.asymptotic_dof = asymptotic_dof
nominal_values = kwargs.get("nominal_values", {})
self._define_parameters(parameter_definition, nominal_values)
self._define_parameters(parameter_definition)

self._check_ll_and_generate_data_signature()
self.set_nominal_values(overwrite_static=True, **kwargs.get("nominal_values", {}))
Copy link
Collaborator Author

@dachengx dachengx Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hammannr Can we keep set_nominal_values in the future so that _define_parameters only handle the definition of Parameters? And for the need_reinit parameters, I think they can be defined correctly just by set_nominal_values.


def _define_parameters(self, parameter_definition, nominal_values=None):
def _define_parameters(self, parameter_definition):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think _define_parameters should be as simple as just initializing Parameters.

"""Initialize the parameters of the model."""
if parameter_definition is None:
self.parameters = Parameters()
elif isinstance(parameter_definition, dict):
for name, definition in parameter_definition.items():
if name in nominal_values:
definition["nominal_value"] = nominal_values[name]
self.parameters = Parameters.from_config(parameter_definition)
elif isinstance(parameter_definition, list):
self.parameters = Parameters.from_list(parameter_definition)
Expand Down Expand Up @@ -239,6 +236,18 @@ def store_data(
raise ValueError("The number of data sets and data names must be the same")
toydata_to_file(file_name, _data_list, data_name_list, **kw)

def set_nominal_values(self, overwrite_static=False, **nominal_values):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recover set_nominal_values, with an option to overwrite static parameters. And in principle I think we do need a method to overwrite static parameters(bypass the check).

"""Set the nominal values for parameters.

Keyword Args:
nominal_values (dict): A dict of parameter names and values.

"""
self.parameters.set_nominal_values(
overwrite_static=overwrite_static,
**nominal_values,
)

def set_fit_guesses(self, **fit_guesses):
"""Set the fit guesses for parameters.

Expand Down
22 changes: 20 additions & 2 deletions alea/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Parameter:
nominal_value (float, optional (default=None)): The nominal value of the parameter.
fittable (bool, optional (default=True)):
Indicates if the parameter is fittable or always fixed.
ptype (str, optional (default=None)): The ptype of the parameter.
ptype (str, optional (default=shape)): The ptype of the parameter.
uncertainty (float or str, optional (default=None)): The uncertainty of the parameter.
If a string, it can be evaluated as a numpy or
scipy function to define non-gaussian constraints.
Expand All @@ -40,7 +40,7 @@ def __init__(
name: str,
nominal_value: Optional[float] = None,
fittable: bool = True,
ptype: Optional[str] = None,
ptype: str = "shape",
uncertainty: Optional[Union[float, str]] = None,
relative_uncertainty: Optional[bool] = None,
blueice_anchors: Optional[List] = None,
Expand All @@ -54,6 +54,11 @@ def __init__(
self.name = name
self._nominal_value = nominal_value
self.fittable = fittable
if ptype not in ["rate", "shape", "efficiency", "livetime"]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the requirement of ptype is more strict.

raise ValueError(
f"{name}'s ptype {ptype} is not valid.",
"it should be one of 'rate', 'shape', 'livetime'.",
)
self.ptype = ptype
self.relative_uncertainty = relative_uncertainty
self.uncertainty = uncertainty
Expand Down Expand Up @@ -315,6 +320,19 @@ def uncertainties(self) -> dict:
"""
return {k: i.uncertainty for k, i in self.parameters.items() if i.uncertainty is not None}

def set_nominal_values(self, overwrite_static=False, **nominal_values):
"""Set the nominal values for parameters.

Keyword Args:
nominal_values (dict): A dict of parameter names and values.

"""
for name, value in nominal_values.items():
if overwrite_static:
self.parameters[name]._nominal_value = value
else:
self.parameters[name].nominal_value = value

@property
def with_uncertainty(self) -> "Parameters":
"""Return parameters with a not-NaN uncertainty.
Expand Down
39 changes: 29 additions & 10 deletions alea/submitters/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def submit(
if os.path.splitext(limit_threshold)[-1] != ".json":
raise ValueError("The limit_threshold file should be a json file.")

# initialize the runner
script = next(self.computation_tickets_generator())[0]
runner = self.initialized_runner(script, pop_limit_threshold=True)

# calculate the threshold, iterate over the output files
threshold = cast(Dict[str, Any], {})
for runner_args in self.merged_arguments_generator():
Expand All @@ -162,6 +166,7 @@ def submit(
**nominal_values,
**generate_values,
}
runner.model.set_nominal_values(overwrite_static=True, **nominal_values)

# read the likelihood ratio
output_filename = runner_args["output_filename"]
Expand All @@ -187,15 +192,22 @@ def submit(
)

# update poi according to poi_expectation
runner_args["statistical_model_args"].pop("limit_threshold", None)
runner = Runner(**runner_args)
expectation_values = runner.model.get_expectation_values(
**{**nominal_values, **generate_values}
)
# in some rare cases the poi is not a rate multiplier
# then the poi_expectation is not in the nominal_expectation_values
component = self.poi.replace("_rate_multiplier", "")
poi_expectation = expectation_values.get(component, None)
if runner.input_poi_expectation:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just "if poi_expectation" in generate_values?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also be fine. What is the benefit?

poi_expectation = generate_values.get("poi_expectation")
# nominal_values are passed to update_poi to update the poi
# like wimp_mass, livetime, etc.
generate_values = runner.update_poi(
runner.model, self.poi, generate_values, nominal_values
Comment on lines +199 to +200
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this won't work with static parameters. That was the entire point of the previous PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran it quickly and yes, this crashes if you have static parameters (as it should).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The recent commit 7b4f38f makes the PR compatible with static parameters.

)
else:
expectation_values = runner.model.get_expectation_values(
**{**nominal_values, **generate_values}
)
# in some rare cases the poi is not a rate multiplier
# then the poi_expectation is not in the expectation_values
# in these cases we only assign None to poi_expectation
component = self.poi.replace("_rate_multiplier", "")
poi_expectation = expectation_values.get(component, None)
poi_value = generate_values.pop(self.poi)

# make sure no poi and poi_expectation in the hashed_keys
Expand Down Expand Up @@ -226,7 +238,7 @@ def submit(
"threshold": [],
"poi_expectation": [],
}
threshold[threshold_key] = threshold_value
threshold[threshold_key] = deepcopy(threshold_value)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very important because the shallow copy will convey the modification of nominal_values and generate_values to threshold_value and threshold.

threshold[threshold_key][self.poi].append(poi_value)
threshold[threshold_key]["threshold"].append(q_llr)
threshold[threshold_key]["poi_expectation"].append(poi_expectation)
Expand All @@ -240,6 +252,13 @@ def submit(
threshold[k]["threshold"] = [x[1] for x in sorted_pairs]
threshold[k]["poi_expectation"] = [x[2] for x in sorted_pairs]

for k, v in threshold.items():
if k != deterministic_hash(v["hashed_keys"]):
raise ValueError(
"Something wrong with the threshold, "
"inconsistency between hash and hashed keys found."
)

# save the threshold into a json file
with open(limit_threshold, mode="w") as f:
json.dump(threshold, f, indent=4)
Expand Down
2 changes: 1 addition & 1 deletion alea/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def convert_variations(variations: dict, iteration) -> list:
if not isinstance(v, list):
raise ValueError(f"variations {k} must be a list, not {v} with {type(v)}")
variations[k] = expand_grid_dict(v)
result = [dict(zip(variations, t)) for t in iteration(*variations.values())]
result = [dict(zip(variations, deepcopy(t))) for t in iteration(*variations.values())]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very important because the shallow copy will convey the modification of items in t in one entry of result, to another entry of result.

if result:
return result
else:
Expand Down
16 changes: 16 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
clip_limits,
can_expand_grid,
expand_grid_dict,
convert_to_vary,
deterministic_hash,
)

Expand Down Expand Up @@ -74,6 +75,21 @@ def test_expand_grid_dict(self):
],
)

def test_convert_to_vary(self):
"""Test of the convert_to_zip function."""
varied = convert_to_vary({"a": [1, 2], "b": [{"c": 3}, {"c": 4}]})
self.assertNotEqual(id(varied[0]["b"]), id(varied[2]["b"]))
self.assertNotEqual(id(varied[1]["b"]), id(varied[3]["b"]))
self.assertEqual(
varied,
[
{"a": 1, "b": {"c": 3}},
{"a": 1, "b": {"c": 4}},
{"a": 2, "b": {"c": 3}},
{"a": 2, "b": {"c": 4}},
],
)

def test_deterministic_hash(self):
"""Test of the deterministic_hash function."""
self.assertEqual(deterministic_hash([0, 1]), "si3ifpvg2u")
Expand Down