Skip to content

Commit

Permalink
Update tfp.md
Browse files Browse the repository at this point in the history
No need to import tensorflow as everything is already running in JAX
  • Loading branch information
junpenglao committed Sep 20, 2022
1 parent 138844e commit 12badae
Showing 1 changed file with 1 addition and 8 deletions.
9 changes: 1 addition & 8 deletions examples/tfp.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,21 +166,14 @@ plt.show()

## Compare Sampling Time with TFP

```{code-cell} ipython3
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
```

```{code-cell} ipython3
%%time
num_results = 500_000
num_burnin_steps = 0
# Improve performance by tracing the sampler using `tf.function`
# and compiling it using XLA.
@tf.function(autograph=False, experimental_compile=True, experimental_relax_shapes=True)
@jax.jit
def do_sampling():
return tfp.mcmc.sample_chain(
num_results=num_results,
Expand Down

0 comments on commit 12badae

Please sign in to comment.