From fe6a64d1d06e25c5404b403721c12f6b30dadf3f Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Tue, 13 Aug 2024 01:34:03 -0400 Subject: [PATCH] Harmonize Quickstart example --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a8d847cf9..9590b4cd6 100644 --- a/README.md +++ b/README.md @@ -75,9 +75,10 @@ state = nuts.init(initial_position) # Iterate rng_key = jax.random.key(0) +step = jax.jit(nuts.step) for step in range(100): nuts_key = jax.random.fold_in(rng_key, step) - state, _ = nuts.step(nuts_key, state) + state, _ = step(nuts_key, state) ``` See [the documentation](https://blackjax-devs.github.io/blackjax/index.html) for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc.