Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pull out seasonality as YearlyFourier and MonthlyFourier #802

Merged
merged 18 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
738 changes: 557 additions & 181 deletions docs/source/notebooks/mmm/mmm_components.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pymc_marketing/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
DAYS_IN_YEAR: float = 365.25
DAYS_IN_MONTH: float = DAYS_IN_YEAR / 12
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
33 changes: 18 additions & 15 deletions pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,36 @@
TanhSaturationBaselined,
)
from pymc_marketing.mmm.delayed_saturated_mmm import MMM, DelayedSaturatedMMM
from pymc_marketing.mmm.fourier import MonthlyFourier, YearlyFourier
from pymc_marketing.mmm.preprocessing import (
preprocessing_method_X,
preprocessing_method_y,
)
from pymc_marketing.mmm.validating import validation_method_X, validation_method_y

__all__ = [
"base",
"delayed_saturated_mmm",
"preprocessing",
"validating",
"MMM",
"MMMModelBuilder",
"BaseValidateMMM",
"DelayedSaturatedMMM",
"preprocessing_method_X",
"preprocessing_method_y",
"validation_method_X",
"validation_method_y",
"AdstockTransformation",
"BaseValidateMMM",
"DelayedAdstock",
"DelayedSaturatedMMM",
"GeometricAdstock",
"WeibullAdstock",
"SaturationTransformation",
"MichaelisMentenSaturation",
"HillSaturation",
"LogisticSaturation",
"MMM",
"MMMModelBuilder",
"MichaelisMentenSaturation",
"MonthlyFourier",
"SaturationTransformation",
"TanhSaturation",
"TanhSaturationBaselined",
"WeibullAdstock",
"YearlyFourier",
"base",
"delayed_saturated_mmm",
"preprocessing",
"preprocessing_method_X",
"preprocessing_method_y",
"validating",
"validation_method_X",
"validation_method_y",
]
131 changes: 22 additions & 109 deletions pymc_marketing/mmm/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@
"""

import warnings
from collections.abc import Generator, MutableMapping, Sequence
from copy import deepcopy
from inspect import signature
from itertools import product
from typing import Any

import arviz as az
Expand All @@ -36,55 +34,17 @@
from pymc.distributions.shape_utils import Dims
from pytensor import tensor as pt

from pymc_marketing.mmm.plot import (
plot_hdi,
plot_samples,
)
from pymc_marketing.model_config import parse_model_config
from pymc_marketing.prior import DimHandler, Prior, create_dim_handler

Values = Sequence[Any] | npt.NDArray[Any]
Coords = dict[str, Values]

# chain and draw from sampling
# "x" for saturation, "time since exposure" for adstock
NON_GRID_NAMES = {"chain", "draw", "x", "time since exposure"}


def get_plot_coords(coords: Coords) -> Coords:
plot_coord_names = list(key for key in coords.keys() if key not in NON_GRID_NAMES)
return {name: np.array(coords[name]) for name in plot_coord_names}


def get_total_coord_size(coords: Coords) -> int:
total_size: int = (
1 if coords == {} else np.prod([len(values) for values in coords.values()]) # type: ignore
)
if total_size >= 12:
warnings.warn("Large number of coordinates!", stacklevel=2)

return total_size


def set_subplot_kwargs_defaults(
subplot_kwargs: MutableMapping[str, Any],
total_size: int,
) -> None:
if "ncols" in subplot_kwargs and "nrows" in subplot_kwargs:
raise ValueError("Only specify one")

if "ncols" not in subplot_kwargs and "nrows" not in subplot_kwargs:
subplot_kwargs["ncols"] = total_size

if "ncols" in subplot_kwargs:
subplot_kwargs["nrows"] = total_size // subplot_kwargs["ncols"]
elif "nrows" in subplot_kwargs:
subplot_kwargs["ncols"] = total_size // subplot_kwargs["nrows"]


def selections(
coords: Coords,
) -> Generator[dict[str, Any], None, None]:
"""Helper to create generator of selections."""
coord_names = coords.keys()
for values in product(*coords.values()):
yield {name: value for name, value in zip(coord_names, values, strict=True)}
# lower, higher from hdi
NON_GRID_NAMES = {"chain", "draw", "x", "time since exposure", "hdi"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not know if it is overkill, but could we use from types import MappingProxyType for these immutable dictionaries?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched them to frozensets



class ParameterPriorException(Exception):
Expand Down Expand Up @@ -439,36 +399,15 @@ def plot_curve_samples(
The axes with the plot.

"""
plot_coords = get_plot_coords(curve.coords)
total_size = get_total_coord_size(plot_coords)

