Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hierarchical Model Configuration #743

Merged
merged 16 commits into from
Jun 13, 2024
1 change: 1 addition & 0 deletions docs/source/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@

clv
mmm
model_config
```
24 changes: 15 additions & 9 deletions pymc_marketing/mmm/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import warnings
from collections.abc import Generator, MutableMapping, Sequence
from copy import deepcopy
from inspect import signature
from itertools import product
from typing import Any
Expand All @@ -35,7 +36,11 @@
from pymc.distributions.shape_utils import Dims
from pytensor import tensor as pt

from pymc_marketing.mmm.utils import _get_distribution_from_dict
from pymc_marketing.model_config import (
DimHandler,
create_dim_handler,
create_distribution,
)

Values = Sequence[Any] | npt.NDArray[Any]
Coords = dict[str, Values]
Expand Down Expand Up @@ -154,7 +159,7 @@ class Transformation:
def __init__(self, priors: dict | None = None, prefix: str | None = None) -> None:
self._checks()
priors = priors or {}
self.function_priors = {**self.default_priors, **priors}
self.function_priors = {**deepcopy(self.default_priors), **priors}
self.prefix = prefix or self.prefix

def update_priors(self, priors: dict[str, Any]) -> None:
Expand Down Expand Up @@ -271,20 +276,21 @@ def variable_mapping(self) -> dict[str, str]:
def _create_distributions(
self, dims: Dims | None = None
) -> dict[str, pt.TensorVariable]:
dim_handler: DimHandler = create_dim_handler(dims)
distributions: dict[str, pt.TensorVariable] = {}
for parameter_name, variable_name in self.variable_mapping.items():
parameter_prior = self.function_priors[parameter_name]

distribution = _get_distribution_from_dict(
dist=parameter_prior,
)

distributions[parameter_name] = distribution(
var_dims = parameter_prior.get("dims")
var = create_distribution(
name=variable_name,
dims=dims,
**parameter_prior["kwargs"],
distribution_name=parameter_prior["dist"],
distribution_kwargs=parameter_prior["kwargs"],
dims=var_dims,
)

distributions[parameter_name] = dim_handler(var, var_dims)

return distributions

def sample_prior(
Expand Down
30 changes: 30 additions & 0 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,36 @@ def function(self, x, b):
saturation.plot_curve(curve)
plt.show()

Define a hierarchical saturation function with only hierarchical parameters
for saturation parameter of logistic saturation.

.. code-block:: python

from pymc_marketing.mmm import LogisticSaturation

priors = {
"lam": {
"dist": "Gamma",
"kwargs": {
"alpha": {
"dist": "HalfNormal",
"kwargs": {"sigma": 1},
},
"beta": {
"dist": "HalfNormal",
"kwargs": {"sigma": 1},
},
},
"dims": "channel",
},
"beta": {
"dist": "HalfNormal",
"kwargs": {"sigma": 1},
"dims": "channel",
},
}
saturation = LogisticSaturation(priors=priors)

"""

import numpy as np
Expand Down
186 changes: 55 additions & 131 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import pymc as pm
import pytensor.tensor as pt
import seaborn as sns
from pytensor.tensor import TensorVariable
from xarray import DataArray, Dataset

from pymc_marketing.constants import DAYS_IN_YEAR
Expand All @@ -47,12 +46,16 @@
from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget
from pymc_marketing.mmm.tvp import create_time_varying_intercept, infer_time_index
from pymc_marketing.mmm.utils import (
_get_distribution_from_dict,
apply_sklearn_transformer_across_dim,
create_new_spend_data,
generate_fourier_modes,
)
from pymc_marketing.mmm.validating import ValidateControlColumns
from pymc_marketing.model_config import (
create_distribution_from_config,
create_likelihood_distribution,
get_distribution,
)

__all__ = ["BaseMMM", "MMM", "DelayedSaturatedMMM"]

Expand Down Expand Up @@ -236,112 +239,6 @@
idata.attrs["validate_data"] = json.dumps(self.validate_data)
idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)

def _create_likelihood_distribution(
self,
dist: dict,
mu: TensorVariable,
observed: np.ndarray | pd.Series,
dims: str,
) -> TensorVariable:
"""
Create and return a likelihood distribution for the model.

This method prepares the distribution and its parameters as specified in the
configuration dictionary, validates them, and constructs the likelihood
distribution using PyMC.

Parameters
----------
dist : Dict
A configuration dictionary that must contain a 'dist' key with the name of
the distribution and a 'kwargs' key with parameters for the distribution.
observed : Union[np.ndarray, pd.Series]
The observed data to which the likelihood distribution will be fitted.
dims : str
The dimensions of the data.

Returns
-------
TensorVariable
The likelihood distribution constructed with PyMC.

Raises
------
ValueError
If 'kwargs' key is missing in `dist`, or the parameter configuration does
not contain 'dist' and 'kwargs' keys, or if 'mu' is present in the nested
'kwargs'
"""
allowed_distributions = [
"Normal",
"StudentT",
"Laplace",
"Logistic",
"LogNormal",
"Wald",
"TruncatedNormal",
"Gamma",
"AsymmetricLaplace",
"VonMises",
]

if dist["dist"] not in allowed_distributions:
raise ValueError(
f"""
The distribution used for the likelihood is not allowed.
Please, use one of the following distributions: {allowed_distributions}.
"""
)

# Validate that 'kwargs' is present and is a dictionary
if "kwargs" not in dist or not isinstance(dist["kwargs"], dict):
raise ValueError(
"The 'kwargs' key must be present in the 'dist' dictionary and be a dictionary itself."
)

if "mu" in dist["kwargs"]:
raise ValueError(
"The 'mu' key is not allowed directly within 'kwargs' of the main distribution as it is reserved."
)

parameter_distributions = {}
for param, param_config in dist["kwargs"].items():
# Check if param_config is a dictionary with a 'dist' key
if isinstance(param_config, dict) and "dist" in param_config:
# Prepare nested distribution
if "kwargs" not in param_config:
raise ValueError(
f"The parameter configuration for '{param}' must contain 'kwargs'."
)

parameter_distributions[param] = _get_distribution_from_dict(
dist=param_config
)(**param_config["kwargs"], name=f"likelihood_{param}")
elif isinstance(param_config, int | float):
# Use the value directly
parameter_distributions[param] = param_config
else:
raise ValueError(
f"""
Invalid parameter configuration for '{param}'.
It must be either a dictionary with a 'dist' key or a numeric value.
"""
)

# Extract the likelihood distribution name and instantiate it
likelihood_dist_name = dist["dist"]
likelihood_dist = _get_distribution_from_dict(
dist={"dist": likelihood_dist_name}
)

return likelihood_dist(
name=self.output_var,
mu=mu,
observed=observed,
dims=dims,
**parameter_distributions,
)

def forward_pass(
self, x: pt.TensorVariable | npt.NDArray[np.float_]
) -> pt.TensorVariable:
Expand Down Expand Up @@ -429,16 +326,6 @@
)
"""

self.intercept_dist = _get_distribution_from_dict(
dist=self.model_config["intercept"]
)
self.gamma_control_dist = _get_distribution_from_dict(
dist=self.model_config["gamma_control"]
)
self.gamma_fourier_dist = _get_distribution_from_dict(
dist=self.model_config["gamma_fourier"]
)

self._generate_and_preprocess_model_data(X, y)
with pm.Model(
coords=self.model_coords,
Expand All @@ -464,16 +351,19 @@
self._time_index,
dims="date",
)
intercept_dist = get_distribution(
name=self.model_config["intercept"]["dist"]
)
intercept = create_time_varying_intercept(
time_index,
self._time_index_mid,
self._time_resolution,
self.intercept_dist,
intercept_dist,
self.model_config,
)
else:
intercept = self.intercept_dist(
name="intercept", **self.model_config["intercept"]["kwargs"]
intercept = create_distribution_from_config(
name="intercept", config=self.model_config
)

channel_contributions = pm.Deterministic(
Expand All @@ -492,10 +382,17 @@
for column in self.control_columns
)
):
gamma_control = self.gamma_control_dist(
if self.model_config["gamma_control"].get("dims") != "control":
msg = (

Check warning on line 386 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L386

Added line #L386 was not covered by tests
"The 'dims' key in gamma_control must be 'control'."
" This will be fixed automatically."
)
warnings.warn(msg, stacklevel=2)
self.model_config["gamma_control"]["dims"] = "control"

Check warning on line 391 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L390-L391

Added lines #L390 - L391 were not covered by tests

gamma_control = create_distribution_from_config(
name="gamma_control",
dims="control",
**self.model_config["gamma_control"]["kwargs"],
config=self.model_config,
)

control_data_ = pm.Data(
Expand Down Expand Up @@ -529,10 +426,17 @@
mutable=True,
)

gamma_fourier = self.gamma_fourier_dist(
if self.model_config["gamma_fourier"].get("dims") != "fourier_mode":
msg = (
"The 'dims' key in gamma_fourier must be 'fourier_mode'."
" This will be fixed automatically."
)
warnings.warn(msg, stacklevel=2)
self.model_config["gamma_fourier"]["dims"] = "fourier_mode"

gamma_fourier = create_distribution_from_config(
name="gamma_fourier",
dims="fourier_mode",
**self.model_config["gamma_fourier"]["kwargs"],
config=self.model_config,
)

fourier_contribution = pm.Deterministic(
Expand All @@ -551,8 +455,9 @@

mu = pm.Deterministic(name="mu", var=mu_var, dims="date")

self._create_likelihood_distribution(
dist=self.model_config["likelihood"],
create_likelihood_distribution(
name=self.output_var,
param_config=self.model_config["likelihood"],
mu=mu,
observed=target_,
dims="date",
Expand All @@ -568,8 +473,16 @@
"sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}},
},
},
"gamma_control": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}},
"gamma_fourier": {"dist": "Laplace", "kwargs": {"mu": 0, "b": 1}},
"gamma_control": {
"dist": "Normal",
"kwargs": {"mu": 0, "sigma": 2},
"dims": "control",
},
"gamma_fourier": {
"dist": "Laplace",
"kwargs": {"mu": 0, "b": 1},
"dims": "fourier_mode",
},
"intercept_tvp_kwargs": {
"m": 200,
"L": None,
Expand All @@ -580,6 +493,17 @@
},
}

for media_transform in [self.adstock, self.saturation]:
for param, config in media_transform.function_priors.items():
if "dims" not in config:
msg = (
f"{param} doesn't have a 'dims' key in config. Setting to channel."
f" Set priors explicitly in {media_transform.__class__.__name__}"
" to avoid this warning."
)
warnings.warn(msg, stacklevel=2)
config["dims"] = "channel"

return {
**base_config,
**self.adstock.model_config,
Expand Down
Loading
Loading