Skip to content

Commit

Permalink
Liger Kernel integration (#1861)
Browse files Browse the repository at this point in the history
* add initial plugin support w Liger kernel patches

* integrate the input args classes

* fix liger plugin and dynamic configuration class

* drop untrainable samples and refactor config plugins integration

* fix incorrect inputs and circular imports

* fix bool comparison

* fix for dropping untraibable tokens

* fix licensing so liger integration is Apache 2.0

* add jamba support

* pylint ignore
  • Loading branch information
winglian authored Aug 23, 2024
1 parent e8ff5d5 commit 1f686c5
Show file tree
Hide file tree
Showing 12 changed files with 1,010 additions and 3 deletions.
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ ignore_errors = True
[mypy-axolotl.models.mixtral.*]
ignore_errors = True

[mypy-axolotl.integrations.liger.models.*]
ignore_errors = True

[mypy-axolotl.models.phi.*]
ignore_errors = True

Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
triton>=2.3.0
liger-kernel

mamba-ssm==1.2.0.post1

Expand Down
6 changes: 6 additions & 0 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from transformers.utils.import_utils import _is_package_available

from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.config import (
Expand Down Expand Up @@ -365,6 +366,11 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):

cfg.axolotl_config_path = config

if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)

try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
Expand Down
Loading

0 comments on commit 1f686c5

Please sign in to comment.