Skip to content

Commit

Permalink
DelayedSaturatedMMM deprecations and moving files (#965)
Browse files Browse the repository at this point in the history
* deprecations and moving files

* Update UML Diagrams

* change the imports in notebooks

* push up the code / test changes. need to run

* remove _get_\w*_function tests

* rerun the tvp notebook

* remove stale test

* move away from string initialization

* change the tvp media example
  • Loading branch information
wd60622 authored Aug 24, 2024
1 parent 7d3b832 commit 288b8e8
Show file tree
Hide file tree
Showing 16 changed files with 393 additions and 740 deletions.
9 changes: 5 additions & 4 deletions docs/source/notebooks/mmm/mmm_budget_allocation_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from pymc_marketing.mmm.delayed_saturated_mmm import MMM\n",
"from pymc_marketing.mmm import MMM\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
Expand Down Expand Up @@ -89,7 +89,7 @@
"Once the model has been trained, it is easy to save for later use. An example of the \".save\" method is demonstrated below to store the model at a designated [location](https://github.com/pymc-labs/pymc-marketing/tree/main/data).\n",
"\n",
"## Loading a Pre-Trained Model\n",
"To utilize a saved model, load it into a new instance of the DelayedSaturatedMMM class using the load method below."
"To utilize a saved model, load it into a new instance of the MMM class using the load method below."
]
},
{
Expand Down Expand Up @@ -1738,7 +1738,8 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
Expand All @@ -1755,5 +1756,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 0
"nbformat_minor": 4
}
30 changes: 14 additions & 16 deletions docs/source/notebooks/mmm/mmm_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
"import pymc as pm\n",
"import seaborn as sns\n",
"\n",
"from pymc_marketing.mmm.delayed_saturated_mmm import MMM\n",
"from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation\n",
"from pymc_marketing.mmm.transformers import geometric_adstock, logistic_saturation\n",
"\n",
"warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
Expand Down Expand Up @@ -979,15 +979,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We can specify the model structure using the {class}`MMM <pymc_marketing.mmm.delayed_saturated_mmm.MMM>` class. This class, handles a lot of internal boilerplate code for us such us scaling the data (see details below) and handy diagnostics and reporting plots. One great feature is that we can specify the channel priors distributions ourselves, which fundamental component of the [bayesian workflow](https://arxiv.org/abs/2011.01808) as we can incorporate our prior knowledge into the model. This is one of the most important advantages of using a bayesian approach. Let's see how we can do it.\n",
"We can specify the model structure using the {class}`MMM <pymc_marketing.mmm.mmm.MMM>` class. This class, handles a lot of internal boilerplate code for us such us scaling the data (see details below) and handy diagnostics and reporting plots. One great feature is that we can specify the channel priors distributions ourselves, which fundamental component of the [bayesian workflow](https://arxiv.org/abs/2011.01808) as we can incorporate our prior knowledge into the model. This is one of the most important advantages of using a bayesian approach. Let's see how we can do it.\n",
"\n",
"As we do not know much more about the channels, we start with a simple heuristic: \n",
"\n",
"1. The channel contributions should be positive, so we can for example use a {class}`HalfNormal <pymc.distributions.continuous.HalfNormal>` distribution as prior. We need to set the `sigma` parameter per channel. The higher the `sigma`, the more \"freedom\" it has to fit the data. To specify `sigma` we can use the following point.\n",
"\n",
"2. We expect channels where we spend the most to have more attributed sales , before seeing the data. This is a very reasonable assumption (note that we are not imposing anything at the level of efficiency!).\n",
"\n",
"How to incorporate this heuristic into the model? To begin with, it is important to note that the {class}`MMM <pymc_marketing.mmm.delayed_saturated_mmm.MMM>` class scales the target and input variables through an [`MaxAbsScaler`](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MaxAbsScaler.html) transformer from [`scikit-learn`](https://scikit-learn.org/stable/), its important to specify the priors in the scaled space (i.e. between 0 and 1). One way to do it is to use the spend share as the `sigma` parameter for the `HalfNormal` distribution. We can actually add a scaling factor to take into account the support of the distribution.\n",
"How to incorporate this heuristic into the model? To begin with, it is important to note that the {class}`MMM <pymc_marketing.mmm.mmm.MMM>` class scales the target and input variables through an [`MaxAbsScaler`](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MaxAbsScaler.html) transformer from [`scikit-learn`](https://scikit-learn.org/stable/), its important to specify the priors in the scaled space (i.e. between 0 and 1). One way to do it is to use the spend share as the `sigma` parameter for the `HalfNormal` distribution. We can actually add a scaling factor to take into account the support of the distribution.\n",
"\n",
"First, let's compute the share of spend per channel:"
]
Expand Down Expand Up @@ -1072,7 +1072,7 @@
"source": [
"You can use the optional parameter 'model_config' to apply your own priors to the model. Each entry in the 'model_config' contains a key that corresponds to a registered distribution name in our model. The value of the key is a dictionary that describes the input parameters of that specific distribution.\n",
"\n",
"If you're unsure how to define your own priors, you can use the 'default_model_config' property of {class}`MMM <pymc_marketing.mmm.delayed_saturated_mmm.MMM>` to see the required structure."
"If you're unsure how to define your own priors, you can use the 'default_model_config' property of {class}`MMM <pymc_marketing.mmm.mmm.MMM>` to see the required structure."
]
},
{
Expand Down Expand Up @@ -1101,9 +1101,8 @@
"dummy_model = MMM(\n",
" date_column=\"\",\n",
" channel_columns=[\"\"],\n",
" adstock=\"geometric\",\n",
" saturation=\"logistic\",\n",
" adstock_max_lag=4,\n",
" adstock=GeometricAdstock(l_max=4),\n",
" saturation=LogisticSaturation(),\n",
")\n",
"dummy_model.default_model_config"
]
Expand Down Expand Up @@ -1150,14 +1149,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"**Remark:** For the prior specification there is no right or wrong answer. It all depends on the data, the context and the assumptions you are willing to make. It is always recommended to do some prior predictive sampling and sensitivity analysis to check the impact of the priors on the posterior. We skip this here for the sake of simplicity. If you are not sure about specific priors, the {class}`MMM <pymc_marketing.mmm.delayed_saturated_mmm.MMM>` class has some default priors that you can use as a starting point."
"**Remark:** For the prior specification there is no right or wrong answer. It all depends on the data, the context and the assumptions you are willing to make. It is always recommended to do some prior predictive sampling and sensitivity analysis to check the impact of the priors on the posterior. We skip this here for the sake of simplicity. If you are not sure about specific priors, the {class}`MMM <pymc_marketing.mmm.mmm.MMM>` class has some default priors that you can use as a starting point."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Model sampler allows specifying set of parameters that will be passed to fit the same way as the `kwargs` are getting passed so far. It doesn't disable the fit kwargs, but rather extend them, to enable customizable and preservable configuration. By default the sampler_config for {class}`MMM <pymc_marketing.mmm.delayed_saturated_mmm.MMM>` is empty. But if you'd like to use it, you can define it like showed below: "
"Model sampler allows specifying set of parameters that will be passed to fit the same way as the `kwargs` are getting passed so far. It doesn't disable the fit kwargs, but rather extend them, to enable customizable and preservable configuration. By default the sampler_config for {class}`MMM <pymc_marketing.mmm.mmm.MMM>` is empty. But if you'd like to use it, you can define it like showed below: "
]
},
{
Expand All @@ -1173,7 +1172,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we are ready to use the {class}`MMM <pymc_marketing.mmm.delayed_saturated_mmm.MMM>` class to define the model."
"Now we are ready to use the {class}`MMM <pymc_marketing.mmm.mmm.MMM>` class to define the model."
]
},
{
Expand All @@ -1186,15 +1185,14 @@
" model_config=my_model_config,\n",
" sampler_config=my_sampler_config,\n",
" date_column=\"date_week\",\n",
" adstock=\"geometric\",\n",
" saturation=\"logistic\",\n",
" adstock=GeometricAdstock(l_max=8),\n",
" saturation=LogisticSaturation(),\n",
" channel_columns=[\"x1\", \"x2\"],\n",
" control_columns=[\n",
" \"event_1\",\n",
" \"event_2\",\n",
" \"t\",\n",
" ],\n",
" adstock_max_lag=8,\n",
" yearly_seasonality=2,\n",
")"
]
Expand Down Expand Up @@ -6348,7 +6346,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The {func}`fit_result <pymc_marketing.mmm.delayed_saturated_mmm.MMM.fit_result>` attribute contains the `pymc` trace object."
"The {func}`fit_result <pymc_marketing.mmm.mmm.MMM.fit_result>` attribute contains the `pymc` trace object."
]
},
{
Expand Down Expand Up @@ -9400,7 +9398,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The results look great! We therefore successfully recovered the true values from the data generation process. We have also seen how easy is to use the {class}`MMM <pymc_marketing.mmm.delayed_saturated_mmm.MMM>` class to fit media mix models! It takes over the model specification and the media transformations, while having all the flexibility of `pymc`!"
"The results look great! We therefore successfully recovered the true values from the data generation process. We have also seen how easy is to use the {class}`MMM <pymc_marketing.mmm.mmm.MMM>` class to fit media mix models! It takes over the model specification and the media transformations, while having all the flexibility of `pymc`!"
]
},
{
Expand Down Expand Up @@ -10443,7 +10441,7 @@
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand Down
11 changes: 5 additions & 6 deletions docs/source/notebooks/mmm/mmm_lift_test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"import pandas as pd\n",
"import pymc as pm\n",
"\n",
"from pymc_marketing.mmm import MMM\n",
"from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation\n",
"from pymc_marketing.mmm.transformers import logistic_saturation"
]
},
Expand Down Expand Up @@ -228,9 +228,8 @@
"mmm = MMM(\n",
" date_column=\"date\",\n",
" channel_columns=[\"channel 1\", \"channel 2\"],\n",
" adstock_max_lag=6,\n",
" adstock=\"geometric\",\n",
" saturation=\"logistic\",\n",
" adstock=GeometricAdstock(l_max=6),\n",
" saturation=LogisticSaturation(),\n",
")"
]
},
Expand Down Expand Up @@ -1795,7 +1794,7 @@
],
"source": [
"%load_ext watermark\n",
"%watermark -n -u -v -iv -w -p pymc_marketing,pytensor"
"%watermark -n -u -v -iv -w -p pymc_marketing -p pytensor"
]
}
],
Expand All @@ -1815,7 +1814,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion docs/source/notebooks/mmm/mmm_roas.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"import seaborn as sns\n",
"\n",
"from pymc_marketing.hsgp_kwargs import HSGPKwargs\n",
"from pymc_marketing.mmm.delayed_saturated_mmm import (\n",
"from pymc_marketing.mmm import (\n",
" MMM,\n",
" GeometricAdstock,\n",
" LogisticSaturation,\n",
Expand Down
24 changes: 10 additions & 14 deletions docs/source/notebooks/mmm/mmm_time_varying_media_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
"import pymc as pm\n",
"import seaborn as sns\n",
"\n",
"from pymc_marketing.mmm import MMM\n",
"from pymc_marketing.mmm import MMM, GeometricAdstock, MichaelisMentenSaturation\n",
"from pymc_marketing.prior import Prior\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
Expand Down Expand Up @@ -292,10 +292,9 @@
" date_column=\"date_week\",\n",
" channel_columns=[\"x1\", \"x2\"],\n",
" control_columns=[\"event_1\", \"event_2\"],\n",
" adstock_max_lag=adstock_max_lag,\n",
" yearly_seasonality=yearly_seasonality,\n",
" adstock=\"geometric\",\n",
" saturation=\"michaelis_menten\",\n",
" adstock=GeometricAdstock(l_max=adstock_max_lag),\n",
" saturation=MichaelisMentenSaturation(),\n",
" time_varying_media=True,\n",
")"
]
Expand Down Expand Up @@ -4443,10 +4442,9 @@
" date_column=\"date_week\",\n",
" channel_columns=[\"x1\", \"x2\"],\n",
" control_columns=[\"event_1\", \"event_2\"],\n",
" adstock_max_lag=adstock_max_lag,\n",
" yearly_seasonality=yearly_seasonality,\n",
" adstock=\"geometric\",\n",
" saturation=\"michaelis_menten\",\n",
" adstock=GeometricAdstock(l_max=adstock_max_lag),\n",
" saturation=MichaelisMentenSaturation(),\n",
")\n",
"\n",
"basic_mmm.fit(\n",
Expand Down Expand Up @@ -4686,10 +4684,9 @@
" date_column=\"date_week\",\n",
" channel_columns=[\"x1\", \"x2\"],\n",
" control_columns=[\"event_1\", \"event_2\"],\n",
" adstock_max_lag=adstock_max_lag,\n",
" yearly_seasonality=yearly_seasonality,\n",
" adstock=\"geometric\",\n",
" saturation=\"michaelis_menten\",\n",
" adstock=GeometricAdstock(l_max=adstock_max_lag),\n",
" saturation=MichaelisMentenSaturation(),\n",
" time_varying_media=True,\n",
")"
]
Expand Down Expand Up @@ -9385,10 +9382,9 @@
" date_column=\"date_week\",\n",
" channel_columns=[\"x1\", \"x2\"],\n",
" control_columns=[\"event_1\", \"event_2\"],\n",
" adstock_max_lag=adstock_max_lag,\n",
" yearly_seasonality=yearly_seasonality,\n",
" adstock=\"geometric\",\n",
" saturation=\"michaelis_menten\",\n",
" adstock=GeometricAdstock(l_max=adstock_max_lag),\n",
" saturation=MichaelisMentenSaturation(),\n",
" time_varying_media=True,\n",
")"
]
Expand Down Expand Up @@ -9444,7 +9440,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
767 changes: 342 additions & 425 deletions docs/source/notebooks/mmm/mmm_tvp_example.ipynb

