diff --git a/pymc_marketing/prior.py b/pymc_marketing/prior.py index 1ed68bb8f..608ea5066 100644 --- a/pymc_marketing/prior.py +++ b/pymc_marketing/prior.py @@ -78,6 +78,20 @@ dims="channel", ) +Create a prior with a custom transform function by registering it with +`register_tensor_transform`. + +.. code-block:: python + + from pymc_marketing.prior import register_tensor_transform + + def custom_transform(x): + return x ** 2 + + register_tensor_transform("square", custom_transform) + + custom_distribution = Prior("Normal", transform="square") + """ from __future__ import annotations @@ -198,7 +212,47 @@ def _get_pymc_distribution(name: str) -> type[pm.Distribution]: return getattr(pm, name) +Transform = Callable[[pt.TensorLike], pt.TensorLike] + +CUSTOM_TRANSFORMS: dict[str, Transform] = {} + + +def register_tensor_transform(name: str, transform: Transform) -> None: + """Register a tensor transform function to be used in the `Prior` class. + + Parameters + ---------- + name : str + The name of the transform. + func : Callable[[pt.TensorLike], pt.TensorLike] + The function to apply to the tensor. + + Examples + -------- + Register a custom transform function. + + .. code-block:: python + + from pymc_marketing.prior import ( + Prior, + register_tensor_transform, + ) + + def custom_transform(x): + return x ** 2 + + register_tensor_transform("square", custom_transform) + + custom_distribution = Prior("Normal", transform="square") + + """ + CUSTOM_TRANSFORMS[name] = transform + + def _get_transform(name: str): + if name in CUSTOM_TRANSFORMS: + return CUSTOM_TRANSFORMS[name] + for module in (pt, pm.math): if hasattr(module, name): break @@ -206,10 +260,15 @@ def _get_transform(name: str): module = None if not module: - raise UnknownTransformError( - f"Neither PyTensor or pm.math have the function {name!r}" + msg = ( + f"Neither pytensor.tensor nor pymc.math have the function {name!r}. " + "If this is a custom function, register it with the " + "`pymc_marketing.prior.register_tensor_transform` function before " + "previous function call." ) + raise UnknownTransformError(msg) + return getattr(module, name) @@ -243,6 +302,7 @@ class Prior: transform : str, optional The name of the transform to apply to the variable after it is created, by default None or no transform. The transformation must + be registered with `register_tensor_transform` function or be available in either `pytensor.tensor` or `pymc.math`. """ diff --git a/tests/test_prior.py b/tests/test_prior.py index 3ebae2816..943fdcd56 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -31,6 +31,7 @@ UnsupportedParameterizationError, UnsupportedShapeError, handle_dims, + register_tensor_transform, ) @@ -72,7 +73,8 @@ def test_handle_dims(x, dims, desired_dims, expected_fn) -> None: def test_missing_transform() -> None: - with pytest.raises(UnknownTransformError): + match = "Neither pytensor.tensor nor pymc.math have the function 'foo_bar'" + with pytest.raises(UnknownTransformError, match=match): Prior("Normal", transform="foo_bar") @@ -608,3 +610,39 @@ def test_checks_param_value_types() -> None: def test_check_equality_with_numpy() -> None: dist = Prior("Normal", mu=np.array([1, 2, 3]), sigma=1) assert dist == dist.deepcopy() + + +def clear_custom_transforms() -> None: + global CUSTOM_TRANSFORMS + CUSTOM_TRANSFORMS = {} + + +def test_custom_transform() -> None: + new_transform_name = "foo_bar" + with pytest.raises(UnknownTransformError): + Prior("Normal", transform=new_transform_name) + + register_tensor_transform(new_transform_name, lambda x: x**2) + + dist = Prior("Normal", transform=new_transform_name) + prior = dist.sample_prior(samples=10) + df_prior = prior.to_dataframe() + + np.testing.assert_array_equal( + df_prior["var"].to_numpy(), df_prior["var_raw"].to_numpy() ** 2 + ) + + +def test_custom_transform_comes_first() -> None: + # function in pytensor.tensor + register_tensor_transform("square", lambda x: 2 * x) + + dist = Prior("Normal", transform="square") + prior = dist.sample_prior(samples=10) + df_prior = prior.to_dataframe() + + np.testing.assert_array_equal( + df_prior["var"].to_numpy(), 2 * df_prior["var_raw"].to_numpy() + ) + + clear_custom_transforms()