Skip to content

Commit

Permalink
Media transformation sampling & plotting methods (#734)
Browse files Browse the repository at this point in the history
* add plotting methods

* add tests for new methods

* saturation support for additional variable dims

* consolidate the logic of sampling

* change warning

* workflow from a fitted model

* change order of tests

* suggestion to use names

* because of new data

---------

Co-authored-by: Juan Orduz <[email protected]>
  • Loading branch information
2 people authored and twiecki committed Sep 10, 2024
1 parent 5768e51 commit bbf7adf
Show file tree
Hide file tree
Showing 9 changed files with 746 additions and 28 deletions.
99 changes: 99 additions & 0 deletions pymc_marketing/mmm/components/adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,35 @@
.. code-block:: python
from pymc_marketing.mmm import AdstockTransformation
class MyAdstock(AdstockTransformation):
def function(self, x, alpha):
return x * alpha
default_priors = {"alpha": {"dist": "HalfNormal", "kwargs": {"sigma": 1}}}
Plot the default priors for an adstock transformation:
.. code-block:: python
from pymc_marketing.mmm import GeometricAdstock
import matplotlib.pyplot as plt
adstock = GeometricAdstock(l_max=15)
prior = adstock.sample_prior()
curve = adstock.sample_curve(prior)
adstock.plot_curve(curve)
plt.show()
"""

import warnings

import numpy as np
import xarray as xr

from pymc_marketing.mmm.components.base import Transformation
from pymc_marketing.mmm.transformers import (
ConvMode,
Expand Down Expand Up @@ -74,12 +93,62 @@ def __init__(

super().__init__(priors=priors, prefix=prefix)

def sample_curve(
self,
parameters: xr.Dataset,
amount: float = 1.0,
) -> xr.DataArray:
"""Sample the adstock transformation given parameters.
Parameters
----------
parameters : xr.Dataset
Dataset with parameter values.
amount : float, optional
Amount to apply the adstock transformation to, by default 1.0.
Returns
-------
xr.DataArray
Adstocked version of the amount.
"""

time_since = np.arange(0, self.l_max)
coords = {
"time since exposure": time_since,
}
x = np.zeros(self.l_max)
x[0] = amount

return self._sample_curve(
var_name="adstock",
parameters=parameters,
x=x,
coords=coords,
)


class GeometricAdstock(AdstockTransformation):
"""Wrapper around geometric adstock function.
For more information, see :func:`pymc_marketing.mmm.transformers.geometric_adstock`.
.. plot::
:context: close-figs
import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import GeometricAdstock
rng = np.random.default_rng(0)
adstock = GeometricAdstock(l_max=10)
prior = adstock.sample_prior(random_seed=rng)
curve = adstock.sample_curve(prior)
adstock.plot_curve(curve, sample_kwargs={"rng": rng})
plt.show()
"""

lookup_name = "geometric"
Expand All @@ -97,6 +166,21 @@ class DelayedAdstock(AdstockTransformation):
For more information, see :func:`pymc_marketing.mmm.transformers.delayed_adstock`.
.. plot::
:context: close-figs
import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import DelayedAdstock
rng = np.random.default_rng(0)
adstock = DelayedAdstock(l_max=10)
prior = adstock.sample_prior(random_seed=rng)
curve = adstock.sample_curve(prior)
adstock.plot_curve(curve, sample_kwargs={"rng": rng})
plt.show()
"""

lookup_name = "delayed"
Expand All @@ -122,6 +206,21 @@ class WeibullAdstock(AdstockTransformation):
For more information, see :func:`pymc_marketing.mmm.transformers.weibull_adstock`.
.. plot::
:context: close-figs
import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import WeibullAdstock
rng = np.random.default_rng(0)
adstock = WeibullAdstock(l_max=10, kind="CDF")
prior = adstock.sample_prior(random_seed=rng)
curve = adstock.sample_curve(prior)
adstock.plot_curve(curve, sample_kwargs={"rng": rng})
plt.show()
"""

lookup_name = "weibull"
Expand Down
Loading

0 comments on commit bbf7adf

Please sign in to comment.