-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_list.py
72 lines (58 loc) · 2 KB
/
image_list.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
from PIL import Image
from torch.utils.data import Dataset
def load_image(img_path):
img = Image.open(img_path)
img = img.convert("RGB")
return img
class ImageList(Dataset):
def __init__(
self,
image_root: str,
label_file: str,
transform=None,
pseudo_item_list=None,
):
self.image_root = image_root
self._label_file = label_file
self.transform = transform
assert (
label_file or pseudo_item_list
), f"Must provide either label file or pseudo labels."
self.item_list = (
self.build_index(label_file) if label_file else pseudo_item_list
)
def build_index(self, label_file):
"""Build a list of <image path, class label> items.
Args:
label_file: path to the domain-net label file
Returns:
item_list: a list of <image path, class label> items.
"""
# read in items; each item takes one line
with open(label_file, "r") as fd:
lines = fd.readlines()
lines = [line.strip() for line in lines if line]
item_list = []
for item in lines:
img_file, label = item.split()
img_path = os.path.join(self.image_root, img_file)
label = int(label)
item_list.append((img_path, label, img_file))
return item_list
def __getitem__(self, idx):
"""Retrieve data for one item.
Args:
idx: index of the dataset item.
Returns:
img: <C, H, W> tensor of an image
label: int or <C, > tensor, the corresponding class label. when using raw label
file return int, when using pseudo label list return <C, > tensor.
"""
img_path, label, _ = self.item_list[idx]
img = load_image(img_path)
if self.transform:
img = self.transform(img)
return img_path, img, label, idx
def __len__(self):
return len(self.item_list)