-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathmsr.py
70 lines (54 loc) · 3.39 KB
/
msr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
import sys
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(ROOT_DIR)
sys.path.append(os.path.join(ROOT_DIR, 'modules'))
from point_4d_convolution import *
from transformer import *
class P4Transformer(nn.Module):
def __init__(self, radius, nsamples, spatial_stride, # P4DConv: spatial
temporal_kernel_size, temporal_stride, # P4DConv: temporal
emb_relu, # embedding: relu
dim, depth, heads, dim_head, # transformer
mlp_dim, num_classes): # output
super().__init__()
self.tube_embedding = P4DConv(in_planes=0, mlp_planes=[dim], mlp_batch_norm=[False], mlp_activation=[False],
spatial_kernel_size=[radius, nsamples], spatial_stride=spatial_stride,
temporal_kernel_size=temporal_kernel_size, temporal_stride=temporal_stride, temporal_padding=[1, 0],
operator='+', spatial_pooling='max', temporal_pooling='max')
self.pos_embedding = nn.Conv1d(in_channels=4, out_channels=dim, kernel_size=1, stride=1, padding=0, bias=True)
self.emb_relu = nn.ReLU() if emb_relu else False
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, num_classes),
)
def forward(self, input): # [B, L, N, 3]
device = input.get_device()
xyzs, features = self.tube_embedding(input) # [B, L, n, 3], [B, L, C, n]
xyzts = []
xyzs = torch.split(tensor=xyzs, split_size_or_sections=1, dim=1)
xyzs = [torch.squeeze(input=xyz, dim=1).contiguous() for xyz in xyzs]
for t, xyz in enumerate(xyzs):
t = torch.ones((xyz.size()[0], xyz.size()[1], 1), dtype=torch.float32, device=device) * (t+1)
xyzt = torch.cat(tensors=(xyz, t), dim=2)
xyzts.append(xyzt)
xyzts = torch.stack(tensors=xyzts, dim=1)
xyzts = torch.reshape(input=xyzts, shape=(xyzts.shape[0], xyzts.shape[1]*xyzts.shape[2], xyzts.shape[3])) # [B, L*n, 4]
features = features.permute(0, 1, 3, 2) # [B, L, n, C]
features = torch.reshape(input=features, shape=(features.shape[0], features.shape[1]*features.shape[2], features.shape[3])) # [B, L*n, C]
xyzts = self.pos_embedding(xyzts.permute(0, 2, 1)).permute(0, 2, 1)
embedding = xyzts + features
if self.emb_relu:
embedding = self.emb_relu(embedding)
output = self.transformer(embedding)
output = torch.max(input=output, dim=1, keepdim=False, out=None)[0]
output = self.mlp_head(output)
return output