Skip to content

Commit

Permalink
[JAX] move example libraries from jax.experimental to `jax.example_…
Browse files Browse the repository at this point in the history
…libraries`

PiperOrigin-RevId: 404416764
  • Loading branch information
froystig authored and OptaxDev committed Nov 19, 2021
1 parent 46fbf8a commit 95c3d00
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/differentially_private_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from absl import flags
import datasets
import jax
from jax.experimental import stax
from jax.example_libraries import stax
import jax.numpy as jnp
import optax
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
Expand Down
6 changes: 3 additions & 3 deletions optax/_src/equivalence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import chex
from flax import optim
from jax.experimental import optimizers
from jax.example_libraries import optimizers
import jax.numpy as jnp

from optax._src import alias
Expand All @@ -31,7 +31,7 @@
LR_SCHED = lambda _: LR # Trivial constant "schedule".


class ExperimentalOptimizersEquivalenceTest(chex.TestCase):
class OptimizersEquivalenceTest(chex.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -66,7 +66,7 @@ def setUp(self):
)
def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol):

# experimental/optimizers.py
# example_libraries/optimizers.py
jax_params = self.init_params
opt_init, opt_update, get_params = jax_optimizer
state = opt_init(jax_params)
Expand Down

0 comments on commit 95c3d00

Please sign in to comment.