Large diffs are not rendered by default.

Binary file modified docs/source/uml/classes_mmm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/uml/packages_mmm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 3 additions & 4 deletions pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Marketing Mix Models (MMM)."""

from pymc_marketing.mmm import base, delayed_saturated_mmm, preprocessing, validating
from pymc_marketing.mmm import base, mmm, preprocessing, validating
from pymc_marketing.mmm.base import BaseValidateMMM, MMMModelBuilder
from pymc_marketing.mmm.components.adstock import (
AdstockTransformation,
Expand All @@ -37,8 +37,8 @@
register_saturation_transformation,
saturation_from_dict,
)
from pymc_marketing.mmm.delayed_saturated_mmm import MMM, DelayedSaturatedMMM
from pymc_marketing.mmm.fourier import MonthlyFourier, YearlyFourier
from pymc_marketing.mmm.mmm import MMM
from pymc_marketing.mmm.preprocessing import (
preprocessing_method_X,
preprocessing_method_y,
Expand All @@ -49,7 +49,6 @@
"AdstockTransformation",
"BaseValidateMMM",
"DelayedAdstock",
"DelayedSaturatedMMM",
"GeometricAdstock",
"HillSaturation",
"HillSaturationSigmoid",
Expand All @@ -71,7 +70,7 @@
"register_adstock_transformation",
"YearlyFourier",
"base",
"delayed_saturated_mmm",
"mmm",
"preprocessing",
"preprocessing_method_X",
"preprocessing_method_y",
Expand Down
39 changes: 0 additions & 39 deletions pymc_marketing/mmm/components/adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def function(self, x, alpha):
"""

