Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/llamafactory/v1/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.

from .arg_parser import InputArgument, get_args
from .arg_utils import ModelClass, SampleBackend
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
from .training_args import TrainingArguments


__all__ = [
"BatchingStrategy",
"DataArguments",
"InputArgument",
"ModelArguments",
Expand Down
8 changes: 8 additions & 0 deletions src/llamafactory/v1/config/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ class SampleBackend(StrEnum):
VLLM = "vllm"


@unique
class BatchingStrategy(StrEnum):
NORMAL = "normal"
PADDING_FREE = "padding_free"
DYNAMIC_BATCHING = "dynamic_batching"
DYNAMIC_PADDING_FREE = "dynamic_padding_free"


def _convert_str_dict(data: dict) -> dict:
"""Parse string representation inside the dictionary.

Expand Down
4 changes: 0 additions & 4 deletions src/llamafactory/v1/config/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,3 @@ class DataArguments:
default=None,
metadata={"help": "Path to the dataset."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Cutoff length for the dataset."},
)
20 changes: 16 additions & 4 deletions src/llamafactory/v1/config/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dataclasses import dataclass, field
from uuid import uuid4

from .arg_utils import PluginConfig, get_plugin_config
from .arg_utils import BatchingStrategy, PluginConfig, get_plugin_config


@dataclass
Expand All @@ -29,18 +29,30 @@ class TrainingArguments:
default=1,
metadata={"help": "Micro batch size for training."},
)
global_batch_size: int = field(
default=1,
metadata={"help": "Global batch size for training."},
global_batch_size: int | None = field(
default=None,
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
)
learning_rate: float = field(
default=1e-4,
metadata={"help": "Learning rate for training."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Maximum sequence length for training."},
)
bf16: bool = field(
default=False,
metadata={"help": "Use bf16 for training."},
)
batching_strategy: BatchingStrategy = field(
default=BatchingStrategy.NORMAL,
metadata={"help": "Batching strategy for training."},
)
batching_workers: int = field(
default=16,
metadata={"help": "Number of workers for batching."},
)
dist_config: PluginConfig | None = field(
default=None,
metadata={"help": "Distribution configuration for training."},
Expand Down
2 changes: 0 additions & 2 deletions src/llamafactory/v1/core/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

from ..config.training_args import TrainingArguments
from ..utils.types import HFModel, TorchDataset
from .utils.data_collator import DataCollator
from .utils.rendering import Renderer


Expand All @@ -45,7 +44,6 @@ def __init__(
self.model = model
self.renderer = renderer
self.dataset = dataset
self.data_collator = DataCollator()
self.optimizer = None
self.lr_scheduler = None

Expand Down
16 changes: 9 additions & 7 deletions src/llamafactory/v1/core/data_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,17 @@ def _get_dataset_info(self) -> None:

def _load_dataset(self) -> None:
"""Load datasets according to dataset info."""
is_streaming = [dataset_info.get("streaming", False) for dataset_info in self.dataset_infos.values()]
self.streaming = any(is_streaming)
if all(is_streaming) != any(is_streaming):
raise ValueError("All datasets must be streaming or non-streaming.")

for dataset_name, dataset_info in self.dataset_infos.items():
split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False)
self.streaming |= streaming
if dataset_info.get("source", "hf_hub") == "hf_hub":
from datasets import load_dataset

self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming)
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=self.streaming)
else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin

Expand All @@ -98,8 +101,7 @@ def _load_dataset(self) -> None:
def _build_data_index(self) -> None:
"""Build dataset index."""
for dataset_name, dataset in self.datasets.items():
streaming = self.dataset_infos[dataset_name].get("streaming", False)
if streaming:
if self.streaming:
data_index = [(dataset_name, -1) for _ in range(1000)]
else:
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
Expand Down Expand Up @@ -185,8 +187,8 @@ def __iter__(self) -> Iterable[Sample]:

if __name__ == "__main__":
"""
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
python -m llamafactory.v1.core.data_engine --dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --dataset data/v1_dpo_demo.yaml
"""
from ..config.arg_parser import get_args

Expand Down
244 changes: 244 additions & 0 deletions src/llamafactory/v1/core/utils/batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Batching utils supports stateful dataloader.

1. Init stateful dataloader (tokenize)
2. Add to buffer
3. Yield batch indexes (micro batch * grad acc)
a) non pack + non dynamic
b) non pack + dynamic
c) pack + non dynamic
d) pack + dynamic
"""

from collections.abc import Iterator
from typing import Any

from torch.utils.data import default_collate
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

from ...accelerator.interface import DistributedInterface
from ...config import BatchingStrategy
from ...utils import logging
from ...utils.helper import pad_and_truncate
from ...utils.types import BatchInput, ModelInput, TorchDataset
from .rendering import Renderer


logger = logging.get_logger(__name__)


def default_collate_fn(
buffer: list[ModelInput], buffer_tokens: int, micro_batch_size: int, num_micro_batch: int, cutoff_len: int
) -> tuple[list[ModelInput], int, list[BatchInput]]:
batch_size = micro_batch_size * num_micro_batch
if len(buffer) < batch_size:
return buffer, buffer_tokens, None

samples = buffer[:batch_size]
buffer = buffer[batch_size:]
buffer_tokens -= sum(len(sample["input_ids"]) for sample in samples)

batch = []
for i in range(num_micro_batch):
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
batch.append(default_collate(pad_and_truncate(micro_batch, cutoff_len)))

return buffer, buffer_tokens, batch


class BatchGenerator(Iterator):
def __init__(
self,
dataset: TorchDataset,
renderer: Renderer,
micro_batch_size: int = 1,
global_batch_size: int | None = None,
cutoff_len: int = 2048,
batching_workers: int = 0,
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
pin_memory: bool = True,
drop_last: bool = True,
) -> None:
self.dataset = dataset
self.renderer = renderer

self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.cutoff_len = cutoff_len
self.batching_workers = batching_workers
self.batching_strategy = batching_strategy
self.pin_memory = pin_memory
self.drop_last = drop_last
# TODO: support length and infinity

dp_size = DistributedInterface().get_world_size("dp")

if self.global_batch_size is None:
self.global_batch_size = dp_size * micro_batch_size
self.num_micro_batch = 1
elif self.global_batch_size % (dp_size * micro_batch_size) == 0:
self.num_micro_batch = global_batch_size // dp_size // micro_batch_size
else:
raise ValueError(
"Global batch size must be divisible by DP size and micro batch size. "
f"Got {global_batch_size} % ({dp_size} * {micro_batch_size}) != 0."
)

if not self.drop_last:
raise ValueError("Drop last must be True.")

self._init_data_provider()

self._is_resuming: bool = False
self._data_iter = iter(self._data_provider)
self._buffer: list[ModelInput] = []
self._buffer_tokens: int = 0
self._max_buffer_tokens: int = self.micro_batch_size * self.num_micro_batch * self.cutoff_len

logger.info_rank0(
f"Init unified data loader with global batch size {self.global_batch_size}, "
f"micro batch size {self.micro_batch_size}, "
f"num micro batch {self.num_micro_batch}, "
f"cutoff len {self.cutoff_len}, "
f"batching workers {self.batching_workers}, "
f"batching strategy {self.batching_strategy}."
)

def _init_data_provider(self) -> None:
if len(self.dataset) != -1:
sampler = StatefulDistributedSampler(
self.dataset,
num_replicas=DistributedInterface().get_world_size("dp"),
rank=DistributedInterface().get_rank("dp"),
shuffle=True,
seed=0,
drop_last=self.drop_last,
)
else:
raise NotImplementedError("Iterable dataset is not supported yet.")

self._data_provider = StatefulDataLoader(
self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch,
sampler=sampler,
num_workers=self.batching_workers,
collate_fn=self.renderer.process_samples,
pin_memory=self.pin_memory,
drop_last=self.drop_last,
)
if self.batching_strategy == BatchingStrategy.NORMAL:
self._length = len(self._data_provider)
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin

self._length = BatchingPlugin(self.batching_strategy).compute_length()
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")

def __len__(self) -> int:
return self._length

def __iter__(self):
if not self._is_resuming:
self._buffer.clear()
self._buffer_tokens = 0

self._data_iter = iter(self._data_provider)
self._is_resuming = False
return self

def __next__(self):
batch = self._next_batch()
if batch is None:
raise StopIteration

return batch

def _next_batch(self) -> list[BatchInput] | None:
while self._buffer_tokens < self._max_buffer_tokens:
try:
samples: list[ModelInput] = next(self._data_iter)
except StopIteration:
break

num_tokens = sum(len(sample["input_ids"]) for sample in samples)
self._buffer.extend(samples)
self._buffer_tokens += num_tokens

return self._build_batch()

def _build_batch(self) -> list[BatchInput] | None:
if self.batching_strategy == BatchingStrategy.NORMAL:
self._buffer, self._buffer_tokens, batch = default_collate_fn(
self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len
)
return batch
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin

self._buffer, self._buffer_tokens, batch = BatchingPlugin(self.batching_strategy)(
self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len
)
return batch

def state_dict(self) -> dict[str, Any]:
return {
"buffer": self._buffer,
"buffer_tokens": self._buffer_tokens,
"data_provider": self._data_provider.state_dict(),
}

def load_state_dict(self, state: dict[str, Any]) -> None:
self._buffer = state["buffer"]
self._buffer_tokens = state["buffer_tokens"]
self._data_provider.load_state_dict(state["data_provider"])
self._is_resuming = True

def set_epoch(self, epoch: int) -> None:
if hasattr(self._data_provider.sampler, "set_epoch"):
self._data_provider.sampler.set_epoch(epoch)


if __name__ == "__main__":
"""
python -m llamafactory.v1.core.utils.batching \
--model llamafactory/tiny-random-qwen2.5 \
--dataset data/v1_sft_demo.yaml \
--micro_batch_size 2 \
--global_batch_size 4 \
--batching_workers 0
"""
from ...config.arg_parser import get_args
from ..data_engine import DataEngine
from ..model_engine import ModelEngine

data_args, model_args, training_args, _ = get_args()
data_engine = DataEngine(data_args=data_args)
model_engine = ModelEngine(model_args=model_args)
batch_generator = BatchGenerator(
data_engine,
model_engine.renderer,
micro_batch_size=training_args.micro_batch_size,
global_batch_size=training_args.global_batch_size,
cutoff_len=training_args.cutoff_len,
batching_workers=training_args.batching_workers,
batching_strategy=training_args.batching_strategy,
)
for batch in batch_generator:
print(batch)
print(len(batch))
print(batch[0]["input_ids"].shape)
break
Loading
Loading