diff --git a/config_files/training/config_mem_map_mamba.yaml b/config_files/training/config_mem_map_mamba.yaml new file mode 100644 index 00000000..2bdd6c41 --- /dev/null +++ b/config_files/training/config_mem_map_mamba.yaml @@ -0,0 +1,217 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + training: + callback_interval_in_samples: 32768 + global_num_training_samples: 2048 + global_num_seen_samples: 0 + do_apply_activation_checkpointing: false + gradient_acc_steps: 1 + local_train_micro_batch_size: 16 + sequence_length: 4096 + gradient_clipping: + mode: NONE + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints + + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_megatron + config: + raw_data_path: /raid/s3/opengptx/max_lue/modalities/data/sample_datasets/redpajama_v2/mem_map/redpyjama_v2_default_DE_num_docs_1024/redpyjama_v2_default_DE_num_docs_1024.pbin + block_size: ${settings.training.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_megatron + config: + raw_data_path: /raid/s3/opengptx/max_lue/modalities/data/sample_datasets/redpajama_v2/mem_map/redpyjama_v2_default_DE_num_docs_1024/redpyjama_v2_default_DE_num_docs_1024.pbin + block_size: ${settings.training.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: "val" + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + +checkpointing: + component_key: checkpointing + variant_key: default + config: + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [ MambaBlock ] + +model: + component_key: model + variant_key: mamba + config: + d_model: 16 + n_layer: 2 + vocab_size: 50257 + rms_norm: True + ssm_cfg: {} + residual_in_fp32: True + fused_add_norm: True + pad_vocab_size_multiple: 8 + tie_embeddings: True + prediction_key: logits + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [ MambaBlock ] + +scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: 64 + pct_start: 0.01 + anneal_strategy: cos + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [ 0.9, 0.95 ] + eps: 1e-8 + weight_decay: 1e-1 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_samples: ${settings.training.global_num_seen_samples} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: ONLINE + experiment_id: ${settings.experiment_id} + directory: "." \ No newline at end of file diff --git a/config_files/training/config_mem_map_mamba_overfitting.yaml b/config_files/training/config_mem_map_mamba_overfitting.yaml new file mode 100644 index 00000000..24ea0d1a --- /dev/null +++ b/config_files/training/config_mem_map_mamba_overfitting.yaml @@ -0,0 +1,247 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + training: + global_training_log_interval_in_steps: 10 + global_checkpointing_interval_in_steps: 1000 + global_evaluation_interval_in_steps: 64 + global_num_seen_steps: 0 + do_apply_activation_checkpointing: false + gradient_acc_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 2048 + gradient_clipping: + mode: p2_norm + threshold: 1.0 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints + +cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: /raid/s3/opengptx/user/richard-rutmann/projects/Modalities/modalities/data/data_overfitting/data_overfitting_en.pbin + block_size: ${settings.training.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +val_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: /raid/s3/opengptx/user/richard-rutmann/projects/Modalities/modalities/data/data_overfitting/data_overfitting_en.pbin + block_size: ${settings.training.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: repeating_data_loader + config: + reshuffle_after_epoch: false + num_epochs: 1 # 100 epochs + dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: val + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + + +checkpointing: + component_key: checkpointing + variant_key: default + config: + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [ MambaBlock ] + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [MambaBlock] + +model: + component_key: model + variant_key: mamba + config: + d_model: 768 + n_layer: 24 + vocab_size: 50257 + rms_norm: true + residual_in_fp32: true + fused_add_norm: true + pad_vocab_size_multiple: 8 + tie_embeddings: true + prediction_key: ${settings.referencing_keys.prediction_key} + sample_key: ${settings.referencing_keys.sample_key} + seed: null + dtype: null + initializer_cfg: {} + num_last_tokens: 0 + inference_params: {} + mixer_model_config: + norm_epsilon: 1e-5 + device: null + mamba_block_config: + d_state: 16 + d_conv: 4 + expand: 2 + dt_rank: auto + dt_min: 0.001 + dt_max: 0.1 + dt_init: random + dt_scale: 1.0 + dt_init_floor: 1e-4 + conv_bias: true + bias: false + use_fast_path: true + + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +scheduler: + component_key: scheduler + variant_key: dummy_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_steps: ${settings.training.global_num_seen_steps} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: ONLINE + experiment_id: ${settings.experiment_id} + directory: "." \ No newline at end of file diff --git a/config_files/training/config_mem_map_mamba_small_scale.yaml b/config_files/training/config_mem_map_mamba_small_scale.yaml new file mode 100644 index 00000000..80e99357 --- /dev/null +++ b/config_files/training/config_mem_map_mamba_small_scale.yaml @@ -0,0 +1,241 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + training: + global_training_log_interval_in_steps: 10 + global_checkpointing_interval_in_steps: 1000 + global_evaluation_interval_in_steps: 64 + global_num_seen_steps: 0 + do_apply_activation_checkpointing: false + gradient_acc_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 0 # TODO: Is sequence_length used in training? + gradient_clipping: + mode: p2_norm + threshold: 1.0 + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpointing_path: data/checkpoints + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: /raid/fhgiais/opengptx/michaelf/git_repos/modalities/data-temp/en/modalities/2048/train_2048.pbin + block_size: ${settings.training.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +val_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: /raid/fhgiais/opengptx/michaelf/git_repos/modalities/data-temp/en/modalities/2048/valid_2048.pbin + block_size: ${settings.training.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: repeating_data_loader + config: + reshuffle_after_epoch: false + num_epochs: 1 # 100 epochs + dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + shuffle: false + dataloader_tag: val + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.training.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + + +checkpointing: + component_key: checkpointing + variant_key: default + config: + checkpointing_strategy: + component_key: checkpointing_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: 3 # -1 to save all checkpoints + checkpointing_execution: + component_key: checkpointing_execution + variant_key: fsdp_to_disc_checkpointing + config: + checkpoint_path: ${settings.paths.checkpointing_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [ MambaBlock ] + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: BF_16 + sharding_strategy: FULL_SHARD + block_names: [ MambaBlock ] + +model: + component_key: model + variant_key: mamba + config: + d_model: 768 + n_layer: 24 + vocab_size: 50257 + rms_norm: true + residual_in_fp32: true + fused_add_norm: true + pad_vocab_size_multiple: 8 + tie_embeddings: true + prediction_key: ${settings.referencing_keys.prediction_key} + sample_key: ${settings.referencing_keys.sample_key} + seed: null + dtype: null + initializer_cfg: {} + num_last_tokens: 0 + inference_params: {} + mixer_model_config: + norm_epsilon: 1e-5 + device: null + mamba_block_config: + d_state: 16 + d_conv: 4 + expand: 2 + dt_rank: auto + dt_min: 0.001 + dt_max: 0.1 + dt_init: random + dt_scale: 1.0 + dt_init_floor: 1e-4 + conv_bias: true + bias: false + use_fast_path: true + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [ 0.9, 0.95 ] + eps: 1e-8 + weight_decay: 1e-1 + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +scheduler: + component_key: scheduler + variant_key: dummy_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + + +batch_progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + local_rank: ${settings.cuda_env.local_rank} + world_size: ${settings.cuda_env.world_size} + global_num_seen_steps: ${settings.training.global_num_seen_steps} + train_dataloader: + instance_key: train_dataloader + pass_type: BY_REFERENCE + eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + local_rank: ${settings.cuda_env.local_rank} + project: modalities + mode: ONLINE + experiment_id: ${settings.experiment_id} + directory: "." \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e9121daa..a425329b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,57 +26,57 @@ dependencies = [ "flash-attn", # install this directly via `pip install flash-attn --no-build-isolation` ] -[project.optional-dependencies] -linting = ["pre-commit"] -tests = ["pytest", "pytest-cov"] -install_helper = ["ninja"] - -[project.scripts] -modalities = "modalities.__main__:main" - -[build-system] -requires = ["setuptools >= 61.0.0"] -build-backend = "setuptools.build_meta" - -[tool.black] -target-version = ["py310"] -line-length = 120 - -[tool.isort] -profile = "black" -line_length = 120 - -[tool.ruff] -line-length = 120 - -[tool.pytest.ini_options] -addopts = "--cov=src --cov-report term --cov-report html" - -[tool.coverage.run] -branch = true -omit = ["*/src/modalities/dataloader/open_gptx_dataset/*"] - -[tool.coverage.report] -# Regexes for lines to exclude from consideration -exclude_also = [ - # Don't complain about missing debug-only code: - "def __repr__", - "if self\\.debug", - - # Don't complain if tests don't hit defensive assertion code: - "raise AssertionError", - "raise NotImplementedError", - - # Don't complain if non-runnable code isn't run: - "if 0:", - "if __name__ == .__main__.:", - - # Don't complain about abstract methods, they aren't run: - "@(abc\\.)?abstractmethod", -] - - -ignore_errors = true - -[tool.coverage.html] -directory = "coverage_html_report" +#[project.optional-dependencies] +#linting = ["pre-commit"] +#tests = ["pytest", "pytest-cov"] +#install_helper = ["ninja"] +# +#[project.scripts] +#modalities = "modalities.__main__:main" +# +#[build-system] +#requires = ["setuptools >= 61.0.0"] +#build-backend = "setuptools.build_meta" +# +#[tool.black] +#target-version = ["py310"] +#line-length = 120 +# +#[tool.isort] +#profile = "black" +#line_length = 120 +# +#[tool.ruff] +#line-length = 120 +# +#[tool.pytest.ini_options] +#addopts = "--cov=src --cov-report term --cov-report html" +# +#[tool.coverage.run] +#branch = true +#omit = ["*/src/modalities/dataloader/open_gptx_dataset/*"] +# +#[tool.coverage.report] +## Regexes for lines to exclude from consideration +#exclude_also = [ +# # Don't complain about missing debug-only code: +# "def __repr__", +# "if self\\.debug", +# +# # Don't complain if tests don't hit defensive assertion code: +# "raise AssertionError", +# "raise NotImplementedError", +# +# # Don't complain if non-runnable code isn't run: +# "if 0:", +# "if __name__ == .__main__.:", +# +# # Don't complain about abstract methods, they aren't run: +# "@(abc\\.)?abstractmethod", +#] +# +# +#ignore_errors = true +# +#[tool.coverage.html] +#directory = "coverage_html_report" diff --git a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py index ede0f4f4..9f2a26fd 100644 --- a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py +++ b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py @@ -77,6 +77,7 @@ def _save_checkpoint(self, model: FSDP, optimizer: Optimizer, train_step_id: int train_step_id=train_step_id, entity_type=CheckpointingEntityType.MODEL, ) + model_checkpoint_path.parent.mkdir(parents=True, exist_ok=True) torch.save(model_state, model_checkpoint_path) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index b531cf21..6ddd85fd 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -1,6 +1,6 @@ import math from functools import partial -from typing import Annotated, Dict, Tuple +from typing import Annotated, Dict, Tuple, List import torch from einops import repeat @@ -14,6 +14,7 @@ from modalities.models.model import NNModel from modalities.models.vision_transformer.vision_transformer_model import VisionTransformer, VisionTransformerConfig from modalities.nn.attention import AttentionConfig +from transformers import PreTrainedTokenizer class TextDecoderConfig(BaseModel): diff --git a/src/modalities/models/coca/multi_modal_decoder.py b/src/modalities/models/coca/multi_modal_decoder.py index cced19b4..d0d92ccb 100644 --- a/src/modalities/models/coca/multi_modal_decoder.py +++ b/src/modalities/models/coca/multi_modal_decoder.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Dict +from typing import Dict, List import torch import xformers.ops as xops @@ -9,22 +9,23 @@ from modalities.models.model import NNModel from modalities.nn.attention import AttentionConfig, AttentionType, MultiHeadAttention from modalities.nn.mlp import MLP +from transformers import PreTrainedTokenizer class TransformerBlock(nn.Module): def __init__( - self, - n_embd: int, - bias: bool, - epsilon: float, - activation: ActivationType, - n_head: int, - dropout: float, - ffn_hidden: int, - with_context: bool, - attention_type: AttentionType, - attention_config: AttentionConfig = None, - add_extra_mlp: bool = False, + self, + n_embd: int, + bias: bool, + epsilon: float, + activation: ActivationType, + n_head: int, + dropout: float, + ffn_hidden: int, + with_context: bool, + attention_type: AttentionType, + attention_config: AttentionConfig = None, + add_extra_mlp: bool = False, ): super().__init__() self.with_context = with_context @@ -70,20 +71,20 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor class MultiModalTextDecoder(NNModel): def __init__( - self, - sample_key: str, - prediction_key: str, - block_size: int, - vocab_size: int, - n_layer: int, - n_head: int, - n_embd: int, - ffn_hidden: int, - dropout: float, - bias: bool, - activation: ActivationType, - epsilon: float, - attention_config: AttentionConfig, + self, + sample_key: str, + prediction_key: str, + block_size: int, + vocab_size: int, + n_layer: int, + n_head: int, + n_embd: int, + ffn_hidden: int, + dropout: float, + bias: bool, + activation: ActivationType, + epsilon: float, + attention_config: AttentionConfig, ): super().__init__() self.sample_key = sample_key diff --git a/src/modalities/models/coca/text_decoder.py b/src/modalities/models/coca/text_decoder.py index aab7c8f8..8ac11739 100644 --- a/src/modalities/models/coca/text_decoder.py +++ b/src/modalities/models/coca/text_decoder.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, List import torch from torch import nn @@ -7,24 +7,25 @@ from modalities.models.gpt2.gpt2_model import ActivationType from modalities.models.model import NNModel from modalities.nn.attention import AttentionConfig, AttentionType +from transformers import PreTrainedTokenizer class TextDecoder(NNModel): def __init__( - self, - sample_key: str, - prediction_key: str, - block_size: int, - vocab_size: int, - n_layer: int, - n_head: int, - n_embd: int, - ffn_hidden: int, - dropout: float, - bias: bool, - activation: ActivationType, - epsilon: float, - attention_config: AttentionConfig = None, + self, + sample_key: str, + prediction_key: str, + block_size: int, + vocab_size: int, + n_layer: int, + n_head: int, + n_embd: int, + ffn_hidden: int, + dropout: float, + bias: bool, + activation: ActivationType, + epsilon: float, + attention_config: AttentionConfig = None, ): super().__init__() self.sample_key = sample_key diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index ca83eb0d..9fa4063e 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -1,4 +1,5 @@ import math +import sys from copy import deepcopy from enum import Enum from functools import partial @@ -9,12 +10,15 @@ import xformers.ops as xops from flash_attn import flash_attn_func from pydantic import BaseModel, Field, model_validator, validator +from torch.nn import functional as F +from transformers import PreTrainedTokenizer from modalities.config.pydanctic_if_types import PydanticPytorchModuleType from modalities.config.utils import convert_base_model_config_to_dict from modalities.models.model import NNModel from modalities.util import parse_enum_by_name + # GPT2 implementation taken from nanogpt https://github.com/karpathy/nanoGPT @@ -25,20 +29,20 @@ class PositionTypes(str, Enum): class QueryKeyValueTransform(nn.Module): def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pass class IdentityTransform(QueryKeyValueTransform): def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return q, k, v @@ -90,7 +94,7 @@ def apply_rotary_pos_emb(self, x, cos, sin): return (x * cos) + (self.rotate_half(x) * sin) def forward( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k) q = self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached) @@ -165,7 +169,7 @@ def check_divisibility(self) -> "GPT2LLMConfig": @model_validator(mode="after") def validate_sizes(self) -> "GPT2LLMConfig": for param, param_name in zip( - [self.ffn_hidden, self.vocab_size, self.n_embd], ["ffn_hidden", "vocab_size", "n_embd"] + [self.ffn_hidden, self.vocab_size, self.n_embd], ["ffn_hidden", "vocab_size", "n_embd"] ): if param % 128 != 0: # See https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc @@ -175,14 +179,14 @@ def validate_sizes(self) -> "GPT2LLMConfig": class CausalSelfAttention(nn.Module): def __init__( - self, - n_head_q: int, - n_head_kv: int, - n_embd: int, - attention_config: AttentionConfig, - bias: bool, - dropout: float, - block_size: int, + self, + n_head_q: int, + n_head_kv: int, + n_embd: int, + attention_config: AttentionConfig, + bias: bool, + dropout: float, + block_size: int, ): super().__init__() assert n_embd % n_head_q == 0, "`n_embd needs` to be divisible by `n_head_q`." @@ -237,7 +241,7 @@ def projection(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch @staticmethod def execute_qkv_transforms( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qkv_transforms: nn.ModuleList, n_head_q: int + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qkv_transforms: nn.ModuleList, n_head_q: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, block_size, embedding_dim = q.size() n_head_dim = embedding_dim // n_head_q @@ -296,18 +300,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class GPT2Block(nn.Module): def __init__( - self, - n_embd: int, - bias: bool, - n_head_q: int, - n_head_kv: int, - activation_type: ActivationType, - attention_config: AttentionConfig, - dropout: float, - block_size: int, - ffn_hidden: int, - attention_norm: nn.Module, - ffn_norm: nn.Module, + self, + n_embd: int, + bias: bool, + n_head_q: int, + n_head_kv: int, + activation_type: ActivationType, + attention_config: AttentionConfig, + dropout: float, + block_size: int, + ffn_hidden: int, + attention_norm: nn.Module, + ffn_norm: nn.Module, ): super().__init__() self.attention_norm = attention_norm @@ -338,28 +342,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class GPT2LLM(NNModel): + + def __init__( - self, - sample_key: str, - prediction_key: str, - poe_type: PositionTypes, - block_size: int, - vocab_size: int, - n_layer: int, - n_head_q: int, - n_head_kv: int, - n_embd: int, - ffn_hidden: int, - dropout: float, - bias: bool, - activation_type: ActivationType, - weight_init: WeightInitializationConfig, - attention_config: AttentionConfig, - attention_norm: nn.Module, - ffn_norm: nn.Module, - lm_head_norm: nn.Module, + self, + sample_key: str, + prediction_key: str, + poe_type: PositionTypes, + block_size: int, + vocab_size: int, + n_layer: int, + n_head_q: int, + n_head_kv: int, + n_embd: int, + ffn_hidden: int, + dropout: float, + bias: bool, + activation_type: ActivationType, + weight_init: WeightInitializationConfig, + attention_config: AttentionConfig, + attention_norm: nn.Module, + ffn_norm: nn.Module, + lm_head_norm: nn.Module, + seed: int = None ): - super().__init__() + + super().__init__(seed=seed) self.sample_key = sample_key self.prediction_key = prediction_key self.block_size = block_size diff --git a/src/modalities/models/huggingface/huggingface_models.py b/src/modalities/models/huggingface/huggingface_models.py index 4c66d46f..b80f222d 100644 --- a/src/modalities/models/huggingface/huggingface_models.py +++ b/src/modalities/models/huggingface/huggingface_models.py @@ -3,11 +3,12 @@ import torch from pydantic import BaseModel -from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoTokenizer, PreTrainedTokenizer from modalities.config.lookup_enum import LookupEnum from modalities.models.model import NNModel + # Huggingface Model dependencies # # ModuleUtilsMixin @@ -37,15 +38,16 @@ class HuggingFacePretrainedModelConfig(BaseModel): class HuggingFacePretrainedModel(NNModel): + def __init__( - self, - model_type: HuggingFaceModelTypes, - model_name: str, - prediction_key: str, - huggingface_prediction_subscription_key: str, - sample_key: str, - model_args: Optional[Any] = None, - kwargs: Optional[Any] = None, + self, + model_type: HuggingFaceModelTypes, + model_name: str, + prediction_key: str, + huggingface_prediction_subscription_key: str, + sample_key: str, + model_args: Optional[Any] = None, + kwargs: Optional[Any] = None, ): super().__init__() if model_args is None: diff --git a/src/modalities/models/mamba/__init__.py b/src/modalities/models/mamba/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/modalities/models/mamba/mamba_block.py b/src/modalities/models/mamba/mamba_block.py new file mode 100644 index 00000000..1bc71bb1 --- /dev/null +++ b/src/modalities/models/mamba/mamba_block.py @@ -0,0 +1,364 @@ +# Copyright (c) 2023, Tri Dao, Albert Gu. + +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pydantic import BaseModel +from torch import Tensor + +from einops import rearrange, repeat + +from modalities.models.mamba.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn + +try: + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +except ImportError: + causal_conv1d_fn, causal_conv1d_update = None, None + +try: + from modalities.models.mamba.ops.triton.selective_state_update import selective_state_update +except ImportError: + selective_state_update = None + +try: + from modalities.models.mamba.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn +except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + + +class MambaBlock(nn.Module): + def __init__( + self, + d_model: int, + d_state: int, + d_conv: int, + expand: int, + dt_rank: str, + dt_min: float, + dt_max: float, + dt_init: str, + dt_scale: float, + dt_init_floor: float, + conv_bias: bool, + bias: bool, + use_fast_path: bool, + layer_idx: int, + device: Optional[str], + dtype: Optional[str], + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + self.use_fast_path = use_fast_path + self.layer_idx = layer_idx + + self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) + + self.conv1d = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + bias=conv_bias, + kernel_size=d_conv, + groups=self.d_inner, + padding=d_conv - 1, + **factory_kwargs, + ) + + self.activation = "silu" + self.act = nn.SiLU() + + self.x_proj = nn.Linear( + self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs + ) + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) + + # Initialize special dt projection to preserve variance at initialization + dt_init_std = self.dt_rank ** -0.5 * dt_scale + if dt_init == "constant": + nn.init.constant_(self.dt_proj.weight, dt_init_std) + elif dt_init == "random": + nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) + else: + raise NotImplementedError + + # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + dt = torch.exp( + torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ).clamp(min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + self.dt_proj.bias.copy_(inv_dt) + # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + self.dt_proj.bias._no_reinit = True + + # S4D real initialization + A = repeat( + torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), + "n -> d n", + d=self.d_inner, + ).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) + + def forward(self, hidden_states: torch.Tensor, inference_params: Optional[dict] = None) -> torch.Tensor: + """ + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + batch, seqlen, dim = hidden_states.shape + + conv_state, ssm_state = None, None + if inference_params: + conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # The states are updated inplace + out, _, _ = self.step(hidden_states, conv_state, ssm_state) + return out + + # We do matmul and transpose BLH -> HBL at the same time (Batch size, sequence length, hidden dim) + xz = rearrange( + self.in_proj.weight @ rearrange(hidden_states.to(self.in_proj.weight.dtype), "b l d -> d (b l)"), + "d (b l) -> b d l", + l=seqlen, + ) + if self.in_proj.bias is not None: + xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") + + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + # In the backward pass we write dx and dz next to each other to avoid torch.cat + if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states + out = mamba_inner_fn( + xz, + self.conv1d.weight, + self.conv1d.bias, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias, + A, + None, # input-dependent B + None, # input-dependent C + self.D.float(), + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + ) + else: + x, z = xz.chunk(2, dim=1) + # Compute short convolution + if conv_state is not None: + # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) + if causal_conv1d_fn is None: + x = self.act(self.conv1d(x)[..., :seqlen]) + else: + assert self.activation in ["silu", "swish"] + x = causal_conv1d_fn( + x=x, + weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), + bias=self.conv1d.bias, + activation=self.activation, + ) + + # We're careful here about the layout, to avoid extra transposes. + # We want dt to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) + dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj.weight @ dt.t() + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + assert self.activation in ["silu", "swish"] + y = selective_scan_fn( + x, + dt, + A, + B, + C, + self.D.float(), + z=z, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=ssm_state is not None, + ) + if ssm_state is not None: + y, last_state = y + ssm_state.copy_(last_state) + y = rearrange(y, "b d l -> b l d") + out = self.out_proj(y) + return out + + def step(self, hidden_states: torch.Tensor, conv_state: torch.Tensor, ssm_state: torch.Tensor) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor]: + dtype = hidden_states.dtype + assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" + xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + + # x goes to the left and z goes to the right in the mamba block + x, z = xz.chunk(2, dim=-1) # (B D) + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = x + x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv1d.bias is not None: + x = x + self.conv1d.bias + x = self.act(x).to(dtype=dtype) + else: + x = causal_conv1d_update( + x, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + x_db = self.x_proj(x) # (B dt_rank+2*d_state) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + # Don't add dt_bias here + dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + + # SSM step + if selective_state_update is None: + # Discretize A and B + dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) + dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) + dB = torch.einsum("bd,bn->bdn", dt, B) + ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) + y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) + y = y + self.D.to(dtype) * x + y = y * self.act(z) # (B D) + else: + y = selective_state_update( + ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ) + + out = self.out_proj(y) + return out.unsqueeze(1), conv_state, ssm_state + + def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: Optional[str] = None, **kwargs) -> \ + Tuple[torch.Tensor, torch.Tensor]: + device = self.out_proj.weight.device + conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype + conv_state = torch.zeros( + batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype + ) + ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype + # ssm_dtype = torch.float32 + ssm_state = torch.zeros( + batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype + ) + return conv_state, ssm_state + + def _get_states_from_cache(self, inference_params: Optional[dict], batch_size: int, + initialize_states: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict: + batch_shape = (batch_size,) + conv_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_conv, + device=self.conv1d.weight.device, + dtype=self.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.d_model * self.expand, + self.d_state, + device=self.dt_proj.weight.device, + dtype=self.dt_proj.weight.dtype, + # dtype=torch.float32, + ) + inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) + else: + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + # TODO: What if batch size changes between generation, and we reuse the same states? + if initialize_states: + conv_state.zero_() + ssm_state.zero_() + return conv_state, ssm_state + + +class Block(nn.Module): + def __init__( + self, + d_model: int, + mixer_cls: MambaBlock, + norm_cls: nn.LayerNorm, + fused_add_norm: bool, + residual_in_fp32: bool, + ): + """ + Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" + + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA/MLP -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Add -> LN -> Mixer, returning both + the hidden_states (output of the mixer) and the residual. + This is purely for performance reasons, as we can fuse add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.mixer = mixer_cls(d_model) + self.norm = norm_cls(d_model) + if self.fused_add_norm: + assert RMSNorm is not None, "RMSNorm import fails" + assert isinstance( + self.norm, (nn.LayerNorm, RMSNorm) + ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" + + def forward( + self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Mixer(LN(residual)) + """ + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn + hidden_states, residual = fused_add_norm_fn( + hidden_states, + self.norm.weight, + self.norm.bias, + residual=residual, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + eps=self.norm.eps, + ) + hidden_states = self.mixer(hidden_states, inference_params=inference_params) + return hidden_states, residual + + def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: Optional[str] = None, **kwargs) -> dict: + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) diff --git a/src/modalities/models/mamba/mamba_config.py b/src/modalities/models/mamba/mamba_config.py new file mode 100644 index 00000000..573e7681 --- /dev/null +++ b/src/modalities/models/mamba/mamba_config.py @@ -0,0 +1,43 @@ +from typing import Optional + +from pydantic import BaseModel + + +class MambaBlockConfig(BaseModel): + d_state: int + d_conv: int + expand: int + dt_rank: str + dt_min: float + dt_max: float + dt_init: str + dt_scale: float + dt_init_floor: float + conv_bias: bool + bias: bool + use_fast_path: bool + + +class MixerModelConfig(BaseModel): + norm_epsilon: float + device: Optional[str] + mamba_block_config: MambaBlockConfig + + +class MambaLLMConfig(BaseModel): + d_model: int + n_layer: int + vocab_size: int + rms_norm: bool + residual_in_fp32: bool + fused_add_norm: bool + pad_vocab_size_multiple: int + tie_embeddings: bool + prediction_key: str + sample_key: str + seed: Optional[int] + dtype: Optional[int] + initializer_cfg: dict + num_last_tokens: int + inference_params: dict + mixer_model_config: MixerModelConfig diff --git a/src/modalities/models/mamba/mamba_model.py b/src/modalities/models/mamba/mamba_model.py new file mode 100644 index 00000000..ed832a6c --- /dev/null +++ b/src/modalities/models/mamba/mamba_model.py @@ -0,0 +1,239 @@ +# Copyright (c) 2023, Albert Gu, Tri Dao. + +import math +import sys +from functools import partial +from typing import Dict, Optional, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from modalities.models.mamba.mamba_block import Block, MambaBlock +from modalities.models.mamba.mamba_config import MixerModelConfig, MambaBlockConfig +from modalities.models.model import NNModel +from transformers import PreTrainedTokenizer + +try: + from modalities.models.mamba.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn +except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + + +def create_block( + d_model: int, + ssm_cfg: dict, + norm_epsilon: float, + rms_norm: bool, + residual_in_fp32: bool, + fused_add_norm: bool, + layer_idx: int, + device: str, + dtype: str, +) -> Block: + factory_kwargs = {"device": device, "dtype": dtype} + mixer_cls = partial(MambaBlock, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) + norm_cls = partial( + nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs + ) + block = Block(d_model=d_model, mixer_cls=mixer_cls, norm_cls=norm_cls, fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32) + block.layer_idx = layer_idx + return block + + +# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 +def _init_weights( + module: nn.Module, + n_layer: int, + initializer_range: float = 0.02, # Now only used for embedding layer. + rescale_prenorm_residual: bool = True, + n_residuals_per_layer: int = 1, # Change to 2 if we have MLP +) -> None: + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(n_residuals_per_layer * n_layer) + + +class MixerModel(nn.Module): + def __init__( + self, + d_model: int, + n_layer: int, + vocab_size: int, + norm_epsilon: float, + rms_norm: bool, + initializer_cfg: dict, + fused_add_norm: bool, + residual_in_fp32: bool, + device: Optional[str], + dtype: Optional[str], + mamba_block_config: MambaBlockConfig, + ) -> None: + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + if self.fused_add_norm: + if layer_norm_fn is None or rms_norm_fn is None: + raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") + self.layers = nn.ModuleList( + [ + create_block( + d_model, + norm_epsilon=norm_epsilon, + rms_norm=rms_norm, + residual_in_fp32=residual_in_fp32, + fused_add_norm=fused_add_norm, + layer_idx=i, + ssm_cfg=mamba_block_config.model_dump(), + **factory_kwargs, + ) + for i in range(n_layer) + ] + ) + self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( + d_model, eps=norm_epsilon, **factory_kwargs + ) + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) + + def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: Optional[str] = None, **kwargs) -> dict: + return { + i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + for i, layer in enumerate(self.layers) + } + + def forward(self, input_ids: torch.Tensor, inference_params: Optional[dict] = None) -> torch.Tensor: + hidden_states = self.embedding(input_ids) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + hidden_states, residual, inference_params=inference_params + ) + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn + hidden_states = fused_add_norm_fn( + hidden_states, + self.norm_f.weight, + self.norm_f.bias, + eps=self.norm_f.eps, + residual=residual, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, + ) + return hidden_states + + +class MambaLLM(NNModel): + + def __init__( + self, + d_model: int, + n_layer: int, + vocab_size: int, + rms_norm: bool, + residual_in_fp32: bool, + fused_add_norm: bool, + pad_vocab_size_multiple: int, + tie_embeddings: bool, + prediction_key: str, + sample_key: str, + seed: Optional[int], + dtype: Optional[str], + initializer_cfg: dict, + num_last_tokens: int, + inference_params: dict, + mixer_model_config: MixerModelConfig, + ): + super().__init__(seed=seed) + + self.d_model = d_model + self.n_layer = n_layer + self.vocab_size = vocab_size + self.rms_norm = rms_norm + self.residual_in_fp32 = residual_in_fp32 + self.fused_add_norm = fused_add_norm + self.pad_vocab_size_multiple = pad_vocab_size_multiple + self.tie_embeddings = tie_embeddings + self.prediction_key = prediction_key + self.sample_key = sample_key + self.dtype = dtype + self.initializer_cfg = initializer_cfg + self.mixer_model_config = mixer_model_config + + # todo: How to pass these variables in the forward method? + self.inference_params = inference_params + self.num_last_tokens = num_last_tokens + + if self.vocab_size % self.pad_vocab_size_multiple != 0: + self.vocab_size += self.pad_vocab_size_multiple - (self.vocab_size % self.pad_vocab_size_multiple) + self.backbone = MixerModel( + d_model=self.d_model, + n_layer=self.n_layer, + vocab_size=self.vocab_size, + rms_norm=self.rms_norm, + initializer_cfg=self.initializer_cfg, + fused_add_norm=self.fused_add_norm, + residual_in_fp32=self.residual_in_fp32, + dtype=self.dtype, + norm_epsilon=self.mixer_model_config.norm_epsilon, + device=self.mixer_model_config.device, + mamba_block_config=self.mixer_model_config.mamba_block_config + ) + self.lm_head = nn.Linear(self.d_model, self.vocab_size, bias=False, dtype=self.dtype) + self.apply( + partial( + _init_weights, + n_layer=self.n_layer, + **initializer_cfg, + ) + ) + self.tie_weights() + + def tie_weights(self) -> None: + if self.tie_embeddings: + self.lm_head.weight = self.backbone.embedding.weight + + def allocate_inference_cache(self, batch_size: int, max_seqlen: int, dtype: str = None, **kwargs) -> dict: + return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + num_last_tokens: if > 0, only return the logits for the last n tokens + """ + + hidden_states = self.backbone(inputs[self.sample_key], inference_params=self.inference_params) + if self.num_last_tokens > 0: + hidden_states = hidden_states[:, -self.num_last_tokens:] + lm_logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)) + return {self.prediction_key: lm_logits} diff --git a/src/modalities/models/mamba/ops/__init__.py b/src/modalities/models/mamba/ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/modalities/models/mamba/ops/selective_scan_interface.py b/src/modalities/models/mamba/ops/selective_scan_interface.py new file mode 100644 index 00000000..c3596bfe --- /dev/null +++ b/src/modalities/models/mamba/ops/selective_scan_interface.py @@ -0,0 +1,357 @@ +# Copyright (c) 2023, Tri Dao, Albert Gu. + +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd + +from einops import rearrange, repeat + +try: + from causal_conv1d import causal_conv1d_fn + import causal_conv1d_cuda +except ImportError: + causal_conv1d_fn = None + causal_conv1d_cuda = None + +import selective_scan_cuda + + +class SelectiveScanFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = rearrange(B, "b dstate l -> b 1 dstate l") + ctx.squeeze_B = True + if C.dim() == 3: + C = rearrange(C, "b dstate l -> b 1 dstate l") + ctx.squeeze_C = True + out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) + ctx.delta_softplus = delta_softplus + ctx.has_z = z is not None + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if not ctx.has_z: + ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) + return out if not return_last_state else (out, last_state) + else: + ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) + + @staticmethod + def backward(ctx, dout, *args): + if not ctx.has_z: + u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors + z = None + out = None + else: + u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors + if dout.stride(-1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + # Here we just pass in None and dz will be allocated in the C++ code. + du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( + u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, + False # option to recompute out_z, not used here + ) + dz = rest[0] if ctx.has_z else None + dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB + dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC + return (du, ddelta, dA, dB, dC, + dD if D is not None else None, + dz, + ddelta_bias if delta_bias is not None else None, + None, + None) + + +def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). Note that the gradient of the last state is + not considered in the backward pass. + """ + return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) + + +def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, + return_last_state=False): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + if A.is_complex(): + if is_variable_B: + B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) + if is_variable_C: + C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) + else: + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + if y.is_complex(): + y = y.real * 2 + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +class MambaInnerFn(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1): + """ + xz: (batch, dim, seqlen) + """ + assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." + assert checkpoint_lvl in [0, 1] + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + if torch.is_autocast_enabled(): + x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype()) + out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype()) + if out_proj_bias is not None else None) + if xz.stride(-1) != 1: + xz = xz.contiguous() + conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") + x, z = xz.chunk(2, dim=1) + conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( + x, conv1d_weight, conv1d_bias, None, None, None, True + ) + # We're being very careful here about the layout, to avoid extra transposes. + # We want delta to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) + ctx.is_variable_B = B is None + ctx.is_variable_C = C is None + ctx.B_proj_bias_is_None = B_proj_bias is None + ctx.C_proj_bias_is_None = C_proj_bias is None + if B is None: # variable B + B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate) + if B_proj_bias is not None: + B = B + B_proj_bias.to(dtype=B.dtype) + if not A.is_complex(): + # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + else: + B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + else: + if B.stride(-1) != 1: + B = B.contiguous() + if C is None: # variable C + C = x_dbl[:, -d_state:] # (bl dstate) + if C_proj_bias is not None: + C = C + C_proj_bias.to(dtype=C.dtype) + if not A.is_complex(): + # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() + else: + C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous() + else: + if C.stride(-1) != 1: + C = C.contiguous() + if D is not None: + D = D.contiguous() + out, scan_intermediates, out_z = selective_scan_cuda.fwd( + conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus + ) + ctx.delta_softplus = delta_softplus + ctx.out_proj_bias_is_None = out_proj_bias is None + ctx.checkpoint_lvl = checkpoint_lvl + if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass + conv1d_out, delta = None, None + ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, + delta_proj_weight, out_proj_weight, conv1d_out, delta, + A, B, C, D, delta_bias, scan_intermediates, out) + return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) + + @staticmethod + @custom_bwd + def backward(ctx, dout): + # dout: (batch, seqlen, dim) + assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." + (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, + conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + x, z = xz.chunk(2, dim=1) + if dout.stride(-1) != 1: + dout = dout.contiguous() + if ctx.checkpoint_lvl == 1: + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( + x, conv1d_weight, conv1d_bias, None, None, None, True + ) + delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), + "d (b l) -> b d l", l = L) + # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the + # backward of selective_scan_cuda with the backward of chunk). + dxz = torch.empty_like(xz) # (batch, dim, seqlen) + dx, dz = dxz.chunk(2, dim=1) + dout = rearrange(dout, "b l e -> e (b l)") + dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L) + dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd( + conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz, + ctx.delta_softplus, + True # option to recompute out_z + ) + dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)")) + dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None + dD = dD if D is not None else None + dx_dbl = torch.empty_like(x_dbl) + dB_proj_bias = None + if ctx.is_variable_B: + if not A.is_complex(): + dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous() + else: + dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None + dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d) + dB = None + dC_proj_bias = None + if ctx.is_variable_C: + if not A.is_complex(): + dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous() + else: + dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous() + dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None + dx_dbl[:, -d_state:] = dC # (bl d) + dC = None + ddelta = rearrange(ddelta, "b d l -> d (b l)") + ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank]) + dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight) + dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)") + dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d")) + dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out) + dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( + x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True + ) + dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None + dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") + return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight, + dout_proj_weight, dout_proj_bias, + dA, dB, dC, dD, + ddelta_bias if delta_bias is not None else None, + dB_proj_bias, dC_proj_bias, None) + + +def mamba_inner_fn( + xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True +): + return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus) + + +def mamba_inner_ref( + xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, + out_proj_weight, out_proj_bias, + A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, + C_proj_bias=None, delta_softplus=True +): + assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." + L = xz.shape[-1] + delta_rank = delta_proj_weight.shape[1] + d_state = A.shape[-1] * (1 if not A.is_complex() else 2) + x, z = xz.chunk(2, dim=1) + x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu") + # We're being very careful here about the layout, to avoid extra transposes. + # We want delta to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d) + delta = delta_proj_weight @ x_dbl[:, :delta_rank].t() + delta = rearrange(delta, "d (b l) -> b d l", l=L) + if B is None: # variable B + B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d) + if B_proj_bias is not None: + B = B + B_proj_bias.to(dtype=B.dtype) + if not A.is_complex(): + B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous() + else: + B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + if C is None: # variable B + C = x_dbl[:, -d_state:] # (bl d) + if C_proj_bias is not None: + C = C + C_proj_bias.to(dtype=C.dtype) + if not A.is_complex(): + C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous() + else: + C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous() + y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True) + return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias) diff --git a/src/modalities/models/mamba/ops/triton/__init__.py b/src/modalities/models/mamba/ops/triton/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/modalities/models/mamba/ops/triton/layernorm.py b/src/modalities/models/mamba/ops/triton/layernorm.py new file mode 100644 index 00000000..ba33ce1e --- /dev/null +++ b/src/modalities/models/mamba/ops/triton/layernorm.py @@ -0,0 +1,635 @@ +# Copyright (c) 2023, Tri Dao. +# Implement residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn.functional as F +from torch.cuda.amp import custom_fwd, custom_bwd + +import triton +import triton.language as tl + + +def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + return out if not prenorm else (out, x) + + +def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): + dtype = x.dtype + if upcast: + weight = weight.float() + bias = bias.float() if bias is not None else None + if upcast: + x = x.float() + residual = residual.float() if residual is not None else residual + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) + out = out.to(dtype) + return out if not prenorm else (out, x) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + RESIDUAL_OUT, # pointer to the residual + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + N, # number of columns in X + eps, # epsilon to avoid division by zero + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd( + x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + assert y.stride(-1) == 1 + if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): + residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + y, + weight, + bias, + residual, + residual_out, + mean, + rstd, + x.stride(0), + y.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + N, + eps, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + bias is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype + return y, mean, rstd, residual_out if residual_out is not None else x + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + DRESIDUAL_IN, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + has_residual=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + dresidual_in, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype: + dresidual_in = dx + return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) + + +class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm + ) + ctx.save_for_backward(residual_out, weight, bias, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + y = y.reshape(x_shape_og) + return y if not prenorm else (y, residual_out.reshape(x_shape_og)) + + @staticmethod + def backward(ctx, dy, *args): + x, weight, bias, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dw, db, dresidual_in = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm) + + +def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6): + return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + ) + + +class LayerNormLinearFn(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + norm_weight = norm_weight.contiguous() + if norm_bias is not None: + norm_bias = norm_bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, mean, rstd, residual_out = _layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual, + ctx.has_residual, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) diff --git a/src/modalities/models/mamba/ops/triton/selective_state_update.py b/src/modalities/models/mamba/ops/triton/selective_state_update.py new file mode 100644 index 00000000..193552a0 --- /dev/null +++ b/src/modalities/models/mamba/ops/triton/selective_state_update.py @@ -0,0 +1,263 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this +""" + +import math +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + + +@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.jit +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, + # Matrix dimensions + batch, nheads, dim, dstate, nheads_ngroups_ratio, + # Strides + stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate, + stride_x_batch, stride_x_head, stride_x_dim, + stride_dt_batch, stride_dt_head, stride_dt_dim, + stride_dt_bias_head, stride_dt_bias_dim, + stride_A_head, stride_A_dim, stride_A_dstate, + stride_B_batch, stride_B_group, stride_B_dstate, + stride_C_batch, stride_C_group, stride_C_dstate, + stride_D_head, stride_D_dim, + stride_z_batch, stride_z_head, stride_z_dim, + stride_out_batch, stride_out_head, stride_out_dim, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + TIE_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + if HAS_DT_BIAS: + dt_bias_ptr += pid_h * stride_dt_bias_head + A_ptr += pid_h * stride_A_head + B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group + if HAS_Z: + z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + if HAS_D: + D_ptr += pid_h * stride_D_head + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_D: + D_ptrs = D_ptr + offs_m * stride_D_dim + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix + + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + if not TIE_HDIM: + dB = B[None, :] * dt[:, None] + else: + dB = B * dt # vector of size (dstate,) + state = state * dA + dB * x[:, None] + tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + +def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)) + # We don't want autotune since it will overwrite the state + # We instead tune by hand. + BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 + else ((16, 4) if dstate <= 32 else + ((8, 4) if dstate <= 64 else + ((4, 4) if dstate <= 128 else + ((4, 8)))))) + tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0 + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, x, dt, dt_bias, A, B, C, D, z, out, + batch, nheads, dim, dstate, nheads // ngroups, + state.stride(0), state.stride(1), state.stride(2), state.stride(3), + x.stride(0), x.stride(1), x.stride(2), + dt.stride(0), dt.stride(1), dt.stride(2), + *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, + A.stride(0), A.stride(1), A.stride(2), + B.stride(0), B.stride(1), B.stride(2), + C.stride(0), C.stride(1), C.stride(2), + *(D.stride(0), D.stride(1)) if D is not None else 0, + z_strides[0], z_strides[1], z_strides[2], + out.stride(0), out.stride(1), out.stride(2), + dt_softplus, + tie_hdim, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + if not has_heads: + out = out.squeeze(1) + return out + + +def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + dt = dt + dt_bias + dt = F.softplus(dt) if dt_softplus else dt + dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate) + B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) + state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate + out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) + if D is not None: + out += (x * D).to(out.dtype) + out = (out if z is None else out * F.silu(z)).to(x.dtype) + if not has_heads: + out = out.squeeze(1) + return out diff --git a/src/modalities/models/mamba/utils/__init__.py b/src/modalities/models/mamba/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/modalities/models/mamba/utils/generation.py b/src/modalities/models/mamba/utils/generation.py new file mode 100644 index 00000000..369c7a14 --- /dev/null +++ b/src/modalities/models/mamba/utils/generation.py @@ -0,0 +1,387 @@ +# Copyright (c) 2023, Albert Gu, Tri Dao. +import gc +import time +from collections import namedtuple +from dataclasses import dataclass, field +from functools import partial +from typing import Callable, Optional, Sequence, Union + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import Tensor +from torch.profiler import ProfilerActivity, profile, record_function +from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer + + +@dataclass +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + max_seqlen: int + max_batch_size: int + seqlen_offset: int = 0 + batch_size_offset: int = 0 + key_value_memory_dict: dict = field(default_factory=dict) + lengths_per_sample: Optional[Tensor] = None + + def reset(self, max_seqlen, max_batch_size): + self.max_seqlen = max_seqlen + self.max_batch_size = max_batch_size + self.seqlen_offset = 0 + if self.lengths_per_sample is not None: + self.lengths_per_sample.zero_() + + +def modify_logits_for_min_p_filtering(logits, min_p): + """Set the logits for none min_p values to -inf. Done in-place.""" + if min_p <= 0.0 or min_p >= 1.0: + return + indices_to_remove = logits < min_p + logits.masked_fill_(indices_to_remove, float("-Inf")) +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 +def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf. Done in-place.""" + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(indices_to_remove, float("-Inf")) + + +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 +def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf. Done in-place.""" + if top_p <= 0.0 or top_p >= 1.0: + return + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - top_p) + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits.masked_fill_(indices_to_remove, float("-inf")) + + +def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0): + """Apply repetition penalty. See https://arxiv.org/abs/1909.05858 + logits: (batch_size, vocab_size) + prev_output_tokens: (batch_size, seq_len) + """ + if repetition_penalty == 1.0: + return logits + score = torch.gather(logits, 1, prev_output_tokens) + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty) + logits.scatter_(1, prev_output_tokens, score) + return logits + + +def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0): + """Sample from top-k logits. + Arguments: + logits: Tensor of shape (batch_size, vocab_size) + """ + if top_k == 1: # Short-circuit for greedy decoding + return logits.argmax(dim=-1) + else: + if top_p > 0.0: + assert top_p <= 1.0, "top-p should be in (0, 1]." + if top_k > 0: + top_k = min(top_k, logits.size(-1)) # Safety check + logits_top, indices = torch.topk(logits, top_k, dim=-1) + if temperature != 1.0: + logits_top /= temperature + modify_logits_for_top_p_filtering(logits_top, top_p) + return indices[ + torch.arange(indices.shape[0], device=indices.device), + torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), + ] + else: + if min_p > 0.0: + logits_top = logits.clone() + max_prob = logits_top[..., 0].item() + min_prob = max_prob * min_p + modify_logits_for_min_p_filtering(logits_top, min_p) + if temperature != 1.0: + logits_top /= temperature + return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) + # Clone so that when we modify for top_p we don't change the original logits + logits_top = logits / temperature if temperature != 1.0 else logits.clone() + modify_logits_for_top_p_filtering(logits_top, top_p) + return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( + dim=-1 + ) + + +@torch.inference_mode() +def decode( + input_ids, + model, + max_length, + top_k=1, + top_p=0.0, + min_p=0.0, + temperature=1.0, + repetition_penalty=1.0, + eos_token_id=None, + teacher_outputs=None, + vocab_size=None, + cg=False, + enable_timing=False, + streamer: Optional[TextStreamer] = None +): + """Decoding, either greedy or with top-k or top-p sampling. + If top-k = 0, don't limit the number of candidates (pure sampling). + Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, + then top-p. + We assume that all sequences in the same batch have the same length. + + Arguments: + input_ids: (batch, seq_len) + max_length: int + teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the + logits, the next token is taken from the teacher_outputs. Useful for testing. + Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: + sequences: (batch, max_length) + scores: tuples of (batch, vocab_size) + """ + if streamer is not None: + streamer.put(input_ids.cpu()) + + batch_size, seqlen_og = input_ids.shape + teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 + if cg: + if not hasattr(model, "_decoding_cache"): + model._decoding_cache = None + model._decoding_cache = update_graph_cache( + model, + model._decoding_cache, + batch_size, + seqlen_og, + max_length, + ) + inference_params = model._decoding_cache.inference_params + inference_params.reset(max_length, batch_size) + else: + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) + + def get_logits(input_ids, inference_params): + decoding = inference_params.seqlen_offset > 0 + if decoding: + position_ids = torch.full( + (batch_size, 1), + inference_params.seqlen_offset, + dtype=torch.long, + device=input_ids.device, + ) + else: + position_ids = None + if not cg or not decoding: + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=1, + ).logits.squeeze(dim=1) + else: + logits = model._decoding_cache.run( + input_ids, position_ids, inference_params.seqlen_offset + ).squeeze(dim=1) + return logits[..., :vocab_size] if vocab_size is not None else logits + + def sample_tokens(logits, inference_params): + if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: + token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature) + else: + token = teacher_outputs[:, inference_params.seqlen_offset] + # return rearrange(token, "b -> b 1") + return token.unsqueeze(1) + + def should_stop(current_token, inference_params): + if inference_params.seqlen_offset == 0: + return False + if eos_token_id is not None and (current_token == eos_token_id).all(): + return True + if inference_params.seqlen_offset >= max_length - 1: + return True + return False + + start = torch.cuda.Event(enable_timing=enable_timing) + end = torch.cuda.Event(enable_timing=enable_timing) + + if enable_timing: + start.record() + scores, sequences = [], [input_ids] + sequences_cat = input_ids + while not should_stop(sequences[-1], inference_params): + scores.append(get_logits(sequences[-1], inference_params)) + inference_params.seqlen_offset += sequences[-1].shape[1] + if repetition_penalty == 1.0: + sampled_tokens = sample_tokens(scores[-1], inference_params) + else: + logits = modify_logit_for_repetition_penalty( + scores[-1].clone(), sequences_cat, repetition_penalty + ) + sampled_tokens = sample_tokens(logits, inference_params) + sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1) + sequences.append(sampled_tokens) + if streamer is not None: + streamer.put(sampled_tokens.cpu()) + if streamer is not None: + streamer.end() + if enable_timing: + end.record() + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") + output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput + return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) + + +class GenerationMixin: + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + raise NotImplementedError + + def generate( + self, + input_ids, + max_length, + top_k=1, + top_p=0.0, + min_p=0.0, + temperature=1.0, + return_dict_in_generate=False, + output_scores=False, + **kwargs, + ): + output = decode( + input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, **kwargs + ) + if not output_scores: + output.scores = None + return output if return_dict_in_generate else output.sequences + + +@dataclass +class DecodingCGCache: + max_batch_size: int = 0 + max_seqlen: int = 0 + device = None + dtype = None + callables: dict = field(default_factory=dict) + mempool = None + inference_params: Optional[InferenceParams] = None + run: Optional[Callable] = None + + +@torch.inference_mode() +def update_graph_cache( + model, + cache, + batch_size, + seqlen_og, + max_seqlen, + decoding_seqlens=(1,), + dtype=None, + n_warmups=2, +): + if cache is None: + cache = DecodingCGCache() + param_example = next(iter(model.parameters())) + device = param_example.device + if dtype is None: + dtype = param_example.dtype + if ( + (device, dtype) != (cache.device, cache.dtype) + or batch_size > cache.max_batch_size + or max_seqlen > cache.max_seqlen + ): # Invalidate the cache + cache.callables = {} + cache.mempool = None + cache.inference_params = None + gc.collect() + cache.device, cache.dtype = device, dtype + cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen + assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache" + inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) + lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) + cache.inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_og, + key_value_memory_dict=inf_cache, + lengths_per_sample=lengths_per_sample, + ) + cache.mempool = torch.cuda.graphs.graph_pool_handle() + for decoding_seqlen in decoding_seqlens: + if (batch_size, decoding_seqlen) not in cache.callables: + cache.callables[batch_size, decoding_seqlen] = capture_graph( + model, + cache.inference_params, + batch_size, + max_seqlen, + decoding_seqlen=decoding_seqlen, + mempool=cache.mempool, + n_warmups=n_warmups, + ) + + def dispatch(input_ids, position_ids, seqlen): + batch_size, decoding_seqlen = input_ids.shape[:2] + return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) + + cache.run = dispatch + cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing + return cache + + +def capture_graph( + model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 +): + device = next(iter(model.parameters())).device + input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) + position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) + seqlen_offset_og = inference_params.seqlen_offset + inference_params.seqlen_offset = max_seqlen - decoding_seqlen + inference_params.lengths_per_sample[:] = inference_params.seqlen_offset + + # Warmup before capture + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(n_warmups): + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=decoding_seqlen, + ).logits + s.synchronize() + # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, + # which requires that graph launch and non-captured launch to not overlap (I think, + # that's how I interpret the documentation). I'm not sure if this is required. + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.cuda.current_stream().wait_stream(s) + # Captures the graph + # To allow capture, automatically sets a side stream as the current stream in the context + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=mempool): + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=decoding_seqlen, + ).logits + + def run(new_input_ids, new_position_ids, seqlen): + inference_params.lengths_per_sample[:] = seqlen + input_ids.copy_(new_input_ids) + position_ids.copy_(new_position_ids) + graph.replay() + return logits.clone() + + inference_params.seqlen_offset = seqlen_offset_og + return run diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index dfcb02f6..a20cfd47 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -1,10 +1,11 @@ from abc import abstractmethod -from typing import Dict +from typing import Dict, List import torch import torch.nn as nn from modalities.batch import DatasetBatch, InferenceResultBatch +from transformers import PreTrainedTokenizer class NNModel(nn.Module): @@ -21,6 +22,7 @@ def get_parameters(self) -> Dict[str, torch.Tensor]: return {name: param for name, param in self.named_parameters()} + def model_predict_batch(model: nn.Module, batch: DatasetBatch) -> InferenceResultBatch: forward_result = model.forward(batch.samples) result_batch = InferenceResultBatch(targets=batch.targets, predictions=forward_result) diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index a9258008..33f014df 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -66,6 +66,7 @@ HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig, ) +from modalities.models.mamba.mamba_config import MambaLLMConfig from modalities.models.model_factory import ModelFactory from modalities.optimizers.lr_schedulers import DummyLRScheduler from modalities.optimizers.optimizer_factory import OptimizerFactory @@ -81,6 +82,8 @@ FSDPGradientClipperConfig, ) +from modalities.models.mamba.mamba_model import MambaLLM + @dataclass class ComponentEntity: @@ -93,6 +96,7 @@ class ComponentEntity: COMPONENTS = [ # models ComponentEntity("model", "gpt2", GPT2LLM, GPT2LLMConfig), + ComponentEntity("model", "mamba", MambaLLM, MambaLLMConfig), ComponentEntity( "model", "huggingface_pretrained_model", HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig ), diff --git a/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py index 155d5343..bd1db928 100644 --- a/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py +++ b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py @@ -28,7 +28,11 @@ def _get_fsdp_blocks_from_block_names(model: nn.Module, block_names: List[str]) for cls_block_name in block_names: # TODO FullyShardedDataParallelPlugin from Accelerate uses string matching to find the correct # block class. In the long-term we should implmement this ourselves in a robuster fashion. - block_type = get_module_class_from_name(model, cls_block_name) + try: + block_type = FullyShardedDataParallelPlugin.get_module_class_from_name(model, cls_block_name) + except AttributeError: + from accelerate.utils.dataclasses import get_module_class_from_name + block_type = get_module_class_from_name(model, cls_block_name) if block_type is None: raise ValueError(f"Could not find block with name {cls_block_name} in model") fsdp_block_types.append(block_type) diff --git a/tests/models/mamba/__init__.py b/tests/models/mamba/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/mamba/conftest.py b/tests/models/mamba/conftest.py new file mode 100644 index 00000000..8f518ebb --- /dev/null +++ b/tests/models/mamba/conftest.py @@ -0,0 +1,266 @@ +import torch +from torch import nn + +from modalities.models.mamba.mamba_block import MambaBlock, Block +import pytest + +from functools import partial +from modalities.models.mamba.mamba_config import MambaBlockConfig, MixerModelConfig +from modalities.models.mamba.mamba_model import MambaLLM, MixerModel + + +@pytest.fixture() +def batch_size(): + return 2 + + +@pytest.fixture() +def expand(): + return 2 + + +@pytest.fixture(scope="session") +def d_model(): + return 16 + + +@pytest.fixture() +def d_state(): + return 3 + + +@pytest.fixture(scope="session") +def d_conv(): + return 4 + + +@pytest.fixture() +def vocab_size(): + return 1024 + + +@pytest.fixture() +def sequence_length(): + return 64 + + +@pytest.fixture() +def n_layer(): + return 2 + + +@pytest.fixture() +def conv_state(batch_size, expand, d_model, d_conv): + return torch.rand((batch_size, expand * d_model, d_conv)) + + +@pytest.fixture() +def ssm_state(batch_size, expand, d_model, d_state): + return torch.rand((batch_size, expand * d_model, d_state)) + + +@pytest.fixture() +def layer_idx(): + return 0 + + +@pytest.fixture() +def dt_rank(): + return "auto" + + +@pytest.fixture() +def dt_min(): + return 0.001 + + +@pytest.fixture() +def dt_max(): + return 0.1 + + +@pytest.fixture() +def dt_init(): + return "random" + + +@pytest.fixture() +def dt_scale(): + return 1.0 + + +@pytest.fixture() +def dt_init_floor(): + return 1e-4 + + +@pytest.fixture() +def conv_bias(): + return True + + +@pytest.fixture() +def bias(): + return False + + +@pytest.fixture() +def use_fast_path(): + return True + + +@pytest.fixture() +def device(): + return None + + +@pytest.fixture() +def dtype(): + return None + + +@pytest.fixture() +def ssm_cfg(mamba_block_config): + return mamba_block_config.model_dump() + + +@pytest.fixture() +def mamba_block_config(d_state, + d_conv, + expand, + dt_rank, + dt_min, + dt_max, + dt_init, + dt_scale, + dt_init_floor, + conv_bias, + bias, + use_fast_path): + return MambaBlockConfig(d_state=d_state, + d_conv=d_conv, + expand=expand, + dt_rank=dt_rank, + dt_min=dt_min, + dt_max=dt_max, + dt_init=dt_init, + dt_scale=dt_scale, + dt_init_floor=dt_init_floor, + conv_bias=conv_bias, + bias=bias, + use_fast_path=use_fast_path) + + +@pytest.fixture() +def mixer_model(d_model, n_layer, vocab_size, norm_epsilon, rms_norm, initializer_cfg, fused_add_norm, residual_in_fp32, + device, dtype, mamba_block_config): + return MixerModel(d_model=d_model, n_layer=n_layer, vocab_size=vocab_size, norm_epsilon=norm_epsilon, + rms_norm=rms_norm, initializer_cfg=initializer_cfg, fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, device=device, dtype=dtype, + mamba_block_config=mamba_block_config) + + +@pytest.fixture() +def factory_kwargs(device, dtype): + return {"device": device, "dtype": dtype} + + +@pytest.fixture() +def norm_epsilon(): + return 1e-5 + + +@pytest.fixture() +def rms_norm(): + return False + + +@pytest.fixture() +def initializer_cfg(): + return {} + + +@pytest.fixture() +def mixer_cls(layer_idx, factory_kwargs, ssm_cfg): + return partial(MambaBlock, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) + + +@pytest.fixture() +def mamba_block(d_model, d_state, d_conv, expand, dt_rank, dt_min, dt_max, dt_init, dt_init_floor, dt_scale, + bias, conv_bias, use_fast_path, layer_idx, dtype, device): + return MambaBlock(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand, dt_rank=dt_rank, + dt_min=dt_min, dt_max=dt_max, dt_init=dt_init, dt_scale=dt_scale, dt_init_floor=dt_init_floor, + bias=bias, conv_bias=conv_bias, use_fast_path=use_fast_path, layer_idx=layer_idx, dtype=dtype, + device=device) + + +@pytest.fixture() +def norm_cls(d_model, norm_epsilon, factory_kwargs): + return partial( + nn.LayerNorm, eps=norm_epsilon, **factory_kwargs + ) + + +@pytest.fixture() +def fused_add_norm(): + return True + + +@pytest.fixture() +def residual_in_fp32(): + return True + + +@pytest.fixture() +def block(d_model, mixer_cls, norm_cls, fused_add_norm, residual_in_fp32): + return Block(d_model=d_model, + mixer_cls=mixer_cls, + norm_cls=norm_cls, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32) + + +@pytest.fixture() +def hidden_states(d_model, + batch_size, ): + return torch.randn(batch_size, 1, d_model) + + +@pytest.fixture() +def prediction_key(): + return "logits" + + +@pytest.fixture() +def sample_key(): + return "input_ids" + + +@pytest.fixture() +def seed(): + return 42 + + +@pytest.fixture() +def mixer_model_config(norm_epsilon, device, mamba_block_config): + return MixerModelConfig(norm_epsilon=norm_epsilon, device=device, mamba_block_config=mamba_block_config) + + +@pytest.fixture() +def mamba_llm(d_model, n_layer, vocab_size, rms_norm, residual_in_fp32, fused_add_norm, prediction_key, sample_key, + seed, dtype, initializer_cfg, mixer_model_config): + return MambaLLM(d_model=d_model, n_layer=n_layer, vocab_size=vocab_size, rms_norm=rms_norm, + residual_in_fp32=residual_in_fp32, fused_add_norm=fused_add_norm, pad_vocab_size_multiple=1, + tie_embeddings=False, prediction_key=prediction_key, sample_key=sample_key, seed=seed, dtype=dtype, + initializer_cfg=initializer_cfg, num_last_tokens=0, inference_params={}, + mixer_model_config=mixer_model_config) + + +@pytest.fixture() +def linear_layer(): + return nn.Linear(in_features=16, out_features=24) + + +@pytest.fixture() +def embedding_layer(vocab_size, d_model): + return nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model) diff --git a/tests/models/mamba/test_mamba_block.py b/tests/models/mamba/test_mamba_block.py new file mode 100644 index 00000000..fc63aebc --- /dev/null +++ b/tests/models/mamba/test_mamba_block.py @@ -0,0 +1,74 @@ +import torch +import pytest +from modalities.models.mamba.utils.generation import InferenceParams + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="We need cuda to run Mamba.") +def test_mamba_block_forward(batch_size, + sequence_length, + d_model, + d_state, + d_conv, + expand, + mamba_block): + x = torch.randn(batch_size, sequence_length, d_model).to("cuda") + mamba_block = mamba_block.to("cuda") + y = mamba_block(x) + assert y.shape == x.shape + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="We need cuda to run Mamba.") +def test_block_forward(hidden_states, block): + block = block.to("cuda") + hidden_states = hidden_states.to("cuda") + computed_hidden_states, computed_residuals = block(hidden_states) + assert (hidden_states == computed_residuals).all() + assert hidden_states.shape == computed_hidden_states.shape + assert (hidden_states != computed_hidden_states).any() + + +def test_get_states_from_cache(conv_state, + ssm_state, + batch_size, + expand, + d_model, + d_state, + d_conv, + mamba_block, + layer_idx + ): + inference_params = InferenceParams(max_seqlen=16, max_batch_size=3, seqlen_offset=0, batch_size_offset=0, + key_value_memory_dict={layer_idx: (conv_state, ssm_state)}, + lengths_per_sample=None) + computed_conv_state, computed_ssm_state = mamba_block._get_states_from_cache(inference_params, batch_size) + assert (conv_state == computed_conv_state).all() + assert (ssm_state == computed_ssm_state).all() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="We need cuda to run Mamba.") +def test_step(conv_state, ssm_state, mamba_block, hidden_states): + device = "cuda" + mamba_block = mamba_block.to(device) + hidden_states = hidden_states.to(device) + conv_state = conv_state.to(device) + ssm_state = ssm_state.to(device) + computed_hidden_states, computed_conv_state, computed_ssm_state = mamba_block.step( + hidden_states=hidden_states.detach().clone(), + conv_state=conv_state.detach().clone(), + ssm_state=ssm_state.detach().clone()) + assert computed_hidden_states.shape == hidden_states.shape + assert computed_conv_state.shape == conv_state.shape + assert computed_ssm_state.shape == ssm_state.shape + assert (computed_hidden_states != hidden_states).any() + assert (computed_conv_state != conv_state).any() + assert (computed_ssm_state != ssm_state).any() + + +def test_allocate_inference_cache(mamba_block, batch_size, sequence_length, conv_state, ssm_state): + device = "cuda" + mamba_block.to(device) + computed_conv_state, computed_ssm_state = mamba_block.allocate_inference_cache(batch_size=batch_size, + max_seqlen=sequence_length, + dtype=torch.float32) + assert (computed_conv_state == torch.zeros(conv_state.shape).to(device)).all() + assert (computed_ssm_state == torch.zeros(ssm_state.shape).to(device)).all() diff --git a/tests/models/mamba/test_mamba_model.py b/tests/models/mamba/test_mamba_model.py new file mode 100644 index 00000000..e176af78 --- /dev/null +++ b/tests/models/mamba/test_mamba_model.py @@ -0,0 +1,56 @@ +import pytest +import torch +from transformers import AutoTokenizer +from modalities.models.mamba.mamba_model import _init_weights, create_block, MambaLLM +from tests.conftest import _ROOT_DIR + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="We need cuda to run Mamba.") +def test_mixer_model_forward(batch_size, sequence_length, vocab_size, mixer_model, d_model): + x = torch.randint(size=(batch_size, sequence_length), high=vocab_size).to("cuda") + mixer_model = mixer_model.to("cuda") + y = mixer_model(x) + assert y.shape == (batch_size, sequence_length, d_model) + assert y.shape != x.shape + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="We need cuda to run Mamba.") +def test_mixer_model_allocate_inference_cache(batch_size, sequence_length, mixer_model, n_layer): + mixer_model = mixer_model.to("cuda") + computed_inference_cache = mixer_model.allocate_inference_cache(batch_size, sequence_length) + assert len(computed_inference_cache) == n_layer + + +def test__init_weights(linear_layer, embedding_layer, n_layer): + _init_weights(linear_layer, n_layer) + assert int(linear_layer.bias.sum()) == 0 + + embedding_layer_weights_before = embedding_layer.weight.clone().detach() + _init_weights(embedding_layer, n_layer) + embedding_layer_weights_after = embedding_layer.weight + assert (embedding_layer_weights_before != embedding_layer_weights_after).any() + + +def test_mamba_llm_forward(mamba_llm, batch_size, sequence_length, vocab_size, prediction_key): + mamba_llm = mamba_llm.to("cuda") + x = torch.randint(size=(batch_size, sequence_length), high=vocab_size).to("cuda") + inputs = {"input_ids": x} + y = mamba_llm(inputs) + assert prediction_key in y.keys() + assert y[prediction_key].shape == (batch_size, sequence_length, vocab_size) + + +def test__create_block(d_model, ssm_cfg, norm_epsilon, rms_norm, residual_in_fp32, fused_add_norm, layer_idx, device, + dtype): + test_block = create_block(d_model=d_model, ssm_cfg=ssm_cfg, norm_epsilon=norm_epsilon, rms_norm=rms_norm, + residual_in_fp32=residual_in_fp32, fused_add_norm=fused_add_norm, layer_idx=layer_idx, + device=device, dtype=dtype) + assert test_block.norm.normalized_shape[0] == d_model + assert test_block.mixer.d_model == d_model + + +def test_tie_weights(mamba_llm): + assert (mamba_llm.lm_head.weight != mamba_llm.backbone.embedding.weight).any() + mamba_llm.tie_embeddings = True + mamba_llm.tie_weights() + assert (mamba_llm.lm_head.weight == mamba_llm.backbone.embedding.weight).all()