Skip to content

Commit

Permalink
Mandatorily set ptype, add option to overwrite static parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
dachengx committed Nov 6, 2023
1 parent 4ba6d88 commit 7b4f38f
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 9 deletions.
2 changes: 1 addition & 1 deletion alea/examples/configs/unbinned_wimp_statistical_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ parameter_definition:
nominal_value: 50
ptype: shape
fittable: false
description: WIMP mass in GeV/c^2
blueice_anchors:
- 50
description: WIMP mass in GeV/c^2

livetime_sr0:
nominal_value: 0.2
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
parameter_definition:
wimp_mass:
nominal_value: 50
ptype: shape
fittable: false
blueice_anchors:
- 50
description: WIMP mass in GeV/c^2

livetime:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
parameter_definition:
wimp_mass:
nominal_value: 50
ptype: shape
fittable: false
blueice_anchors:
- 50
description: WIMP mass in GeV/c^2

livetime_sr2:
nominal_value: 0.5
ptype: livetime
fittable: false
description: Livetime of SR2 in years

livetime_sr3:
nominal_value: 1.0
ptype: livetime
fittable: false
description: Livetime of SR3 in years

Expand Down
21 changes: 15 additions & 6 deletions alea/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,16 @@ def __init__(
self._confidence_interval_kind = confidence_interval_kind
self.confidence_interval_threshold = confidence_interval_threshold
self.asymptotic_dof = asymptotic_dof
nominal_values = kwargs.get("nominal_values", {})
self._define_parameters(parameter_definition, nominal_values)
self._define_parameters(parameter_definition)

self._check_ll_and_generate_data_signature()
self.set_nominal_values(overwrite_static=True, **kwargs.get("nominal_values", {}))

def _define_parameters(self, parameter_definition, nominal_values=None):
def _define_parameters(self, parameter_definition):
"""Initialize the parameters of the model."""
if parameter_definition is None:
self.parameters = Parameters()
elif isinstance(parameter_definition, dict):
for name, definition in parameter_definition.items():
if name in nominal_values:
definition["nominal_value"] = nominal_values[name]
self.parameters = Parameters.from_config(parameter_definition)
elif isinstance(parameter_definition, list):
self.parameters = Parameters.from_list(parameter_definition)
Expand Down Expand Up @@ -239,6 +236,18 @@ def store_data(
raise ValueError("The number of data sets and data names must be the same")
toydata_to_file(file_name, _data_list, data_name_list, **kw)

def set_nominal_values(self, overwrite_static=False, **nominal_values):
"""Set the nominal values for parameters.
Keyword Args:
nominal_values (dict): A dict of parameter names and values.
"""
self.parameters.set_nominal_values(
overwrite_static=overwrite_static,
**nominal_values,
)

def set_fit_guesses(self, **fit_guesses):
"""Set the fit guesses for parameters.
Expand Down
22 changes: 20 additions & 2 deletions alea/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Parameter:
nominal_value (float, optional (default=None)): The nominal value of the parameter.
fittable (bool, optional (default=True)):
Indicates if the parameter is fittable or always fixed.
ptype (str, optional (default=None)): The ptype of the parameter.
ptype (str, optional (default=shape)): The ptype of the parameter.
uncertainty (float or str, optional (default=None)): The uncertainty of the parameter.
If a string, it can be evaluated as a numpy or
scipy function to define non-gaussian constraints.
Expand All @@ -40,7 +40,7 @@ def __init__(
name: str,
nominal_value: Optional[float] = None,
fittable: bool = True,
ptype: Optional[str] = None,
ptype: str = "shape",
uncertainty: Optional[Union[float, str]] = None,
relative_uncertainty: Optional[bool] = None,
blueice_anchors: Optional[List] = None,
Expand All @@ -54,6 +54,11 @@ def __init__(
self.name = name
self._nominal_value = nominal_value
self.fittable = fittable
if ptype not in ["rate", "shape", "efficiency", "livetime"]:
raise ValueError(
f"{name}'s ptype {ptype} is not valid.",
"it should be one of 'rate', 'shape', 'livetime'.",
)
self.ptype = ptype
self.relative_uncertainty = relative_uncertainty
self.uncertainty = uncertainty
Expand Down Expand Up @@ -315,6 +320,19 @@ def uncertainties(self) -> dict:
"""
return {k: i.uncertainty for k, i in self.parameters.items() if i.uncertainty is not None}

def set_nominal_values(self, overwrite_static=False, **nominal_values):
"""Set the nominal values for parameters.
Keyword Args:
nominal_values (dict): A dict of parameter names and values.
"""
for name, value in nominal_values.items():
if overwrite_static:
self.parameters[name]._nominal_value = value
else:
self.parameters[name].nominal_value = value

@property
def with_uncertainty(self) -> "Parameters":
"""Return parameters with a not-NaN uncertainty.
Expand Down
1 change: 1 addition & 0 deletions alea/submitters/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def submit(
**nominal_values,
**generate_values,
}
runner.model.set_nominal_values(overwrite_static=True, **nominal_values)

# read the likelihood ratio
output_filename = runner_args["output_filename"]
Expand Down

0 comments on commit 7b4f38f

Please sign in to comment.