-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_loaders.py
35 lines (31 loc) · 1.1 KB
/
data_loaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from typing import Tuple
from torchvision import datasets
from torchvision.transforms import transforms
from torch.utils import data
from os.path import exists
# slightly modified version of the one used in another project
def get_data_loaders(batch_size: int, image_size: int) -> Tuple[data.DataLoader, data.DataLoader]:
if not exists("dataset"):
raise Exception("No dataset found. You need to put your directory with the images inside the dataset "
"directory")
t = transforms.Compose([transforms.Resize(64) ,transforms.ToTensor()])
dataset = datasets.ImageFolder("dataset", t)
test_set_size = 4
test_set, train_set = data.random_split(
dataset,
[test_set_size, len(dataset) - test_set_size],
)
data_loader_train = data.DataLoader(
train_set,
batch_size=batch_size,
num_workers=8,
shuffle=True,
pin_memory=True
)
data_loader_test = data.DataLoader(
test_set,
batch_size=test_set_size,
shuffle=False,
pin_memory=True
)
return data_loader_train, data_loader_test