import warnings

import numpy as np
import xarray as xr
from pydantic import Field, InstanceOf, validate_call
Expand Down Expand Up @@ -345,40 +343,3 @@ def adstock_from_dict(data: dict) -> AdstockTransformation:
if "priors" in data:
data["priors"] = {k: Prior.from_json(v) for k, v in data["priors"].items()}
return cls(**data)


def _get_adstock_function(
function: str | AdstockTransformation,
**kwargs,
) -> AdstockTransformation:
"""Get an adstock function.
Helper for use in the MMM to get an adstock function from the if registered.
"""
if isinstance(function, AdstockTransformation):
return function

elif isinstance(function, str):
if function not in ADSTOCK_TRANSFORMATIONS:
raise ValueError(
f"Unknown adstock function: {function}. Choose from {list(ADSTOCK_TRANSFORMATIONS.keys())}"
)

if kwargs:
msg = (
"The preferred method of initializing a "
"lagging function is to use the class directly. "
"String support will deprecate in 0.9.0."
)
warnings.warn(
msg,
DeprecationWarning,
stacklevel=1,
)

return ADSTOCK_TRANSFORMATIONS[function](**kwargs)

else:
raise ValueError(
f"Unknown adstock function: {function}. Choose from {list(ADSTOCK_TRANSFORMATIONS.keys())}"
)
19 changes: 0 additions & 19 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,22 +480,3 @@ def saturation_from_dict(data: dict) -> SaturationTransformation:
key: Prior.from_json(value) for key, value in data["priors"].items()
}
return cls(**data)


def _get_saturation_function(
function: str | SaturationTransformation,
) -> SaturationTransformation:
"""
Get a saturation function.
Helper for use in the MMM to get a saturation function.
"""
if isinstance(function, SaturationTransformation):
return function

if function not in SATURATION_TRANSFORMATIONS:
raise ValueError(
f"Unknown saturation function: {function}. Choose from {list(SATURATION_TRANSFORMATIONS.keys())}"
)

return SATURATION_TRANSFORMATIONS[function]()
Loading

0 comments on commit 288b8e8

Please sign in to comment.