-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
50 lines (43 loc) · 1.53 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
def _sigmoid(x):
y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
return y
def _gather_feat(feat, ind, mask=None):
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
if mask is not None:
mask = mask.unsqueeze(2).expand_as(feat)
feat = feat[mask]
feat = feat.view(-1, dim)
return feat
def _tranpose_and_gather_feat(feat, ind):
feat = feat.permute(0, 2, 3, 1).contiguous()
feat = feat.view(feat.size(0), -1, feat.size(3))
feat = _gather_feat(feat, ind)
return feat
def flip_tensor(x):
return torch.flip(x, [3])
# tmp = x.detach().cpu().numpy()[..., ::-1].copy()
# return torch.from_numpy(tmp).to(x.device)
def flip_lr(x, flip_idx):
tmp = x.detach().cpu().numpy()[..., ::-1].copy()
shape = tmp.shape
for e in flip_idx:
tmp[:, e[0], ...], tmp[:, e[1], ...] = \
tmp[:, e[1], ...].copy(), tmp[:, e[0], ...].copy()
return torch.from_numpy(tmp.reshape(shape)).to(x.device)
def flip_lr_off(x, flip_idx):
tmp = x.detach().cpu().numpy()[..., ::-1].copy()
shape = tmp.shape
tmp = tmp.reshape(tmp.shape[0], 17, 2,
tmp.shape[2], tmp.shape[3])
tmp[:, :, 0, :, :] *= -1
for e in flip_idx:
tmp[:, e[0], ...], tmp[:, e[1], ...] = \
tmp[:, e[1], ...].copy(), tmp[:, e[0], ...].copy()
return torch.from_numpy(tmp.reshape(shape)).to(x.device)