Skip to content

Commit

Permalink
Merge pull request #16 from Modalities/feat/coca
Browse files Browse the repository at this point in the history
Add Contrastive Captioner (CoCa)
  • Loading branch information
le1nux committed Mar 27, 2024
2 parents f8e34c3 + cc1754e commit 60feafe
Show file tree
Hide file tree
Showing 25 changed files with 1,316 additions and 7 deletions.
269 changes: 269 additions & 0 deletions config_files/config_example_coca.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
settings:
experiment_id: ${modalities_env:experiment_id}
referencing_keys:
sample_key: input_ids
target_key: target_ids
training:
callback_interval_in_samples: 6
global_num_training_samples: 12
global_num_seen_samples: 0
do_apply_activation_checkpointing: true
gradient_acc_steps: 1
local_train_micro_batch_size: 3
sequence_length: 256
gradient_clipping:
mode: p2_norm
threshold: 1.0
cuda_env:
local_rank: ${cuda_env:LOCAL_RANK}
global_rank: ${cuda_env:RANK}
world_size: ${cuda_env:WORLD_SIZE}
paths:
checkpointing_path: data/checkpoints

tokenizer:
component_key: tokenizer
variant_key: gpt2_tokenizer_fast
config:
tokenizer_file: data/tokenizer/tokenizer_gpt2.json

collate_fn:
component_key: collate_fn
variant_key: coca_collator
config:
sample_keys:
- images
- ${settings.referencing_keys.sample_key}
target_keys: []
text_sample_key: ${settings.referencing_keys.sample_key}
text_target_key: ${settings.referencing_keys.target_key}

train_dataset:
component_key: dataset
variant_key: dummy_dataset
config:
num_samples: 4
sample_definition:
- sample_key: images
sample_shape: [3, 224, 224]
sample_type: float
- sample_key: input_ids
sample_shape: [1024]
sample_type: int

val_dataset:
component_key: dataset
variant_key: dummy_dataset
config:
num_samples: 4
sample_definition:
- sample_key: images
sample_shape: [3, 224, 224]
sample_type: float
- sample_key: input_ids
sample_shape: [1024]
sample_type: int

train_dataloader:
component_key: data_loader
variant_key: default
config:
num_workers: 2
pin_memory: true
shuffle: false
dataloader_tag: "train"
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
batch_sampler:
component_key: batch_sampler
variant_key: default
config:
batch_size: ${settings.training.local_train_micro_batch_size}
drop_last: false
sampler:
component_key: sampler
variant_key: distributed_sampler
config:
rank: ${settings.cuda_env.global_rank}
num_replicas: ${settings.cuda_env.world_size}
shuffle: true
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
collate_fn:
instance_key: collate_fn
pass_type: BY_REFERENCE

val_dataloader:
component_key: data_loader
variant_key: default
config:
num_workers: 2
pin_memory: true
shuffle: false
dataloader_tag: "val"
dataset:
instance_key: val_dataset
pass_type: BY_REFERENCE
batch_sampler:
component_key: batch_sampler
variant_key: default
config:
batch_size: ${settings.training.local_train_micro_batch_size}
drop_last: false
sampler:
component_key: sampler
variant_key: distributed_sampler
config:
rank: ${settings.cuda_env.global_rank}
num_replicas: ${settings.cuda_env.world_size}
shuffle: false
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
collate_fn:
instance_key: collate_fn
pass_type: BY_REFERENCE

eval_dataloaders:
- instance_key: val_dataloader
pass_type: BY_REFERENCE

checkpointing:
component_key: checkpointing
variant_key: default
config:
checkpointing_strategy:
component_key: checkpointing_strategy
variant_key: save_k_most_recent_checkpoints_strategy
config:
k: -1 # -1 to save all checkpoints
checkpointing_execution:
component_key: checkpointing_execution
variant_key: fsdp_to_disc_checkpointing
config:
checkpoint_path: ${settings.paths.checkpointing_path}
global_rank: ${settings.cuda_env.global_rank}
experiment_id: ${settings.experiment_id}
mixed_precision_settings: FP_16
sharding_strategy: FULL_SHARD
block_names: [TransformerBlock, VisionTransformerBlock]

loss_fn:
component_key: loss
variant_key: clm_cross_entropy_loss
config:
target_key: ${settings.referencing_keys.target_key}
prediction_key: logits

wrapped_model:
component_key: model
variant_key: fsdp_wrapped
config:
model:
instance_key: model
pass_type: BY_REFERENCE
sync_module_states: true
mixed_precision_settings: FP_16
sharding_strategy: FULL_SHARD
block_names: [TransformerBlock, VisionTransformerBlock]

model:
component_key: model
variant_key: coca
config:
prediction_key: logits
vision_embd_prediction_key: vision_embeddings
text_embd_prediction_key: text_embeddings
vision_cls_prediction_key: vision_cls
text_cls_prediction_key: text_cls
vision_encoder_config:
sample_key: images
prediction_key: vision_embeddings
img_size: 224
n_classes: Null # Disable vision transformer head
n_layer: 12
attention_config:
attention_engine_type: default_attention
n_head: 12
n_embd: 768
dropout: 0.0
patch_size: 16
patch_stride: 16
n_img_channels: 3
add_cls_token: False
bias: True
text_decoder_config:
sample_key: ${settings.referencing_keys.sample_key}
prediction_key: ${loss_fn.config.prediction_key}
block_size: 1024
vocab_size: 50304
n_layer_text: 12
n_layer_multimodal_text: 12
attention_config:
attention_engine_type: default_attention
n_head: 12
ffn_hidden: 2048
n_embd: 768
dropout: 0.0
bias: true
activation: fused_swiglu
epsilon: 1e-5
n_pool_head: 8
n_vision_queries: 256
bias_attn_pool: False
epsilon_attn_pool: 1e-5
weight_init:
mean: 0.0
std: 0.02

