diff --git a/config_files/config_example_coca.yaml b/config_files/config_example_coca.yaml new file mode 100644 index 00000000..a61257e1 --- /dev/null +++ b/config_files/config_example_coca.yaml @@ -0,0 +1,269 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + training: + callback_interval_in_samples: 6 + global_num_training_samples: 12 + global_num_seen_samples: 0 + do_apply_activation_checkpointing: true + gradient_acc_steps: 1 + local_train_micro_batch_size: 3 + sequence_length: 256 + gradient_clipping: + mode: p2_norm + threshold: 1.0 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints + +tokenizer: + component_key: tokenizer + variant_key: gpt2_tokenizer_fast + config: + tokenizer_file: data/tokenizer/tokenizer_gpt2.json + +collate_fn: + component_key: collate_fn + variant_key: coca_collator + config: + sample_keys: + - images + - ${settings.referencing_keys.sample_key} + target_keys: [] + text_sample_key: ${settings.referencing_keys.sample_key} + text_target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: dummy_dataset + config: + num_samples: 4 + sample_definition: + - sample_key: images + sample_shape: [3, 224, 224] + sample_type: float + - sample_key: input_ids + sample_shape: [1024] + sample_type: int + +val_dataset: + component_key: dataset + variant_key: dummy_dataset + config: + num_samples: 4 + sample_definition: + - sample_key: images + sample_shape: [3, 224, 224] + sample_type: float + - sample_key: input_ids + sample_shape: [1024] + sample_type: int + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "val" + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: false + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + +checkpointing: + component_key: checkpointing + variant_key: default + config: + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: FP_16 + sharding_strategy: FULL_SHARD + block_names: [TransformerBlock, VisionTransformerBlock] + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: logits + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: FP_16 + sharding_strategy: FULL_SHARD + block_names: [TransformerBlock, VisionTransformerBlock] + +model: + component_key: model + variant_key: coca + config: + prediction_key: logits + vision_embd_prediction_key: vision_embeddings + text_embd_prediction_key: text_embeddings + vision_cls_prediction_key: vision_cls + text_cls_prediction_key: text_cls + vision_encoder_config: + sample_key: images + prediction_key: vision_embeddings + img_size: 224 + n_classes: Null # Disable vision transformer head + n_layer: 12 + attention_config: + attention_engine_type: default_attention + n_head: 12 + n_embd: 768 + dropout: 0.0 + patch_size: 16 + patch_stride: 16 + n_img_channels: 3 + add_cls_token: False + bias: True + text_decoder_config: + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${loss_fn.config.prediction_key} + block_size: 1024 + vocab_size: 50304 + n_layer_text: 12 + n_layer_multimodal_text: 12 + attention_config: + attention_engine_type: default_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: fused_swiglu + epsilon: 1e-5 + n_pool_head: 8 + n_vision_queries: 256 + bias_attn_pool: False + epsilon_attn_pool: 1e-5 + weight_init: + mean: 0.0 + std: 0.02 + +scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: 4 + pct_start: 0.01 + anneal_strategy: cos + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_samples: ${settings.training.global_num_seen_samples} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: "." diff --git a/pyproject.toml b/pyproject.toml index 00b31e63..e9121daa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,8 @@ dependencies = [ "xformers", "class_resolver", "wandb", - "flash-attn" # install this directly via `pip install flash-attn --no-build-isolation` - + "einops>=0.7.0", + "flash-attn", # install this directly via `pip install flash-attn --no-build-isolation` ] [project.optional-dependencies] @@ -73,10 +73,10 @@ exclude_also = [ # Don't complain about abstract methods, they aren't run: "@(abc\\.)?abstractmethod", - ] +] ignore_errors = true [tool.coverage.html] -directory = "coverage_html_report" \ No newline at end of file +directory = "coverage_html_report" diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index ef0ae2ad..95959d57 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -1,10 +1,12 @@ from __future__ import annotations +from enum import Enum from pathlib import Path -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import jq import numpy as np +from pydantic import BaseModel from torch.utils.data.dataset import Dataset as TorchdataSet from tqdm import tqdm from transformers import BatchEncoding, PreTrainedTokenizer @@ -24,6 +26,52 @@ def _check_if_inbounds(self, idx: int): raise IndexError +class DummySampleDataType(str, Enum): + FLOAT = "float" + INT = "int" + + +class DummySampleConfig(BaseModel): + sample_key: str + sample_shape: Tuple[int, ...] + sample_type: DummySampleDataType + + +class DummyDatasetConfig(BaseModel): + num_samples: int + sample_definition: List[DummySampleConfig] + + +class DummyDataset(Dataset): + def __init__(self, num_samples: int, sample_definition: Tuple[DummySampleConfig]): + """ + :param num_samples: Number of samples the dataset should generate. + :param sample_definition: A list of tuples defining the dataset output. + Each touple contains the sample key, shape and data type. + """ + super().__init__(raw_data_path=None, block_size=None, sample_key=None) + self.num_samples = num_samples + self.sample_definition = sample_definition + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, idx: int) -> Dict: + return self._create_random_sample() + + def _create_random_sample(self): + sample = dict() + for s in self.sample_definition: + if s.sample_type == DummySampleDataType.FLOAT: + data = np.random.randn(*s.sample_shape) + elif s.sample_type == DummySampleDataType.INT: + data = np.random.randint(low=0, high=512, size=s.sample_shape) + else: + raise NotImplementedError(f"DummyDataset does not support type { s.sample_type}") + sample[s.sample_key] = data + return sample + + class MemMapDataset(Dataset): def __init__( self, diff --git a/src/modalities/dataloader/dataset_factory.py b/src/modalities/dataloader/dataset_factory.py index 157e98d0..972c09c6 100644 --- a/src/modalities/dataloader/dataset_factory.py +++ b/src/modalities/dataloader/dataset_factory.py @@ -1,11 +1,17 @@ from pathlib import Path -from typing import Optional +from typing import Optional, Tuple from pydantic import FilePath from torch.utils.data.dataset import Dataset from transformers import PreTrainedTokenizer -from modalities.dataloader.dataset import MemMapDataset, PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron +from modalities.dataloader.dataset import ( + DummyDataset, + DummySampleConfig, + MemMapDataset, + PackedMemMapDatasetContinuous, + PackedMemMapDatasetMegatron, +) from modalities.dataloader.open_gptx_dataset.open_gptx_dataset import OpenGPTXMMapDataset @@ -26,6 +32,11 @@ def __getitem__(self, idx: int): class DatasetFactory: + @staticmethod + def get_dummy_dataset(num_samples: int, sample_definition: Tuple[DummySampleConfig]) -> DummyDataset: + dataset = DummyDataset(num_samples=num_samples, sample_definition=sample_definition) + return dataset + @staticmethod def get_mem_map_dataset( raw_data_path: Path, diff --git a/src/modalities/models/coca/__init__.py b/src/modalities/models/coca/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/modalities/models/coca/attention_pooling.py b/src/modalities/models/coca/attention_pooling.py new file mode 100644 index 00000000..0a6c1039 --- /dev/null +++ b/src/modalities/models/coca/attention_pooling.py @@ -0,0 +1,23 @@ +import torch +from torch import nn + +from modalities.nn.attention import AttentionConfig, AttentionType, MultiHeadAttention + + +class AttentionPooling(nn.Module): + def __init__(self, n_embd: int, n_head: int, bias: bool, epsilon: float, attention_config: AttentionConfig = None): + super().__init__() + self.ln_1 = nn.LayerNorm(normalized_shape=n_embd, bias=bias, eps=epsilon) + self.attn = MultiHeadAttention( + n_embd=n_embd, + n_head=n_head, + attention_config=attention_config, + attention_type=AttentionType.CROSS_ATTENTION, + ) + self.ln_2 = nn.LayerNorm(normalized_shape=n_embd, bias=bias, eps=epsilon) + + def forward(self, queries: torch.Tensor, context: torch.Tensor) -> torch.Tensor: + x = self.ln_1(context) + x = self.attn(queries, context=x) + x = self.ln_2(x) + return x diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py new file mode 100644 index 00000000..b531cf21 --- /dev/null +++ b/src/modalities/models/coca/coca_model.py @@ -0,0 +1,178 @@ +import math +from functools import partial +from typing import Annotated, Dict, Tuple + +import torch +from einops import repeat +from pydantic import BaseModel, Field +from torch import nn + +from modalities.models.coca.attention_pooling import AttentionPooling +from modalities.models.coca.multi_modal_decoder import MultiModalTextDecoder +from modalities.models.coca.text_decoder import TextDecoder +from modalities.models.gpt2.gpt2_model import ActivationType, WeightInitializationConfig +from modalities.models.model import NNModel +from modalities.models.vision_transformer.vision_transformer_model import VisionTransformer, VisionTransformerConfig +from modalities.nn.attention import AttentionConfig + + +class TextDecoderConfig(BaseModel): + sample_key: str + prediction_key: str + block_size: Annotated[int, Field(ge=1)] + vocab_size: Annotated[int, Field(ge=1)] + n_layer_text: Annotated[int, Field(ge=1)] + n_layer_multimodal_text: Annotated[int, Field(ge=1)] + n_head: Annotated[int, Field(ge=1)] + n_embd: Annotated[int, Field(ge=1)] + ffn_hidden: Annotated[int, Field(ge=1)] + dropout: Annotated[float, Field(ge=0.0)] + bias: bool + attention_config: AttentionConfig + activation: ActivationType + epsilon: Annotated[float, Field(ge=0.0)] + + +class CoCaConfig(BaseModel): + prediction_key: str = "logits" + vision_embd_prediction_key: str # same key as vision encoder + text_embd_prediction_key: str + vision_cls_prediction_key: str + text_cls_prediction_key: str + vision_encoder_config: VisionTransformerConfig + text_decoder_config: TextDecoderConfig + n_pool_head: Annotated[int, Field(ge=1)] + n_vision_queries: Annotated[int, Field(ge=1)] + bias_attn_pool: bool + epsilon_attn_pool: Annotated[float, Field(ge=0.0)] + weight_init: WeightInitializationConfig + + +class CoCa(NNModel): + """CoCa + + The Contrastive Captioner (CoCa) is an encoder-decoder model that integrates the concepts of CLIP + and generative models such as SimVLM by using contrastive and captioning losses for training. + + Paper: `CoCa: Contrastive Captioners are Image-Text Foundation Models` + Link: https://arxiv.org/abs/2205.01917 + """ + + def __init__( + self, + prediction_key: str, + vision_cls_prediction_key: str, + text_cls_prediction_key: str, + vision_embd_prediction_key: str, + text_embd_prediction_key: str, + n_vision_queries: int, + n_pool_head: int, + bias_attn_pool: bool, + epsilon_attn_pool: float, + vision_encoder_config: VisionTransformerConfig, + text_decoder_config: TextDecoderConfig, + weight_init: WeightInitializationConfig, + ) -> None: + super().__init__() + self.prediction_key = prediction_key + self.vision_cls_prediction_key = vision_cls_prediction_key + self.text_cls_prediction_key = text_cls_prediction_key + self.vision_embd_prediction_key = vision_embd_prediction_key + self.text_embd_prediction_key = text_embd_prediction_key + + self.vision_encoder = VisionTransformer(**dict(vision_encoder_config)) + self.text_decoder = TextDecoder( + sample_key=text_decoder_config.sample_key, + prediction_key=text_embd_prediction_key, + block_size=text_decoder_config.block_size + 1, # +1 for the class token + n_layer=text_decoder_config.n_layer_text, + vocab_size=text_decoder_config.vocab_size, + n_head=text_decoder_config.n_head, + n_embd=text_decoder_config.n_embd, + ffn_hidden=text_decoder_config.ffn_hidden, + dropout=text_decoder_config.dropout, + bias=text_decoder_config.bias, + attention_config=text_decoder_config.attention_config, + activation=text_decoder_config.activation, + epsilon=text_decoder_config.epsilon, + ) + self.multimodal_decoder = MultiModalTextDecoder( + sample_key=text_embd_prediction_key, + prediction_key=text_decoder_config.prediction_key, + block_size=text_decoder_config.block_size, + n_layer=text_decoder_config.n_layer_multimodal_text, + vocab_size=text_decoder_config.vocab_size, + n_head=text_decoder_config.n_head, + n_embd=text_decoder_config.n_embd, + ffn_hidden=text_decoder_config.ffn_hidden, + dropout=text_decoder_config.dropout, + bias=text_decoder_config.bias, + attention_config=text_decoder_config.attention_config, + activation=text_decoder_config.activation, + epsilon=text_decoder_config.epsilon, + ) + + self.text_decoder.transformer.wte.weight = ( + self.multimodal_decoder.lm_head.weight + ) # https://paperswithcode.com/method/weight-tying + + # vision_queries: 256 queries for multimodal cross attention and 1 as vision cls token for contrastive learning + self.vision_queries = nn.Parameter(torch.randn(n_vision_queries + 1, vision_encoder_config.n_embd)) + self.attn_pool = AttentionPooling( + n_embd=vision_encoder_config.n_embd, + n_head=n_pool_head, + bias=bias_attn_pool, + epsilon=epsilon_attn_pool, + attention_config=text_decoder_config.attention_config, + ) + + # init all weights + self.apply(partial(self._init_weights, weight_init=weight_init)) + # apply special scaled init to the residual projections, per GPT-2 paper + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_( + p, + mean=weight_init.mean, + std=weight_init.std + / math.sqrt(2 * (text_decoder_config.n_layer_text + text_decoder_config.n_layer_multimodal_text)), + ) + + def _init_weights(self, module: nn.Module, weight_init: WeightInitializationConfig): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=weight_init.mean, std=weight_init.std) + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + vision_embd, vision_cls_token = self._forward_encode_vision(inputs) + text_embd, text_cls_token = self._forward_encode_text(inputs) + logits = self._forward_decode(text_embd, vision_embd) + return { + self.prediction_key: logits, + self.vision_cls_prediction_key: vision_cls_token, + self.text_cls_prediction_key: text_cls_token, + } + + def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + vision_embd = self.vision_encoder(inputs)[self.vision_embd_prediction_key] + queries = repeat(self.vision_queries, "n d -> b n d", b=vision_embd.shape[0]) + vision_embd = self.attn_pool(queries, context=vision_embd) + vision_embd, vision_cls_token = vision_embd[:, :-1, :], vision_embd[:, -1:, :] + return vision_embd, vision_cls_token + + def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + text_embd = self.text_decoder(inputs)[self.text_embd_prediction_key] + text_embd, text_cls_token = text_embd[:, :-1, :], text_embd[:, -1:, :] + return text_embd, text_cls_token + + def _forward_decode(self, text_embd: torch.Tensor, vision_embd: torch.Tensor) -> torch.Tensor: + decoder_inputs = { + self.text_embd_prediction_key: text_embd, + "context": vision_embd, + } + decoder_outputs = self.multimodal_decoder(decoder_inputs) + logits = decoder_outputs[self.multimodal_decoder.prediction_key] + return logits diff --git a/src/modalities/models/coca/collator.py b/src/modalities/models/coca/collator.py new file mode 100644 index 00000000..0c9584ca --- /dev/null +++ b/src/modalities/models/coca/collator.py @@ -0,0 +1,44 @@ +from dataclasses import field +from typing import Dict, List + +import torch +from pydantic import BaseModel + +from modalities.batch import DatasetBatch +from modalities.models.gpt2.collator import CollateFnIF + + +class CoCaCollateFnConfig(BaseModel): + sample_keys: List[str] + target_keys: List[str] + text_sample_key: str + text_target_key: str + + +class CoCaCollatorFn(CollateFnIF): + def __init__(self, sample_keys: List[str], target_keys: List[str], text_sample_key: str, text_target_key: str): + self.device: torch.device = field(default_factory=lambda: torch.device("cpu")) + if text_sample_key not in sample_keys: + raise ValueError(f"{text_sample_key} is not part of sample keys {sample_keys}") + if text_target_key in target_keys: + raise ValueError( + f"{text_target_key} should not be part of target keys {target_keys}, " + f"because {text_target_key} will generated based on {text_sample_key}" + ) + self.sample_keys = sample_keys + self.target_keys = target_keys + self.text_sample_key = text_sample_key + self.text_target_key = text_target_key + + def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + samples = { + sample_key: torch.stack([torch.tensor(d[sample_key]) for d in batch]) for sample_key in self.sample_keys + } + targets = { + target_key: torch.stack([torch.tensor(d[target_key]) for d in batch]) for target_key in self.target_keys + } + + # Create target for text input + targets[self.text_target_key] = samples[self.text_sample_key][:, 1:].clone().detach() + samples[self.text_sample_key] = samples[self.text_sample_key][:, :-1].clone().detach() + return DatasetBatch(targets=targets, samples=samples) diff --git a/src/modalities/models/coca/multi_modal_decoder.py b/src/modalities/models/coca/multi_modal_decoder.py new file mode 100644 index 00000000..cced19b4 --- /dev/null +++ b/src/modalities/models/coca/multi_modal_decoder.py @@ -0,0 +1,124 @@ +from functools import partial +from typing import Dict + +import torch +import xformers.ops as xops +from torch import nn + +from modalities.models.gpt2.gpt2_model import ActivationType +from modalities.models.model import NNModel +from modalities.nn.attention import AttentionConfig, AttentionType, MultiHeadAttention +from modalities.nn.mlp import MLP + + +class TransformerBlock(nn.Module): + def __init__( + self, + n_embd: int, + bias: bool, + epsilon: float, + activation: ActivationType, + n_head: int, + dropout: float, + ffn_hidden: int, + with_context: bool, + attention_type: AttentionType, + attention_config: AttentionConfig = None, + add_extra_mlp: bool = False, + ): + super().__init__() + self.with_context = with_context + self.add_extra_mlp = add_extra_mlp + + if activation == ActivationType.GELU: + mlp = partial(MLP, in_features=n_embd, hidden_features=ffn_hidden, bias=bias, dropout=dropout) + elif activation == ActivationType.FUSED_SWIGLU: + mlp = partial(xops.SwiGLU, in_features=n_embd, hidden_features=ffn_hidden, bias=bias) + else: + raise NotImplementedError(f"activation type {activation} not implemented") + + self.ln_1 = nn.LayerNorm(normalized_shape=n_embd, bias=bias, eps=epsilon) + self.attn = MultiHeadAttention( + n_embd=n_embd, n_head=n_head, bias=bias, attention_config=attention_config, attention_type=attention_type + ) + + if not self.with_context or self.add_extra_mlp: + self.ln_2 = nn.LayerNorm(normalized_shape=n_embd, bias=bias, eps=epsilon) + self.mlp = mlp() + + if self.with_context: + self.ln_3 = nn.LayerNorm(normalized_shape=n_embd, bias=bias, eps=epsilon) + self.cross_attn = MultiHeadAttention( + n_embd=n_embd, + n_head=n_head, + bias=bias, + attention_config=attention_config, + attention_type=AttentionType.CROSS_ATTENTION, + ) + self.ln_4 = nn.LayerNorm(normalized_shape=n_embd, bias=bias, eps=epsilon) + self.mlp_2 = mlp() + + def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + x = x + self.attn(self.ln_1(x)) + if not self.with_context or self.add_extra_mlp: + x = x + self.mlp(self.ln_2(x)) + if self.with_context: + x = x + self.cross_attn(self.ln_3(x), context=context) + x = x + self.mlp_2(self.ln_4(x)) + return x + + +class MultiModalTextDecoder(NNModel): + def __init__( + self, + sample_key: str, + prediction_key: str, + block_size: int, + vocab_size: int, + n_layer: int, + n_head: int, + n_embd: int, + ffn_hidden: int, + dropout: float, + bias: bool, + activation: ActivationType, + epsilon: float, + attention_config: AttentionConfig, + ): + super().__init__() + self.sample_key = sample_key + self.prediction_key = prediction_key + self.block_size = block_size + + self.transformer = nn.ModuleDict( + dict( + h=nn.ModuleList( + [ + TransformerBlock( + n_embd=n_embd, + bias=bias, + epsilon=epsilon, + activation=activation, + n_head=n_head, + dropout=dropout, + ffn_hidden=ffn_hidden, + with_context=True, + attention_type=AttentionType.CAUSAL_SELF_ATTENTION, + attention_config=attention_config, + add_extra_mlp=False, + ) + for _ in range(n_layer) + ] + ), + ln_f=nn.LayerNorm(normalized_shape=n_embd, bias=bias, eps=epsilon), + ) + ) + self.lm_head = nn.Linear(in_features=n_embd, out_features=vocab_size, bias=False) + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + x = inputs[self.sample_key] + for block in self.transformer.h: + x = block(x, context=inputs["context"]) + x = self.transformer.ln_f(x) + logits = self.lm_head(x) + return {self.prediction_key: logits} diff --git a/src/modalities/models/coca/text_decoder.py b/src/modalities/models/coca/text_decoder.py new file mode 100644 index 00000000..aab7c8f8 --- /dev/null +++ b/src/modalities/models/coca/text_decoder.py @@ -0,0 +1,73 @@ +from typing import Dict + +import torch +from torch import nn + +from modalities.models.coca.multi_modal_decoder import TransformerBlock +from modalities.models.gpt2.gpt2_model import ActivationType +from modalities.models.model import NNModel +from modalities.nn.attention import AttentionConfig, AttentionType + + +class TextDecoder(NNModel): + def __init__( + self, + sample_key: str, + prediction_key: str, + block_size: int, + vocab_size: int, + n_layer: int, + n_head: int, + n_embd: int, + ffn_hidden: int, + dropout: float, + bias: bool, + activation: ActivationType, + epsilon: float, + attention_config: AttentionConfig = None, + ): + super().__init__() + self.sample_key = sample_key + self.prediction_key = prediction_key + self.block_size = block_size + + self.cls_token = nn.Parameter(torch.empty(1, 1, n_embd)) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(num_embeddings=vocab_size, embedding_dim=n_embd), + wpe=nn.Embedding(num_embeddings=block_size, embedding_dim=n_embd), + drop=nn.Dropout(dropout), + h=nn.ModuleList( + [ + TransformerBlock( + n_embd=n_embd, + bias=bias, + epsilon=epsilon, + activation=activation, + n_head=n_head, + dropout=dropout, + ffn_hidden=ffn_hidden, + with_context=False, + attention_type=AttentionType.CAUSAL_SELF_ATTENTION, + attention_config=attention_config, + ) + for _ in range(n_layer) + ] + ), + ) + ) + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + input_ids = inputs[self.sample_key] + device = input_ids.device + B, T = input_ids.size() + assert T <= self.block_size, f"Cannot forward sequence of length {T}, block size is only {self.block_size}" + pos = torch.arange(0, T + 1, dtype=torch.long, device=device) + + tok_emb = self.transformer.wte(input_ids) + tok_emb = torch.cat([tok_emb, self.cls_token.repeat(B, 1, 1)], dim=1) + pos_emb = self.transformer.wpe(pos) + x = self.transformer.drop(tok_emb + pos_emb) + for block in self.transformer.h: + x = block(x) + return {self.prediction_key: x} diff --git a/src/modalities/models/vision_transformer/__init__.py b/src/modalities/models/vision_transformer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/modalities/models/vision_transformer/vision_transformer_model.py b/src/modalities/models/vision_transformer/vision_transformer_model.py new file mode 100644 index 00000000..78cbbb7a --- /dev/null +++ b/src/modalities/models/vision_transformer/vision_transformer_model.py @@ -0,0 +1,171 @@ +from math import floor +from typing import Annotated, Dict, Optional, Tuple, Union + +import torch +from einops.layers.torch import Rearrange +from pydantic import BaseModel, Field +from torch import nn + +from modalities.nn.attention import AttentionConfig, AttentionType, MultiHeadAttention +from modalities.nn.mlp import MLP + + +class VisionTransformerConfig(BaseModel): + sample_key: str + prediction_key: str + img_size: Annotated[Union[Tuple[int, int], int], Field(ge=1)] = 224 + n_classes: Optional[Annotated[int, Field(ge=1)]] = 1000 + n_layer: Annotated[int, Field(ge=1)] = 12 + attention_config: AttentionConfig = None + n_head: Annotated[int, Field(ge=1)] = 8 + n_embd: Annotated[int, Field(ge=1)] = 768 + dropout: Annotated[float, Field(ge=0.0)] = 0.0 + patch_size: Annotated[int, Field(ge=1)] = 16 + patch_stride: Annotated[int, Field(ge=1)] = 16 + n_img_channels: Annotated[int, Field(ge=1)] = 3 + add_cls_token: bool = True + bias: bool = True + + +class ImagePatchEmbedding(nn.Module): + def __init__( + self, + n_img_channels: int = 3, + n_embd: int = 768, + patch_size: int = 16, + patch_stride: int = 16, + add_cls_token: bool = True, + ) -> None: + super().__init__() + self.conv = nn.Conv2d( + in_channels=n_img_channels, out_channels=n_embd, kernel_size=patch_size, stride=patch_stride + ) + + # Define a rearrangement operation to reshape the tensor from + # batched 4D format (batch_size, channels, height, width) to + # batched 3D format (batch_size, height*width, channels). + # This is required to support torch.compile. + # See https://github.com/arogozhnikov/einops/wiki/Using-torch.compile-with-einops + self.rearrange = Rearrange("b c h w -> b (h w) c") + + self.cls_token = None + if add_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, n_embd)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B = x.shape[0] + x = self.conv(x) + x = self.rearrange(x) + if self.cls_token is not None: + x = torch.cat([self.cls_token.repeat(B, 1, 1), x], dim=1) + return x + + +class VisionTransformerBlock(nn.Module): + def __init__( + self, + n_embd: int = 768, + n_head: int = 8, + ffn_hidden: int = 3072, + bias: bool = True, + dropout: float = 0.0, + attention_config: AttentionConfig = None, + ) -> None: + super().__init__() + self.norm1 = nn.LayerNorm(n_embd) + self.attention = MultiHeadAttention( + n_embd=n_embd, + n_head=n_head, + attention_config=attention_config, + attention_type=AttentionType.NON_CAUSAL_SELF_ATTENTION, + ) + self.norm2 = nn.LayerNorm(n_embd) + self.mlp = MLP(in_features=n_embd, hidden_features=ffn_hidden, bias=bias, dropout=dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attention(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class VisionTransformer(nn.Module): + """ViT + + The Vision Transformer (ViT) is a pure transformer architecture + that applies attention mechanisms directly to sequences of image patches for image classification tasks. + + Paper: `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + Link: https://arxiv.org/abs/2010.11929 + """ + + def __init__( + self, + sample_key: str, + prediction_key: str, + img_size: Union[Tuple[int, int], int] = 224, + n_classes: int = 1000, + n_layer: int = 12, + attention_config: AttentionConfig = None, + n_head: int = 8, + n_embd: int = 768, + ffn_hidden: int = 3072, + dropout: float = 0.0, + patch_size: int = 16, + patch_stride: int = 16, + n_img_channels: int = 3, + add_cls_token: bool = True, + bias: bool = True, + ) -> None: + super().__init__() + self.sample_key = sample_key + self.prediction_key = prediction_key + self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) + self.block_size = self._calculate_block_size(self.img_size, patch_size, patch_stride, add_cls_token) + + self.embedding_fn = ImagePatchEmbedding(n_img_channels, n_embd, patch_size, patch_stride, add_cls_token) + self.positional_embedding_fn = nn.Embedding(num_embeddings=self.block_size, embedding_dim=n_embd) + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList( + [ + VisionTransformerBlock( + n_embd=n_embd, + n_head=n_head, + ffn_hidden=ffn_hidden, + bias=bias, + dropout=dropout, + attention_config=attention_config, + ) + for _ in range(n_layer) + ] + ) + + self.head = None + if n_classes is not None: + self.norm = nn.LayerNorm(n_embd) + self.head = nn.Linear(in_features=n_embd, out_features=n_classes, bias=bias) + + def forward_images(self, x: torch.Tensor) -> torch.Tensor: + x = self.embedding_fn(x) + x = self.dropout(x + self.positional_embedding_fn.weight) + for block in self.blocks: + x = block(x) + return x + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + x = inputs[self.sample_key] + x = self.forward_images(x) + if self.head: + if self.embedding_fn.cls_token is not None: + x = x[:, 0] + else: + x = x.mean(dim=1) + x = self.head(self.norm(x)) + return {self.prediction_key: x} + + @staticmethod + def _calculate_block_size(img_size: Tuple[int, int], patch_size: int, patch_stride: int, add_cls_token: bool): + # See https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html for details + block_size = (floor((img_size[0] - patch_size) / patch_stride) + 1) * ( + floor((img_size[1] - patch_size) / patch_stride) + 1 + ) + int(add_cls_token) + return block_size diff --git a/src/modalities/nn/__init__.py b/src/modalities/nn/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/modalities/nn/attention.py b/src/modalities/nn/attention.py new file mode 100644 index 00000000..dd8b5db5 --- /dev/null +++ b/src/modalities/nn/attention.py @@ -0,0 +1,98 @@ +import math +from enum import Enum +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from pydantic import BaseModel +from torch import Tensor, nn + + +class AttentionEngineType(str, Enum): + DEFAULT_ATTENTION = "default_attention" + PYTORCH_FLASH_ATTENTION = "pytorch_flash_attention" + + +class AttentionType(str, Enum): + CAUSAL_SELF_ATTENTION = "causal_self_attention" + NON_CAUSAL_SELF_ATTENTION = "non_causal_self_attention" + CROSS_ATTENTION = "cross_attention" + + +class AttentionConfig(BaseModel): + attention_engine_type: AttentionEngineType + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + attention_config: AttentionConfig = None, + attention_type: AttentionType = AttentionType.CAUSAL_SELF_ATTENTION, + n_embd: int = 768, + n_head: int = 8, + bias: bool = True, + dropout: float = 0.0, + block_size: int = 1024, + ): + super().__init__() + if n_embd % n_head != 0: + raise ValueError("n_embd needs to be divisible by n_head") + if attention_config is None: + attention_config = AttentionConfig(attention_engine_type=AttentionEngineType.DEFAULT_ATTENTION) + self.n_head = n_head + self.n_embd = n_embd + self.dropout = dropout + self.use_flash = attention_config.attention_engine_type == AttentionEngineType.PYTORCH_FLASH_ATTENTION + self.is_causal = attention_type == AttentionType.CAUSAL_SELF_ATTENTION + self.use_cross_attention = attention_type == AttentionType.CROSS_ATTENTION + + self.wq = nn.Linear(in_features=n_embd, out_features=n_embd, bias=bias) + self.wk = nn.Linear(in_features=n_embd, out_features=n_embd, bias=bias) + self.wv = nn.Linear(in_features=n_embd, out_features=n_embd, bias=bias) + self.c_proj = nn.Linear(in_features=n_embd, out_features=n_embd, bias=bias) + + if not self.use_flash: + self.attn_dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() + self.register_buffer( + "bias", + torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size), + ) + self.resid_dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() + + def forward(self, x: Tensor, context: Optional[Tensor] = None) -> Tensor: + context = context if self.use_cross_attention else x + B, T, C = x.shape # batch size, sequence length, embedding dimensionality (n_embd) + q, k, v = self._forward_input_projection(x, context=context) + if self.use_flash: + y = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, + attn_mask=None, + dropout_p=self.dropout if self.training else 0, + is_causal=self.is_causal, + ) + else: + y = self._forward_attention(query=q, key=k, value=v) + y = y.transpose(1, 2).contiguous().view(B, T, C) + y = self.resid_dropout(self.c_proj(y)) + return y + + def _forward_input_projection(self, x: Tensor, context: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + B, T, C = x.shape # batch size, sequence length, embedding dimensionality (n_embd) + _, Tc, Cc = context.shape # batch size, context length, context embedding dimensionality + # Note that the context length (Tc), sequence length (T) and embedding dimensionalities (C and Cc) + # are the same for self-attention and can only differ for cross-attention + q = self.wq(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + k = self.wk(context).view(B, Tc, self.n_head, Cc // self.n_head).transpose(1, 2) + v = self.wv(context).view(B, Tc, self.n_head, Cc // self.n_head).transpose(1, 2) + return q, k, v + + def _forward_attention(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: + att = (query @ key.transpose(-2, -1)) * (1.0 / math.sqrt(key.size(-1))) + if self.is_causal: + T = query.size(2) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + return att @ value diff --git a/src/modalities/nn/mlp.py b/src/modalities/nn/mlp.py new file mode 100644 index 00000000..4ef55d56 --- /dev/null +++ b/src/modalities/nn/mlp.py @@ -0,0 +1,31 @@ +from typing import Callable, Optional + +from torch import Tensor, nn + + +class MLP(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + bias: bool = True, + dropout: float = 0.0, + act_fn: Callable[[], nn.Module] = nn.GELU, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or 4 * in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_fn() + self.drop1 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop2 = nn.Dropout(dropout) if dropout > 0 else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 28fa6372..6628762d 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -45,12 +45,15 @@ WandBEvaluationResultSubscriberConfig, ) from modalities.dataloader.dataloader_factory import DataloaderFactory +from modalities.dataloader.dataset import DummyDatasetConfig from modalities.dataloader.dataset_factory import DatasetFactory from modalities.logging_broker.subscriber_impl.subscriber_factory import ( ProgressSubscriberFactory, ResultsSubscriberFactory, ) from modalities.loss_functions import CLMCrossEntropyLoss +from modalities.models.coca.coca_model import CoCa, CoCaConfig +from modalities.models.coca.collator import CoCaCollateFnConfig, CoCaCollatorFn from modalities.models.components.layer_norms import LayerNormConfig, RMSLayerNorm, RMSLayerNormConfig from modalities.models.gpt2.collator import GPT2LLMCollateFn from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2LLMConfig @@ -79,6 +82,7 @@ class ComponentEntity: ), ComponentEntity("model", "checkpointed", ModelFactory.get_checkpointed_model, CheckpointedModelConfig), ComponentEntity("model", "fsdp_wrapped", ModelFactory.get_fsdp_wrapped_model, FSDPWrappedModelConfig), + ComponentEntity("model", "coca", CoCa, CoCaConfig), # losses ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), # optmizers @@ -115,12 +119,14 @@ class ComponentEntity: ComponentEntity( "dataset", "open_gptx_mmap_dataset", DatasetFactory.get_open_gptx_mmap_dataset, OpenGPTXMMapDatasetConfig ), + ComponentEntity("dataset", "dummy_dataset", DatasetFactory.get_dummy_dataset, DummyDatasetConfig), # samplers ComponentEntity("sampler", "distributed_sampler", DistributedSampler, DistributedSamplerConfig), # batch samplers ComponentEntity("batch_sampler", "default", BatchSampler, BatchSamplerConfig), # collators ComponentEntity("collate_fn", "gpt_2_llm_collator", GPT2LLMCollateFn, GPT2LLMCollateFnConfig), + ComponentEntity("collate_fn", "coca_collator", CoCaCollatorFn, CoCaCollateFnConfig), # data loaders ComponentEntity("data_loader", "default", DataloaderFactory.get_dataloader, LLMDataLoaderConfig), # ComponentEntity("data_loader", "repeating_data_loader",(RepeatingDataLoader, None), # TODO diff --git a/tests/dataloader/test_dummy_dataset.py b/tests/dataloader/test_dummy_dataset.py new file mode 100644 index 00000000..7c103699 --- /dev/null +++ b/tests/dataloader/test_dummy_dataset.py @@ -0,0 +1,21 @@ +import numpy as np + +from modalities.dataloader.dataset import DummyDataset, DummySampleConfig, DummySampleDataType + + +def test_dummy_dataset(): + dataset = DummyDataset( + num_samples=50, + sample_definition=[ + DummySampleConfig(sample_key="input_ids", sample_shape=(512,), sample_type=DummySampleDataType.INT), + DummySampleConfig(sample_key="images", sample_shape=(3, 224, 224), sample_type=DummySampleDataType.FLOAT), + ], + ) + assert len(dataset) == 50 + sample = next(iter(dataset)) + assert "input_ids" in sample + assert sample["input_ids"].shape == (512,) + assert sample["input_ids"].dtype == np.int64 + assert "images" in sample + assert sample["images"].shape == (3, 224, 224) + assert sample["images"].dtype == np.float64 diff --git a/tests/models/coca/__init__.py b/tests/models/coca/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/coca/coca_config.yaml b/tests/models/coca/coca_config.yaml new file mode 100644 index 00000000..952cda66 --- /dev/null +++ b/tests/models/coca/coca_config.yaml @@ -0,0 +1,44 @@ +prediction_key: logits +vision_embd_prediction_key: vision_embeddings +text_embd_prediction_key: text_embeddings +vision_cls_prediction_key: vision_cls +text_cls_prediction_key: text_cls +vision_encoder_config: + sample_key: images + prediction_key: vision_embeddings + img_size: 224 + n_classes: Null # Disable vision transformer head + n_layer: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 8 + n_embd: 768 + dropout: 0.0 + patch_size: 16 + patch_stride: 16 + n_img_channels: 3 + add_cls_token: False + bias: True +text_decoder_config: + sample_key: input_ids + prediction_key: text_embeddings + block_size: 1024 + vocab_size: 50304 + n_layer_text: 6 + n_layer_multimodal_text: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: fused_swiglu + epsilon: 1e-5 +n_pool_head: 8 +n_vision_queries: 256 +bias_attn_pool: False +epsilon_attn_pool: 1e-5 +weight_init: + mean: 0.0 + std: 0.02 diff --git a/tests/models/coca/test_attention_pooling.py b/tests/models/coca/test_attention_pooling.py new file mode 100644 index 00000000..f781d3b7 --- /dev/null +++ b/tests/models/coca/test_attention_pooling.py @@ -0,0 +1,11 @@ +import torch + +from modalities.models.coca.attention_pooling import AttentionPooling + + +def test_attention_pooling_forward(): + model = AttentionPooling(n_embd=768, n_head=8, bias=False, epsilon=1e-5) + dummy_input = torch.randn(1, 256, 768) + dummy_queries = torch.randn(1, 257, 768) + out = model(dummy_queries, dummy_input) + assert out.shape == (1, 257, 768) diff --git a/tests/models/coca/test_coca.py b/tests/models/coca/test_coca.py new file mode 100644 index 00000000..9fa12c6d --- /dev/null +++ b/tests/models/coca/test_coca.py @@ -0,0 +1,60 @@ +from pathlib import Path + +import pytest +import torch + +from modalities.__main__ import Main, load_app_config_dict +from modalities.models.coca.coca_model import CoCa, CoCaConfig +from tests.conftest import _ROOT_DIR + + +def test_coca(): + # Create model + config_file_path = _ROOT_DIR / Path("tests/models/coca/coca_config.yaml") + config_dict = load_app_config_dict(config_file_path=config_file_path) + coca_config = CoCaConfig.model_validate(config_dict) + model = CoCa(**dict(coca_config)) + + # Create dummy inputs + dummy_input_image = torch.randn(1, 3, 224, 224) + dummy_input_text = torch.randint( + 0, coca_config.text_decoder_config.vocab_size, (1, coca_config.text_decoder_config.block_size) + ) + dummy_input = dict(images=dummy_input_image, input_ids=dummy_input_text) + + # Create optimizer + optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + + # Run one training step + optimizer.zero_grad() + out = model(dummy_input) + loss = out["logits"].sum() + loss.backward() + optimizer.step() + + # Test outputs + assert "logits" in out + assert "vision_cls" in out + assert "text_cls" in out + assert out["logits"].shape == (1, 1024, 50304) + assert out["vision_cls"].shape == (1, 1, 768) + assert out["text_cls"].shape == (1, 1, 768) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="This e2e test requires 1 GPU.") +def test_e2e_coca_training_run_without_checkpoint(monkeypatch): + monkeypatch.setenv("RANK", "0") + monkeypatch.setenv("LOCAL_RANK", "0") + monkeypatch.setenv("WORLD_SIZE", "1") + monkeypatch.setenv("MASTER_ADDR", "localhost") + monkeypatch.setenv("MASTER_PORT", "9948") + + # Load config + dummy_config_path = _ROOT_DIR / Path("config_files/config_example_coca.yaml") + config_dict = load_app_config_dict(dummy_config_path) + + # Disable checkpointing + config_dict["checkpointing"]["config"]["checkpointing_strategy"]["config"]["k"] = 0 + + main = Main(config_dict, dummy_config_path) + main.run() diff --git a/tests/models/vision_transformer/test_vision_transformer.py b/tests/models/vision_transformer/test_vision_transformer.py new file mode 100644 index 00000000..24b03921 --- /dev/null +++ b/tests/models/vision_transformer/test_vision_transformer.py @@ -0,0 +1,52 @@ +from pathlib import Path + +import pytest +import torch + +from modalities.__main__ import load_app_config_dict +from modalities.models.vision_transformer.vision_transformer_model import VisionTransformer, VisionTransformerConfig +from tests.conftest import _ROOT_DIR + + +def test_vision_transformer(): + # Create model + config_file_path = _ROOT_DIR / Path("tests/models/vision_transformer/vision_transformer_config.yaml") + config_dict = load_app_config_dict(config_file_path=config_file_path) + config = VisionTransformerConfig.model_validate(config_dict) + model = VisionTransformer(**dict(config)) + + # Create dummy inputs + dummy_input_image = torch.randn(1, 3, 224, 224) + dummy_input = dict(images=dummy_input_image) + + # Create optimizer + optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + + # Run one training step + optimizer.zero_grad() + out = model(dummy_input) + loss = out["logits"].sum() + loss.backward() + optimizer.step() + + # Test outputs + assert "logits" in out + assert out["logits"].shape == (1, 1000) + + +@pytest.mark.parametrize( + "img_size,patch_size,patch_stride,add_cls_token,target_block_size", + [ + ((224, 224), 16, 16, True, 197), + ((224, 224), 16, 16, False, 196), + ((224, 112), 16, 16, False, 98), + ((480, 480), 16, 16, False, 900), + ((480 + 1, 480 + 1), 16, 16, False, 900), + ((224, 224), 8, 16, True, 197), + ((224, 224), 16, 8, True, 730), + ((224, 224), 8, 8, True, 785), + ], +) +def test_vision_transformer_block_size(img_size, patch_size, patch_stride, add_cls_token, target_block_size): + block_size = VisionTransformer._calculate_block_size(img_size, patch_size, patch_stride, add_cls_token) + assert block_size == target_block_size diff --git a/tests/models/vision_transformer/vision_transformer_config.yaml b/tests/models/vision_transformer/vision_transformer_config.yaml new file mode 100644 index 00000000..d6657c5c --- /dev/null +++ b/tests/models/vision_transformer/vision_transformer_config.yaml @@ -0,0 +1,13 @@ +sample_key: images +prediction_key: logits +img_size: 224 +n_classes: 1000 +n_layer: 6 +n_head: 8 +n_embd: 768 +dropout: 0.0 +patch_size: 16 +patch_stride: 16 +n_img_channels: 3 +add_cls_token: True +bias: True diff --git a/tests/nn/test_attention.py b/tests/nn/test_attention.py new file mode 100644 index 00000000..cdd8d314 --- /dev/null +++ b/tests/nn/test_attention.py @@ -0,0 +1,22 @@ +import pytest +import torch + +from modalities.nn.attention import AttentionType, MultiHeadAttention + + +@pytest.mark.parametrize( + "attention_type", [AttentionType.CAUSAL_SELF_ATTENTION, AttentionType.NON_CAUSAL_SELF_ATTENTION] +) +def test_attention_forward(attention_type): + model = MultiHeadAttention(n_embd=64, n_head=8, attention_type=attention_type) + dummy_input = torch.randn(1, 256, 64) + out = model(dummy_input) + assert out.shape == (1, 256, 64) + + +def test_attention_with_cross_attention_forward(): + model = MultiHeadAttention(n_embd=64, n_head=8, attention_type=AttentionType.CROSS_ATTENTION) + dummy_input = torch.randn(1, 256, 64) + dummy_context = torch.randn(1, 16, 64) + out = model(dummy_input, context=dummy_context) + assert out.shape == (1, 256, 64) diff --git a/tests/nn/test_mlp.py b/tests/nn/test_mlp.py new file mode 100644 index 00000000..0c94e69b --- /dev/null +++ b/tests/nn/test_mlp.py @@ -0,0 +1,10 @@ +import torch + +from modalities.nn.mlp import MLP + + +def test_mlp_forward(): + model = MLP(in_features=64, hidden_features=256) + dummy_input = torch.randn(1, 10, 64) + out = model(dummy_input) + assert out.shape == (1, 10, 64)