Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement labels argument for create_dataset (and downstream functions) #2418

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions timm/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
self,
root,
reader=None,
labels=None,
split='train',
class_map=None,
load_bytes=False,
Expand All @@ -36,6 +37,7 @@ def __init__(
reader = create_reader(
reader or '',
root=root,
labels=labels,
split=split,
class_map=class_map,
**kwargs,
Expand Down Expand Up @@ -89,6 +91,7 @@ def __init__(
self,
root,
reader=None,
labels=None,
split='train',
class_map=None,
is_training=False,
Expand All @@ -110,6 +113,7 @@ def __init__(
self.reader = create_reader(
reader,
root=root,
labels=labels,
split=split,
class_map=class_map,
is_training=is_training,
Expand Down
10 changes: 9 additions & 1 deletion timm/data/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Hacked together by / Copyright 2021, Ross Wightman
"""
import os
from typing import Optional
from typing import Optional, Union, Dict

from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder
try:
Expand Down Expand Up @@ -63,6 +63,7 @@ def _try(syn):
def create_dataset(
name: str,
root: Optional[str] = None,
labels: Optional[Union[Dict, str]] = None,
split: str = 'validation',
search_split: bool = True,
class_map: dict = None,
Expand Down Expand Up @@ -91,6 +92,7 @@ def create_dataset(
Args:
name: Dataset name, empty is okay for folder based datasets
root: Root folder of dataset (All)
labels: Specify filename -> label mapping via file or dict (Folder)
split: Dataset split (All)
search_split: Search for split specific child fold from root so one can specify
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (Folder, Torch)
Expand All @@ -112,6 +114,7 @@ def create_dataset(
kwargs = {k: v for k, v in kwargs.items() if v is not None}
name = name.lower()
if name.startswith('torch/'):
assert labels is None, "Argument 'labels' incompatible with name 'torch/...'"
name = name.split('/', 2)[-1]
torch_kwargs = dict(root=root, download=download, **kwargs)
if name in _TORCH_BASIC_DS:
Expand Down Expand Up @@ -162,6 +165,7 @@ def create_dataset(
ds = ImageDataset(
root,
reader=name,
labels=labels,
split=split,
class_map=class_map,
input_img_mode=input_img_mode,
Expand All @@ -172,6 +176,7 @@ def create_dataset(
ds = IterableImageDataset(
root,
reader=name,
labels=labels,
split=split,
class_map=class_map,
is_training=is_training,
Expand All @@ -188,6 +193,7 @@ def create_dataset(
ds = IterableImageDataset(
root,
reader=name,
labels=labels,
split=split,
class_map=class_map,
is_training=is_training,
Expand All @@ -203,6 +209,7 @@ def create_dataset(
ds = IterableImageDataset(
root,
reader=name,
labels=labels,
split=split,
class_map=class_map,
is_training=is_training,
Expand All @@ -221,6 +228,7 @@ def create_dataset(
ds = ImageDataset(
root,
reader=name,
labels=labels,
class_map=class_map,
load_bytes=load_bytes,
input_img_mode=input_img_mode,
Expand Down
10 changes: 9 additions & 1 deletion timm/data/readers/reader_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Optional, Union, Dict

from .reader_image_folder import ReaderImageFolder
from .reader_image_in_tar import ReaderImageInTar
Expand All @@ -8,6 +8,7 @@
def create_reader(
name: str,
root: Optional[str] = None,
labels: Optional[Union[Dict, str]] = None,
split: str = 'train',
**kwargs,
):
Expand All @@ -19,6 +20,13 @@ def create_reader(
prefix = name[0]
name = name[-1]

if isinstance(labels, str):
import json
with open(labels, 'r') as labels_file:
labels = json.load(labels_file)
if labels is not None:
kwargs["labels"] = labels

# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'hfds':
Expand Down
21 changes: 14 additions & 7 deletions timm/data/readers/reader_image_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
def find_images_and_targets(
folder: str,
types: Optional[Union[List, Tuple, Set]] = None,
labels: Optional[Dict] = None,
class_to_idx: Optional[Dict] = None,
leaf_name_only: bool = True,
sort: bool = True
Expand All @@ -27,6 +28,7 @@ def find_images_and_targets(
Args:
folder: root of folder to recursively search
types: types (file extensions) to search for in path
labels: specify filename -> label mapping (and ignore the subfolder structure)
class_to_idx: specify mapping for class (folder name) to class index if set
leaf_name_only: use only leaf-name of folder walk for class names
sort: re-sort found images by name (for consistent ordering)
Expand All @@ -35,22 +37,25 @@ def find_images_and_targets(
A list of image and target tuples, class_to_idx mapping
"""
types = get_img_extensions(as_set=True) if not types else set(types)
labels = []
filenames = []
file_labels = []
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
if labels is None:
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
for f in files:
base, ext = os.path.splitext(f)
if ext.lower() in types:
if labels is not None:
label = labels[f]
filenames.append(os.path.join(root, f))
labels.append(label)
file_labels.append(label)
if class_to_idx is None:
# building class index
unique_labels = set(labels)
unique_labels = set(file_labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, file_labels) if l in class_to_idx]
if sort:
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
return images_and_targets, class_to_idx
Expand All @@ -61,6 +66,7 @@ class ReaderImageFolder(Reader):
def __init__(
self,
root,
labels=None,
class_map='',
input_key=None,
):
Expand All @@ -75,8 +81,9 @@ def __init__(
find_types = input_key.split(';')
self.samples, self.class_to_idx = find_images_and_targets(
root,
class_to_idx=class_to_idx,
types=find_types,
labels=labels,
class_to_idx=class_to_idx,
)
if len(self.samples) == 0:
raise RuntimeError(
Expand Down
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
help='path to dataset (root dir)')
parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
group.add_argument('--labels', metavar='FILENAME',
help='File containing the filename to label associations.')
group.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',
Expand Down Expand Up @@ -656,6 +658,7 @@ def main():
dataset_train = create_dataset(
args.dataset,
root=args.data_dir,
labels=args.labels,
split=args.train_split,
is_training=True,
class_map=args.class_map,
Expand All @@ -674,6 +677,7 @@ def main():
dataset_eval = create_dataset(
args.dataset,
root=args.data_dir,
labels=args.labels,
split=args.val_split,
is_training=False,
class_map=args.class_map,
Expand Down