Skip to content

Commit

Permalink
grokfast plugin fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 17, 2024
1 parent 5244066 commit f7897f7
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 12 deletions.
8 changes: 6 additions & 2 deletions src/axolotl/integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,9 @@ def add_callbacks_pre_trainer(self, cfg, model):
"""
callbacks = []
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
plugin_callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
if plugin_callbacks: # if the plugin returned a list of callbacks
callbacks.extend(plugin_callbacks)
return callbacks

def add_callbacks_post_trainer(self, cfg, trainer):
Expand All @@ -380,5 +382,7 @@ def add_callbacks_post_trainer(self, cfg, trainer):
"""
callbacks = []
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
plugin_callbacks = plugin.add_callbacks_post_trainer(cfg, trainer)
if plugin_callbacks:
callbacks.extend(plugin_callbacks)
return callbacks
18 changes: 12 additions & 6 deletions src/axolotl/integrations/grokfast/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,35 @@
"""
Grokfast plugin for Axolotl
"""
from transformers.trainer_callback import CallbackHandler
import logging

from transformers.trainer_callback import TrainerCallback

from ..base import BasePlugin
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
from .optimizer import gradfilter_ema

LOG = logging.getLogger("axolotl.integrations.grokfast")


class GrokfastCallbackHandler(CallbackHandler):
class GrokfastCallbackHandler(TrainerCallback):
"""
Transformer trainer callbacks for Grokfast
"""

def __init__(self, *args, alpha=0.98, lamb=2.0, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):
super().__init__(*args_, **kwargs)
self.grads = None
self.alpha = alpha
self.lamb = lamb

def on_train_begin(self, args, state): # pylint: disable=unused-argument
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
self.grads = None

def on_pre_optimizer_step(
self, args, state, control, model
self, args_, state, control, **kwargs
): # pylint: disable=unused-argument
model = kwargs.pop("model")
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
return control

Expand All @@ -38,6 +43,7 @@ def get_input_args(self):
return "axolotl.integrations.grokfast.GrokfastArgs"

def add_callbacks_post_trainer(self, cfg, trainer):
LOG.info("Adding Grokfast callback to the trainer")
callback = GrokfastCallbackHandler(
alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb
)
Expand Down
8 changes: 4 additions & 4 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,10 @@ def tokenize_prompt(self, prompt):
}

return tokenized_prompt
LOG.info(self.roles_to_train)
LOG.info(self.train_on_eos)
LOG.info(self.prompter.message_field_training)
LOG.info(self.prompter.message_field_training_detail)
LOG.debug(self.roles_to_train)
LOG.debug(self.train_on_eos)
LOG.debug(self.prompter.message_field_training)
LOG.debug(self.prompter.message_field_training_detail)

turns = prompt[self.messages]
input_ids = self.prompter.build_prompt(turns)
Expand Down
2 changes: 2 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,8 @@ class Config:
is_mistral_derived_model: Optional[bool] = Field(default=None)
is_qwen_derived_model: Optional[bool] = Field(default=None)

plugins: Optional[List[str]] = Field(default=None)

@field_validator("datasets", mode="before")
@classmethod
def fix_sharegpt_datasets(cls, datasets):
Expand Down

0 comments on commit f7897f7

Please sign in to comment.