Skip to content

Commit

Permalink
Merge pull request #141 from Modalities/dev_experiments
Browse files Browse the repository at this point in the history
Towards stable modalities version
  • Loading branch information
le1nux authored Jul 10, 2024
2 parents f25c018 + 2de5ab4 commit 1b57ef5
Show file tree
Hide file tree
Showing 101 changed files with 4,074 additions and 2,092 deletions.
41 changes: 41 additions & 0 deletions .github/ISSUE_TEMPLATE/bug-report.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: 🐛 Bug Report
description: Submit a bug report to help improve modalities
labels: [ "bug" ]

body:
- type: markdown
attributes:
value: >
#### Before submitting a bug report, please make sure the issue hasn't already been addressed, by searching through [the existing and past issues](https://github.com/Modalities/modalities/issues).
- type: textarea
id: system-info
attributes:
label: System Info
description: Please share your system info with us.
placeholder: modalities version, platform, python version, ...
validations:
required: true

- type: textarea
attributes:
label: 🐛 Describe the bug
description: |
Please provide a clear and concise description of what the bug is. If relevant, add a minimal example so that we can reproduce the error by running the code. Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception.
placeholder: |
A clear and concise description of what the bug is.
```python
# Sample code to reproduce the problem
```
```
The error message you got, with the full traceback.
```
validations:
required: true

- type: markdown
attributes:
value: >
Thanks for contributing 🎉!
1 change: 1 addition & 0 deletions .github/ISSUE_TEMPLATE/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
blank_issues_enabled: true
23 changes: 23 additions & 0 deletions .github/ISSUE_TEMPLATE/documentation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: 📚 Documentation
description: Report an issue related to https://modalities.github.io/modalities/
labels: [ "documentation" ]

body:
- type: textarea
attributes:
label: 📚 The doc issue
description: >
A clear and concise description of what content in https://modalities.github.io/modalities/ is an issue.
validations:
required: true

- type: textarea
attributes:
label: Suggest a potential alternative/fix
description: >
Tell us how we could improve the documentation in this regard.
- type: markdown
attributes:
value: >
Thanks for contributing 🎉!
27 changes: 27 additions & 0 deletions .github/ISSUE_TEMPLATE/feature-request.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: 🚀 Feature Request
description: Submit a proposal/request for a new modalities feature
labels: [ "feature" ]

body:
- type: textarea
id: feature-request
validations:
required: true
attributes:
label: Feature request
description: |
A clear and concise description of the feature proposal.
- type: textarea
id: motivation
validations:
required: true
attributes:
label: Motivation
description: |
Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link it here, too.
- type: markdown
attributes:
value: >
Thanks for contributing 🎉!
16 changes: 16 additions & 0 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# What does this PR do?

This PR ..

## General Changes
* ..

## Breaking Changes
* ..

