diff --git a/unsloth/models/auto_sequence_classification.py b/unsloth/models/auto_sequence_classification.py new file mode 100644 index 000000000..df694bde8 --- /dev/null +++ b/unsloth/models/auto_sequence_classification.py @@ -0,0 +1,268 @@ +import torch +import torch.nn as nn +from transformers import ( + AutoModelForSequenceClassification, + AutoConfig, + PreTrainedModel, + MllamaForConditionalGeneration, + MllamaConfig, + LlavaNextForConditionalGeneration, + LlavaNextConfig, + AutoTokenizer +) +from transformers.modeling_outputs import SequenceClassifierOutput +from typing import Optional, Union, Tuple +import warnings + + +class SequenceClassificationMixin: + """ + Mixin class containing common methods for sequence classification models. + """ + + @staticmethod + def compute_classification_loss(logits, labels, num_labels, config): + """Compute loss based on problem type.""" + if labels is None: + return None + + # Determine problem type if not set + if config.problem_type is None: + if num_labels == 1: + config.problem_type = "regression" + elif num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + config.problem_type = "single_label_classification" + else: + config.problem_type = "multi_label_classification" + + # Compute loss based on problem type + if config.problem_type == "regression": + loss_fct = nn.MSELoss() + if num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif config.problem_type == "single_label_classification": + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.view(-1, num_labels), labels.view(-1)) + elif config.problem_type == "multi_label_classification": + loss_fct = nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + else: + raise ValueError(f"Unknown problem type: {config.problem_type}") + + return loss + + @staticmethod + def pool_sequence(last_hidden_state, attention_mask=None): + """Pool the sequence representation using the last non-padded token.""" + if attention_mask is not None: + # Find the last non-padded token for each sequence + batch_size = last_hidden_state.shape[0] + sequence_lengths = attention_mask.sum(dim=1) - 1 + pooled_output = last_hidden_state[torch.arange(batch_size), sequence_lengths] + else: + # Use the last token + pooled_output = last_hidden_state[:, -1, :] + + return pooled_output + + + +class MllamaForSequenceClassification(PreTrainedModel, SequenceClassificationMixin): + """ + Mllama model with a sequence classification head on top (a linear layer on top of the pooled output). + """ + config_class = MllamaConfig + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + # Load the base vision model + self.mllama = MllamaForConditionalGeneration(config) + + # Get the hidden size from the language model + if hasattr(config, 'text_config') and config.text_config is not None: + hidden_size = config.text_config.hidden_size + elif hasattr(config, 'hidden_size'): + hidden_size = config.hidden_size + else: + # Fallback - get from the actual model + hidden_size = self.mllama.language_model.config.hidden_size + + # Classification head + self.score = nn.Linear(hidden_size, config.num_labels) + self.dropout = nn.Dropout(config.classifier_dropout if hasattr(config, 'classifier_dropout') else 0.1) + + # Initialize weights + self.post_init() + + + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping + the model weights fixed. + """ + def make_inputs_require_grads(module, input, output): + output.requires_grad_(True) + + # Access embeddings through the language model + embedding_layer = self.mllama.model.language_model.embed_tokens + self._require_grads_hook = embedding_layer.register_forward_hook(make_inputs_require_grads) + + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_ids: Optional[torch.LongTensor] = None, + aspect_ratio_mask: Optional[torch.LongTensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, SequenceClassifierOutput]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Get outputs from the language model part only (ignore vision for sequence classification) + language_model_outputs = self.mllama.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict + ) + + # Get the last hidden state and pool it + last_hidden_state = language_model_outputs.last_hidden_state + pooled_output = self.pool_sequence(last_hidden_state, attention_mask) + + # Apply dropout and classification + pooled_output = self.dropout(pooled_output) + logits = self.score(pooled_output) + + # Compute loss using the mixin method + loss = self.compute_classification_loss(logits, labels, self.num_labels, self.config) + + if not return_dict: + output = (logits,) + language_model_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=language_model_outputs.hidden_states, + attentions=language_model_outputs.attentions, + ) + + +class LlavaNextForSequenceClassification(PreTrainedModel, SequenceClassificationMixin): + """ + LlavaNext model with a sequence classification head on top (a linear layer on top of the pooled output). + """ + config_class = LlavaNextConfig + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + # Load the base vision model + self.llava_next = LlavaNextForConditionalGeneration(config) + + # Get the hidden size from the language model + if hasattr(config, 'text_config') and config.text_config is not None: + hidden_size = config.text_config.hidden_size + elif hasattr(config, 'hidden_size'): + hidden_size = config.hidden_size + else: + # Fallback - get from the actual model + hidden_size = self.llava_next.language_model.config.hidden_size + + # Classification head - handle quantization + self.score = self._create_classification_head(hidden_size, config.num_labels) + self.dropout = nn.Dropout(config.classifier_dropout if hasattr(config, 'classifier_dropout') else 0.1) + + # Initialize weights + self.post_init() + + def _create_classification_head(self, hidden_size, num_labels): + """Create classification head with quantization support""" + try: + import bitsandbytes as bnb + from transformers.utils import is_bitsandbytes_available + if is_bitsandbytes_available() and hasattr(self.llava_next, 'language_model'): + # Check if the base model is quantized + if hasattr(self.llava_next.language_model, 'model'): + first_layer = next(iter(self.llava_next.language_model.model.layers)) + if hasattr(first_layer, 'self_attn') and hasattr(first_layer.self_attn, 'q_proj'): + if hasattr(first_layer.self_attn.q_proj, 'quant_state'): + # Model is quantized, use Linear8bitLt for the classification head + return bnb.nn.Linear8bitLt(hidden_size, num_labels, has_fp16_weights=False) + except (ImportError, AttributeError, StopIteration): + pass + + # Default to regular Linear layer + return nn.Linear(hidden_size, num_labels) + + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping + the model weights fixed. + """ + def make_inputs_require_grads(module, input, output): + output.requires_grad_(True) + + # Access embeddings through the language model + embedding_layer = self.llava_next.model.language_model.embed_tokens + self._require_grads_hook = embedding_layer.register_forward_hook(make_inputs_require_grads) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, SequenceClassifierOutput]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Get outputs from the language model part only (ignore vision for sequence classification) + language_model_outputs = self.llava_next.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict + ) + + # Get the last hidden state and pool it + last_hidden_state = language_model_outputs.last_hidden_state + pooled_output = self.pool_sequence(last_hidden_state, attention_mask) + + # Apply dropout and classification + pooled_output = self.dropout(pooled_output) + logits = self.score(pooled_output) + + # Compute loss using the mixin method + loss = self.compute_classification_loss(logits, labels, self.num_labels, self.config) + + if not return_dict: + output = (logits,) + language_model_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=language_model_outputs.hidden_states, + attentions=language_model_outputs.attentions, + ) \ No newline at end of file diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 8a4902698..b31a60edd 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -453,6 +453,7 @@ def from_pretrained( from .vision import FastBaseModel from transformers import ( AutoModelForCausalLM, + AutoModelForSequenceClassification, ) try: from transformers import AutoModelForImageTextToText @@ -461,6 +462,7 @@ def from_pretrained( from transformers import AutoModelForVision2Seq pass + DISABLE_COMPILE_MODEL_NAMES = [ "aya-vision", "modernbert", @@ -743,7 +745,6 @@ def from_pretrained( is_vlm = is_vlm or hasattr(model_config, "vision_config") if auto_model is None: auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM - model, tokenizer = FastBaseModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 4466128a2..742afe70e 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -18,6 +18,7 @@ AutoProcessor, AutoTokenizer, AutoModelForCausalLM, + AutoModelForSequenceClassification, ) try: from transformers import AutoModelForImageTextToText @@ -31,6 +32,7 @@ from ._utils import __version__ from ._utils import * from ..save import patch_saving_functions +from .auto_sequence_classification import MllamaForSequenceClassification, LlavaNextForSequenceClassification from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model from peft import PeftModelForCausalLM from transformers import set_seed as transformers_set_seed @@ -39,6 +41,11 @@ SKIP_QUANTIZATION_MODULES, requires_grad_for_gradient_checkpointing, ) +from transformers import ( + AutoConfig, + MllamaConfig, + LlavaNextConfig, +) from transformers.models.llama.modeling_llama import logger from transformers import __version__ as transformers_version from triton import __version__ as triton_version @@ -84,6 +91,54 @@ return_lora_modules, ) + +def patch_vision_models_for_sequence_classification(): + """ + Patch function to register both MllamaForSequenceClassification and LlavaNextForSequenceClassification + with AutoModelForSequenceClassification + """ + # Register the model classes + AutoModelForSequenceClassification.register(MllamaConfig, MllamaForSequenceClassification) + AutoModelForSequenceClassification.register(LlavaNextConfig, LlavaNextForSequenceClassification) + + # Also register in the config mapping if needed + from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.update({ + MllamaConfig: MllamaForSequenceClassification, + LlavaNextConfig: LlavaNextForSequenceClassification + }) + + print("✅ Successfully patched MllamaForSequenceClassification and LlavaNextForSequenceClassification!") + +# Legacy function for backward compatibility +def patch_mllama_for_sequence_classification(): + """ + Legacy patch function - now calls the main patch function + """ + patch_vision_models_for_sequence_classification() + +def create_config_for_classification(model_name: str, num_labels: int, **kwargs): + """ + Create a proper config for sequence classification + """ + # Load the original config + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + + # Add classification-specific parameters + config.num_labels = num_labels + config.problem_type = kwargs.get('problem_type', None) + config.classifier_dropout = kwargs.get('classifier_dropout', 0.1) + + return config + +def get_base_model(model): + # Get the first level module name from named_modules + for name, _ in model.named_modules(): + base_name = name.split(".")[0] + if base_name: + return base_name + + def unsloth_base_fast_generate( self, *args, @@ -261,6 +316,7 @@ def from_pretrained( whisper_task = None, **kwargs, ): + if model_types is None: raise RuntimeError( "Unsloth: Please use FastModel or FastVisionModel and not use FastBaseModel directly!" @@ -339,7 +395,7 @@ def from_pretrained( correct_dtype = bnb_compute_dtype pass - # Stop SDPA for some archs like Pixtral / Mistral3 + # Stop SDPA for some archs like Pixtral / Mistral3 / SequenceClassification if not ("attn_implementation" in kwargs): kwargs["attn_implementation"] = "sdpa" if not supports_sdpa: @@ -389,20 +445,42 @@ def from_pretrained( # Check if using forced float32 - we load it in bfloat16, then cast to float16! torch_dtype = dtype if do_forced_float32: torch_dtype = torch.bfloat16 - - model = auto_model.from_pretrained( - model_name, - device_map = device_map, - torch_dtype = torch_dtype, - # quantization_config = bnb_config, - token = token, - trust_remote_code = trust_remote_code, - # attn_implementation = attn_implementation, - **kwargs, - ) + if auto_model.__name__.endswith("ForSequenceClassification"): + if not "num_labels" in kwargs: + raise ValueError( + "Could not find 'num_labels' in model. " + "Please ensure the model is properly configured for sequence classification " + "with the correct number of output labels." + ) + patch_mllama_for_sequence_classification() + # Create config with classification parameters + config = create_config_for_classification(model_name, **kwargs) + del kwargs["attn_implementation"] + del kwargs["num_labels"] + model = auto_model.from_pretrained( + model_name, + config = config, + device_map = device_map, + torch_dtype = torch_dtype, + # quantization_config = bnb_config, + token = token, + trust_remote_code = trust_remote_code, + # attn_implementation = attn_implementation, + **kwargs, + ) + else: + model = auto_model.from_pretrained( + model_name, + device_map = device_map, + torch_dtype = torch_dtype, + # quantization_config = bnb_config, + token = token, + trust_remote_code = trust_remote_code, + # attn_implementation = attn_implementation, + **kwargs, + ) # Return old flag os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer - # Edit data-types if custom_datatype is not None: for name, module in model.named_modules(): @@ -483,9 +561,17 @@ def from_pretrained( # Also set is_loaded_in_8bit to disable incorrect DDP m.is_loaded_in_8bit = True if not full_finetuning else False + # Patch generate if os.environ.get("UNSLOTH_DISABLE_FAST_GENERATION", "0") == "0": - if model.generate.__name__ != "unsloth_base_fast_generate": + if model.__class__.__name__.endswith("ForSequenceClassification"): + base_model_name = get_base_model(model) + base_model = getattr(model, base_model_name) + if base_model.generate.__name__ != "unsloth_base_fast_generate": + base_model._old_generate = base_model.generate + unsloth_base_fast_generate.__doc__ = base_model._old_generate.__doc__ + base_model.generate = types.MethodType(unsloth_base_fast_generate, model) + elif model.generate.__name__ != "unsloth_base_fast_generate": model._old_generate = model.generate unsloth_base_fast_generate.__doc__ = model._old_generate.__doc__ model.generate = types.MethodType(unsloth_base_fast_generate, model)