Skip to content

Commit

Permalink
Allow plot MMM components in the original scale (#870)
Browse files Browse the repository at this point in the history
* add original scale implementation

* add plot nb

* change location

* undo

* make mypy happy

* test plot

* add test

* update plot readme

* fix test

* improve variable description
  • Loading branch information
juanitorduz authored Jul 25, 2024
1 parent 9129a9e commit de5679f
Show file tree
Hide file tree
Showing 5 changed files with 1,446 additions and 1,288 deletions.
Binary file modified docs/source/_static/mmm_plot_components_contributions.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2,545 changes: 1,257 additions & 1,288 deletions docs/source/notebooks/mmm/mmm_example.ipynb

Large diffs are not rendered by default.

162 changes: 162 additions & 0 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,168 @@ def plot_channel_parameter(self, param_name: str, **plt_kwargs: Any) -> plt.Figu
)
return fig

def get_ts_contribution_posterior(
self, var_contribution: str, original_scale: bool = False
) -> DataArray:
"""Get the posterior distribution of the time series contributions of a given variable.
Parameters
----------
var_contribution : str
The variable for which to get the contributions. It must be a valid variable
in the `fit_result` attribute.
original_scale : bool, optional
Whether to plot in the original scale.
Returns
-------
DataArray
The posterior distribution of the time series contributions.
"""
contributions = self._format_model_contributions(
var_contribution=var_contribution
)

if original_scale:
return apply_sklearn_transformer_across_dim(
data=contributions,
func=self.get_target_transformer().inverse_transform,
dim_name="date",
)

return contributions

def plot_components_contributions(
self, original_scale: bool = False, **plt_kwargs: Any
) -> plt.Figure:
"""Plot the target variable and the posterior predictive model components in
the scaled space.
Parameters
----------
original_scale : bool, optional
Whether to plot in the original scale.
**plt_kwargs
Additional keyword arguments to pass to `plt.subplots`.
Returns
-------
plt.Figure
"""
channel_contributions = self.get_ts_contribution_posterior(
var_contribution="channel_contributions", original_scale=original_scale
)

means = [channel_contributions.mean(["chain", "draw"])]
contribution_vars = [
az.hdi(channel_contributions, hdi_prob=0.94).channel_contributions
]

for arg, var_contribution in zip(
["control_columns", "yearly_seasonality"],
["control_contributions", "fourier_contributions"],
strict=True,
):
if getattr(self, arg, None):
contributions = self.get_ts_contribution_posterior(
var_contribution=var_contribution, original_scale=original_scale
)

means.append(contributions.mean(["chain", "draw"]))
contribution_vars.append(
az.hdi(contributions, hdi_prob=0.94)[var_contribution]
)

fig, ax = plt.subplots(**plt_kwargs)

for i, (mean, hdi, var_contribution) in enumerate(
zip(
means,
contribution_vars,
[
"channel_contribution",
"control_contribution",
"fourier_contribution",
],
strict=False,
)
):
if self.X is not None:
ax.fill_between(
x=self.X[self.date_column],
y1=hdi.isel(hdi=0),
y2=hdi.isel(hdi=1),
color=f"C{i}",
alpha=0.25,
label=f"$94\\%$ HDI ({var_contribution})",
)
ax.plot(
np.asarray(self.X[self.date_column]),
np.asarray(mean),
color=f"C{i}",
)
if self.X is not None:
intercept = az.extract(
self.fit_result, var_names=["intercept"], combined=False
)

if original_scale:
intercept = apply_sklearn_transformer_across_dim(
data=intercept,
func=self.get_target_transformer().inverse_transform,
dim_name="chain",
)

if intercept.ndim == 2:
# Intercept has a stationary prior
intercept_hdi = np.repeat(
a=az.hdi(intercept).intercept.data[None, ...],
repeats=self.X[self.date_column].shape[0],
axis=0,
)
elif intercept.ndim == 3:
# Intercept has a time-varying prior
intercept_hdi = az.hdi(intercept).intercept.data

ax.plot(
np.asarray(self.X[self.date_column]),
np.full(len(self.X[self.date_column]), intercept.mean().data),
color=f"C{i + 1}",
)
ax.fill_between(
x=self.X[self.date_column],
y1=intercept_hdi[:, 0],
y2=intercept_hdi[:, 1],
color=f"C{i + 1}",
alpha=0.25,
label="$94\\%$ HDI (intercept)",
)

y_to_plot = (
self.get_target_transformer().inverse_transform(
np.asarray(self.preprocessed_data["y"]).reshape(-1, 1)
)
if original_scale
else np.asarray(self.preprocessed_data["y"])
)

ylabel = self.output_var if original_scale else f"{self.output_var} scaled"

ax.plot(
np.asarray(self.X[self.date_column]),
y_to_plot,
label=ylabel,
color="black",
)
ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.1), ncol=3)
ax.set(
title="Posterior Predictive Model Components",
xlabel="date",
ylabel=ylabel,
)
return fig

def plot_channel_contributions_grid(
self,
start: float,
Expand Down
25 changes: 25 additions & 0 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,31 @@ def test_channel_contributions_forward_pass_recovers_contribution(
y=mmm_fitted.y.max(),
)

@pytest.mark.parametrize(
argnames="original_scale",
argvalues=[False, True],
ids=["scaled", "original-scale"],
)
@pytest.mark.parametrize(
argnames="var_contribution",
argvalues=["channel_contributions", "control_contributions"],
ids=["channel_contribution", "control_contribution"],
)
def test_get_ts_contribution_posterior(
self,
mmm_fitted_with_posterior_predictive: MMM,
var_contribution: str,
original_scale: bool,
):
ts_posterior = (
mmm_fitted_with_posterior_predictive.get_ts_contribution_posterior(
var_contribution=var_contribution, original_scale=original_scale
)
)
assert ts_posterior.dims == ("chain", "draw", "date")
assert ts_posterior.chain.size == 1
assert ts_posterior.draw.size == 500

@pytest.mark.parametrize(
argnames="original_scale",
argvalues=[False, True],
Expand Down
2 changes: 2 additions & 0 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ def mock_fitted_mmm(mock_mmm, toy_X, toy_y):
("plot_direct_contribution_curves", {"same_axes": True}),
("plot_direct_contribution_curves", {"channels": ["channel_2"]}),
("plot_channel_parameter", {"param_name": "adstock_alpha"}),
("plot_components_contributions", {}),
("plot_components_contributions", {"original_scale": True}),
],
)
def test_delayed_saturated_mmm_plots(
Expand Down

0 comments on commit de5679f

Please sign in to comment.