From 2dea62bca569679bb8f1000d8c786e8220a27abb Mon Sep 17 00:00:00 2001 From: doersch Date: Wed, 22 May 2024 17:12:08 +0100 Subject: [PATCH] Fix a bug in resizing in Kubric point tracking data augmentation. PiperOrigin-RevId: 636189132 --- challenges/point_tracking/dataset.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/challenges/point_tracking/dataset.py b/challenges/point_tracking/dataset.py index 88e9c64b..c08c66a5 100644 --- a/challenges/point_tracking/dataset.py +++ b/challenges/point_tracking/dataset.py @@ -921,15 +921,14 @@ def add_tracks(data, # Crop the video to the sampled window, in a way which matches the coordinate # frame produced the track_points functions. - crop_window = crop_window / ( - np.array(shp[1:3] + shp[1:3]).astype(np.float32) - 1) - crop_window = tf.tile(crop_window[tf.newaxis, :], [num_frames, 1]) - video = tf.image.crop_and_resize( - video, - tf.cast(crop_window, tf.float32), - tf.range(num_frames), - train_size, + start = tf.tensor_scatter_nd_update( + [0, 0, 0, 0], [[1], [2]], crop_window[0:2] ) + size = tf.tensor_scatter_nd_update( + tf.shape(video), [[1], [2]], crop_window[2:4] - crop_window[0:2] + ) + video = tf.slice(video, start, size) + video = tf.image.resize(tf.cast(video, tf.float32), train_size) if vflip: video = video[:, ::-1, :, :] target_points = target_points * np.array([1, -1])