From 12badae66a5b08d126ce705212d7541c72dae2c0 Mon Sep 17 00:00:00 2001 From: junpenglao Date: Tue, 20 Sep 2022 15:01:24 +0200 Subject: [PATCH] Update tfp.md No need to import tensorflow as everything is already running in JAX --- examples/tfp.md | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/examples/tfp.md b/examples/tfp.md index e828bbb86..67fb87640 100644 --- a/examples/tfp.md +++ b/examples/tfp.md @@ -166,11 +166,6 @@ plt.show() ## Compare Sampling Time with TFP -```{code-cell} ipython3 -import tensorflow.compat.v2 as tf - -tf.enable_v2_behavior() -``` ```{code-cell} ipython3 %%time @@ -178,9 +173,7 @@ tf.enable_v2_behavior() num_results = 500_000 num_burnin_steps = 0 -# Improve performance by tracing the sampler using `tf.function` -# and compiling it using XLA. -@tf.function(autograph=False, experimental_compile=True, experimental_relax_shapes=True) +@jax.jit def do_sampling(): return tfp.mcmc.sample_chain( num_results=num_results,