Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
mali-git authored Mar 26, 2024
1 parent db231a8 commit f8e34c3
Showing 1 changed file with 34 additions and 105 deletions.
139 changes: 34 additions & 105 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,77 +55,40 @@ Or, if you are a VsCode user, add this to your `launch.json`:
}
```

# Pydantic and ClassResolver

The mechanism introduced to instantiate classes via `type_hint` in the `config.yaml`, utilizes
1) Omegaconf to load the config yaml file
2) Pydantic for the validation of the config
3) ClassResolver to instantiate the correct, concrete class of a class hierarchy.

Firstly, Omegaconf loads the config yaml file and resolves internal references such as `${subconfig.attribute}`.

Then, Pydantic validates the whole config as is and checks that each of the sub-configs are `pydantic.BaseModel` classes.
For configs, which allow different concrete classes to be instantiated by `ClassResolver`, the special member names `type_hint` and `config` are introduced.
With this we utilize Pydantics feature to auto-select a fitting type based on the keys in the config yaml file.

`ClassResolver` replaces large if-else control structures to infer the correct concrete type with a `type_hint` used for correct class selection:
```python
activation_resolver = ClassResolver(
[nn.ReLU, nn.Tanh, nn.Hardtanh],
base=nn.Module,
default=nn.ReLU,
)
type_hint="ReLU"
activation_kwargs={...}
activation_resolver.make(type_hint, activation_kwargs),
```

In our implementation we go a step further, as both,
* a `type_hint` in a `BaseModel` config must be of type `modalities.config.lookup_types.LookupEnum` and
* `config` is a union of allowed concrete configs of base type `BaseModel`.
`config` hereby replaces `activation_kwargs` in the example above, and replaces it with pydantic-validated `BaseModel` configs.

With this, a mapping between type hint strings needed for `class-resolver`, and the concrete class is introduced, while allowing pydantic to select the correct concrete config:

```python
from enum import Enum
from typing import Annotated
from pydantic import BaseModel, PositiveInt, PositiveFloat, Field

class LookupEnum(Enum):
@classmethod
def _missing_(cls, value: str) -> type:
"""constructs Enum by member name, if not constructable by value"""
return cls.__dict__[value]

class SchedulerTypes(LookupEnum):
StepLR = torch.optim.lr_scheduler.StepLR
ConstantLR = torch.optim.lr_scheduler.ConstantLR

class StepLRConfig(BaseModel):
step_size: Annotated[int, Field(strict=True, ge=1)]
gamma: Annotated[float, Field(strict=True, ge=0.0)]


class ConstantLRConfig(BaseModel):
factor: PositiveFloat
total_iters: PositiveInt


