Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catchup #2

Merged
merged 2 commits into from
Feb 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 90 additions & 38 deletions challenges/point_tracking/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,6 @@
from tensorflow_graphics.geometry.transformation import rotation_matrix_3d


TOTAL_TRACKS = 256
MAX_SAMPLED_FRAC = .1
MAX_SEG_ID = 22
INPUT_SIZE = (None, 256, 256)
STRIDE = 4 # Make sure this divides all axes of INPUT_SIZE


def project_point(cam, point3d, num_frames):
"""Compute the image space coordinates [0, 1] for a set of points.

Expand Down Expand Up @@ -188,6 +181,7 @@ def get_camera_matrices(
cam_positions,
cam_quaternions,
cam_sensor_width,
input_size,
num_frames=None,
):
"""Tf function that converts camera positions into projection matrices."""
Expand All @@ -198,7 +192,7 @@ def get_camera_matrices(
focal_length = tf.cast(cam_focal_length, tf.float32)
sensor_width = tf.cast(cam_sensor_width, tf.float32)
f_x = focal_length / sensor_width
f_y = focal_length / sensor_width * INPUT_SIZE[1] / INPUT_SIZE[2]
f_y = focal_length / sensor_width * input_size[0] / input_size[1]
p_x = 0.5
p_y = 0.5
intrinsics.append(
Expand Down Expand Up @@ -235,6 +229,7 @@ def single_object_reproject(
num_frames=None,
depth_map=None,
window=None,
input_size=None,
):
"""Reproject points for a single object.

Expand All @@ -247,6 +242,7 @@ def single_object_reproject(
num_frames: Number of frames
depth_map: Depth map video for the camera
window: the window inside which we're sampling points
input_size: [height, width] of the input images.

Returns:
Position for each point, of shape [num_points, num_frames, 2], in pixel
Expand All @@ -265,8 +261,8 @@ def single_object_reproject(
)

occluded = tf.less(reproj[:, :, 2], 0)
reproj = reproj[:, :, 0:2] * np.array(INPUT_SIZE[2:0:-1])[np.newaxis,
np.newaxis, :]
reproj = reproj[:, :, 0:2] * np.array(input_size[::-1])[np.newaxis,
np.newaxis, :]
occluded = tf.logical_or(
occluded,
tf.less(
Expand All @@ -285,19 +281,23 @@ def single_object_reproject(
return obj_reproj, obj_occ


def get_num_to_sample(counts):
def get_num_to_sample(counts, max_seg_id, max_sampled_frac, tracks_to_sample):
"""Computes the number of points to sample for each object.

Args:
counts: The number of points available per object. An int array of length
n, where n is the number of objects.
max_seg_id: The maximum number of segment id's in the video.
max_sampled_frac: The maximum fraction of points to sample from each
object, out of all points that lie on the sampling grid.
tracks_to_sample: Total number of tracks to sample per video.

Returns:
The number of points to sample for each object. An int array of length n.
"""
seg_order = tf.argsort(counts)
sorted_counts = tf.gather(counts, seg_order)
initializer = (0, TOTAL_TRACKS, 0)
initializer = (0, tracks_to_sample, 0)

def scan_fn(prev_output, count_seg):
index = prev_output[0]
Expand All @@ -308,7 +308,7 @@ def scan_fn(prev_output, count_seg):
tf.cast(desired_frac, tf.float32))
want_to_sample = tf.cast(tf.round(want_to_sample), tf.int32)
max_to_sample = (
tf.cast(count_seg, tf.float32) * tf.cast(MAX_SAMPLED_FRAC, tf.float32))
tf.cast(count_seg, tf.float32) * tf.cast(max_sampled_frac, tf.float32))
max_to_sample = tf.cast(tf.round(max_to_sample), tf.int32)
num_to_sample = tf.minimum(want_to_sample, max_to_sample)

Expand All @@ -323,7 +323,7 @@ def scan_fn(prev_output, count_seg):
num_to_sample = tf.concat(
[
num_to_sample,
tf.zeros([MAX_SEG_ID - tf.shape(num_to_sample)[0]], dtype=tf.int32),
tf.zeros([max_seg_id - tf.shape(num_to_sample)[0]], dtype=tf.int32),
],
axis=0,
)
Expand All @@ -344,7 +344,10 @@ def track_points(
cam_quaternions,
cam_sensor_width,
window,
num_frames=None,
tracks_to_sample=256,
sampling_stride=4,
max_seg_id=25,
max_sampled_frac=0.1,
):
"""Track points in 2D using Kubric data.

Expand All @@ -364,7 +367,12 @@ def track_points(
window: the window inside which we're sampling points. Integer valued
in the format [x_min, y_min, x_max, y_max], where min is inclusive and
max is exclusive.
num_frames: number of frames in the video
tracks_to_sample: Total number of tracks to sample per video.
sampling_stride: For efficiency, query points are sampled from a random grid
of this stride.
max_seg_id: The maxium segment id in the video.
max_sampled_frac: The maximum fraction of points to sample from each
object, out of all points that lie on the sampling grid.

Returns:
A set of queries, randomly sampled from the video (with a bias toward
Expand All @@ -388,22 +396,25 @@ def track_points(
depth_f32 = tf.cast(depth, tf.float32)
depth_map = depth_min + depth_f32 * (depth_max-depth_min) / 65535

input_size = object_coordinates.shape.as_list()[1:3]
num_frames = object_coordinates.shape.as_list()[0]

# We first sample query points within the given window. That means first
# extracting the window from the segmentation tensor, because we want to have
# a bias toward moving objects.
# Note: for speed we sample points on a grid. The grid start position is
# randomized within the window.
start_vec = [
tf.random.uniform([], minval=0, maxval=STRIDE, dtype=tf.int32)
for _ in range(len(INPUT_SIZE))
tf.random.uniform([], minval=0, maxval=sampling_stride, dtype=tf.int32)
for _ in range(3)
]
start_vec[1] += window[0]
start_vec[2] += window[1]
end_vec = [num_frames, window[2], window[3]]

def extract_box(x):
x = x[start_vec[0]::STRIDE, start_vec[1]:window[2]:STRIDE,
start_vec[2]:window[3]:STRIDE]
x = x[start_vec[0]::sampling_stride, start_vec[1]:window[2]:sampling_stride,
start_vec[2]:window[3]:sampling_stride]
return x

segmentations_box = extract_box(segmentations)
Expand All @@ -413,13 +424,19 @@ def extract_box(x):
# how many points are available for each object.

cnt = tf.math.bincount(tf.cast(tf.reshape(segmentations_box, [-1]), tf.int32))
num_to_sample = get_num_to_sample(cnt)
num_to_sample.set_shape([MAX_SEG_ID])
num_to_sample = get_num_to_sample(
cnt,
max_seg_id,
max_sampled_frac,
tracks_to_sample,
)
num_to_sample.set_shape([max_seg_id])
intrinsics, matrix_world = get_camera_matrices(
cam_focal_length,
cam_positions,
cam_quaternions,
cam_sensor_width,
input_size,
num_frames=num_frames,
)

Expand All @@ -431,11 +448,14 @@ def get_camera(fr=None):
# Construct pixel coordinates for each pixel within the window.
window = tf.cast(window, tf.float32)
z, y, x = tf.meshgrid(
*[tf.range(st, ed, STRIDE) for st, ed in zip(start_vec, end_vec)],
*[
tf.range(st, ed, sampling_stride)
for st, ed in zip(start_vec, end_vec)
],
indexing='ij')
pix_coords = tf.reshape(tf.stack([z, y, x], axis=-1), [-1, 3])

for i in range(MAX_SEG_ID):
for i in range(max_seg_id):
# sample points on object i in the first frame. obj_id is the position
# within the object_coordinates array, which is one lower than the value
# in the segmentation mask (0 in the segmentation mask is the background
Expand Down Expand Up @@ -492,6 +512,7 @@ def get_camera(fr=None):
num_frames=num_frames,
depth_map=depth_map,
window=window,
input_size=input_size,
),
lambda: # pylint: disable=g-long-lambda
(tf.zeros([0, num_frames, 2], dtype=tf.float32),
Expand All @@ -508,7 +529,7 @@ def get_camera(fr=None):
np.array([num_frames]), window[2:4]],
axis=0)
wd = wd[tf.newaxis, tf.newaxis, :]
coord_multiplier = [num_frames, INPUT_SIZE[1], INPUT_SIZE[2]]
coord_multiplier = [num_frames, input_size[0], input_size[1]]
all_reproj = tf.concat(all_reproj, axis=0)
# We need to extract x,y, but the format of the window is [t1,y1,x1,t2,y2,x2]
window_size = wd[:, :, 5:3:-1] - wd[:, :, 2:0:-1]
Expand Down Expand Up @@ -558,17 +579,28 @@ def _get_distorted_bounding_box(


def add_tracks(data,
train_size=(200, 200),
train_size=(256, 256),
vflip=False,
random_crop=True):
random_crop=True,
tracks_to_sample=256,
sampling_stride=4,
max_seg_id=25,
max_sampled_frac=0.1):
"""Track points in 2D using Kubric data.

Args:
data: kubric data, including RGB/depth/object coordinate/segmentation
data: Kubric data, including RGB/depth/object coordinate/segmentation
videos and camera parameters.
train_size: cropped output will be at this resolution
train_size: Cropped output will be at this resolution. Ignored if
random_crop is False.
vflip: whether to vertically flip images and tracks (to test generalization)
random_crop: whether to randomly crop videos
random_crop: Whether to randomly crop videos
tracks_to_sample: Total number of tracks to sample per video.
sampling_stride: For efficiency, query points are sampled from a random grid
of this stride.
max_seg_id: The maxium segment id in the video.
max_sampled_frac: The maximum fraction of points to sample from each
object, out of all points that lie on the sampling grid.

Returns:
A dict with the following keys:
Expand All @@ -589,6 +621,8 @@ def add_tracks(data,
"""
shp = data['video'].shape.as_list()
num_frames = shp[0]
if any([s % sampling_stride != 0 for s in shp[:-1]]):
raise ValueError('All video dims must be a multiple of sampling_stride.')

bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
min_area = 0.3
Expand All @@ -604,7 +638,7 @@ def add_tracks(data,
area_range=(min_area, max_area),
max_attempts=20)
else:
crop_window = tf.constant([0, 0, INPUT_SIZE[1], INPUT_SIZE[2]],
crop_window = tf.constant([0, 0, shp[1], shp[2]],
dtype=tf.int32,
shape=[4])

Expand All @@ -613,14 +647,14 @@ def add_tracks(data,
data['metadata']['depth_range'], data['segmentations'],
data['instances']['bboxes_3d'], data['camera']['focal_length'],
data['camera']['positions'], data['camera']['quaternions'],
data['camera']['sensor_width'], crop_window, num_frames)
data['camera']['sensor_width'], crop_window, tracks_to_sample,
sampling_stride, max_seg_id, max_sampled_frac)
video = data['video']

shp = video.shape.as_list()
num_frames = shp[0]
query_points.set_shape([TOTAL_TRACKS, 3])
target_points.set_shape([TOTAL_TRACKS, num_frames, 2])
occluded.set_shape([TOTAL_TRACKS, num_frames])
query_points.set_shape([tracks_to_sample, 3])
target_points.set_shape([tracks_to_sample, num_frames, 2])
occluded.set_shape([tracks_to_sample, num_frames])

# Crop the video to the sampled window, in a way which matches the coordinate
# frame produced the track_points functions.
Expand Down Expand Up @@ -654,6 +688,11 @@ def create_point_tracking_dataset(
repeat=True,
vflip=False,
random_crop=True,
tracks_to_sample=256,
sampling_stride=4,
max_seg_id=25,
max_sampled_frac=0.1,
num_parallel_point_extraction_calls=16,
**kwargs):
"""Construct a dataset for point tracking using Kubric: go/kubric.

Expand All @@ -667,6 +706,15 @@ def create_point_tracking_dataset(
repeat: Bool. whether to repeat the dataset.
vflip: Bool. whether to vertically flip the dataset to test generalization.
random_crop: Bool. whether to randomly crop videos
tracks_to_sample: Int. Total number of tracks to sample per video.
sampling_stride: Int. For efficiency, query points are sampled from a
random grid of this stride.
max_seg_id: Int. The maxium segment id in the video. Note the size of
the to graph is proportional to this number, so prefer small values.
max_sampled_frac: Float. The maximum fraction of points to sample from each
object, out of all points that lie on the sampling grid.
num_parallel_point_extraction_calls: Int. The num_parallel_calls for the
map function for point extraction.
**kwargs: additional args to pass to tfds.load.

Returns:
Expand All @@ -686,8 +734,12 @@ def create_point_tracking_dataset(
add_tracks,
train_size=train_size,
vflip=vflip,
random_crop=random_crop),
num_parallel_calls=2)
random_crop=random_crop,
tracks_to_sample=tracks_to_sample,
sampling_stride=sampling_stride,
max_seg_id=max_seg_id,
max_sampled_frac=max_sampled_frac),
num_parallel_calls=num_parallel_point_extraction_calls)
if shuffle_buffer_size is not None:
ds = ds.shuffle(shuffle_buffer_size)

Expand Down