Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/llamafactory/v1/accelerator/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@
from torch.distributed import barrier, destroy_process_group, init_process_group
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

from ..utils import logging
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
from . import helper


logger = logging.get_logger(__name__)


class Dim(str, Enum):
"""Dimension names."""

Expand Down Expand Up @@ -157,6 +161,7 @@ def __init__(self, config: DistributedConfig | None = None) -> None:
self.data_device_mesh = None

self._initialized = True
logger.info_rank0(f"DistributedInterface initialized with strategy={self.strategy}.")

def __str__(self) -> str:
return (
Expand Down
32 changes: 32 additions & 0 deletions src/llamafactory/v1/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .arg_parser import InputArgument, get_args
from .arg_utils import ModelClass, SampleBackend
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
from .training_args import TrainingArguments


__all__ = [
"DataArguments",
"InputArgument",
"ModelArguments",
"ModelClass",
"SampleArguments",
"SampleBackend",
"TrainingArguments",
"get_args",
]
9 changes: 5 additions & 4 deletions src/llamafactory/v1/config/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ class ModelArguments:
default=False,
metadata={"help": "Trust remote code from Hugging Face."},
)
use_fast_processor: bool = field(
default=True,
metadata={"help": "Use fast processor from Hugging Face."},
)
model_class: ModelClass = field(
default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."},
)
init_config: PluginConfig | None = field(
default=None,
metadata={"help": "Initialization configuration for the model."},
)
peft_config: PluginConfig | None = field(
default=None,
metadata={"help": "PEFT configuration for the model."},
Expand All @@ -49,6 +49,7 @@ class ModelArguments:
)

def __post_init__(self) -> None:
self.init_config = get_plugin_config(self.init_config)
self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config)
self.quant_config = get_plugin_config(self.quant_config)
2 changes: 1 addition & 1 deletion src/llamafactory/v1/config/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
@dataclass
class TrainingArguments:
output_dir: str = field(
default=os.path.join("outputs", str(uuid4())),
default=os.path.join("outputs", str(uuid4().hex)),
metadata={"help": "Path to the output directory."},
)
micro_batch_size: int = field(
Expand Down
77 changes: 77 additions & 0 deletions src/llamafactory/v1/core/base_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod

from ..config import ModelArguments, SampleArguments, SampleBackend
from ..utils.types import HFModel, Processor, TorchDataset


class BaseEngine(ABC):
@abstractmethod
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel = None,
processor: Processor = None,
) -> None:
"""Initialize the engine.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
processor: Processor.
"""
...

@abstractmethod
async def generate(self, messages):
pass

@abstractmethod
async def batch_infer(self, data: TorchDataset) -> None:
pass


class HuggingFaceEngine(BaseEngine):
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
processor: Processor,
) -> None:
self.args = args


class BaseSampler:
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
processor: Processor,
) -> None:
if args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(args, model_args, model, processor)
else:
raise ValueError(f"Unknown sample backend: {args.sample_backend}")

async def generate(self, messages):
return await self.engine.generate(messages)

async def batch_infer(self, data: TorchDataset) -> None:
return await self.engine.batch_infer(data)
44 changes: 0 additions & 44 deletions src/llamafactory/v1/core/chat_sampler.py

This file was deleted.

44 changes: 32 additions & 12 deletions src/llamafactory/v1/core/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,24 @@

"""The definition of model loader.

Init Phase:
How to use:
model_loader = ModelLoader(model_args, is_trainable=True)
model_loader.processor: Get the tokenizer or multi-modal processor.
model_loader.model_config: Get the model configuration.
model_loader.model: Get the HF model.

Init Workflow:
1. Init processor.
2. Init model config.
3. Init model.
4. Init adapter.

"""

import torch
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoProcessor

from ..accelerator.helper import DeviceType
from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging
Expand Down Expand Up @@ -55,11 +62,14 @@ def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
"""HF model."""

def _init_processor(self) -> Processor:
"""Init processor."""
"""Init processor.

NOTE: Transformers v5 always use fast tokenizer.
https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/auto/tokenization_auto.py#L642
"""
return AutoProcessor.from_pretrained(
self.args.model,
trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.use_fast_processor,
)

def _init_model_config(self) -> HFConfig:
Expand Down Expand Up @@ -92,14 +102,24 @@ def _init_model(self) -> HFModel:

AutoClass = AutoModel

# map the entire model to the current accelerator
model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=DistributedInterface().current_accelerator,
trust_remote_code=self.args.trust_remote_code,
)
if self.args.init_config is not None:
from ..plugins.model_plugins.initialization import InitPlugin

init_device = InitPlugin(self.args.init_config.name)()
else:
init_device = DistributedInterface().current_accelerator

if init_device.type == DeviceType.META:
with init_empty_weights():
model = AutoClass.from_config(self.model_config)
else:
model = AutoClass.from_pretrained(
self.args.model,
config=self.model_config,
dtype="auto",
device_map=init_device,
trust_remote_code=self.args.trust_remote_code,
)

if self.args.peft_config is None:
if self.is_train:
Expand Down
43 changes: 43 additions & 0 deletions src/llamafactory/v1/plugins/model_plugins/initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch

from ...accelerator.helper import DeviceType
from ...accelerator.interface import DistributedInterface
from ...utils.plugin import BasePlugin


class InitPlugin(BasePlugin):
def __call__(self) -> torch.device:
return super().__call__()


@InitPlugin("init_on_meta").register
def init_on_meta() -> torch.device:
return torch.device(DeviceType.META.value)


@InitPlugin("init_on_rank0").register
def init_on_rank0() -> torch.device:
if DistributedInterface().get_rank() == 0:
return torch.device(DeviceType.CPU.value)
else:
return torch.device(DeviceType.META.value)


@InitPlugin("init_on_default").register
def init_on_default() -> torch.device:
return DistributedInterface().current_accelerator
35 changes: 35 additions & 0 deletions src/llamafactory/v1/samplers/cli_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from ..config import InputArgument, SampleBackend, get_args
from ..core.base_sampler import BaseSampler
from ..core.model_loader import ModelLoader


def run_chat(args: InputArgument = None):
data_args, model_args, _, sample_args = get_args(args)
if sample_args.sample_backend != SampleBackend.HF:
model_args.init_plugin = {"name": "init_on_meta"}

model_loader = ModelLoader(model_args)
sampler = BaseSampler(sample_args, model_args, model_loader.model, model_loader.processor)
if data_args.dataset is not None:
sampler.batch_infer()
else:
sampler.generate()


if __name__ == "__main__":
run_chat()
Loading