Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/llamafactory/data/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,21 @@ def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> O

return medias

def _extract_loss_mask(self, example: dict[str, Any], prompt: list[dict[str, str]], response: list[dict[str, str]]) -> Optional[list[int]]:
loss_mask = example.get("loss_mask")
if loss_mask is None:
return None
if not isinstance(loss_mask, list):
logger.warning_rank0_once("`loss_mask` should be a list. Ignore this field.")
return None
total_turns = len(prompt) + len(response)
if len(loss_mask) != total_turns:
logger.warning_rank0_once(
f"`loss_mask` length {len(loss_mask)} mismatches total turns {total_turns}. Ignore this field."
)
return None
return loss_mask

@abstractmethod
def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
r"""Convert a single example in the dataset to the standard format."""
Expand Down Expand Up @@ -127,6 +142,7 @@ def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
"_loss_mask": self._extract_loss_mask(example, prompt, response),
}
return output

Expand Down Expand Up @@ -223,6 +239,7 @@ def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
"_loss_mask": self._extract_loss_mask(example, prompt, response),
}
return output

Expand Down Expand Up @@ -363,6 +380,7 @@ def __call__(self, example: dict[str, Any]) -> dict[str, Any]:
"_images": self._find_medias(example[self.dataset_attr.images]) if self.dataset_attr.images else None,
"_videos": self._find_medias(example[self.dataset_attr.videos]) if self.dataset_attr.videos else None,
"_audios": self._find_medias(example[self.dataset_attr.audios]) if self.dataset_attr.audios else None,
"_loss_mask": self._extract_loss_mask(example, prompt, response),
}
return output

Expand Down
32 changes: 32 additions & 0 deletions src/llamafactory/data/processor/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ..data_utils import Role
from .processor_utils import DatasetProcessor, greedy_knapsack, infer_seqlen


Expand All @@ -39,6 +40,7 @@ def _encode_data_example(
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
loss_mask: Optional[list[int]] = None,
) -> tuple[list[int], list[int]]:
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
input_ids, labels = self.template.mm_plugin.process_token_ids(
Expand All @@ -49,6 +51,26 @@ def _encode_data_example(
if self.data_args.mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns

assistant_loss_mask: Optional[list[int]] = None
if loss_mask is not None:
if len(loss_mask) != len(prompt) + len(response):
logger.warning_rank0_once(
f"Dropped invalid `loss_mask` with length {len(loss_mask)} for example."
)
else:
assistant_loss_mask = []
for mask_value, message in zip(loss_mask, prompt + response):
if message.get("role") == Role.ASSISTANT.value:
assistant_loss_mask.append(1 if mask_value else 0)
if len(assistant_loss_mask) != len(encoded_pairs):
logger.warning_rank0_once(
"Mismatch between assistant turns and `loss_mask`. Ignoring provided mask."
)
assistant_loss_mask = None

if self.data_args.mask_history and assistant_loss_mask is not None:
assistant_loss_mask = assistant_loss_mask[::-1] # align with reversed encoded_pairs

for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= self.data_args.cutoff_len:
break
Expand All @@ -72,6 +94,10 @@ def _encode_data_example(
else:
target_label = target_ids

if assistant_loss_mask is not None and turn_idx < len(assistant_loss_mask):
if assistant_loss_mask[turn_idx] == 0:
target_label = [IGNORE_INDEX] * target_len
Copy link

@zengxingchen zengxingchen Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hiyouga @CjangCjengh
Hi! I just read through this PR and noticed a potential issue when mask_history=True and loss_mask is also used. In that case, mask_history sets IGNORE_INDEX before loss_mask is applied, so the loss mask may not take effect as intended.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, when loss_mask is used, mask_history shouldn’t be responsible for setting IGNORE_INDEX. The main thing we still need from mask_history is reversing the IDs to avoid truncating the last turn. That said, I’m not sure whether this default behavior should also apply to loss_mask, since loss_mask is a fairly flexible option with different possible use cases.


if self.data_args.mask_history: # reversed sequences
input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
Expand All @@ -89,13 +115,15 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs = defaultdict(list)
loss_masks = examples.get("_loss_mask")
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue

example_loss_mask = loss_masks[i] if loss_masks else None
input_ids, labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
Expand All @@ -104,6 +132,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
loss_mask=example_loss_mask,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
Expand Down Expand Up @@ -132,13 +161,15 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A
batch_input_ids, batch_labels, batch_images, batch_videos, batch_audios = [], [], [], [], []
lengths = []
length2indexes = defaultdict(list)
loss_masks = examples.get("_loss_mask")
for i in range(len(examples["_prompt"])):
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
logger.warning_rank0(
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
)
continue

example_loss_mask = loss_masks[i] if loss_masks else None
input_ids, labels = self._encode_data_example(
prompt=examples["_prompt"][i],
response=examples["_response"][i],
Expand All @@ -147,6 +178,7 @@ def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[A
images=examples["_images"][i] or [],
videos=examples["_videos"][i] or [],
audios=examples["_audios"][i] or [],
loss_mask=example_loss_mask,
)
length = len(input_ids)
if length > self.data_args.cutoff_len:
Expand Down
Loading