scheduler:
component_key: scheduler
variant_key: onecycle_lr
config:
optimizer:
instance_key: optimizer
pass_type: BY_REFERENCE
max_lr: 6e-4
div_factor: 10
final_div_factor: 1
total_steps: 4
pct_start: 0.01
anneal_strategy: cos

optimizer:
component_key: optimizer
variant_key: adam_w
config:
lr: 0.0001
betas: [0.9, 0.95]
eps: 1e-8
weight_decay: 1e-1
wrapped_model:
instance_key: wrapped_model
pass_type: BY_REFERENCE

batch_progress_subscriber:
component_key: progress_subscriber
variant_key: rich
config:
local_rank: ${settings.cuda_env.local_rank}
world_size: ${settings.cuda_env.world_size}
global_num_seen_samples: ${settings.training.global_num_seen_samples}
train_dataloader:
instance_key: train_dataloader
pass_type: BY_REFERENCE
eval_dataloaders:
instance_key: eval_dataloaders
pass_type: BY_REFERENCE

evaluation_subscriber:
component_key: results_subscriber
variant_key: wandb
config:
local_rank: ${settings.cuda_env.local_rank}
project: modalities
mode: OFFLINE
experiment_id: ${settings.experiment_id}
directory: "."
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ dependencies = [
"xformers",
"class_resolver",
"wandb",
"flash-attn" # install this directly via `pip install flash-attn --no-build-isolation`

"einops>=0.7.0",
"flash-attn", # install this directly via `pip install flash-attn --no-build-isolation`
]

[project.optional-dependencies]
Expand Down Expand Up @@ -73,10 +73,10 @@ exclude_also = [

# Don't complain about abstract methods, they aren't run:
"@(abc\\.)?abstractmethod",
]
]


ignore_errors = true

[tool.coverage.html]
directory = "coverage_html_report"
directory = "coverage_html_report"
50 changes: 49 additions & 1 deletion src/modalities/dataloader/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from enum import Enum
from pathlib import Path
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import jq
import numpy as np
from pydantic import BaseModel
from torch.utils.data.dataset import Dataset as TorchdataSet
from tqdm import tqdm
from transformers import BatchEncoding, PreTrainedTokenizer
Expand All @@ -24,6 +26,52 @@ def _check_if_inbounds(self, idx: int):
raise IndexError


class DummySampleDataType(str, Enum):
FLOAT = "float"
INT = "int"


class DummySampleConfig(BaseModel):
sample_key: str
sample_shape: Tuple[int, ...]
sample_type: DummySampleDataType


class DummyDatasetConfig(BaseModel):
num_samples: int
sample_definition: List[DummySampleConfig]


class DummyDataset(Dataset):
def __init__(self, num_samples: int, sample_definition: Tuple[DummySampleConfig]):
"""
:param num_samples: Number of samples the dataset should generate.
:param sample_definition: A list of tuples defining the dataset output.
Each touple contains the sample key, shape and data type.
"""
super().__init__(raw_data_path=None, block_size=None, sample_key=None)
self.num_samples = num_samples
self.sample_definition = sample_definition

def __len__(self) -> int:
return self.num_samples

def __getitem__(self, idx: int) -> Dict:
return self._create_random_sample()

def _create_random_sample(self):
sample = dict()
for s in self.sample_definition:
if s.sample_type == DummySampleDataType.FLOAT:
data = np.random.randn(*s.sample_shape)
elif s.sample_type == DummySampleDataType.INT:
data = np.random.randint(low=0, high=512, size=s.sample_shape)
else:
raise NotImplementedError(f"DummyDataset does not support type { s.sample_type}")
sample[s.sample_key] = data
return sample


class MemMapDataset(Dataset):
def __init__(
self,
Expand Down
15 changes: 13 additions & 2 deletions src/modalities/dataloader/dataset_factory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple

from pydantic import FilePath
from torch.utils.data.dataset import Dataset
from transformers import PreTrainedTokenizer

from modalities.dataloader.dataset import MemMapDataset, PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron
from modalities.dataloader.dataset import (
DummyDataset,
DummySampleConfig,
MemMapDataset,
PackedMemMapDatasetContinuous,
PackedMemMapDatasetMegatron,
)
from modalities.dataloader.open_gptx_dataset.open_gptx_dataset import OpenGPTXMMapDataset


Expand All @@ -26,6 +32,11 @@ def __getitem__(self, idx: int):


class DatasetFactory:
@staticmethod
def get_dummy_dataset(num_samples: int, sample_definition: Tuple[DummySampleConfig]) -> DummyDataset:
dataset = DummyDataset(num_samples=num_samples, sample_definition=sample_definition)
return dataset

@staticmethod
def get_mem_map_dataset(
raw_data_path: Path,
Expand Down
Empty file.
Loading

0 comments on commit 60feafe

Please sign in to comment.