-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathssm_init.py
109 lines (92 loc) · 3.09 KB
/
ssm_init.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copied from https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/sequence/ss/dplr.py
"""Initializations of structured state space models"""
import math
import torch
from einops import repeat
def dplr(
scaling="linear",
N=64,
rank=1,
H=1,
dtype=torch.float,
real_scale=1.0,
imag_scale=1.0,
random_real=False,
random_imag=False,
normalize=False,
diagonal=True,
random_B=False,
):
assert dtype == torch.float or dtype == torch.double
dtype = torch.cfloat if dtype == torch.float else torch.cdouble
pi = torch.tensor(math.pi)
if random_real:
real_part = torch.rand(H, N // 2)
else:
real_part = 0.5 * torch.ones(H, N // 2)
if random_imag:
imag_part = N // 2 * torch.rand(H, N // 2)
else:
imag_part = repeat(torch.arange(N // 2), "n -> h n", h=H)
real_part = real_scale * real_part
if scaling == "random":
imag_part = torch.randn(H, N // 2)
elif scaling == "real":
imag_part = 0 * imag_part
real_part = 1 + repeat(torch.arange(N // 2), "n -> h n", h=H)
elif scaling in ["linear", "lin"]:
imag_part = pi * imag_part
elif scaling in [
"inverse",
"inv",
]: # Based on asymptotics of the default HiPPO matrix
imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1)
elif scaling in ["inverse2", "inv2"]:
imag_part = 1 / pi * N * (N / (1 + imag_part) - 1)
elif scaling in ["quadratic", "quad"]:
imag_part = 1 / pi * (1 + 2 * imag_part) ** 2
else:
raise NotImplementedError
imag_part = imag_scale * imag_part
w = -real_part + 1j * imag_part
# Initialize B
if random_B:
B = torch.randn(H, N // 2, dtype=dtype)
else:
B = torch.ones(H, N // 2, dtype=dtype)
if normalize:
norm = (
-B / w
) # (H, N) # Result if you integrate the kernel with constant 1 function
zeta = 2 * torch.sum(
torch.abs(norm) ** 2, dim=-1, keepdim=True
) # Variance with a random C vector
B = B / zeta**0.5
P = torch.randn(rank, H, N // 2, dtype=dtype)
if diagonal:
P = P * 0.0
V = torch.eye(N, dtype=dtype)[:, : N // 2] # Only used in testing
V = repeat(V, "n m -> h n m", h=H)
return w, P, B, V
def ssm(measure, N, R, H, **ssm_args):
"""Dispatcher to create single SSM initialization
N: state size
R: rank (for DPLR parameterization)
H: number of independent SSM copies
"""
if measure == "dplr":
w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args)
elif measure.startswith("diag"):
args = measure.split("-")
assert args[0] == "diag" and len(args) > 1
scaling = args[1]
w, P, B, V = dplr(scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args)
return w, P, B, V
combinations = {
"hippo": ["legs", "fourier"],
"diag": ["diag-inv", "diag-lin"],
"all": ["legs", "fourier", "diag-inv", "diag-lin"],
}
def combination(measure, N, R, S):
A, P, B, V = ssm(measure, N, R, S)
return A, P, B, V