From 9c480499a79c8f80bbc53f4959ed7e9d000eef21 Mon Sep 17 00:00:00 2001 From: Pankaj Dixit Date: Thu, 1 Jan 2026 09:01:24 -0800 Subject: [PATCH 1/2] Fix #9306: add entropy logging for SFT training path Signed-off-by: Pankaj Dixit --- src/llamafactory/hparams/finetuning_args.py | 4 +++ src/llamafactory/train/sft/trainer.py | 33 +++++++++++++++++-- src/llamafactory/train/trainer_utils.py | 35 +++++++++++++++++++++ 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 7ab2ce3bc8..27cb58099c 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -522,6 +522,10 @@ class FinetuningArguments( default=False, metadata={"help": "Whether or not to compute effective tokens per second."}, ) + log_entropy: bool = field( + default=False, + metadata={"help": "Whether or not to log entropy during training."}, + ) def __post_init__(self): def split_arg(arg): diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 0ee389b3c6..d5028b37c4 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -111,8 +111,29 @@ def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Samp return super()._get_train_sampler(*args, **kwargs) @override - def compute_loss(self, model, inputs, *args, **kwargs): - return super().compute_loss(model, inputs, *args, **kwargs) + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + # Always get outputs if we need entropy, otherwise follow the request + need_outputs = return_outputs or getattr(self.finetuning_args, 'log_entropy', False) + + if need_outputs: + loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs) + else: + loss = super().compute_loss(model, inputs, return_outputs=False, **kwargs) + outputs = None + + # Compute entropy if enabled + if getattr(self.finetuning_args, 'log_entropy', False) and outputs is not None: + if hasattr(outputs, 'logits') and 'labels' in inputs: + from ..trainer_utils import compute_entropy + + with torch.no_grad(): + # Use the already-computed logits (detached to avoid affecting gradients) + entropy = compute_entropy(outputs.logits.detach(), inputs['labels']) + self._current_entropy = entropy.item() + + if return_outputs: + return loss, outputs + return loss @override def prediction_step( @@ -141,6 +162,14 @@ def prediction_step( return loss, generated_tokens, labels + @override + def log(self, logs: dict[str, float], *args, **kwargs) -> None: + r"""Override to add entropy to logs if computed.""" + if hasattr(self, '_current_entropy'): + logs['entropy'] = self._current_entropy + del self._current_entropy + return super().log(logs, *args, **kwargs) + def save_predictions( self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True ) -> None: diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index ec291e447a..778912de11 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -777,3 +777,38 @@ def get_ray_trainer( ), ) return trainer + + +def compute_entropy(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor: + """Compute mean entropy over valid tokens. + + Args: + logits: Model output logits of shape (batch_size, seq_len, vocab_size) + labels: Target labels of shape (batch_size, seq_len) + ignore_index: Index to ignore in labels (typically -100) + + Returns: + Mean entropy over valid tokens + """ + # Shift logits and labels for next-token prediction + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Create mask for valid tokens + valid_mask = shift_labels != ignore_index + + if not valid_mask.any(): + return torch.tensor(0.0, device=logits.device) + + # Compute probabilities via softmax + probs = torch.nn.functional.softmax(shift_logits, dim=-1) + + # Compute entropy: -sum(p * log(p)) + log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1) + entropy = -torch.sum(probs * log_probs, dim=-1) # (batch_size, seq_len-1) + + # Average over valid tokens only + valid_entropy = entropy[valid_mask] + mean_entropy = valid_entropy.mean() + + return mean_entropy From 42da0ceb12cd2ea0e5f179094d656169a67be42a Mon Sep 17 00:00:00 2001 From: Pankaj Dixit Date: Thu, 1 Jan 2026 09:01:24 -0800 Subject: [PATCH 2/2] Fix #9306: add entropy logging for SFT training path Signed-off-by: Pankaj Dixit --- examples/train_lora/llama3.2_1B_lora_sft.sh | 42 +++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 examples/train_lora/llama3.2_1B_lora_sft.sh diff --git a/examples/train_lora/llama3.2_1B_lora_sft.sh b/examples/train_lora/llama3.2_1B_lora_sft.sh new file mode 100644 index 0000000000..715fffd4d6 --- /dev/null +++ b/examples/train_lora/llama3.2_1B_lora_sft.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +set -x +OUTPUT="saves/llama3.2-1b/lora/sft" +mkdir -p "$OUTPUT" +echo "Logging to: $OUTPUT" + +MODEL_PATH=meta-llama/Llama-3.2-1B + +llamafactory-cli train \ + --model_name_or_path ${MODEL_PATH} \ + --trust_remote_code \ + --stage sft \ + --do_train \ + --finetuning_type lora \ + --lora_rank 16 \ + --lora_target all \ + --dataset identity,alpaca_en_demo \ + --template llama3 \ + --cutoff_len 2048 \ + --max_samples 1000 \ + --overwrite_cache \ + --preprocessing_num_workers 16 \ + --dataloader_num_workers 4 \ + --output_dir ${OUTPUT} \ + --logging_steps 10 \ + --save_steps 500 \ + --plot_loss \ + --overwrite_output_dir \ + --save_only_model false \ + --report_to none \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 4 \ + --learning_rate 2e-4 \ + --num_train_epochs 3.0 \ + --lr_scheduler_type cosine \ + --warmup_ratio 0.1 \ + --bf16 \ + --log_entropy \ + --ddp_timeout 180000000 > "$OUTPUT/train.log" 2>&1 + +echo "Training completed. Logs are saved to: $OUTPUT/train.log" \ No newline at end of file