diff --git a/pyrenew/latent/infectionswithfeedback.py b/pyrenew/latent/infectionswithfeedback.py index b344a1b5..2d8cc46d 100644 --- a/pyrenew/latent/infectionswithfeedback.py +++ b/pyrenew/latent/infectionswithfeedback.py @@ -5,7 +5,6 @@ import jax.numpy as jnp from numpy.typing import ArrayLike -import pyrenew.arrayutils as au import pyrenew.latent.infection_functions as inf from pyrenew.metaclass import RandomVariable @@ -168,23 +167,17 @@ def sample( ) ) - if inf_feedback_strength.ndim == Rt.ndim - 1: - inf_feedback_strength = inf_feedback_strength[jnp.newaxis] - - # Making sure inf_feedback_strength spans the Rt length - if inf_feedback_strength.shape[0] == 1: - inf_feedback_strength = au.pad_edges_to_match( - x=inf_feedback_strength, - y=Rt, - axis=0, - )[0] - if inf_feedback_strength.shape != Rt.shape: - raise ValueError( - "Infection feedback strength must be of length 1 " - "or the same length as the reproduction number array. " - f"Got {inf_feedback_strength.shape} " - f"and {Rt.shape} respectively." + try: + inf_feedback_strength = jnp.broadcast_to( + inf_feedback_strength, Rt.shape ) + except Exception as e: + raise ValueError( + "Could not broadcast inf_feedback_strength " + f"(shape {inf_feedback_strength.shape}) " + "to the shape of Rt" + f"{Rt.shape}" + ) from e # Sampling inf feedback pmf inf_feedback_pmf = self.infection_feedback_pmf(**kwargs)