diff --git a/challenges/point_tracking/dataset.py b/challenges/point_tracking/dataset.py index 0f4aca3d..43002a8b 100644 --- a/challenges/point_tracking/dataset.py +++ b/challenges/point_tracking/dataset.py @@ -802,8 +802,29 @@ def get_camera(fr=None): all_relative_depth = all_reproj_depth / chosen_points_depth[..., tf.newaxis] - return tf.cast(chosen_points, tf.float32), tf.cast(all_reproj, - tf.float32), all_occ, all_relative_depth + # Pad to tracks_to_sample by repeating last row if necessary. + # Note this shouldn't be necessary in most cases but sometimes we end up with + # fewer points than requested. + chosen_points = _pad_with_last_row(chosen_points, tracks_to_sample) + all_reproj = _pad_with_last_row(all_reproj, tracks_to_sample) + all_occ = _pad_with_last_row(all_occ, tracks_to_sample) + all_relative_depth = _pad_with_last_row(all_relative_depth, tracks_to_sample) + + return ( + tf.cast(chosen_points, tf.float32), + tf.cast(all_reproj, tf.float32), + all_occ, + all_relative_depth, + ) + + +def _pad_with_last_row(arr, desired_length): + """Pad the array with the last row to the desired first axis length.""" + last_row = arr[-2:-1] + last_row_rep = tf.tile( + last_row, [desired_length - tf.shape(arr)[0]] + [1] * (arr.ndim - 1) + ) + return tf.concat([arr, last_row_rep], axis=0) def _get_distorted_bounding_box(