Skip to content

Commit

Permalink
Merge pull request #113 from Modalities/mamba
Browse files Browse the repository at this point in the history
Mamba
  • Loading branch information
rrutmann committed May 27, 2024
2 parents 2b562b1 + 8c7f9bd commit d2f6dd9
Show file tree
Hide file tree
Showing 28 changed files with 3,572 additions and 159 deletions.
217 changes: 217 additions & 0 deletions config_files/training/config_mem_map_mamba.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
settings:
experiment_id: ${modalities_env:experiment_id}
referencing_keys:
sample_key: input_ids
target_key: target_ids
prediction_key: logits
training:
callback_interval_in_samples: 32768
global_num_training_samples: 2048
global_num_seen_samples: 0
do_apply_activation_checkpointing: false
gradient_acc_steps: 1
local_train_micro_batch_size: 16
sequence_length: 4096
gradient_clipping:
mode: NONE
cuda_env:
local_rank: ${cuda_env:LOCAL_RANK}
global_rank: ${cuda_env:RANK}
world_size: ${cuda_env:WORLD_SIZE}
paths:
checkpointing_path: data/checkpoints


collate_fn:
component_key: collate_fn
variant_key: gpt_2_llm_collator
config:
sample_key: ${settings.referencing_keys.sample_key}
target_key: ${settings.referencing_keys.target_key}

train_dataset:
component_key: dataset
variant_key: packed_mem_map_dataset_megatron
config:
raw_data_path: /raid/s3/opengptx/max_lue/modalities/data/sample_datasets/redpajama_v2/mem_map/redpyjama_v2_default_DE_num_docs_1024/redpyjama_v2_default_DE_num_docs_1024.pbin
block_size: ${settings.training.sequence_length}
sample_key: ${settings.referencing_keys.sample_key}

train_dataloader:
component_key: data_loader
variant_key: default
config:
num_workers: 2
pin_memory: true
shuffle: false
dataloader_tag: "train"
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
batch_sampler:
component_key: batch_sampler
variant_key: default
config:
batch_size: ${settings.training.local_train_micro_batch_size}
drop_last: true
sampler:
component_key: sampler
variant_key: distributed_sampler
config:
rank: ${settings.cuda_env.global_rank}
num_replicas: ${settings.cuda_env.world_size}
shuffle: true
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
collate_fn:
instance_key: collate_fn
pass_type: BY_REFERENCE

val_dataset:
component_key: dataset
variant_key: packed_mem_map_dataset_megatron
config:
raw_data_path: /raid/s3/opengptx/max_lue/modalities/data/sample_datasets/redpajama_v2/mem_map/redpyjama_v2_default_DE_num_docs_1024/redpyjama_v2_default_DE_num_docs_1024.pbin
block_size: ${settings.training.sequence_length}
sample_key: ${settings.referencing_keys.sample_key}

val_dataloader:
component_key: data_loader
variant_key: default
config:
num_workers: 2
pin_memory: true
shuffle: false
dataloader_tag: "val"
dataset:
instance_key: val_dataset
pass_type: BY_REFERENCE
batch_sampler:
component_key: batch_sampler
variant_key: default
config:
batch_size: ${settings.training.local_train_micro_batch_size}
drop_last: true
sampler:
component_key: sampler
variant_key: distributed_sampler
config:
rank: ${settings.cuda_env.global_rank}
num_replicas: ${settings.cuda_env.world_size}
shuffle: true
dataset:
instance_key: val_dataset
pass_type: BY_REFERENCE
collate_fn:
instance_key: collate_fn
pass_type: BY_REFERENCE

eval_dataloaders:
- instance_key: val_dataloader
pass_type: BY_REFERENCE

checkpointing:
component_key: checkpointing
variant_key: default
config:
checkpointing_strategy:
component_key: checkpointing_strategy
variant_key: save_k_most_recent_checkpoints_strategy
config:
k: -1 # -1 to save all checkpoints
checkpointing_execution:
component_key: checkpointing_execution
variant_key: fsdp_to_disc_checkpointing
config:
checkpoint_path: ${settings.paths.checkpointing_path}
global_rank: ${settings.cuda_env.global_rank}
experiment_id: ${settings.experiment_id}
mixed_precision_settings: BF_16
sharding_strategy: FULL_SHARD
block_names: [ MambaBlock ]

model:
component_key: model
variant_key: mamba
config:
d_model: 16
n_layer: 2
vocab_size: 50257
rms_norm: True
ssm_cfg: {}
residual_in_fp32: True
fused_add_norm: True
pad_vocab_size_multiple: 8
tie_embeddings: True
prediction_key: logits

wrapped_model:
component_key: model
variant_key: fsdp_wrapped
config:
model:
instance_key: model
pass_type: BY_REFERENCE
sync_module_states: true
mixed_precision_settings: BF_16
sharding_strategy: FULL_SHARD
block_names: [ MambaBlock ]

scheduler:
component_key: scheduler
variant_key: onecycle_lr
config:
optimizer:
instance_key: optimizer
pass_type: BY_REFERENCE
max_lr: 6e-4
div_factor: 10
final_div_factor: 1
total_steps: 64
pct_start: 0.01
anneal_strategy: cos

loss_fn:
component_key: loss
variant_key: clm_cross_entropy_loss
config:
target_key: ${settings.referencing_keys.target_key}
prediction_key: ${settings.referencing_keys.prediction_key}

optimizer:
component_key: optimizer
variant_key: adam_w
config:
lr: 0.0001
betas: [ 0.9, 0.95 ]
eps: 1e-8
weight_decay: 1e-1
wrapped_model:
instance_key: wrapped_model
pass_type: BY_REFERENCE

batch_progress_subscriber:
component_key: progress_subscriber
variant_key: rich
config:
local_rank: ${settings.cuda_env.local_rank}
world_size: ${settings.cuda_env.world_size}
global_num_seen_samples: ${settings.training.global_num_seen_samples}
train_dataloader:
instance_key: train_dataloader
pass_type: BY_REFERENCE
eval_dataloaders:
- instance_key: val_dataloader
pass_type: BY_REFERENCE


evaluation_subscriber:
component_key: results_subscriber
variant_key: wandb
config:
local_rank: ${settings.cuda_env.local_rank}
project: modalities
mode: ONLINE
experiment_id: ${settings.experiment_id}
directory: "."
Loading

0 comments on commit d2f6dd9

Please sign in to comment.