Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MMM NB Improvements (waterfall & error plots) #664

Merged
merged 30 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a889a22
model spec
juanitorduz May 2, 2024
0f18e5c
changes init
juanitorduz May 2, 2024
7aa5b34
improveements
juanitorduz May 2, 2024
621e4bf
try othere way of adding links
juanitorduz May 2, 2024
23269d5
make color cohorent with the color palette
juanitorduz May 2, 2024
73367ba
add link to clases
juanitorduz May 2, 2024
02b55e8
Merge branch 'main' into mmm_nb_improvements
juanitorduz May 2, 2024
41aea18
add new spends plot
juanitorduz May 2, 2024
1811a7a
Merge branch 'main' into mmm_nb_improvements
juanitorduz May 2, 2024
5540b0e
Merge branch 'main' into mmm_nb_improvements
juanitorduz May 2, 2024
3b47eeb
add feedback part 1
juanitorduz May 2, 2024
26df3cf
Merge branch 'main' into mmm_nb_improvements
juanitorduz May 3, 2024
f3cb12b
Merge branch 'main' into mmm_nb_improvements
juanitorduz May 3, 2024
57c9bc0
Merge branch 'main' into mmm_nb_improvements
juanitorduz May 6, 2024
792ea92
add errors plot
juanitorduz May 7, 2024
14680ed
Merge branch 'mmm_nb_improvements' of https://github.com/pymc-labs/py…
juanitorduz May 7, 2024
029da1f
typo
juanitorduz May 7, 2024
2111c25
Update pymc_marketing/mmm/base.py
juanitorduz May 7, 2024
d573082
modularize code
juanitorduz May 7, 2024
afab68e
Merge branch 'mmm_nb_improvements' of https://github.com/pymc-labs/py…
juanitorduz May 7, 2024
78b8627
clean code
juanitorduz May 7, 2024
0ade5f8
add some initial tests
juanitorduz May 7, 2024
771dec8
fix tests
juanitorduz May 7, 2024
a32d7d7
git test base class
juanitorduz May 7, 2024
e822a8b
improvee broadcasting
juanitorduz May 7, 2024
2de1778
add more tests
juanitorduz May 7, 2024
62acfa8
add errors formula
juanitorduz May 7, 2024
b33eeab
fix test
juanitorduz May 7, 2024
66b24ac
make dims consistent
juanitorduz May 8, 2024
2aef35f
Merge branch 'main' into mmm_nb_improvements
juanitorduz May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/notebooks/general/other_nuts_samplers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"id": "51e3591e",
"metadata": {},
"source": [
"(other_nuts_samplers)=\n",
"# Other NUTS Samplers\n",
"\n",
"In this notebook we show how to fit a CLV model with other NUTS samplers. These alternative samplers can be significantly faster and also sample on the GPU.\n",
Expand Down
2,901 changes: 1,581 additions & 1,320 deletions docs/source/notebooks/mmm/mmm_example.ipynb

Large diffs are not rendered by default.

117 changes: 115 additions & 2 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from pymc_marketing.mmm.budget_optimizer import budget_allocator
from pymc_marketing.mmm.transformers import michaelis_menten
from pymc_marketing.mmm.utils import (
apply_sklearn_transformer_across_dim,
estimate_menten_parameters,
estimate_sigmoid_parameters,
find_sigmoid_inflection_point,
Expand Down Expand Up @@ -337,6 +338,19 @@
def plot_posterior_predictive(
self, original_scale: bool = False, ax: plt.Axes = None, **plt_kwargs: Any
) -> plt.Figure:
"""Plot posterior distribution from the model fit.

Parameters
----------
original_scale : bool, optional
Whether to plot in the original scale.
ax : plt.Axes, optional
Matplotlib axis object.
wd60622 marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
plt.Figure
"""
posterior_predictive_data: Dataset = self.posterior_predictive
likelihood_hdi_94: DataArray = az.hdi(
ary=posterior_predictive_data, hdi_prob=0.94
Expand Down Expand Up @@ -394,7 +408,9 @@
np.asarray(posterior_predictive_data.date),
target_to_plot,
color="black",
label="Observed",
)
ax.legend()
ax.set(
title="Posterior Predictive Check",
xlabel="date",
Expand All @@ -404,6 +420,97 @@
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
return fig

def plot_errors(
self, original_scale: bool = False, ax: plt.Axes = None, **plt_kwargs: Any
) -> plt.Figure:
"""Plot model errors by taking the difference between true values and predicted.

Parameters
----------
original_scale : bool, optional
Whether to plot in the original scale.
ax : plt.Axes, optional
Matplotlib axis object.

Returns
-------
plt.Figure
"""
posterior_predictive_data: Dataset = self.posterior_predictive

target = np.asarray(
transform_1d_array(self.get_target_transformer().transform, self.y)
)

if len(target) != len(posterior_predictive_data.date):
raise ValueError(

Check warning on line 446 in pymc_marketing/mmm/base.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/base.py#L446

Added line #L446 was not covered by tests
"The length of the target variable doesn't match the length of the date column. "
"If you are computing out-of-sample errors, please overwrite `self.y` with the "
"corresponding (non-transformed) target variable."
)

target_broadcast = np.atleast_1d(target)[np.newaxis, np.newaxis, ...]
errors = target_broadcast - posterior_predictive_data
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved

errors_hdi_94: DataArray = az.hdi(ary=errors, hdi_prob=0.94)[self.output_var]
errors_hdi_50: DataArray = az.hdi(ary=errors, hdi_prob=0.50)[self.output_var]

if original_scale:
errors = apply_sklearn_transformer_across_dim(
data=errors,
func=self.get_target_transformer().inverse_transform,
dim_name="date",
)
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved

errors_hdi_94 = self.get_target_transformer().inverse_transform(
Xt=errors_hdi_94
)
errors_hdi_50 = self.get_target_transformer().inverse_transform(
Xt=errors_hdi_50
)
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved

if ax is None:
fig, ax = plt.subplots(**plt_kwargs)
else:
fig = ax.figure

if self.X is not None and self.y is not None:
ax.fill_between(
x=posterior_predictive_data.date,
y1=errors_hdi_94[:, 0],
y2=errors_hdi_94[:, 1],
color="C3",
alpha=0.2,
label="$94\%$ HDI", # noqa: W605
)

ax.fill_between(
x=posterior_predictive_data.date,
y1=errors_hdi_50[:, 0],
y2=errors_hdi_50[:, 1],
color="C3",
alpha=0.3,
label="$50\%$ HDI", # noqa: W605
)

ax.plot(
posterior_predictive_data.date,
errors[self.output_var].mean(dim=("chain", "draw")).to_numpy(),
color="C3",
label="Errors Mean",
)

ax.axhline(y=0.0, linestyle="--", color="black", label="zero")
ax.legend()
ax.set(
title="Errors Posterior Distribution",
xlabel="date",
ylabel="true - predictions",
)
else:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")

Check warning on line 511 in pymc_marketing/mmm/base.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/base.py#L511

Added line #L511 was not covered by tests
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
return fig

def _format_model_contributions(self, var_contribution: str) -> DataArray:
contributions = az.extract(
self.fit_result,
Expand Down Expand Up @@ -1411,14 +1518,20 @@
cumulative_contribution = 0

for index, row in dataframe.iterrows():
color = "lightblue" if row["contribution"] >= 0 else "salmon"
color = "C0" if row["contribution"] >= 0 else "C3"

bar_start = (
cumulative_contribution + row["contribution"]
if row["contribution"] < 0
else cumulative_contribution
)
ax.barh(row["component"], row["contribution"], left=bar_start, color=color)
ax.barh(
row["component"],
row["contribution"],
left=bar_start,
color=color,
alpha=0.5,
)

if row["contribution"] > 0:
cumulative_contribution += row["contribution"]
Expand Down
3 changes: 3 additions & 0 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ class ToyMMM(BaseDelayedSaturatedMMM, MaxAbsScaleTarget):
("plot_posterior_predictive", {}),
("plot_posterior_predictive", {"original_scale": True}),
("plot_posterior_predictive", {"ax": plt.subplots()[1]}),
("plot_errors", {}),
("plot_errors", {"original_scale": True}),
("plot_errors", {"ax": plt.subplots()[1]}),
("plot_components_contributions", {}),
("plot_channel_parameter", {"param_name": "alpha"}),
("plot_waterfall_components_decomposition", {"original_scale": True}),
Expand Down
Loading