Skip to content

Commit

Permalink
Ensure fit reproducibility (#963)
Browse files Browse the repository at this point in the history
* use get instead of assign to default to sampler_config if exists

* default to what is given

* write tests based on the issue

* have defaults while not overriding
  • Loading branch information
wd60622 authored and twiecki committed Sep 10, 2024
1 parent 11cb80e commit bcd429b
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 11 deletions.
60 changes: 50 additions & 10 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,46 @@ def check_array(X, **kwargs):
return X


def create_sample_kwargs(
sampler_config: dict[str, Any],
progressbar: bool | None,
random_seed,
**kwargs,
) -> dict[str, Any]:
"""Create the dictionary of keyword arguments for `pm.sample`.
Parameters
----------
sampler_config : dict
The configuration dictionary for the sampler.
progressbar : bool, optional
Whether to show the progress bar during sampling. Defaults to True.
random_seed : RandomState
The random seed for the sampler.
**kwargs : Any
Additional keyword arguments to pass to the sampler.
Returns
-------
dict
The dictionary of keyword arguments for `pm.sample`.
"""
sampler_config = sampler_config.copy()

if progressbar is not None:
sampler_config["progressbar"] = progressbar
else:
sampler_config["progressbar"] = sampler_config.get("progressbar", True)

if random_seed is not None:
sampler_config["random_seed"] = random_seed

sampler_config.update(**kwargs)

return sampler_config


class ModelBuilder(ABC):
"""Base class for building models with PyMC Marketing.
Expand Down Expand Up @@ -501,7 +541,7 @@ def fit(
self,
X: pd.DataFrame,
y: pd.Series | np.ndarray | None = None,
progressbar: bool = True,
progressbar: bool | None = None,
predictor_names: list[str] | None = None,
random_seed: RandomState | None = None,
**kwargs: Any,
Expand All @@ -516,8 +556,8 @@ def fit(
The training input samples. If scikit-learn is available, array-like, otherwise array.
y : array-like | array, shape (n_obs,)
The target values (real numbers). If scikit-learn is available, array-like, otherwise array.
progressbar : bool
Specifies whether the fit progress bar should be displayed.
progressbar : bool, optional
Specifies whether the fit progress bar should be displayed. Defaults to True.
predictor_names : Optional[List[str]] = None,
Allows for custom naming of predictors when given in a form of a 2D array.
Allows for naming of predictors when given in a form of np.ndarray, if not provided
Expand Down Expand Up @@ -556,14 +596,14 @@ def fit(
if not hasattr(self, "model"):
self.build_model(self.X, self.y)

sampler_config = self.sampler_config.copy()
sampler_config["progressbar"] = progressbar
sampler_config["random_seed"] = random_seed
sampler_config.update(**kwargs)

sampler_args = {**self.sampler_config, **kwargs}
sampler_kwargs = create_sample_kwargs(
self.sampler_config,
progressbar,
random_seed,
**kwargs,
)
with self.model:
idata = pm.sample(**sampler_args)
idata = pm.sample(**sampler_kwargs)

if self.idata:
self.idata = self.idata.copy()
Expand Down
127 changes: 126 additions & 1 deletion tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@
import sys
import tempfile

import arviz as az
import numpy as np
import pandas as pd
import pymc as pm
import pytest
import xarray as xr

from pymc_marketing.model_builder import ModelBuilder
from pymc_marketing.model_builder import ModelBuilder, create_sample_kwargs


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -446,3 +447,127 @@ def create_idata_attrs(self) -> dict:
match = "Missing required keys in attrs"
with pytest.raises(ValueError, match=match):
model.sample_prior_predictive(X_pred=X_pred)


@pytest.mark.parametrize(
"sampler_config, fit_kwargs, expected",
[
(
{},
{
"progressbar": None,
"random_seed": None,
},
{
"progressbar": True,
},
),
(
{
"random_seed": 52,
"progressbar": False,
},
{
"progressbar": None,
"random_seed": None,
},
{
"progressbar": False,
"random_seed": 52,
},
),
(
{
"random_seed": 52,
"progressbar": True,
},
{
"progressbar": False,
"random_seed": 42,
},
{
"progressbar": False,
"random_seed": 42,
},
),
],
ids=[
"no_sampler_config/defaults",
"use_sampler_config",
"override_sampler_config",
],
)
def test_create_sample_kwargs(sampler_config, fit_kwargs, expected) -> None:
sampler_config_before = sampler_config.copy()
assert create_sample_kwargs(sampler_config, **fit_kwargs) == expected

# Doesn't override
assert sampler_config_before == sampler_config


def create_int_seed():
return 42


def create_rng_seed():
return np.random.default_rng(42)


@pytest.mark.parametrize(
"create_random_seed",
[
create_int_seed,
create_rng_seed,
],
ids=["int", "rng"],
)
def test_fit_random_seed_reproducibility(toy_X, toy_y, create_random_seed) -> None:
sampler_config = {
"chains": 1,
"draws": 10,
"tune": 5,
}
model = ModelBuilderTest(sampler_config=sampler_config)

idata = model.fit(toy_X, toy_y, random_seed=create_random_seed())
idata2 = model.fit(toy_X, toy_y, random_seed=create_random_seed())

assert idata.posterior.equals(idata2.posterior)

sizes = idata.posterior.sizes
assert sizes["chain"] == 1
assert sizes["draw"] == 10


def test_fit_sampler_config_seed_reproducibility(toy_X, toy_y) -> None:
sampler_config = {
"chains": 1,
"draws": 10,
"tune": 5,
"random_seed": 42,
}
model = ModelBuilderTest(sampler_config=sampler_config)

idata = model.fit(toy_X, toy_y)
idata2 = model.fit(toy_X, toy_y)

assert idata.posterior.equals(idata2.posterior)


def test_fit_sampler_config_with_rng_fails(mocker, toy_X, toy_y) -> None:
def mock_sample(*args, **kwargs):
idata = pm.sample_prior_predictive(10)
return az.InferenceData(posterior=idata.prior)

mocker.patch("pymc.sample", mock_sample)
sampler_config = {
"chains": 1,
"draws": 10,
"tune": 5,
"random_seed": np.random.default_rng(42),
}
model = ModelBuilderTest(sampler_config=sampler_config)

match = "Object of type Generator is not JSON serializable"
with pytest.raises(TypeError, match=match):
model.fit(toy_X, toy_y)

0 comments on commit bcd429b

Please sign in to comment.