-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathintegral_operators.py
368 lines (318 loc) · 15.7 KB
/
integral_operators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class SpectralConv1d_Uno(nn.Module):
def __init__(self, in_codim, out_codim, dim1,modes1 = None):
super(SpectralConv1d_Uno, self).__init__()
"""
1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
dim1 = Default output grid size along x (or 1st dimension of output domain)
Ratio of grid size of the input and the output implecitely
set the expansion or contraction farctor along each dimension of the domain.
modes1 = Number of fourier modes to consider for the integral operator.
Number of modes must be compatibale with the input grid size
and desired output grid size.
i.e., modes1 <= min( dim1/2, input_dim1/2).
Here "input_dim1" is the grid size along x axis (or first dimension) of the input domain.
in_codim = Input co-domian dimension
out_codim = output co-domain dimension
"""
in_codim = int(in_codim)
out_codim = int(out_codim)
self.in_channels = in_codim
self.out_channels = out_codim
self.dim1 = dim1 #output dimensions
if modes1 is not None:
self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
else:
self.modes1 = dim1//2
self.scale = (1 / (2*in_codim))**(1.0/2.0)
self.weights1 = nn.Parameter(self.scale * torch.randn(in_codim, out_codim, self.modes1, dtype=torch.cfloat))
# Complex multiplication
def compl_mul1d(self, input, weights):
# (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
return torch.einsum("bix,iox->box", input, weights)
def forward(self, x, dim1 = None):
"""
input shape = (batch, in_codim, input_dim1)
output shape = (batch, out_codim, dim1)
"""
if dim1 is not None:
self.dim1 = dim1
batchsize = x.shape[0]
x_ft = torch.fft.rfft(x, norm = 'forward')
# Multiply relevant Fourier modes
out_ft = torch.zeros(batchsize, self.out_channels, self.dim1//2 + 1 , dtype=torch.cfloat, device=x.device)
out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1)
#Return to physical space
x = torch.fft.irfft(out_ft, n=self.dim1, norm = 'forward')
return x
class pointwise_op_1D(nn.Module):
"""
All variables are consistent with the SpectralConv1d_Uno class.
"""
def __init__(self, in_codim, out_codim,dim1):
super(pointwise_op_1D,self).__init__()
self.conv = nn.Conv1d(int(in_codim), int(out_codim), 1)
self.dim1 = int(dim1)
def forward(self,x, dim1 = None):
if dim1 is None:
dim1 = self.dim1
x_out = self.conv(x)
x_out = torch.nn.functional.interpolate(x_out, size = dim1,mode = 'linear',align_corners=True, antialias= True)
return x_out
class OperatorBlock_1D(nn.Module):
"""
Normalize = if true performs InstanceNorm1d on the output.
Non_Lin = if true, applies point wise nonlinearity.
All other variables are consistent with the SpectralConv1d_Uno class.
"""
def __init__(self, in_codim, out_codim,dim1,modes1, Normalize = True,Non_Lin = True):
super(OperatorBlock_1D,self).__init__()
self.conv = SpectralConv1d_Uno(in_codim, out_codim, dim1,modes1)
self.w = pointwise_op_1D(in_codim, out_codim, dim1)
self.normalize = Normalize
self.non_lin = Non_Lin
if Normalize:
self.normalize_layer = torch.nn.InstanceNorm1d(int(out_codim),affine=True)
def forward(self,x, dim1 = None):
"""
input shape = (batch, in_codim, input_dim1)
output shape = (batch, out_codim, dim1)
"""
x1_out = self.conv(x,dim1)
x2_out = self.w(x,dim1)
x_out = x1_out + x2_out
if self.normalize:
x_out = self.normalize_layer(x_out)
if self.non_lin:
x_out = F.gelu(x_out)
return x_out
class SpectralConv2d_Uno(nn.Module):
def __init__(self, in_codim, out_codim, dim1, dim2,modes1 = None, modes2 = None):
super(SpectralConv2d_Uno, self).__init__()
"""
2D Fourier layer. It does FFT, linear transform, and Inverse FFT.
dim1 = Default output grid size along x (or 1st dimension of output domain)
dim2 = Default output grid size along y ( or 2nd dimension of output domain)
Ratio of grid size of the input and the output implecitely
set the expansion or contraction farctor along each dimension.
modes1, modes2 = Number of fourier modes to consider for the ontegral operator
Number of modes must be compatibale with the input grid size
and desired output grid size.
i.e., modes1 <= min( dim1/2, input_dim1/2).
Here "input_dim1" is the grid size along x axis (or first dimension) of the input domain.
Other modes also the have same constrain.
in_codim = Input co-domian dimension
out_codim = output co-domain dimension
"""
in_codim = int(in_codim)
out_codim = int(out_codim)
self.in_channels = in_codim
self.out_channels = out_codim
self.dim1 = dim1
self.dim2 = dim2
if modes1 is not None:
self.modes1 = modes1
self.modes2 = modes2
else:
self.modes1 = dim1//2-1
self.modes2 = dim2//2
self.scale = (1 / (2*in_codim))**(1.0/2.0)
self.weights1 = nn.Parameter(self.scale * (torch.randn(in_codim, out_codim, self.modes1, self.modes2, dtype=torch.cfloat)))
self.weights2 = nn.Parameter(self.scale * (torch.randn(in_codim, out_codim, self.modes1, self.modes2, dtype=torch.cfloat)))
# Complex multiplication
def compl_mul2d(self, input, weights):
return torch.einsum("bixy,ioxy->boxy", input, weights)
def forward(self, x, dim1 = None,dim2 = None):
if dim1 is not None:
self.dim1 = dim1
self.dim2 = dim2
batchsize = x.shape[0]
#Compute Fourier coeffcients up to factor of e^(- something constant)
x_ft = torch.fft.rfft2(x, norm = 'forward')
# Multiply relevant Fourier modes
out_ft = torch.zeros(batchsize, self.out_channels, self.dim1, self.dim2//2 + 1 , dtype=torch.cfloat, device=x.device)
out_ft[:, :, :self.modes1, :self.modes2] = \
self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1)
out_ft[:, :, -self.modes1:, :self.modes2] = \
self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2)
#Return to physical space
x = torch.fft.irfft2(out_ft, s=(self.dim1, self.dim2),norm = 'forward')
return x
class pointwise_op_2D(nn.Module):
"""
dim1 = Default output grid size along x (or 1st dimension)
dim2 = Default output grid size along y ( or 2nd dimension)
in_codim = Input co-domian dimension
out_codim = output co-domain dimension
"""
def __init__(self, in_codim, out_codim,dim1, dim2):
super(pointwise_op_2D,self).__init__()
self.conv = nn.Conv2d(int(in_codim), int(out_codim), 1)
self.dim1 = int(dim1)
self.dim2 = int(dim2)
def forward(self,x, dim1 = None, dim2 = None):
"""
input shape = (batch, in_codim, input_dim1,input_dim2)
output shape = (batch, out_codim, dim1,dim2)
"""
if dim1 is None:
dim1 = self.dim1
dim2 = self.dim2
x_out = self.conv(x)
#ft = torch.fft.rfft2(x_out)
#ft_u = torch.zeros_like(ft)
#ft_u[:dim1//2-1,:dim2//2-1] = ft[:dim1//2-1,:dim2//2-1]
#ft_u[-(dim1//2-1):,:dim2//2-1] = ft[-(dim1//2-1):,:dim2//2-1]
#x_out = torch.fft.irfft2(ft_u)
x_out = torch.nn.functional.interpolate(x_out, size = (dim1, dim2),mode = 'bicubic',align_corners=True, antialias=True)
return x_out
class OperatorBlock_2D(nn.Module):
"""
Normalize = if true performs InstanceNorm2d on the output.
Non_Lin = if true, applies point wise nonlinearity.
All other variables are consistent with the SpectralConv2d_Uno class.
"""
def __init__(self, in_codim, out_codim,dim1, dim2,modes1,modes2, Normalize = False, Non_Lin = True):
super(OperatorBlock_2D,self).__init__()
self.conv = SpectralConv2d_Uno(in_codim, out_codim, dim1,dim2,modes1,modes2)
self.w = pointwise_op_2D(in_codim, out_codim, dim1,dim2)
self.normalize = Normalize
self.non_lin = Non_Lin
if Normalize:
self.normalize_layer = torch.nn.InstanceNorm2d(int(out_codim),affine=True)
def forward(self,x, dim1 = None, dim2 = None):
"""
input shape = (batch, in_codim, input_dim1,input_dim2)
output shape = (batch, out_codim, dim1,dim2)
"""
x1_out = self.conv(x,dim1,dim2)
x2_out = self.w(x,dim1,dim2)
x_out = x1_out + x2_out
if self.normalize:
x_out = self.normalize_layer(x_out)
if self.non_lin:
x_out = F.gelu(x_out)
return x_out
class SpectralConv3d_Uno(nn.Module):
def __init__(self, in_codim, out_codim,dim1,dim2,dim3, modes1=None, modes2=None, modes3=None):
super(SpectralConv3d_Uno, self).__init__()
"""
3D Fourier layer. It does FFT, linear transform, and Inverse FFT.
dim1 = Default output grid size along x (or 1st dimension of output domain)
dim2 = Default output grid size along y ( or 2nd dimension of output domain)
dim3 = Default output grid size along time t ( or 3rd dimension of output domain)
Ratio of grid size of the input and output grid size (dim1,dim2,dim3) implecitely
set the expansion or contraction farctor along each dimension.
modes1, modes2, modes3 = Number of fourier modes to consider for the ontegral operator
Number of modes must be compatibale with the input grid size
and desired output grid size.
i.e., modes1 <= min( dim1/2, input_dim1/2).
modes2 <= min( dim2/2, input_dim2/2)
Here input_dim1, input_dim2 are respectively the grid size along
x axis and y axis (or first dimension and second dimension) of the input domain.
Other modes also have the same constrain.
in_codim = Input co-domian dimension
out_codim = output co-domain dimension
"""
in_codim = int(in_codim)
out_codim = int(out_codim)
self.in_channels = in_codim
self.out_channels = out_codim
self.dim1 = dim1
self.dim2 = dim2
self.dim3 = dim3
if modes1 is not None:
self.modes1 = modes1
self.modes2 = modes2
self.modes3 = modes3
else:
self.modes1 = dim1
self.modes2 = dim2
self.modes3 = dim3//2+1
self.scale = (1 / (2*in_codim))**(1.0/2.0)
self.weights1 = nn.Parameter(self.scale * torch.randn(in_codim, out_codim, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
self.weights2 = nn.Parameter(self.scale * torch.randn(in_codim, out_codim, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
self.weights3 = nn.Parameter(self.scale * torch.randn(in_codim, out_codim, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
self.weights4 = nn.Parameter(self.scale * torch.randn(in_codim, out_codim, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat))
# Complex multiplication
def compl_mul3d(self, input, weights):
return torch.einsum("bixyz,ioxyz->boxyz", input, weights)
def forward(self, x, dim1 = None,dim2=None,dim3=None):
"""
dim1,dim2,dim3 are the output grid size along (x,y,t)
input shape = (batch, in_codim, input_dim1, input_dim2, input_dim3)
output shape = (batch, out_codim, dim1,dim2,dim3)
"""
if dim1 is not None:
self.dim1 = dim1
self.dim2 = dim2
self.dim3 = dim3
batchsize = x.shape[0]
x_ft = torch.fft.rfftn(x, dim=[-3,-2,-1], norm = 'forward')
out_ft = torch.zeros(batchsize, self.out_channels, self.dim1, self.dim2, self.dim3//2 + 1, dtype=torch.cfloat, device=x.device)
out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \
self.compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1)
out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \
self.compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2)
out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \
self.compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3)
out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \
self.compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4)
#Return to physical space
x = torch.fft.irfftn(out_ft, s=(self.dim1, self.dim2, self.dim3), norm = 'forward')
return x
class pointwise_op_3D(nn.Module):
def __init__(self, in_codim, out_codim,dim1, dim2,dim3):
super(pointwise_op_3D,self).__init__()
self.conv = nn.Conv3d(int(in_codim), int(out_codim), 1)
self.dim1 = int(dim1)
self.dim2 = int(dim2)
self.dim3 = int(dim3)
def forward(self,x, dim1 = None, dim2 = None, dim3 = None):
"""
dim1,dim2,dim3 are the output dimensions (x,y,t)
"""
if dim1 is None:
dim1 = self.dim1
dim2 = self.dim2
dim3 = self.dim3
x_out = self.conv(x)
ft = torch.fft.rfftn(x_out,dim=[-3,-2,-1])
ft_u = torch.zeros_like(ft)
ft_u[:, :, :(dim1//2), :(dim2//2), :(dim3//2)] = ft[:, :, :(dim1//2), :(dim2//2), :(dim3//2)]
ft_u[:, :, -(dim1//2):, :(dim2//2), :(dim3//2)] = ft[:, :, -(dim1//2):, :(dim2//2), :(dim3//2)]
ft_u[:, :, :(dim1//2), -(dim2//2):, :(dim3//2)] = ft[:, :, :(dim1//2), -(dim2//2):, :(dim3//2)]
ft_u[:, :, -(dim1//2):, -(dim2//2):, :(dim3//2)] = ft[:, :, -(dim1//2):, -(dim2//2):, :(dim3//2)]
x_out = torch.fft.irfftn(ft_u, s=(dim1, dim2, dim3))
x_out = torch.nn.functional.interpolate(x_out, size = (dim1, dim2,dim3),mode = 'trilinear',align_corners=True)
return x_out
class OperatorBlock_3D(nn.Module):
"""
Normalize = if true performs InstanceNorm3d on the output.
Non_Lin = if true, applies point wise nonlinearity.
All other variables are consistent with the SpectralConv3d_Uno class.
"""
def __init__(self, in_codim, out_codim,dim1, dim2,dim3,modes1,modes2,modes3, Normalize = False,Non_Lin = True):
super(OperatorBlock_3D,self).__init__()
self.conv = SpectralConv3d_Uno(in_codim, out_codim, dim1,dim2,dim3,modes1,modes2,modes3)
self.w = pointwise_op_3D(in_codim, out_codim, dim1,dim2,dim3)
self.normalize = Normalize
self.non_lin = Non_Lin
if Normalize:
self.normalize_layer = torch.nn.InstanceNorm3d(int(out_codim),affine=True)
def forward(self,x, dim1 = None, dim2 = None, dim3 = None):
"""
input shape = (batch, in_codim, input_dim1, input_dim2, input_dim3)
output shape = (batch, out_codim, dim1,dim2,dim3)
"""
x1_out = self.conv(x,dim1,dim2,dim3)
x2_out = self.w(x,dim1,dim2,dim3)
x_out = x1_out + x2_out
if self.normalize:
x_out = self.normalize_layer(x_out)
if self.non_lin:
x_out = F.gelu(x_out)
return x_out