## Checklist before submitting final PR
- [ ] My PR is minimal and addresses one issue in isolation
- [ ] I have merged the latest version of the target branch into this feature branch
- [ ] I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
- [ ] I have run a sample config for model training
- [ ] I have checked that all tests run through (`python tests/tests.py`)
5 changes: 4 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
python-version: ["3.10", "3.11"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Comment out dependencies that require GPU
run: |
sed -i 's/"flash-attn"/#"flash-attn"/g' pyproject.toml
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
39 changes: 39 additions & 0 deletions CHANGELOG_DEV.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,50 @@

| PR | Type | Ref. Issue(s) | Breaking Changes |PR Description|
|------------------|------------|---------------|------------------|------------------------------------------------------------------------------------------------|
| [#141](#pr-141-towards-stable-modalities-version) | Bug Fix | [#129](https://github.com/Modalities/modalities/issues/129) | **Yes** | Towards stable modalities version |
| [#154](pr-154-manual-swiglu-implementation) | Bug Fix | [#14](https://github.com/Modalities/modalities/issues/14) | **Yes** | Towards stable modalities version |
| | | | | |




## PR #141 Towards stable modalities version

This PR further stabilise the codebase and makes training more robust also w.r.t. loss spikes, which we fixed via scaled weight initialisation and an increased batch size in our experiments.
The PR also fixes all failing tests and adds a simple entrypoint for running cpu, single-gpu and multi-gpu tests. The PR contains multiple sub PRs.

**General changes:**
* Bug fix: the model evaluation mode is now properly deactivated after evaluation (see PR [#131](https://github.com/Modalities/modalities/pull/131))
* Bug fix: Fixed the implementation of Pre-LN for GPT2 model (see PR [#136](https://github.com/Modalities/modalities/pull/136))
* Enhancement: Further mixed precision strategies; also added one matching MegatronLM's.
* Enhancement: Single, unified entrypoint for running cpu, single-gpu and multi-gpu tests. All tests fixed. (PR [#155](https://github.com/Modalities/modalities/pull/155))
* Enhancement: Previously, we would chunk the dataset into `block_size` long chunks. Each chunk would then be used for training individually. As a result, the last token of a block would be only used as a target but never as an input. We changed this, such that we reuse the last token of a batch as the first one of the subsequent batch. (PR [#158](https://github.com/Modalities/modalities/pull/158))
* Bug: Indexing of the original samples of the dataset pbin files had multiple bugs. The index tuples are now always in bytes and the start of the first sample in the data section starts at byte 0 (before the was a wrong offset) (PR [#164](https://github.com/Modalities/modalities/pull/164))
* Enhancement: Improvements on the current pull request template and addition of several issue templates (bug report, documentation, feature request, blank) (PR [#172](https://github.com/Modalities/modalities/pull/172))
* Components and factories for plain, scaled and scaled_embed initialisation. (PR [#161](https://github.com/Modalities/modalities/pull/161))
* in GPT2 model training configs, the standard deviation `std` can now be set to the string `auto` (in which case it will equal `sqrt(2/(5*hidden_dim))`, see e.g. https://arxiv.org/abs/2312.16903) (PR [#161](https://github.com/Modalities/modalities/pull/161))
* The CoCa model, which previously used a hardcoded, (probably not entirely correct) scaled initialization (see #165), can now only use plain initialization (PR [#161](https://github.com/Modalities/modalities/pull/161))


**Breaking changes:**
* Enhancement: Logging is now always based on #training steps and #consumed tokens (PR [#137](https://github.com/Modalities/modalities/pull/137))
This change is a breaking change and the experiment configs need to adapated as shown [here](https://github.com/Modalities/modalities/pull/137/files#diff-2bea5a6678ec91ea603cc2e80d17847360af5e9f7624c8e710f329ee1eb9b4f4).
* Enhancement: The model parameters are now grouped within the respective model. The optimizer can leverage these groups to e.g., only apply weight decay to non-layer-norm weights. See [here](https://github.com/Modalities/modalities/pull/139/files#diff-2bea5a6678ec91ea603cc2e80d17847360af5e9f7624c8e710f329ee1eb9b4f4) for the necessary config changes. (PR [#139](https://github.com/Modalities/modalities/pull/139))
* Enhancement: We support now different attention implementations (manual, pytorch flash, DAO flash) See [here](https://github.com/Modalities/modalities/pull/138/files#diff-2bea5a6678ec91ea603cc2e80d17847360af5e9f7624c8e710f329ee1eb9b4f4) for the respective config changes. (PR [#138](https://github.com/Modalities/modalities/pull/138))
* Enhancement: replaced `block_size` in `Dataset`, `Model` and `NumberConversion` with `sequence_length` (PR [#158](https://github.com/Modalities/modalities/pull/158))
* Enhancement: `block_size` is now `sequence_length +1` and we should always specify `sequence_length` as a value of power of 2. (PR [#158](https://github.com/Modalities/modalities/pull/158))
* Enhancement: Restricted the codebase to the officially supported python versions 3.10 and 3.11 ((PR [#174](https://github.com/Modalities/modalities/pull/174)))
* All training configs require an additional component for initialization of the raw model (i.e. the model with random weights), as shown [here](https://github.com/Modalities/modalities/blob/7d26675051b918c3a2b98f32f50cb3ca8ef97d6f/config_files/training/config_lorem_ipsum.yaml#L181). (PR [#161](https://github.com/Modalities/modalities/pull/161))

## Checklist before submitting final PR
- [ ] My PR is minimal and addresses one issue / enhancement in isolation
- [ ] I have merged main into this feature branch
- [ ] I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
- [ ] I have run a sample config for model training
- [ ] I have fixed all failing tests (`python tests/tests.py`)



## PR #154 Manual SwiGLU implementation

This [PR](https://github.com/Modalities/modalities/pull/154) adds a manual SwiGLU implementation. The original one from xops was imcompatible with activation checkpointing (see issue [#14](https://github.com/Modalities/modalities/issues/14))
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ In the following, we list the already implemented, planned and in-progress featu
|--------------------------------|------------------|-------------------------------------------------------------------------------------------------------------------|
| SwiGLU | supported | A nonlinear activation function combining Gated Linear Units (GLU) with Swish for enhancing model capacity and learning efficiency. |
| Weight Decay | supported | Regularization technique that adds a penalty on the size of weights, encouraging smaller weights to reduce overfitting and improve generalization. |
| Weight Initialization | supported | Choose between different, configurable weight initialization techniques to stabilize training. |
| RMSNorm (pre-normalization) | supported | Normalizes the pre-activation weights in a layer to stabilize training, often used as an alternative to LayerNorm for improved training dynamics. |
| Rotary Positional Embeddings (RoPE) | supported | Encodes sequence position information into attention mechanisms, preserving relative positional information and improving model's understanding of sequence order. |
| Grouped-query Attention (GQA) | supported | Enhances attention mechanisms by grouping queries to reduce computation and memory footprint while maintaining or improving performance. |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ settings:
prediction_key: logits
model_path: /raid/s3/opengptx/max_lue/modalities/data/checkpoints/2024-04-22__13-16-03/eid_2024-04-22__13-16-03-model-num_steps_1152.bin
device: 0
context_length: 2048
sequence_length: 2048

text_inference_component:
component_key: inference_component
Expand All @@ -17,7 +17,7 @@ text_inference_component:
tokenizer:
instance_key: tokenizer
pass_type: BY_REFERENCE
context_length: ${settings.context_length}
sequence_length: ${settings.sequence_length}
eod_token: <eod>
prompt_template: "{prompt_input}" # "<instruction> Du bist Moody, ein LLM welches Menschen helfen soll. user: {prompt_input}"
temperature: 0
Expand All @@ -44,7 +44,7 @@ raw_model:
config:
sample_key: ${settings.referencing_keys.sample_key}
poe_type: ABSOLUTE
block_size: ${settings.context_length}
sequence_length: ${settings.sequence_length}
prediction_key: ${settings.referencing_keys.prediction_key}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 12
Expand All @@ -59,9 +59,6 @@ raw_model:
- type_hint: IdentityTransform
config: {}
activation_type: gelu
weight_init:
mean: 0.0
std: 0.02
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ settings:
prediction_key: logits
model_path: /raid/s3/opengptx/max_lue/modalities/data/checkpoints/2024-04-28__13-06-00/eid_2024-04-28__13-06-00-model-num_steps_256.bin
device: 0
context_length: 2048
sequence_length: 2048

text_inference_component:
component_key: inference_component
Expand All @@ -17,7 +17,7 @@ text_inference_component:
tokenizer:
instance_key: tokenizer
pass_type: BY_REFERENCE
context_length: ${settings.context_length}
sequence_length: ${settings.sequence_length}
eod_token: <eod>
prompt_template: "{prompt_input}" # "<instruction> Du bist Moody, ein LLM welches Menschen helfen soll. user: {prompt_input}"
temperature: 0
Expand All @@ -44,7 +44,7 @@ raw_model:
config:
sample_key: ${settings.referencing_keys.sample_key}
poe_type: NOPE
block_size: ${settings.context_length}
sequence_length: ${settings.sequence_length}
prediction_key: ${settings.referencing_keys.prediction_key}
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: 12
Expand All @@ -62,9 +62,6 @@ raw_model:
n_head: ${raw_model.config.n_head_q} #it has to be head_q here
seq_length_dim: -2
activation_type: gelu
weight_init:
mean: 0.0
std: 0.02
attention_norm:
component_key: layer_norm
variant_key: rms_norm
Expand Down
55 changes: 41 additions & 14 deletions config_files/training/config_example_coca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@ settings:
sample_key: input_ids
target_key: target_ids
training:
global_training_log_interval_in_steps: 2
global_checkpointing_interval_in_steps: 2
global_evaluation_interval_in_steps: 2
global_num_training_samples: 12
global_num_seen_steps: 0
training_log_interval_in_steps: 2
checkpointing_interval_in_steps: 2
evaluation_interval_in_steps: 2
global_num_seen_tokens: 0
activation_checkpointing_modules: []
gradient_acc_steps: 1
local_train_micro_batch_size: 3
Expand Down Expand Up @@ -42,7 +41,7 @@ train_dataset:
component_key: dataset
variant_key: dummy_dataset
config:
num_samples: 4
num_samples: 64
sample_definition:
- sample_key: images
sample_shape: [3, 224, 224]
Expand All @@ -55,7 +54,7 @@ val_dataset:
component_key: dataset
variant_key: dummy_dataset
config:
num_samples: 4
num_samples: 32
sample_definition:
- sample_key: images
sample_shape: [3, 224, 224]
Expand Down Expand Up @@ -144,7 +143,13 @@ checkpoint_saving:
checkpoint_path: ${settings.paths.checkpointing_path}
global_rank: ${settings.cuda_env.global_rank}
experiment_id: ${settings.experiment_id}

get_num_tokens_from_num_steps_callable:
component_key: number_conversion
variant_key: num_tokens_from_num_steps_callable
config:
num_ranks: ${settings.cuda_env.world_size}
local_micro_batch_size: ${settings.training.local_train_micro_batch_size}
sequence_length: ${settings.training.sequence_length}
loss_fn:
component_key: loss
variant_key: clm_cross_entropy_loss
Expand All @@ -164,7 +169,23 @@ wrapped_model:
sharding_strategy: FULL_SHARD
block_names: [TransformerBlock, VisionTransformerBlock]

model:
model:
component_key: model
variant_key: model_initialized
config:
model:
instance_key: model_raw
pass_type: BY_REFERENCE
model_initializer:
component_key: model_initialization
variant_key: composed
config:
model_type: coca
weight_init_type: plain
mean: 0.0
std: 0.02

model_raw:
component_key: model
variant_key: coca
config:
Expand Down Expand Up @@ -209,9 +230,6 @@ model:
n_vision_queries: 256
bias_attn_pool: False
epsilon_attn_pool: 1e-5
weight_init:
mean: 0.0
std: 0.02

scheduler:
component_key: scheduler
Expand All @@ -223,7 +241,7 @@ scheduler:
max_lr: 6e-4
div_factor: 10
final_div_factor: 1
total_steps: 4
total_steps: 64
pct_start: 0.01
anneal_strategy: cos

Expand All @@ -235,6 +253,7 @@ optimizer:
betas: [0.9, 0.95]
eps: 1e-8
weight_decay: 1e-1
weight_decay_groups_excluded: []
wrapped_model:
instance_key: wrapped_model
pass_type: BY_REFERENCE
Expand All @@ -254,7 +273,15 @@ batch_progress_subscriber:
variant_key: rich
config:
local_rank: ${settings.cuda_env.local_rank}
global_num_seen_steps: ${settings.training.global_num_seen_steps}
global_num_seen_steps:
component_key: number_conversion
variant_key: num_steps_from_num_tokens
config:
num_ranks: ${settings.cuda_env.world_size}
local_micro_batch_size: ${settings.training.local_train_micro_batch_size}
global_num_tokens: ${settings.training.global_num_seen_tokens}
sequence_length: ${settings.training.sequence_length}
gradient_acc_steps: ${settings.training.gradient_acc_steps}
train_dataloader:
instance_key: train_dataloader
pass_type: BY_REFERENCE
Expand Down
Loading

0 comments on commit 1b57ef5

Please sign in to comment.