Small experiments on estimating transformer memory from axolotl config files.
This is primarily a learning project, in which I am looking to predict the memory consumption of a LLM from an axolotl config file alone.
The hope is that if we can get a reasonable solution for this, we can move it over to the axolotl
project directly. Issue Here
Simply pass the axolotl config file to the main script as a '--config' file path to estimate the size in memory.
python main.py --config examples/code-llama/7b/lora.yml
Would return:
┌─────────────────────────────────────────────────────────────────┐
│ Memory Estimate │
├──────────────────────────────────────────┬───────────┬──────────┤
│ Modelling │ Precision │ Memory │
├──────────────────────────────────────────┼───────────┼──────────┤
│ Base Model (codellama/CodeLlama-7b-hf) │ BIT8 │ 6.2GiB │
│ LORA Adapter │ BIT16 │ 152.5MiB │
├──────────────────────────────────────────┬───────────┬──────────┤
│ Training │ Precision │ Memory │
├──────────────────────────────────────────┼───────────┼──────────┤
│ Gradients │ BIT16 │ 152.5MiB │
│ Optimizer (adamw_bnb_8bit) │ BIT8 │ 152.5MiB │
│ Activations │ MIXED │ 201.8MiB │
└──────────────────────────────────────────┴───────────┴──────────┘
Borrowing from here, we group memory requirements into three broad buckets:
The memory required to load the model into memory. Includes base model, quantized or unquantized, and peft adapters.
Model Base | Base Model | 4bit | 8bit | LORA | QLORA | GPTQ | GPTQ w/Flash Attn | flash attn | xformers attn |
---|---|---|---|---|---|---|---|---|---|
Llama | ✔️ | ✔️ | ✔️ | ✔️ | ❌ | ❌ | ❌ | ❌ | ❌ |
The memory required for a single backward pass of the model.
Optimizer | Basic |
---|---|
sgd | ✔️ |
adamw_hf | ✔️ |
adamw_torch | ✔️ |
adamw_torch_fused | ✔️ |
adamw_apex_fused | ❌ |
adamw_anyprecision | ❌ |
adafactor | ❌ |
adamw_bnb_8bit | ✔️ |
The required memory to do a forward pass of the model.