-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New code and pretrained model updated
- Loading branch information
Showing
34 changed files
with
2,410 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .dataloaders import ModelNet40Data | ||
from .dataloaders import ClassificationData, RegistrationData, SegmentationData, FlowData, SceneflowDataset | ||
from .dataloaders import download_modelnet40, deg_to_rad, create_random_transform |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.utils.data import Dataset | ||
from torch.utils.data import DataLoader | ||
import numpy as np | ||
import os | ||
import h5py | ||
import subprocess | ||
import shlex | ||
import json | ||
import glob | ||
from .. ops import transform_functions, se3 | ||
|
||
def download_modelnet40(): | ||
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | ||
DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data') | ||
if not os.path.exists(DATA_DIR): | ||
os.mkdir(DATA_DIR) | ||
if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): | ||
www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' | ||
zipfile = os.path.basename(www) | ||
os.system('wget %s; unzip %s' % (www, zipfile)) | ||
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) | ||
os.system('rm %s' % (zipfile)) | ||
|
||
def load_data(train): | ||
if train: partition = 'train' | ||
else: partition = 'test' | ||
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | ||
DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data') | ||
all_data = [] | ||
all_label = [] | ||
for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5' % partition)): | ||
f = h5py.File(h5_name) | ||
data = f['data'][:].astype('float32') | ||
label = f['label'][:].astype('int64') | ||
f.close() | ||
all_data.append(data) | ||
all_label.append(label) | ||
all_data = np.concatenate(all_data, axis=0) | ||
all_label = np.concatenate(all_label, axis=0) | ||
return all_data, all_label | ||
|
||
def deg_to_rad(deg): | ||
return np.pi / 180 * deg | ||
|
||
def create_random_transform(dtype, max_rotation_deg, max_translation): | ||
max_rotation = deg_to_rad(max_rotation_deg) | ||
rot = np.random.uniform(-max_rotation, max_rotation, [1, 3]) | ||
trans = np.random.uniform(-max_translation, max_translation, [1, 3]) | ||
quat = transform_functions.euler_to_quaternion(rot, "xyz") | ||
|
||
vec = np.concatenate([quat, trans], axis=1) | ||
vec = torch.tensor(vec, dtype=dtype) | ||
return vec | ||
|
||
class ModelNet40Data(Dataset): | ||
def __init__( | ||
self, | ||
train=True, | ||
num_points=1024, | ||
download=True, | ||
randomize_data=False | ||
): | ||
super(ModelNet40Data, self).__init__() | ||
if download: download_modelnet40() | ||
self.data, self.labels = load_data(train) | ||
if not train: self.shapes = self.read_classes_ModelNet40() | ||
self.num_points = num_points | ||
self.randomize_data = randomize_data | ||
|
||
def __getitem__(self, idx): | ||
if self.randomize_data: current_points = self.randomize(idx) | ||
else: current_points = self.data[idx].copy() | ||
|
||
current_points = torch.from_numpy(current_points).float() | ||
label = torch.from_numpy(self.labels[idx]).type(torch.LongTensor) | ||
|
||
return current_points, label | ||
|
||
def __len__(self): | ||
return self.data.shape[0] | ||
|
||
def randomize(self, idx): | ||
pt_idxs = np.arange(0, self.num_points) | ||
np.random.shuffle(pt_idxs) | ||
return self.data[idx, pt_idxs].copy() | ||
|
||
def get_shape(self, label): | ||
return self.shapes[label] | ||
|
||
def read_classes_ModelNet40(self): | ||
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | ||
DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data') | ||
file = open(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'shape_names.txt'), 'r') | ||
shape_names = file.read() | ||
shape_names = np.array(shape_names.split('\n')[:-1]) | ||
return shape_names | ||
|
||
class RegistrationData(Dataset): | ||
def __init__(self, algorithm, data_class=ModelNet40Data()): | ||
super(RegistrationData, self).__init__() | ||
self.algorithm = 'iPCRNet' | ||
|
||
self.set_class(data_class) | ||
if self.algorithm == 'PCRNet' or self.algorithm == 'iPCRNet': | ||
from .. ops.transform_functions import PCRNetTransform | ||
self.transforms = PCRNetTransform(len(data_class), angle_range=45, translation_range=1) | ||
|
||
def __len__(self): | ||
return len(self.data_class) | ||
|
||
def set_class(self, data_class): | ||
self.data_class = data_class | ||
|
||
def __getitem__(self, index): | ||
template, label = self.data_class[index] | ||
self.transforms.index = index # for fixed transformations in PCRNet. | ||
source = self.transforms(template) | ||
igt = self.transforms.igt | ||
return template, source, igt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
try: | ||
from .emd import EMDLoss | ||
except: | ||
print("Sorry EMD loss is not compatible with your system!") | ||
try: | ||
from .chamfer_distance import ChamferDistanceLoss | ||
except: | ||
print("Sorry ChamferDistance loss is not compatible with your system!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
def chamfer_distance(template: torch.Tensor, source: torch.Tensor): | ||
from .cuda.chamfer_distance import ChamferDistance | ||
cost_p0_p1, cost_p1_p0 = ChamferDistance()(template, source) | ||
cost_p0_p1 = torch.mean(torch.sqrt(cost_p0_p1)) | ||
cost_p1_p0 = torch.mean(torch.sqrt(cost_p1_p0)) | ||
chamfer_loss = (cost_p0_p1 + cost_p1_p0)/2.0 | ||
return chamfer_loss | ||
|
||
|
||
class ChamferDistanceLoss(nn.Module): | ||
def __init__(self): | ||
super(ChamferDistanceLoss, self).__init__() | ||
|
||
def forward(self, template, source): | ||
return chamfer_distance(template, source) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .chamfer_distance import ChamferDistance |
185 changes: 185 additions & 0 deletions
185
pcrnet/losses/cuda/chamfer_distance/chamfer_distance.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
#include <torch/torch.h> | ||
|
||
// CUDA forward declarations | ||
int ChamferDistanceKernelLauncher( | ||
const int b, const int n, | ||
const float* xyz, | ||
const int m, | ||
const float* xyz2, | ||
float* result, | ||
int* result_i, | ||
float* result2, | ||
int* result2_i); | ||
|
||
int ChamferDistanceGradKernelLauncher( | ||
const int b, const int n, | ||
const float* xyz1, | ||
const int m, | ||
const float* xyz2, | ||
const float* grad_dist1, | ||
const int* idx1, | ||
const float* grad_dist2, | ||
const int* idx2, | ||
float* grad_xyz1, | ||
float* grad_xyz2); | ||
|
||
|
||
void chamfer_distance_forward_cuda( | ||
const at::Tensor xyz1, | ||
const at::Tensor xyz2, | ||
const at::Tensor dist1, | ||
const at::Tensor dist2, | ||
const at::Tensor idx1, | ||
const at::Tensor idx2) | ||
{ | ||
ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(), | ||
xyz2.size(1), xyz2.data<float>(), | ||
dist1.data<float>(), idx1.data<int>(), | ||
dist2.data<float>(), idx2.data<int>()); | ||
} | ||
|
||
void chamfer_distance_backward_cuda( | ||
const at::Tensor xyz1, | ||
const at::Tensor xyz2, | ||
at::Tensor gradxyz1, | ||
at::Tensor gradxyz2, | ||
at::Tensor graddist1, | ||
at::Tensor graddist2, | ||
at::Tensor idx1, | ||
at::Tensor idx2) | ||
{ | ||
ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(), | ||
xyz2.size(1), xyz2.data<float>(), | ||
graddist1.data<float>(), idx1.data<int>(), | ||
graddist2.data<float>(), idx2.data<int>(), | ||
gradxyz1.data<float>(), gradxyz2.data<float>()); | ||
} | ||
|
||
|
||
void nnsearch( | ||
const int b, const int n, const int m, | ||
const float* xyz1, | ||
const float* xyz2, | ||
float* dist, | ||
int* idx) | ||
{ | ||
for (int i = 0; i < b; i++) { | ||
for (int j = 0; j < n; j++) { | ||
const float x1 = xyz1[(i*n+j)*3+0]; | ||
const float y1 = xyz1[(i*n+j)*3+1]; | ||
const float z1 = xyz1[(i*n+j)*3+2]; | ||
double best = 0; | ||
int besti = 0; | ||
for (int k = 0; k < m; k++) { | ||
const float x2 = xyz2[(i*m+k)*3+0] - x1; | ||
const float y2 = xyz2[(i*m+k)*3+1] - y1; | ||
const float z2 = xyz2[(i*m+k)*3+2] - z1; | ||
const double d=x2*x2+y2*y2+z2*z2; | ||
if (k==0 || d < best){ | ||
best = d; | ||
besti = k; | ||
} | ||
} | ||
dist[i*n+j] = best; | ||
idx[i*n+j] = besti; | ||
} | ||
} | ||
} | ||
|
||
|
||
void chamfer_distance_forward( | ||
const at::Tensor xyz1, | ||
const at::Tensor xyz2, | ||
const at::Tensor dist1, | ||
const at::Tensor dist2, | ||
const at::Tensor idx1, | ||
const at::Tensor idx2) | ||
{ | ||
const int batchsize = xyz1.size(0); | ||
const int n = xyz1.size(1); | ||
const int m = xyz2.size(1); | ||
|
||
const float* xyz1_data = xyz1.data<float>(); | ||
const float* xyz2_data = xyz2.data<float>(); | ||
float* dist1_data = dist1.data<float>(); | ||
float* dist2_data = dist2.data<float>(); | ||
int* idx1_data = idx1.data<int>(); | ||
int* idx2_data = idx2.data<int>(); | ||
|
||
nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data); | ||
nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data); | ||
} | ||
|
||
|
||
void chamfer_distance_backward( | ||
const at::Tensor xyz1, | ||
const at::Tensor xyz2, | ||
at::Tensor gradxyz1, | ||
at::Tensor gradxyz2, | ||
at::Tensor graddist1, | ||
at::Tensor graddist2, | ||
at::Tensor idx1, | ||
at::Tensor idx2) | ||
{ | ||
const int b = xyz1.size(0); | ||
const int n = xyz1.size(1); | ||
const int m = xyz2.size(1); | ||
|
||
const float* xyz1_data = xyz1.data<float>(); | ||
const float* xyz2_data = xyz2.data<float>(); | ||
float* gradxyz1_data = gradxyz1.data<float>(); | ||
float* gradxyz2_data = gradxyz2.data<float>(); | ||
float* graddist1_data = graddist1.data<float>(); | ||
float* graddist2_data = graddist2.data<float>(); | ||
const int* idx1_data = idx1.data<int>(); | ||
const int* idx2_data = idx2.data<int>(); | ||
|
||
for (int i = 0; i < b*n*3; i++) | ||
gradxyz1_data[i] = 0; | ||
for (int i = 0; i < b*m*3; i++) | ||
gradxyz2_data[i] = 0; | ||
for (int i = 0;i < b; i++) { | ||
for (int j = 0; j < n; j++) { | ||
const float x1 = xyz1_data[(i*n+j)*3+0]; | ||
const float y1 = xyz1_data[(i*n+j)*3+1]; | ||
const float z1 = xyz1_data[(i*n+j)*3+2]; | ||
const int j2 = idx1_data[i*n+j]; | ||
|
||
const float x2 = xyz2_data[(i*m+j2)*3+0]; | ||
const float y2 = xyz2_data[(i*m+j2)*3+1]; | ||
const float z2 = xyz2_data[(i*m+j2)*3+2]; | ||
const float g = graddist1_data[i*n+j]*2; | ||
|
||
gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2); | ||
gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2); | ||
gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2); | ||
gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2)); | ||
gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2)); | ||
gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2)); | ||
} | ||
for (int j = 0; j < m; j++) { | ||
const float x1 = xyz2_data[(i*m+j)*3+0]; | ||
const float y1 = xyz2_data[(i*m+j)*3+1]; | ||
const float z1 = xyz2_data[(i*m+j)*3+2]; | ||
const int j2 = idx2_data[i*m+j]; | ||
const float x2 = xyz1_data[(i*n+j2)*3+0]; | ||
const float y2 = xyz1_data[(i*n+j2)*3+1]; | ||
const float z2 = xyz1_data[(i*n+j2)*3+2]; | ||
const float g = graddist2_data[i*m+j]*2; | ||
gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2); | ||
gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2); | ||
gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2); | ||
gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2)); | ||
gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2)); | ||
gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2)); | ||
} | ||
} | ||
} | ||
|
||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("forward", &chamfer_distance_forward, "ChamferDistance forward"); | ||
m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)"); | ||
m.def("backward", &chamfer_distance_backward, "ChamferDistance backward"); | ||
m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)"); | ||
} |
Oops, something went wrong.