Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/llamafactory/v1/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.

from .arg_parser import InputArgument, get_args
from .arg_utils import ModelClass, SampleBackend
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
from .training_args import TrainingArguments


__all__ = [
"BatchingStrategy",
"DataArguments",
"InputArgument",
"ModelArguments",
Expand Down
8 changes: 8 additions & 0 deletions src/llamafactory/v1/config/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ class SampleBackend(StrEnum):
VLLM = "vllm"


@unique
class BatchingStrategy(StrEnum):
NORMAL = "normal"
PADDING_FREE = "padding_free"
DYNAMIC_BATCHING = "dynamic_batching"
DYNAMIC_PADDING_FREE = "dynamic_padding_free"


def _convert_str_dict(data: dict) -> dict:
"""Parse string representation inside the dictionary.

Expand Down
4 changes: 0 additions & 4 deletions src/llamafactory/v1/config/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,3 @@ class DataArguments:
default=None,
metadata={"help": "Path to the dataset."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Cutoff length for the dataset."},
)
20 changes: 16 additions & 4 deletions src/llamafactory/v1/config/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dataclasses import dataclass, field
from uuid import uuid4

from .arg_utils import PluginConfig, get_plugin_config
from .arg_utils import BatchingStrategy, PluginConfig, get_plugin_config


@dataclass
Expand All @@ -29,18 +29,30 @@ class TrainingArguments:
default=1,
metadata={"help": "Micro batch size for training."},
)
global_batch_size: int = field(
default=1,
metadata={"help": "Global batch size for training."},
global_batch_size: int | None = field(
default=None,
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
)
learning_rate: float = field(
default=1e-4,
metadata={"help": "Learning rate for training."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Maximum sequence length for training."},
)
bf16: bool = field(
default=False,
metadata={"help": "Use bf16 for training."},
)
batching_strategy: BatchingStrategy = field(
default=BatchingStrategy.NORMAL,
metadata={"help": "Batching strategy for training."},
)
batching_workers: int = field(
default=16,
metadata={"help": "Number of workers for batching."},
)
dist_config: PluginConfig | None = field(
default=None,
metadata={"help": "Distribution configuration for training."},
Expand Down
2 changes: 0 additions & 2 deletions src/llamafactory/v1/core/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

from ..config.training_args import TrainingArguments
from ..utils.types import HFModel, TorchDataset
from .utils.data_collator import DataCollator
from .utils.rendering import Renderer


Expand All @@ -45,7 +44,6 @@ def __init__(
self.model = model
self.renderer = renderer
self.dataset = dataset
self.data_collator = DataCollator()
self.optimizer = None
self.lr_scheduler = None

Expand Down
16 changes: 9 additions & 7 deletions src/llamafactory/v1/core/data_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,17 @@ def _get_dataset_info(self) -> None:

def _load_dataset(self) -> None:
"""Load datasets according to dataset info."""
is_streaming = [dataset_info.get("streaming", False) for dataset_info in self.dataset_infos.values()]
self.streaming = any(is_streaming)
if all(is_streaming) != any(is_streaming):
raise ValueError("All datasets must be streaming or non-streaming.")

for dataset_name, dataset_info in self.dataset_infos.items():
split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False)
self.streaming |= streaming
if dataset_info.get("source", "hf_hub") == "hf_hub":
from datasets import load_dataset

self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming)
self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=self.streaming)
else: # data loader plugin
from ..plugins.data_plugins.loader import DataLoaderPlugin

Expand All @@ -98,8 +101,7 @@ def _load_dataset(self) -> None:
def _build_data_index(self) -> None:
"""Build dataset index."""
for dataset_name, dataset in self.datasets.items():
streaming = self.dataset_infos[dataset_name].get("streaming", False)
if streaming:
if self.streaming:
data_index = [(dataset_name, -1) for _ in range(1000)]
else:
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
Expand Down Expand Up @@ -185,8 +187,8 @@ def __iter__(self) -> Iterable[Sample]:

if __name__ == "__main__":
"""
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml
python -m llamafactory.v1.core.data_engine --dataset data/v1_sft_demo.yaml
python -m llamafactory.v1.core.data_engine --dataset data/v1_dpo_demo.yaml
"""
from ..config.arg_parser import get_args

Expand Down
244 changes: 244 additions & 0 deletions src/llamafactory/v1/core/utils/batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Batching utils supports stateful dataloader.

1. Init stateful dataloader (tokenize)
2. Add to buffer
3. Yield batch indexes (micro batch * grad acc)
a) non pack + non dynamic
b) non pack + dynamic
c) pack + non dynamic
d) pack + dynamic
"""

from collections.abc import Iterator
from typing import Any

from torch.utils.data import default_collate
from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

from ...accelerator.interface import DistributedInterface
from ...config import BatchingStrategy
from ...utils import logging
from ...utils.helper import pad_and_truncate
from ...utils.types import BatchInput, ModelInput, TorchDataset
from .rendering import Renderer


logger = logging.get_logger(__name__)


def default_collate_fn(
buffer: list[ModelInput], buffer_tokens: int, micro_batch_size: int, num_micro_batch: int, cutoff_len: int
) -> tuple[list[ModelInput], int, list[BatchInput]]:
batch_size = micro_batch_size * num_micro_batch
if len(buffer) < batch_size:
return buffer, buffer_tokens, None

samples = buffer[:batch_size]
buffer = buffer[batch_size:]
buffer_tokens -= sum(len(sample["input_ids"]) for sample in samples)

batch = []
for i in range(num_micro_batch):
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
batch.append(default_collate(pad_and_truncate(micro_batch, cutoff_len)))

return buffer, buffer_tokens, batch


class BatchGenerator(Iterator):
def __init__(
self,
dataset: TorchDataset,
renderer: Renderer,
micro_batch_size: int = 1,
global_batch_size: int | None = None,
cutoff_len: int = 2048,
batching_workers: int = 0,
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
pin_memory: bool = True,
drop_last: bool = True,
) -> None:
self.dataset = dataset
self.renderer = renderer

self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.cutoff_len = cutoff_len
self.batching_workers = batching_workers
self.batching_strategy = batching_strategy
self.pin_memory = pin_memory
self.drop_last = drop_last
# TODO: support length and infinity

dp_size = DistributedInterface().get_world_size("dp")

if self.global_batch_size is None:
self.global_batch_size = dp_size * micro_batch_size
self.num_micro_batch = 1
elif self.global_batch_size % (dp_size * micro_batch_size) == 0:
self.num_micro_batch = global_batch_size // dp_size // micro_batch_size
else:
raise ValueError(
"Global batch size must be divisible by DP size and micro batch size. "
f"Got {global_batch_size} % ({dp_size} * {micro_batch_size}) != 0."
)

if not self.drop_last:
raise ValueError("Drop last must be True.")

self._init_data_provider()

self._is_resuming: bool = False
self._data_iter = iter(self._data_provider)
self._buffer: list[ModelInput] = []
self._buffer_tokens: int = 0
self._max_buffer_tokens: int = self.micro_batch_size * self.num_micro_batch * self.cutoff_len

logger.info_rank0(
f"Init unified data loader with global batch size {self.global_batch_size}, "
f"micro batch size {self.micro_batch_size}, "
f"num micro batch {self.num_micro_batch}, "
f"cutoff len {self.cutoff_len}, "
f"batching workers {self.batching_workers}, "
f"batching strategy {self.batching_strategy}."
)

def _init_data_provider(self) -> None:
if len(self.dataset) != -1:
sampler = StatefulDistributedSampler(
self.dataset,
num_replicas=DistributedInterface().get_world_size("dp"),
rank=DistributedInterface().get_rank("dp"),
shuffle=True,
seed=0,
drop_last=self.drop_last,
)
else:
raise NotImplementedError("Iterable dataset is not supported yet.")

self._data_provider = StatefulDataLoader(
self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch,
sampler=sampler,
num_workers=self.batching_workers,
collate_fn=self.renderer.process_samples,
pin_memory=self.pin_memory,
drop_last=self.drop_last,
)
if self.batching_strategy == BatchingStrategy.NORMAL:
self._length = len(self._data_provider)
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin

self._length = BatchingPlugin(self.batching_strategy).compute_length()
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")

def __len__(self) -> int:
return self._length

def __iter__(self):
if not self._is_resuming:
self._data_iter = iter(self._data_provider)
self._buffer.clear()
self._buffer_tokens = 0

self._is_resuming = False
return self

def __next__(self):
batch = self._next_batch()
if batch is None:
raise StopIteration

return batch

def _next_batch(self) -> list[BatchInput] | None:
while self._buffer_tokens < self._max_buffer_tokens:
try:
samples: list[ModelInput] = next(self._data_iter)
except StopIteration:
break

num_tokens = sum(len(sample["input_ids"]) for sample in samples)
self._buffer.extend(samples)
self._buffer_tokens += num_tokens

return self._build_batch()

def _build_batch(self) -> list[BatchInput] | None:
if self.batching_strategy == BatchingStrategy.NORMAL:
self._buffer, self._buffer_tokens, batch = default_collate_fn(
self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len
)
return batch
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin

self._buffer, self._buffer_tokens, batch = BatchingPlugin(self.batching_strategy)(
self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len
)
return batch

def state_dict(self) -> dict[str, Any]:
return {
"buffer": self._buffer,
"buffer_tokens": self._buffer_tokens,
"data_provider": self._data_provider.state_dict(),
}

def load_state_dict(self, state: dict[str, Any]) -> None:
self._buffer = state["buffer"]
self._buffer_tokens = state["buffer_tokens"]
self._data_provider.load_state_dict(state["data_provider"])
self._is_resuming = True

def set_epoch(self, epoch: int) -> None:
if hasattr(self._data_provider.sampler, "set_epoch"):
self._data_provider.sampler.set_epoch(epoch)


if __name__ == "__main__":
"""
python -m llamafactory.v1.core.utils.batching \
--model llamafactory/tiny-random-qwen2.5 \
--dataset data/v1_sft_demo.yaml \
--micro_batch_size 2 \
--global_batch_size 4 \
--batching_workers 0
"""
from ...config.arg_parser import get_args
from ..data_engine import DataEngine
from ..model_engine import ModelEngine

data_args, model_args, training_args, _ = get_args()
data_engine = DataEngine(data_args=data_args)
model_engine = ModelEngine(model_args=model_args)
batch_generator = BatchGenerator(
data_engine,
model_engine.renderer,
micro_batch_size=training_args.micro_batch_size,
global_batch_size=training_args.global_batch_size,
cutoff_len=training_args.cutoff_len,
batching_workers=training_args.batching_workers,
batching_strategy=training_args.batching_strategy,
)
for batch in batch_generator:
print(batch)
print(len(batch))
print(batch[0]["input_ids"].shape)
break
Loading
Loading