From 3f87a2659823eb3af12d6ce98fe986be266cb70e Mon Sep 17 00:00:00 2001 From: Klaus Greff Date: Mon, 7 Oct 2024 13:45:57 +0200 Subject: [PATCH] Update `track_points` to pad with last chosen point if fewer than requested points are chosen (#333) PiperOrigin-RevId: 681430762 --- challenges/point_tracking/dataset.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/challenges/point_tracking/dataset.py b/challenges/point_tracking/dataset.py index 0f4aca3d..43002a8b 100644 --- a/challenges/point_tracking/dataset.py +++ b/challenges/point_tracking/dataset.py @@ -802,8 +802,29 @@ def get_camera(fr=None): all_relative_depth = all_reproj_depth / chosen_points_depth[..., tf.newaxis] - return tf.cast(chosen_points, tf.float32), tf.cast(all_reproj, - tf.float32), all_occ, all_relative_depth + # Pad to tracks_to_sample by repeating last row if necessary. + # Note this shouldn't be necessary in most cases but sometimes we end up with + # fewer points than requested. + chosen_points = _pad_with_last_row(chosen_points, tracks_to_sample) + all_reproj = _pad_with_last_row(all_reproj, tracks_to_sample) + all_occ = _pad_with_last_row(all_occ, tracks_to_sample) + all_relative_depth = _pad_with_last_row(all_relative_depth, tracks_to_sample) + + return ( + tf.cast(chosen_points, tf.float32), + tf.cast(all_reproj, tf.float32), + all_occ, + all_relative_depth, + ) + + +def _pad_with_last_row(arr, desired_length): + """Pad the array with the last row to the desired first axis length.""" + last_row = arr[-2:-1] + last_row_rep = tf.tile( + last_row, [desired_length - tf.shape(arr)[0]] + [1] * (arr.ndim - 1) + ) + return tf.concat([arr, last_row_rep], axis=0) def _get_distorted_bounding_box(