This directory contains usage examples demonstrating different patterns and use cases.
The simplest pattern: cache a frozen feature extractor during training.
python basic_usage.pyWhat it demonstrates:
- Creating a cache backend
- Wrapping a feature extractor with caching
- Using cached features in a training loop
- First epoch populates cache, later epochs reuse it
Key takeaway: Run the script multiple times - subsequent runs will be much faster as they reuse the cache!
More complex patterns for real-world scenarios.
python advanced_usage.pyWhat it demonstrates:
- K-Fold Cross-Validation: Share cache across folds for overlapping samples
- DDP Training: Single-writer pattern for distributed training
- Multiple Models: Multiple models sharing the same feature cache
- Complex Outputs: Caching modules with dict/tuple outputs
- Progressive Enrichment: Cache grows across multiple runs
Use case: You have a pretrained model (ResNet, BERT, etc.) and want to cache its outputs.
from torchcachex import ArrowIPCCacheBackend, CacheModuleDecorator
# Create backend
backend = ArrowIPCCacheBackend(
cache_dir="./cache",
module_id="resnet50_imagenet_v1", # Stable ID
lru_size=4096,
async_write=True,
)
# Wrap frozen feature extractor
feature_extractor = ResNet50Pretrained() # Your frozen model
cached_extractor = CacheModuleDecorator(
module=feature_extractor,
cache_backend=backend,
enabled=True,
enforce_stateless=True, # Verify no trainable params
)
# Use in training
for batch in dataloader:
# Cache lookup (or compute if miss)
features = cached_extractor(
batch["images"],
cache_ids=batch["sample_ids"] # Must be stable!
)
# Train your classifier
logits = classifier(features)
loss = criterion(logits, batch["labels"])
# ...Use case: Run K-fold CV without recomputing features for overlapping samples.
# Single cache shared across all folds
backend = ArrowIPCCacheBackend(
cache_dir="./cache",
module_id="features_v1",
)
cached_extractor = CacheModuleDecorator(feature_extractor, backend, enabled=True)
for fold in range(K):
train_loader, val_loader = get_fold_loaders(fold)
# Train fold (features cached progressively)
for batch in train_loader:
features = cached_extractor(batch["input"], cache_ids=batch["ids"])
# ... train ...
# Validate (reuses cached features)
for batch in val_loader:
features = cached_extractor(batch["input"], cache_ids=batch["ids"])
# ... evaluate ...Benefit: Fold N reuses all features computed in folds 0..N-1.
Use case: Train with multiple GPUs, cache on shared filesystem.
import os
backend = ArrowIPCCacheBackend(
cache_dir="/shared/cache", # Shared across ranks
module_id="features_v1",
writer_rank=0, # Only rank 0 writes
current_rank=int(os.getenv("RANK", 0)), # From DDP
)
cached_extractor = CacheModuleDecorator(feature_extractor, backend, enabled=True)
# All ranks compute, only rank 0 writes cache
for batch in dataloader:
features = cached_extractor(batch["input"], cache_ids=batch["ids"])
# ... train ...Benefit: First epoch is same speed, all subsequent epochs are fast on all ranks.
Use case: Train multiple models (e.g., ensembles) that share the same features.
# Single cache for shared features
backend = ArrowIPCCacheBackend(
cache_dir="./cache",
module_id="resnet50_features_v1", # Same ID!
)
# Multiple models use the same cached features
cached_extractor = CacheModuleDecorator(feature_extractor, backend, enabled=True)
model_a = ModelA(features=cached_extractor)
model_b = ModelB(features=cached_extractor)
model_c = ModelC(features=cached_extractor)
# Train model A (populates cache)
train(model_a, train_loader)
# Train models B and C (reuse cache from A)
train(model_b, train_loader) # Fast!
train(model_c, train_loader) # Fast!Use case: Your module returns dicts, tuples, or mixed types.
class MultiHeadExtractor(nn.Module):
def forward(self, x):
return {
"visual": self.visual_head(x), # Tensor
"semantic": self.semantic_head(x), # Tensor
"metadata": {"width": x.shape[-1]}, # Non-tensor (pickled)
}
# Decorator handles complex structures automatically
cached = CacheModuleDecorator(
MultiHeadExtractor(),
backend,
enabled=True
)
outputs = cached(batch["input"], cache_ids=batch["ids"])
# outputs["visual"] → Tensor (native storage, dtype preserved)
# outputs["semantic"] → Tensor (native storage, dtype preserved)
# outputs["metadata"] → dict (pickled)❌ Wrong:
# This breaks caching across runs!
cache_ids = [f"sample_{time.time()}_{i}" for i in range(len(batch))]✅ Correct:
# Use dataset index or UUID
cache_ids = [f"sample_{dataset_idx}" for dataset_idx in batch["indices"]]❌ Wrong:
# Different ID every run!
module_id = f"features_{datetime.now()}"✅ Correct:
# Semantic versioning
module_id = "resnet50_imagenet_v1" # Change v2 when model changes❌ Wrong:
for epoch in range(10):
for batch in loader:
cached_module(batch["input"], cache_ids=batch["ids"])
# Cache may be lost if process crashes!✅ Correct:
for epoch in range(10):
for batch in loader:
cached_module(batch["input"], cache_ids=batch["ids"])
backend.flush() # Force persist at end of epoch❌ Wrong:
# This module has trainable parameters!
trainable_module = nn.Linear(512, 10)
cached = CacheModuleDecorator(trainable_module, backend, enabled=True)
# Will raise error if enforce_stateless=True✅ Correct:
# Freeze module first
for param in trainable_module.parameters():
param.requires_grad = False
cached = CacheModuleDecorator(trainable_module, backend, enabled=True)# Small dataset (< 10k): cache everything in memory
backend = ArrowIPCCacheBackend(..., lru_size=10000)
# Large dataset: size for your working set
backend = ArrowIPCCacheBackend(..., lru_size=4096)
# Very large dataset: minimal LRU, rely on disk cache
backend = ArrowIPCCacheBackend(..., lru_size=1024)# Large batches: flush more frequently
backend = ArrowIPCCacheBackend(..., flush_every=512)
# Small batches: flush less frequently
backend = ArrowIPCCacheBackend(..., flush_every=4096)# Training (non-blocking writes)
backend = ArrowIPCCacheBackend(..., async_write=True)
# Testing/debugging (immediate persistence)
backend = ArrowIPCCacheBackend(..., async_write=False)- See
../benchmark.pyfor performance benchmarks - See
../ARCHITECTURE.mdfor technical deep-dive - See
../README.mdfor full API reference