diff --git a/src/llamafactory/data/collator_tokenized.py b/src/llamafactory/data/collator_tokenized.py new file mode 100644 index 0000000000..f1db1a25cf --- /dev/null +++ b/src/llamafactory/data/collator_tokenized.py @@ -0,0 +1,85 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import torch + +from ..extras.constants import IGNORE_INDEX +from transformers import DataCollatorForSeq2Seq + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + + +def _resolve_pad_token_id(tokenizer: "PreTrainedTokenizer", model: "PreTrainedModel") -> int: + r"""Resolve the padding token ID from tokenizer or model config.""" + pad_id = getattr(getattr(model, "config", None), "pad_token_id", None) + if pad_id is None and tokenizer is not None: + pad_id = getattr(tokenizer, "pad_token_id", None) + if pad_id is None: + pad_id = getattr(getattr(model, "config", None), "eos_token_id", None) + return 0 if pad_id is None else int(pad_id) + + +@dataclass +class TokenizedIdsCollator(DataCollatorForSeq2Seq): + r"""Collator for pre-tokenized LM data. + + Expects features containing `input_ids` and optionally `attention_mask`. + Pads to batch max length with `pad_token_id`, generates labels and masks missing fields when needed. + """ + + strict: bool = True + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: + pad_id = _resolve_pad_token_id(self.tokenizer, self.model) + + # Validate and compute max length + max_len = 0 + for f in features: + if "input_ids" not in f or not isinstance(f["input_ids"], list): + if self.strict: + raise ValueError("Each feature must contain list[int] `input_ids`.") + else: + f["input_ids"] = f.get("input_ids", []) or [] + max_len = max(max_len, len(f["input_ids"])) + + input_ids = [] + attention_mask = [] + labels = [] + for f in features: + ids = f["input_ids"] + pad_amt = max_len - len(ids) + row_ids = ids + [pad_id] * pad_amt + input_ids.append(row_ids) + + if "attention_mask" in f and isinstance(f["attention_mask"], list): + if self.strict and len(f["attention_mask"]) != len(ids): + raise ValueError("attention_mask length must match input_ids length.") + mask = f["attention_mask"] + [0] * pad_amt + else: + mask = [1] * len(ids) + [0] * pad_amt + attention_mask.append(mask) + + labels.append(ids + [IGNORE_INDEX] * pad_amt) + + batch = { + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "attention_mask": torch.tensor(attention_mask, dtype=torch.long), + "labels": torch.tensor(labels, dtype=torch.long), + } + return batch diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index b5adc139e9..10aca16380 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -21,6 +21,7 @@ from ..extras import logging from ..extras.constants import FILEEXT2TYPE from ..extras.misc import check_version, has_tokenized_data +from .collator_tokenized import TokenizedIdsCollator from .converter import align_dataset from .data_utils import get_dataset_module, merge_dataset, read_cloud_json, split_dataset from .parser import get_dataset_list @@ -32,6 +33,7 @@ SupervisedDatasetProcessor, UnsupervisedDatasetProcessor, ) +from .tokenized_parquet import load_tokenized_parquet_dataset if TYPE_CHECKING: @@ -241,6 +243,10 @@ def _get_preprocessed_dataset( if dataset is None: return None + # Bypass tokenizer for pre-tokenized pathway + if data_args.dataset_format == "tokenized_ids": + return dataset + dataset_processor = _get_dataset_processor( data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval) ) @@ -301,15 +307,30 @@ def get_dataset( # Load and preprocess dataset with training_args.main_process_first(desc="load dataset", local=(not data_args.data_shared_file_system)): - dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) - eval_dataset = _get_merged_dataset( - data_args.eval_dataset, - model_args, - data_args, - training_args, - stage, - return_dict=data_args.eval_on_each_dataset, - ) + if data_args.dataset_format == "tokenized_ids": + # Load pre-tokenized parquet files + cols = data_args.dataset_columns or {} + ids_key = cols.get("ids", "input_ids") + mask_key = cols.get("mask", "attention_mask") + files = data_args.data_files + if isinstance(files, dict): + files = files.get("train", []) + if not isinstance(files, list) or len(files) == 0: + raise ValueError( + "For dataset_format=tokenized_ids, provide non-empty data_files list (parquet paths)." + ) + dataset = load_tokenized_parquet_dataset(files, ids_key=ids_key, mask_key=mask_key) + eval_dataset = None + else: + dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) + eval_dataset = _get_merged_dataset( + data_args.eval_dataset, + model_args, + data_args, + training_args, + stage, + return_dict=data_args.eval_on_each_dataset, + ) with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)): dataset = _get_preprocessed_dataset( @@ -332,4 +353,9 @@ def get_dataset( logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.") logger.info_rank0(f"Please launch the training with `tokenized_path: {data_args.tokenized_path}`.") - return get_dataset_module(dataset_dict) + module = get_dataset_module(dataset_dict) + # Replace collator for tokenized_ids + if data_args.dataset_format == "tokenized_ids": + collator = TokenizedIdsCollator(tokenizer=tokenizer, model=None) # model attached later by trainer + module["data_collator"] = collator + return module diff --git a/src/llamafactory/data/tokenized_parquet.py b/src/llamafactory/data/tokenized_parquet.py new file mode 100644 index 0000000000..b20c5d9cc4 --- /dev/null +++ b/src/llamafactory/data/tokenized_parquet.py @@ -0,0 +1,84 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Optional + +import pyarrow as pa +import pyarrow.parquet as pq + +from ..extras import logging + + +if TYPE_CHECKING: + from datasets import IterableDataset + + +logger = logging.get_logger(__name__) + + +def _iter_parquet_rows(paths: list[str], ids_key: str, mask_key: Optional[str]) -> Iterable[dict[str, Any]]: + r"""Iterate over rows from multiple Parquet files, yielding pre-tokenized samples.""" + for path in paths: + try: + pf = pq.ParquetFile(path) + except FileNotFoundError: + logger.warning(f"Parquet file not found, skipping: {path}") + continue + + with pf: + with pq.ParquetFile(path) as pf: + for i in range(pf.num_row_groups): + table: pa.Table = pf.read_row_group(i) + ids_col = table[ids_key] + mask_col = table[mask_key] if mask_key and mask_key in table.column_names else None + ids_py = ids_col.to_pylist() + mask_py = mask_col.to_pylist() if mask_col is not None else itertools.repeat(None) + for ids, mask in zip(ids_py, mask_py): + yield { + "input_ids": list(ids) if isinstance(ids, (list, tuple)) else ids, + **( + {"attention_mask": (list(mask) if isinstance(mask, (list, tuple)) else mask)} + if mask is not None + else {} + ), + } + + +def load_tokenized_parquet_dataset( + data_files: list[str], + ids_key: str = "input_ids", + mask_key: Optional[str] = "attention_mask", +) -> "IterableDataset": + r"""Create a streaming HF IterableDataset over pre-tokenized Parquet samples. + + Args: + data_files: List of local Parquet file paths. + ids_key: Column name for input token IDs. + mask_key: Column name for attention mask (optional). + + Returns: + IterableDataset yielding dictionaries with `input_ids` and optionally `attention_mask`. + + Note: + Always streams row groups to avoid materializing large corpora in memory. + """ + from datasets import IterableDataset + + if not data_files: + raise ValueError("data_files must be a non-empty list of Parquet paths") + + logger.info_rank0(f"Building streaming dataset from {len(data_files)} parquet file(s)") + return IterableDataset.from_generator(_iter_parquet_rows, gen_kwargs={"paths": data_files, "ids_key": ids_key, "mask_key": mask_key}) # type: ignore diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index e6844733e5..6f8ea5ce80 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -16,7 +16,7 @@ # limitations under the License. from dataclasses import asdict, dataclass, field -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Union @dataclass @@ -137,6 +137,34 @@ class DataArguments: default=False, metadata={"help": "Whether or not to use a shared file system for the datasets."}, ) + dataset_format: Optional[Literal["default", "tokenized_ids"]] = field( + default="default", + metadata={ + "help": ( + "Format of the input dataset. Use 'tokenized_ids' for pre-tokenized parquet files " + "containing token IDs. This bypasses the tokenization step during training." + ) + }, + ) + data_files: Optional[Any] = field( + default=None, + metadata={ + "help": ( + "Path(s) to data files for tokenized_ids format. " + "Can be a single path, comma-separated paths, or a list of paths." + ) + }, + ) + dataset_columns: Optional[dict[str, str]] = field( + default=None, + metadata={ + "help": ( + "Column name mapping for tokenized datasets. " + "Example: {'ids': 'token_ids', 'mask': 'attn_mask'}. " + "Defaults to {'ids': 'input_ids', 'mask': 'attention_mask'}." + ) + }, + ) def __post_init__(self): def split_arg(arg): @@ -147,6 +175,12 @@ def split_arg(arg): self.dataset = split_arg(self.dataset) self.eval_dataset = split_arg(self.eval_dataset) + # Handle data_files for tokenized_ids format + if self.dataset_format == "tokenized_ids": + if self.data_files is None: + raise ValueError("data_files must be specified when using dataset_format='tokenized_ids'.") + self.data_files = split_arg(self.data_files) + if self.media_dir is None: self.media_dir = self.dataset_dir