Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/llamafactory/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def load_model(
)
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels

model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
model = apply_default_kernels(model, include_kernels=model_args.use_v1_kernels)

trainable_params, all_param = count_parameters(model)
if is_trainable:
Expand Down
53 changes: 20 additions & 33 deletions src/llamafactory/v1/core/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import asyncio
import os
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Generator
from collections.abc import AsyncGenerator
from threading import Thread

import torch
from transformers import TextIteratorStreamer
from transformers import AsyncTextIteratorStreamer

from ..accelerator.interface import DistributedInterface
from ..config import ModelArguments, SampleArguments, SampleBackend
Expand Down Expand Up @@ -88,39 +88,26 @@ def __init__(
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))

@torch.inference_mode()
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = TextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()

def stream():
try:
return streamer.__next__()
except StopIteration:
raise StopAsyncIteration()

return stream

async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
async with self.semaphore:
response = self.get_response(messages, tools)
while True:
try:
yield await asyncio.to_thread(response)
except StopAsyncIteration:
break
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = AsyncTextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()

async for token in streamer:
yield token

async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Expand Down
13 changes: 7 additions & 6 deletions src/llamafactory/v1/core/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,31 @@
"""

from ..config.training_args import TrainingArguments
from ..utils.types import HFModel, Processor, TorchDataset
from .trainer_utils.data_collator import DataCollator
from ..utils.types import HFModel, TorchDataset
from .utils.data_collator import DataCollator
from .utils.rendering import Renderer


class BaseTrainer:
def __init__(
self,
args: TrainingArguments,
model: HFModel,
processor: Processor,
renderer: Renderer,
dataset: TorchDataset,
) -> None:
self.args = args
self.model = model
self.processor = processor
self.renderer = renderer
self.dataset = dataset
self.data_collator = DataCollator()
self.optimizer = None
self.lr_scheduler = None

def init_model_and_optimizer(self) -> None:
def _create_dataloader(self) -> None:
pass

def create_dataloader(self) -> None:
def _init_model_and_optimizer(self) -> None:
pass

def fit(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/llamafactory/v1/core/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _init_model_config(self) -> HFConfig:
def _init_model(self) -> HFModel:
"""Init model.

Let transformers handle the model init context.
Transformers can choose the proper model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
"""
if self.args.model_class == ModelClass.LLM:
Expand Down Expand Up @@ -141,7 +141,7 @@ def _init_model(self) -> HFModel:
from ..plugins.model_plugins.kernels.interface import KernelPlugin

model = KernelPlugin(self.args.kernel_config.name)(
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
model, include_kernels=self.args.kernel_config.get("include_kernels")
)

return model
Expand Down
26 changes: 15 additions & 11 deletions src/llamafactory/v1/plugins/model_plugins/kernels/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
import importlib
from pathlib import Path

from ....utils.logging import get_logger
from ....utils import logging
from ....utils.plugin import BasePlugin
from ....utils.types import HFModel
from .registry import Registry


logger = get_logger(__name__)
logger = logging.get_logger(__name__)


def scan_all_kernels():
Expand Down Expand Up @@ -110,27 +111,30 @@ class KernelPlugin(BasePlugin):


@KernelPlugin("auto").register()
def apply_default_kernels(**kwargs):
def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel:
"""Applies all default registered kernels to the model.

Args:
**kwargs: Keyword arguments passed to the kernel application function.
Typically includes the model instance and the include_kernels configuration.
model (HFModel): The model instance to apply kernels to.
include_kernels (str, optional): Comma-separated list of kernel IDs to apply.
If "auto" or True, applies all default kernels.
If None or False, no kernels are applied.
Defaults to None.

Returns:
HFModel: The model with applied kernels.
"""
if not kwargs.get("include_kernels"): # None/False/empty string
return kwargs.get("model")
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
if not include_kernels:
return model
elif include_kernels == "auto" or include_kernels is True:
use_kernels = default_kernels.keys()
else:
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
use_kernels = include_kernels.split(",") # "kernel_id1,kernel_id2,kernel_id3"

for kernel in use_kernels:
if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found")

apply_kernel(kernel, **kwargs)
apply_kernel(kernel, model=model)

return kwargs.get("model")
return model
6 changes: 2 additions & 4 deletions src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

"""

from typing import Optional

from ....accelerator.helper import get_current_accelerator
from .base import BaseKernel

Expand Down Expand Up @@ -73,14 +71,14 @@ def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None:
return kernel_cls

@classmethod
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
def get(cls, kernel_id: str) -> type[BaseKernel] | None:
"""Retrieves a registered kernel implementation by its ID.

Args:
kernel_id (str): The ID of the kernel to retrieve.

Returns:
Optional[type[BaseKernel]]: The kernel class if found, else ``None``.
type[BaseKernel] | None: The kernel class if found, else ``None``.
"""
return cls._kernels.get(kernel_id)

Expand Down