diff --git a/numpyro/contrib/ecs_proxies.py b/numpyro/contrib/ecs_proxies.py index 421b981a4..0ed852af4 100644 --- a/numpyro/contrib/ecs_proxies.py +++ b/numpyro/contrib/ecs_proxies.py @@ -162,15 +162,14 @@ def log_likelihood_sum(params_flat, subsample_indices=None): for k, v in log_likelihood(params_flat, subsample_indices).items() } - match degree: - case 2: - TPState = TaylorTwoProxyState - case 1: - TPState = TaylorOneProxyState - case _: - raise ValueError( - "Taylor proxy only defined for first and second degree." - ) + if degree == 2: + TPState = TaylorTwoProxyState + elif 1: + TPState = TaylorOneProxyState + else: + raise ValueError( + "Taylor proxy only defined for first and second degree." + ) # those stats are dict keyed by subsample names ref_sum_log_lik = log_likelihood_sum(ref_params_flat)