Skip to content

Commit

Permalink
Convert explicit looping to bit twiddling for nuts u-turn calculations (
Browse files Browse the repository at this point in the history
pyro-ppl#1818)

* convert looping to twiddling

* convert looping to twiddling
  • Loading branch information
andrewdipper authored Jun 23, 2024
1 parent 9785376 commit 0924135
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions numpyro/infer/hmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,14 +941,10 @@ def _double_tree(
def _leaf_idx_to_ckpt_idxs(n):
# computes the number of non-zero bits except the last bit
# e.g. 6 -> 2, 7 -> 2, 13 -> 2
_, idx_max = while_loop(
lambda nc: nc[0] > 0, lambda nc: (nc[0] >> 1, nc[1] + (nc[0] & 1)), (n >> 1, 0)
)
idx_max = jnp.bitwise_count(n >> 1).astype(jnp.int32)
# computes the number of contiguous last non-zero bits
# e.g. 6 -> 0, 7 -> 3, 13 -> 1
_, num_subtrees = while_loop(
lambda nc: (nc[0] & 1) != 0, lambda nc: (nc[0] >> 1, nc[1] + 1), (n, 0)
)
num_subtrees = jnp.bitwise_count((~n & (n + 1)) - 1).astype(jnp.int32)
# TODO: explore the potential of setting idx_min=0 to allow more turning checks
# It will be useful in case: e.g. assume a tree 0 -> 7 is a circle,
# subtrees 0 -> 3, 4 -> 7 are half-circles, which two leaves might not
Expand Down

0 comments on commit 0924135

Please sign in to comment.