diff --git a/src/llamafactory/data/converter.py b/src/llamafactory/data/converter.py index ac3735e648..6f7710a379 100644 --- a/src/llamafactory/data/converter.py +++ b/src/llamafactory/data/converter.py @@ -75,6 +75,21 @@ def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> O return medias + def _extract_loss_mask(self, example: dict[str, Any], prompt: list[dict[str, str]], response: list[dict[str, str]]) -> Optional[list[int]]: + loss_mask = example.get("loss_mask") + if loss_mask is None: + return None + if not isinstance(loss_mask, list): + logger.warning_rank0_once("`loss_mask` should be a list. Ignore this field.") + return None + total_turns = len(prompt) + len(response) + if len(loss_mask) != total_turns: + logger.warning_rank0_once( + f"`loss_mask` length {len(loss_mask)} mismatches total turns {total_turns}. Ignore this field." + ) + return None + return loss_mask + @abstractmethod def __call__(self, example: dict[str, Any]) -> dict[str, Any]: r"""Convert a single example in the dataset to the standard format.""" @@ -127,6 +142,7 @@ def __call__(self, example: dict[str, Any]) -> dict[str, Any]: "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None, "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None, "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None, + "_loss_mask": self._extract_loss_mask(example, prompt, response), } return output @@ -223,6 +239,7 @@ def __call__(self, example: dict[str, Any]) -> dict[str, Any]: "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None, "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None, "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None, + "_loss_mask": self._extract_loss_mask(example, prompt, response), } return output @@ -363,6 +380,7 @@ def __call__(self, example: dict[str, Any]) -> dict[str, Any]: "_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None, "_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None, "_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None, + "_loss_mask": self._extract_loss_mask(example, prompt, response), } return output diff --git a/src/llamafactory/data/processor/supervised.py b/src/llamafactory/data/processor/supervised.py index b5aba11b65..a8e537de59 100644 --- a/src/llamafactory/data/processor/supervised.py +++ b/src/llamafactory/data/processor/supervised.py @@ -18,6 +18,7 @@ from ...extras import logging from ...extras.constants import IGNORE_INDEX +from ..data_utils import Role from .processor_utils import DatasetProcessor, greedy_knapsack, infer_seqlen @@ -39,6 +40,7 @@ def _encode_data_example( images: list["ImageInput"], videos: list["VideoInput"], audios: list["AudioInput"], + loss_mask: Optional[list[int]] = None, ) -> tuple[list[int], list[int]]: messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor) input_ids, labels = self.template.mm_plugin.process_token_ids( @@ -49,6 +51,26 @@ def _encode_data_example( if self.data_args.mask_history: encoded_pairs = encoded_pairs[::-1] # high priority for last turns + assistant_loss_mask: Optional[list[int]] = None + if loss_mask is not None: + if len(loss_mask) != len(prompt) + len(response): + logger.warning_rank0_once( + f"Dropped invalid `loss_mask` with length {len(loss_mask)} for example." + ) + else: + assistant_loss_mask = [] + for mask_value, message in zip(loss_mask, prompt + response): + if message.get("role") == Role.ASSISTANT.value: + assistant_loss_mask.append(1 if mask_value else 0) + if len(assistant_loss_mask) != len(encoded_pairs): + logger.warning_rank0_once( + "Mismatch between assistant turns and `loss_mask`. Ignoring provided mask." + ) + assistant_loss_mask = None + + if self.data_args.mask_history and assistant_loss_mask is not None: + assistant_loss_mask = assistant_loss_mask[::-1] # align with reversed encoded_pairs + for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): if total_length >= self.data_args.cutoff_len: break @@ -72,6 +94,10 @@ def _encode_data_example( else: target_label = target_ids + if assistant_loss_mask is not None and turn_idx < len(assistant_loss_mask): + if assistant_loss_mask[turn_idx] == 0: + target_label = [IGNORE_INDEX] * target_len + if self.data_args.mask_history: # reversed sequences input_ids = source_ids + target_ids + input_ids labels = source_label + target_label + labels @@ -89,6 +115,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A # build inputs with format ` X Y ` and labels with format ` ... Y ` # for multiturn examples, we only mask the prompt part in each prompt-response pair. model_inputs = defaultdict(list) + loss_masks = examples.get("_loss_mask") for i in range(len(examples["_prompt"])): if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: logger.warning_rank0( @@ -96,6 +123,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A ) continue + example_loss_mask = loss_masks[i] if loss_masks else None input_ids, labels = self._encode_data_example( prompt=examples["_prompt"][i], response=examples["_response"][i], @@ -104,6 +132,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A images=examples["_images"][i] or [], videos=examples["_videos"][i] or [], audios=examples["_audios"][i] or [], + loss_mask=example_loss_mask, ) model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) @@ -132,6 +161,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], [] lengths = [] length2indexes = defaultdict(list) + loss_masks = examples.get("_loss_mask") for i in range(len(examples["_prompt"])): if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: logger.warning_rank0( @@ -139,6 +169,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A ) continue + example_loss_mask = loss_masks[i] if loss_masks else None input_ids, labels = self._encode_data_example( prompt=examples["_prompt"][i], response=examples["_response"][i], @@ -147,6 +178,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A images=examples["_images"][i] or [], videos=examples["_videos"][i] or [], audios=examples["_audios"][i] or [], + loss_mask=example_loss_mask, ) length = len(input_ids) if length > self.data_args.cutoff_len: