-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdecoder.py
143 lines (118 loc) · 5.22 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
class SdfDecoder(nn.Module):
def __init__(self,
d_in,
d_out,
d_hidden,
n_layers,
skip_in=[],
bias=0.5,
scale=1,
geometric_init=True,
weight_norm=True,
inside_outside=False):
super(SdfDecoder, self).__init__()
dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
self.num_layers = len(dims)
self.skip_in = skip_in
self.d_out = d_out
for l in range(0, self.num_layers - 1):
if l + 1 in self.skip_in:
out_dim = dims[l + 1] - dims[0]
else:
out_dim = dims[l + 1]
lin = nn.Linear(dims[l], out_dim)
if geometric_init:
if l == self.num_layers - 2:
if not inside_outside:
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
torch.nn.init.constant_(lin.bias, -bias)
else:
torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
torch.nn.init.constant_(lin.bias, bias)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
if weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.activation = nn.Softplus(beta=100)
def forward(self, inputs):
x = inputs
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
if l in self.skip_in:
x = torch.cat([x, inputs], 1) / np.sqrt(2)
x = lin(x)
if l < self.num_layers - 2:
x = self.activation(x)
if self.d_out == 3:
x = torch.sigmoid(x)
return x
class SdfModel(nn.Module):
def __init__(self,config_json ):
super().__init__()
self.model = SdfDecoder(d_in=config_json['channels'],
d_out=1,
d_hidden=config_json['width'],
n_layers=config_json['n_layers'],
skip_in=config_json['skip_in'],
).cuda()
self.model.load_state_dict(torch.load(config_json['ckpt_path'], map_location='cuda'))
self.model.eval()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), self.specs["sdf_lr"])
return optimizer
def normalize_coordinate2(self, p, padding=0.1, plane='xz'):
''' Normalize coordinate to [0, 1] for unit cube experiments
Args:
p (tensor): point
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
plane (str): plane feature type, ['xz', 'xy', 'yz']
'''
if plane == 'zx':
xy = p[:, :, [2, 0]]
elif plane == 'yx':
xy = p[:, :, [1, 0]]
else:
xy = p[:, :, [1, 2]]
return xy
def sample_plane_feature(self, query, plane_feature, plane, padding=0.1):
xy = self.normalize_coordinate2(query.clone(), plane=plane, padding=padding)
xy = xy[:, :, None].float()
# vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
# vgrid = xy - 1.0
vgrid = xy
sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True,
mode='bilinear').squeeze(-1)
return sampled_feat
def forward_with_plane_features(self, plane_features, xyz):
'''
plane_features: B, D*3, res, res (e.g. B, 768, 64, 64)
xyz: B, N, 3
'''
point_features = self.get_points_plane_features(plane_features, xyz) # point_features: B, N, D
pred_sdf = self.model(point_features)
return pred_sdf # [B, num_points]
def get_points_plane_features(self, plane_features, query):
# plane features shape: batch, dim*3, 64, 64
fea = {}
fea['yx'], fea['zx'], fea['yz'] = plane_features[:, 0, ...], plane_features[:, 1, ...], plane_features[:, 2,
...]
# print("shapes: ", fea['xz'].shape, fea['xy'].shape, fea['yz'].shape) #([1, 256, 64, 64])
plane_feat_sum = 0
plane_feat_sum += self.sample_plane_feature(query, fea['yx'], 'yx')
plane_feat_sum += self.sample_plane_feature(query, fea['zx'], 'zx')
plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')
return plane_feat_sum.transpose(2, 1)
def forward(self, plane_features, xyz):
'''
plane_features: B, D*3, res, res (e.g. B, 768, 64, 64)
xyz: B, N, 3
'''
point_features = self.get_points_plane_features(plane_features, xyz) # point_features: B, N, D
pred_sdf = self.model(point_features)
return pred_sdf # [B, num_points]