diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index 3a7cf7cc1..725eeb6db 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -47,6 +47,46 @@ def check_array(X, **kwargs): return X +def create_sample_kwargs( + sampler_config: dict[str, Any], + progressbar: bool | None, + random_seed, + **kwargs, +) -> dict[str, Any]: + """Create the dictionary of keyword arguments for `pm.sample`. + + Parameters + ---------- + sampler_config : dict + The configuration dictionary for the sampler. + progressbar : bool, optional + Whether to show the progress bar during sampling. Defaults to True. + random_seed : RandomState + The random seed for the sampler. + **kwargs : Any + Additional keyword arguments to pass to the sampler. + + Returns + ------- + dict + The dictionary of keyword arguments for `pm.sample`. + + """ + sampler_config = sampler_config.copy() + + if progressbar is not None: + sampler_config["progressbar"] = progressbar + else: + sampler_config["progressbar"] = sampler_config.get("progressbar", True) + + if random_seed is not None: + sampler_config["random_seed"] = random_seed + + sampler_config.update(**kwargs) + + return sampler_config + + class ModelBuilder(ABC): """Base class for building models with PyMC Marketing. @@ -501,7 +541,7 @@ def fit( self, X: pd.DataFrame, y: pd.Series | np.ndarray | None = None, - progressbar: bool = True, + progressbar: bool | None = None, predictor_names: list[str] | None = None, random_seed: RandomState | None = None, **kwargs: Any, @@ -516,8 +556,8 @@ def fit( The training input samples. If scikit-learn is available, array-like, otherwise array. y : array-like | array, shape (n_obs,) The target values (real numbers). If scikit-learn is available, array-like, otherwise array. - progressbar : bool - Specifies whether the fit progress bar should be displayed. + progressbar : bool, optional + Specifies whether the fit progress bar should be displayed. Defaults to True. predictor_names : Optional[List[str]] = None, Allows for custom naming of predictors when given in a form of a 2D array. Allows for naming of predictors when given in a form of np.ndarray, if not provided @@ -556,14 +596,14 @@ def fit( if not hasattr(self, "model"): self.build_model(self.X, self.y) - sampler_config = self.sampler_config.copy() - sampler_config["progressbar"] = progressbar - sampler_config["random_seed"] = random_seed - sampler_config.update(**kwargs) - - sampler_args = {**self.sampler_config, **kwargs} + sampler_kwargs = create_sample_kwargs( + self.sampler_config, + progressbar, + random_seed, + **kwargs, + ) with self.model: - idata = pm.sample(**sampler_args) + idata = pm.sample(**sampler_kwargs) if self.idata: self.idata = self.idata.copy() diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 2aff3d79e..ec2f9e90b 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -30,13 +30,14 @@ import sys import tempfile +import arviz as az import numpy as np import pandas as pd import pymc as pm import pytest import xarray as xr -from pymc_marketing.model_builder import ModelBuilder +from pymc_marketing.model_builder import ModelBuilder, create_sample_kwargs @pytest.fixture(scope="module") @@ -446,3 +447,127 @@ def create_idata_attrs(self) -> dict: match = "Missing required keys in attrs" with pytest.raises(ValueError, match=match): model.sample_prior_predictive(X_pred=X_pred) + + +@pytest.mark.parametrize( + "sampler_config, fit_kwargs, expected", + [ + ( + {}, + { + "progressbar": None, + "random_seed": None, + }, + { + "progressbar": True, + }, + ), + ( + { + "random_seed": 52, + "progressbar": False, + }, + { + "progressbar": None, + "random_seed": None, + }, + { + "progressbar": False, + "random_seed": 52, + }, + ), + ( + { + "random_seed": 52, + "progressbar": True, + }, + { + "progressbar": False, + "random_seed": 42, + }, + { + "progressbar": False, + "random_seed": 42, + }, + ), + ], + ids=[ + "no_sampler_config/defaults", + "use_sampler_config", + "override_sampler_config", + ], +) +def test_create_sample_kwargs(sampler_config, fit_kwargs, expected) -> None: + sampler_config_before = sampler_config.copy() + assert create_sample_kwargs(sampler_config, **fit_kwargs) == expected + + # Doesn't override + assert sampler_config_before == sampler_config + + +def create_int_seed(): + return 42 + + +def create_rng_seed(): + return np.random.default_rng(42) + + +@pytest.mark.parametrize( + "create_random_seed", + [ + create_int_seed, + create_rng_seed, + ], + ids=["int", "rng"], +) +def test_fit_random_seed_reproducibility(toy_X, toy_y, create_random_seed) -> None: + sampler_config = { + "chains": 1, + "draws": 10, + "tune": 5, + } + model = ModelBuilderTest(sampler_config=sampler_config) + + idata = model.fit(toy_X, toy_y, random_seed=create_random_seed()) + idata2 = model.fit(toy_X, toy_y, random_seed=create_random_seed()) + + assert idata.posterior.equals(idata2.posterior) + + sizes = idata.posterior.sizes + assert sizes["chain"] == 1 + assert sizes["draw"] == 10 + + +def test_fit_sampler_config_seed_reproducibility(toy_X, toy_y) -> None: + sampler_config = { + "chains": 1, + "draws": 10, + "tune": 5, + "random_seed": 42, + } + model = ModelBuilderTest(sampler_config=sampler_config) + + idata = model.fit(toy_X, toy_y) + idata2 = model.fit(toy_X, toy_y) + + assert idata.posterior.equals(idata2.posterior) + + +def test_fit_sampler_config_with_rng_fails(mocker, toy_X, toy_y) -> None: + def mock_sample(*args, **kwargs): + idata = pm.sample_prior_predictive(10) + return az.InferenceData(posterior=idata.prior) + + mocker.patch("pymc.sample", mock_sample) + sampler_config = { + "chains": 1, + "draws": 10, + "tune": 5, + "random_seed": np.random.default_rng(42), + } + model = ModelBuilderTest(sampler_config=sampler_config) + + match = "Object of type Generator is not JSON serializable" + with pytest.raises(TypeError, match=match): + model.fit(toy_X, toy_y)