Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions simple_shapes_dataset/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(
)

self.dataset_path = Path(dataset_path)
self.domain_classes = domain_classes
self.domain_proportions = domain_proportions
self.seed = seed
self.ood_seed = ood_seed
Expand All @@ -70,6 +69,7 @@ def __init__(
self._train_transform = train_transforms or {}
self._val_transform = val_transforms or {}
self._use_default_transforms = use_default_transforms
self.domain_classes = self.get_domain_classes(domain_classes)

self.max_train_size = max_train_size
self.batch_size = batch_size
Expand All @@ -85,6 +85,34 @@ def __init__(

self._collate_fn = collate_fn

def get_domain_classes(
self, domain_classes: Mapping[DomainDesc, type[DataDomain]]
) -> dict[str, dict[DomainDesc, DataDomain]]:
all_domain_classes: dict[str, dict[DomainDesc, DataDomain]] = {
"train": {},
"val": {},
"test": {},
}

self.domains = {domain.kind for domain in domain_classes}

for split in ["train", "val", "test"]:
transforms = self._get_transforms(self.domains, split)

for domain, domain_cls in domain_classes.items():
transform = None
if transforms is not None and domain.kind in transforms:
transform = transforms[domain.kind]

all_domain_classes[split][domain] = domain_cls(
self.dataset_path,
split,
transform,
self.domain_args.get(domain.kind, None),
)

return all_domain_classes

def _get_transforms(
self, domains: Iterable[str], mode: str
) -> dict[str, Callable[[Any], Any]]:
Expand Down Expand Up @@ -117,22 +145,22 @@ def _requires_aligned_dataset(self) -> bool:
return True
return False

def _get_selected_domains(self) -> set[str]:
return {domain.kind for domain in self.domain_classes}
def _get_selected_domains(self, split: str) -> set[str]:
return {domain.kind for domain in self.domain_classes[split]}

def _get_dataset(self, split: str) -> Mapping[frozenset[str], DatasetT]:
assert split in ("train", "val", "test")

domains = self._get_selected_domains()
domains = self.domains

if split == "train" and self._requires_aligned_dataset():
if self._requires_aligned_dataset():
if self.seed is None:
raise ValueError("Seed must be provided when using aligned dataset")

return get_aligned_datasets(
self.dataset_path,
split,
self.domain_classes,
self.domain_classes[split],
self.domain_proportions,
self.seed,
self.max_train_size,
Expand All @@ -145,7 +173,7 @@ def _get_dataset(self, split: str) -> Mapping[frozenset[str], DatasetT]:
frozenset(domains): SimpleShapesDataset(
self.dataset_path,
split,
self.domain_classes,
self.domain_classes[split],
transforms=self._get_transforms(domains, split),
domain_args=self.domain_args,
)
Expand All @@ -156,7 +184,7 @@ def _get_dataset(self, split: str) -> Mapping[frozenset[str], DatasetT]:
split,
{
domain_type: domain_cls
for domain_type, domain_cls in self.domain_classes.items()
for domain_type, domain_cls in self.domain_classes[split].items()
if domain_type.kind == domain
},
self.max_train_size,
Expand Down Expand Up @@ -258,9 +286,10 @@ def val_dataloader(
assert self.val_dataset is not None

dataloaders = {}
max_sized_dataset = max(len(dataset) for dataset in self.val_dataset.values())
for domain, dataset in self.val_dataset.items():
dataloaders[domain] = DataLoader(
dataset,
RepeatedDataset(dataset, max_sized_dataset, drop_last=False),
pin_memory=True,
batch_size=self.batch_size,
num_workers=self.num_workers,
Expand All @@ -275,17 +304,18 @@ def val_dataloader(
num_workers=self.num_workers,
collate_fn=self._collate_fn,
)
return CombinedLoader(dataloaders, mode="sequential")
return CombinedLoader(dataloaders, mode="min_size")

def test_dataloader(
self,
) -> CombinedLoader:
assert self.test_dataset is not None

dataloaders = {}
max_sized_dataset = max(len(dataset) for dataset in self.test_dataset.values())
for domain, dataset in self.test_dataset.items():
dataloaders[domain] = DataLoader(
dataset,
RepeatedDataset(dataset, max_sized_dataset, drop_last=False),
pin_memory=True,
batch_size=self.batch_size,
num_workers=self.num_workers,
Expand All @@ -300,7 +330,7 @@ def test_dataloader(
num_workers=self.num_workers,
collate_fn=self._collate_fn,
)
return CombinedLoader(dataloaders, mode="sequential")
return CombinedLoader(dataloaders, mode="min_size")

def predict_dataloader(self):
assert self.val_dataset is not None
Expand Down
13 changes: 2 additions & 11 deletions simple_shapes_dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self,
dataset_path: str | Path,
split: str,
domain_classes: Mapping[DomainDesc, type[DataDomain]],
domain_classes: dict[DomainDesc, DataDomain],
max_size: int | None = None,
transforms: Mapping[str, Callable[[Any], Any]] | None = None,
domain_args: Mapping[str, Any] | None = None,
Expand All @@ -84,16 +84,7 @@ def __init__(
self.domain_args = domain_args or {}

for domain, domain_cls in domain_classes.items():
transform = None
if transforms is not None and domain.kind in transforms:
transform = transforms[domain.kind]

self.domains[domain.kind] = domain_cls(
dataset_path,
split,
transform,
self.domain_args.get(domain.kind, None),
)
self.domains[domain.kind] = domain_cls

lengths = {len(domain) for domain in self.domains.values()}
min_length = min(lengths)
Expand Down
22 changes: 21 additions & 1 deletion simple_shapes_dataset/domain.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -387,3 +387,23 @@ def get_default_domains(
domain = DomainType[domain].value
domain_classes[domain] = DEFAULT_DOMAINS[domain.kind]
return domain_classes


def get_default_domains_dataset(
domains: Iterable[DomainDesc | str],
dataset_path: str | Path,
split: str,
transforms: Mapping[str, Callable[[Any], Any]],
domain_args: Mapping[str, Any] = {},
) -> dict[DomainDesc, DataDomain]:
domain_classes = {}
for domain in domains:
if isinstance(domain, str):
domain = DomainType[domain].value
domain_classes[domain] = DEFAULT_DOMAINS[domain.kind](
dataset_path,
split,
transforms[domain.kind],
domain_args.get(domain.kind, None),
)
return domain_classes
2 changes: 1 addition & 1 deletion simple_shapes_dataset/domain_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_alignment(
def get_aligned_datasets(
dataset_path: str | Path,
split: str,
domain_classes: Mapping[DomainDesc, type[DataDomain]],
domain_classes: Mapping[DomainDesc, DataDomain],
domain_proportions: Mapping[frozenset[str], float],
seed: int,
max_size: int | None = None,
Expand Down
36 changes: 31 additions & 5 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
from torch.utils.data.dataloader import DataLoader
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose, ToTensor
from utils import PROJECT_DIR

from simple_shapes_dataset.data_module import SimpleShapesDataModule
from simple_shapes_dataset.dataset import SimpleShapesDataset
from simple_shapes_dataset.domain import (
get_default_domains,
get_default_domains_dataset,
)
from simple_shapes_dataset.domain_alignment import get_aligned_datasets
from simple_shapes_dataset.pre_process import attribute_to_tensor


def test_dataset():
transform = {"attr": Compose([]), "v": Compose([])}

dataset = SimpleShapesDataset(
PROJECT_DIR / "sample_dataset",
split="train",
domain_classes=get_default_domains(["v", "attr"]),
domain_classes=get_default_domains_dataset(
["v", "attr"],
PROJECT_DIR / "sample_dataset",
split="train",
transforms=transform,
),
)

assert len(dataset) == 4
Expand All @@ -26,10 +34,16 @@ def test_dataset():


def test_dataset_val():
transform = {"attr": Compose([]), "v": Compose([])}
dataset = SimpleShapesDataset(
PROJECT_DIR / "sample_dataset",
split="val",
domain_classes=get_default_domains(["v", "attr"]),
domain_classes=get_default_domains_dataset(
["v", "attr"],
PROJECT_DIR / "sample_dataset",
split="val",
transforms=transform,
),
)

assert len(dataset) == 2
Expand All @@ -44,7 +58,12 @@ def test_dataloader():
dataset = SimpleShapesDataset(
PROJECT_DIR / "sample_dataset",
split="train",
domain_classes=get_default_domains(["v", "attr"]),
domain_classes=get_default_domains_dataset(
["v", "attr"],
PROJECT_DIR / "sample_dataset",
split="train",
transforms=transform,
),
transforms=transform,
)

Expand All @@ -55,10 +74,17 @@ def test_dataloader():


def test_get_aligned_datasets():
transform = {"t": Compose([]), "v": Compose([])}

datasets = get_aligned_datasets(
PROJECT_DIR / "sample_dataset",
"train",
domain_classes=get_default_domains(["v", "t"]),
domain_classes=get_default_domains_dataset(
["v", "t"],
PROJECT_DIR / "sample_dataset",
split="train",
transforms=transform,
),
domain_proportions={
frozenset(["v", "t"]): 0.5,
frozenset("v"): 1.0,
Expand Down
33 changes: 7 additions & 26 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from utils import PROJECT_DIR

from simple_shapes_dataset.dataset import SimpleShapesDataset
from simple_shapes_dataset.domain import get_default_domains
from simple_shapes_dataset.domain import get_default_domains_dataset
from simple_shapes_dataset.pre_process import attribute_to_tensor


Expand All @@ -14,7 +14,12 @@ def test_attr_preprocess():
dataset = SimpleShapesDataset(
PROJECT_DIR / "sample_dataset",
split="train",
domain_classes=get_default_domains(["attr"]),
domain_classes=get_default_domains_dataset(
["attr"],
PROJECT_DIR / "sample_dataset",
split="train",
transforms=transform,
),
transforms=transform,
)

Expand All @@ -26,27 +31,3 @@ def test_attr_preprocess():
assert item["attr"][0].shape == (2, 3)
assert isinstance(item["attr"][1], torch.Tensor)
assert item["attr"][1].shape == (2, 8)


def test_attr_preprocess_with_unpaired():
transform = {
"attr": attribute_to_tensor,
}
dataset = SimpleShapesDataset(
PROJECT_DIR / "sample_dataset",
split="train",
domain_classes=get_default_domains(["attr"]),
domain_args={"attr": {"n_unpaired": 1}},
transforms=transform,
)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
item = next(iter(dataloader))
assert isinstance(item["attr"], list)
assert len(item["attr"]) == 3
assert isinstance(item["attr"][0], torch.Tensor)
assert item["attr"][0].shape == (2, 3)
assert isinstance(item["attr"][1], torch.Tensor)
assert item["attr"][1].shape == (2, 8)
assert isinstance(item["attr"][2], torch.Tensor)
assert item["attr"][2].shape == (2, 1)