Skip to content

Commit

Permalink
Return R2N2 voxel coordinates
Browse files Browse the repository at this point in the history
Summary: Return R2N2's voxel coordinates.

Reviewed By: nikhilaravi

Differential Revision: D22462530

fbshipit-source-id: a995cfa0957b2561eb3b0f4591cb1db42170bc68
  • Loading branch information
megluyagao authored and facebook-github-bot committed Aug 7, 2020
1 parent 326e4cc commit 63ba74f
Show file tree
Hide file tree
Showing 7 changed files with 602 additions and 135 deletions.
2 changes: 1 addition & 1 deletion pytorch3d/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from .r2n2 import R2N2, BlenderCamera
from .r2n2 import R2N2, BlenderCamera, collate_batched_R2N2, render_cubified_voxels
from .shapenet import ShapeNetCore
from .utils import collate_batched_meshes

Expand Down
3 changes: 2 additions & 1 deletion pytorch3d/datasets/r2n2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from .r2n2 import R2N2, BlenderCamera
from .r2n2 import R2N2
from .utils import BlenderCamera, collate_batched_R2N2, render_cubified_voxels


__all__ = [k for k in globals().keys() if not k.startswith("_")]
126 changes: 75 additions & 51 deletions pytorch3d/datasets/r2n2/r2n2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,32 @@
import torch
from PIL import Image
from pytorch3d.datasets.shapenet_base import ShapeNetBase
from pytorch3d.datasets.utils import compute_extrinsic_matrix
from pytorch3d.io import load_obj
from pytorch3d.renderer import HardPhongShader
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.transforms import Transform3d
from tabulate import tabulate

from .utils import (
BlenderCamera,
align_bbox,
compute_extrinsic_matrix,
read_binvox_coords,
voxelize,
)

SYNSET_DICT_DIR = Path(__file__).resolve().parent

# Default values of rotation, translation and intrinsic matrices for BlenderCamera.
r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3)
t = np.expand_dims(np.zeros(3), axis=0) # (1, 3)
k = np.expand_dims(np.eye(4), axis=0) # (1, 4, 4)
SYNSET_DICT_DIR = Path(__file__).resolve().parent
MAX_CAMERA_DISTANCE = 1.75 # Constant from R2N2.
VOXEL_SIZE = 128
# Intrinsic matrix extracted from Blender. Taken from meshrcnn codebase:
# https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py
BLENDER_INTRINSIC = torch.tensor(
[
[2.1875, 0.0, 0.0, 0.0],
[0.0, 2.1875, 0.0, 0.0],
[0.0, 0.0, -1.002002, -0.2002002],
[0.0, 0.0, -1.0, 0.0],
]
)


class R2N2(ShapeNetBase):
Expand All @@ -42,6 +54,7 @@ def __init__(
r2n2_dir,
splits_file,
return_all_views: bool = True,
return_voxels: bool = False,
):
"""
Store each object's synset id and models id the given directories.
Expand All @@ -54,6 +67,8 @@ def __init__(
return_all_views (bool): Indicator of whether or not to load all the views in
the split. If set to False, one of the views in the split will be randomly
selected and loaded.
return_voxels(bool): Indicator of whether or not to return voxels as a tensor
of shape (D, D, D) where D is the number of voxels along each dimension.
"""
super().__init__()
self.shapenet_dir = shapenet_dir
Expand Down Expand Up @@ -83,6 +98,16 @@ def __init__(
) % (r2n2_dir)
warnings.warn(msg)

self.return_voxels = return_voxels
# Check if the folder containing voxel coordinates is included in r2n2_dir.
if not path.isdir(path.join(r2n2_dir, "ShapeNetVox32")):
self.return_voxels = False
msg = (
"ShapeNetVox32 not found in %s. Voxel coordinates will "
"be skipped when returning models."
) % (r2n2_dir)
warnings.warn(msg)

