diff --git a/alea/utils.py b/alea/utils.py index 0c1d9bb4..79b9876b 100644 --- a/alea/utils.py +++ b/alea/utils.py @@ -14,6 +14,7 @@ from typing import Any, List, Dict, Tuple, Optional, Union, cast, get_args, get_origin import h5py +import matplotlib.pyplot as plt # These imports are needed to evaluate strings import numpy # noqa: F401 @@ -608,3 +609,55 @@ def deterministic_hash(thing, length=10): # disable bandit digest = sha256(jsonned.encode("ascii")).digest() return b32encode(digest)[:length].decode("ascii").lower() + + +def signal_multiplier_estimator( + signal: np.ndarray, + background: np.ndarray, + data: np.ndarray, + iteration=100, + diagnostic=False, +) -> float: + """Estimate the best-fit signal multiplier using perturbation theory. The method tries to solve + the critial point of the likelihood function by perturbation theory, where the likelihood + function is defined as the binned Poisson likelihood function, given signal, background models + and data. + + Args: + signal (np.ndarray): signal model + background (np.ndarray): background model + data (np.ndarray): data array + iteration (int, optional (default=100)): number of iterations + Returns: + float: best-fit signal multiplier + + """ + mask = (signal > 0) | (background > 0) + if np.any(data[~mask] > 0): + raise ValueError("Data has non-zero values where signal and background is zero.") + + sig = signal[mask].ravel() + bkg = background[mask].ravel() + obs = data[mask].ravel() + + @np.errstate(invalid="ignore", divide="ignore") + def correction_on_multiplier(x): + exp = sig * x + bkg + return np.sum(np.where(exp > 0, (obs / exp - 1) * sig, 0)) / np.sum( + np.where(exp > 0, obs * sig**2 / exp**2, 0) + ) + + # For underfluctutation case, the best-fit multiplier could be negative + # in which case the perturbation theory may not converge or be negative. + # Thus we clip it to be non-negative. + x = np.sum(obs - bkg) / np.sum(sig) + xs = [x] + for _ in range(iteration): + x += correction_on_multiplier(x) + x = np.clip(x, 0, None) + xs.append(x) + if diagnostic: + plt.plot(xs, marker=".") + plt.xlabel("Iteration") + plt.ylabel("x") + return x diff --git a/tests/test_utils.py b/tests/test_utils.py index c66fa809..027c347b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,8 @@ from unittest import TestCase import numpy as np +import inference_interface as ii +import multihist as mh from scipy.stats import chi2 from alea.utils import ( @@ -14,6 +16,8 @@ expand_grid_dict, convert_to_vary, deterministic_hash, + get_file_path, + signal_multiplier_estimator, ) @@ -98,3 +102,14 @@ def test_deterministic_hash(self): self.assertEqual( deterministic_hash({"a": np.array([0, 1]), "b": np.array([0, 1])}), "anxefavaju" ) + + def test_signal_multiplier_estimator(self): + bkg = ii.template_to_multihist( + get_file_path("er_template_0.ii.h5"), hist_name="er_template" + ) + sig = ii.template_to_multihist( + get_file_path("wimp50gev_template.ii.h5"), hist_name="wimp_template" + ) + data = (bkg + sig * 1e-1).get_random(size=np.random.poisson(bkg.n)) + data = mh.Histdd(*data.T, bins=bkg.bin_edges) + signal_multiplier_estimator(sig.histogram, bkg.histogram, data.histogram, diagnostic=True)