From f226253362051a2659c803c94cb43c6ae895d5ef Mon Sep 17 00:00:00 2001 From: Doan Minh Phuong <105905988+nRuaif@users.noreply.github.com> Date: Fri, 1 Mar 2024 21:48:50 +0700 Subject: [PATCH 1/6] Add new optims. --- src/axolotl/core/trainer_builder.py | 35 +++- src/axolotl/custom_optim/__init__.py | 0 src/axolotl/custom_optim/lion.py | 191 ++++++++++++++++++++ src/axolotl/custom_optim/prodigy.py | 251 +++++++++++++++++++++++++++ src/axolotl/custom_optim/sophia.py | 198 +++++++++++++++++++++ 5 files changed, 666 insertions(+), 9 deletions(-) create mode 100644 src/axolotl/custom_optim/__init__.py create mode 100644 src/axolotl/custom_optim/lion.py create mode 100644 src/axolotl/custom_optim/prodigy.py create mode 100644 src/axolotl/custom_optim/sophia.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 3502b229c..368b0ab02 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -972,29 +972,46 @@ def build(self, total_num_steps): trainer_kwargs = {} - if self.cfg.optimizer == "lion_pytorch": - from lion_pytorch import Lion - lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]} + if self.cfg.optimizer in ["lion_pytorch", "prodigy", "sophia"]: + + custom_optim_kwargs = {"lr": training_arguments_kwargs["learning_rate"]} if "weight_decay" in training_arguments_kwargs: - lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"] + custom_optim_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"] if ( "adam_beta1" in training_arguments_kwargs and "adam_beta2" in training_arguments_kwargs ): - lion_kwargs["betas"] = ( + custom_optim_kwargs["betas"] = ( training_arguments_kwargs["adam_beta1"], training_arguments_kwargs["adam_beta2"], ) - trainer_kwargs["optimizers"] = ( - Lion(params=self.model.parameters(), **lion_kwargs), - None, - ) + if self.cfg.optimizer == "lion_pytorch": + from axolotl.custom_optim.lion import Lion + trainer_kwargs["optimizers"] = ( + Lion(params=self.model.parameters(), **custom_optim_kwargs), + None, + ) + if self.cfg.optimizer == "sophia": + from axolotl.custom_optim.sophia import SophiaG + trainer_kwargs["optimizers"] = ( + SophiaG(params=self.model.parameters(), **custom_optim_kwargs), + None, + ) + if self.cfg.optimizer == "prodigy": + from axolotl.custom_optim.prodigy import Prodigy + trainer_kwargs["optimizers"] = ( + Prodigy(params=self.model.parameters(), **custom_optim_kwargs), + None, + ) + # Set default so transformers doesn't throw training_arguments_kwargs["optim"] = "adamw_hf" + + if self.cfg.optimizer == "adamw_anyprecision": if Path(self.cfg.torchdistx_path).exists(): sys.path.append(self.cfg.torchdistx_path) diff --git a/src/axolotl/custom_optim/__init__.py b/src/axolotl/custom_optim/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/custom_optim/lion.py b/src/axolotl/custom_optim/lion.py new file mode 100644 index 000000000..d27bc0710 --- /dev/null +++ b/src/axolotl/custom_optim/lion.py @@ -0,0 +1,191 @@ +from typing import Tuple, Optional, Callable + +import torch +from torch.optim.optimizer import Optimizer + +try: + import triton + import triton.language as tl +except ImportError as e: + print('triton is not installed, please install by running `pip install triton -U --pre`') + exit() + + +def exists(val): + return val is not None + + +# update functions + +def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): + # stepweight decay + + p.data.mul_(1 - lr * wd) + + # weight update + + update = exp_avg.clone().mul_(beta1).add(grad, alpha=1 - beta1).sign_() + p.add_(update, alpha=-lr) + + # decay the momentum running average coefficient + + exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) + + +def clone_inplace_updated_params(nargs): + nargs['p_ptr'] = nargs['p_ptr'].clone() + nargs['exp_avg_ptr'] = nargs['exp_avg_ptr'].clone() + + +# triton cuda kernel + +@triton.autotune(configs=[ + triton.Config({'BLOCK_SIZE': 128}, num_warps=4, pre_hook=clone_inplace_updated_params), + triton.Config({'BLOCK_SIZE': 1024}, num_warps=8, pre_hook=clone_inplace_updated_params), +], key=['n_elements']) +@triton.jit +def update_fn_kernel( + p_ptr, + grad_ptr, + exp_avg_ptr, + lr, + wd, + beta1, + beta2, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < n_elements + + # offsetted pointers + + offset_p_ptr = p_ptr + offsets + offset_grad_ptr = grad_ptr + offsets + offset_exp_avg_ptr = exp_avg_ptr + offsets + + # load + + p = tl.load(offset_p_ptr, mask=mask) + grad = tl.load(offset_grad_ptr, mask=mask) + exp_avg = tl.load(offset_exp_avg_ptr, mask=mask) + + # stepweight decay + + p = p * (1 - lr * wd) + + # diff between momentum running average and grad + + diff = exp_avg - grad + + # weight update + + update = diff * beta1 + grad + + # torch.sign + + can_update = update != 0 + update_sign = tl.where(update > 0, -lr, lr) + + p = p + update_sign * can_update + + # decay the momentum running average coefficient + + exp_avg = diff * beta2 + grad + + # store new params and momentum running average coefficient + + tl.store(offset_p_ptr, p, mask=mask) + tl.store(offset_exp_avg_ptr, exp_avg, mask=mask) + + +def update_fn( + p: torch.Tensor, + grad: torch.Tensor, + exp_avg: torch.Tensor, + lr: float, + wd: float, + beta1: float, + beta2: float +): + assert all([t.is_cuda for t in (p, grad, exp_avg)]) + n_elements = p.numel() + + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + + update_fn_kernel[grid]( + p, + grad, + exp_avg, + lr, + wd, + beta1, + beta2, + n_elements + ) + + +class Lion(Optimizer): + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + use_triton: bool = False + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + + defaults = dict( + lr=lr, + betas=betas, + weight_decay=weight_decay + ) + + super().__init__(params, defaults) + + self.update_fn = update_fn + + if use_triton: + self.update_fn = triton_update_fn + + @torch.no_grad() + def step( + self, + closure: Optional[Callable] = None + ): + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in filter(lambda p: exists(p.grad), group['params']): + + grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \ + self.state[p] + + # init state - exponential moving average of gradient values + + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + + self.update_fn( + p, + grad, + exp_avg, + lr, + wd, + beta1, + beta2 + ) + + return loss diff --git a/src/axolotl/custom_optim/prodigy.py b/src/axolotl/custom_optim/prodigy.py new file mode 100644 index 000000000..da9250eab --- /dev/null +++ b/src/axolotl/custom_optim/prodigy.py @@ -0,0 +1,251 @@ +import math +from typing import TYPE_CHECKING, Any, Callable, Optional + +import torch +import torch.optim +import logging +import os +import torch.distributed as dist + +if TYPE_CHECKING: + from torch.optim.optimizer import _params_t +else: + _params_t = Any + + +class Prodigy(torch.optim.Optimizer): + r""" + Implements Adam with Prodigy step-sizes. + Leave LR set to 1 unless you encounter instability. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate adjustment parameter. Increases or decreases the Prodigy learning rate. + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + beta3 (float): + coefficients for computing the Prodidy stepsize using running averages. + If set to None, uses the value of square root of beta2 (default: None). + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + decouple (boolean): + Use AdamW style decoupled weight decay + use_bias_correction (boolean): + Turn on Adam's bias correction. Off by default. + safeguard_warmup (boolean): + Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. + d0 (float): + Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. + d_coef (float): + Coefficient in the expression for the estimate of d (default 1.0). + Values such as 0.5 and 2.0 typically work as well. + Changing this parameter is the preferred way to tune the method. + growth_rate (float): + prevent the D estimate from growing faster than this multiplicative rate. + Default is inf, for unrestricted. Values like 1.02 give a kind of learning + rate warmup effect. + fsdp_in_use (bool): + If you're using sharded parameters, this should be set to True. The optimizer + will attempt to auto-detect this, but if you're using an implementation other + than PyTorch's builtin version, the auto-detection won't work. + """ + + def __init__(self, params, lr=1.0, + betas=(0.9, 0.999), beta3=None, + eps=1e-8, weight_decay=0, decouple=True, + use_bias_correction=True, safeguard_warmup=True, + d0=1e-6, d_coef=1.0, growth_rate=float('inf'), + fsdp_in_use=False): + if not 0.0 < d0: + raise ValueError("Invalid d0 value: {}".format(d0)) + if not 0.0 < lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 < eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + if decouple and weight_decay > 0: + print(f"Using decoupled weight decay") + + defaults = dict(lr=lr, betas=betas, beta3=beta3, + eps=eps, weight_decay=weight_decay, + d=d0, d0=d0, d_max=d0, + d_numerator=0.0, d_coef=d_coef, + k=0, growth_rate=growth_rate, + use_bias_correction=use_bias_correction, + decouple=decouple, safeguard_warmup=safeguard_warmup, + fsdp_in_use=fsdp_in_use) + self.d0 = d0 + super().__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return False + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + d_denom = 0.0 + + group = self.param_groups[0] + use_bias_correction = group['use_bias_correction'] + beta1, beta2 = group['betas'] + beta3 = group['beta3'] + if beta3 is None: + beta3 = math.sqrt(beta2) + k = group['k'] + + d = group['d'] + d_max = group['d_max'] + d_coef = group['d_coef'] + lr = max(group['lr'] for group in self.param_groups) + + if use_bias_correction: + bias_correction = ((1 - beta2 ** (k + 1)) ** 0.5) / (1 - beta1 ** (k + 1)) + else: + bias_correction = 1 + + dlr = d * lr * bias_correction + + growth_rate = group['growth_rate'] + decouple = group['decouple'] + fsdp_in_use = group['fsdp_in_use'] + + d_numerator = group['d_numerator'] + d_numerator *= beta3 + + for group in self.param_groups: + decay = group['weight_decay'] + k = group['k'] + eps = group['eps'] + group_lr = group['lr'] + d0 = group['d0'] + safeguard_warmup = group['safeguard_warmup'] + + if group_lr not in [lr, 0.0]: + raise RuntimeError( + f"Setting different lr values in different parameter groups is only supported for values of 0") + + for p in group['params']: + if p.grad is None: + continue + if hasattr(p, "_fsdp_flattened"): + fsdp_in_use = True + + grad = p.grad.data + + # Apply weight decay (coupled variant) + if decay != 0 and not decouple: + grad.add_(p.data, alpha=decay) + + state = self.state[p] + + # State initialization + if 'step' not in state: + state['step'] = 0 + state['s'] = torch.zeros_like(p.data).detach() + state['p0'] = p.detach().clone() + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data).detach() + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data).detach() + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + s = state['s'] + p0 = state['p0'] + + if group_lr > 0.0: + # we use d / d0 instead of just d to avoid getting values that are too small + d_numerator += (d / d0) * dlr * torch.dot(grad.flatten(), (p0.data - p.data).flatten()).item() + + # Adam EMA updates + exp_avg.mul_(beta1).add_(grad, alpha=d * (1 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=d * d * (1 - beta2)) + + if safeguard_warmup: + s.mul_(beta3).add_(grad, alpha=((d / d0) * d)) + else: + s.mul_(beta3).add_(grad, alpha=((d / d0) * dlr)) + d_denom += s.abs().sum().item() + + ###### + + d_hat = d + + # if we have not done any progres, return + # if we have any gradients available, will have d_denom > 0 (unless \|g\|=0) + if d_denom == 0: + return loss + + if lr > 0.0: + if fsdp_in_use: + dist_tensor = torch.zeros(2).cuda() + dist_tensor[0] = d_numerator + dist_tensor[1] = d_denom + dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) + global_d_numerator = dist_tensor[0] + global_d_denom = dist_tensor[1] + else: + global_d_numerator = d_numerator + global_d_denom = d_denom + + d_hat = d_coef * global_d_numerator / global_d_denom + if d == group['d0']: + d = max(d, d_hat) + d_max = max(d_max, d_hat) + d = min(d_max, d * growth_rate) + + for group in self.param_groups: + group['d_numerator'] = global_d_numerator + group['d_denom'] = global_d_denom + group['d'] = d + group['d_max'] = d_max + group['d_hat'] = d_hat + + decay = group['weight_decay'] + k = group['k'] + eps = group['eps'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + + state = self.state[p] + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + state['step'] += 1 + + denom = exp_avg_sq.sqrt().add_(d * eps) + + # Apply weight decay (decoupled variant) + if decay != 0 and decouple: + p.data.add_(p.data, alpha=-decay * dlr) + + ### Take step + p.data.addcdiv_(exp_avg, denom, value=-dlr) + + group['k'] = k + 1 + + return loss diff --git a/src/axolotl/custom_optim/sophia.py b/src/axolotl/custom_optim/sophia.py new file mode 100644 index 000000000..1dfcad29a --- /dev/null +++ b/src/axolotl/custom_optim/sophia.py @@ -0,0 +1,198 @@ +import math +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer +from typing import List, Optional + + +class SophiaG(Optimizer): + def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho=0.04, + weight_decay=1e-1, *, maximize: bool = False, + capturable: bool = False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= rho: + raise ValueError("Invalid rho parameter at index 1: {}".format(rho)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, rho=rho, + weight_decay=weight_decay, + maximize=maximize, capturable=capturable) + super(SophiaG, self).__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('maximize', False) + group.setdefault('capturable', False) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def update_hessian(self): + for group in self.param_groups: + beta1, beta2 = group['betas'] + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + + if len(state) == 0: + state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ + if self.defaults['capturable'] else torch.tensor(0.) + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + if 'hessian' not in state.keys(): + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2) + + @torch.no_grad() + def step(self, closure=None, bs=5120): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + state_steps = [] + hessian = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + + if p.grad.is_sparse: + raise RuntimeError('Hero does not support sparse gradients') + grads.append(p.grad) + state = self.state[p] + # State initialization + if len(state) == 0: + state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ + if self.defaults['capturable'] else torch.tensor(0.) + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + if 'hessian' not in state.keys(): + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + state_steps.append(state['step']) + hessian.append(state['hessian']) + + if self.defaults['capturable']: + bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs + + sophiag(params_with_grad, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=group['rho'], + lr=group['lr'], + weight_decay=group['weight_decay'], + maximize=group['maximize'], + capturable=group['capturable']) + + return loss + + +def sophiag(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + capturable: bool = False, + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool): + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") + + func = _single_tensor_sophiag + + func(params, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=rho, + lr=lr, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable) + + +def _single_tensor_sophiag(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool, + capturable: bool): + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + hess = hessian[i] + step_t = state_steps[i] + + if capturable: + assert param.is_cuda and step_t.is_cuda and bs.is_cuda + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + hess = torch.view_as_real(hess) + param = torch.view_as_real(param) + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + if capturable: + step_size = lr + step_size_neg = step_size.neg() + + ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1) + param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) + else: + step_size_neg = - lr + + ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1) + param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) From f426aad3e217cddfd1c475659f068fac511bc4d2 Mon Sep 17 00:00:00 2001 From: Doan Minh Phuong <105905988+nRuaif@users.noreply.github.com> Date: Fri, 1 Mar 2024 21:48:50 +0700 Subject: [PATCH 2/6] Add new optims. --- src/axolotl/core/trainer_builder.py | 35 +++- src/axolotl/custom_optim/__init__.py | 0 src/axolotl/custom_optim/lion.py | 191 ++++++++++++++++++++ src/axolotl/custom_optim/prodigy.py | 251 +++++++++++++++++++++++++++ src/axolotl/custom_optim/sophia.py | 198 +++++++++++++++++++++ 5 files changed, 666 insertions(+), 9 deletions(-) create mode 100644 src/axolotl/custom_optim/__init__.py create mode 100644 src/axolotl/custom_optim/lion.py create mode 100644 src/axolotl/custom_optim/prodigy.py create mode 100644 src/axolotl/custom_optim/sophia.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 3502b229c..368b0ab02 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -972,29 +972,46 @@ def build(self, total_num_steps): trainer_kwargs = {} - if self.cfg.optimizer == "lion_pytorch": - from lion_pytorch import Lion - lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]} + if self.cfg.optimizer in ["lion_pytorch", "prodigy", "sophia"]: + + custom_optim_kwargs = {"lr": training_arguments_kwargs["learning_rate"]} if "weight_decay" in training_arguments_kwargs: - lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"] + custom_optim_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"] if ( "adam_beta1" in training_arguments_kwargs and "adam_beta2" in training_arguments_kwargs ): - lion_kwargs["betas"] = ( + custom_optim_kwargs["betas"] = ( training_arguments_kwargs["adam_beta1"], training_arguments_kwargs["adam_beta2"], ) - trainer_kwargs["optimizers"] = ( - Lion(params=self.model.parameters(), **lion_kwargs), - None, - ) + if self.cfg.optimizer == "lion_pytorch": + from axolotl.custom_optim.lion import Lion + trainer_kwargs["optimizers"] = ( + Lion(params=self.model.parameters(), **custom_optim_kwargs), + None, + ) + if self.cfg.optimizer == "sophia": + from axolotl.custom_optim.sophia import SophiaG + trainer_kwargs["optimizers"] = ( + SophiaG(params=self.model.parameters(), **custom_optim_kwargs), + None, + ) + if self.cfg.optimizer == "prodigy": + from axolotl.custom_optim.prodigy import Prodigy + trainer_kwargs["optimizers"] = ( + Prodigy(params=self.model.parameters(), **custom_optim_kwargs), + None, + ) + # Set default so transformers doesn't throw training_arguments_kwargs["optim"] = "adamw_hf" + + if self.cfg.optimizer == "adamw_anyprecision": if Path(self.cfg.torchdistx_path).exists(): sys.path.append(self.cfg.torchdistx_path) diff --git a/src/axolotl/custom_optim/__init__.py b/src/axolotl/custom_optim/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/custom_optim/lion.py b/src/axolotl/custom_optim/lion.py new file mode 100644 index 000000000..d27bc0710 --- /dev/null +++ b/src/axolotl/custom_optim/lion.py @@ -0,0 +1,191 @@ +from typing import Tuple, Optional, Callable + +import torch +from torch.optim.optimizer import Optimizer + +try: + import triton + import triton.language as tl +except ImportError as e: + print('triton is not installed, please install by running `pip install triton -U --pre`') + exit() + + +def exists(val): + return val is not None + + +# update functions + +def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): + # stepweight decay + + p.data.mul_(1 - lr * wd) + + # weight update + + update = exp_avg.clone().mul_(beta1).add(grad, alpha=1 - beta1).sign_() + p.add_(update, alpha=-lr) + + # decay the momentum running average coefficient + + exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) + + +def clone_inplace_updated_params(nargs): + nargs['p_ptr'] = nargs['p_ptr'].clone() + nargs['exp_avg_ptr'] = nargs['exp_avg_ptr'].clone() + + +# triton cuda kernel + +@triton.autotune(configs=[ + triton.Config({'BLOCK_SIZE': 128}, num_warps=4, pre_hook=clone_inplace_updated_params), + triton.Config({'BLOCK_SIZE': 1024}, num_warps=8, pre_hook=clone_inplace_updated_params), +], key=['n_elements']) +@triton.jit +def update_fn_kernel( + p_ptr, + grad_ptr, + exp_avg_ptr, + lr, + wd, + beta1, + beta2, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < n_elements + + # offsetted pointers + + offset_p_ptr = p_ptr + offsets + offset_grad_ptr = grad_ptr + offsets + offset_exp_avg_ptr = exp_avg_ptr + offsets + + # load + + p = tl.load(offset_p_ptr, mask=mask) + grad = tl.load(offset_grad_ptr, mask=mask) + exp_avg = tl.load(offset_exp_avg_ptr, mask=mask) + + # stepweight decay + + p = p * (1 - lr * wd) + + # diff between momentum running average and grad + + diff = exp_avg - grad + + # weight update + + update = diff * beta1 + grad + + # torch.sign + + can_update = update != 0 + update_sign = tl.where(update > 0, -lr, lr) + + p = p + update_sign * can_update + + # decay the momentum running average coefficient + + exp_avg = diff * beta2 + grad + + # store new params and momentum running average coefficient + + tl.store(offset_p_ptr, p, mask=mask) + tl.store(offset_exp_avg_ptr, exp_avg, mask=mask) + + +def update_fn( + p: torch.Tensor, + grad: torch.Tensor, + exp_avg: torch.Tensor, + lr: float, + wd: float, + beta1: float, + beta2: float +): + assert all([t.is_cuda for t in (p, grad, exp_avg)]) + n_elements = p.numel() + + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + + update_fn_kernel[grid]( + p, + grad, + exp_avg, + lr, + wd, + beta1, + beta2, + n_elements + ) + + +class Lion(Optimizer): + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + use_triton: bool = False + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + + defaults = dict( + lr=lr, + betas=betas, + weight_decay=weight_decay + ) + + super().__init__(params, defaults) + + self.update_fn = update_fn + + if use_triton: + self.update_fn = triton_update_fn + + @torch.no_grad() + def step( + self, + closure: Optional[Callable] = None + ): + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in filter(lambda p: exists(p.grad), group['params']): + + grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \ + self.state[p] + + # init state - exponential moving average of gradient values + + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + + self.update_fn( + p, + grad, + exp_avg, + lr, + wd, + beta1, + beta2 + ) + + return loss diff --git a/src/axolotl/custom_optim/prodigy.py b/src/axolotl/custom_optim/prodigy.py new file mode 100644 index 000000000..da9250eab --- /dev/null +++ b/src/axolotl/custom_optim/prodigy.py @@ -0,0 +1,251 @@ +import math +from typing import TYPE_CHECKING, Any, Callable, Optional + +import torch +import torch.optim +import logging +import os +import torch.distributed as dist + +if TYPE_CHECKING: + from torch.optim.optimizer import _params_t +else: + _params_t = Any + + +class Prodigy(torch.optim.Optimizer): + r""" + Implements Adam with Prodigy step-sizes. + Leave LR set to 1 unless you encounter instability. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate adjustment parameter. Increases or decreases the Prodigy learning rate. + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + beta3 (float): + coefficients for computing the Prodidy stepsize using running averages. + If set to None, uses the value of square root of beta2 (default: None). + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + decouple (boolean): + Use AdamW style decoupled weight decay + use_bias_correction (boolean): + Turn on Adam's bias correction. Off by default. + safeguard_warmup (boolean): + Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. + d0 (float): + Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. + d_coef (float): + Coefficient in the expression for the estimate of d (default 1.0). + Values such as 0.5 and 2.0 typically work as well. + Changing this parameter is the preferred way to tune the method. + growth_rate (float): + prevent the D estimate from growing faster than this multiplicative rate. + Default is inf, for unrestricted. Values like 1.02 give a kind of learning + rate warmup effect. + fsdp_in_use (bool): + If you're using sharded parameters, this should be set to True. The optimizer + will attempt to auto-detect this, but if you're using an implementation other + than PyTorch's builtin version, the auto-detection won't work. + """ + + def __init__(self, params, lr=1.0, + betas=(0.9, 0.999), beta3=None, + eps=1e-8, weight_decay=0, decouple=True, + use_bias_correction=True, safeguard_warmup=True, + d0=1e-6, d_coef=1.0, growth_rate=float('inf'), + fsdp_in_use=False): + if not 0.0 < d0: + raise ValueError("Invalid d0 value: {}".format(d0)) + if not 0.0 < lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 < eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + if decouple and weight_decay > 0: + print(f"Using decoupled weight decay") + + defaults = dict(lr=lr, betas=betas, beta3=beta3, + eps=eps, weight_decay=weight_decay, + d=d0, d0=d0, d_max=d0, + d_numerator=0.0, d_coef=d_coef, + k=0, growth_rate=growth_rate, + use_bias_correction=use_bias_correction, + decouple=decouple, safeguard_warmup=safeguard_warmup, + fsdp_in_use=fsdp_in_use) + self.d0 = d0 + super().__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return False + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + d_denom = 0.0 + + group = self.param_groups[0] + use_bias_correction = group['use_bias_correction'] + beta1, beta2 = group['betas'] + beta3 = group['beta3'] + if beta3 is None: + beta3 = math.sqrt(beta2) + k = group['k'] + + d = group['d'] + d_max = group['d_max'] + d_coef = group['d_coef'] + lr = max(group['lr'] for group in self.param_groups) + + if use_bias_correction: + bias_correction = ((1 - beta2 ** (k + 1)) ** 0.5) / (1 - beta1 ** (k + 1)) + else: + bias_correction = 1 + + dlr = d * lr * bias_correction + + growth_rate = group['growth_rate'] + decouple = group['decouple'] + fsdp_in_use = group['fsdp_in_use'] + + d_numerator = group['d_numerator'] + d_numerator *= beta3 + + for group in self.param_groups: + decay = group['weight_decay'] + k = group['k'] + eps = group['eps'] + group_lr = group['lr'] + d0 = group['d0'] + safeguard_warmup = group['safeguard_warmup'] + + if group_lr not in [lr, 0.0]: + raise RuntimeError( + f"Setting different lr values in different parameter groups is only supported for values of 0") + + for p in group['params']: + if p.grad is None: + continue + if hasattr(p, "_fsdp_flattened"): + fsdp_in_use = True + + grad = p.grad.data + + # Apply weight decay (coupled variant) + if decay != 0 and not decouple: + grad.add_(p.data, alpha=decay) + + state = self.state[p] + + # State initialization + if 'step' not in state: + state['step'] = 0 + state['s'] = torch.zeros_like(p.data).detach() + state['p0'] = p.detach().clone() + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data).detach() + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data).detach() + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + s = state['s'] + p0 = state['p0'] + + if group_lr > 0.0: + # we use d / d0 instead of just d to avoid getting values that are too small + d_numerator += (d / d0) * dlr * torch.dot(grad.flatten(), (p0.data - p.data).flatten()).item() + + # Adam EMA updates + exp_avg.mul_(beta1).add_(grad, alpha=d * (1 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=d * d * (1 - beta2)) + + if safeguard_warmup: + s.mul_(beta3).add_(grad, alpha=((d / d0) * d)) + else: + s.mul_(beta3).add_(grad, alpha=((d / d0) * dlr)) + d_denom += s.abs().sum().item() + + ###### + + d_hat = d + + # if we have not done any progres, return + # if we have any gradients available, will have d_denom > 0 (unless \|g\|=0) + if d_denom == 0: + return loss + + if lr > 0.0: + if fsdp_in_use: + dist_tensor = torch.zeros(2).cuda() + dist_tensor[0] = d_numerator + dist_tensor[1] = d_denom + dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) + global_d_numerator = dist_tensor[0] + global_d_denom = dist_tensor[1] + else: + global_d_numerator = d_numerator + global_d_denom = d_denom + + d_hat = d_coef * global_d_numerator / global_d_denom + if d == group['d0']: + d = max(d, d_hat) + d_max = max(d_max, d_hat) + d = min(d_max, d * growth_rate) + + for group in self.param_groups: + group['d_numerator'] = global_d_numerator + group['d_denom'] = global_d_denom + group['d'] = d + group['d_max'] = d_max + group['d_hat'] = d_hat + + decay = group['weight_decay'] + k = group['k'] + eps = group['eps'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + + state = self.state[p] + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + state['step'] += 1 + + denom = exp_avg_sq.sqrt().add_(d * eps) + + # Apply weight decay (decoupled variant) + if decay != 0 and decouple: + p.data.add_(p.data, alpha=-decay * dlr) + + ### Take step + p.data.addcdiv_(exp_avg, denom, value=-dlr) + + group['k'] = k + 1 + + return loss diff --git a/src/axolotl/custom_optim/sophia.py b/src/axolotl/custom_optim/sophia.py new file mode 100644 index 000000000..1dfcad29a --- /dev/null +++ b/src/axolotl/custom_optim/sophia.py @@ -0,0 +1,198 @@ +import math +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer +from typing import List, Optional + + +class SophiaG(Optimizer): + def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho=0.04, + weight_decay=1e-1, *, maximize: bool = False, + capturable: bool = False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= rho: + raise ValueError("Invalid rho parameter at index 1: {}".format(rho)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, rho=rho, + weight_decay=weight_decay, + maximize=maximize, capturable=capturable) + super(SophiaG, self).__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('maximize', False) + group.setdefault('capturable', False) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def update_hessian(self): + for group in self.param_groups: + beta1, beta2 = group['betas'] + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + + if len(state) == 0: + state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ + if self.defaults['capturable'] else torch.tensor(0.) + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + if 'hessian' not in state.keys(): + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2) + + @torch.no_grad() + def step(self, closure=None, bs=5120): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + state_steps = [] + hessian = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + + if p.grad.is_sparse: + raise RuntimeError('Hero does not support sparse gradients') + grads.append(p.grad) + state = self.state[p] + # State initialization + if len(state) == 0: + state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ + if self.defaults['capturable'] else torch.tensor(0.) + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + if 'hessian' not in state.keys(): + state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + state_steps.append(state['step']) + hessian.append(state['hessian']) + + if self.defaults['capturable']: + bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs + + sophiag(params_with_grad, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=group['rho'], + lr=group['lr'], + weight_decay=group['weight_decay'], + maximize=group['maximize'], + capturable=group['capturable']) + + return loss + + +def sophiag(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + capturable: bool = False, + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool): + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") + + func = _single_tensor_sophiag + + func(params, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=rho, + lr=lr, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable) + + +def _single_tensor_sophiag(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool, + capturable: bool): + for i, param in enumerate(params): + grad = grads[i] if not maximize else -grads[i] + exp_avg = exp_avgs[i] + hess = hessian[i] + step_t = state_steps[i] + + if capturable: + assert param.is_cuda and step_t.is_cuda and bs.is_cuda + + if torch.is_complex(param): + grad = torch.view_as_real(grad) + exp_avg = torch.view_as_real(exp_avg) + hess = torch.view_as_real(hess) + param = torch.view_as_real(param) + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + if capturable: + step_size = lr + step_size_neg = step_size.neg() + + ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1) + param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) + else: + step_size_neg = - lr + + ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1) + param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) From a16079b51e4edc4a258a3a38f377a9a6fe821331 Mon Sep 17 00:00:00 2001 From: Doan Minh Phuong <105905988+nRuaif@users.noreply.github.com> Date: Fri, 1 Mar 2024 23:11:06 +0700 Subject: [PATCH 3/6] Set bias correction and safeguard_warmup to false --- src/axolotl/custom_optim/prodigy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/custom_optim/prodigy.py b/src/axolotl/custom_optim/prodigy.py index da9250eab..663f54e68 100644 --- a/src/axolotl/custom_optim/prodigy.py +++ b/src/axolotl/custom_optim/prodigy.py @@ -57,7 +57,7 @@ class Prodigy(torch.optim.Optimizer): def __init__(self, params, lr=1.0, betas=(0.9, 0.999), beta3=None, eps=1e-8, weight_decay=0, decouple=True, - use_bias_correction=True, safeguard_warmup=True, + use_bias_correction=False, safeguard_warmup=False, d0=1e-6, d_coef=1.0, growth_rate=float('inf'), fsdp_in_use=False): if not 0.0 < d0: From 93e95e099845b4765bc597c4a6711a7ca3492bb2 Mon Sep 17 00:00:00 2001 From: Doan Minh Phuong <105905988+nRuaif@users.noreply.github.com> Date: Fri, 1 Mar 2024 23:34:51 +0700 Subject: [PATCH 4/6] Test --- src/axolotl/core/trainer_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 368b0ab02..91d8e59a9 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1003,7 +1003,7 @@ def build(self, total_num_steps): if self.cfg.optimizer == "prodigy": from axolotl.custom_optim.prodigy import Prodigy trainer_kwargs["optimizers"] = ( - Prodigy(params=self.model.parameters(), **custom_optim_kwargs), + Prodigy(params=filter(lambda p: p.requires_grad, self.model.parameters()), **custom_optim_kwargs), None, ) From 398a94cf665538783a379bbdd3b0a6cd1a017086 Mon Sep 17 00:00:00 2001 From: Doan Minh Phuong <105905988+nRuaif@users.noreply.github.com> Date: Sat, 2 Mar 2024 10:53:32 +0700 Subject: [PATCH 5/6] fix typo --- src/axolotl/custom_optim/lion.py | 2 +- src/axolotl/custom_optim/sophia.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/custom_optim/lion.py b/src/axolotl/custom_optim/lion.py index d27bc0710..a15460552 100644 --- a/src/axolotl/custom_optim/lion.py +++ b/src/axolotl/custom_optim/lion.py @@ -103,7 +103,7 @@ def update_fn_kernel( tl.store(offset_exp_avg_ptr, exp_avg, mask=mask) -def update_fn( +def triton_update_fn( p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor, diff --git a/src/axolotl/custom_optim/sophia.py b/src/axolotl/custom_optim/sophia.py index 1dfcad29a..d8ea59649 100644 --- a/src/axolotl/custom_optim/sophia.py +++ b/src/axolotl/custom_optim/sophia.py @@ -168,7 +168,7 @@ def _single_tensor_sophiag(params: List[Tensor], step_t = state_steps[i] if capturable: - assert param.is_cuda and step_t.is_cuda and bs.is_cuda + assert param.is_cuda and step_t.is_cuda if torch.is_complex(param): grad = torch.view_as_real(grad) From 24459eeae6be8ca3307d4568a2b6c84d0af8cb7e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 5 Mar 2024 20:08:47 -0500 Subject: [PATCH 6/6] chore: lint --- src/axolotl/core/trainer_builder.py | 18 +- src/axolotl/custom_optim/lion.py | 95 ++++--- src/axolotl/custom_optim/prodigy.py | 155 +++++++----- src/axolotl/custom_optim/sophia.py | 237 +++++++++++------- .../config/models/input/v0_4_1/__init__.py | 4 +- 5 files changed, 288 insertions(+), 221 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 91d8e59a9..62e4a788a 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -972,12 +972,12 @@ def build(self, total_num_steps): trainer_kwargs = {} - if self.cfg.optimizer in ["lion_pytorch", "prodigy", "sophia"]: - custom_optim_kwargs = {"lr": training_arguments_kwargs["learning_rate"]} if "weight_decay" in training_arguments_kwargs: - custom_optim_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"] + custom_optim_kwargs["weight_decay"] = training_arguments_kwargs[ + "weight_decay" + ] if ( "adam_beta1" in training_arguments_kwargs @@ -990,28 +990,34 @@ def build(self, total_num_steps): if self.cfg.optimizer == "lion_pytorch": from axolotl.custom_optim.lion import Lion + trainer_kwargs["optimizers"] = ( Lion(params=self.model.parameters(), **custom_optim_kwargs), None, ) if self.cfg.optimizer == "sophia": from axolotl.custom_optim.sophia import SophiaG + trainer_kwargs["optimizers"] = ( SophiaG(params=self.model.parameters(), **custom_optim_kwargs), None, ) if self.cfg.optimizer == "prodigy": from axolotl.custom_optim.prodigy import Prodigy + trainer_kwargs["optimizers"] = ( - Prodigy(params=filter(lambda p: p.requires_grad, self.model.parameters()), **custom_optim_kwargs), + Prodigy( + params=filter( + lambda p: p.requires_grad, self.model.parameters() + ), + **custom_optim_kwargs, + ), None, ) # Set default so transformers doesn't throw training_arguments_kwargs["optim"] = "adamw_hf" - - if self.cfg.optimizer == "adamw_anyprecision": if Path(self.cfg.torchdistx_path).exists(): sys.path.append(self.cfg.torchdistx_path) diff --git a/src/axolotl/custom_optim/lion.py b/src/axolotl/custom_optim/lion.py index a15460552..c3aaa786d 100644 --- a/src/axolotl/custom_optim/lion.py +++ b/src/axolotl/custom_optim/lion.py @@ -1,4 +1,4 @@ -from typing import Tuple, Optional, Callable +from typing import Callable, Optional, Tuple import torch from torch.optim.optimizer import Optimizer @@ -6,8 +6,10 @@ try: import triton import triton.language as tl -except ImportError as e: - print('triton is not installed, please install by running `pip install triton -U --pre`') +except ImportError: + print( + "triton is not installed, please install by running `pip install triton -U --pre`" + ) exit() @@ -17,6 +19,7 @@ def exists(val): # update functions + def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): # stepweight decay @@ -33,16 +36,24 @@ def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): def clone_inplace_updated_params(nargs): - nargs['p_ptr'] = nargs['p_ptr'].clone() - nargs['exp_avg_ptr'] = nargs['exp_avg_ptr'].clone() + nargs["p_ptr"] = nargs["p_ptr"].clone() + nargs["exp_avg_ptr"] = nargs["exp_avg_ptr"].clone() # triton cuda kernel -@triton.autotune(configs=[ - triton.Config({'BLOCK_SIZE': 128}, num_warps=4, pre_hook=clone_inplace_updated_params), - triton.Config({'BLOCK_SIZE': 1024}, num_warps=8, pre_hook=clone_inplace_updated_params), -], key=['n_elements']) + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE": 128}, num_warps=4, pre_hook=clone_inplace_updated_params + ), + triton.Config( + {"BLOCK_SIZE": 1024}, num_warps=8, pre_hook=clone_inplace_updated_params + ), + ], + key=["n_elements"], +) @triton.jit def update_fn_kernel( p_ptr, @@ -110,23 +121,15 @@ def triton_update_fn( lr: float, wd: float, beta1: float, - beta2: float + beta2: float, ): assert all([t.is_cuda for t in (p, grad, exp_avg)]) n_elements = p.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - - update_fn_kernel[grid]( - p, - grad, - exp_avg, - lr, - wd, - beta1, - beta2, - n_elements - ) + def grid(meta): + return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + + update_fn_kernel[grid](p, grad, exp_avg, lr, wd, beta1, beta2, n_elements) class Lion(Optimizer): @@ -136,16 +139,12 @@ def __init__( lr: float = 1e-4, betas: Tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, - use_triton: bool = False + use_triton: bool = False, ): - assert lr > 0. - assert all([0. <= beta <= 1. for beta in betas]) + assert lr > 0.0 + assert all([0.0 <= beta <= 1.0 for beta in betas]) - defaults = dict( - lr=lr, - betas=betas, - weight_decay=weight_decay - ) + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) super().__init__(params, defaults) @@ -155,37 +154,29 @@ def __init__( self.update_fn = triton_update_fn @torch.no_grad() - def step( - self, - closure: Optional[Callable] = None - ): - + def step(self, closure: Optional[Callable] = None): loss = None if exists(closure): with torch.enable_grad(): loss = closure() for group in self.param_groups: - for p in filter(lambda p: exists(p.grad), group['params']): - - grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \ - self.state[p] + for p in filter(lambda p: exists(p.grad), group["params"]): + grad, lr, wd, beta1, beta2, state = ( + p.grad, + group["lr"], + group["weight_decay"], + *group["betas"], + self.state[p], + ) # init state - exponential moving average of gradient values if len(state) == 0: - state['exp_avg'] = torch.zeros_like(p) - - exp_avg = state['exp_avg'] - - self.update_fn( - p, - grad, - exp_avg, - lr, - wd, - beta1, - beta2 - ) + state["exp_avg"] = torch.zeros_like(p) + + exp_avg = state["exp_avg"] + + self.update_fn(p, grad, exp_avg, lr, wd, beta1, beta2) return loss diff --git a/src/axolotl/custom_optim/prodigy.py b/src/axolotl/custom_optim/prodigy.py index 663f54e68..01dda02b1 100644 --- a/src/axolotl/custom_optim/prodigy.py +++ b/src/axolotl/custom_optim/prodigy.py @@ -1,11 +1,9 @@ import math -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any import torch -import torch.optim -import logging -import os import torch.distributed as dist +import torch.optim if TYPE_CHECKING: from torch.optim.optimizer import _params_t @@ -54,12 +52,22 @@ class Prodigy(torch.optim.Optimizer): than PyTorch's builtin version, the auto-detection won't work. """ - def __init__(self, params, lr=1.0, - betas=(0.9, 0.999), beta3=None, - eps=1e-8, weight_decay=0, decouple=True, - use_bias_correction=False, safeguard_warmup=False, - d0=1e-6, d_coef=1.0, growth_rate=float('inf'), - fsdp_in_use=False): + def __init__( + self, + params, + lr=1.0, + betas=(0.9, 0.999), + beta3=None, + eps=1e-8, + weight_decay=0, + decouple=True, + use_bias_correction=False, + safeguard_warmup=False, + d0=1e-6, + d_coef=1.0, + growth_rate=float("inf"), + fsdp_in_use=False, + ): if not 0.0 < d0: raise ValueError("Invalid d0 value: {}".format(d0)) if not 0.0 < lr: @@ -72,16 +80,26 @@ def __init__(self, params, lr=1.0, raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if decouple and weight_decay > 0: - print(f"Using decoupled weight decay") - - defaults = dict(lr=lr, betas=betas, beta3=beta3, - eps=eps, weight_decay=weight_decay, - d=d0, d0=d0, d_max=d0, - d_numerator=0.0, d_coef=d_coef, - k=0, growth_rate=growth_rate, - use_bias_correction=use_bias_correction, - decouple=decouple, safeguard_warmup=safeguard_warmup, - fsdp_in_use=fsdp_in_use) + print("Using decoupled weight decay") + + defaults = dict( + lr=lr, + betas=betas, + beta3=beta3, + eps=eps, + weight_decay=weight_decay, + d=d0, + d0=d0, + d_max=d0, + d_numerator=0.0, + d_coef=d_coef, + k=0, + growth_rate=growth_rate, + use_bias_correction=use_bias_correction, + decouple=decouple, + safeguard_warmup=safeguard_warmup, + fsdp_in_use=fsdp_in_use, + ) self.d0 = d0 super().__init__(params, defaults) @@ -107,17 +125,17 @@ def step(self, closure=None): d_denom = 0.0 group = self.param_groups[0] - use_bias_correction = group['use_bias_correction'] - beta1, beta2 = group['betas'] - beta3 = group['beta3'] + use_bias_correction = group["use_bias_correction"] + beta1, beta2 = group["betas"] + beta3 = group["beta3"] if beta3 is None: beta3 = math.sqrt(beta2) - k = group['k'] + k = group["k"] - d = group['d'] - d_max = group['d_max'] - d_coef = group['d_coef'] - lr = max(group['lr'] for group in self.param_groups) + d = group["d"] + d_max = group["d_max"] + d_coef = group["d_coef"] + lr = max(group["lr"] for group in self.param_groups) if use_bias_correction: bias_correction = ((1 - beta2 ** (k + 1)) ** 0.5) / (1 - beta1 ** (k + 1)) @@ -126,26 +144,27 @@ def step(self, closure=None): dlr = d * lr * bias_correction - growth_rate = group['growth_rate'] - decouple = group['decouple'] - fsdp_in_use = group['fsdp_in_use'] + growth_rate = group["growth_rate"] + decouple = group["decouple"] + fsdp_in_use = group["fsdp_in_use"] - d_numerator = group['d_numerator'] + d_numerator = group["d_numerator"] d_numerator *= beta3 for group in self.param_groups: - decay = group['weight_decay'] - k = group['k'] - eps = group['eps'] - group_lr = group['lr'] - d0 = group['d0'] - safeguard_warmup = group['safeguard_warmup'] + decay = group["weight_decay"] + k = group["k"] + eps = group["eps"] + group_lr = group["lr"] + d0 = group["d0"] + safeguard_warmup = group["safeguard_warmup"] if group_lr not in [lr, 0.0]: raise RuntimeError( - f"Setting different lr values in different parameter groups is only supported for values of 0") + "Setting different lr values in different parameter groups is only supported for values of 0" + ) - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if hasattr(p, "_fsdp_flattened"): @@ -160,27 +179,33 @@ def step(self, closure=None): state = self.state[p] # State initialization - if 'step' not in state: - state['step'] = 0 - state['s'] = torch.zeros_like(p.data).detach() - state['p0'] = p.detach().clone() + if "step" not in state: + state["step"] = 0 + state["s"] = torch.zeros_like(p.data).detach() + state["p0"] = p.detach().clone() # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data).detach() + state["exp_avg"] = torch.zeros_like(p.data).detach() # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data).detach() + state["exp_avg_sq"] = torch.zeros_like(p.data).detach() - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - s = state['s'] - p0 = state['p0'] + s = state["s"] + p0 = state["p0"] if group_lr > 0.0: # we use d / d0 instead of just d to avoid getting values that are too small - d_numerator += (d / d0) * dlr * torch.dot(grad.flatten(), (p0.data - p.data).flatten()).item() + d_numerator += ( + (d / d0) + * dlr + * torch.dot(grad.flatten(), (p0.data - p.data).flatten()).item() + ) # Adam EMA updates exp_avg.mul_(beta1).add_(grad, alpha=d * (1 - beta1)) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=d * d * (1 - beta2)) + exp_avg_sq.mul_(beta2).addcmul_( + grad, grad, value=d * d * (1 - beta2) + ) if safeguard_warmup: s.mul_(beta3).add_(grad, alpha=((d / d0) * d)) @@ -210,32 +235,32 @@ def step(self, closure=None): global_d_denom = d_denom d_hat = d_coef * global_d_numerator / global_d_denom - if d == group['d0']: + if d == group["d0"]: d = max(d, d_hat) d_max = max(d_max, d_hat) d = min(d_max, d * growth_rate) for group in self.param_groups: - group['d_numerator'] = global_d_numerator - group['d_denom'] = global_d_denom - group['d'] = d - group['d_max'] = d_max - group['d_hat'] = d_hat + group["d_numerator"] = global_d_numerator + group["d_denom"] = global_d_denom + group["d"] = d + group["d_max"] = d_max + group["d_hat"] = d_hat - decay = group['weight_decay'] - k = group['k'] - eps = group['eps'] + decay = group["weight_decay"] + k = group["k"] + eps = group["eps"] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue grad = p.grad.data state = self.state[p] - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - state['step'] += 1 + state["step"] += 1 denom = exp_avg_sq.sqrt().add_(d * eps) @@ -243,9 +268,9 @@ def step(self, closure=None): if decay != 0 and decouple: p.data.add_(p.data, alpha=-decay * dlr) - ### Take step + # Take step p.data.addcdiv_(exp_avg, denom, value=-dlr) - group['k'] = k + 1 + group["k"] = k + 1 return loss diff --git a/src/axolotl/custom_optim/sophia.py b/src/axolotl/custom_optim/sophia.py index d8ea59649..85838ca97 100644 --- a/src/axolotl/custom_optim/sophia.py +++ b/src/axolotl/custom_optim/sophia.py @@ -1,14 +1,22 @@ -import math +from typing import List + import torch from torch import Tensor from torch.optim.optimizer import Optimizer -from typing import List, Optional class SophiaG(Optimizer): - def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho=0.04, - weight_decay=1e-1, *, maximize: bool = False, - capturable: bool = False): + def __init__( + self, + params, + lr=1e-4, + betas=(0.965, 0.99), + rho=0.04, + weight_decay=1e-1, + *, + maximize: bool = False, + capturable: bool = False + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= betas[0] < 1.0: @@ -19,41 +27,57 @@ def __init__(self, params, lr=1e-4, betas=(0.965, 0.99), rho=0.04, raise ValueError("Invalid rho parameter at index 1: {}".format(rho)) if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - defaults = dict(lr=lr, betas=betas, rho=rho, - weight_decay=weight_decay, - maximize=maximize, capturable=capturable) + defaults = dict( + lr=lr, + betas=betas, + rho=rho, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable, + ) super(SophiaG, self).__init__(params, defaults) def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: - group.setdefault('maximize', False) - group.setdefault('capturable', False) + group.setdefault("maximize", False) + group.setdefault("capturable", False) state_values = list(self.state.values()) - step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step']) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]["step"] + ) if not step_is_tensor: for s in state_values: - s['step'] = torch.tensor(float(s['step'])) + s["step"] = torch.tensor(float(s["step"])) @torch.no_grad() def update_hessian(self): for group in self.param_groups: - beta1, beta2 = group['betas'] - for p in group['params']: + beta1, beta2 = group["betas"] + for p in group["params"]: if p.grad is None: continue state = self.state[p] if len(state) == 0: - state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ - if self.defaults['capturable'] else torch.tensor(0.) - state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) - state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) - - if 'hessian' not in state.keys(): - state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) - - state['hessian'].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2) + state["step"] = ( + torch.zeros((1,), dtype=torch.float, device=p.device) + if self.defaults["capturable"] + else torch.tensor(0.0) + ) + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["hessian"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + if "hessian" not in state.keys(): + state["hessian"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + state["hessian"].mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2) @torch.no_grad() def step(self, closure=None, bs=5120): @@ -68,99 +92,118 @@ def step(self, closure=None, bs=5120): exp_avgs = [] state_steps = [] hessian = [] - beta1, beta2 = group['betas'] + beta1, beta2 = group["betas"] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue params_with_grad.append(p) if p.grad.is_sparse: - raise RuntimeError('Hero does not support sparse gradients') + raise RuntimeError("Hero does not support sparse gradients") grads.append(p.grad) state = self.state[p] # State initialization if len(state) == 0: - state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \ - if self.defaults['capturable'] else torch.tensor(0.) - state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) - state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) - - if 'hessian' not in state.keys(): - state['hessian'] = torch.zeros_like(p, memory_format=torch.preserve_format) - - exp_avgs.append(state['exp_avg']) - state_steps.append(state['step']) - hessian.append(state['hessian']) - - if self.defaults['capturable']: + state["step"] = ( + torch.zeros((1,), dtype=torch.float, device=p.device) + if self.defaults["capturable"] + else torch.tensor(0.0) + ) + state["exp_avg"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + state["hessian"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + if "hessian" not in state.keys(): + state["hessian"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + + exp_avgs.append(state["exp_avg"]) + state_steps.append(state["step"]) + hessian.append(state["hessian"]) + + if self.defaults["capturable"]: bs = torch.ones((1,), dtype=torch.float, device=p.device) * bs - sophiag(params_with_grad, - grads, - exp_avgs, - hessian, - state_steps, - bs=bs, - beta1=beta1, - beta2=beta2, - rho=group['rho'], - lr=group['lr'], - weight_decay=group['weight_decay'], - maximize=group['maximize'], - capturable=group['capturable']) + sophiag( + params_with_grad, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=group["rho"], + lr=group["lr"], + weight_decay=group["weight_decay"], + maximize=group["maximize"], + capturable=group["capturable"], + ) return loss -def sophiag(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - hessian: List[Tensor], - state_steps: List[Tensor], - capturable: bool = False, - *, - bs: int, - beta1: float, - beta2: float, - rho: float, - lr: float, - weight_decay: float, - maximize: bool): +def sophiag( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + capturable: bool = False, + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool +): if not all(isinstance(t, torch.Tensor) for t in state_steps): - raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") + raise RuntimeError( + "API has changed, `state_steps` argument must contain a list of singleton tensors" + ) func = _single_tensor_sophiag - func(params, - grads, - exp_avgs, - hessian, - state_steps, - bs=bs, - beta1=beta1, - beta2=beta2, - rho=rho, - lr=lr, - weight_decay=weight_decay, - maximize=maximize, - capturable=capturable) - - -def _single_tensor_sophiag(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - hessian: List[Tensor], - state_steps: List[Tensor], - *, - bs: int, - beta1: float, - beta2: float, - rho: float, - lr: float, - weight_decay: float, - maximize: bool, - capturable: bool): + func( + params, + grads, + exp_avgs, + hessian, + state_steps, + bs=bs, + beta1=beta1, + beta2=beta2, + rho=rho, + lr=lr, + weight_decay=weight_decay, + maximize=maximize, + capturable=capturable, + ) + + +def _single_tensor_sophiag( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + hessian: List[Tensor], + state_steps: List[Tensor], + *, + bs: int, + beta1: float, + beta2: float, + rho: float, + lr: float, + weight_decay: float, + maximize: bool, + capturable: bool +): for i, param in enumerate(params): grad = grads[i] if not maximize else -grads[i] exp_avg = exp_avgs[i] @@ -192,7 +235,7 @@ def _single_tensor_sophiag(params: List[Tensor], ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1) param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) else: - step_size_neg = - lr + step_size_neg = -lr ratio = (exp_avg.abs() / (rho * bs * hess + 1e-15)).clamp(None, 1) param.addcmul_(exp_avg.sign(), ratio, value=step_size_neg) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 1494ace2c..efd13bd19 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -282,7 +282,9 @@ class HyperparametersConfig(BaseModel): learning_rate: Union[str, float] weight_decay: Optional[float] = None - optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch", "prodigy", "sophia"]]] = None + optimizer: Optional[ + Union[OptimizerNames, Literal["lion_pytorch", "prodigy", "sophia"]] + ] = None torchdistx_path: Optional[str] = None lr_scheduler: Optional[SchedulerType] = None lr_scheduler_kwargs: Optional[Dict[str, Any]] = None