-
Notifications
You must be signed in to change notification settings - Fork 18
use chat_template in speculative training #123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use chat_template in speculative training #123
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR removes the custom chat template system and simplifies the dataset processing by relying on the tokenizer's built-in apply_chat_template functionality. The change eliminates the need for maintaining custom chat template configurations and regex-based loss mask creation.
- Removed the entire
chat_templates.pymodule and its associated CLI argument - Refactored dataset processing to use tokenizer's
apply_chat_templateon a per-message basis - Simplified loss mask creation by applying masks based on message roles instead of using regex pattern matching on the formatted text
- Removed unused
TransformersKwargsimport
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| tools/train_eagle3_online.py | Removed --chat_template_type CLI argument and its usage in dataset creation |
| angelslim/compressor/speculative/train/data/online_dataset.py | Refactored message processing to use per-message tokenization and role-based loss masking |
| angelslim/compressor/speculative/train/data/chat_templates.py | Deleted entire custom chat template module |
| angelslim/compressor/speculative/inference/models/eagle3/target/modeling_qwen3_kv.py | Removed unused TransformersKwargs import from type hints |
Comments suppressed due to low confidence (1)
angelslim/compressor/speculative/train/data/online_dataset.py:1
- Corrected the print message by removing the extra space before 'datasets' and simplifying the message structure.
# Copyright 2025 Tencent Inc. All Rights Reserved.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for message in messages: | ||
| message_tokens = self.tokenizer.apply_chat_template( | ||
| [message], | ||
| tokenize=True, | ||
| add_generation_prompt=False, | ||
| return_tensors="pt", | ||
| ).squeeze(0) | ||
|
|
||
| # Determine the loss mask based on the role | ||
| if message["role"] in ["system", "user"]: | ||
| mask = torch.zeros_like(message_tokens) | ||
| else: # assistant | ||
| mask = torch.ones_like(message_tokens) |
Copilot
AI
Nov 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The per-message tokenization approach may produce incorrect results. Calling apply_chat_template on individual messages wrapped in a list does not guarantee the same tokenization as applying it to the full conversation at once. Chat templates often have context-dependent formatting (e.g., different handling for first vs. subsequent messages, or special tokens between messages) that will be lost when processing messages independently. This could lead to incorrect token boundaries and potentially malformed chat formatting.
| message_tokens = self.tokenizer.apply_chat_template( | ||
| [message], | ||
| tokenize=True, | ||
| add_generation_prompt=False, | ||
| return_tensors="pt", | ||
| ).squeeze(0) | ||
|
|
||
| # Determine the loss mask based on the role | ||
| if message["role"] in ["system", "user"]: | ||
| mask = torch.zeros_like(message_tokens) | ||
| else: # assistant | ||
| mask = torch.ones_like(message_tokens) |
Copilot
AI
Nov 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loss mask now includes special tokens and chat template formatting tokens for assistant messages, which is likely incorrect. The old implementation used offset mapping to identify only the actual response content, excluding the assistant header (e.g., <|im_start|>assistant\n). The new approach will apply loss to all tokens in the assistant message including template-specific tokens, which may degrade training quality by forcing the model to learn to predict formatting tokens rather than just content.
| message_tokens = self.tokenizer.apply_chat_template( | |
| [message], | |
| tokenize=True, | |
| add_generation_prompt=False, | |
| return_tensors="pt", | |
| ).squeeze(0) | |
| # Determine the loss mask based on the role | |
| if message["role"] in ["system", "user"]: | |
| mask = torch.zeros_like(message_tokens) | |
| else: # assistant | |
| mask = torch.ones_like(message_tokens) | |
| # Get the formatted string for the message | |
| formatted_str = self.tokenizer.apply_chat_template( | |
| [message], | |
| tokenize=False, | |
| add_generation_prompt=False, | |
| ) | |
| # Tokenize with offsets | |
| encoding = self.tokenizer( | |
| formatted_str, | |
| return_tensors="pt", | |
| return_offsets_mapping=True, | |
| ) | |
| message_tokens = encoding["input_ids"].squeeze(0) | |
| offsets = encoding["offset_mapping"].squeeze(0) | |
| # Determine the loss mask based on the role | |
| if message["role"] in ["system", "user"]: | |
| mask = torch.zeros_like(message_tokens) | |
| else: # assistant | |
| # Find the span of the assistant's response in the formatted string | |
| # Heuristic: find the first occurrence of message["content"] in formatted_str | |
| content = message.get("content", "") | |
| if not content: | |
| mask = torch.zeros_like(message_tokens) | |
| else: | |
| start_idx = formatted_str.find(content) | |
| end_idx = start_idx + len(content) | |
| # Set mask to 1 for tokens whose offsets are within the content span | |
| mask = torch.tensor( | |
| [ | |
| 1 if (offset[0] >= start_idx and offset[1] <= end_idx and offset[0] < offset[1]) | |
| else 0 | |
| for offset in offsets.tolist() | |
| ], | |
| dtype=torch.long, | |
| ) |
This pull request removes the custom chat template management system from the data pipeline and simplifies how chat templates are applied during dataset preparation. The changes streamline the code by relying on the tokenizer's built-in chat template handling, and update the data processing logic to compute the loss mask per message, rather than using string pattern matching. Additionally, related command-line arguments and imports for chat template selection are removed.
Data Pipeline Simplification:
chat_templates.pymodule, including all classes and functions for managing chat templates.DatasetBuilderandDatasetManagerinonline_dataset.pyto eliminate dependencies on chat template types and external template management. Now, chat templates are applied directly via the tokenizer, and loss masks are computed per message based on role. [1] [2] [3] [4] [5]Training Script Adjustments:
--chat_template_typeargument and all related logic fromtrain_eagle3_online.py, as chat template selection is no longer necessary. [1] [2] [3]Model Code Minor Cleanup:
TransformersKwargsimport and relaxed the type of**kwargsin theforwardmethod ofmodeling_qwen3_kv.py. [1] [2]