Skip to content

Commit

Permalink
Update track_points to pad with last chosen point if fewer than req…
Browse files Browse the repository at this point in the history
…uested points are chosen (#333)

PiperOrigin-RevId: 681430762
  • Loading branch information
Qwlouse authored Oct 7, 2024
1 parent c65de56 commit 3f87a26
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions challenges/point_tracking/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,8 +802,29 @@ def get_camera(fr=None):

all_relative_depth = all_reproj_depth / chosen_points_depth[..., tf.newaxis]

return tf.cast(chosen_points, tf.float32), tf.cast(all_reproj,
tf.float32), all_occ, all_relative_depth
# Pad to tracks_to_sample by repeating last row if necessary.
# Note this shouldn't be necessary in most cases but sometimes we end up with
# fewer points than requested.
chosen_points = _pad_with_last_row(chosen_points, tracks_to_sample)
all_reproj = _pad_with_last_row(all_reproj, tracks_to_sample)
all_occ = _pad_with_last_row(all_occ, tracks_to_sample)
all_relative_depth = _pad_with_last_row(all_relative_depth, tracks_to_sample)

return (
tf.cast(chosen_points, tf.float32),
tf.cast(all_reproj, tf.float32),
all_occ,
all_relative_depth,
)


def _pad_with_last_row(arr, desired_length):
"""Pad the array with the last row to the desired first axis length."""
last_row = arr[-2:-1]
last_row_rep = tf.tile(
last_row, [desired_length - tf.shape(arr)[0]] + [1] * (arr.ndim - 1)
)
return tf.concat([arr, last_row_rep], axis=0)


def _get_distorted_bounding_box(
Expand Down

0 comments on commit 3f87a26

Please sign in to comment.