diff --git a/tests/tensorflow2/test_should_save_tensor.py b/tests/tensorflow2/test_should_save_tensor.py index 6f09da2f12..0e3e4ba1db 100644 --- a/tests/tensorflow2/test_should_save_tensor.py +++ b/tests/tensorflow2/test_should_save_tensor.py @@ -4,6 +4,7 @@ # First Party import smdebug.tensorflow as smd from smdebug.core.collection import CollectionKeys +from smdebug.core.modes import ModeKeys from smdebug.tensorflow import SaveConfig from smdebug.tensorflow.constants import TF_DEFAULT_SAVED_COLLECTIONS @@ -27,6 +28,9 @@ def helper_create_hook(out_dir, collections, include_regex=None): hook.get_collection(collection).include(include_regex) hook.register_model(model) + hook.set_mode(ModeKeys.TRAIN) + hook._prepare_collections() + hook._increment_step() hook.on_train_begin() return hook