Skip to content

Conversation

@Hazel-Heejeong-Nam
Copy link

What does this PR do?

This PR adds two new datasets to stable-datasets:

  • NotMNIST: Image classification dataset with letters A-J (10 classes)
  • CLEVRER: Video reasoning dataset with 20,000 synthetic videos

Requirements

⚠️ Important: CLEVRER dataset requires additional dependencies for video decoding:

conda install ffmpeg
uv pip install torch torchcodec 

Note

⚠️ Note on Download Implementation: CLEVRER video files are very large (~12GB for train.zip, ~6GB each for validation/test). During implementation, multiple download approaches were attempted:

  • bulk_download with ProcessPoolExecutor → Process killed
  • ThreadPoolExecutor → Process killed
  • Standard download() function → Process killed

All approaches resulted in the process being terminated (Killed) due to memory pressure or system limits when handling such large files.

Solution: Implemented a custom _wget_download() function that uses the system wget command with the -c flag, which:

  1. Handles HTTP redirects properly (the CLEVRER server uses redirects)
  2. Supports resume - if the download is interrupted, running again will continue from where it left off
  3. Does not load file contents into Python memory, avoiding OOM kills
  4. Shows download progress with --progress=bar:force:noscroll

Usage Examples

NotMNIST

from stable_datasets.images.not_mnist import NotMNIST

print("Loading NotMNIST dataset...")
notmnist_train = NotMNIST(split="train")
notmnist_test = NotMNIST(split="test")
notmnist_all = NotMNIST(split=None)

print(f"\nDataset Metadata:")
print(f"  - Homepage: {notmnist_train.info.homepage}")
print(f"  - Description: {notmnist_train.info.description}")
print(f"  - Citation:\n{notmnist_train.info.citation}")

print(f"\nDataset Statistics:")
print(f"  - Train samples: {len(notmnist_train)}")
print(f"  - Test samples: {len(notmnist_test)}")
print(f"  - Total splits: {len(notmnist_all)}")
print(f"  - Total samples: {len(notmnist_all['train']) + len(notmnist_all['test'])}")
print(f"  - Number of classes: {notmnist_train.features['label'].num_classes}")

sample = notmnist_train[0]
print(f"\nSample Information:")
print(f"  - Keys: {list(sample.keys())}")
print(f"  - Image type: {type(sample['image'])}")
print(f"  - Image size: {sample['image'].size}")
print(f"  - Label (int): {sample['label']}")
print(f"  - Label (string): {notmnist_train.features['label'].int2str(sample['label'])}")

print(f"\nAll classes (A-J):")
for i in range(10):
    print(f"  {i}: {notmnist_train.features['label'].names[i]}")

print("\nNotMNIST dataset loaded successfully!")

CLEVRER

import json

from stable_datasets.images.clevrer import CLEVRER

print("Loading CLEVRER dataset...")
clevrer_train = CLEVRER(split="train")
clevrer_val = CLEVRER(split="validation")
clevrer_test = CLEVRER(split="test")
clevrer_all = CLEVRER(split=None)

print(f"\nDataset Metadata:")
print(f"  - Homepage: {clevrer_train.info.homepage}")
print(f"  - Description: {clevrer_train.info.description}")
print(f"  - Citation:\n{clevrer_train.info.citation}")

print(f"\nDataset Statistics:")
print(f"  - Train samples: {len(clevrer_train)}")
print(f"  - Validation samples: {len(clevrer_val)}")
print(f"  - Test samples: {len(clevrer_test)}")
print(f"  - Total splits: {len(clevrer_all)}")
print(f"  - Total samples: {len(clevrer_all['train']) + len(clevrer_all['validation']) + len(clevrer_all['test'])}")

sample = clevrer_train[0]
print(f"\nSample Information:")
print(f"  - Keys: {list(sample.keys())}")
print(f"  - Video type: {type(sample['video'])}")
print(f"  - Scene index: {sample['scene_index']}")
print(f"  - Video filename: {sample['video_filename']}")

# Parse questions JSON
questions = json.loads(sample["questions_json"])
print(f"\nQuestions Information:")
print(f"  - Number of questions: {len(questions)}")
if len(questions) > 0:
    q = questions[0]
    print(f"  - First question: {q['question']}")
    print(f"  - Question type: {q['question_type']}")
    print(f"  - Answer: {q.get('answer', 'N/A')}")

# Parse annotations JSON
annotations = json.loads(sample["annotations_json"])
print(f"\nAnnotations Information:")
if annotations:
    print(f"  - Number of objects: {len(annotations.get('object_property', []))}")
    print(f"  - Number of collisions: {len(annotations.get('collision', []))}")
    if annotations.get('object_property'):
        obj = annotations['object_property'][0]
        print(f"  - First object: {obj['color']} {obj['material']} {obj['shape']}")
else:
    print("  - No annotations available (test split)")

print("\nCLEVRER dataset loaded successfully!")

Who can review?

@RandallBalestriero

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant