diff --git a/elegy/hooks_test.py b/elegy/hooks_test.py index 9049a5d2..fb6ca577 100644 --- a/elegy/hooks_test.py +++ b/elegy/hooks_test.py @@ -75,7 +75,7 @@ def call(self, x): return x m = elegy.Model(Module0()) - m.predict(jnp.ones(4)) # init + m.init(jnp.ones(4)) with elegy.hooks.context(named_call=True): jaxpr = jax.make_jaxpr( diff --git a/elegy/slicing.py b/elegy/slicing.py index 7ac7ab3e..9e2a2590 100644 --- a/elegy/slicing.py +++ b/elegy/slicing.py @@ -18,6 +18,7 @@ def slice_model( if not model.initialized: model.predict(sample_input, initialize=True) + model.update_modules() with hooks.context(named_call=True), jax.disable_jit(): jaxpr = jax.make_jaxpr(model.pred_step, static_argnums=[2, 3])( diff --git a/elegy/slicing_test.py b/elegy/slicing_test.py index a42d16cc..a407720a 100644 --- a/elegy/slicing_test.py +++ b/elegy/slicing_test.py @@ -136,6 +136,19 @@ def test_basic_nested(self): assert "module1_linear2" not in submodel.states["net_params"].keys() +def test_no_default_parameters(): + x = np.random.random((32, 100)).astype("float32") + module = BasicModule0() + model = elegy.Model(module, seed=np.random.randint(100, 100000)) + model.init(x) + + submodel = elegy.Model(model.slice("linear0", "linear1", x)) + assert submodel.predict(x, initialize=True).shape == (32, 10) + + model.update_modules() + assert jnp.allclose(submodel.predict(x), module.test_call0(x)) + + class BasicModule0(elegy.Module): def call(self, x): x = x / 255.0