-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathsoft_dtw.py
101 lines (87 loc) · 3.04 KB
/
soft_dtw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import torch
from numba import jit
from torch.autograd import Function
def pairwise_distances(x, y=None):
'''
Input: x is a Nxd matrix
y is an optional Mxd matirx
Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
if y is not given then use 'y=x'.
i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
'''
x_norm = (x**2).sum(1).view(-1, 1)
if y is not None:
y_t = torch.transpose(y, 0, 1)
y_norm = (y**2).sum(1).view(1, -1)
else:
y_t = torch.transpose(x, 0, 1)
y_norm = x_norm.view(1, -1)
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
return torch.clamp(dist, 0.0, float('inf'))
@jit(nopython = True)
def compute_softdtw(D, gamma):
N = D.shape[0]
M = D.shape[1]
R = np.zeros((N + 2, M + 2)) + 1e8
R[0, 0] = 0
for j in range(1, M + 1):
for i in range(1, N + 1):
r0 = -R[i - 1, j - 1] / gamma
r1 = -R[i - 1, j] / gamma
r2 = -R[i, j - 1] / gamma
rmax = max(max(r0, r1), r2)
rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax)
softmin = - gamma * (np.log(rsum) + rmax)
R[i, j] = D[i - 1, j - 1] + softmin
return R
@jit(nopython = True)
def compute_softdtw_backward(D_, R, gamma):
N = D_.shape[0]
M = D_.shape[1]
D = np.zeros((N + 2, M + 2))
E = np.zeros((N + 2, M + 2))
D[1:N + 1, 1:M + 1] = D_
E[-1, -1] = 1
R[:, -1] = -1e8
R[-1, :] = -1e8
R[-1, -1] = R[-2, -2]
for j in range(M, 0, -1):
for i in range(N, 0, -1):
a0 = (R[i + 1, j] - R[i, j] - D[i + 1, j]) / gamma
b0 = (R[i, j + 1] - R[i, j] - D[i, j + 1]) / gamma
c0 = (R[i + 1, j + 1] - R[i, j] - D[i + 1, j + 1]) / gamma
a = np.exp(a0)
b = np.exp(b0)
c = np.exp(c0)
E[i, j] = E[i + 1, j] * a + E[i, j + 1] * b + E[i + 1, j + 1] * c
return E[1:N + 1, 1:M + 1]
class SoftDTWBatch(Function):
@staticmethod
def forward(ctx, D, gamma = 1.0): # D.shape: [batch_size, N , N]
dev = D.device
batch_size,N,N = D.shape
gamma = torch.FloatTensor([gamma]).to(dev)
D_ = D.detach().cpu().numpy()
g_ = gamma.item()
total_loss = 0
R = torch.zeros((batch_size, N+2 ,N+2)).to(dev)
for k in range(0, batch_size): # loop over all D in the batch
Rk = torch.FloatTensor(compute_softdtw(D_[k,:,:], g_)).to(dev)
R[k:k+1,:,:] = Rk
total_loss = total_loss + Rk[-2,-2]
ctx.save_for_backward(D, R, gamma)
return total_loss / batch_size
@staticmethod
def backward(ctx, grad_output):
dev = grad_output.device
D, R, gamma = ctx.saved_tensors
batch_size,N,N = D.shape
D_ = D.detach().cpu().numpy()
R_ = R.detach().cpu().numpy()
g_ = gamma.item()
E = torch.zeros((batch_size, N ,N)).to(dev)
for k in range(batch_size):
Ek = torch.FloatTensor(compute_softdtw_backward(D_[k,:,:], R_[k,:,:], g_)).to(dev)
E[k:k+1,:,:] = Ek
return grad_output * E, None