diff --git a/scripts/run_finetune.sh b/scripts/run_finetune.sh index 2043800aa..d78458e20 100755 --- a/scripts/run_finetune.sh +++ b/scripts/run_finetune.sh @@ -74,5 +74,5 @@ deepspeed ${deepspeed_args} \ --ddp_timeout 72000 \ --save_steps 5000 \ --dataloader_num_workers 1 \ - | tee ${log_dir}/train.log \ - 2> ${log_dir}/train.err + > >(tee ${log_dir}/train.log) \ + 2> >(tee ${log_dir}/train.err >&2) diff --git a/scripts/run_finetune_with_custom_optim.sh b/scripts/run_finetune_with_custom_optim.sh new file mode 100755 index 000000000..b33988978 --- /dev/null +++ b/scripts/run_finetune_with_custom_optim.sh @@ -0,0 +1,209 @@ +#!/bin/bash +# Please run this script under ${project_id} in project directory of +# https://github.com/shizhediao/llm-ft +# COMMIT: d5fecf30ba8011067b10cf51fede53a5ab6574e4 + +# Parses arguments +model_name_or_path=gpt2 +dataset_path=data/alpaca/train_conversation + +# Other optional arguments that can improve memory saving +gradient_checkpointing=True +use_flash_attention=0 +gradient_accumulation_steps=1 +batch_size=1 +block_size=256 +per_device_train_batch_size=1 +conversation_template=llama2 +optim=dummy +learning_rate=1e-5 +lr_schedule=cosine +beta1=0.9 +beta2=0.999 +num_epoch=3 +use_deepspeed=1 +seed=42 + +# Safety related arguments +trust_remote_code=0 + +# Enable model parallelism for multiple gpus, modify this if you prefer +# customized deepspeed zero-redundancy optimization settings +num_gpu=$(python -c "import torch; print(torch.cuda.device_count())") +ds_config_file=configs/ds_config_zero0_no_offload.json +if [[ ${num_gpu} -ge 2 ]]; then + ds_config_file=configs/ds_config_zero2_no_offload.json +fi + +while [[ $# -ge 1 ]]; do + key="$1" + case ${key} in + -m|--model_name_or_path) + model_name_or_path="$2" + shift + ;; + -d|--dataset_path) + dataset_path="$2" + shift + ;; + -o|--output_model_path) + output_dir="$2" + shift + ;; + --lisa_activated_layers) + lisa_activated_layers="$2" + shift + ;; + --lisa_interval_steps) + lisa_interval_steps="$2" + shift + ;; + --gradient_checkpointing) + gradient_checkpointing="$2" + shift + ;; + --deepspeed) + ds_config_file="$2" + shift + ;; + --use_flash_attention) + use_flash_attention="$2" + shift + ;; + --gradient_accumulation_steps) + gradient_accumulation_steps="$2" + shift + ;; + --block_size) + block_size="$2" + shift + ;; + --conversation_template) + conversation_template="$2" + shift + ;; + --per_device_train_batch_size|--batch_size) + per_device_train_batch_size="$2" + batch_size="$2" + shift + ;; + --trust_remote_code) + trust_remote_code="$2" + shift + ;; + --run_name) + run_name="$2" + shift + ;; + --optim) + optim="$2" + shift + ;; + --lr) + learning_rate=$2 + shift + ;; + --beta1) + beta1=$2 + shift + ;; + --beta2) + beta2=$2 + shift + ;; + -n|--num_epoch) + num_epoch=$2 + shift + ;; + --lr_schedule) + lr_schedule=$2 + shift + ;; + --use_deepspeed) + use_deepspeed=$2 + shift + ;; + --seed) + seed=$2 + shift + ;; + *) + echo "error: unknown option \"${key}\"" 1>&2 + exit 1 + esac + shift +done + +gpu_id=${CUDA_VISIBLE_DEVICES} +deepspeed_args="--master_port=1103${gpu_id::1} --hostfile configs/hostfile --include localhost:${gpu_id}" + +optim_suffix_args="" +if [ "${optim}" == "dummy" ]; then + optim_suffix_args="--use_customized_optim 1" + optim_suffix_args+=" --customized_optim ${optim}" + optim_suffix_args+=" --optim_dummy_beta1 ${beta1}" + optim_suffix_args+=" --optim_dummy_beta2 ${beta2}" +else + optim_suffix_args="--optim ${optim}" + optim_suffix_args+=" --adam_beta1 ${beta1}" + optim_suffix_args+=" --adam_beta2 ${beta2}" +fi + +# Finetune +exp_id=alpaca_${optim}_lr-${learning_rate}_beta1-${beta1}_beta2-${beta2}_lr-sched-${lr_schedule}_model-$(basename ${model_name_or_path})_batch-size-${batch_size}x${gradient_accumulation_steps}_seed-${seed} +echo "$(date): ${exp_id}..." + +tmp_dir=tmp +mkdir -p ${tmp_dir} + +prefix=${exp_id} +if [ -f ${tmp_dir}/${prefix}.mark ]; then + exit 0 +fi + +trap "rm -f ${tmp_dir}/${prefix}.mark" SIGINT SIGTERM SIGKILL +touch ${tmp_dir}/${prefix}.mark + +project_dir=$(cd "$(dirname $0)"/..; pwd) +log_dir=${project_dir}/log/${exp_id} +output_dir=output_models/${exp_id} +mkdir -p ${output_dir} ${log_dir} + +exe="deepspeed ${deepspeed_args}" +if [[ ${use_deepspeed} -eq 0 ]]; then + exe=python +fi +${exe} examples/finetune.py \ + --model_name_or_path ${model_name_or_path} \ + --trust_remote_code ${trust_remote_code} \ + --dataset_path ${dataset_path} \ + --output_dir ${output_dir} --overwrite_output_dir \ + --conversation_template ${conversation_template} \ + --num_train_epochs ${num_epoch} \ + --learning_rate ${learning_rate} \ + --lr_scheduler_type ${lr_schedule} \ + --disable_group_texts 1 \ + --block_size ${block_size} \ + --per_device_train_batch_size ${per_device_train_batch_size} \ + --bf16 \ + --deepspeed configs/ds_config_zero2_no_offload.json \ + --torch_dtype bfloat16 \ + --run_name ${exp_id} \ + --validation_split_percentage 0 \ + --logging_steps 1 \ + --do_train \ + --ddp_timeout 72000 \ + --save_steps 5000 \ + --dataloader_num_workers 1 \ + --gradient_checkpointing ${gradient_checkpointing} \ + --use_flash_attention ${use_flash_attention} \ + --gradient_accumulation_steps ${gradient_accumulation_steps} \ + --seed ${seed} \ + ${optim_suffix_args} \ + | tee ${log_dir}/train.log \ + 2> ${log_dir}/train.err + +if [[ $? -ne 0 ]]; then + echo "$(date): failed" + rm -f ${tmp_dir}/${prefix}.mark +fi diff --git a/src/lmflow/args.py b/src/lmflow/args.py index 56d8d43e3..c4e87491c 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -29,6 +29,10 @@ logger = logging.getLogger(__name__) +class OptimizerNames(): + DUMMY = "dummy" + + @dataclass class ModelArguments: """ @@ -645,6 +649,36 @@ class FinetunerArguments(TrainingArguments): "help": "where the layer attribute stores, e.g. model.model.layers" } ) + use_customized_optim: bool = field( + default=False, + metadata={ + "help": "whether to use customized optimizers." + } + ) + customized_optim: str = field( + default="sign_sgd", + metadata={ + "help": "name of the customized optimizer." + } + ) + customized_optim_args: str = field( + default=None, + metadata={ + "help": "optional arguments that are supplied." + } + ) + optim_dummy_beta1: float = field( + default=0.9, + metadata={ + "help": "A useless argument for dummy optimizer, just for tutorial" + } + ) + optim_dummy_beta2: float = field( + default=0.999, + metadata={ + "help": "A useless argument for dummy optimizer, just for tutorial" + } + ) @dataclass diff --git a/src/lmflow/optim/__init__.py b/src/lmflow/optim/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/lmflow/optim/dummy.py b/src/lmflow/optim/dummy.py new file mode 100644 index 000000000..bb922199d --- /dev/null +++ b/src/lmflow/optim/dummy.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python +# coding=utf-8 +"""Dummy Optimizer. +""" +import math +import warnings +from typing import Callable, Iterable, Tuple + +import torch +from torch import nn +from torch.optim import Optimizer + +class Dummy(Optimizer): + """ + An dummy optimizer that does nothing. + + Parameters: + params (:obj:`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (:obj:`float`, `optional`, defaults to 0): + The learning rate to use. + """ + + def __init__( + self, + params: Iterable[nn.parameter.Parameter], + lr: float = 0., + betas: Tuple[float, float] = (0.9, 0.999), + weight_decay: float = 0.0, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") + defaults = {"lr": lr, "betas": betas, "weight_decay": weight_decay} + super().__init__(params, defaults) + + + @torch.no_grad() + def step(self, closure: Callable=None): + """ + Performs a single optimization step. + + Arguments: + closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Dummy does not support sparse gradients yet") + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg2"] = torch.zeros_like(p) + + # v := exp_avg + # m := double_exp_avg + v, m = state["exp_avg"], state["exp_avg2"] + beta1, beta2 = group["betas"] + step_size = group["lr"] + + state["step"] += 1 + + p.add_(m, alpha=-0.0) + if group["weight_decay"] > 0.0: + p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) + return loss diff --git a/src/lmflow/optim/optimizers.py b/src/lmflow/optim/optimizers.py new file mode 100644 index 000000000..3917436a9 --- /dev/null +++ b/src/lmflow/optim/optimizers.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python +# coding=utf-8 +"""All optimizers. +""" +from lmflow.optim.dummy import Dummy diff --git a/src/lmflow/pipeline/finetuner.py b/src/lmflow/pipeline/finetuner.py index 192dbae4d..c913ac1c8 100644 --- a/src/lmflow/pipeline/finetuner.py +++ b/src/lmflow/pipeline/finetuner.py @@ -7,6 +7,7 @@ import logging import os import sys +from typing import Any, Iterable, Optional, Tuple import datasets import transformers @@ -18,14 +19,21 @@ set_seed, ) from copy import deepcopy -from transformers.utils import send_example_telemetry +from transformers import PreTrainedModel, TrainingArguments from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_callback import ( TrainerCallback, TrainerControl, TrainerState, ) +from transformers.utils import ( + is_sagemaker_mp_enabled, + send_example_telemetry, +) import numpy as np + +import lmflow.optim.optimizers as optim +from lmflow.args import OptimizerNames from lmflow.datasets.dataset import Dataset from lmflow.pipeline.base_tuner import BaseTuner from lmflow.pipeline.utils.peft_trainer import PeftTrainer, PeftSavingCallback @@ -203,6 +211,84 @@ def group_texts(examples): return lm_datasets + def create_customized_optimizer(self, base_trainer_class, model_args): + class CustomizedOptimTrainer(base_trainer_class): + + @staticmethod + def get_optimizer_cls_and_kwargs( + args: TrainingArguments, + model: Optional[PreTrainedModel] = None, + ) -> Tuple[Any, Any]: + # parse args.optim_args + optim_args = {} + if args.customized_optim_args: + for mapping in args.customized_optim_args.replace(" ", "").split(","): + key, value = mapping.split("=") + optim_args[key] = value + + optimizer_kwargs = {"lr": args.learning_rate} + + if args.customized_optim == OptimizerNames.DUMMY: + optimizer_cls = optim.Dummy + dummy_kwargs = { + "betas": (args.optim_dummy_beta1, args.optim_dummy_beta2), + } + optimizer_kwargs.update(dummy_kwargs) + else: + raise ValueError( + f"Trainer cannot instantiate unsupported optimizer: " + f" {args.customized_optim}" + ) + return optimizer_cls, optimizer_kwargs + + + def create_optimizer(self): + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + + if self.optimizer is None: + decay_parameters = self.get_decay_parameter_names(opt_model) + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() + if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() + if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + + optimizer_cls, optimizer_kwargs = CustomizedOptimTrainer.get_optimizer_cls_and_kwargs(self.args, opt_model) + + # Overwrite `params` in case it's created by + # `get_optimizer_cls_and_kwargs` e.g. for GaLore optimizer. + if "params" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop( + "params" + ) + + # For layer-wise dummy optimizers we overwrite + # optimizer_grouped_parameters with `optimizer_dict` to + # avoid arguments conflicts. + if "optimizer_dict" in optimizer_kwargs: + optimizer_grouped_parameters = optimizer_kwargs.pop( + "optimizer_dict" + ) + + self.optimizer = optimizer_cls( + optimizer_grouped_parameters, + **optimizer_kwargs + ) + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer(self.optimizer) + + return CustomizedOptimTrainer def tune(self, model, @@ -297,6 +383,12 @@ def compute_metrics(eval_preds): if data_collator is None: data_collator = default_data_collator + if training_args.use_customized_optim: + BaseTrainer = FinetuningTrainer + FinetuningTrainer = self.create_customized_optimizer( + BaseTrainer, model_args + ) + if training_args.use_lisa: class DynamicLayerActivationCallback(TrainerCallback): def __init__(self, n_layers, interval_steps, model):