if axes is None:
subplot_kwargs = subplot_kwargs or {}
set_subplot_kwargs_defaults(subplot_kwargs, total_size)
fig, axes = plt.subplots(**subplot_kwargs)
else:
fig = plt.gcf()

plot_kwargs = plot_kwargs or {}
plot_kwargs["alpha"] = plot_kwargs.get("alpha", 0.3)
plot_kwargs["legend"] = False

for i, (ax, sel) in enumerate(
zip(np.ravel(axes), selections(plot_coords), strict=False)
):
color = f"C{i}"

df_curve = curve.sel(sel).to_series().unstack()
df_sample = df_curve.sample(n=n, random_state=rng)

df_sample.T.plot(ax=ax, color=color, **plot_kwargs)
title = ", ".join(f"{name}={value}" for name, value in sel.items())
ax.set_title(title)

if not isinstance(axes, np.ndarray):
axes = np.array([axes])

return fig, axes
return plot_samples(
curve,
non_grid_names=NON_GRID_NAMES,
n=n,
rng=rng,
axes=axes,
subplot_kwargs=subplot_kwargs,
plot_kwargs=plot_kwargs,
)

def plot_curve_hdi(
self,
Expand Down Expand Up @@ -498,42 +437,16 @@ def plot_curve_hdi(
tuple[plt.Figure, npt.NDArray[plt.Axes]]

"""
plot_coords = get_plot_coords(curve.coords)
total_size = get_total_coord_size(plot_coords)

hdi_kwargs = hdi_kwargs or {}
conf = az.hdi(curve, **hdi_kwargs)[curve.name]

if axes is None:
subplot_kwargs = subplot_kwargs or {}
set_subplot_kwargs_defaults(subplot_kwargs, total_size)
fig, axes = plt.subplots(**subplot_kwargs)
else:
fig = plt.gcf()

plot_kwargs = plot_kwargs or {}
plot_kwargs["alpha"] = plot_kwargs.get("alpha", 0.3)

for i, (ax, sel) in enumerate(
zip(np.ravel(axes), selections(plot_coords), strict=False)
):
color = f"C{i}"
df_conf = conf.sel(sel).to_series().unstack()

ax.fill_between(
x=df_conf.index,
y1=df_conf["lower"],
y2=df_conf["higher"],
color=color,
**plot_kwargs,
)
title = ", ".join(f"{name}={value}" for name, value in sel.items())
ax.set_title(title)

if not isinstance(axes, np.ndarray):
axes = np.array([axes])

return fig, axes
return plot_hdi(
conf,
non_grid_names=NON_GRID_NAMES,
axes=axes,
subplot_kwargs=subplot_kwargs,
plot_kwargs=plot_kwargs,
)

def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> pt.TensorVariable:
"""Called within a model context.
Expand Down
90 changes: 20 additions & 70 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import seaborn as sns
from xarray import DataArray, Dataset

from pymc_marketing.constants import DAYS_IN_YEAR
from pymc_marketing.mmm.base import BaseValidateMMM
from pymc_marketing.mmm.budget_optimizer import BudgetOptimizer
from pymc_marketing.mmm.components.adstock import (
Expand All @@ -39,6 +38,7 @@
SaturationTransformation,
_get_saturation_function,
)
from pymc_marketing.mmm.fourier import YearlyFourier
from pymc_marketing.mmm.lift_test import (
add_lift_measurements_to_likelihood_from_saturation,
scale_lift_measurements,
Expand All @@ -48,7 +48,6 @@
from pymc_marketing.mmm.utils import (
apply_sklearn_transformer_across_dim,
create_new_spend_data,
generate_fourier_modes,
)
from pymc_marketing.mmm.validating import ValidateControlColumns
from pymc_marketing.model_config import parse_model_config
Expand Down Expand Up @@ -127,7 +126,6 @@ def __init__(
self.adstock_max_lag = adstock_max_lag
self.time_varying_intercept = time_varying_intercept
self.time_varying_media = time_varying_media
self.yearly_seasonality = yearly_seasonality
self.date_column = date_column
self.validate_data = validate_data

Expand All @@ -145,6 +143,14 @@ def __init__(
self.adstock.update_priors({**self.default_model_config, **model_config})
self.saturation.update_priors({**self.default_model_config, **model_config})

self.yearly_seasonality = yearly_seasonality
if self.yearly_seasonality is not None:
self.yearly_fourier = YearlyFourier(
n_order=self.yearly_seasonality,
prefix="fourier_mode",
prior=self.model_config["gamma_fourier"],
)

super().__init__(
date_column=date_column,
channel_columns=channel_columns,
Expand Down Expand Up @@ -214,13 +220,6 @@ def _generate_and_preprocess_model_data( # type: ignore
coords["control"] = self.control_columns
X_data = pd.concat([X_data, control_data], axis=1)

fourier_features: pd.DataFrame | None = None
if self.yearly_seasonality is not None:
fourier_features = self._get_fourier_models_data(X=X)
self.fourier_columns = fourier_features.columns
coords["fourier_mode"] = fourier_features.columns.to_numpy()
X_data = pd.concat([X_data, fourier_features], axis=1)

self.model_coords = coords
if self.validate_data:
self.validate("X", X_data)
Expand Down Expand Up @@ -452,38 +451,19 @@ def build_model(

mu_var += control_contributions.sum(axis=-1)

if (
hasattr(self, "fourier_columns")
and self.fourier_columns is not None
and len(self.fourier_columns) > 0
and all(
column in self.preprocessed_data["X"].columns
for column in self.fourier_columns
)
):
fourier_data_ = pm.Data(
name="fourier_data",
value=self.preprocessed_data["X"][self.fourier_columns],
dims=("date", "fourier_mode"),
mutable=True,
)
if self.model_config["gamma_fourier"].dims != ("fourier_mode",):
self.model_config["gamma_fourier"].dims = "fourier_mode"

gamma_fourier = self.model_config["gamma_fourier"].create_variable(
name="gamma_fourier"
)

fourier_contribution = pm.Deterministic(
name="fourier_contributions",
var=fourier_data_ * gamma_fourier,
dims=("date", "fourier_mode"),
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
if self.yearly_seasonality is not None:
dayofyear = pm.Data(
name="dayofyear",
values=self.preprocessed_data["X"][
self.date_column
].dt.dayofyear.to_numpy(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we check somewhere that the date column indeed has a date type? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think there is a check or cohersion somewhere. It seems like the tvp code assumes it is datetime as well

dims="date",
)

yearly_seasonality_contribution = pm.Deterministic(
name="yearly_seasonality_contribution",
var=fourier_contribution.sum(axis=-1),
dims=("date"),
var=self.yearly_fourier.apply(dayofyear),
dims="date",
)

mu_var += yearly_seasonality_contribution
Expand Down Expand Up @@ -536,36 +516,6 @@ def default_model_config(self) -> dict:
**self.saturation.model_config,
}

def _get_fourier_models_data(self, X) -> pd.DataFrame:
"""Generates fourier modes to model seasonality.

Parameters
----------
X : Union[pd.DataFrame, pd.Series], shape (n_obs, n_features)
Input data for the model. To generate the Fourier modes, it must contain a date column.

Returns
-------
pd.DataFrame
Fourier modes (sin and cos with different frequencies) as columns in a dataframe.

References
----------
https://www.pymc.io/projects/examples/en/latest/time_series/Air_passengers-Prophet_with_Bayesian_workflow.html
"""
if self.yearly_seasonality is None:
raise ValueError("yearly_seasonality must be specified.")
date_data: pd.Series = pd.to_datetime(
arg=X[self.date_column], format="%Y-%m-%d"
)
periods: npt.NDArray[np.float64] = (
date_data.dt.dayofyear.to_numpy() / DAYS_IN_YEAR
)
return generate_fourier_modes(
periods=periods,
n_order=self.yearly_seasonality,
)

def channel_contributions_forward_pass(
self, channel_data: npt.NDArray[np.float64]
) -> npt.NDArray[np.float64]:
Expand Down Expand Up @@ -735,8 +685,8 @@ def identity(x):
)
data["control_data"] = control_transformation(control_data)

if hasattr(self, "fourier_columns"):
data["fourier_data"] = self._get_fourier_models_data(X)
if self.yearly_seasonality is not None:
data["dayofyear"] = X[self.date_column].dt.dayofyear.to_numpy()

if self.time_varying_intercept | self.time_varying_media:
data["time_index"] = infer_time_index(
Expand Down
Loading
Loading