diff --git a/challenges/point_tracking/dataset.py b/challenges/point_tracking/dataset.py index 7012f06c..0e1a3732 100644 --- a/challenges/point_tracking/dataset.py +++ b/challenges/point_tracking/dataset.py @@ -11,13 +11,6 @@ from tensorflow_graphics.geometry.transformation import rotation_matrix_3d -TOTAL_TRACKS = 256 -MAX_SAMPLED_FRAC = .1 -MAX_SEG_ID = 22 -INPUT_SIZE = (None, 256, 256) -STRIDE = 4 # Make sure this divides all axes of INPUT_SIZE - - def project_point(cam, point3d, num_frames): """Compute the image space coordinates [0, 1] for a set of points. @@ -188,6 +181,7 @@ def get_camera_matrices( cam_positions, cam_quaternions, cam_sensor_width, + input_size, num_frames=None, ): """Tf function that converts camera positions into projection matrices.""" @@ -198,7 +192,7 @@ def get_camera_matrices( focal_length = tf.cast(cam_focal_length, tf.float32) sensor_width = tf.cast(cam_sensor_width, tf.float32) f_x = focal_length / sensor_width - f_y = focal_length / sensor_width * INPUT_SIZE[1] / INPUT_SIZE[2] + f_y = focal_length / sensor_width * input_size[0] / input_size[1] p_x = 0.5 p_y = 0.5 intrinsics.append( @@ -235,6 +229,7 @@ def single_object_reproject( num_frames=None, depth_map=None, window=None, + input_size=None, ): """Reproject points for a single object. @@ -265,8 +260,8 @@ def single_object_reproject( ) occluded = tf.less(reproj[:, :, 2], 0) - reproj = reproj[:, :, 0:2] * np.array(INPUT_SIZE[2:0:-1])[np.newaxis, - np.newaxis, :] + reproj = reproj[:, :, 0:2] * np.array(input_size[::-1])[np.newaxis, + np.newaxis, :] occluded = tf.logical_or( occluded, tf.less( @@ -285,19 +280,23 @@ def single_object_reproject( return obj_reproj, obj_occ -def get_num_to_sample(counts): +def get_num_to_sample(counts, max_seg_id, max_sampled_frac, tracks_to_sample): """Computes the number of points to sample for each object. Args: counts: The number of points available per object. An int array of length n, where n is the number of objects. + max_seg_id: The maximum number of segment id's in the video. + max_sampled_frac: The maximum fraction of points to sample from each + object, out of all points that lie on the sampling grid. + tracks_to_sample: Total number of tracks to sample per video. Returns: The number of points to sample for each object. An int array of length n. """ seg_order = tf.argsort(counts) sorted_counts = tf.gather(counts, seg_order) - initializer = (0, TOTAL_TRACKS, 0) + initializer = (0, tracks_to_sample, 0) def scan_fn(prev_output, count_seg): index = prev_output[0] @@ -308,7 +307,7 @@ def scan_fn(prev_output, count_seg): tf.cast(desired_frac, tf.float32)) want_to_sample = tf.cast(tf.round(want_to_sample), tf.int32) max_to_sample = ( - tf.cast(count_seg, tf.float32) * tf.cast(MAX_SAMPLED_FRAC, tf.float32)) + tf.cast(count_seg, tf.float32) * tf.cast(max_sampled_frac, tf.float32)) max_to_sample = tf.cast(tf.round(max_to_sample), tf.int32) num_to_sample = tf.minimum(want_to_sample, max_to_sample) @@ -323,7 +322,7 @@ def scan_fn(prev_output, count_seg): num_to_sample = tf.concat( [ num_to_sample, - tf.zeros([MAX_SEG_ID - tf.shape(num_to_sample)[0]], dtype=tf.int32), + tf.zeros([max_seg_id - tf.shape(num_to_sample)[0]], dtype=tf.int32), ], axis=0, ) @@ -344,7 +343,10 @@ def track_points( cam_quaternions, cam_sensor_width, window, - num_frames=None, + tracks_to_sample=256, + sampling_stride=4, + max_seg_id=25, + max_sampled_frac=0.1, ): """Track points in 2D using Kubric data. @@ -364,7 +366,12 @@ def track_points( window: the window inside which we're sampling points. Integer valued in the format [x_min, y_min, x_max, y_max], where min is inclusive and max is exclusive. - num_frames: number of frames in the video + tracks_to_sample: Total number of tracks to sample per video. + sampling_stride: For efficiency, query points are sampled from a random grid + of this stride. + max_seg_id: The maxium segment id in the video. + max_sampled_frac: The maximum fraction of points to sample from each + object, out of all points that lie on the sampling grid. Returns: A set of queries, randomly sampled from the video (with a bias toward @@ -388,22 +395,28 @@ def track_points( depth_f32 = tf.cast(depth, tf.float32) depth_map = depth_min + depth_f32 * (depth_max-depth_min) / 65535 + input_size = object_coordinates.shape.as_list()[1:3] + num_frames = object_coordinates.shape.as_list()[0] + # We first sample query points within the given window. That means first # extracting the window from the segmentation tensor, because we want to have # a bias toward moving objects. # Note: for speed we sample points on a grid. The grid start position is # randomized within the window. start_vec = [ - tf.random.uniform([], minval=0, maxval=STRIDE, dtype=tf.int32) - for _ in range(len(INPUT_SIZE)) + tf.random.uniform([], minval=0, maxval=sampling_stride, dtype=tf.int32) + for _ in range(3) ] start_vec[1] += window[0] start_vec[2] += window[1] end_vec = [num_frames, window[2], window[3]] def extract_box(x): - x = x[start_vec[0]::STRIDE, start_vec[1]:window[2]:STRIDE, - start_vec[2]:window[3]:STRIDE] + x = x[ + start_vec[0]::sampling_stride, + start_vec[1]:window[2]:sampling_stride, + start_vec[2]:window[3]:sampling_stride, + ] return x segmentations_box = extract_box(segmentations) @@ -413,13 +426,19 @@ def extract_box(x): # how many points are available for each object. cnt = tf.math.bincount(tf.cast(tf.reshape(segmentations_box, [-1]), tf.int32)) - num_to_sample = get_num_to_sample(cnt) - num_to_sample.set_shape([MAX_SEG_ID]) + num_to_sample = get_num_to_sample( + cnt, + max_seg_id, + max_sampled_frac, + tracks_to_sample, + ) + num_to_sample.set_shape([max_seg_id]) intrinsics, matrix_world = get_camera_matrices( cam_focal_length, cam_positions, cam_quaternions, cam_sensor_width, + input_size, num_frames=num_frames, ) @@ -431,11 +450,11 @@ def get_camera(fr=None): # Construct pixel coordinates for each pixel within the window. window = tf.cast(window, tf.float32) z, y, x = tf.meshgrid( - *[tf.range(st, ed, STRIDE) for st, ed in zip(start_vec, end_vec)], + *[tf.range(st, ed, sampling_stride) for st, ed in zip(start_vec, end_vec)], indexing='ij') pix_coords = tf.reshape(tf.stack([z, y, x], axis=-1), [-1, 3]) - for i in range(MAX_SEG_ID): + for i in range(max_seg_id): # sample points on object i in the first frame. obj_id is the position # within the object_coordinates array, which is one lower than the value # in the segmentation mask (0 in the segmentation mask is the background @@ -492,6 +511,7 @@ def get_camera(fr=None): num_frames=num_frames, depth_map=depth_map, window=window, + input_size=input_size, ), lambda: # pylint: disable=g-long-lambda (tf.zeros([0, num_frames, 2], dtype=tf.float32), @@ -508,7 +528,7 @@ def get_camera(fr=None): np.array([num_frames]), window[2:4]], axis=0) wd = wd[tf.newaxis, tf.newaxis, :] - coord_multiplier = [num_frames, INPUT_SIZE[1], INPUT_SIZE[2]] + coord_multiplier = [num_frames, input_size[0], input_size[1]] all_reproj = tf.concat(all_reproj, axis=0) # We need to extract x,y, but the format of the window is [t1,y1,x1,t2,y2,x2] window_size = wd[:, :, 5:3:-1] - wd[:, :, 2:0:-1] @@ -558,17 +578,28 @@ def _get_distorted_bounding_box( def add_tracks(data, - train_size=(200, 200), + train_size=(256, 256), vflip=False, - random_crop=True): + random_crop=True, + tracks_to_sample=256, + sampling_stride=4, + max_seg_id=25, + max_sampled_frac=0.1): """Track points in 2D using Kubric data. Args: - data: kubric data, including RGB/depth/object coordinate/segmentation + data: Kubric data, including RGB/depth/object coordinate/segmentation videos and camera parameters. - train_size: cropped output will be at this resolution + train_size: Cropped output will be at this resolution. Ignored if + random_crop is False. vflip: whether to vertically flip images and tracks (to test generalization) - random_crop: whether to randomly crop videos + random_crop: Whether to randomly crop videos + tracks_to_sample: Total number of tracks to sample per video. + sampling_stride: For efficiency, query points are sampled from a random grid + of this stride. + max_seg_id: The maxium segment id in the video. + max_sampled_frac: The maximum fraction of points to sample from each + object, out of all points that lie on the sampling grid. Returns: A dict with the following keys: @@ -589,6 +620,8 @@ def add_tracks(data, """ shp = data['video'].shape.as_list() num_frames = shp[0] + if any([s % sampling_stride != 0 for s in shp[:-1]]): + raise ValueError("All video dims must be a multiple of sampling_stride.") bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) min_area = 0.3 @@ -604,7 +637,7 @@ def add_tracks(data, area_range=(min_area, max_area), max_attempts=20) else: - crop_window = tf.constant([0, 0, INPUT_SIZE[1], INPUT_SIZE[2]], + crop_window = tf.constant([0, 0, shp[2], shp[3]], dtype=tf.int32, shape=[4]) @@ -613,14 +646,14 @@ def add_tracks(data, data['metadata']['depth_range'], data['segmentations'], data['instances']['bboxes_3d'], data['camera']['focal_length'], data['camera']['positions'], data['camera']['quaternions'], - data['camera']['sensor_width'], crop_window, num_frames) + data['camera']['sensor_width'], crop_window, tracks_to_sample, + sampling_stride, max_seg_id, max_sampled_frac) video = data['video'] shp = video.shape.as_list() - num_frames = shp[0] - query_points.set_shape([TOTAL_TRACKS, 3]) - target_points.set_shape([TOTAL_TRACKS, num_frames, 2]) - occluded.set_shape([TOTAL_TRACKS, num_frames]) + query_points.set_shape([tracks_to_sample, 3]) + target_points.set_shape([tracks_to_sample, num_frames, 2]) + occluded.set_shape([tracks_to_sample, num_frames]) # Crop the video to the sampled window, in a way which matches the coordinate # frame produced the track_points functions. @@ -654,6 +687,10 @@ def create_point_tracking_dataset( repeat=True, vflip=False, random_crop=True, + tracks_to_sample=256, + sampling_stride=4, + max_seg_id=25, + max_sampled_frac=0.1, **kwargs): """Construct a dataset for point tracking using Kubric: go/kubric. @@ -667,6 +704,13 @@ def create_point_tracking_dataset( repeat: Bool. whether to repeat the dataset. vflip: Bool. whether to vertically flip the dataset to test generalization. random_crop: Bool. whether to randomly crop videos + tracks_to_sample: Int. Total number of tracks to sample per video. + sampling_stride: Int. For efficiency, query points are sampled from a + random grid of this stride. + max_seg_id: Int. The maxium segment id in the video. Note the size of + the to graph is proportional to this number, so prefer small values. + max_sampled_frac: Float. The maximum fraction of points to sample from each + object, out of all points that lie on the sampling grid. **kwargs: additional args to pass to tfds.load. Returns: @@ -686,7 +730,11 @@ def create_point_tracking_dataset( add_tracks, train_size=train_size, vflip=vflip, - random_crop=random_crop), + random_crop=random_crop, + tracks_to_sample=tracks_to_sample, + sampling_stride=sampling_stride, + max_seg_id=max_seg_id, + max_sampled_frac=max_sampled_frac), num_parallel_calls=2) if shuffle_buffer_size is not None: ds = ds.shuffle(shuffle_buffer_size)