-
Notifications
You must be signed in to change notification settings - Fork 12
/
dataloader.py
44 lines (34 loc) · 1.32 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import glob
import os
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
from torchvision.transforms import (Compose, Normalize, Resize, ToPILImage,
ToTensor)
class SegmentationDataset(Dataset):
def __init__(self, folder_path, transforms):
super(SegmentationDataset, self).__init__()
self.images = glob.glob(os.path.join(folder_path, 'images', '*.jpg'))
self.masks = []
for image in self.images:
mask_path = os.path.join(
folder_path, 'masks', f'{os.path.basename(image)}.png')
self.masks.append(mask_path)
self.transforms = transforms
assert (len(self.images) == len(self.masks))
def __getitem__(self, index):
img_path = self.images[index]
target_path = self.masks[index]
img = Image.open(img_path).convert('RGB')
target = Image.open(target_path)
if self.transforms is not None:
img, target = self.transforms(img, target)
if len(target.shape) == 2:
target = target.reshape((1,)+target.shape)
if len(img.shape) == 2:
img = img.reshape((1,)+img.shape)
return img, target
def __len__(self):
return len(self.images)