synset_set = set()
# Store lists of views of each model in a list.
self.views_per_model_list = []
Expand Down Expand Up @@ -173,6 +198,8 @@ def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
- R: Rotation matrix of shape (V, 3, 3), where V is number of views returned.
- T: Translation matrix of shape (V, 3), where V is number of views returned.
- K: Intrinsic matrix of shape (V, 4, 4), where V is number of views returned.
- voxels: Voxels of shape (D, D, D), where D is the number of voxels along each
dimension.
"""
if isinstance(model_idx, tuple):
model_idx, view_idxs = model_idx
Expand Down Expand Up @@ -208,6 +235,7 @@ def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
model["label"] = self.synset_dict[model["synset_id"]]

model["images"] = None
images, Rs, Ts, voxel_RTs = [], [], [], []
# Retrieve R2N2's renderings if required.
if self.return_images:
rendering_path = path.join(
Expand All @@ -217,12 +245,9 @@ def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
model["model_id"],
"rendering",
)

# Read metadata file to obtain params for calibration matrices.
with open(path.join(rendering_path, "rendering_metadata.txt"), "r") as f:
metadata_lines = f.readlines()

images, Rs, Ts = [], [], []
for i in model_views:
# Read image.
image_path = path.join(rendering_path, "%02d.png" % i)
Expand All @@ -234,9 +259,13 @@ def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
azim, elev, yaw, dist_ratio, fov = [
float(v) for v in metadata_lines[i].strip().split(" ")
]
R, T = self._compute_camera_calibration(azim, elev, dist_ratio)
dist = dist_ratio * MAX_CAMERA_DISTANCE
# Extrinsic matrix before transformation to PyTorch3D world space.
RT = compute_extrinsic_matrix(azim, elev, dist)
R, T = self._compute_camera_calibration(RT)
Rs.append(R)
Ts.append(T)
voxel_RTs.append(RT)

# Intrinsic matrix extracted from the Blender with slight modification to work with
# PyTorch3D world space. Taken from meshrcnn codebase:
Expand All @@ -254,27 +283,48 @@ def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
model["T"] = torch.stack(Ts)
model["K"] = K.expand(len(model_views), 4, 4)

voxels_list = []
# Read voxels if required.
voxel_path = path.join(
self.r2n2_dir,
"ShapeNetVox32",
model["synset_id"],
model["model_id"],
"model.binvox",
)
if self.return_voxels:
if not path.isfile(voxel_path):
msg = "Voxel file not found for model %s from category %s."
raise FileNotFoundError(msg % (model["model_id"], model["synset_id"]))

with open(voxel_path, "rb") as f:
# Read voxel coordinates as a tensor of shape (N, 3).
voxel_coords = read_binvox_coords(f)
# Align voxels to the same coordinate system as mesh verts.
voxel_coords = align_bbox(voxel_coords, model["verts"])
for RT in voxel_RTs:
# Compute projection matrix.
P = BLENDER_INTRINSIC.mm(RT)
# Convert voxel coordinates of shape (N, 3) to voxels of shape (D, D, D).
voxels = voxelize(voxel_coords, P, VOXEL_SIZE)
voxels_list.append(voxels)
model["voxels"] = torch.stack(voxels_list)

return model

def _compute_camera_calibration(self, azim: float, elev: float, dist_ratio: float):
def _compute_camera_calibration(self, RT):
"""
Helper function for calculating rotation and translation matrices from azimuth
angle, elevation and distance ratio.
Helper function for calculating rotation and translation matrices from ShapeNet
to camera transformation and ShapeNet to PyTorch3D transformation.
Args:
azim: Rotation about the z-axis, in degrees.
elev: Rotation above the xy-plane, in degrees.
dist_ratio: Ratio of distance from the origin to the maximum camera distance.
RT: Extrinsic matrix that performs ShapeNet world view to camera view
transformation.
Returns:
- R: Rotation matrix of shape (3, 3).
- T: Translation matrix of shape (3).
R: Rotation matrix of shape (3, 3).
T: Translation matrix of shape (3).
"""
# Retrive R,T,K of the selected view(s) by reading the metadata.
MAX_CAMERA_DISTANCE = 1.75 # Constant from R2N2.
dist = dist_ratio * MAX_CAMERA_DISTANCE
RT = compute_extrinsic_matrix(azim, elev, dist)

# Transform the mesh vertices from shapenet world to pytorch3d world.
shapenet_to_pytorch3d = torch.tensor(
[
Expand All @@ -285,9 +335,7 @@ def _compute_camera_calibration(self, azim: float, elev: float, dist_ratio: floa
],
dtype=torch.float32,
)
RT = compute_extrinsic_matrix(azim, elev, dist) # (4, 4)
RT = torch.transpose(RT, 0, 1).mm(shapenet_to_pytorch3d) # (4, 4)

# Extract rotation and translation matrices from RT.
R = RT[:3, :3]
T = RT[3, :3]
Expand Down Expand Up @@ -348,27 +396,3 @@ def render(
return super().render(
idxs=idxs, shader_type=shader_type, device=device, cameras=cameras, **kwargs
)


class BlenderCamera(CamerasBase):
"""
Camera for rendering objects with calibration matrices from the R2N2 dataset
(which uses Blender for rendering the views for each model).
"""

def __init__(self, R=r, T=t, K=k, device="cpu"):
"""
Args:
R: Rotation matrix of shape (N, 3, 3).
T: Translation matrix of shape (N, 3).
K: Intrinsic matrix of shape (N, 4, 4).
device: torch.device or str.
"""
# The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary.
super().__init__(device=device, R=R, T=T, K=K)

def get_projection_transform(self, **kwargs) -> Transform3d:
transform = Transform3d(device=self.device)
transform._matrix = self.K.transpose(1, 2).contiguous() # pyre-ignore[16]
return transform
Loading

0 comments on commit 63ba74f

Please sign in to comment.