From 7571cca0ecb3f7b30d454b9b5027346bd0901530 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Thu, 20 Jun 2024 08:00:54 -0700 Subject: [PATCH 1/2] convert looping to twiddling --- numpyro/infer/hmc_util.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/numpyro/infer/hmc_util.py b/numpyro/infer/hmc_util.py index 51e628148..b331540d1 100644 --- a/numpyro/infer/hmc_util.py +++ b/numpyro/infer/hmc_util.py @@ -941,14 +941,10 @@ def _double_tree( def _leaf_idx_to_ckpt_idxs(n): # computes the number of non-zero bits except the last bit # e.g. 6 -> 2, 7 -> 2, 13 -> 2 - _, idx_max = while_loop( - lambda nc: nc[0] > 0, lambda nc: (nc[0] >> 1, nc[1] + (nc[0] & 1)), (n >> 1, 0) - ) + idx_max = jnp.bitwise_count(n >> 1).astype(jnp.int32) # computes the number of contiguous last non-zero bits # e.g. 6 -> 0, 7 -> 3, 13 -> 1 - _, num_subtrees = while_loop( - lambda nc: (nc[0] & 1) != 0, lambda nc: (nc[0] >> 1, nc[1] + 1), (n, 0) - ) + num_subtrees = jnp.bitwise_count((~n & (n + 1)) - 1).astype(jnp.int32) # TODO: explore the potential of setting idx_min=0 to allow more turning checks # It will be useful in case: e.g. assume a tree 0 -> 7 is a circle, # subtrees 0 -> 3, 4 -> 7 are half-circles, which two leaves might not From 1927b5853760a1abe40691c6e598442315bae1ff Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Thu, 20 Jun 2024 08:00:54 -0700 Subject: [PATCH 2/2] convert looping to twiddling --- numpyro/infer/hmc_util.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/numpyro/infer/hmc_util.py b/numpyro/infer/hmc_util.py index 51e628148..b331540d1 100644 --- a/numpyro/infer/hmc_util.py +++ b/numpyro/infer/hmc_util.py @@ -941,14 +941,10 @@ def _double_tree( def _leaf_idx_to_ckpt_idxs(n): # computes the number of non-zero bits except the last bit # e.g. 6 -> 2, 7 -> 2, 13 -> 2 - _, idx_max = while_loop( - lambda nc: nc[0] > 0, lambda nc: (nc[0] >> 1, nc[1] + (nc[0] & 1)), (n >> 1, 0) - ) + idx_max = jnp.bitwise_count(n >> 1).astype(jnp.int32) # computes the number of contiguous last non-zero bits # e.g. 6 -> 0, 7 -> 3, 13 -> 1 - _, num_subtrees = while_loop( - lambda nc: (nc[0] & 1) != 0, lambda nc: (nc[0] >> 1, nc[1] + 1), (n, 0) - ) + num_subtrees = jnp.bitwise_count((~n & (n + 1)) - 1).astype(jnp.int32) # TODO: explore the potential of setting idx_min=0 to allow more turning checks # It will be useful in case: e.g. assume a tree 0 -> 7 is a circle, # subtrees 0 -> 3, 4 -> 7 are half-circles, which two leaves might not