-
-
Notifications
You must be signed in to change notification settings - Fork 795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Prodigy, SophiaG optimizers #1350
Draft
Kimiko-AI
wants to merge
9
commits into
axolotl-ai-cloud:main
Choose a base branch
from
Kimiko-AI:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 7 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
f226253
Add new optims.
Kimiko-AI f426aad
Add new optims.
Kimiko-AI 650f820
Merge remote-tracking branch 'origin/main'
Kimiko-AI c1c1361
Fix val check
Kimiko-AI f7f9351
Merge remote-tracking branch 'origin/main'
Kimiko-AI a16079b
Set bias correction and safeguard_warmup to false
Kimiko-AI 93e95e0
Test
Kimiko-AI 398a94c
fix typo
Kimiko-AI 24459ee
chore: lint
winglian File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
from typing import Tuple, Optional, Callable | ||
|
||
import torch | ||
from torch.optim.optimizer import Optimizer | ||
|
||
try: | ||
import triton | ||
import triton.language as tl | ||
except ImportError as e: | ||
print('triton is not installed, please install by running `pip install triton -U --pre`') | ||
exit() | ||
|
||
|
||
def exists(val): | ||
return val is not None | ||
|
||
|
||
# update functions | ||
|
||
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): | ||
# stepweight decay | ||
|
||
p.data.mul_(1 - lr * wd) | ||
|
||
# weight update | ||
|
||
update = exp_avg.clone().mul_(beta1).add(grad, alpha=1 - beta1).sign_() | ||
p.add_(update, alpha=-lr) | ||
|
||
# decay the momentum running average coefficient | ||
|
||
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) | ||
|
||
|
||
def clone_inplace_updated_params(nargs): | ||
nargs['p_ptr'] = nargs['p_ptr'].clone() | ||
nargs['exp_avg_ptr'] = nargs['exp_avg_ptr'].clone() | ||
|
||
|
||
# triton cuda kernel | ||
|
||
@triton.autotune(configs=[ | ||
triton.Config({'BLOCK_SIZE': 128}, num_warps=4, pre_hook=clone_inplace_updated_params), | ||
triton.Config({'BLOCK_SIZE': 1024}, num_warps=8, pre_hook=clone_inplace_updated_params), | ||
], key=['n_elements']) | ||
@triton.jit | ||
def update_fn_kernel( | ||
p_ptr, | ||
grad_ptr, | ||
exp_avg_ptr, | ||
lr, | ||
wd, | ||
beta1, | ||
beta2, | ||
n_elements, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
pid = tl.program_id(axis=0) | ||
|
||
block_start = pid * BLOCK_SIZE | ||
offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
|
||
mask = offsets < n_elements | ||
|
||
# offsetted pointers | ||
|
||
offset_p_ptr = p_ptr + offsets | ||
offset_grad_ptr = grad_ptr + offsets | ||
offset_exp_avg_ptr = exp_avg_ptr + offsets | ||
|
||
# load | ||
|
||
p = tl.load(offset_p_ptr, mask=mask) | ||
grad = tl.load(offset_grad_ptr, mask=mask) | ||
exp_avg = tl.load(offset_exp_avg_ptr, mask=mask) | ||
|
||
# stepweight decay | ||
|
||
p = p * (1 - lr * wd) | ||
|
||
# diff between momentum running average and grad | ||
|
||
diff = exp_avg - grad | ||
|
||
# weight update | ||
|
||
update = diff * beta1 + grad | ||
|
||
# torch.sign | ||
|
||
can_update = update != 0 | ||
update_sign = tl.where(update > 0, -lr, lr) | ||
|
||
p = p + update_sign * can_update | ||
|
||
# decay the momentum running average coefficient | ||
|
||
exp_avg = diff * beta2 + grad | ||
|
||
# store new params and momentum running average coefficient | ||
|
||
tl.store(offset_p_ptr, p, mask=mask) | ||
tl.store(offset_exp_avg_ptr, exp_avg, mask=mask) | ||
|
||
|
||
def update_fn( | ||
p: torch.Tensor, | ||
grad: torch.Tensor, | ||
exp_avg: torch.Tensor, | ||
lr: float, | ||
wd: float, | ||
beta1: float, | ||
beta2: float | ||
): | ||
assert all([t.is_cuda for t in (p, grad, exp_avg)]) | ||
n_elements = p.numel() | ||
|
||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) | ||
|
||
update_fn_kernel[grid]( | ||
p, | ||
grad, | ||
exp_avg, | ||
lr, | ||
wd, | ||
beta1, | ||
beta2, | ||
n_elements | ||
) | ||
|
||
|
||
class Lion(Optimizer): | ||
def __init__( | ||
self, | ||
params, | ||
lr: float = 1e-4, | ||
betas: Tuple[float, float] = (0.9, 0.99), | ||
weight_decay: float = 0.0, | ||
use_triton: bool = False | ||
): | ||
assert lr > 0. | ||
assert all([0. <= beta <= 1. for beta in betas]) | ||
|
||
defaults = dict( | ||
lr=lr, | ||
betas=betas, | ||
weight_decay=weight_decay | ||
) | ||
|
||
super().__init__(params, defaults) | ||
|
||
self.update_fn = update_fn | ||
|
||
if use_triton: | ||
self.update_fn = triton_update_fn | ||
|
||
@torch.no_grad() | ||
def step( | ||
self, | ||
closure: Optional[Callable] = None | ||
): | ||
|
||
loss = None | ||
if exists(closure): | ||
with torch.enable_grad(): | ||
loss = closure() | ||
|
||
for group in self.param_groups: | ||
for p in filter(lambda p: exists(p.grad), group['params']): | ||
|
||
grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \ | ||
self.state[p] | ||
|
||
# init state - exponential moving average of gradient values | ||
|
||
if len(state) == 0: | ||
state['exp_avg'] = torch.zeros_like(p) | ||
|
||
exp_avg = state['exp_avg'] | ||
|
||
self.update_fn( | ||
p, | ||
grad, | ||
exp_avg, | ||
lr, | ||
wd, | ||
beta1, | ||
beta2 | ||
) | ||
|
||
return loss |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thie
update_fn
function is redefined below on line 106.