diff --git a/angelslim/compressor/speculative/inference/models/eagle3/target/modeling_qwen3_kv.py b/angelslim/compressor/speculative/inference/models/eagle3/target/modeling_qwen3_kv.py index b00f15b..af7c7fd 100644 --- a/angelslim/compressor/speculative/inference/models/eagle3/target/modeling_qwen3_kv.py +++ b/angelslim/compressor/speculative/inference/models/eagle3/target/modeling_qwen3_kv.py @@ -31,7 +31,6 @@ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config from transformers.processing_utils import Unpack from transformers.utils import ( - TransformersKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -958,7 +957,7 @@ def forward( output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[TransformersKwargs], + **kwargs, ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): # noqa: E501 diff --git a/angelslim/compressor/speculative/train/data/chat_templates.py b/angelslim/compressor/speculative/train/data/chat_templates.py deleted file mode 100644 index bd84849..0000000 --- a/angelslim/compressor/speculative/train/data/chat_templates.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2025 Tencent Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum -from typing import Dict - - -class ChatTemplateType(Enum): - """Supported chat template types.""" - - QWEN3 = "qwen3" - - -# String to ChatTemplateType mapping -CHAT_TEMPLATE_TYPE_MAPPING = { - "qwen3": ChatTemplateType.QWEN3, -} - - -class ChatTemplate: - """Chat template configuration for a specific model type.""" - - def __init__(self, user_header: str, assistant_header: str): - self.user_header = user_header - self.assistant_header = assistant_header - - def to_dict(self) -> Dict[str, str]: - """Convert template to dictionary format.""" - return { - "user_header": self.user_header, - "assistant_header": self.assistant_header, - } - - -class ChatTemplateManager: - """Manager for chat templates of different model types.""" - - def __init__(self): - self._templates = self._initialize_templates() - - def _initialize_templates(self) -> Dict[ChatTemplateType, ChatTemplate]: - """Initialize predefined chat templates.""" - return { - ChatTemplateType.QWEN3: ChatTemplate( - user_header="<|im_start|>user\n", - assistant_header="<|im_start|>assistant\n", - ) - } - - def get_template(self, chat_template_type: ChatTemplateType) -> ChatTemplate: - """ - Get chat template for specified chat template type. - - Args: - chat_template_type: The chat template type to get template for - - Returns: - ChatTemplate instance - - Raises: - ValueError: If chat template type is not supported - """ - if chat_template_type not in self._templates: - raise ValueError(f"Unsupported chat template type: {chat_template_type}") - - return self._templates[chat_template_type] - - def get_template_dict(self, chat_template_type: ChatTemplateType) -> Dict[str, str]: - """ - Get chat template as dictionary for specified chat template type. - - Args: - chat_template_type: The chat template type to get template for - - Returns: - Dictionary containing template configuration - """ - template = self.get_template(chat_template_type) - return template.to_dict() - - def list_supported_types(self) -> list[str]: - """ - List all supported chat template types. - - Returns: - List of supported chat template type names - """ - return [template_type.value for template_type in self._templates.keys()] - - -# Global template manager instance -template_manager = ChatTemplateManager() - - -# Convenience functions for backward compatibility -def get_template(chat_template_type: ChatTemplateType) -> Dict[str, str]: - """Get chat template dictionary for specified chat template type.""" - return template_manager.get_template_dict(chat_template_type) - - -def list_supported_chat_template_types() -> list[str]: - """List all supported chat template types.""" - return template_manager.list_supported_types() - - -def string_to_chat_template_type(template_type_str: str) -> ChatTemplateType: - """ - Convert string to ChatTemplateType enum. - - Args: - template_type_str: String representation of chat template type - - Returns: - ChatTemplateType enum - - Raises: - ValueError: If chat template type string is not supported - """ - if template_type_str not in CHAT_TEMPLATE_TYPE_MAPPING: - supported_types = list(CHAT_TEMPLATE_TYPE_MAPPING.keys()) - raise ValueError( - f"Unsupported chat template type: {template_type_str}. " - f"Supported types: {supported_types}" - ) - - return CHAT_TEMPLATE_TYPE_MAPPING[template_type_str] - - -def get_supported_chat_template_type_strings() -> list[str]: - """ - Get list of supported chat template type strings for command line arguments. - - Returns: - List of supported chat template type strings - """ - return list(CHAT_TEMPLATE_TYPE_MAPPING.keys()) diff --git a/angelslim/compressor/speculative/train/data/online_dataset.py b/angelslim/compressor/speculative/train/data/online_dataset.py index 31d1228..febfd4b 100644 --- a/angelslim/compressor/speculative/train/data/online_dataset.py +++ b/angelslim/compressor/speculative/train/data/online_dataset.py @@ -12,20 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple import torch from datasets import load_dataset from torch.utils.data import Dataset from transformers import AutoTokenizer -from .chat_templates import ( - ChatTemplateType, - string_to_chat_template_type, - template_manager, -) - class DatasetBuilder: def __init__( @@ -33,17 +26,10 @@ def __init__( tokenizer: AutoTokenizer, max_length: int = 2048, shuffle_seed: int = 42, - chat_template_type: ChatTemplateType = ChatTemplateType.QWEN3, ): self.tokenizer = tokenizer self.max_length = max_length self.shuffle_seed = shuffle_seed - self.chat_template_type = chat_template_type - - # Get chat template - template = template_manager.get_template_dict(chat_template_type) - self.user_header = template["user_header"] - self.assistant_header = template["assistant_header"] def build_dataset(self, datapath: str, num_proc: int = 8) -> Dataset: try: @@ -108,28 +94,33 @@ def _process_single_conversation( if not messages: return None - # Apply chat template - conversation = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=False, - ) + input_ids_list = [] + loss_mask_list = [] - # Tokenize conversation - encoding = self.tokenizer( - conversation, - return_offsets_mapping=True, - max_length=self.max_length, - truncation=True, - padding=False, - ) + for message in messages: + message_tokens = self.tokenizer.apply_chat_template( + [message], + tokenize=True, + add_generation_prompt=False, + return_tensors="pt", + ).squeeze(0) + + # Determine the loss mask based on the role + if message["role"] in ["system", "user"]: + mask = torch.zeros_like(message_tokens) + else: # assistant + mask = torch.ones_like(message_tokens) + + input_ids_list.append(message_tokens) + loss_mask_list.append(mask) + + input_ids = torch.cat(input_ids_list, dim=0) + loss_mask = torch.cat(loss_mask_list, dim=0) - input_ids = encoding.input_ids - offsets = encoding.offset_mapping + if len(input_ids) > self.max_length: + input_ids = input_ids[: self.max_length] + loss_mask = loss_mask[: self.max_length] - # Create loss mask for assistant responses - loss_mask = self._create_loss_mask_from_offsets(conversation, offsets) - input_ids = torch.tensor(input_ids) attention_mask = torch.ones_like(input_ids) return { @@ -143,34 +134,6 @@ def _process_single_conversation( print(f"Error processing conversation: {e}") return None - # Copied from https://github.com/NickL77/BaldEagle/blob/master/generate_data/generate_data.py # noqa: E501 - def _create_loss_mask_from_offsets( - self, conversation: str, offsets: torch.Tensor - ) -> torch.Tensor: - loss_mask = torch.zeros(len(offsets), dtype=torch.long) - - # Find all assistant response spans - assistant_pattern = ( - re.escape(self.assistant_header) - + r"(.*?)(?=" - + re.escape(self.user_header) - + "|$)" - ) - - for match in re.finditer(assistant_pattern, conversation, re.DOTALL): - # Get the actual response content (excluding header) - response_start = match.start(1) - response_end = match.end(1) - - # Mark tokens that overlap with assistant response - for idx, (token_start, token_end) in enumerate(offsets): - - # Check if token overlaps with assistant response span - if not (token_end <= response_start or token_start > response_end): - loss_mask[idx] = 1 - - return loss_mask - def _build_messages(self, source: List[Dict]) -> List[Dict]: # System message messages = [{"role": "system", "content": self._get_system_prompt()}] @@ -267,7 +230,6 @@ def __init__( data_args, tokenizer: AutoTokenizer, model_max_length: int = 2048, - chat_template_type: Optional[Union[str, ChatTemplateType]] = None, ): """ Initialize DatasetManager with DataArguments. @@ -276,29 +238,16 @@ def __init__( data_args: DataArguments object from train_eagle3_online.py tokenizer: Tokenizer for the model model_max_length: Maximum sequence length - chat_template_type: Chat template type. Can be: - - ChatTemplateType enum value (e.g., ChatTemplateType.QWEN3) - - String (e.g., "llama", "qwen") - - None (will default to LLAMA) """ self.data_args = data_args self.tokenizer = tokenizer self.model_max_length = model_max_length - # Convert chat_template_type to ChatTemplateType enum - if chat_template_type is None: - # Default to QWEN3 - chat_template_type = ChatTemplateType.QWEN3 - elif isinstance(chat_template_type, str): - # Convert string to enum - chat_template_type = string_to_chat_template_type(chat_template_type) - # Create dataset builder self.dataset_builder = DatasetBuilder( tokenizer=tokenizer, max_length=model_max_length, shuffle_seed=data_args.shuffle_seed, - chat_template_type=chat_template_type, ) def create_datasets(self) -> Tuple[Dataset, Optional[Dataset]]: diff --git a/tools/train_eagle3_online.py b/tools/train_eagle3_online.py index 59e7f25..a205d20 100644 --- a/tools/train_eagle3_online.py +++ b/tools/train_eagle3_online.py @@ -10,9 +10,6 @@ DataCollatorWithPadding, DatasetManager, ) -from angelslim.compressor.speculative.train.data.chat_templates import ( - get_supported_chat_template_type_strings, -) from angelslim.compressor.speculative.train.models.draft import ( DraftModelConfig, create_draft_model, @@ -96,15 +93,6 @@ def parse_args(): default=None, help="Path to evaluation data file (JSON format)", ) - data_group.add_argument( - "--chat_template_type", - type=str, - default="llama", - help=( - f"Chat template type for conversation formatting. " - f"Supported types: {', '.join(get_supported_chat_template_type_strings())}" - ), - ) data_group.add_argument( "--num_proc", type=int, @@ -298,15 +286,11 @@ def train_eagle3_online(): rank0_print("Draft model loaded successfully") # Create datasets using DatasetManager - rank0_print( - "Creating training and evaluation datasets " - f"with chat template type: {args.chat_template_type}..." - ) + rank0_print("Creating training and evaluation datasets") dataset_manager = DatasetManager( data_args=args, tokenizer=target_model.tokenizer, model_max_length=args.model_max_length, - chat_template_type=args.chat_template_type, ) train_dataset, eval_dataset = dataset_manager.create_datasets() rank0_print(