Skip to content

Commit

Permalink
Allowing Hierarchical Non Centered Parametrization (#747)
Browse files Browse the repository at this point in the history
* Allowing non center parametrization

* update notebook

* Adding example in docstring

* change

* Push code changes.

* A painful and ugly change!

The things one does for democracy!

* Missing parts!

* adding missing test 2D

* Missing raise

---------

Co-authored-by: Will Dean <[email protected]>
  • Loading branch information
2 people authored and twiecki committed Sep 10, 2024
1 parent 5b20b5e commit c0b37ce
Show file tree
Hide file tree
Showing 8 changed files with 3,267 additions and 3,041 deletions.
2,442 changes: 1,233 additions & 1,209 deletions docs/source/notebooks/mmm/mmm_example.ipynb

Large diffs are not rendered by default.

3,634 changes: 1,817 additions & 1,817 deletions docs/source/notebooks/mmm/mmm_tvp_example.ipynb

Large diffs are not rendered by default.

5 changes: 1 addition & 4 deletions pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pymc_marketing.mmm import base, delayed_saturated_mmm, preprocessing, validating
from pymc_marketing.mmm.base import (
BaseValidateMMM,
MMMModelBuilder,
)
from pymc_marketing.mmm.base import BaseValidateMMM, MMMModelBuilder
from pymc_marketing.mmm.components.adstock import (
AdstockTransformation,
DelayedAdstock,
Expand Down
14 changes: 7 additions & 7 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _save_input_params(self, idata) -> None:
idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)

def forward_pass(
self, x: pt.TensorVariable | npt.NDArray[np.float_]
self, x: pt.TensorVariable | npt.NDArray[np.float64]
) -> pt.TensorVariable:
"""Transforms channel input into target contributions of each channel.
Expand All @@ -253,7 +253,7 @@ def forward_pass(
Parameters
------------
x : pt.TensorVariable | npt.NDArray[np.float_]
x : pt.TensorVariable | npt.NDArray[np.float64]
The channel input which could be spends or impressions
Returns
Expand Down Expand Up @@ -532,7 +532,7 @@ def _get_fourier_models_data(self, X) -> pd.DataFrame:
date_data: pd.Series = pd.to_datetime(
arg=X[self.date_column], format="%Y-%m-%d"
)
periods: npt.NDArray[np.float_] = (
periods: npt.NDArray[np.float64] = (
date_data.dt.dayofyear.to_numpy() / DAYS_IN_YEAR
)
return generate_fourier_modes(
Expand All @@ -541,8 +541,8 @@ def _get_fourier_models_data(self, X) -> pd.DataFrame:
)

def channel_contributions_forward_pass(
self, channel_data: npt.NDArray[np.float_]
) -> npt.NDArray[np.float_]:
self, channel_data: npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]:
"""Evaluate the channel contribution for a given channel data and a fitted model, ie. the forward pass.
Parameters
Expand Down Expand Up @@ -891,8 +891,8 @@ class MMM(
version = "0.0.1"

def channel_contributions_forward_pass(
self, channel_data: npt.NDArray[np.float_]
) -> npt.NDArray[np.float_]:
self, channel_data: npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]:
"""Evaluate the channel contribution for a given channel data and a fitted model, ie. the forward pass.
We return the contribution in the original scale of the target variable.
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def weibull_adstock(
return batched_convolution(x, w, axis=axis, mode=mode)


def logistic_saturation(x, lam: npt.NDArray[np.float_] | float = 0.5):
def logistic_saturation(x, lam: npt.NDArray[np.float64] | float = 0.5):
"""Logistic saturation transformation.
.. math::
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/mmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


def generate_fourier_modes(
periods: npt.NDArray[np.float_], n_order: int
periods: npt.NDArray[np.float64], n_order: int
) -> pd.DataFrame:
"""Generate Fourier modes.
Expand Down
124 changes: 122 additions & 2 deletions pymc_marketing/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@
"dims": "channel",
}
Example parameter configuration with a hierarchical non-centered distribution:
.. code-block:: python
hierarchical_non_centered_parameter = {
"dist": "Normal",
"kwargs": {
"mu": {"dist": "HalfNormal", "kwargs": {"sigma": 2},},
"sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 1},},
},
"dims": ("channel"),
"centered": False,
}
Example configuration of a 2D parameter:
.. code-block:: python
Expand Down Expand Up @@ -369,10 +383,110 @@ def handle_parameter_distributions(
}


class NestedDistributionError(Exception):
"""Error for when a nested distribution is detected where it is not allowed."""

def __init__(self, param: str) -> None:
self.param = param
self.message = (
f"Nested distribution detected in '{param}', which is not allowed."
)
super().__init__(self.message)


def check_for_deeper_nested_distribution(
param_config: dict[str, Any], param_name: str
) -> None:
"""Check if the parameter configuration contains a deeper nested distribution."""
if (
isinstance(param_config, dict)
and "dist" in param_config
and "kwargs" in param_config
):
for _key, value in param_config["kwargs"].items():
if isinstance(value, dict) and "dist" in value and "kwargs" in value:
raise NestedDistributionError(param_name)


class NonCenterInvalidDistributionError(Exception):
"""Error for when an invalid distribution is used for non-centered hierarchical distribution."""

def __init__(self, name: str) -> None:
self.param = name
self.message = f"""
Invalid distribution '{name}' for non-centered hierarchical distribution.
Only 'Normal' is allowed.
"""
super().__init__(self.message)


def create_hierarchical_non_center(
name: str,
distribution_kwargs: dict[str, Any],
**kwargs,
) -> pt.TensorVariable:
"""
Create a hierarchical non-centered distribution.
This function constructs a hierarchical non-centered distribution using the provided
distribution parameters for offset, mu, and sigma. It returns a deterministic variable
representing the hierarchical non-centered distribution.
Parameters
----------
name : str
The name of the variable.
distribution_kwargs : dict[str, Any]
A dictionary containing the distribution parameters for 'offset', 'mu', and 'sigma'.
**kwargs
Additional keyword arguments, including 'dims' for specifying desired dimensions.
Returns
-------
pt.TensorVariable
A PyMC deterministic variable representing the hierarchical non-centered distribution.
"""
desired_dims = kwargs.get("dims", ())
dim_handler = create_dim_handler(desired_dims)

mu_dist = distribution_kwargs["mu"]
mu_dims = mu_dist.get("dims", ())
sigma_dist = distribution_kwargs["sigma"]
sigma_dims = sigma_dist.get("dims", ())

offset = pm.Normal(name=f"{name}_offset", mu=0, sigma=1, dims=desired_dims)

check_for_deeper_nested_distribution(mu_dist, f"{name}_mu")

mu_global = create_distribution(
f"{name}_mu",
mu_dist["dist"],
mu_dist["kwargs"],
dims=mu_dims,
)
mu_global = dim_handler(mu_global, mu_dims)

check_for_deeper_nested_distribution(sigma_dist, f"{name}_sigma")

sigma_global = create_distribution(
f"{name}_sigma",
sigma_dist["dist"],
sigma_dist["kwargs"],
dims=sigma_dims,
)
sigma_global = dim_handler(sigma_global, sigma_dims)

return pm.Deterministic(
name=name, var=mu_global + offset * sigma_global, dims=desired_dims
)


def create_distribution(
name: str,
distribution_name: str,
distribution_kwargs: dict[str, Any],
centered: bool | None = None,
**kwargs,
) -> pt.TensorVariable:
"""Create a PyMC distribution with the specified parameters.
Expand All @@ -392,9 +506,13 @@ def create_distribution(
-------
TensorVariable
A PyMC random variable.
"""
dim_handler = create_dim_handler(kwargs.get("dims"))
if centered is False:
if distribution_name != "Normal":
raise NonCenterInvalidDistributionError(distribution_name)
return create_hierarchical_non_center(name, distribution_kwargs, **kwargs)

dim_handler = create_dim_handler(kwargs.get("dims", ()))
parameter_distributions = handle_parameter_distributions(
name, distribution_kwargs, dim_handler=dim_handler
)
Expand Down Expand Up @@ -440,6 +558,7 @@ def create_distribution_from_config(name: str, config) -> pt.TensorVariable:
"""
parameter_config = config[name]
centered_flag = parameter_config.get("centered", True)
try:
dist_name = parameter_config["dist"]
dist_kwargs = parameter_config["kwargs"]
Expand All @@ -450,6 +569,7 @@ def create_distribution_from_config(name: str, config) -> pt.TensorVariable:
name,
dist_name,
dist_kwargs,
centered=centered_flag,
dims=parameter_config.get("dims"),
)

Expand Down
85 changes: 85 additions & 0 deletions tests/test_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,62 @@ def model_config():
},
"dims": ("channel", "control"),
},
# Hierarchical centered distribution
"hierarchical_centered": {
"dist": "Normal",
"kwargs": {
"mu": {
"dist": "Normal",
"kwargs": {
"mu": 0.0,
"sigma": 1.0,
},
"dims": "channel",
},
"sigma": {
"dist": "HalfNormal",
"kwargs": {
"sigma": 1.0,
},
"dims": "geo",
},
},
"dims": ("channel", "geo"),
"centered": True,
},
# Hierarchical non-centered distribution
"hierarchical_non_centered": {
"dist": "Normal",
"kwargs": {
"mu": {"dist": "HalfNormal", "kwargs": {"sigma": 2}},
"sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 1}},
},
"dims": "channel",
"centered": False,
},
# 2D Hierarchical non-centered distribution
"hierarchical_non_centered_2d": {
"dist": "Normal",
"kwargs": {
"mu": {
"dist": "Normal",
"kwargs": {
"mu": 0.0,
"sigma": 1.0,
},
"dims": "channel",
},
"sigma": {
"dist": "HalfNormal",
"kwargs": {
"sigma": 1.0,
},
"dims": "geo",
},
},
"dims": ("channel", "geo"),
"centered": False,
},
# Incorrect config
"error": {
"dist": "Normal",
Expand Down Expand Up @@ -224,6 +280,35 @@ def coords() -> dict[str, list[str]]:
("alpha", ["alpha", "alpha_mu", "alpha_sigma"], [(3,), (), ()]),
("gamma", ["gamma", "gamma_mu", "gamma_sigma"], [(3, 2), (3,), (2,)]),
("delta", ["delta"], [(3, 1)]),
(
"hierarchical_centered",
[
"hierarchical_centered",
"hierarchical_centered_mu",
"hierarchical_centered_sigma",
],
[(3, 2), (3,), (2,)],
),
(
"hierarchical_non_centered",
[
"hierarchical_non_centered",
"hierarchical_non_centered_mu",
"hierarchical_non_centered_sigma",
"hierarchical_non_centered_offset",
],
[(3,), (), (), (3,)],
),
(
"hierarchical_non_centered_2d",
[
"hierarchical_non_centered_2d",
"hierarchical_non_centered_2d_mu",
"hierarchical_non_centered_2d_sigma",
"hierarchical_non_centered_2d_offset",
],
[(3, 2), (3,), (2,), (3, 2)],
),
],
)
def test_create_distribution(
Expand Down

0 comments on commit c0b37ce

Please sign in to comment.