Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.processing_utils import Unpack
from transformers.utils import (
TransformersKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
Expand Down Expand Up @@ -958,7 +957,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
**kwargs,
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): # noqa: E501
Expand Down
147 changes: 0 additions & 147 deletions angelslim/compressor/speculative/train/data/chat_templates.py

This file was deleted.

101 changes: 25 additions & 76 deletions angelslim/compressor/speculative/train/data/online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple

import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import AutoTokenizer

from .chat_templates import (
ChatTemplateType,
string_to_chat_template_type,
template_manager,
)


class DatasetBuilder:
def __init__(
self,
tokenizer: AutoTokenizer,
max_length: int = 2048,
shuffle_seed: int = 42,
chat_template_type: ChatTemplateType = ChatTemplateType.QWEN3,
):
self.tokenizer = tokenizer
self.max_length = max_length
self.shuffle_seed = shuffle_seed
self.chat_template_type = chat_template_type

# Get chat template
template = template_manager.get_template_dict(chat_template_type)
self.user_header = template["user_header"]
self.assistant_header = template["assistant_header"]

def build_dataset(self, datapath: str, num_proc: int = 8) -> Dataset:
try:
Expand Down Expand Up @@ -108,28 +94,33 @@ def _process_single_conversation(
if not messages:
return None

# Apply chat template
conversation = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
input_ids_list = []
loss_mask_list = []

# Tokenize conversation
encoding = self.tokenizer(
conversation,
return_offsets_mapping=True,
max_length=self.max_length,
truncation=True,
padding=False,
)
for message in messages:
message_tokens = self.tokenizer.apply_chat_template(
[message],
tokenize=True,
add_generation_prompt=False,
return_tensors="pt",
).squeeze(0)

# Determine the loss mask based on the role
if message["role"] in ["system", "user"]:
mask = torch.zeros_like(message_tokens)
else: # assistant
mask = torch.ones_like(message_tokens)
Comment on lines +100 to +112
Copy link

Copilot AI Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The per-message tokenization approach may produce incorrect results. Calling apply_chat_template on individual messages wrapped in a list does not guarantee the same tokenization as applying it to the full conversation at once. Chat templates often have context-dependent formatting (e.g., different handling for first vs. subsequent messages, or special tokens between messages) that will be lost when processing messages independently. This could lead to incorrect token boundaries and potentially malformed chat formatting.

Copilot uses AI. Check for mistakes.
Comment on lines +101 to +112
Copy link

Copilot AI Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loss mask now includes special tokens and chat template formatting tokens for assistant messages, which is likely incorrect. The old implementation used offset mapping to identify only the actual response content, excluding the assistant header (e.g., <|im_start|>assistant\n). The new approach will apply loss to all tokens in the assistant message including template-specific tokens, which may degrade training quality by forcing the model to learn to predict formatting tokens rather than just content.

Suggested change
message_tokens = self.tokenizer.apply_chat_template(
[message],
tokenize=True,
add_generation_prompt=False,
return_tensors="pt",
).squeeze(0)
# Determine the loss mask based on the role
if message["role"] in ["system", "user"]:
mask = torch.zeros_like(message_tokens)
else: # assistant
mask = torch.ones_like(message_tokens)
# Get the formatted string for the message
formatted_str = self.tokenizer.apply_chat_template(
[message],
tokenize=False,
add_generation_prompt=False,
)
# Tokenize with offsets
encoding = self.tokenizer(
formatted_str,
return_tensors="pt",
return_offsets_mapping=True,
)
message_tokens = encoding["input_ids"].squeeze(0)
offsets = encoding["offset_mapping"].squeeze(0)
# Determine the loss mask based on the role
if message["role"] in ["system", "user"]:
mask = torch.zeros_like(message_tokens)
else: # assistant
# Find the span of the assistant's response in the formatted string
# Heuristic: find the first occurrence of message["content"] in formatted_str
content = message.get("content", "")
if not content:
mask = torch.zeros_like(message_tokens)
else:
start_idx = formatted_str.find(content)
end_idx = start_idx + len(content)
# Set mask to 1 for tokens whose offsets are within the content span
mask = torch.tensor(
[
1 if (offset[0] >= start_idx and offset[1] <= end_idx and offset[0] < offset[1])
else 0
for offset in offsets.tolist()
],
dtype=torch.long,
)

Copilot uses AI. Check for mistakes.

input_ids_list.append(message_tokens)
loss_mask_list.append(mask)

input_ids = torch.cat(input_ids_list, dim=0)
loss_mask = torch.cat(loss_mask_list, dim=0)

input_ids = encoding.input_ids
offsets = encoding.offset_mapping
if len(input_ids) > self.max_length:
input_ids = input_ids[: self.max_length]
loss_mask = loss_mask[: self.max_length]

# Create loss mask for assistant responses
loss_mask = self._create_loss_mask_from_offsets(conversation, offsets)
input_ids = torch.tensor(input_ids)
attention_mask = torch.ones_like(input_ids)

return {
Expand All @@ -143,34 +134,6 @@ def _process_single_conversation(
print(f"Error processing conversation: {e}")
return None

# Copied from https://github.com/NickL77/BaldEagle/blob/master/generate_data/generate_data.py # noqa: E501
def _create_loss_mask_from_offsets(
self, conversation: str, offsets: torch.Tensor
) -> torch.Tensor:
loss_mask = torch.zeros(len(offsets), dtype=torch.long)

# Find all assistant response spans
assistant_pattern = (
re.escape(self.assistant_header)
+ r"(.*?)(?="
+ re.escape(self.user_header)
+ "|$)"
)

for match in re.finditer(assistant_pattern, conversation, re.DOTALL):
# Get the actual response content (excluding header)
response_start = match.start(1)
response_end = match.end(1)

# Mark tokens that overlap with assistant response
for idx, (token_start, token_end) in enumerate(offsets):

# Check if token overlaps with assistant response span
if not (token_end <= response_start or token_start > response_end):
loss_mask[idx] = 1

return loss_mask

def _build_messages(self, source: List[Dict]) -> List[Dict]:
# System message
messages = [{"role": "system", "content": self._get_system_prompt()}]
Expand Down Expand Up @@ -267,7 +230,6 @@ def __init__(
data_args,
tokenizer: AutoTokenizer,
model_max_length: int = 2048,
chat_template_type: Optional[Union[str, ChatTemplateType]] = None,
):
"""
Initialize DatasetManager with DataArguments.
Expand All @@ -276,29 +238,16 @@ def __init__(
data_args: DataArguments object from train_eagle3_online.py
tokenizer: Tokenizer for the model
model_max_length: Maximum sequence length
chat_template_type: Chat template type. Can be:
- ChatTemplateType enum value (e.g., ChatTemplateType.QWEN3)
- String (e.g., "llama", "qwen")
- None (will default to LLAMA)
"""
self.data_args = data_args
self.tokenizer = tokenizer
self.model_max_length = model_max_length

# Convert chat_template_type to ChatTemplateType enum
if chat_template_type is None:
# Default to QWEN3
chat_template_type = ChatTemplateType.QWEN3
elif isinstance(chat_template_type, str):
# Convert string to enum
chat_template_type = string_to_chat_template_type(chat_template_type)

# Create dataset builder
self.dataset_builder = DatasetBuilder(
tokenizer=tokenizer,
max_length=model_max_length,
shuffle_seed=data_args.shuffle_seed,
chat_template_type=chat_template_type,
)

def create_datasets(self) -> Tuple[Dataset, Optional[Dataset]]:
Expand Down
18 changes: 1 addition & 17 deletions tools/train_eagle3_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
DataCollatorWithPadding,
DatasetManager,
)
from angelslim.compressor.speculative.train.data.chat_templates import (
get_supported_chat_template_type_strings,
)
from angelslim.compressor.speculative.train.models.draft import (
DraftModelConfig,
create_draft_model,
Expand Down Expand Up @@ -96,15 +93,6 @@ def parse_args():
default=None,
help="Path to evaluation data file (JSON format)",
)
data_group.add_argument(
"--chat_template_type",
type=str,
default="llama",
help=(
f"Chat template type for conversation formatting. "
f"Supported types: {', '.join(get_supported_chat_template_type_strings())}"
),
)
data_group.add_argument(
"--num_proc",
type=int,
Expand Down Expand Up @@ -298,15 +286,11 @@ def train_eagle3_online():
rank0_print("Draft model loaded successfully")

# Create datasets using DatasetManager
rank0_print(
"Creating training and evaluation datasets "
f"with chat template type: {args.chat_template_type}..."
)
rank0_print("Creating training and evaluation datasets")
dataset_manager = DatasetManager(
data_args=args,
tokenizer=target_model.tokenizer,
model_max_length=args.model_max_length,
chat_template_type=args.chat_template_type,
)
train_dataset, eval_dataset = dataset_manager.create_datasets()
rank0_print(
Expand Down