Skip to content

Commit

Permalink
update surrogate case
Browse files Browse the repository at this point in the history
  • Loading branch information
SoloWayG committed Aug 2, 2024
1 parent 38af6c9 commit 52a6160
Show file tree
Hide file tree
Showing 23 changed files with 71 additions and 32,273 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ share/python-wheels/
*.egg
MANIFEST
.idea/

data_from_comsol/data/*.txt
result_tests
data_from_comsol/*.zip
#Train data for surrogate models
data_from_comsol/generated_data
data_from_comsol/test_gen_data
Expand Down
Binary file added data_from_comsol/comsol_files/Untitled.mph
Binary file not shown.
File renamed without changes.
File renamed without changes.
File renamed without changes.
7 changes: 0 additions & 7 deletions data_from_comsol/coords.txt

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
50 changes: 50 additions & 0 deletions gefest/surrogate_models/darcy/darcy_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import matplotlib.pyplot as plt
from neuralop.datasets import load_darcy_flow_small

train_loader, test_loaders, data_processor = load_darcy_flow_small(
n_train=100, batch_size=4,
test_resolutions=[32, 32], n_tests=[50, 50], test_batch_sizes=[4, 2],
)

train_dataset = train_loader.dataset
for res, test_loader in test_loaders.items():
print(res)
# Get first batch
batch = next(iter(test_loader))
x = batch['x']
y = batch['y']

print(f'Testing samples for res {res} have shape {x.shape[1:]}')


data = train_dataset[0]
x = data['x']
y = data['y']

print(f'Training sample have shape {x.shape[1:]}')


# Which sample to view
index = 0

data = train_dataset[index]
data = data_processor.preprocess(data, batched=False)
x = data['x']
y = data['y']
fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(2, 2, 1)
ax.imshow(x[0], cmap='gray')
ax.set_title('input x')
ax = fig.add_subplot(2, 2, 2)
ax.imshow(y.squeeze())
ax.set_title('input y')
ax = fig.add_subplot(2, 2, 3)
ax.imshow(x[1])
ax.set_title('x: 1st pos embedding')
ax = fig.add_subplot(2, 2, 4)
ax.imshow(x[2])
ax.set_title('x: 2nd pos embedding')
fig.suptitle('Visualizing one input sample', y=0.98)
plt.tight_layout()
fig.show()
print()
10 changes: 10 additions & 0 deletions gefest/surrogate_models/darcy/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
wandb
ruamel.yaml
configmypy
tensorly
tensorly-torch
torch-harmonics
matplotlib
opt-einsum
h5py
zarr
6 changes: 4 additions & 2 deletions gefest/surrogate_models/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
device = torch.device("cpu" if not torch.cuda.is_available() else 'cuda')
dataloader = create_single_dataloader(path_to_dir='data_from_comsol//test_gen_data',batch_size=10,shuffle=False)
model = AttU_Net(img_ch=1,output_ch=1).to(device)#UNet(in_channels=1,out_channels=1).to(device)
model.load_state_dict(torch.load(r'weights\unet_11_adam_Accum_2.pt',map_location=torch.device(device)))
CASE = 'att_11'
model.load_state_dict(torch.load(r'D:\Projects\GEFEST\GEFEST_surr\GEFEST\weights\best\unet_69_Adam_from01_lr20to3e104_bs4.pt',map_location=torch.device(device)))
CASE = 'unet_69_Adam_from01_lr20to3e104_bs4'
model.eval()
#predicts = []
truth = []
Expand All @@ -37,5 +37,7 @@
masks = np.concatenate((masks,x.cpu().numpy()))
#predicts+=y_pred.cpu().numpy()
data_to_save = {'flow_predict':predicts,"flow":truth,'mask':masks}
if not os.path.isdir(f'gefest/surrogate_models/gendata/{CASE}'):
os.mkdir(f'gefest/surrogate_models/gendata/{CASE}')
np.savez(f'gefest/surrogate_models/gendata/{CASE}/data',data_to_save,allow_pickle=False)
print()
6 changes: 4 additions & 2 deletions gefest/surrogate_models/utils/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def animation_npz(path_to_dir):
axs[2].clear()


time.sleep(2.2)
#time.sleep(2.2)
a1.remove(),a2.remove(),a3.remove()
print(summ_mae)
print(cnt)
Expand All @@ -91,4 +91,6 @@ def animation_npz(path_to_dir):
#animation_npz(path_to_dir='gefest\surrogate_models\gendata/ssim_23')
#animation_npz(path_to_dir='gefest\surrogate_models\gendata/ssim_plus_57')
#animation_data_npz(path_to_dir='data_from_comsol/gen_data_extend')
animation_npz(path_to_dir='gefest\surrogate_models\gendata/att_11')
#animation_npz(path_to_dir='gefest\surrogate_models\gendata/unet_68_Adam_att_unet_bi_wu_f01_t01_10ep_bs32_30_to_0001_ssim_2')
if __name__=='__main__':
animation_npz(path_to_dir=r'D:\Projects\GEFEST\GEFEST_surr\GEFEST\gefest\surrogate_models\gendata\unet_69_Adam_from01_lr20to3e104_bs4')
Loading

0 comments on commit 52a6160

Please sign in to comment.