Skip to content

Commit

Permalink
Register and allow custom transform for Prior class (#972)
Browse files Browse the repository at this point in the history
* allow register and use custom transform

* add to the example block
  • Loading branch information
wd60622 authored and twiecki committed Sep 10, 2024
1 parent 82f67f4 commit 4ec3b4a
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 3 deletions.
64 changes: 62 additions & 2 deletions pymc_marketing/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@
dims="channel",
)
Create a prior with a custom transform function by registering it with
`register_tensor_transform`.
.. code-block:: python
from pymc_marketing.prior import register_tensor_transform
def custom_transform(x):
return x ** 2
register_tensor_transform("square", custom_transform)
custom_distribution = Prior("Normal", transform="square")
"""

from __future__ import annotations
Expand Down Expand Up @@ -198,18 +212,63 @@ def _get_pymc_distribution(name: str) -> type[pm.Distribution]:
return getattr(pm, name)


Transform = Callable[[pt.TensorLike], pt.TensorLike]

CUSTOM_TRANSFORMS: dict[str, Transform] = {}


def register_tensor_transform(name: str, transform: Transform) -> None:
"""Register a tensor transform function to be used in the `Prior` class.
Parameters
----------
name : str
The name of the transform.
func : Callable[[pt.TensorLike], pt.TensorLike]
The function to apply to the tensor.
Examples
--------
Register a custom transform function.
.. code-block:: python
from pymc_marketing.prior import (
Prior,
register_tensor_transform,
)
def custom_transform(x):
return x ** 2
register_tensor_transform("square", custom_transform)
custom_distribution = Prior("Normal", transform="square")
"""
CUSTOM_TRANSFORMS[name] = transform


def _get_transform(name: str):
if name in CUSTOM_TRANSFORMS:
return CUSTOM_TRANSFORMS[name]

for module in (pt, pm.math):
if hasattr(module, name):
break
else:
module = None

if not module:
raise UnknownTransformError(
f"Neither PyTensor or pm.math have the function {name!r}"
msg = (
f"Neither pytensor.tensor nor pymc.math have the function {name!r}. "
"If this is a custom function, register it with the "
"`pymc_marketing.prior.register_tensor_transform` function before "
"previous function call."
)

raise UnknownTransformError(msg)

return getattr(module, name)


Expand Down Expand Up @@ -243,6 +302,7 @@ class Prior:
transform : str, optional
The name of the transform to apply to the variable after it is
created, by default None or no transform. The transformation must
be registered with `register_tensor_transform` function or
be available in either `pytensor.tensor` or `pymc.math`.
"""
Expand Down
40 changes: 39 additions & 1 deletion tests/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
UnsupportedParameterizationError,
UnsupportedShapeError,
handle_dims,
register_tensor_transform,
)


Expand Down Expand Up @@ -72,7 +73,8 @@ def test_handle_dims(x, dims, desired_dims, expected_fn) -> None:


def test_missing_transform() -> None:
with pytest.raises(UnknownTransformError):
match = "Neither pytensor.tensor nor pymc.math have the function 'foo_bar'"
with pytest.raises(UnknownTransformError, match=match):
Prior("Normal", transform="foo_bar")


Expand Down Expand Up @@ -608,3 +610,39 @@ def test_checks_param_value_types() -> None:
def test_check_equality_with_numpy() -> None:
dist = Prior("Normal", mu=np.array([1, 2, 3]), sigma=1)
assert dist == dist.deepcopy()


def clear_custom_transforms() -> None:
global CUSTOM_TRANSFORMS
CUSTOM_TRANSFORMS = {}


def test_custom_transform() -> None:
new_transform_name = "foo_bar"
with pytest.raises(UnknownTransformError):
Prior("Normal", transform=new_transform_name)

register_tensor_transform(new_transform_name, lambda x: x**2)

dist = Prior("Normal", transform=new_transform_name)
prior = dist.sample_prior(samples=10)
df_prior = prior.to_dataframe()

np.testing.assert_array_equal(
df_prior["var"].to_numpy(), df_prior["var_raw"].to_numpy() ** 2
)


def test_custom_transform_comes_first() -> None:
# function in pytensor.tensor
register_tensor_transform("square", lambda x: 2 * x)

dist = Prior("Normal", transform="square")
prior = dist.sample_prior(samples=10)
df_prior = prior.to_dataframe()

np.testing.assert_array_equal(
df_prior["var"].to_numpy(), 2 * df_prior["var_raw"].to_numpy()
)

clear_custom_transforms()

0 comments on commit 4ec3b4a

Please sign in to comment.