Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/train_lora/llama3.2_1B_lora_sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/bin/bash

set -x
OUTPUT="saves/llama3.2-1b/lora/sft"
mkdir -p "$OUTPUT"
echo "Logging to: $OUTPUT"

MODEL_PATH=meta-llama/Llama-3.2-1B

llamafactory-cli train \
--model_name_or_path ${MODEL_PATH} \
--trust_remote_code \
--stage sft \
--do_train \
--finetuning_type lora \
--lora_rank 16 \
--lora_target all \
--dataset identity,alpaca_en_demo \
--template llama3 \
--cutoff_len 2048 \
--max_samples 1000 \
--overwrite_cache \
--preprocessing_num_workers 16 \
--dataloader_num_workers 4 \
--output_dir ${OUTPUT} \
--logging_steps 10 \
--save_steps 500 \
--plot_loss \
--overwrite_output_dir \
--save_only_model false \
--report_to none \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--learning_rate 2e-4 \
--num_train_epochs 3.0 \
--lr_scheduler_type cosine \
--warmup_ratio 0.1 \
--bf16 \
--log_entropy \
--ddp_timeout 180000000 > "$OUTPUT/train.log" 2>&1

echo "Training completed. Logs are saved to: $OUTPUT/train.log"
4 changes: 4 additions & 0 deletions src/llamafactory/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,10 @@ class FinetuningArguments(
default=False,
metadata={"help": "Whether or not to compute effective tokens per second."},
)
log_entropy: bool = field(
default=False,
metadata={"help": "Whether or not to log entropy during training."},
)

def __post_init__(self):
def split_arg(arg):
Expand Down
33 changes: 31 additions & 2 deletions src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,29 @@ def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Samp
return super()._get_train_sampler(*args, **kwargs)

@override
def compute_loss(self, model, inputs, *args, **kwargs):
return super().compute_loss(model, inputs, *args, **kwargs)
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
# Always get outputs if we need entropy, otherwise follow the request
need_outputs = return_outputs or getattr(self.finetuning_args, 'log_entropy', False)

if need_outputs:
loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs)
else:
loss = super().compute_loss(model, inputs, return_outputs=False, **kwargs)
outputs = None

# Compute entropy if enabled
if getattr(self.finetuning_args, 'log_entropy', False) and outputs is not None:
if hasattr(outputs, 'logits') and 'labels' in inputs:
from ..trainer_utils import compute_entropy

with torch.no_grad():
# Use the already-computed logits (detached to avoid affecting gradients)
entropy = compute_entropy(outputs.logits.detach(), inputs['labels'])
self._current_entropy = entropy.item()

if return_outputs:
return loss, outputs
return loss

@override
def prediction_step(
Expand Down Expand Up @@ -141,6 +162,14 @@ def prediction_step(

return loss, generated_tokens, labels

@override
def log(self, logs: dict[str, float], *args, **kwargs) -> None:
r"""Override to add entropy to logs if computed."""
if hasattr(self, '_current_entropy'):
logs['entropy'] = self._current_entropy
del self._current_entropy
return super().log(logs, *args, **kwargs)

def save_predictions(
self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True
) -> None:
Expand Down
35 changes: 35 additions & 0 deletions src/llamafactory/train/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,38 @@ def get_ray_trainer(
),
)
return trainer


def compute_entropy(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor:
"""Compute mean entropy over valid tokens.

Args:
logits: Model output logits of shape (batch_size, seq_len, vocab_size)
labels: Target labels of shape (batch_size, seq_len)
ignore_index: Index to ignore in labels (typically -100)

Returns:
Mean entropy over valid tokens
"""
# Shift logits and labels for next-token prediction
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# Create mask for valid tokens
valid_mask = shift_labels != ignore_index

if not valid_mask.any():
return torch.tensor(0.0, device=logits.device)

# Compute probabilities via softmax
probs = torch.nn.functional.softmax(shift_logits, dim=-1)

# Compute entropy: -sum(p * log(p))
log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
entropy = -torch.sum(probs * log_probs, dim=-1) # (batch_size, seq_len-1)

# Average over valid tokens only
valid_entropy = entropy[valid_mask]
mean_entropy = valid_entropy.mean()

return mean_entropy