diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index f1188b53d814..468b65b46775 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -1664,6 +1664,40 @@ def test_foreach_rnn(): check_foreach_rnn(cell_type, num_states) +@with_seed() +def test_cut_subgraph_while_loop(): + class TestLayer(gluon.HybridBlock): + def __init__(self, prefix=None, params=None): + super(TestLayer, self).__init__(prefix=prefix, params=params) + def hybrid_forward(self, F, data): + out1, data1 = F.contrib.while_loop( + cond=lambda i: i <= 5, + func=lambda i: (None, (i + 1, )), + loop_vars=(data, ), + max_iterations=10, + ) + out2, data2 = F.contrib.while_loop( + cond=lambda i: data1[0], + func=lambda i: (None, (i + 1, )), + loop_vars=data1[0], + max_iterations=10, + ) + return data2[0] + data = mx.nd.normal(loc=0, scale=1, shape=(1, )) + layer = TestLayer() + layer.initialize(ctx=default_context()) + res1 = layer(data) + with mx.autograd.record(): + res1 = layer(data) + layer = TestLayer() + layer.initialize(ctx=default_context()) + layer.hybridize() + res2 = layer(data) + with mx.autograd.record(): + res2 = layer(data) + assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001) + + if __name__ == '__main__': import nose nose.runmodule()