-
Notifications
You must be signed in to change notification settings - Fork 8k
Add entropy logging for SFT training path #9717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Pankaj Dixit <[email protected]>
Signed-off-by: Pankaj Dixit <[email protected]>
Summary of ChangesHello @pankd, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the SFT training path by integrating entropy logging, which allows for deeper insights into model uncertainty and prediction distribution during the training process. By adding a new configuration argument and modifying the training loop, users can now easily enable the logging of mean entropy, facilitating better monitoring and analysis of model behavior and convergence. A new utility function handles the entropy calculation, and an example script showcases its usage. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a valuable feature for logging token entropy during SFT training, which is great for monitoring model uncertainty. The implementation is solid, adding the necessary command-line argument and integrating the entropy calculation into the training loop.
I've provided a few suggestions to enhance the code. In the new example script, I've recommended quoting shell variables for better robustness. For the Python code, I've suggested a refactoring in compute_loss to improve clarity and a more efficient and stable way to calculate entropy in compute_entropy using PyTorch's built-in distribution utilities.
Overall, these are great additions. My feedback aims to refine the implementation for better maintainability and performance.
| MODEL_PATH=meta-llama/Llama-3.2-1B | ||
|
|
||
| llamafactory-cli train \ | ||
| --model_name_or_path ${MODEL_PATH} \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| --overwrite_cache \ | ||
| --preprocessing_num_workers 16 \ | ||
| --dataloader_num_workers 4 \ | ||
| --output_dir ${OUTPUT} \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| --log_entropy \ | ||
| --ddp_timeout 180000000 > "$OUTPUT/train.log" 2>&1 | ||
|
|
||
| echo "Training completed. Logs are saved to: $OUTPUT/train.log" No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): | ||
| # Always get outputs if we need entropy, otherwise follow the request | ||
| need_outputs = return_outputs or getattr(self.finetuning_args, 'log_entropy', False) | ||
|
|
||
| if need_outputs: | ||
| loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs) | ||
| else: | ||
| loss = super().compute_loss(model, inputs, return_outputs=False, **kwargs) | ||
| outputs = None | ||
|
|
||
| # Compute entropy if enabled | ||
| if getattr(self.finetuning_args, 'log_entropy', False) and outputs is not None: | ||
| if hasattr(outputs, 'logits') and 'labels' in inputs: | ||
| from ..trainer_utils import compute_entropy | ||
|
|
||
| with torch.no_grad(): | ||
| # Use the already-computed logits (detached to avoid affecting gradients) | ||
| entropy = compute_entropy(outputs.logits.detach(), inputs['labels']) | ||
| self._current_entropy = entropy.item() | ||
|
|
||
| if return_outputs: | ||
| return loss, outputs | ||
| return loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method can be slightly refactored for better readability and to avoid a repeated getattr call. Storing the log_entropy flag in a local variable at the beginning makes the intent clearer throughout the method.
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
log_entropy = getattr(self.finetuning_args, "log_entropy", False)
# Always get outputs if we need entropy, otherwise follow the request
need_outputs = return_outputs or log_entropy
if need_outputs:
loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs)
else:
loss = super().compute_loss(model, inputs, return_outputs=False, **kwargs)
outputs = None
# Compute entropy if enabled
if log_entropy and outputs is not None:
if hasattr(outputs, "logits") and "labels" in inputs:
from ..trainer_utils import compute_entropy
with torch.no_grad():
# Use the already-computed logits (detached to avoid affecting gradients)
entropy = compute_entropy(outputs.logits.detach(), inputs["labels"])
self._current_entropy = entropy.item()
if return_outputs:
return loss, outputs
return loss| # Compute probabilities via softmax | ||
| probs = torch.nn.functional.softmax(shift_logits, dim=-1) | ||
|
|
||
| # Compute entropy: -sum(p * log(p)) | ||
| log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1) | ||
| entropy = -torch.sum(probs * log_probs, dim=-1) # (batch_size, seq_len-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of manually calculating softmax and then the entropy formula, you can use torch.distributions.Categorical for a more concise, efficient, and numerically stable implementation. The entropy() method directly computes what you need.
| # Compute probabilities via softmax | |
| probs = torch.nn.functional.softmax(shift_logits, dim=-1) | |
| # Compute entropy: -sum(p * log(p)) | |
| log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1) | |
| entropy = -torch.sum(probs * log_probs, dim=-1) # (batch_size, seq_len-1) | |
| # Compute entropy using torch.distributions | |
| dist = torch.distributions.Categorical(logits=shift_logits) | |
| entropy = dist.entropy() # (batch_size, seq_len-1) |
This PR introduces entropy logging during the SFT training path, enabling better monitoring and analysis of model behavior during training. A sample test has been added with entropy logging enabled.
Fixes # 9306