-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathCustomDataset.py
47 lines (39 loc) · 1.43 KB
/
CustomDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from torch.utils.data import Dataset
import random
import os
import torch
def GLOBAL_to_LGR_path(global_lists, key, names, var):
lgr_list = []
for path in global_lists:
case = path.split('/')[-1]
slope = case[:7]
idx = case.split('_')[2]
for nwell in range(1,5):
if var == 'dP':
string = f'{slope}_{idx}_{key}_WELL{nwell}_DP.pt'
if string in names:
home_path = f'/dP_{key}/'
lgr_list.append(home_path + string)
elif var == 'SG':
string = f'{slope}_{idx}_{key}_WELL{nwell}_SG.pt'
if string in names:
home_path = f'/SG_{key}/'
lgr_list.append(home_path + string)
return lgr_list
class CustomDataset(Dataset):
def __init__(self, root_path, names):
self.names = names
self.root_path = root_path
def __len__(self):
return len(self.names)
def __getitem__(self, idx):
path = self.names[idx]
data = torch.load(self.root_path+path)
name = path.split('/')[-1]
slope, idx, well = name[:7], name.split('_')[2], name.split('_')[-2]
x = data['input'].permute(0,4,1,2,3,5)[0,...]
y = data['output'].permute(0,4,1,2,3,5)[0,...,:1]
D = {'x': x,
'y': y,
'path': [slope, idx, well]}
return D