class SchedulerConfig(BaseModel):
type_hint: SchedulerTypes
config: StepLRConfig | ConstantLRConfig
```

To allow a user-friendly instantiation, all class resolvers are defined in the `ResolverRegistry` and `build_component_by_config` as convenience function is introduced. Dependencies can be passed-through with the `extra_kwargs` argument:
```python
resolvers = ResolverRegister(config=config)
optimizer = ... # our example dependency
scheduler = resolvers.build_component_by_config(config=config.scheduler, extra_kwargs=dict(optimizer=optimizer))
```

To add a new resolver use `add_resolver`, and the corresponding added resolver will be accessible by the register_key given during adding.
For access use the `build_component_by_key_query` function of the `ResolverRegistry`.
# Supported Features
In the following, we list the already implemented, planned and in-progress features w.r.t. to improving downstream performance, throughput, multi-modality, and alignment.

## Throughput Features

| Name | Status | Description |
|---------------------------------------|------------------|-------------------------------------------------------------------------------------------------------------------|
| Mixed Precision Training | supported | Utilizes both single (FP32) and half precision (FP16) floating-point formats to speed up arithmetic computations while maintaining model accuracy. Support for bf16|
| Fully Sharded Data Parallel (FSDP) | supported | Optimizes distributed training by sharding the model parameters, gradients, and optimizer states across all GPUs, reducing memory overhead and enabling the training of larger models. |
| Gradient Accumulation | supported | Allows for the use of larger batch sizes than what might fit in memory by accumulating gradients over multiple mini-batches before updating model weights. |
| CPU Offloading via FSDP | supported | Moves parts of the model or computation from GPU to CPU or other storage to manage GPU memory constraints. |
| Memmap for efficient data loading | supported | Optimizes the data pipeline to reduce I/O bottlenecks. |
| Activation Checkpointing | supported | Saves intermediate activations to memory only at certain points during the forward pass and recomputes them during the backward pass, reducing memory usage at the cost of additional computation. |
| Flash Attention | supported | A highly optimized attention mechanism that significantly reduces the computational burden and memory footprint of attention calculations, enabling faster training and inference on large models. |
| Adaptive Batch Size Exploration | planned | Dynamically increases the training batch size during the training process to identify the maximum batch size that can be accommodated by a given GPU setup without causing memory overflow or performance degradation. |
| Node Failure Recovery | planned | Implements mechanisms to automatically detect and recover from failures (e.g., node or GPU failures) in distributed training environments, ensuring that training can continue with minimal interruption even if one or more nodes / GPUs in the cluster fail. |



## Downstream Performance Features

| Name | Status | Description |
|--------------------------------|------------------|-------------------------------------------------------------------------------------------------------------------|
| SwiGLU | supported | A nonlinear activation function combining Gated Linear Units (GLU) with Swish for enhancing model capacity and learning efficiency. |
| Weight Decay | supported | Regularization technique that adds a penalty on the size of weights, encouraging smaller weights to reduce overfitting and improve generalization. |
| RMSNorm (pre-normalization) | supported | Normalizes the pre-activation weights in a layer to stabilize training, often used as an alternative to LayerNorm for improved training dynamics. |
| Rotary Positional Embeddings (RoPE) | supported | Encodes sequence position information into attention mechanisms, preserving relative positional information and improving model's understanding of sequence order. |
| Grouped-query Attention (GQA) | supported | Enhances attention mechanisms by grouping queries to reduce computation and memory footprint while maintaining or improving performance. |
| Learning Rate Scheduler | supported | Adjusts the learning rate during training according to a predefined schedule (e.g., step decay, exponential decay) to improve convergence and performance. |
| Gradient Clipping | supported | Prevents exploding gradients by clipping the gradients of an optimization algorithm to a maximum value, thereby stabilizing training. |
| Training Warmup | planned | Gradually increases the learning rate from a low to a high value during the initial phase of training to stabilize optimization. |
| Loss Masking | planned | Ignores or gives less weight to certain data points in the loss function, often used in tasks with variable-length sequences to ignore padding tokens or in more specific usecases such as GAtt. |
| Knowledge Distillation | planned | Transfers knowledge from a larger, complex model to a smaller, more efficient model, improving the smaller model's performance without the computational cost of the larger model.|
| Hyperparameter Optimization | planned | Grid search for various hyperparameter such as LR, Optimizer arguments etc. Also the integration of µP might be interesting |


## Entry Points
Expand Down Expand Up @@ -182,37 +145,3 @@ in our `pyproject.toml`, we can start only main with `modalities` (which does no

Alternatively, directly use `src/modalities/__main__.py do_stuff --config_file_path config_files/config.yaml --my_cli_argument 3537`.

# MemMap Datasets

## MemMapDataset Index Generator

The `MemMapDataset` requires an index file providing the necessary pointers into the raw data file. The `MemMapDataset` can create the index file lazily, however, it is advised to create it beforehand. This can be done by running

```sh
modalities data create_raw_index <path/to/jsonl/file>
```

The index will be created in the same directory as the raw data file. For further options you may look into the usage documentation via `modalities data create_raw_index --help`.

## Packed Dataset Generator

The `PackedMemMapDatasetContinuous` and `PackedMemMapDatasetMegatron` require a packed data file. To create the data file, you first have to generate a `MemMapDataset` index file as described [above](#memmapdataset-index-generator). Assuming the index and raw data are located in the same directory, you can simply execute the following command:

```sh
modalities data pack_encoded_data <path/to/jsonl/file>
```

The packed data file will be created in the same directory as the raw data file. For further options you may look into the usage documentation via `modalities data pack_encoded_data --help`.

### Packed Data Format

The packed data file is a bytestream containing both the tokenized data as well as an index denoting the start and length of the tokenized documents inside the bytestream. The data file consists of 3 concatenated parts:

header segment | data segment | index segment

* **header segment**: This section is a 8 bytes sized integer which encodes the length of the data segment in bytes.
* **data segment**: This section contains a concatenation of all documents in form of 4 bytes sized tokens.
An end-of-sequence token is placed between consecutive documents.
* **index segment**: This section contains a pickled index which locates the documents inside the data segment.
The index is basically a list of tuples, where each tuple contains the start position and length in bytes for the
corresponding document, e.g., `[(start_doc1, len_doc1), (start_doc2, len_doc2), ....]`.

0 comments on commit f8e34c3

Please sign in to comment.