Skip to content

Commit

Permalink
[Feature] Support "auto" fp16/bf16 for DeepSpeed (InternLM#195)
Browse files Browse the repository at this point in the history
* support deepspeed auto dtype

* update ds configs
  • Loading branch information
LZHgrla authored Nov 2, 2023
1 parent 8d57f35 commit 0b84391
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 23 deletions.
12 changes: 6 additions & 6 deletions xtuner/configs/deepspeed/deepspeed_zero2.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
"train_micro_batch_size_per_gpu": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 2,
"contiguous_gradients": false,
"allgather_bucket_size": 1e8,
"reduce_bucket_size": 1e8,
"overlap_comm": true,
"reduce_scatter": true
"overlap_comm": true
},
"fp16": {
"enabled": true,
"enabled": "auto",
"initial_scale_power": 16
},
"bf16": {
"enabled": "auto"
}
}
9 changes: 4 additions & 5 deletions xtuner/configs/deepspeed/deepspeed_zero2_offload.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,17 @@
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 2,
"contiguous_gradients": false,
"allgather_bucket_size": 1e8,
"reduce_bucket_size": 1e8,
"overlap_comm": true,
"reduce_scatter": true,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
},
"fp16": {
"enabled": true,
"enabled": "auto",
"initial_scale_power": 16
},
"bf16": {
"enabled": "auto"
}
}
10 changes: 4 additions & 6 deletions xtuner/configs/deepspeed/deepspeed_zero3.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 3,
"contiguous_gradients": false,
"allgather_bucket_size": 3e8,
"reduce_bucket_size": 3e8,
"overlap_comm": true,
"reduce_scatter": true,
"stage3_gather_16bit_weights_on_model_save": true
},
"low_cpu_mem_usage": false,
"fp16": {
"enabled": true,
"enabled": "auto",
"initial_scale_power": 16
},
"bf16": {
"enabled": "auto"
}
}
9 changes: 4 additions & 5 deletions xtuner/configs/deepspeed/deepspeed_zero3_offload.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
"zero_force_ds_cpu_optimizer": false,
"zero_optimization": {
"stage": 3,
"contiguous_gradients": false,
"allgather_bucket_size": 3e8,
"reduce_bucket_size": 3e8,
"overlap_comm": true,
"reduce_scatter": true,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
Expand All @@ -22,7 +18,10 @@
"stage3_gather_16bit_weights_on_model_save": true
},
"fp16": {
"enabled": true,
"enabled": "auto",
"initial_scale_power": 16
},
"bf16": {
"enabled": "auto"
}
}
4 changes: 3 additions & 1 deletion xtuner/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from xtuner.model.modules import dispatch_modules
from xtuner.model.utils import LoadWoInit, find_all_linear_names, traverse_dict
from xtuner.registry import BUILDER, MAP_FUNC
from xtuner.tools.utils import auto_dtype_of_deepspeed_config


def parse_args():
Expand Down Expand Up @@ -222,9 +223,10 @@ def main():
logger='current',
level=logging.WARNING)
grad_clip = mm_max_norm
ds_cfg = auto_dtype_of_deepspeed_config(ds_cfg)
strategy = dict(
type='DeepSpeedStrategy',
config=args.deepspeed,
config=ds_cfg,
gradient_accumulation_steps=grad_accum,
train_micro_batch_size_per_gpu=train_bs,
gradient_clipping=grad_clip)
Expand Down
20 changes: 20 additions & 0 deletions xtuner/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import re

import torch
from transformers import (PreTrainedTokenizerFast, StoppingCriteria,
StoppingCriteriaList)
from transformers.generation.streamers import BaseStreamer
Expand Down Expand Up @@ -138,3 +139,22 @@ def update_stop_criteria(base,
if answer_stop_word is not None:
answer.append(StopWordStoppingCriteria(tokenizer, answer_stop_word))
return command, answer


def auto_dtype_of_deepspeed_config(ds_config):
if ds_config.get('fp16') and not ds_config.get('bf16'):
if ds_config.get('fp16').get('enabled') == 'auto':
ds_config['fp16']['enabled'] = torch.cuda.is_available()
elif not ds_config.get('fp16') and ds_config.get('bf16'):
if ds_config.get('bf16').get('enabled') == 'auto':
ds_config['bf16']['enabled'] = torch.cuda.is_bf16_supported()
elif ds_config.get('fp16') and ds_config.get('bf16'):
if ds_config.get('fp16').get('enabled') == 'auto':
ds_config['fp16']['enabled'] = torch.cuda.is_available()
if ds_config.get('bf16').get('enabled') == 'auto':
ds_config['bf16']['enabled'] = torch.cuda.is_bf16_supported()
if (ds_config['fp16']['enabled'] is True
and ds_config['bf16']['enabled'] is True):
ds_config['fp16']['enabled'] = False
ds_config['bf16']['enabled'] = True
return ds_config

0 comments on commit 0b84391

Please sign in to comment.