From 1981a785292be82395284cf0a8ee9f0c34a9e86c Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 11 Jun 2024 14:16:44 +0200 Subject: [PATCH 1/9] add plotting methods --- pymc_marketing/mmm/components/adstock.py | 92 ++++++++++ pymc_marketing/mmm/components/base.py | 175 ++++++++++++++++++-- pymc_marketing/mmm/components/saturation.py | 139 +++++++++++++++- pymc_marketing/mmm/delayed_saturated_mmm.py | 2 +- pymc_marketing/mmm/transformers.py | 16 +- tests/mmm/components/test_adstock.py | 2 +- tests/mmm/components/test_base.py | 2 +- tests/mmm/components/test_saturation.py | 2 +- 8 files changed, 406 insertions(+), 24 deletions(-) diff --git a/pymc_marketing/mmm/components/adstock.py b/pymc_marketing/mmm/components/adstock.py index 7d1880e94..dfcc9eb5c 100644 --- a/pymc_marketing/mmm/components/adstock.py +++ b/pymc_marketing/mmm/components/adstock.py @@ -25,6 +25,8 @@ .. code-block:: python + from pymc_marketing.mmm import AdstockTransformation + class MyAdstock(AdstockTransformation): def function(self, x, alpha): return x * alpha @@ -35,6 +37,10 @@ def function(self, x, alpha): import warnings +import numpy as np +import pymc as pm +import xarray as xr + from pymc_marketing.mmm.components.base import Transformation from pymc_marketing.mmm.transformers import ( ConvMode, @@ -74,12 +80,68 @@ def __init__( super().__init__(priors=priors, prefix=prefix) + def sample_curve( + self, + parameters: xr.Dataset, + amount: float = 1.0, + ) -> xr.DataArray: + """Sample the adstock transformation given parameters. + + Parameters + ---------- + parameters : xr.Dataset + Dataset with parameter values. + amount : float, optional + Amount to apply the adstock transformation to, by default 1.0. + + Returns + ------- + xr.DataArray + Adstocked version of the amount. + + """ + + time_since = np.arange(0, self.l_max) + coords = { + "time since exposure": time_since, + } + x = np.zeros(self.l_max) + x[0] = amount + + with pm.Model(coords=coords): + var_name = "adstock" + pm.Deterministic( + var_name, + self.apply(x), + dims="time since exposure", + ) + + return pm.sample_posterior_predictive( + parameters, + var_names=[var_name], + ).posterior_predictive[var_name] + class GeometricAdstock(AdstockTransformation): """Wrapper around geometric adstock function. For more information, see :func:`pymc_marketing.mmm.transformers.geometric_adstock`. + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + from pymc_marketing.mmm import GeometricAdstock + + rng = np.random.default_rng(0) + + adstock = GeometricAdstock(l_max=10) + prior = adstock.sample_prior(random_seed=rng) + curve = adstock.sample_curve(prior) + adstock.plot_curve(curve, sample_kwargs={"rng": rng}) + plt.show() + """ lookup_name = "geometric" @@ -97,6 +159,21 @@ class DelayedAdstock(AdstockTransformation): For more information, see :func:`pymc_marketing.mmm.transformers.delayed_adstock`. + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + from pymc_marketing.mmm import DelayedAdstock + + rng = np.random.default_rng(0) + + adstock = DelayedAdstock(l_max=10) + prior = adstock.sample_prior(random_seed=rng) + curve = adstock.sample_curve(prior) + adstock.plot_curve(curve, sample_kwargs={"rng": rng}) + plt.show() + """ lookup_name = "delayed" @@ -122,6 +199,21 @@ class WeibullAdstock(AdstockTransformation): For more information, see :func:`pymc_marketing.mmm.transformers.weibull_adstock`. + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + from pymc_marketing.mmm import WeibullAdstock + + rng = np.random.default_rng(0) + + adstock = WeibullAdstock(l_max=10, kind="CDF") + prior = adstock.sample_prior(random_seed=rng) + curve = adstock.sample_curve(prior) + adstock.plot_curve(curve, sample_kwargs={"rng": rng}) + plt.show() + """ lookup_name = "weibull" diff --git a/pymc_marketing/mmm/components/base.py b/pymc_marketing/mmm/components/base.py index 6bc2059d8..941806d65 100644 --- a/pymc_marketing/mmm/components/base.py +++ b/pymc_marketing/mmm/components/base.py @@ -24,6 +24,12 @@ from inspect import signature from typing import Any +import arviz as az +import matplotlib.pyplot as plt +import numpy as np +import pymc as pm +import xarray as xr +from pymc.distributions.shape_utils import Dims from pytensor import tensor as pt from pymc_marketing.mmm.utils import _get_distribution_from_dict @@ -212,7 +218,9 @@ def variable_mapping(self) -> dict[str, str]: for parameter in self.default_priors.keys() } - def _create_distributions(self, dim_name: str) -> dict[str, pt.TensorVariable]: + def _create_distributions( + self, dims: Dims | None = None + ) -> dict[str, pt.TensorVariable]: distributions: dict[str, pt.TensorVariable] = {} for parameter_name, variable_name in self.variable_mapping.items(): parameter_prior = self.function_priors[parameter_name] @@ -223,13 +231,159 @@ def _create_distributions(self, dim_name: str) -> dict[str, pt.TensorVariable]: distributions[parameter_name] = distribution( name=variable_name, - dims=dim_name, + dims=dims, **parameter_prior["kwargs"], ) return distributions - def apply(self, x: pt.TensorLike, dim_name: str = "channel") -> pt.TensorVariable: + def sample_prior(self, **sample_prior_predictive_kwargs) -> xr.Dataset: + """Sample the priors for the transformation. + + Parameters + ---------- + **sample_prior_predictive_kwargs + Keyword arguments for the pm.sample_prior_predictive function. + + Returns + ------- + xr.Dataset + The dataset with the sampled priors. + + """ + with pm.Model(): + self._create_distributions() + return pm.sample_prior_predictive(**sample_prior_predictive_kwargs).prior + + def plot_curve( + self, + curve: xr.DataArray, + color: str = "C0", + ax: plt.Axes | None = None, + sample_kwargs: dict | None = None, + hdi_kwargs: dict | None = None, + ) -> plt.Axes: + """Plot curve HDI and samples. + + Parameters + ---------- + curve : xr.DataArray + The curve to plot. + color : str, optional + The color of the curve. Defaults to "C0". + ax : plt.Axes, optional + The axes to plot on. Defaults to None. + sample_kwargs : dict, optional + Keyword arguments for the plot_curve_sample function. Defaults to None. + hdi_kwargs : dict, optional + Keyword arguments for the plot_curve_hdi function. Defaults to None. + + Returns + ------- + plt.Axes + The axes with the plot. + + """ + hdi_kwargs = hdi_kwargs or {} + sample_kwargs = sample_kwargs or {} + ax = self.plot_curve_hdi(curve, color=color, ax=ax, **hdi_kwargs) + ax = self.plot_curve_sample(curve, color=color, ax=ax, **sample_kwargs) + return ax + + def plot_curve_sample( + self, + curve: xr.DataArray, + color: str = "C0", + ax: plt.Axes | None = None, + n: int = 10, + rng: np.random.Generator | None = None, + plot_kwargs: dict | None = None, + ) -> plt.Axes: + """Plot samples from the curve. + + Parameters + ---------- + curve : xr.DataArray + The curve to plot. + color : str, optional + The color of the curve. Defaults to "C0". + ax : plt.Axes, optional + The axes to plot on. Defaults to None. + n : int, optional + The number of samples to plot. Defaults to 10. + rng : np.random.Generator, optional + The random number generator to use. Defaults to None. + plot_kwargs : dict, optional + Keyword arguments for the plot function. Defaults to None. + + Returns + ------- + plt.Axes + The axes with the plot. + + """ + df_curve = curve.to_series().unstack() + + df_sample = df_curve.sample(n=n, random_state=rng) + + ax = ax or plt.gca() + plot_kwargs = plot_kwargs or {} + plot_kwargs["color"] = color + plot_kwargs["alpha"] = plot_kwargs.get("alpha", 0.3) + plot_kwargs["legend"] = False + df_sample.T.plot(ax=ax, **plot_kwargs) + + return ax + + def plot_curve_hdi( + self, + curve: xr.DataArray, + color: str = "C0", + ax: plt.Axes | None = None, + hdi_kwargs: dict | None = None, + plot_kwargs: dict | None = None, + ) -> plt.Axes: + """Plot the HDI of the curve. + + Parameters + ---------- + curve : xr.DataArray + The curve to plot. + color : str, optional + The color of the curve. Defaults to "C0". + ax : plt.Axes, optional + The axes to plot on. Defaults to None. + hdi_kwargs : dict, optional + Keyword arguments for the az.hdi function. Defaults to None. + plot_kwargs : dict, optional + Keyword arguments for the fill_between function. Defaults to None. + + Returns + ------- + plt.Axes + The axes with the plot. + + """ + hdi_kwargs = hdi_kwargs or {} + conf = az.hdi(curve, **hdi_kwargs) + + df_conf = conf[curve.name].to_series().unstack() + + plot_kwargs = plot_kwargs or {} + plot_kwargs["color"] = color + plot_kwargs["alpha"] = plot_kwargs.get("alpha", 0.3) + + ax = ax or plt.gca() + ax.fill_between( + df_conf.index, + df_conf["lower"], + df_conf["higher"], + **plot_kwargs, + ) + + return ax + + def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> pt.TensorVariable: """Called within a model context. Used internally of the MMM to apply the transformation to the data. @@ -238,30 +392,29 @@ def apply(self, x: pt.TensorLike, dim_name: str = "channel") -> pt.TensorVariabl ---------- x : pt.TensorLike The data to be transformed. - dim_name : str, optional - The name of the dimension associated with the columns of the data. - Defaults to "channel". + dims : str, sequence[str], optional + The name of the dimension associated with the columns of the + data. Defaults to None Returns ------- pt.TensorVariable The transformed data. - Examples -------- Call the function for custom use-case - import pymc as pm - .. code-block:: python + import pymc as pm + transformation = ... coords = {"channel": ["TV", "Radio", "Digital"]} with pm.Model(coords=coords): - transformed_data = transformation.apply(data, dim_name="channel") + transformed_data = transformation.apply(data, dims="channel") """ - kwargs = self._create_distributions(dim_name=dim_name) + kwargs = self._create_distributions(dims=dims) return self.function(x, **kwargs) diff --git a/pymc_marketing/mmm/components/saturation.py b/pymc_marketing/mmm/components/saturation.py index cf83b04f2..b3b9ac6be 100644 --- a/pymc_marketing/mmm/components/saturation.py +++ b/pymc_marketing/mmm/components/saturation.py @@ -22,10 +22,10 @@ -------- Create a new saturation transformation: -from pymc_marketing.mmm.components.saturation import SaturationTransformation - .. code-block:: python + from pymc_marketing.mmm import SaturationTransformation + class InfiniteReturns(SaturationTransformation): def function(self, x, b): return b * x @@ -34,6 +34,10 @@ def function(self, x, b): """ +import numpy as np +import pymc as pm +import xarray as xr + from pymc_marketing.mmm.components.base import Transformation from pymc_marketing.mmm.transformers import ( hill_saturation, @@ -69,16 +73,87 @@ class InfiniteReturns(SaturationTransformation): function = infinite_returns default_priors = {"b": {"dist": "HalfNormal", "kwargs": {"sigma": 1}}} + Make use of plotting capabilities to understand the transformation and its + priors + + .. code-block:: python + + import matplotlib.pyplot as plt + import numpy as np + + saturation = InfiniteReturns() + + rng = np.random.default_rng(0) + + prior = saturation.sample_prior(random_seed=rng) + curve = saturation.sample_curve(prior) + saturation.plot_curve(curve, sample_kwargs={"rng": rng}) + plt.show() + """ prefix: str = "saturation" + def sample_curve( + self, + parameters: xr.Dataset, + max_value: float = 1.0, + ) -> xr.DataArray: + """Sample the curve of the saturation transformation given parameters. + + Parameters + ---------- + parameters : xr.Dataset + Dataset with the parameters of the saturation transformation. + max_value : float, optional + Maximum value of the curve, by default 1.0. + + Returns + ------- + xr.DataArray + Curve of the saturation transformation. + + """ + x = np.linspace(0, max_value, 100) + + coords = { + "x": x, + } + + with pm.Model(coords=coords): + var_name = "saturation" + pm.Deterministic( + var_name, + self.apply(x), + dims="x", + ) + + return pm.sample_posterior_predictive( + parameters, + var_names=[var_name], + ).posterior_predictive[var_name] + class LogisticSaturation(SaturationTransformation): """Wrapper around logistic saturation function. For more information, see :func:`pymc_marketing.mmm.transformers.logistic_saturation`. + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + from pymc_marketing.mmm import LogisticSaturation + + rng = np.random.default_rng(0) + + adstock = LogisticSaturation() + prior = adstock.sample_prior(random_seed=rng) + curve = adstock.sample_curve(prior) + adstock.plot_curve(curve, sample_kwargs={"rng": rng}) + plt.show() + """ lookup_name = "logistic" @@ -97,6 +172,21 @@ class TanhSaturation(SaturationTransformation): For more information, see :func:`pymc_marketing.mmm.transformers.tanh_saturation`. + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + from pymc_marketing.mmm import TanhSaturation + + rng = np.random.default_rng(0) + + adstock = TanhSaturation() + prior = adstock.sample_prior(random_seed=rng) + curve = adstock.sample_curve(prior) + adstock.plot_curve(curve, sample_kwargs={"rng": rng}) + plt.show() + """ lookup_name = "tanh" @@ -116,6 +206,21 @@ class TanhSaturationBaselined(SaturationTransformation): For more information, see :func:`pymc_marketing.mmm.transformers.tanh_saturation_baselined`. + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + from pymc_marketing.mmm import TanhSaturationBaselined + + rng = np.random.default_rng(0) + + adstock = TanhSaturationBaselined() + prior = adstock.sample_prior(random_seed=rng) + curve = adstock.sample_curve(prior) + adstock.plot_curve(curve, sample_kwargs={"rng": rng}) + plt.show() + """ lookup_name = "tanh_baselined" @@ -136,6 +241,21 @@ class MichaelisMentenSaturation(SaturationTransformation): For more information, see :func:`pymc_marketing.mmm.transformers.michaelis_menten`. + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + from pymc_marketing.mmm import MichaelisMentenSaturation + + rng = np.random.default_rng(0) + + adstock = MichaelisMentenSaturation() + prior = adstock.sample_prior(random_seed=rng) + curve = adstock.sample_curve(prior) + adstock.plot_curve(curve, sample_kwargs={"rng": rng}) + plt.show() + """ lookup_name = "michaelis_menten" @@ -153,6 +273,21 @@ class HillSaturation(SaturationTransformation): For more information, see :func:`pymc_marketing.mmm.transformers.hill_saturation`. + .. plot:: + :context: close-figs + + import matplotlib.pyplot as plt + import numpy as np + from pymc_marketing.mmm import HillSaturation + + rng = np.random.default_rng(0) + + adstock = HillSaturation() + prior = adstock.sample_prior(random_seed=rng) + curve = adstock.sample_curve(prior) + adstock.plot_curve(curve, sample_kwargs={"rng": rng}) + plt.show() + """ lookup_name = "hill" diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index 1b6b40f6d..2923fa544 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -373,7 +373,7 @@ def forward_pass( else (self.saturation, self.adstock) ) - return second.apply(x=first.apply(x=x)) + return second.apply(x=first.apply(x=x, dims="channel"), dims="channel") def build_model( self, diff --git a/pymc_marketing/mmm/transformers.py b/pymc_marketing/mmm/transformers.py index cfb1a650c..1ce97da4c 100644 --- a/pymc_marketing/mmm/transformers.py +++ b/pymc_marketing/mmm/transformers.py @@ -850,21 +850,23 @@ def hill_saturation( beta: pt.TensorLike, lam: pt.TensorLike, ) -> pt.TensorVariable: - r""" - Hill Saturation Function + r"""Hill Saturation Function + .. math:: - f(x) = \frac{\\sigma}{1 + e^{-\beta(x - \\lambda)}} + f(x) = \frac{\sigma}{1 + e^{-\beta(x - \lambda)}} + where: - - :math:`\\sigma` is the maximum value (upper asymptote), - - :math:`\beta` is the slope parameter, - - :math:`\\lambda` is the transition point on the X-axis, - - :math:`x` is the independent variable. + - :math:`\sigma` is the maximum value (upper asymptote) + - :math:`\beta` is the slope parameter + - :math:`\lambda` is the transition point on the X-axis + - :math:`x` is the independent variable This function computes the Hill sigmoidal response curve, which is commonly used to describe the saturation effect in biological systems. The curve is characterized by its sigmoidal shape, representing a gradual transition from a low, nearly zero level to a high plateau, the maximum value the function will approach as the independent variable grows large. + .. plot:: :context: close-figs import numpy as np diff --git a/tests/mmm/components/test_adstock.py b/tests/mmm/components/test_adstock.py index 6559deed0..af94ad778 100644 --- a/tests/mmm/components/test_adstock.py +++ b/tests/mmm/components/test_adstock.py @@ -57,7 +57,7 @@ def model() -> pm.Model: ) def test_apply(model, adstock, x, dims) -> None: with model: - y = adstock.apply(x, dim_name=dims) + y = adstock.apply(x, dims=dims) assert isinstance(y, pt.TensorVariable) assert y.eval().shape == x.shape diff --git a/tests/mmm/components/test_base.py b/tests/mmm/components/test_base.py index dfac7ba72..78a67bda0 100644 --- a/tests/mmm/components/test_base.py +++ b/tests/mmm/components/test_base.py @@ -166,7 +166,7 @@ def test_apply(new_transformation): x = np.array([1, 2, 3]) expected = np.array([6, 12, 18]) with pm.Model() as generative_model: - pm.Deterministic("y", new_transformation.apply(x, dim_name=None)) + pm.Deterministic("y", new_transformation.apply(x)) fixed_model = pm.do(generative_model, {"new_a": 2, "new_b": 3}) np.testing.assert_allclose(fixed_model["y"].eval(), expected) diff --git a/tests/mmm/components/test_saturation.py b/tests/mmm/components/test_saturation.py index ecef704cb..d571fc379 100644 --- a/tests/mmm/components/test_saturation.py +++ b/tests/mmm/components/test_saturation.py @@ -57,7 +57,7 @@ def saturation_functions(): ) def test_apply_method(model, saturation, x, dims) -> None: with model: - y = saturation.apply(x, dim_name=dims) + y = saturation.apply(x, dims=dims) assert isinstance(y, pt.TensorVariable) assert y.eval().shape == x.shape From f3e956dc6688db07a3372b1872047f4f20528d2d Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 11 Jun 2024 14:50:20 +0200 Subject: [PATCH 2/9] add tests for new methods --- tests/mmm/components/test_adstock.py | 14 +++++++++++ tests/mmm/components/test_base.py | 31 +++++++++++++++++++++++++ tests/mmm/components/test_saturation.py | 16 +++++++++---- 3 files changed, 57 insertions(+), 4 deletions(-) diff --git a/tests/mmm/components/test_adstock.py b/tests/mmm/components/test_adstock.py index af94ad778..914955e9d 100644 --- a/tests/mmm/components/test_adstock.py +++ b/tests/mmm/components/test_adstock.py @@ -15,6 +15,7 @@ import pymc as pm import pytensor.tensor as pt import pytest +import xarray as xr from pymc_marketing.mmm.components.adstock import ( AdstockTransformation, @@ -105,3 +106,16 @@ def test_get_adstock_function_unknown(): ValueError, match="Unknown adstock function: Unknown. Choose from" ): _get_adstock_function(function="Unknown") + + +@pytest.mark.parametrize( + "adstock", + adstocks(), +) +def test_adstock_sample_curve(adstock) -> None: + prior = adstock.sample_prior() + assert isinstance(prior, xr.Dataset) + curve = adstock.sample_curve(prior) + assert isinstance(curve, xr.DataArray) + assert curve.name == "adstock" + assert curve.shape == (1, 500, adstock.l_max) diff --git a/tests/mmm/components/test_base.py b/tests/mmm/components/test_base.py index 78a67bda0..5d159a4cd 100644 --- a/tests/mmm/components/test_base.py +++ b/tests/mmm/components/test_base.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import matplotlib.pyplot as plt import numpy as np import pymc as pm import pytest +import xarray as xr from pymc_marketing.mmm.components.base import ( MissingDataParameter, @@ -206,3 +208,32 @@ def test_new_transformation_warning_no_priors_updated(new_transformation) -> Non new_transformation.update_priors( {"new_c": {"dist": "HalfNormal", "kwargs": {"sigma": 1}}} ) + + +def test_new_transformation_sample_prior(new_transformation) -> None: + prior = new_transformation.sample_prior() + + assert isinstance(prior, xr.Dataset) + assert dict(prior.coords.sizes) == { + "chain": 1, + "draw": 500, + } + + assert set(prior.keys()) == {"new_a", "new_b"} + + +@pytest.fixture +def curve() -> xr.DataArray: + return xr.DataArray( + np.ones((1, 500, 10)), + dims=["chain", "draw", "time"], + coords={"time": np.arange(10), "draw": np.arange(500), "chain": np.arange(1)}, + ) + + +def test_new_transformation_plot_curve(new_transformation, curve) -> None: + ax = new_transformation.plot_curve(curve) + + assert isinstance(ax, plt.Axes) + + plt.close() diff --git a/tests/mmm/components/test_saturation.py b/tests/mmm/components/test_saturation.py index d571fc379..d9781bb4e 100644 --- a/tests/mmm/components/test_saturation.py +++ b/tests/mmm/components/test_saturation.py @@ -17,6 +17,7 @@ import pymc as pm import pytensor.tensor as pt import pytest +import xarray as xr from pymc_marketing.mmm.components.saturation import ( HillSaturation, @@ -103,10 +104,7 @@ def test_get_saturation_function(name, saturation_cls) -> None: assert isinstance(saturation, saturation_cls) -@pytest.mark.parametrize( - "saturation", - saturation_functions(), -) +@pytest.mark.parametrize("saturation", saturation_functions()) def test_get_saturation_function_passthrough(saturation) -> None: id_before = id(saturation) id_after = id(_get_saturation_function(saturation)) @@ -119,3 +117,13 @@ def test_get_saturation_function_unknown() -> None: ValueError, match="Unknown saturation function: unknown. Choose from" ): _get_saturation_function("unknown") + + +@pytest.mark.parametrize("saturation", saturation_functions()) +def test_sample_curve(saturation) -> None: + prior = saturation.sample_prior() + assert isinstance(prior, xr.Dataset) + curve = saturation.sample_curve(prior) + assert isinstance(curve, xr.DataArray) + assert curve.name == "saturation" + assert curve.shape == (1, 500, 100) From ef7aeaf6d8d59a7584793048b1ed53f7d7f5761d Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 11 Jun 2024 15:47:59 +0200 Subject: [PATCH 3/9] saturation support for additional variable dims --- pymc_marketing/mmm/components/saturation.py | 27 +++++++- tests/mmm/components/test_saturation.py | 75 +++++++++++++++++++++ 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/pymc_marketing/mmm/components/saturation.py b/pymc_marketing/mmm/components/saturation.py index b3b9ac6be..ffac93dac 100644 --- a/pymc_marketing/mmm/components/saturation.py +++ b/pymc_marketing/mmm/components/saturation.py @@ -114,22 +114,43 @@ def sample_curve( Curve of the saturation transformation. """ + required_vars = list(self.variable_mapping.values()) + + function_parameters = parameters[required_vars] + x = np.linspace(0, max_value, 100) coords = { "x": x, } + parameter_coords = function_parameters.coords + + additional_coords = { + coord: parameter_coords[coord].to_numpy() + for coord in parameter_coords.keys() + if coord not in {"chain", "draw"} + } + + dims = tuple(additional_coords.keys()) + # Allow broadcasting + x = np.expand_dims( + x, + axis=tuple(range(1, len(dims) + 1)), + ) + + coords.update(additional_coords) + with pm.Model(coords=coords): var_name = "saturation" pm.Deterministic( var_name, - self.apply(x), - dims="x", + self.apply(x, dims=dims), + dims=("x", *dims), ) return pm.sample_posterior_predictive( - parameters, + function_parameters, var_names=[var_name], ).posterior_predictive[var_name] diff --git a/tests/mmm/components/test_saturation.py b/tests/mmm/components/test_saturation.py index d9781bb4e..cb6f4f512 100644 --- a/tests/mmm/components/test_saturation.py +++ b/tests/mmm/components/test_saturation.py @@ -127,3 +127,78 @@ def test_sample_curve(saturation) -> None: assert isinstance(curve, xr.DataArray) assert curve.name == "saturation" assert curve.shape == (1, 500, 100) + + +def create_mock_parameters( + coords: dict[str, list], + variable_dim_mapping: dict[str, tuple[str]], +) -> xr.Dataset: + dim_sizes = {coord: len(values) for coord, values in coords.items()} + return xr.Dataset( + { + name: xr.DataArray( + np.ones(tuple(dim_sizes[coord] for coord in dims)), + dims=dims, + coords={coord: coords[coord] for coord in dims}, + ) + for name, dims in variable_dim_mapping.items() + } + ) + + +@pytest.fixture +def mock_menten_parameters() -> xr.Dataset: + coords = { + "chain": np.arange(1), + "draw": np.arange(500), + } + + variable_dim_mapping = { + "saturation_alpha": ("chain", "draw"), + "saturation_lam": ("chain", "draw"), + "another_random_variable": ("chain", "draw"), + } + + return create_mock_parameters(coords, variable_dim_mapping) + + +def test_sample_curve_additional_dataset_variables(mock_menten_parameters) -> None: + """Case when the parameter dataset has additional variables.""" + saturation = MichaelisMentenSaturation() + + try: + curve = saturation.sample_curve(parameters=mock_menten_parameters) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + assert isinstance(curve, xr.DataArray) + assert curve.name == "saturation" + + +@pytest.fixture +def mock_menten_parameters_with_additional_dim() -> xr.Dataset: + coords = { + "chain": np.arange(1), + "draw": np.arange(500), + "channel": ["C1", "C2", "C3"], + "random_dim": ["R1", "R2"], + } + variable_dim_mapping = { + "saturation_alpha": ("chain", "draw", "channel"), + "saturation_lam": ("chain", "draw", "channel"), + "another_random_variable": ("chain", "draw", "channel", "random_dim"), + } + + return create_mock_parameters(coords, variable_dim_mapping) + + +def test_sample_curve_with_additional_dims( + mock_menten_parameters_with_additional_dim, +) -> None: + saturation = MichaelisMentenSaturation() + curve = saturation.sample_curve( + parameters=mock_menten_parameters_with_additional_dim + ) + + assert curve.coords["channel"].to_numpy().tolist() == ["C1", "C2", "C3"] + assert "random_dim" not in curve.coords From 4018c8d2c314b8d0f17018a42ddf661e23f244ef Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 11 Jun 2024 16:28:05 +0200 Subject: [PATCH 4/9] consolidate the logic of sampling --- pymc_marketing/mmm/components/adstock.py | 19 +++------ pymc_marketing/mmm/components/base.py | 46 +++++++++++++++++++++ pymc_marketing/mmm/components/saturation.py | 38 +++-------------- 3 files changed, 57 insertions(+), 46 deletions(-) diff --git a/pymc_marketing/mmm/components/adstock.py b/pymc_marketing/mmm/components/adstock.py index dfcc9eb5c..01504e2b1 100644 --- a/pymc_marketing/mmm/components/adstock.py +++ b/pymc_marketing/mmm/components/adstock.py @@ -38,7 +38,6 @@ def function(self, x, alpha): import warnings import numpy as np -import pymc as pm import xarray as xr from pymc_marketing.mmm.components.base import Transformation @@ -108,18 +107,12 @@ def sample_curve( x = np.zeros(self.l_max) x[0] = amount - with pm.Model(coords=coords): - var_name = "adstock" - pm.Deterministic( - var_name, - self.apply(x), - dims="time since exposure", - ) - - return pm.sample_posterior_predictive( - parameters, - var_names=[var_name], - ).posterior_predictive[var_name] + return self._sample_curve( + var_name="adstock", + parameters=parameters, + x=x, + coords=coords, + ) class GeometricAdstock(AdstockTransformation): diff --git a/pymc_marketing/mmm/components/base.py b/pymc_marketing/mmm/components/base.py index 941806d65..1aa50ed90 100644 --- a/pymc_marketing/mmm/components/base.py +++ b/pymc_marketing/mmm/components/base.py @@ -290,6 +290,52 @@ def plot_curve( ax = self.plot_curve_sample(curve, color=color, ax=ax, **sample_kwargs) return ax + def _sample_curve( + self, + var_name: str, + parameters: xr.Dataset, + x: pt.TensorLike, + coords: dict[str, Any], + ) -> xr.DataArray: + required_vars = list(self.variable_mapping.values()) + + keys = list(coords.keys()) + if len(keys) != 1: + msg = "The coords should only have one key." + raise ValueError(msg) + x_dim = keys[0] + + function_parameters = parameters[required_vars] + + parameter_coords = function_parameters.coords + + additional_coords = { + coord: parameter_coords[coord].to_numpy() + for coord in parameter_coords.keys() + if coord not in {"chain", "draw"} + } + + dims = tuple(additional_coords.keys()) + # Allow broadcasting + x = np.expand_dims( + x, + axis=tuple(range(1, len(dims) + 1)), + ) + + coords.update(additional_coords) + + with pm.Model(coords=coords): + pm.Deterministic( + var_name, + self.apply(x, dims=dims), + dims=(x_dim, *dims), + ) + + return pm.sample_posterior_predictive( + function_parameters, + var_names=[var_name], + ).posterior_predictive[var_name] + def plot_curve_sample( self, curve: xr.DataArray, diff --git a/pymc_marketing/mmm/components/saturation.py b/pymc_marketing/mmm/components/saturation.py index ffac93dac..ad6c97bf5 100644 --- a/pymc_marketing/mmm/components/saturation.py +++ b/pymc_marketing/mmm/components/saturation.py @@ -35,7 +35,6 @@ def function(self, x, b): """ import numpy as np -import pymc as pm import xarray as xr from pymc_marketing.mmm.components.base import Transformation @@ -114,46 +113,19 @@ def sample_curve( Curve of the saturation transformation. """ - required_vars = list(self.variable_mapping.values()) - - function_parameters = parameters[required_vars] - x = np.linspace(0, max_value, 100) coords = { "x": x, } - parameter_coords = function_parameters.coords - - additional_coords = { - coord: parameter_coords[coord].to_numpy() - for coord in parameter_coords.keys() - if coord not in {"chain", "draw"} - } - - dims = tuple(additional_coords.keys()) - # Allow broadcasting - x = np.expand_dims( - x, - axis=tuple(range(1, len(dims) + 1)), + return self._sample_curve( + var_name="saturation", + parameters=parameters, + x=x, + coords=coords, ) - coords.update(additional_coords) - - with pm.Model(coords=coords): - var_name = "saturation" - pm.Deterministic( - var_name, - self.apply(x, dims=dims), - dims=("x", *dims), - ) - - return pm.sample_posterior_predictive( - function_parameters, - var_names=[var_name], - ).posterior_predictive[var_name] - class LogisticSaturation(SaturationTransformation): """Wrapper around logistic saturation function. From 820e496eded7510bdb7f06abff2b326898a6dd8b Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 12 Jun 2024 12:13:02 +0200 Subject: [PATCH 5/9] change warning --- pymc_marketing/mmm/components/base.py | 189 +++++++++++++++++++------- tests/mmm/components/test_base.py | 58 ++++++-- 2 files changed, 189 insertions(+), 58 deletions(-) diff --git a/pymc_marketing/mmm/components/base.py b/pymc_marketing/mmm/components/base.py index 1aa50ed90..c247b5d91 100644 --- a/pymc_marketing/mmm/components/base.py +++ b/pymc_marketing/mmm/components/base.py @@ -21,12 +21,15 @@ """ import warnings +from collections.abc import Generator, MutableMapping, Sequence from inspect import signature +from itertools import product from typing import Any import arviz as az import matplotlib.pyplot as plt import numpy as np +import numpy.typing as npt import pymc as pm import xarray as xr from pymc.distributions.shape_utils import Dims @@ -34,6 +37,49 @@ from pymc_marketing.mmm.utils import _get_distribution_from_dict +Values = Sequence[Any] | npt.NDArray[Any] +Coords = dict[str, Values] + + +def get_plot_coords(coords: Coords) -> Coords: + plot_coord_names = list(coords.keys())[3:] + return {name: np.array(coords[name]) for name in plot_coord_names} + + +def get_total_coord_size(coords: Coords) -> int: + total_size: int = ( + 1 if coords == {} else np.prod([len(values) for values in coords.values()]) # type: ignore + ) + if total_size >= 12: + warnings.warn("Large number of coordinates!", stacklevel=2) + + return total_size + + +def set_subplot_kwargs_defaults( + subplot_kwargs: MutableMapping[str, Any], + total_size: int, +) -> None: + if "ncols" in subplot_kwargs and "nrows" in subplot_kwargs: + raise ValueError("Only specify one") + + if "ncols" not in subplot_kwargs and "nrows" not in subplot_kwargs: + subplot_kwargs["ncols"] = total_size + + if "ncols" in subplot_kwargs: + subplot_kwargs["nrows"] = total_size // subplot_kwargs["ncols"] + elif "nrows" in subplot_kwargs: + subplot_kwargs["ncols"] = total_size // subplot_kwargs["nrows"] + + +def selections( + coords: Coords, +) -> Generator[dict[str, Any], None, None]: + """Helper to create generator of selections.""" + coord_names = coords.keys() + for values in product(*coords.values()): + yield {name: value for name, value in zip(coord_names, values, strict=True)} + class ParameterPriorException(Exception): """Error when the functions and specified priors don't match up.""" @@ -237,11 +283,15 @@ def _create_distributions( return distributions - def sample_prior(self, **sample_prior_predictive_kwargs) -> xr.Dataset: + def sample_prior( + self, coords: dict | None = None, **sample_prior_predictive_kwargs + ) -> xr.Dataset: """Sample the priors for the transformation. Parameters ---------- + coords : dict, optional + The coordinates for the associated with dims **sample_prior_predictive_kwargs Keyword arguments for the pm.sample_prior_predictive function. @@ -251,28 +301,27 @@ def sample_prior(self, **sample_prior_predictive_kwargs) -> xr.Dataset: The dataset with the sampled priors. """ - with pm.Model(): - self._create_distributions() + coords = coords or {} + dims = tuple(coords.keys()) + with pm.Model(coords=coords): + self._create_distributions(dims=dims) return pm.sample_prior_predictive(**sample_prior_predictive_kwargs).prior def plot_curve( self, curve: xr.DataArray, - color: str = "C0", - ax: plt.Axes | None = None, + subplot_kwargs: dict | None = None, sample_kwargs: dict | None = None, hdi_kwargs: dict | None = None, - ) -> plt.Axes: + ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: """Plot curve HDI and samples. Parameters ---------- curve : xr.DataArray The curve to plot. - color : str, optional - The color of the curve. Defaults to "C0". - ax : plt.Axes, optional - The axes to plot on. Defaults to None. + subplot_kwargs : dict, optional + Keyword arguments for plt.subplots sample_kwargs : dict, optional Keyword arguments for the plot_curve_sample function. Defaults to None. hdi_kwargs : dict, optional @@ -280,15 +329,19 @@ def plot_curve( Returns ------- - plt.Axes - The axes with the plot. + tuple[plt.Figure, npt.NDArray[plt.Axes]] """ hdi_kwargs = hdi_kwargs or {} sample_kwargs = sample_kwargs or {} - ax = self.plot_curve_hdi(curve, color=color, ax=ax, **hdi_kwargs) - ax = self.plot_curve_sample(curve, color=color, ax=ax, **sample_kwargs) - return ax + + if "subplot_kwargs" not in hdi_kwargs: + hdi_kwargs["subplot_kwargs"] = subplot_kwargs + + fig, axes = self.plot_curve_hdi(curve, **hdi_kwargs) + fig, axes = self.plot_curve_samples(curve, axes=axes, **sample_kwargs) + + return fig, axes def _sample_curve( self, @@ -336,98 +389,134 @@ def _sample_curve( var_names=[var_name], ).posterior_predictive[var_name] - def plot_curve_sample( + def plot_curve_samples( self, curve: xr.DataArray, - color: str = "C0", - ax: plt.Axes | None = None, n: int = 10, rng: np.random.Generator | None = None, plot_kwargs: dict | None = None, - ) -> plt.Axes: + subplot_kwargs: dict | None = None, + axes: npt.NDArray[plt.Axes] | None = None, + ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: """Plot samples from the curve. Parameters ---------- curve : xr.DataArray The curve to plot. - color : str, optional - The color of the curve. Defaults to "C0". - ax : plt.Axes, optional - The axes to plot on. Defaults to None. n : int, optional The number of samples to plot. Defaults to 10. rng : np.random.Generator, optional The random number generator to use. Defaults to None. plot_kwargs : dict, optional - Keyword arguments for the plot function. Defaults to None. + Keyword arguments for the DataFrame plot function. Defaults to None. + subplot_kwargs : dict, optional + Keyword arguments for plt.subplots + axes : npt.NDArray[plt.Axes], optional + The exact axes to plot on. Overrides any subplot_kwargs Returns ------- + tuple[plt.Figure, npt.NDArray[plt.Axes]] plt.Axes The axes with the plot. """ - df_curve = curve.to_series().unstack() + plot_coords = get_plot_coords(curve.coords) + total_size = get_total_coord_size(plot_coords) - df_sample = df_curve.sample(n=n, random_state=rng) + if axes is None: + subplot_kwargs = subplot_kwargs or {} + set_subplot_kwargs_defaults(subplot_kwargs, total_size) + fig, axes = plt.subplots(**subplot_kwargs) + else: + fig = plt.gcf() - ax = ax or plt.gca() plot_kwargs = plot_kwargs or {} - plot_kwargs["color"] = color plot_kwargs["alpha"] = plot_kwargs.get("alpha", 0.3) plot_kwargs["legend"] = False - df_sample.T.plot(ax=ax, **plot_kwargs) - return ax + for i, (ax, sel) in enumerate( + zip(np.ravel(axes), selections(plot_coords), strict=False) + ): + color = f"C{i}" + + df_curve = curve.sel(sel).to_series().unstack() + df_sample = df_curve.sample(n=n, random_state=rng) + + df_sample.T.plot(ax=ax, color=color, **plot_kwargs) + title = ", ".join(f"{name}={value}" for name, value in sel.items()) + ax.set_title(title) + + if not isinstance(axes, np.ndarray): + axes = np.array([axes]) + + return fig, axes def plot_curve_hdi( self, curve: xr.DataArray, - color: str = "C0", - ax: plt.Axes | None = None, hdi_kwargs: dict | None = None, plot_kwargs: dict | None = None, - ) -> plt.Axes: + subplot_kwargs: dict | None = None, + axes: npt.NDArray[plt.Axes] | None = None, + ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: """Plot the HDI of the curve. Parameters ---------- curve : xr.DataArray The curve to plot. - color : str, optional - The color of the curve. Defaults to "C0". - ax : plt.Axes, optional - The axes to plot on. Defaults to None. hdi_kwargs : dict, optional Keyword arguments for the az.hdi function. Defaults to None. plot_kwargs : dict, optional Keyword arguments for the fill_between function. Defaults to None. + subplot_kwargs : dict, optional + Keyword arguments for plt.subplots + axes : npt.NDArray[plt.Axes], optional + The exact axes to plot on. Overrides any subplot_kwargs Returns ------- - plt.Axes - The axes with the plot. + tuple[plt.Figure, npt.NDArray[plt.Axes]] """ + plot_coords = get_plot_coords(curve.coords) + total_size = get_total_coord_size(plot_coords) + hdi_kwargs = hdi_kwargs or {} - conf = az.hdi(curve, **hdi_kwargs) + conf = az.hdi(curve, **hdi_kwargs)[curve.name] - df_conf = conf[curve.name].to_series().unstack() + if axes is None: + subplot_kwargs = subplot_kwargs or {} + set_subplot_kwargs_defaults(subplot_kwargs, total_size) + fig, axes = plt.subplots(**subplot_kwargs) + else: + fig = plt.gcf() plot_kwargs = plot_kwargs or {} - plot_kwargs["color"] = color plot_kwargs["alpha"] = plot_kwargs.get("alpha", 0.3) - ax = ax or plt.gca() - ax.fill_between( - df_conf.index, - df_conf["lower"], - df_conf["higher"], - **plot_kwargs, - ) + for i, (ax, sel) in enumerate( + zip(np.ravel(axes), selections(plot_coords), strict=False) + ): + color = f"C{i}" + df_conf = conf.sel(sel).to_series().unstack() + + ax.fill_between( + x=df_conf.index, + y1=df_conf["lower"], + y2=df_conf["higher"], + color=color, + **plot_kwargs, + ) + title = ", ".join(f"{name}={value}" for name, value in sel.items()) + ax.set_title(title) + + if not isinstance(axes, np.ndarray): + axes = np.array([axes]) - return ax + return fig, axes def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> pt.TensorVariable: """Called within a model context. diff --git a/tests/mmm/components/test_base.py b/tests/mmm/components/test_base.py index 5d159a4cd..364bd271b 100644 --- a/tests/mmm/components/test_base.py +++ b/tests/mmm/components/test_base.py @@ -21,6 +21,7 @@ MissingDataParameter, ParameterPriorException, Transformation, + selections, ) @@ -222,18 +223,59 @@ def test_new_transformation_sample_prior(new_transformation) -> None: assert set(prior.keys()) == {"new_a", "new_b"} -@pytest.fixture -def curve() -> xr.DataArray: +def create_curve(coords) -> xr.DataArray: + size = [len(values) for values in coords.values()] + dims = list(coords.keys()) + data = np.ones(size) return xr.DataArray( - np.ones((1, 500, 10)), - dims=["chain", "draw", "time"], - coords={"time": np.arange(10), "draw": np.arange(500), "chain": np.arange(1)}, + data, + dims=dims, + coords=coords, ) -def test_new_transformation_plot_curve(new_transformation, curve) -> None: - ax = new_transformation.plot_curve(curve) +@pytest.mark.parametrize( + "coords, expected_size", + [ + ({"chain": np.arange(1), "draw": np.arange(250), "time": np.arange(10)}, 1), + ( + { + "chain": np.arange(1), + "draw": np.arange(250), + "time": np.arange(10), + "channel": ["A", "B", "C"], + }, + 3, + ), + ], +) +def test_new_transformation_plot_curve( + new_transformation, coords, expected_size +) -> None: + curve = create_curve(coords) + fig, axes = new_transformation.plot_curve(curve) - assert isinstance(ax, plt.Axes) + assert isinstance(fig, plt.Figure) + assert len(axes) == expected_size plt.close() + + +@pytest.mark.parametrize( + "coords, expected", + [ + ({}, [{}]), + ({"channel": [1, 2, 3]}, [{"channel": 1}, {"channel": 2}, {"channel": 3}]), + ( + {"channel": [1, 2], "country": ["A", "B"]}, + [ + {"channel": 1, "country": "A"}, + {"channel": 1, "country": "B"}, + {"channel": 2, "country": "A"}, + {"channel": 2, "country": "B"}, + ], + ), + ], +) +def test_selections(coords, expected) -> None: + assert list(selections(coords)) == expected From e0000e418b9c80886ca61e91248a6e00ac47f992 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 12 Jun 2024 12:26:51 +0200 Subject: [PATCH 6/9] workflow from a fitted model --- tests/mmm/test_delayed_saturated_mmm.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index 7b85e7742..5b9250a25 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -1088,3 +1088,15 @@ def test_initialize_alternative_with_classes() -> None: assert isinstance(mmm.adstock, DelayedAdstock) assert mmm.adstock.l_max == 10 assert isinstance(mmm.saturation, MichaelisMentenSaturation) + + +@pytest.mark.parametrize("media_transform", ["adstock", "saturation"]) +def test_plotting_media_transform_workflow(mmm_fitted, media_transform) -> None: + transform = getattr(mmm_fitted, media_transform) + curve = transform.sample_curve(mmm_fitted.fit_result) + fig, axes = transform.plot_curve(curve) + + assert isinstance(fig, plt.Figure) + assert len(axes) == mmm_fitted.fit_result["channel"].size + + plt.close() From 62d47dd289343c552bb945a1b93eeb6aa3bea89a Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 12 Jun 2024 13:35:35 +0200 Subject: [PATCH 7/9] change order of tests --- tests/mmm/test_delayed_saturated_mmm.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index 791d94b44..2a9091323 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -168,6 +168,18 @@ def mmm_fitted_with_fourier_features( return mock_fit(mmm_with_fourier_features, toy_X, toy_y) +@pytest.mark.parametrize("media_transform", ["adstock", "saturation"]) +def test_plotting_media_transform_workflow(mmm_fitted, media_transform) -> None: + transform = getattr(mmm_fitted, media_transform) + curve = transform.sample_curve(mmm_fitted.fit_result) + fig, axes = transform.plot_curve(curve) + + assert isinstance(fig, plt.Figure) + assert len(axes) == mmm_fitted.fit_result["channel"].size + + plt.close() + + class TestDelayedSaturatedMMM: def test_save_load_with_not_serializable_model_config( self, model_config_requiring_serialization, toy_X, toy_y @@ -1085,15 +1097,3 @@ def test_initialize_alternative_with_classes() -> None: assert isinstance(mmm.adstock, DelayedAdstock) assert mmm.adstock.l_max == 10 assert isinstance(mmm.saturation, MichaelisMentenSaturation) - - -@pytest.mark.parametrize("media_transform", ["adstock", "saturation"]) -def test_plotting_media_transform_workflow(mmm_fitted, media_transform) -> None: - transform = getattr(mmm_fitted, media_transform) - curve = transform.sample_curve(mmm_fitted.fit_result) - fig, axes = transform.plot_curve(curve) - - assert isinstance(fig, plt.Figure) - assert len(axes) == mmm_fitted.fit_result["channel"].size - - plt.close() From 1bc93aa6423dfceacee0f6adf6d5a62dff590334 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 12 Jun 2024 15:45:18 +0200 Subject: [PATCH 8/9] suggestion to use names --- pymc_marketing/mmm/components/adstock.py | 14 ++++++++++++++ pymc_marketing/mmm/components/base.py | 6 +++++- pymc_marketing/mmm/components/saturation.py | 14 ++++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/pymc_marketing/mmm/components/adstock.py b/pymc_marketing/mmm/components/adstock.py index 01504e2b1..791a2828c 100644 --- a/pymc_marketing/mmm/components/adstock.py +++ b/pymc_marketing/mmm/components/adstock.py @@ -33,6 +33,20 @@ def function(self, x, alpha): default_priors = {"alpha": {"dist": "HalfNormal", "kwargs": {"sigma": 1}}} +Plot the default priors for an adstock transformation: + +.. code-block:: python + + from pymc_marketing.mmm import GeometricAdstock + + import matplotlib.pyplot as plt + + adstock = GeometricAdstock(l_max=15) + prior = adstock.sample_prior() + curve = adstock.sample_curve(prior) + adstock.plot_curve(curve) + plt.show() + """ import warnings diff --git a/pymc_marketing/mmm/components/base.py b/pymc_marketing/mmm/components/base.py index c247b5d91..d57a993e9 100644 --- a/pymc_marketing/mmm/components/base.py +++ b/pymc_marketing/mmm/components/base.py @@ -40,9 +40,13 @@ Values = Sequence[Any] | npt.NDArray[Any] Coords = dict[str, Values] +# chain and draw from sampling +# "x" for saturation, "time since exposure" for adstock +NON_GRID_NAMES = {"chain", "draw", "x", "time since exposure"} + def get_plot_coords(coords: Coords) -> Coords: - plot_coord_names = list(coords.keys())[3:] + plot_coord_names = list(key for key in coords.keys() if key not in NON_GRID_NAMES) return {name: np.array(coords[name]) for name in plot_coord_names} diff --git a/pymc_marketing/mmm/components/saturation.py b/pymc_marketing/mmm/components/saturation.py index ad6c97bf5..215b3f551 100644 --- a/pymc_marketing/mmm/components/saturation.py +++ b/pymc_marketing/mmm/components/saturation.py @@ -32,6 +32,20 @@ def function(self, x, b): default_priors = {"b": {"dist": "HalfNormal", "kwargs": {"sigma": 1}}} +Plot the default priors for a saturation transformation: + +.. code-block:: python + + from pymc_marketing.mmm import HillSaturation + + import matplotlib.pyplot as plt + + saturation = HillSaturation() + prior = saturation.sample_prior() + curve = saturation.sample_curve(prior) + saturation.plot_curve(curve) + plt.show() + """ import numpy as np From e2484c3944fc892bfe6e38131ebad8eecf81d982 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 12 Jun 2024 16:26:04 +0200 Subject: [PATCH 9/9] because of new data --- tests/mmm/components/test_base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/mmm/components/test_base.py b/tests/mmm/components/test_base.py index 364bd271b..519c3174e 100644 --- a/tests/mmm/components/test_base.py +++ b/tests/mmm/components/test_base.py @@ -231,18 +231,19 @@ def create_curve(coords) -> xr.DataArray: data, dims=dims, coords=coords, + name="data", ) @pytest.mark.parametrize( "coords, expected_size", [ - ({"chain": np.arange(1), "draw": np.arange(250), "time": np.arange(10)}, 1), + ({"chain": np.arange(1), "draw": np.arange(250), "x": np.arange(10)}, 1), ( { "chain": np.arange(1), "draw": np.arange(250), - "time": np.arange(10), + "x": np.arange(10), "channel": ["A", "B", "C"], }, 3,