-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Only initialize runner once in NeymanConstructor #103
Changes from 3 commits
568004e
4ba6d88
7b4f38f
cfc7b3c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,19 +95,16 @@ def __init__( | |
self._confidence_interval_kind = confidence_interval_kind | ||
self.confidence_interval_threshold = confidence_interval_threshold | ||
self.asymptotic_dof = asymptotic_dof | ||
nominal_values = kwargs.get("nominal_values", {}) | ||
self._define_parameters(parameter_definition, nominal_values) | ||
self._define_parameters(parameter_definition) | ||
|
||
self._check_ll_and_generate_data_signature() | ||
self.set_nominal_values(overwrite_static=True, **kwargs.get("nominal_values", {})) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @hammannr Can we keep |
||
|
||
def _define_parameters(self, parameter_definition, nominal_values=None): | ||
def _define_parameters(self, parameter_definition): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
"""Initialize the parameters of the model.""" | ||
if parameter_definition is None: | ||
self.parameters = Parameters() | ||
elif isinstance(parameter_definition, dict): | ||
for name, definition in parameter_definition.items(): | ||
if name in nominal_values: | ||
definition["nominal_value"] = nominal_values[name] | ||
self.parameters = Parameters.from_config(parameter_definition) | ||
elif isinstance(parameter_definition, list): | ||
self.parameters = Parameters.from_list(parameter_definition) | ||
|
@@ -239,6 +236,18 @@ def store_data( | |
raise ValueError("The number of data sets and data names must be the same") | ||
toydata_to_file(file_name, _data_list, data_name_list, **kw) | ||
|
||
def set_nominal_values(self, overwrite_static=False, **nominal_values): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recover |
||
"""Set the nominal values for parameters. | ||
|
||
Keyword Args: | ||
nominal_values (dict): A dict of parameter names and values. | ||
|
||
""" | ||
self.parameters.set_nominal_values( | ||
overwrite_static=overwrite_static, | ||
**nominal_values, | ||
) | ||
|
||
def set_fit_guesses(self, **fit_guesses): | ||
"""Set the fit guesses for parameters. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ class Parameter: | |
nominal_value (float, optional (default=None)): The nominal value of the parameter. | ||
fittable (bool, optional (default=True)): | ||
Indicates if the parameter is fittable or always fixed. | ||
ptype (str, optional (default=None)): The ptype of the parameter. | ||
ptype (str, optional (default=shape)): The ptype of the parameter. | ||
uncertainty (float or str, optional (default=None)): The uncertainty of the parameter. | ||
If a string, it can be evaluated as a numpy or | ||
scipy function to define non-gaussian constraints. | ||
|
@@ -40,7 +40,7 @@ def __init__( | |
name: str, | ||
nominal_value: Optional[float] = None, | ||
fittable: bool = True, | ||
ptype: Optional[str] = None, | ||
ptype: str = "shape", | ||
uncertainty: Optional[Union[float, str]] = None, | ||
relative_uncertainty: Optional[bool] = None, | ||
blueice_anchors: Optional[List] = None, | ||
|
@@ -54,6 +54,11 @@ def __init__( | |
self.name = name | ||
self._nominal_value = nominal_value | ||
self.fittable = fittable | ||
if ptype not in ["rate", "shape", "efficiency", "livetime"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here the requirement of |
||
raise ValueError( | ||
f"{name}'s ptype {ptype} is not valid.", | ||
"it should be one of 'rate', 'shape', 'livetime'.", | ||
) | ||
self.ptype = ptype | ||
self.relative_uncertainty = relative_uncertainty | ||
self.uncertainty = uncertainty | ||
|
@@ -315,6 +320,19 @@ def uncertainties(self) -> dict: | |
""" | ||
return {k: i.uncertainty for k, i in self.parameters.items() if i.uncertainty is not None} | ||
|
||
def set_nominal_values(self, overwrite_static=False, **nominal_values): | ||
"""Set the nominal values for parameters. | ||
|
||
Keyword Args: | ||
nominal_values (dict): A dict of parameter names and values. | ||
|
||
""" | ||
for name, value in nominal_values.items(): | ||
if overwrite_static: | ||
self.parameters[name]._nominal_value = value | ||
else: | ||
self.parameters[name].nominal_value = value | ||
|
||
@property | ||
def with_uncertainty(self) -> "Parameters": | ||
"""Return parameters with a not-NaN uncertainty. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -137,6 +137,10 @@ def submit( | |
if os.path.splitext(limit_threshold)[-1] != ".json": | ||
raise ValueError("The limit_threshold file should be a json file.") | ||
|
||
# initialize the runner | ||
script = next(self.computation_tickets_generator())[0] | ||
runner = self.initialized_runner(script, pop_limit_threshold=True) | ||
|
||
# calculate the threshold, iterate over the output files | ||
threshold = cast(Dict[str, Any], {}) | ||
for runner_args in self.merged_arguments_generator(): | ||
|
@@ -162,6 +166,7 @@ def submit( | |
**nominal_values, | ||
**generate_values, | ||
} | ||
runner.model.set_nominal_values(overwrite_static=True, **nominal_values) | ||
|
||
# read the likelihood ratio | ||
output_filename = runner_args["output_filename"] | ||
|
@@ -187,15 +192,22 @@ def submit( | |
) | ||
|
||
# update poi according to poi_expectation | ||
runner_args["statistical_model_args"].pop("limit_threshold", None) | ||
runner = Runner(**runner_args) | ||
expectation_values = runner.model.get_expectation_values( | ||
**{**nominal_values, **generate_values} | ||
) | ||
# in some rare cases the poi is not a rate multiplier | ||
# then the poi_expectation is not in the nominal_expectation_values | ||
component = self.poi.replace("_rate_multiplier", "") | ||
poi_expectation = expectation_values.get(component, None) | ||
if runner.input_poi_expectation: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not just "if poi_expectation" in generate_values? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should also be fine. What is the benefit? |
||
poi_expectation = generate_values.get("poi_expectation") | ||
# nominal_values are passed to update_poi to update the poi | ||
# like wimp_mass, livetime, etc. | ||
generate_values = runner.update_poi( | ||
runner.model, self.poi, generate_values, nominal_values | ||
Comment on lines
+199
to
+200
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this won't work with static parameters. That was the entire point of the previous PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ran it quickly and yes, this crashes if you have static parameters (as it should). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The recent commit 7b4f38f makes the PR compatible with static parameters. |
||
) | ||
else: | ||
expectation_values = runner.model.get_expectation_values( | ||
**{**nominal_values, **generate_values} | ||
) | ||
# in some rare cases the poi is not a rate multiplier | ||
# then the poi_expectation is not in the expectation_values | ||
# in these cases we only assign None to poi_expectation | ||
component = self.poi.replace("_rate_multiplier", "") | ||
poi_expectation = expectation_values.get(component, None) | ||
poi_value = generate_values.pop(self.poi) | ||
|
||
# make sure no poi and poi_expectation in the hashed_keys | ||
|
@@ -226,7 +238,7 @@ def submit( | |
"threshold": [], | ||
"poi_expectation": [], | ||
} | ||
threshold[threshold_key] = threshold_value | ||
threshold[threshold_key] = deepcopy(threshold_value) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very important because the shallow copy will convey the modification of |
||
threshold[threshold_key][self.poi].append(poi_value) | ||
threshold[threshold_key]["threshold"].append(q_llr) | ||
threshold[threshold_key]["poi_expectation"].append(poi_expectation) | ||
|
@@ -240,6 +252,13 @@ def submit( | |
threshold[k]["threshold"] = [x[1] for x in sorted_pairs] | ||
threshold[k]["poi_expectation"] = [x[2] for x in sorted_pairs] | ||
|
||
for k, v in threshold.items(): | ||
if k != deterministic_hash(v["hashed_keys"]): | ||
raise ValueError( | ||
"Something wrong with the threshold, " | ||
"inconsistency between hash and hashed keys found." | ||
) | ||
|
||
# save the threshold into a json file | ||
with open(limit_threshold, mode="w") as f: | ||
json.dump(threshold, f, indent=4) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -449,7 +449,7 @@ def convert_variations(variations: dict, iteration) -> list: | |
if not isinstance(v, list): | ||
raise ValueError(f"variations {k} must be a list, not {v} with {type(v)}") | ||
variations[k] = expand_grid_dict(v) | ||
result = [dict(zip(variations, t)) for t in iteration(*variations.values())] | ||
result = [dict(zip(variations, deepcopy(t))) for t in iteration(*variations.values())] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very important because the shallow copy will convey the modification of items in |
||
if result: | ||
return result | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This modification is to be compatible with the requirement of
ptype
.