1
+ import torch
2
+ import numpy as np
3
+ import torch .nn as nn
4
+ import torch .nn .functional as F
5
+
6
+ import operator
7
+ from functools import reduce
8
+ from functools import partial
9
+
10
+ from timeit import default_timer
11
+
12
+ torch .manual_seed (0 )
13
+ np .random .seed (0 )
14
+
15
+ class SpectralConv4d (nn .Module ):
16
+ def __init__ (self , in_channels , out_channels , modes1 , modes2 , modes3 , modes4 ):
17
+ super (SpectralConv4d , self ).__init__ ()
18
+
19
+ """
20
+ 4D Fourier layer. It does FFT, linear transform, and Inverse FFT.
21
+ """
22
+
23
+ self .in_channels = in_channels
24
+ self .out_channels = out_channels
25
+ self .modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
26
+ self .modes2 = modes2
27
+ self .modes3 = modes3
28
+ self .modes4 = modes4
29
+
30
+ self .scale = (1 / (in_channels * out_channels ))
31
+ self .weights1 = nn .Parameter (self .scale * torch .rand (in_channels , out_channels , self .modes1 , self .modes2 , self .modes3 , self .modes4 , dtype = torch .cfloat ))
32
+ self .weights2 = nn .Parameter (self .scale * torch .rand (in_channels , out_channels , self .modes1 , self .modes2 , self .modes3 , self .modes4 , dtype = torch .cfloat ))
33
+ self .weights3 = nn .Parameter (self .scale * torch .rand (in_channels , out_channels , self .modes1 , self .modes2 , self .modes3 , self .modes4 , dtype = torch .cfloat ))
34
+ self .weights4 = nn .Parameter (self .scale * torch .rand (in_channels , out_channels , self .modes1 , self .modes2 , self .modes3 , self .modes4 , dtype = torch .cfloat ))
35
+ self .weights5 = nn .Parameter (self .scale * torch .rand (in_channels , out_channels , self .modes1 , self .modes2 , self .modes3 , self .modes4 , dtype = torch .cfloat ))
36
+ self .weights6 = nn .Parameter (self .scale * torch .rand (in_channels , out_channels , self .modes1 , self .modes2 , self .modes3 , self .modes4 , dtype = torch .cfloat ))
37
+ self .weights7 = nn .Parameter (self .scale * torch .rand (in_channels , out_channels , self .modes1 , self .modes2 , self .modes3 , self .modes4 , dtype = torch .cfloat ))
38
+ self .weights8 = nn .Parameter (self .scale * torch .rand (in_channels , out_channels , self .modes1 , self .modes2 , self .modes3 , self .modes4 , dtype = torch .cfloat ))
39
+
40
+ # Complex multiplication
41
+ def compl_mul4d (self , input , weights ):
42
+ # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
43
+ return torch .einsum ("bixyzt,ioxyzt->boxyzt" , input , weights )
44
+
45
+ def forward (self , x ):
46
+ batchsize = x .shape [0 ]
47
+ #Compute Fourier coeffcients up to factor of e^(- something constant)
48
+ x_ft = torch .fft .rfftn (x , dim = [- 4 ,- 3 ,- 2 ,- 1 ])
49
+
50
+ # Multiply relevant Fourier modes
51
+ out_ft = torch .zeros (batchsize , self .out_channels , x .size (- 4 ), x .size (- 3 ), x .size (- 2 ), x .size (- 1 )// 2 + 1 , dtype = torch .cfloat , device = x .device )
52
+
53
+ out_ft [:, :, :self .modes1 , :self .modes2 , :self .modes3 , :self .modes4 ] = self .compl_mul4d (x_ft [:, :, :self .modes1 , :self .modes2 , :self .modes3 , :self .modes4 ], self .weights1 )
54
+ out_ft [:, :, - self .modes1 :, :self .modes2 , :self .modes3 , :self .modes4 ] = self .compl_mul4d (x_ft [:, :, - self .modes1 :, :self .modes2 , :self .modes3 , :self .modes4 ], self .weights2 )
55
+ out_ft [:, :, :self .modes1 , - self .modes2 :, :self .modes3 , :self .modes4 ] = self .compl_mul4d (x_ft [:, :, :self .modes1 , - self .modes2 :, :self .modes3 , :self .modes4 ], self .weights3 )
56
+ out_ft [:, :, :self .modes1 , :self .modes2 , - self .modes3 :, :self .modes4 ] = self .compl_mul4d (x_ft [:, :, :self .modes1 , :self .modes2 , - self .modes3 :, :self .modes4 ], self .weights4 )
57
+ out_ft [:, :, - self .modes1 :, - self .modes2 :, :self .modes3 , :self .modes4 ] = self .compl_mul4d (x_ft [:, :, - self .modes1 :, - self .modes2 :, :self .modes3 , :self .modes4 ], self .weights5 )
58
+ out_ft [:, :, - self .modes1 :, :self .modes2 , - self .modes3 :, :self .modes4 ] = self .compl_mul4d (x_ft [:, :, - self .modes1 :, :self .modes2 , - self .modes3 :, :self .modes4 ], self .weights6 )
59
+ out_ft [:, :, :self .modes1 , - self .modes2 :, - self .modes3 :, :self .modes4 ] = self .compl_mul4d (x_ft [:, :, :self .modes1 , - self .modes2 :, - self .modes3 :, :self .modes4 ], self .weights7 )
60
+ out_ft [:, :, - self .modes1 :, - self .modes2 :, - self .modes3 :, :self .modes4 ] = self .compl_mul4d (x_ft [:, :, - self .modes1 :, - self .modes2 :, - self .modes3 :, :self .modes4 ], self .weights8 )
61
+
62
+ #Return to physical space
63
+ x = torch .fft .irfftn (out_ft , s = (x .size (- 4 ), x .size (- 3 ), x .size (- 2 ), x .size (- 1 )))
64
+ return x
65
+
66
+ class Block4d (nn .Module ):
67
+ def __init__ (self , width , width2 , modes1 , modes2 , modes3 , modes4 , out_dim ):
68
+ super (Block4d , self ).__init__ ()
69
+ self .modes1 = modes1
70
+ self .modes2 = modes2
71
+ self .modes3 = modes3
72
+ self .modes4 = modes4
73
+
74
+ self .width = width
75
+ self .width2 = width2
76
+ self .out_dim = out_dim
77
+ self .padding = 8
78
+
79
+ # channel
80
+ self .conv0 = SpectralConv4d (self .width , self .width , self .modes1 , self .modes2 , self .modes3 , self .modes4 )
81
+ self .conv1 = SpectralConv4d (self .width , self .width , self .modes1 , self .modes2 , self .modes3 , self .modes4 )
82
+ self .conv2 = SpectralConv4d (self .width , self .width , self .modes1 , self .modes2 , self .modes3 , self .modes4 )
83
+ self .conv3 = SpectralConv4d (self .width , self .width , self .modes1 , self .modes2 , self .modes3 , self .modes4 )
84
+ self .w0 = nn .Conv1d (self .width , self .width , 1 )
85
+ self .w1 = nn .Conv1d (self .width , self .width , 1 )
86
+ self .w2 = nn .Conv1d (self .width , self .width , 1 )
87
+ self .w3 = nn .Conv1d (self .width , self .width , 1 )
88
+ self .fc1 = nn .Linear (self .width , self .width2 )
89
+ self .fc2 = nn .Linear (self .width2 , self .out_dim )
90
+
91
+ def forward (self , x ):
92
+ batchsize = x .shape [0 ]
93
+ size_x , size_y , size_z , size_t = x .shape [2 ], x .shape [3 ], x .shape [4 ], x .shape [5 ]
94
+ # print(size_x, size_y, size_z, size_t)
95
+ # channel
96
+ # print(x.shape)
97
+ x1 = self .conv0 (x )
98
+ # print(x1.shape)
99
+ x2 = self .w0 (x .view (batchsize , self .width , - 1 )).view (batchsize , self .width , size_x , size_y , size_z , size_t )
100
+ x = x1 + x2
101
+ x = F .gelu (x )
102
+
103
+ x1 = self .conv1 (x )
104
+ x2 = self .w1 (x .view (batchsize , self .width , - 1 )).view (batchsize , self .width , size_x , size_y , size_z , size_t )
105
+ x = x1 + x2
106
+ x = F .gelu (x )
107
+
108
+ x1 = self .conv2 (x )
109
+ x2 = self .w2 (x .view (batchsize , self .width , - 1 )).view (batchsize , self .width , size_x , size_y , size_z , size_t )
110
+ x = x1 + x2
111
+ x = F .gelu (x )
112
+
113
+ x1 = self .conv3 (x )
114
+ x2 = self .w3 (x .view (batchsize , self .width , - 1 )).view (batchsize , self .width , size_x , size_y , size_z , size_t )
115
+ x = x1 + x2
116
+
117
+ x = x [:, :, self .padding :- self .padding , self .padding * 2 :- self .padding * 2 ,
118
+ self .padding * 2 :- self .padding * 2 , self .padding :- self .padding ]
119
+
120
+ x = x .permute (0 , 2 , 3 , 4 , 5 , 1 ) # pad the domain if input is non-periodic
121
+ x1 = self .fc1 (x )
122
+ x = F .gelu (x1 )
123
+ x = self .fc2 (x )
124
+
125
+ return x
126
+
127
+ class FNO4d (nn .Module ):
128
+ def __init__ (self , modes1 , modes2 , modes3 , modes4 , width , in_dim ):
129
+ super (FNO4d , self ).__init__ ()
130
+
131
+ self .modes1 = modes1
132
+ self .modes2 = modes2
133
+ self .modes3 = modes3
134
+ self .modes4 = modes4
135
+ self .width = width
136
+ self .width2 = width * 4
137
+ self .in_dim = in_dim
138
+ self .out_dim = 1
139
+ self .padding = 8 # pad the domain if input is non-periodic
140
+
141
+ self .fc0 = nn .Linear (self .in_dim , self .width )
142
+ self .conv = Block4d (self .width , self .width2 ,
143
+ self .modes1 , self .modes2 , self .modes3 , self .modes4 , self .out_dim )
144
+
145
+ def forward (self , x , gradient = False ):
146
+ x = self .fc0 (x )
147
+ x = x .permute (0 , 5 , 1 , 2 , 3 , 4 )
148
+ x = F .pad (x , [self .padding , self .padding , self .padding * 2 , self .padding * 2 , self .padding * 2 ,
149
+ self .padding * 2 , self .padding , self .padding ])
150
+
151
+ x = self .conv (x )
152
+
153
+ return x
154
+
155
+
156
+
0 commit comments