diff --git a/jax_ml_stack/tests/test_nnx_with_optax.py b/jax_ml_stack/tests/test_nnx_with_optax.py index 0c0871b..d571e76 100644 --- a/jax_ml_stack/tests/test_nnx_with_optax.py +++ b/jax_ml_stack/tests/test_nnx_with_optax.py @@ -51,3 +51,7 @@ def loss(model, x=x, y=y): final_loss = loss(model) self.assertNotAlmostEqual(initial_loss, final_loss) + + +if __name__ == '__main__': + unittest.main() diff --git a/jax_ml_stack/tests/test_nnx_with_orbax.py b/jax_ml_stack/tests/test_nnx_with_orbax.py new file mode 100644 index 0000000..f470cb4 --- /dev/null +++ b/jax_ml_stack/tests/test_nnx_with_orbax.py @@ -0,0 +1,80 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import contextlib +import platform +import tempfile +import unittest + +from flax import nnx +import jax +import numpy as np +import orbax.checkpoint + + +class SimpleModel(nnx.Module): + + def __init__(self, rngs): + self.layer1 = nnx.Linear(2, 5, rngs=rngs) + self.layer2 = nnx.Linear(5, 3, rngs=rngs) + + def __call__(self, x): + for layer in [self.layer1, self.layer2]: + x = layer(x) + return x + + +class NNXOrbaxTest(unittest.TestCase): + + def setUp(self): + self.tmp_dir = tempfile.TemporaryDirectory() + + if hasattr(self, 'enterContext'): # Python 3.11 or newer + self.enterContext(self.tmp_dir) + else: + with contextlib.ExitStack() as stack: + stack.enter_context(self.tmp_dir) + self.addCleanup(stack.pop_all().close) + + # TODO(jakevdp): https://github.com/google/orbax/pull/1087 + @unittest.skipIf( + platform.system() == 'Windows', 'orbax divide-by-zero error on Windows' + ) + def test_nnx_orbax_checkpoint(self): + model = SimpleModel(nnx.Rngs(0)) + + # Create the checkpoint + state = nnx.state(model) + checkpointer = orbax.checkpoint.PyTreeCheckpointer() + checkpointer.save(f'{self.tmp_dir.name}/state', item=state) + restore_args = orbax.checkpoint.checkpoint_utils.construct_restore_args( + state + ) + + # update the model with the loaded state + restored_model = nnx.eval_shape(SimpleModel, nnx.Rngs(1)) + restored_state = checkpointer.restore( + f'{self.tmp_dir.name}/state', + item=nnx.state(restored_model), + restore_args=restore_args, + ) + nnx.update(restored_model, restored_state) + + self.assertEqual(type(model), type(restored_model)) + jax.tree.map(np.testing.assert_array_equal, state, restored_state) + + +if __name__ == '__main__': + unittest.main()