-
Notifications
You must be signed in to change notification settings - Fork 0
/
Data.py
91 lines (79 loc) · 3.12 KB
/
Data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose
from skimage import io
from skimage import img_as_float
import os
import numpy as np
from skimage.transform import downscale_local_mean
import ipdb
from DataUtils import *
from TrainUtils import *
from UNet import *
##########################################################
def readImages(path):
image_list = []
pictures = os.listdir(path)
for string in pictures:
image = io.imread(path+string)
image_list.append(image)
return image_list
def stackImages(image_list):
number = len(image_list)
size, _ = image_list[0].shape
batch_array = np.zeros((number,size,size))
for i,pic in enumerate(image_list):
batch_array[i] = pic
return batch_array
def fixLabeling(labels):
labels[labels==0] = 1
return labels-1
class ParhyaleDataset(Dataset):
def __init__(self,image_path,label_path,factor=5,transform=None):
self.transform = transform
self.images = stackImages(readImages(image_path))
self.labels = stackImages(readImages(label_path))
self.labels = fixLabeling(self.labels)
if factor:
self.images = downsize(self.images,factor)
self.labels = downsize(self.labels,factor)
print("Mean pixel value-before transforms: ", np.mean(self.images[0]))
print("Percentage of cells in first image: ", np.mean(self.labels[0]))
# printVariance(self.images)
def fit(self,scalers):
for scaler in scalers:
scaler.fit(self.images)
def __len__(self):
return len(self.images)
def __getitem__(self,index):
image = self.images[index]
label = self.labels[index]
if self.transform:
return imageToTorch(self.transform(image)),labelToTorch(label)
else:
return imageToTorch(image), labelToTorch(label)
if __name__=='__main__':
################### **Creating Dataset** #########################
train_images_path = '/data/bbli/gryllus_disk_images/train/images/'
train_labels_path = '/data/bbli/gryllus_disk_images/train/labels/'
test_images_path = '/data/bbli/gryllus_disk_images/val/images/'
test_labels_path = '/data/bbli/gryllus_disk_images/val/labels/'
center = Standarize()
pad_size = 160
pad = Padder(pad_size)
transforms = Compose([center,pad])
# transforms = Compose ([ToTensor(),Standarize(0,1)])
##########################################################
train_dataset = ParhyaleDataset(train_images_path,train_labels_path,transform=transforms)
train_dataset.fit([center])
checkTrainSetMean(train_dataset)
test_dataset = ParhyaleDataset(test_images_path,test_labels_path,transform=transforms)
################### **Export Variables** #########################
train_loader = DataLoader(train_dataset,shuffle=True)
test_loader = DataLoader(test_dataset,shuffle=True)
img,label = next(iter(train_loader))
print("Pad size: ",pad_size)
# size = 700
# img = torch.Tensor(1,1,size,size)
# make into pytorch cuda variables