From 9f5c36a70e9a9a5595a191b4166f7275a95abf66 Mon Sep 17 00:00:00 2001 From: CjangCjengh <101577701+CjangCjengh@users.noreply.github.com> Date: Thu, 18 Dec 2025 19:31:45 +0800 Subject: [PATCH 1/2] support loss mask --- src/llamafactory/data/converter.py | 18 ++++++++++++ src/llamafactory/data/processor/supervised.py | 29 +++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/src/llamafactory/data/converter.py b/src/llamafactory/data/converter.py index ac3735e648..6f7710a379 100644 --- a/src/llamafactory/data/converter.py +++ b/src/llamafactory/data/converter.py @@ -75,6 +75,21 @@ def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> O return medias + def _extract_loss_mask(self, example: dict[str, Any], prompt: list[dict[str, str]], response: list[dict[str, str]]) -> Optional[list[int]]: + loss_mask = example.get("loss_mask") + if loss_mask is None: + return None + if not isinstance(loss_mask, list): + logger.warning_rank0_once("`loss_mask` should be a list. Ignore this field.") + return None + total_turns = len(prompt) + len(response) + if len(loss_mask) != total_turns: + logger.warning_rank0_once( + f"`loss_mask` length {len(loss_mask)} mismatches total turns {total_turns}. Ignore this field." + ) + return None + return loss_mask + @abstractmethod def __call__(self, example: dict[str, Any]) -> dict[str, Any]: r"""Convert a single example in the dataset to the standard format.""" @@ -127,6 +142,7 @@ def __call__(self, example: dict[str, Any]) -> dict[str, Any]: "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None, "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None, "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None, + "_loss_mask": self._extract_loss_mask(example, prompt, response), } return output @@ -223,6 +239,7 @@ def __call__(self, example: dict[str, Any]) -> dict[str, Any]: "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None, "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None, "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None, + "_loss_mask": self._extract_loss_mask(example, prompt, response), } return output @@ -363,6 +380,7 @@ def __call__(self, example: dict[str, Any]) -> dict[str, Any]: "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None, "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None, "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None, + "_loss_mask": self._extract_loss_mask(example, prompt, response), } return output diff --git a/src/llamafactory/data/processor/supervised.py b/src/llamafactory/data/processor/supervised.py index b5aba11b65..cd0721c637 100644 --- a/src/llamafactory/data/processor/supervised.py +++ b/src/llamafactory/data/processor/supervised.py @@ -18,6 +18,7 @@ from ...extras import logging from ...extras.constants import IGNORE_INDEX +from ..data_utils import Role from .processor_utils import DatasetProcessor, greedy_knapsack, infer_seqlen @@ -39,6 +40,7 @@ def _encode_data_example( images: list["ImageInput"], videos: list["VideoInput"], audios: list["AudioInput"], + loss_mask: Optional[list[int]] = None, ) -> tuple[list[int], list[int]]: messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor) input_ids, labels = self.template.mm_plugin.process_token_ids( @@ -49,6 +51,23 @@ def _encode_data_example( if self.data_args.mask_history: encoded_pairs = encoded_pairs[::-1] # high priority for last turns + assistant_loss_mask: Optional[list[int]] = None + if loss_mask is not None: + if len(loss_mask) != len(prompt) + len(response): + logger.warning_rank0_once( + f"Dropped invalid `loss_mask` with length {len(loss_mask)} for example." + ) + else: + assistant_loss_mask = [] + for mask_value, message in zip(loss_mask, prompt + response): + if message.get("role") == Role.ASSISTANT.value: + assistant_loss_mask.append(1 if mask_value else 0) + if len(assistant_loss_mask) != len(encoded_pairs): + logger.warning_rank0_once( + "Mismatch between assistant turns and `loss_mask`. Ignoring provided mask." + ) + assistant_loss_mask = None + for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): if total_length >= self.data_args.cutoff_len: break @@ -72,6 +91,10 @@ def _encode_data_example( else: target_label = target_ids + if assistant_loss_mask is not None and turn_idx < len(assistant_loss_mask): + if assistant_loss_mask[turn_idx] == 0: + target_label = [IGNORE_INDEX] * target_len + if self.data_args.mask_history: # reversed sequences input_ids = source_ids + target_ids + input_ids labels = source_label + target_label + labels @@ -89,6 +112,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A # build inputs with format ` X Y ` and labels with format ` ... Y ` # for multiturn examples, we only mask the prompt part in each prompt-response pair. model_inputs = defaultdict(list) + loss_masks = examples.get("_loss_mask") for i in range(len(examples["_prompt"])): if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: logger.warning_rank0( @@ -96,6 +120,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A ) continue + example_loss_mask = loss_masks[i] if loss_masks else None input_ids, labels = self._encode_data_example( prompt=examples["_prompt"][i], response=examples["_response"][i], @@ -104,6 +129,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A images=examples["_images"][i] or [], videos=examples["_videos"][i] or [], audios=examples["_audios"][i] or [], + loss_mask=example_loss_mask, ) model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) @@ -132,6 +158,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], [] lengths = [] length2indexes = defaultdict(list) + loss_masks = examples.get("_loss_mask") for i in range(len(examples["_prompt"])): if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: logger.warning_rank0( @@ -139,6 +166,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A ) continue + example_loss_mask = loss_masks[i] if loss_masks else None input_ids, labels = self._encode_data_example( prompt=examples["_prompt"][i], response=examples["_response"][i], @@ -147,6 +175,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A images=examples["_images"][i] or [], videos=examples["_videos"][i] or [], audios=examples["_audios"][i] or [], + loss_mask=example_loss_mask, ) length = len(input_ids) if length > self.data_args.cutoff_len: From d98dc064a82e6d3f31770e2fc5fcada10c57cd9c Mon Sep 17 00:00:00 2001 From: CjangCjengh <101577701+CjangCjengh@users.noreply.github.com> Date: Thu, 18 Dec 2025 19:49:29 +0800 Subject: [PATCH 2/2] fix reverse bug --- src/llamafactory/data/processor/supervised.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/llamafactory/data/processor/supervised.py b/src/llamafactory/data/processor/supervised.py index cd0721c637..a8e537de59 100644 --- a/src/llamafactory/data/processor/supervised.py +++ b/src/llamafactory/data/processor/supervised.py @@ -68,6 +68,9 @@ def _encode_data_example( ) assistant_loss_mask = None + if self.data_args.mask_history and assistant_loss_mask is not None: + assistant_loss_mask = assistant_loss_mask[::-1] # align with reversed encoded_pairs + for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): if total_length >= self.data_args.cutoff_len: break