From 947486582113757f85aed11140ea726c06347910 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 8 Aug 2023 11:11:01 +0200 Subject: [PATCH 1/2] sbvm sampler fix by comparison with R code --- numpyro/distributions/directional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index b3e1dcd1a..eb3de3ecb 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -477,7 +477,7 @@ def update_fn(curr): phi_key, key = random.split(key) accept_key, acg_key, phi_key = random.split(phi_key, 3) - x = jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape) + x = 1./jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape) x /= jnp.linalg.norm( x, axis=1, keepdims=True ) # Angular Central Gaussian distribution From 0672ab9c00975e498373f0d18121f6c8025440d9 Mon Sep 17 00:00:00 2001 From: = Date: Thu, 17 Aug 2023 15:46:55 +0200 Subject: [PATCH 2/2] use rsqrt instead of 1/sqrt --- numpyro/distributions/directional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index eb3de3ecb..67a91a8d1 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -477,7 +477,7 @@ def update_fn(curr): phi_key, key = random.split(key) accept_key, acg_key, phi_key = random.split(phi_key, 3) - x = 1./jnp.sqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape) + x = lax.rsqrt(1 + 2 * eig / b0) * random.normal(acg_key, shape) x /= jnp.linalg.norm( x, axis=1, keepdims=True ) # Angular Central Gaussian distribution