From b97c945feaf3af713e3c7ac1225915b668ea826d Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 22 Oct 2024 13:44:41 -0400 Subject: [PATCH 1/4] use the modelbuilder mixin --- pymc_marketing/model_builder.py | 8 +- .../product_incrementality/mv_its.py | 263 ++++++++++-------- 2 files changed, 160 insertions(+), 111 deletions(-) diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index 725eeb6db..96077fda9 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -580,11 +580,17 @@ def fit( Initializing NUTS using jitter+adapt_diag... """ + if isinstance(y, pd.Series) and not X.index.equals(y.index): + raise ValueError( # pragma: no cover + "Index of X and y must match." + ) + if predictor_names is None: predictor_names = [] if y is None: y = np.zeros(X.shape[0]) - y_df = pd.DataFrame({self.output_var: y}) + + y_df = pd.DataFrame({self.output_var: y}, index=X.index) self._generate_and_preprocess_model_data(X, y_df.values.flatten()) if self.X is None or self.y is None: raise ValueError("X and y must be set before calling build_model!") diff --git a/pymc_marketing/product_incrementality/mv_its.py b/pymc_marketing/product_incrementality/mv_its.py index 462c24ee4..ea42c8e21 100644 --- a/pymc_marketing/product_incrementality/mv_its.py +++ b/pymc_marketing/product_incrementality/mv_its.py @@ -13,6 +13,9 @@ # limitations under the License. """Multivariate Interrupted Time Series Analysis for Product Incrementality.""" +import json +from typing import Any + import arviz as az import matplotlib.pyplot as plt import matplotlib.ticker as mtick @@ -20,169 +23,207 @@ import pandas as pd import pymc as pm +from pymc_marketing.model_builder import ModelBuilder +from pymc_marketing.prior import Prior + HDI_ALPHA = 0.5 -class MVITS: +class MVITS(ModelBuilder): """Multivariate Interrupted Time Series class. Class to perform a multivariate interrupted time series analysis with the specific intent of determining where the sales of a new product came from. """ + _model_type = "Multivariate Interrupted Time Series" + def __init__( self, - data: pd.DataFrame, - treatment_time, existing_sales: list[str], - treatment_sales: str, market_saturated: bool = True, - rng=42, - sample_kwargs: dict | None = None, + model_config: dict | None = None, + sampler_config: dict | None = None, ): - self.data = data - self.treatment_time = treatment_time + super().__init__(model_config=model_config, sampler_config=sampler_config) + self.existing_sales = existing_sales - self.treatment_sales = treatment_sales - self.rng = rng - self.sample_kwargs = sample_kwargs if sample_kwargs is not None else {} self.market_saturated = market_saturated - self.model = self.build_model( - self.data[self.existing_sales], - self.data[self.treatment_sales], - self.market_saturated, - treatment_time=self.treatment_time, - ) - self.sample_prior_predictive() - self.fit() - self.sample_posterior_predictive() - self.calculate_counterfactual() - return + def create_idata_attrs(self) -> dict[str, str]: + """Create the attributes for the InferenceData object.""" + attrs = super().create_idata_attrs() + attrs["existing_sales"] = json.dumps(self.existing_sales) + attrs["market_saturated"] = json.dumps(self.market_saturated) - @staticmethod - def build_model( - existing_sales: pd.DataFrame, - treatment_sales: pd.Series, - market_saturated: bool, - treatment_time, - *, - alpha_background=0.5, - ): - """Return a PyMC model for a multivariate interrupted time series analysis.""" - if not existing_sales.index.equals(treatment_sales.index): - raise ValueError( # pragma: no cover - "Index of existing_sales and treatment_sales must match." - ) + return attrs + + @classmethod + def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]: + """Convert the attributes of the InferenceData object to the __init__ kwargs.""" + return { + "existing_sales": json.loads(attrs["existing_sales"]), + "market_saturated": json.loads(attrs["market_saturated"]), + } - # note: type hints for coords required for mypi to not get confused - coords: dict[str, list[str]] = { - "background_product": list(existing_sales.columns), - "time": list(existing_sales.index.values), + @property + def default_model_config(self) -> dict: + """Default model configuration.""" + return { + "intercept": Prior("Normal", dims="background_product"), + "likelihood": Prior( + "TruncatedNormal", + lower=0, + sigma=Prior("HalfNormal", dims="background_product"), + dims=("time", "background_product"), + ), + "alpha_background": 0.5, + } + + @property + def default_sampler_config(self) -> dict: + """Default sampler configuration.""" + return {} + + @property + def output_var(self) -> str: + """The output variable of the model.""" + return "y" + + def _serializable_model_config(self) -> dict[str, int | float | dict]: # type: ignore + result: dict[str, int | float | dict] = { + "intercept": self.model_config["intercept"].to_json(), + "likelihood": self.model_config["likelihood"].to_json(), + "alpha_background": self.model_config["alpha_background"], + } + + return result + + def _generate_and_preprocess_model_data( + self, + X: pd.DataFrame | pd.Series, + y: np.ndarray, + ) -> None: + if isinstance(X, pd.Series): + raise ValueError("X must be a DataFrame, not a Series") # pragma: no cover + + self.X = X[self.existing_sales] + self.y = pd.Series(y, index=X.index) + + # note: type hints for coords required for mypy to not get confused + self.coords: dict[str, list[str]] = { + "background_product": list(self.existing_sales), + "time": list(X.index.values), "all_sources": [ - *list(existing_sales.columns), + *list(self.existing_sales), "new", ], } - with pm.Model(coords=coords) as model: + def build_model( + self, + X: pd.DataFrame, + y: pd.Series | np.ndarray, + **kwargs, + ) -> None: + """Build a PyMC model for a multivariate interrupted time series analysis.""" + self._generate_and_preprocess_model_data(X, y) # type: ignore + + with pm.Model(coords=self.coords) as model: # data _existing_sales = pm.Data( "existing_sales", - existing_sales.values, + X.values, dims=("time", "background_product"), ) - treatment_sales = pm.Data( - "treatment_sales", treatment_sales.values, dims=("time",) + y = pm.Data( + "treatment_sales", + y if not isinstance(y, pd.Series) else y.values, + dims="time", ) # priors - intercept = pm.Normal( - "intercept", - mu=pm.math.mean(existing_sales[:treatment_time], axis=0), - sigma=np.std(existing_sales[:treatment_time], axis=0), - dims="background_product", - ) + intercept = self.model_config["intercept"].create_variable(name="intercept") + alpha_background = self.model_config["alpha_background"] - sigma = pm.HalfNormal( - "background_product_sigma", - sigma=pm.math.mean(existing_sales.std().values), - dims="background_product", - ) - - if market_saturated: + if self.market_saturated: """We assume the market is saturated. The sum of the beta's will be 1. This means that the reduction in sales of existing products will equal the increase in sales of the new product, such that the total sales remain constant.""" - alpha = np.full(len(coords["background_product"]), alpha_background) + alpha = np.full( + len(self.coords["background_product"]), + alpha_background, + ) beta = pm.Dirichlet("beta", a=alpha, dims="background_product") else: """We assume the market is not saturated. The sum of the beta's will be less than 1. This means that the reduction in sales of existing products will be less than the increase in sales of the new product.""" - alpha_all = np.full(len(coords["all_sources"]), alpha_background) + alpha_all = np.full(len(self.coords["all_sources"]), alpha_background) beta_all = pm.Dirichlet("beta_all", a=alpha_all, dims="all_sources") beta = pm.Deterministic( - "beta", beta_all[:-1], dims="background_product" + "beta", + beta_all[:-1], + dims="background_product", ) pm.Deterministic("new sales", beta_all[-1]) # expectation mu = pm.Deterministic( "mu", - intercept[None, :] - treatment_sales[:, None] * beta[None, :], + intercept[None, :] - y[:, None] * beta[None, :], dims=("time", "background_product"), ) # likelihood - normal_dist = pm.Normal.dist(mu=mu, sigma=sigma) - pm.Truncated( - "y", - normal_dist, - lower=0, + self.model_config["likelihood"].create_likelihood_variable( + name=self.output_var, + mu=mu, observed=_existing_sales, - dims=("time", "background_product"), ) - return model + self.model = model - def sample_prior_predictive(self): - """Sample from the prior predictive distribution.""" - with self.model: - self.idata = pm.sample_prior_predictive(random_seed=self.rng) + def _data_setter( + self, + X: np.ndarray | pd.DataFrame, + y: np.ndarray | pd.Series | None = None, + ) -> None: + """Set the data. - def fit(self): - """Fit the model to the data.""" - with self.model: - self.idata.extend(pm.sample(**self.sample_kwargs, random_seed=self.rng)) + Required from the parent class - def sample_posterior_predictive(self): - """Sample from the posterior predictive distribution.""" - with self.model: - self.idata.extend( - pm.sample_posterior_predictive( - self.idata, - var_names=["mu", "y"], - random_seed=self.rng, - ) - ) + """ - def calculate_counterfactual(self): + def calculate_counterfactual(self, random_seed: int | None = None): """Calculate the counterfactual scenario of never releasing the new product.""" - zero_sales = np.zeros(self.data[self.treatment_sales].shape, dtype=np.int32) + zero_sales = np.zeros_like(self.y, dtype=np.int32) self.counterfactual_model = pm.do(self.model, {"treatment_sales": zero_sales}) with self.counterfactual_model: - self.idata.extend( + self.idata.extend( # type: ignore pm.sample_posterior_predictive( self.idata, - var_names=["mu", "y"], - random_seed=self.rng, + var_names=["mu", self.output_var], + random_seed=random_seed, predictions=True, ) ) - def causal_impact(self, variable="mu"): + def sample(self, X, y, random_seed: int | None = None): + """Sample all the things.""" + self.sample_prior_predictive(X, random_seed=random_seed) + self.fit(X, y, random_seed=random_seed) + self.sample_posterior_predictive( + X, + random_seed=random_seed, + var_names=[self.output_var, "mu"], + ) + self.calculate_counterfactual(random_seed=random_seed) + + return self + + def causal_impact(self, variable: str = "mu"): """Calculate the causal impact of the new product on the background products. Note: if we compare "mu" then we are comparing the expected sales, if we compare @@ -194,28 +235,28 @@ def causal_impact(self, variable="mu"): ) # pragma: no cover return ( - self.idata.posterior_predictive[variable] - self.idata.predictions[variable] + self.idata.posterior_predictive[variable] - self.idata.predictions[variable] # type: ignore ) - def plot_fit(self, variable="mu"): + def plot_fit(self, variable: str = "mu"): """Plot the model fit (posterior predictive) of the background products.""" if variable not in ["mu", "y"]: raise ValueError( f"variable must be either 'mu' or 'y', not {variable}" ) # pragma: no cover - fig, ax = plt.subplots() + _, ax = plt.subplots() # plot data - self.plot_data(self.data, ax) + self.plot_data(ax=ax) # plot posterior predictive distribution of sales for each of the background products - x = self.data.index.values - background_products = list(self.idata.observed_data.background_product.data) + x = self.X.index.values # type: ignore + background_products = list(self.idata.observed_data.background_product.data) # type: ignore for i, background_product in enumerate(background_products): az.plot_hdi( x, - self.idata.posterior_predictive[variable] + self.idata.posterior_predictive[variable] # type: ignore .transpose(..., "time") .sel(background_product=background_product), fill_kwargs={ @@ -237,7 +278,7 @@ def plot_counterfactual(self, variable="mu"): Plot the predicted sales of the background products under the counterfactual scenario of never releasing the new product. """ - fig, ax = plt.subplots() + _, ax = plt.subplots() if variable not in ["mu", "y"]: raise ValueError( @@ -245,10 +286,10 @@ def plot_counterfactual(self, variable="mu"): ) # pragma: no cover # plot data - self.plot_data(self.data, ax) + self.plot_data(ax=ax) # plot posterior predictive distribution of sales for each of the background products - x = self.data.index.values + x = self.X.index.values background_products = list(self.idata.observed_data.background_product.data) for i, background_product in enumerate(background_products): az.plot_hdi( @@ -280,10 +321,10 @@ def plot_causal_impact_sales(self, variable="mu"): Note: if we compare "mu" then we are comparing the expected sales, if we compare "y" then we are comparing the actual sales """ - fig, ax = plt.subplots() + _, ax = plt.subplots() # plot posterior predictive distribution of sales for each of the background products - x = self.data.index.values + x = self.X.index.values background_products = list(self.idata.observed_data.background_product.data) for i, background_product in enumerate(background_products): @@ -312,10 +353,10 @@ def plot_causal_impact_market_share(self, variable="mu"): Note: if we compare "mu" then we are comparing the expected sales, if we compare "y" then we are comparing the actual sales """ - fig, ax = plt.subplots() + _, ax = plt.subplots() # plot posterior predictive distribution of sales for each of the background products - x = self.data.index.values + x = self.X.index.values background_products = list(self.idata.observed_data.background_product.data) # divide the causal impact change in sales by the counterfactual predicted sales @@ -351,11 +392,13 @@ def plot_causal_impact_market_share(self, variable="mu"): ax.set(title="Estimated causal impact of new product upon existing products") return ax - @staticmethod - def plot_data(data, ax=None): + def plot_data(self, ax=None): """Plot the observed data.""" + data = pd.concat([self.X, self.y], axis=1) + if ax is None: - fig, ax = plt.subplots() + _, ax = plt.subplots() + data.plot(ax=ax) data.sum(axis=1).plot(label="total sales", color="black", ax=ax) ax.set_ylim(bottom=0) From b7ebdd0092094caa14f7b394ab1f54688fbbb311 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 22 Oct 2024 18:19:46 -0400 Subject: [PATCH 2/4] use Prior class for the background_distribution --- .../product_incrementality/mv_its.py | 42 +++++++++++++------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/pymc_marketing/product_incrementality/mv_its.py b/pymc_marketing/product_incrementality/mv_its.py index ea42c8e21..8a1b517d6 100644 --- a/pymc_marketing/product_incrementality/mv_its.py +++ b/pymc_marketing/product_incrementality/mv_its.py @@ -45,11 +45,24 @@ def __init__( model_config: dict | None = None, sampler_config: dict | None = None, ): - super().__init__(model_config=model_config, sampler_config=sampler_config) - self.existing_sales = existing_sales self.market_saturated = market_saturated + super().__init__(model_config=model_config, sampler_config=sampler_config) + + self._distribution_checks() + + def _distribution_checks(self): + if self.model_config["market_distribution"].distribution != "Dirichlet": + raise ValueError("market_distribution must be a Dirichlet distribution") # + + dims = "background_product" if self.market_saturated else "all_sources" + + if dims not in self.model_config["market_distribution"].dims: + raise ValueError( + f"market_distribution must have dims='{dims}', not {self.model_config['market_distribution'].dims}" + ) + def create_idata_attrs(self) -> dict[str, str]: """Create the attributes for the InferenceData object.""" attrs = super().create_idata_attrs() @@ -69,6 +82,15 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]: @property def default_model_config(self) -> dict: """Default model configuration.""" + if self.market_saturated: + a = np.full(len(self.existing_sales), 0.5) + dims = "background_product" + else: + a = np.full(len(self.existing_sales) + 1, 0.5) + dims = "all_sources" + + market_distribution = Prior("Dirichlet", a=a, dims=dims) + return { "intercept": Prior("Normal", dims="background_product"), "likelihood": Prior( @@ -77,7 +99,7 @@ def default_model_config(self) -> dict: sigma=Prior("HalfNormal", dims="background_product"), dims=("time", "background_product"), ), - "alpha_background": 0.5, + "market_distribution": market_distribution, } @property @@ -94,7 +116,7 @@ def _serializable_model_config(self) -> dict[str, int | float | dict]: # type: result: dict[str, int | float | dict] = { "intercept": self.model_config["intercept"].to_json(), "likelihood": self.model_config["likelihood"].to_json(), - "alpha_background": self.model_config["alpha_background"], + "market_distribution": self.model_config["market_distribution"].to_json(), } return result @@ -144,24 +166,20 @@ def build_model( # priors intercept = self.model_config["intercept"].create_variable(name="intercept") - alpha_background = self.model_config["alpha_background"] if self.market_saturated: """We assume the market is saturated. The sum of the beta's will be 1. This means that the reduction in sales of existing products will equal the increase in sales of the new product, such that the total sales remain constant.""" - alpha = np.full( - len(self.coords["background_product"]), - alpha_background, - ) - beta = pm.Dirichlet("beta", a=alpha, dims="background_product") + beta = self.model_config["market_distribution"].create_variable("beta") else: """We assume the market is not saturated. The sum of the beta's will be less than 1. This means that the reduction in sales of existing products will be less than the increase in sales of the new product.""" - alpha_all = np.full(len(self.coords["all_sources"]), alpha_background) - beta_all = pm.Dirichlet("beta_all", a=alpha_all, dims="all_sources") + beta_all = self.model_config["market_distribution"].create_variable( + "beta_all", + ) beta = pm.Deterministic( "beta", beta_all[:-1], From 4ee1b92455636b4e22ce22a8e83945ee3dc0beb0 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 22 Oct 2024 22:23:27 -0400 Subject: [PATCH 3/4] modify the tests --- .../product_incrementality/mv_its.py | 34 ++- .../test_incrementality.py | 130 ----------- tests/product_incrementality/test_mv_its.py | 220 ++++++++++++++++++ 3 files changed, 247 insertions(+), 137 deletions(-) delete mode 100644 tests/product_incrementality/test_incrementality.py create mode 100644 tests/product_incrementality/test_mv_its.py diff --git a/pymc_marketing/product_incrementality/mv_its.py b/pymc_marketing/product_incrementality/mv_its.py index 8a1b517d6..177eab764 100644 --- a/pymc_marketing/product_incrementality/mv_its.py +++ b/pymc_marketing/product_incrementality/mv_its.py @@ -22,6 +22,7 @@ import numpy as np import pandas as pd import pymc as pm +from typing_extensions import Self from pymc_marketing.model_builder import ModelBuilder from pymc_marketing.prior import Prior @@ -37,6 +38,7 @@ class MVITS(ModelBuilder): """ _model_type = "Multivariate Interrupted Time Series" + version = "0.1.0" def __init__( self, @@ -214,7 +216,10 @@ def _data_setter( """ - def calculate_counterfactual(self, random_seed: int | None = None): + def calculate_counterfactual( + self, + random_seed: np.random.Generator | int | None = None, + ): """Calculate the counterfactual scenario of never releasing the new product.""" zero_sales = np.zeros_like(self.y, dtype=np.int32) self.counterfactual_model = pm.do(self.model, {"treatment_sales": zero_sales}) @@ -228,14 +233,29 @@ def calculate_counterfactual(self, random_seed: int | None = None): ) ) - def sample(self, X, y, random_seed: int | None = None): + def sample( + self, + X, + y, + random_seed: np.random.Generator | int | None = None, + sample_prior_predictive_kwargs: dict | None = None, + fit_kwargs: dict | None = None, + sample_posterior_predictive_kwargs: dict | None = None, + ) -> Self: """Sample all the things.""" - self.sample_prior_predictive(X, random_seed=random_seed) - self.fit(X, y, random_seed=random_seed) + sample_prior_predictive_kwargs = sample_prior_predictive_kwargs or {} + fit_kwargs = fit_kwargs or {} + sample_posterior_predictive_kwargs = sample_posterior_predictive_kwargs or {} + + self.sample_prior_predictive( + X, random_seed=random_seed, **sample_prior_predictive_kwargs + ) + self.fit(X, y, random_seed=random_seed, **fit_kwargs) self.sample_posterior_predictive( X, random_seed=random_seed, var_names=[self.output_var, "mu"], + **sample_posterior_predictive_kwargs, ) self.calculate_counterfactual(random_seed=random_seed) @@ -270,7 +290,7 @@ def plot_fit(self, variable: str = "mu"): # plot posterior predictive distribution of sales for each of the background products x = self.X.index.values # type: ignore - background_products = list(self.idata.observed_data.background_product.data) # type: ignore + background_products = self.coords["background_product"] for i, background_product in enumerate(background_products): az.plot_hdi( x, @@ -308,7 +328,7 @@ def plot_counterfactual(self, variable="mu"): # plot posterior predictive distribution of sales for each of the background products x = self.X.index.values - background_products = list(self.idata.observed_data.background_product.data) + background_products = self.coords["background_product"] for i, background_product in enumerate(background_products): az.plot_hdi( x, @@ -343,7 +363,7 @@ def plot_causal_impact_sales(self, variable="mu"): # plot posterior predictive distribution of sales for each of the background products x = self.X.index.values - background_products = list(self.idata.observed_data.background_product.data) + background_products = self.coords["background_product"] for i, background_product in enumerate(background_products): az.plot_hdi( diff --git a/tests/product_incrementality/test_incrementality.py b/tests/product_incrementality/test_incrementality.py deleted file mode 100644 index bc8fd5b6e..000000000 --- a/tests/product_incrementality/test_incrementality.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2024 The PyMC Labs Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np -import pandas as pd -import pytest -from matplotlib import pyplot as plt - -from pymc_marketing.product_incrementality.mv_its import ( - MVITS, - generate_saturated_data, - generate_unsaturated_data, -) - -rng = np.random.default_rng(123) - -scenario_saturated = { - "total_sales_mu": 1000, - "total_sales_sigma": 5, - "treatment_time": 40, - "n_observations": 100, - "market_shares_before": [[0.7, 0.3, 0]], - "market_shares_after": [[0.65, 0.25, 0.1]], - "market_share_labels": ["competitor", "own", "new"], - "rng": rng, -} - -scenario_unsaturated_bad = { - "total_sales_before": [1000], - "total_sales_after": [1400], - "total_sales_sigma": 20, - "treatment_time": 40, - "n_observations": 100, - "market_shares_before": [[0.7, 0.3, 0]], - "market_shares_after": [[0.65, 0.25, 0.1]], - "market_share_labels": ["competitor", "own", "new"], - "rng": rng, -} - -scenario_unsaturated_good = { - "total_sales_before": [800], - "total_sales_after": [950], - "total_sales_sigma": 10, - "treatment_time": 40, - "n_observations": 100, - "market_shares_before": [[500 / 800, 300 / 800, 0]], - "market_shares_after": [[400 / 950, 200 / 950, 350 / 950]], - "market_share_labels": ["competitor", "own", "new"], - "rng": rng, -} - -sample_kwargs = {"tune": 100, "draws": 100} - - -@pytest.fixture(scope="module") -def saturated_data_fixture(): - return generate_saturated_data(**scenario_saturated) - - -def test_plot_data(saturated_data_fixture): - ax = MVITS.plot_data(saturated_data_fixture) - assert isinstance(ax, plt.Axes) - - -def test_MVITS_saturated(saturated_data_fixture): - result = MVITS( - saturated_data_fixture, - treatment_time=scenario_saturated["treatment_time"], - existing_sales=["competitor", "own"], - treatment_sales="new", - rng=rng, - sample_kwargs=sample_kwargs, - ) - assert isinstance(result, MVITS) - - ax = result.plot_fit() - assert isinstance(ax, plt.Axes) - - ax = result.plot_counterfactual() - assert isinstance(ax, plt.Axes) - - ax = result.plot_causal_impact_sales() - assert isinstance(ax, plt.Axes) - - ax = result.plot_causal_impact_market_share() - assert isinstance(ax, plt.Axes) - - -@pytest.mark.parametrize( - "scenario", [scenario_unsaturated_bad, scenario_unsaturated_good] -) -def test_MVITS_unsaturated(scenario): - """We will test the `unsaturated` version of the MVITS model. And we will do this - with multiple scenarios.""" - - data = generate_unsaturated_data(**scenario) - assert isinstance(data, pd.DataFrame) - - result = MVITS( - data, - treatment_time=scenario_saturated["treatment_time"], - existing_sales=["competitor", "own"], - market_saturated=False, - treatment_sales="new", - rng=rng, - sample_kwargs=sample_kwargs, - ) - assert isinstance(result, MVITS) - - ax = result.plot_fit() - assert isinstance(ax, plt.Axes) - - ax = result.plot_counterfactual() - assert isinstance(ax, plt.Axes) - - ax = result.plot_causal_impact_sales() - assert isinstance(ax, plt.Axes) - - ax = result.plot_causal_impact_market_share() - assert isinstance(ax, plt.Axes) diff --git a/tests/product_incrementality/test_mv_its.py b/tests/product_incrementality/test_mv_its.py new file mode 100644 index 000000000..1b01e37cd --- /dev/null +++ b/tests/product_incrementality/test_mv_its.py @@ -0,0 +1,220 @@ +# Copyright 2024 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import numpy as np +import pandas as pd +import pytest +from matplotlib import pyplot as plt + +from pymc_marketing.product_incrementality.mv_its import ( + MVITS, + generate_saturated_data, + generate_unsaturated_data, +) + +seed = sum(map(ord, "Product Incrementality")) +rng = np.random.default_rng(seed) + + +scenario_saturated = { + "total_sales_mu": 1000, + "total_sales_sigma": 5, + "treatment_time": 40, + "n_observations": 100, + "market_shares_before": [[0.7, 0.3, 0]], + "market_shares_after": [[0.65, 0.25, 0.1]], + "market_share_labels": ["competitor", "own", "new"], + "rng": rng, +} + +scenario_unsaturated_bad = { + "total_sales_before": [1000], + "total_sales_after": [1400], + "total_sales_sigma": 20, + "treatment_time": 40, + "n_observations": 100, + "market_shares_before": [[0.7, 0.3, 0]], + "market_shares_after": [[0.65, 0.25, 0.1]], + "market_share_labels": ["competitor", "own", "new"], + "rng": rng, +} + +scenario_unsaturated_good = { + "total_sales_before": [800], + "total_sales_after": [950], + "total_sales_sigma": 10, + "treatment_time": 40, + "n_observations": 100, + "market_shares_before": [[500 / 800, 300 / 800, 0]], + "market_shares_after": [[400 / 950, 200 / 950, 350 / 950]], + "market_share_labels": ["competitor", "own", "new"], + "rng": rng, +} + + +@pytest.fixture(scope="module") +def saturated_data(): + return generate_saturated_data(**scenario_saturated) + + +def test_plot_data(saturated_data): + model = MVITS(existing_sales=["competitor", "own"]) + model.X = saturated_data.loc[:, ["competitor", "own"]] + model.y = saturated_data["new"] + + ax = model.plot_data() + assert isinstance(ax, plt.Axes) + plt.close() + + +def mock_fit(self, X, y, **kwargs): + self.idata.add_groups( + { + "posterior": self.idata.prior, + } + ) + + combined_data = pd.concat([X, y.rename(self.output_var)], axis=1) + + if "fit_data" in self.idata: + del self.idata.fit_data + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="The group fit_data is not defined in the InferenceData scheme", + ) + self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore + + return self + + +@pytest.fixture(scope="module") +def fit_model(module_mocker, saturated_data): + model = MVITS(existing_sales=["competitor", "own"], market_saturated=True) + + module_mocker.patch( + "pymc_marketing.product_incrementality.mv_its.MVITS.fit", + mock_fit, + ) + + model.sample( + saturated_data.loc[:, ["competitor", "own"]], + saturated_data["new"], + random_seed=rng, + sample_prior_predictive_kwargs={"samples": 10}, + ) + return model + + +@pytest.mark.parametrize( + "plot_method", + [ + "plot_fit", + "plot_counterfactual", + "plot_causal_impact_sales", + "plot_causal_impact_market_share", + ], +) +def test_MVITS_saturated(fit_model, plot_method): + ax = getattr(fit_model, plot_method)() + assert isinstance(ax, plt.Axes) + plt.close() + + +@pytest.fixture(scope="module") +def unsaturated_data_bad(): + return generate_unsaturated_data(**scenario_unsaturated_bad) + + +@pytest.fixture(scope="module") +def unsaturated_data_good(): + return generate_unsaturated_data(**scenario_unsaturated_good) + + +@pytest.fixture(scope="module") +def unsaturated_model_bad(module_mocker, unsaturated_data_bad): + model = MVITS(existing_sales=["competitor", "own"], market_saturated=False) + + module_mocker.patch( + "pymc_marketing.product_incrementality.mv_its.MVITS.fit", + mock_fit, + ) + + model.sample( + unsaturated_data_bad.loc[:, ["competitor", "own"]], + unsaturated_data_bad["new"], + random_seed=rng, + sample_prior_predictive_kwargs={"samples": 10}, + ) + return model + + +@pytest.fixture(scope="module") +def unsaturated_model_good(module_mocker, unsaturated_data_good): + model = MVITS(existing_sales=["competitor", "own"], market_saturated=False) + + module_mocker.patch( + "pymc_marketing.product_incrementality.mv_its.MVITS.fit", + mock_fit, + ) + + model.sample( + unsaturated_data_good.loc[:, ["competitor", "own"]], + unsaturated_data_good["new"], + random_seed=rng, + sample_prior_predictive_kwargs={"samples": 10}, + ) + return model + + +@pytest.mark.parametrize( + "model_name", ["unsaturated_model_bad", "unsaturated_model_good"] +) +@pytest.mark.parametrize( + "plot_method", + [ + "plot_fit", + "plot_counterfactual", + "plot_causal_impact_sales", + "plot_causal_impact_market_share", + ], +) +def test_MVITS_unsaturated(request, model_name, plot_method): + """We will test the `unsaturated` version of the MVITS model. And we will do this + with multiple scenarios.""" + + model = request.getfixturevalue(model_name) + + ax = getattr(model, plot_method)() + assert isinstance(ax, plt.Axes) + plt.close() + + +def test_save_load(fit_model, saturated_data) -> None: + test_file = "test-mvits.nc" + fit_model.save(test_file) + + loaded = MVITS.load(test_file) + + assert loaded.model_config == fit_model.model_config + assert loaded.existing_sales == fit_model.existing_sales + assert loaded.market_saturated == fit_model.market_saturated + assert loaded.X.columns.name is None + pd.testing.assert_frame_equal(loaded.X, fit_model.X, check_names=False) + assert loaded.y.name == fit_model.output_var + pd.testing.assert_series_equal(loaded.y.rename("new"), saturated_data["new"]) From 8a284dc65515e2900d7b63d0d94be8f4a02611b2 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 22 Oct 2024 22:27:49 -0400 Subject: [PATCH 4/4] test for mismatch --- pymc_marketing/model_builder.py | 4 +--- tests/test_model_builder.py | 9 +++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index 96077fda9..2ab4f4323 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -581,9 +581,7 @@ def fit( """ if isinstance(y, pd.Series) and not X.index.equals(y.index): - raise ValueError( # pragma: no cover - "Index of X and y must match." - ) + raise ValueError("Index of X and y must match.") if predictor_names is None: predictor_names = [] diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index ec2f9e90b..add74eff1 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -571,3 +571,12 @@ def mock_sample(*args, **kwargs): match = "Object of type Generator is not JSON serializable" with pytest.raises(TypeError, match=match): model.fit(toy_X, toy_y) + + +def test_unmatched_index(toy_X, toy_y) -> None: + model = ModelBuilderTest() + toy_X = toy_X.copy() + toy_X.index = toy_X.index + 1 + match = "Index of X and y must match" + with pytest.raises(ValueError, match=match): + model.fit(toy_X, toy_y)