Skip to content

Commit aa65652

Browse files
author
Gege Wen
committed
training code
1 parent 293d228 commit aa65652

40 files changed

+4073
-37
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
__pycache__/
2+
3+
ECLIPSE/meta_data/*.npy
4+
*.pt
5+
logs/*

.ipynb_checkpoints/eval_sequential_prediction_dp-checkpoint.ipynb

+399
Large diffs are not rendered by default.

.ipynb_checkpoints/eval_sequential_prediction_sg-checkpoint.ipynb

+412
Large diffs are not rendered by default.

Adam.py

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import math
2+
import torch
3+
from torch import Tensor
4+
from typing import List, Optional
5+
from torch.optim.optimizer import Optimizer
6+
7+
8+
def adam(params: List[Tensor],
9+
grads: List[Tensor],
10+
exp_avgs: List[Tensor],
11+
exp_avg_sqs: List[Tensor],
12+
max_exp_avg_sqs: List[Tensor],
13+
state_steps: List[int],
14+
*,
15+
amsgrad: bool,
16+
beta1: float,
17+
beta2: float,
18+
lr: float,
19+
weight_decay: float,
20+
eps: float):
21+
r"""Functional API that performs Adam algorithm computation.
22+
See :class:`~torch.optim.Adam` for details.
23+
"""
24+
25+
for i, param in enumerate(params):
26+
27+
grad = grads[i]
28+
exp_avg = exp_avgs[i]
29+
exp_avg_sq = exp_avg_sqs[i]
30+
step = state_steps[i]
31+
32+
bias_correction1 = 1 - beta1 ** step
33+
bias_correction2 = 1 - beta2 ** step
34+
35+
if weight_decay != 0:
36+
grad = grad.add(param, alpha=weight_decay)
37+
38+
# Decay the first and second moment running average coefficient
39+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
40+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
41+
if amsgrad:
42+
# Maintains the maximum of all 2nd moment running avg. till now
43+
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
44+
# Use the max. for normalizing running avg. of gradient
45+
denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
46+
else:
47+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
48+
49+
step_size = lr / bias_correction1
50+
51+
param.addcdiv_(exp_avg, denom, value=-step_size)
52+
53+
54+
class Adam(Optimizer):
55+
r"""Implements Adam algorithm.
56+
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
57+
The implementation of the L2 penalty follows changes proposed in
58+
`Decoupled Weight Decay Regularization`_.
59+
Args:
60+
params (iterable): iterable of parameters to optimize or dicts defining
61+
parameter groups
62+
lr (float, optional): learning rate (default: 1e-3)
63+
betas (Tuple[float, float], optional): coefficients used for computing
64+
running averages of gradient and its square (default: (0.9, 0.999))
65+
eps (float, optional): term added to the denominator to improve
66+
numerical stability (default: 1e-8)
67+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
68+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
69+
algorithm from the paper `On the Convergence of Adam and Beyond`_
70+
(default: False)
71+
.. _Adam\: A Method for Stochastic Optimization:
72+
https://arxiv.org/abs/1412.6980
73+
.. _Decoupled Weight Decay Regularization:
74+
https://arxiv.org/abs/1711.05101
75+
.. _On the Convergence of Adam and Beyond:
76+
https://openreview.net/forum?id=ryQu7f-RZ
77+
"""
78+
79+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
80+
weight_decay=0, amsgrad=False):
81+
if not 0.0 <= lr:
82+
raise ValueError("Invalid learning rate: {}".format(lr))
83+
if not 0.0 <= eps:
84+
raise ValueError("Invalid epsilon value: {}".format(eps))
85+
if not 0.0 <= betas[0] < 1.0:
86+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
87+
if not 0.0 <= betas[1] < 1.0:
88+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
89+
if not 0.0 <= weight_decay:
90+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
91+
defaults = dict(lr=lr, betas=betas, eps=eps,
92+
weight_decay=weight_decay, amsgrad=amsgrad)
93+
super(Adam, self).__init__(params, defaults)
94+
95+
def __setstate__(self, state):
96+
super(Adam, self).__setstate__(state)
97+
for group in self.param_groups:
98+
group.setdefault('amsgrad', False)
99+
100+
@torch.no_grad()
101+
def step(self, closure=None):
102+
"""Performs a single optimization step.
103+
Args:
104+
closure (callable, optional): A closure that reevaluates the model
105+
and returns the loss.
106+
"""
107+
loss = None
108+
if closure is not None:
109+
with torch.enable_grad():
110+
loss = closure()
111+
112+
for group in self.param_groups:
113+
params_with_grad = []
114+
grads = []
115+
exp_avgs = []
116+
exp_avg_sqs = []
117+
max_exp_avg_sqs = []
118+
state_steps = []
119+
beta1, beta2 = group['betas']
120+
121+
for p in group['params']:
122+
if p.grad is not None:
123+
params_with_grad.append(p)
124+
if p.grad.is_sparse:
125+
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
126+
grads.append(p.grad)
127+
128+
state = self.state[p]
129+
# Lazy state initialization
130+
if len(state) == 0:
131+
state['step'] = 0
132+
# Exponential moving average of gradient values
133+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
134+
# Exponential moving average of squared gradient values
135+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
136+
if group['amsgrad']:
137+
# Maintains max of all exp. moving avg. of sq. grad. values
138+
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
139+
140+
exp_avgs.append(state['exp_avg'])
141+
exp_avg_sqs.append(state['exp_avg_sq'])
142+
143+
if group['amsgrad']:
144+
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
145+
146+
# update the steps for each param group update
147+
state['step'] += 1
148+
# record the step after step update
149+
state_steps.append(state['step'])
150+
151+
adam(params_with_grad,
152+
grads,
153+
exp_avgs,
154+
exp_avg_sqs,
155+
max_exp_avg_sqs,
156+
state_steps,
157+
amsgrad=group['amsgrad'],
158+
beta1=beta1,
159+
beta2=beta2,
160+
lr=group['lr'],
161+
weight_decay=group['weight_decay'],
162+
eps=group['eps'])
163+
return loss
164+
165+
166+
167+
168+

CustomDataset.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from torch.utils.data import Dataset
2+
import random
3+
import os
4+
import torch
5+
6+
def GLOBAL_to_LGR_path(global_lists, key, names, var):
7+
lgr_list = []
8+
for path in global_lists:
9+
case = path.split('/')[-1]
10+
slope = case[:7]
11+
idx = case.split('_')[2]
12+
for nwell in range(1,5):
13+
if var == 'dP':
14+
string = f'{slope}_{idx}_{key}_WELL{nwell}_DP.pt'
15+
if string in names:
16+
home_path = f'/dP_{key}/'
17+
lgr_list.append(home_path + string)
18+
elif var == 'SG':
19+
string = f'{slope}_{idx}_{key}_WELL{nwell}_SG.pt'
20+
if string in names:
21+
home_path = f'/SG_{key}/'
22+
lgr_list.append(home_path + string)
23+
24+
return lgr_list
25+
26+
class CustomDataset(Dataset):
27+
def __init__(self, root_path, names):
28+
self.names = names
29+
self.root_path = root_path
30+
31+
def __len__(self):
32+
return len(self.names)
33+
34+
def __getitem__(self, idx):
35+
path = self.names[idx]
36+
data = torch.load(self.root_path+path)
37+
38+
name = path.split('/')[-1]
39+
slope, idx, well = name[:7], name.split('_')[2], name.split('_')[-2]
40+
41+
x = data['input'].permute(0,4,1,2,3,5)[0,...]
42+
y = data['output'].permute(0,4,1,2,3,5)[0,...,:1]
43+
44+
D = {'x': x,
45+
'y': y,
46+
'path': [slope, idx, well]}
47+
return D

DATA_LOADER_DICT.pth

20.9 KB
Binary file not shown.

FNO4D.py

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import torch
2+
import numpy as np
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
6+
import operator
7+
from functools import reduce
8+
from functools import partial
9+
10+
from timeit import default_timer
11+
12+
torch.manual_seed(0)
13+
np.random.seed(0)
14+
15+
class SpectralConv4d(nn.Module):
16+
def __init__(self, in_channels, out_channels, modes1, modes2, modes3, modes4):
17+
super(SpectralConv4d, self).__init__()
18+
19+
"""
20+
4D Fourier layer. It does FFT, linear transform, and Inverse FFT.
21+
"""
22+
23+
self.in_channels = in_channels
24+
self.out_channels = out_channels
25+
self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1
26+
self.modes2 = modes2
27+
self.modes3 = modes3
28+
self.modes4 = modes4
29+
30+
self.scale = (1 / (in_channels * out_channels))
31+
self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat))
32+
self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat))
33+
self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat))
34+
self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat))
35+
self.weights5 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat))
36+
self.weights6 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat))
37+
self.weights7 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat))
38+
self.weights8 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, self.modes4, dtype=torch.cfloat))
39+
40+
# Complex multiplication
41+
def compl_mul4d(self, input, weights):
42+
# (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t)
43+
return torch.einsum("bixyzt,ioxyzt->boxyzt", input, weights)
44+
45+
def forward(self, x):
46+
batchsize = x.shape[0]
47+
#Compute Fourier coeffcients up to factor of e^(- something constant)
48+
x_ft = torch.fft.rfftn(x, dim=[-4,-3,-2,-1])
49+
50+
# Multiply relevant Fourier modes
51+
out_ft = torch.zeros(batchsize, self.out_channels, x.size(-4), x.size(-3), x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device)
52+
53+
out_ft[:, :, :self.modes1, :self.modes2, :self.modes3, :self.modes4] = self.compl_mul4d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3, :self.modes4], self.weights1)
54+
out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3, :self.modes4] = self.compl_mul4d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3, :self.modes4], self.weights2)
55+
out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3, :self.modes4] = self.compl_mul4d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3, :self.modes4], self.weights3)
56+
out_ft[:, :, :self.modes1, :self.modes2, -self.modes3:, :self.modes4] = self.compl_mul4d(x_ft[:, :, :self.modes1, :self.modes2, -self.modes3:, :self.modes4], self.weights4)
57+
out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3, :self.modes4] = self.compl_mul4d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3, :self.modes4], self.weights5)
58+
out_ft[:, :, -self.modes1:, :self.modes2, -self.modes3:, :self.modes4] = self.compl_mul4d(x_ft[:, :, -self.modes1:, :self.modes2, -self.modes3:, :self.modes4], self.weights6)
59+
out_ft[:, :, :self.modes1, -self.modes2:, -self.modes3:, :self.modes4] = self.compl_mul4d(x_ft[:, :, :self.modes1, -self.modes2:, -self.modes3:, :self.modes4], self.weights7)
60+
out_ft[:, :, -self.modes1:, -self.modes2:, -self.modes3:, :self.modes4] = self.compl_mul4d(x_ft[:, :, -self.modes1:, -self.modes2:, -self.modes3:, :self.modes4], self.weights8)
61+
62+
#Return to physical space
63+
x = torch.fft.irfftn(out_ft, s=(x.size(-4), x.size(-3), x.size(-2), x.size(-1)))
64+
return x
65+
66+
class Block4d(nn.Module):
67+
def __init__(self, width, width2, modes1, modes2, modes3, modes4, out_dim):
68+
super(Block4d, self).__init__()
69+
self.modes1 = modes1
70+
self.modes2 = modes2
71+
self.modes3 = modes3
72+
self.modes4 = modes4
73+
74+
self.width = width
75+
self.width2 = width2
76+
self.out_dim = out_dim
77+
self.padding = 8
78+
79+
# channel
80+
self.conv0 = SpectralConv4d(self.width, self.width, self.modes1, self.modes2, self.modes3, self.modes4)
81+
self.conv1 = SpectralConv4d(self.width, self.width, self.modes1, self.modes2, self.modes3, self.modes4)
82+
self.conv2 = SpectralConv4d(self.width, self.width, self.modes1, self.modes2, self.modes3, self.modes4)
83+
self.conv3 = SpectralConv4d(self.width, self.width, self.modes1, self.modes2, self.modes3, self.modes4)
84+
self.w0 = nn.Conv1d(self.width, self.width, 1)
85+
self.w1 = nn.Conv1d(self.width, self.width, 1)
86+
self.w2 = nn.Conv1d(self.width, self.width, 1)
87+
self.w3 = nn.Conv1d(self.width, self.width, 1)
88+
self.fc1 = nn.Linear(self.width, self.width2)
89+
self.fc2 = nn.Linear(self.width2, self.out_dim)
90+
91+
def forward(self, x):
92+
batchsize = x.shape[0]
93+
size_x, size_y, size_z, size_t = x.shape[2], x.shape[3], x.shape[4], x.shape[5]
94+
# print(size_x, size_y, size_z, size_t)
95+
# channel
96+
# print(x.shape)
97+
x1 = self.conv0(x)
98+
# print(x1.shape)
99+
x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z, size_t)
100+
x = x1 + x2
101+
x = F.gelu(x)
102+
103+
x1 = self.conv1(x)
104+
x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z, size_t)
105+
x = x1 + x2
106+
x = F.gelu(x)
107+
108+
x1 = self.conv2(x)
109+
x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z, size_t)
110+
x = x1 + x2
111+
x = F.gelu(x)
112+
113+
x1 = self.conv3(x)
114+
x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y, size_z, size_t)
115+
x = x1 + x2
116+
117+
x = x[:, :, self.padding:-self.padding, self.padding*2:-self.padding*2,
118+
self.padding*2:-self.padding*2, self.padding:-self.padding]
119+
120+
x = x.permute(0, 2, 3, 4, 5, 1) # pad the domain if input is non-periodic
121+
x1 = self.fc1(x)
122+
x = F.gelu(x1)
123+
x = self.fc2(x)
124+
125+
return x
126+
127+
class FNO4d(nn.Module):
128+
def __init__(self, modes1, modes2, modes3, modes4, width, in_dim):
129+
super(FNO4d, self).__init__()
130+
131+
self.modes1 = modes1
132+
self.modes2 = modes2
133+
self.modes3 = modes3
134+
self.modes4 = modes4
135+
self.width = width
136+
self.width2 = width*4
137+
self.in_dim = in_dim
138+
self.out_dim = 1
139+
self.padding = 8 # pad the domain if input is non-periodic
140+
141+
self.fc0 = nn.Linear(self.in_dim, self.width)
142+
self.conv = Block4d(self.width, self.width2,
143+
self.modes1, self.modes2, self.modes3, self.modes4, self.out_dim)
144+
145+
def forward(self, x, gradient=False):
146+
x = self.fc0(x)
147+
x = x.permute(0, 5, 1, 2, 3, 4)
148+
x = F.pad(x, [self.padding, self.padding, self.padding*2, self.padding*2, self.padding*2,
149+
self.padding*2, self.padding, self.padding])
150+
151+
x = self.conv(x)
152+
153+
return x
154+
155+
156+

0 commit comments

Comments
 (0)