From af2127c8d6bbec5b3311bdf0aef43875c6abdebb Mon Sep 17 00:00:00 2001 From: Zhongqiang Huang Date: Tue, 20 Aug 2024 23:13:51 -0400 Subject: [PATCH] Initial commit of cformer adapter and input_kl loss --- ultravox/inference/infer_test.py | 2 +- ultravox/inference/ultravox_infer.py | 2 +- ultravox/model/data_processing.py | 23 +- ultravox/model/ultravox_adapter.py | 226 ++++++++++++++++++ ultravox/model/ultravox_config.py | 58 ++++- ultravox/model/ultravox_model.py | 174 +++++++------- ultravox/model/ultravox_pipeline.py | 2 +- ultravox/model/ultravox_processing.py | 29 ++- ultravox/training/config_base.py | 22 +- .../training/configs/asr_tinyllama_100s.yaml | 7 - .../training/configs/tinyllama_whisper.yaml | 29 +++ ultravox/training/train.py | 7 +- 12 files changed, 446 insertions(+), 135 deletions(-) create mode 100644 ultravox/model/ultravox_adapter.py delete mode 100644 ultravox/training/configs/asr_tinyllama_100s.yaml create mode 100644 ultravox/training/configs/tinyllama_whisper.yaml diff --git a/ultravox/inference/infer_test.py b/ultravox/inference/infer_test.py index c9768937..a3da135a 100644 --- a/ultravox/inference/infer_test.py +++ b/ultravox/inference/infer_test.py @@ -46,7 +46,7 @@ def fake_generate(**kwargs): return output processor = ultravox_processing.UltravoxProcessor( - audio_processor, tokenizer=tokenizer + audio_processor=audio_processor, tokenizer=tokenizer ) super().__init__( mock.MagicMock(), diff --git a/ultravox/inference/ultravox_infer.py b/ultravox/inference/ultravox_infer.py index 6765ece1..462e8286 100644 --- a/ultravox/inference/ultravox_infer.py +++ b/ultravox/inference/ultravox_infer.py @@ -58,7 +58,7 @@ def __init__( ) processor = ultravox_processing.UltravoxProcessor( - audio_processor, tokenizer=tokenizer, stack_factor=model.config.stack_factor + audio_processor=audio_processor, tokenizer=tokenizer, adapter=model.adapter ) super().__init__( diff --git a/ultravox/model/data_processing.py b/ultravox/model/data_processing.py index 6de95944..fa6159bc 100644 --- a/ultravox/model/data_processing.py +++ b/ultravox/model/data_processing.py @@ -57,6 +57,7 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]: inputs = self.processor( text=text, audio=audio, + transcript=sample.audio_transcript, return_tensors="pt", sampling_rate=sample.sample_rate, ) @@ -72,27 +73,22 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]: # No need to shift the labels as the model does it internally labels = input_ids.clone() - if not self.train_on_inputs: + if not self.train_on_inputs and sample.messages[-1]["role"] == "assistant": # Mask the prompt tokens and only compute loss on the assistant message, not the prompt. # The idea is that the model should only be able to predict the assistant message given the user message. # One reason is that there's very little randomness in the prompt, so the model would be forced to memorize it. # # Example (-100 is the ignore index): # Tokens: Transcribe\n<|audio|> Brown fox jumps over the lazy dog - # Labels: -100 -100 -100 -100 Brown fox jumps over the lazy dog + # Labels: -100 -100 -100 -100 Brown fox jumps over the lazy dog # # Note: The above might look weird because I'm mixing token IDs and text, but that's just for illustration. - input_text = self.processor.tokenizer.apply_chat_template( - sample.messages[:-1], tokenize=False - ) - # TODO: this might be slow due to calling audio_processor twice. We can compute modified input_text_len directly too. - # Revisit when using WhisperProcessor. - input_token_len = self.processor( - text=input_text, - audio=audio, - sampling_rate=sample.sample_rate, - )["input_ids"].shape[-1] + output_text = self.processor.tokenizer.apply_chat_template( + sample.messages[-1:], tokenize=False + ) + output_token_len = self.processor(text=output_text)["input_ids"].shape[-1] + input_token_len = len(input_ids) - output_token_len labels[:input_token_len] = -100 # If include_alt_fields is True, also include alt_input_ids, alt_attention_mask, and alt_labels @@ -102,14 +98,13 @@ def _process(self, sample: datasets.VoiceSample) -> Dict[str, Any]: alt_inputs = self.processor( text=alt_text, - audio=None, return_tensors="pt", ) alt_input_ids = alt_inputs["input_ids"].squeeze_(0) alt_inputs["attention_mask"].squeeze_(0) alt_labels = alt_input_ids.clone() - if not self.train_on_inputs: + if not self.train_on_inputs and sample.messages[-1]["role"] == "assistant": alt_input_token_len = ( input_token_len + len(alt_input_ids) - len(input_ids) ) diff --git a/ultravox/model/ultravox_adapter.py b/ultravox/model/ultravox_adapter.py new file mode 100644 index 00000000..02fda61e --- /dev/null +++ b/ultravox/model/ultravox_adapter.py @@ -0,0 +1,226 @@ +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +import transformers +import torch.nn.functional as F +import numpy as np + +import logging +from transformers import WhisperConfig + +from transformers.models.whisper import modeling_whisper as whisper +from transformers.models.wav2vec2 import modeling_wav2vec2 as wav2vec2 +from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config +from transformers.models.whisper.configuration_whisper import WhisperConfig + +from .ultravox_config import UltravoxConfig, UltravoxStackingAdapterConfig, UltravoxCFormerAdapterConfig + +logger = logging.getLogger(__name__) + +class RMSNorm(transformers.models.llama.modeling_llama.LlamaRMSNorm): + def __init__(self, hidden_size: int, init: float = 1, eps: float = 1e-6): + super().__init__(hidden_size=hidden_size, eps=eps) + self.weight.data.fill_(init) + +# currently attention_mask is not yet implemented in the forward method +class UltravoxAdapter(nn.Module): + def __init__(self, config: UltravoxConfig): + super().__init__() + audio_config: Union[Wav2Vec2Config, WhisperConfig] = config.audio_config + text_config: transformers.LlamaConfig = config.text_config + + self.input_size = audio_config.hidden_size + # self.hidden_size always matches audio_config.hidden_size + self.hidden_size = audio_config.hidden_size + self.output_size = text_config.hidden_size + + self.post_ln = RMSNorm(self.hidden_size, init=config.norm_init) + self.text_proj = nn.Linear(self.hidden_size, self.output_size) + + def forward(self, audio_features: torch.Tensor, num_tokens: Optional[torch.Tensor]=None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + raise NotImplementedError("Subclasses must implement this method") + + def project_to_text(self, hidden_states): + hidden_states = self.post_ln(hidden_states) + hidden_states = self.text_proj(hidden_states) + return hidden_states + + def get_audio_token_len(self, audio_frame_len: int, token_len: int) -> int: + raise NotImplementedError("Subclasses must implement this method") + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, init: float = 1.0): + super().__init__() + self.eps = 1e-6 + self.weight = nn.Parameter(torch.ones(dim) * init) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] + rms = torch.sqrt(torch.sum(x * x, dim=-1, keepdim=True) / d + self.eps) + x = x / rms + return x * self.weight + +class SwiGLU(nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + +class StackAudioFrames(nn.Module): + def __init__(self, stack_factor: int): + super().__init__() + self.stack_factor = stack_factor + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stack_factor == 1: + return x + b, t, d = x.shape + pad = (self.stack_factor - (t % self.stack_factor)) % self.stack_factor + x = torch.nn.functional.pad(x, (0, 0, 0, pad)) + return x.reshape(b, -1, d * self.stack_factor) + + +class StackingAdapter(UltravoxAdapter): + def __init__(self, config: UltravoxConfig): + super().__init__(config) + + self.adapter_config = UltravoxStackingAdapterConfig(**config.adapter_config) + + self._pad_and_stack = StackAudioFrames(self.adapter_config.stack_factor) + stacked_size = self.input_size * self.adapter_config.stack_factor + self.ln_pre = RMSNorm(stacked_size, init=config.norm_init) + # swiglu reduces dimension by 2, so we double it here before swigu to keep effective hidden size consistent. + intermediate_size = self.hidden_size * 2 if self.adapter_config.activation == "swiglu" else self.hidden_size + self.linear_1 = nn.Linear(stacked_size, intermediate_size, bias=False) + self.act = transformers.activations.get_activation(self.adapter_config.activation) + + def get_audio_token_len(self, audio_frame_len: int, token_len: int) -> int: + return int(np.ceil(audio_frame_len / self.adapter_config.stack_factor)) + + def forward(self, audio_features: torch.Tensor, num_tokens: Optional[torch.Tensor]=None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = self._pad_and_stack(audio_features) + hidden_states = self.ln_pre(hidden_states) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.project_to_text(hidden_states) + return hidden_states, None + + +class CFormerAdapter(UltravoxAdapter): + def __init__(self, config: UltravoxConfig): + super().__init__(config) + + adapter_config = UltravoxCFormerAdapterConfig(**config.adapter_config) + + self.num_pre_cif_layers = adapter_config.num_pre_cif_layers + self.num_post_cif_layers = adapter_config.num_post_cif_layers + + if self.num_pre_cif_layers or self.num_post_cif_layers: + if config.audio_config.model_type == "whisper": + transformer_layer_class = whisper.WhisperEncoderLayer + elif config.audio_config.model_type == "wav2vec2": + transformer_layer_class = wav2vec2.Wav2Vec2EncoderLayer + else: + raise ValueError(f"Unsupported audio model type: {config.audio_config.model_type}") + + if self.num_pre_cif_layers > 0: + self.pre_cif_layers = nn.ModuleList( + [transformer_layer_class(config.audio_config) for _ in range(self.num_pre_cif_layers)] + ) + + self.cif_proj = nn.Linear(self.hidden_size-1, self.hidden_size) + + if self.num_post_cif_layers > 0: + self.post_cif_layers = nn.ModuleList( + [transformer_layer_class(config.audio_config) for _ in range(self.num_post_cif_layers)] + ) + + def get_audio_token_len(self, audio_frame_len: int, token_len: int) -> int: + return token_len + + # This implements the continuous integrate-and-fire mechanism adapted from this paper: https://arxiv.org/abs/1905.11235 + # TODO: add support for attention_mask + def forward_cif(self, hidden_states: torch.Tensor, alphas: torch.Tensor, num_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + device = hidden_states.device + B, T, _ = hidden_states.size() + + max_num_tokens = num_tokens.max() + + # loop vars + integrate = torch.zeros([B], device=device) # accumulated alpha value that hasn't benen fired yet + remainds = torch.zeros([B], device=device) # reamining alpha value from recent firing + token_index = torch.zeros([B], dtype=torch.long, device=device) # num of fires that has happened + + # weights: B x max_num_tokens x T, weights[i, j, k] is the contribution of the k-th speech feature to the j-th text/speech token for the i-th sample + weights = torch.zeros((B, max_num_tokens, T), device=device) + for t in range(T): + if t > 0: + weights[:, :, t - 1].scatter_add_(dim=1, index=token_index.unsqueeze(1), src=remainds.unsqueeze(1)) + + alpha = alphas[:, t] + alpha_needed = 1 - integrate + integrate += alpha + ready_to_fire = integrate >= 1.0 + + while True: # allow repeated firing if integrate > threshold + integrate = torch.where(ready_to_fire, integrate - 1, integrate) + alpha_integrated = torch.where(ready_to_fire, alpha_needed, alpha) + + weights[:, :, t].scatter_(dim=1, index=token_index.unsqueeze(1), src=alpha_integrated.unsqueeze(1)) + remainds = alpha - alpha_integrated + + token_index = token_index + ready_to_fire.type_as(token_index) + token_index = torch.minimum(token_index, num_tokens - 1) + + alpha = remainds + alpha_needed = 1 + ready_to_fire = integrate >= 1.0 + if not ready_to_fire.any(): + break + + # the resulting hidden_states contains the hidden states of speech tokens right after CIF mechanism + hidden_states = weights.type_as(hidden_states).bmm(hidden_states) + + return hidden_states + + + def forward(self, audio_features: torch.Tensor, num_tokens: Optional[torch.Tensor]=None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + hidden_states = audio_features + T = hidden_states.size(1) + + for layer in self.pre_cif_layers: + hidden_states = layer(hidden_states, None, None)[0] + + # alphas is computed from the last element of hidden_states using a sigmoid function, and used to assign speech features to text/speech tokens. + alphas = torch.sigmoid(hidden_states[:, :, -1]) + pred_num_tokens = alphas.sum(-1) + + if self.training: + if num_tokens is None: + raise ValueError("num_tokens must be provided in training mode") + else: + # num_tokens is determined by accumulated predicted alpha values in inference mode + num_tokens = torch.round(pred_num_tokens).int() + # force the number of predicted tokens to be at least 1 in non-streaming mode + # this will break streaming mode and needs to be updated + num_tokens[num_tokens < 1] = 1 + + # scale alphas so that the sum of alphas is equal to num_tokens + alphas = alphas * (num_tokens / pred_num_tokens)[:, None].repeat(1, T) + + # remove the last element of hidden_states and apply CIF mechanism + hidden_states = self.forward_cif(hidden_states[:, :, :-1], alphas, num_tokens) + # project back to self.hidden_size + hidden_states = self.cif_proj(hidden_states) + + for layer in self.post_cif_layers: + hidden_states = layer(hidden_states, None, None)[0] + + hidden_states = self.project_to_text(hidden_states) + + return hidden_states, pred_num_tokens + +transformers.activations.ACT2FN["swiglu"] = SwiGLU diff --git a/ultravox/model/ultravox_config.py b/ultravox/model/ultravox_config.py index ed294e35..147a93c0 100644 --- a/ultravox/model/ultravox_config.py +++ b/ultravox/model/ultravox_config.py @@ -1,6 +1,6 @@ import dataclasses from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import transformers @@ -20,21 +20,58 @@ class LoraConfigSimplified: default_factory=lambda: ["k_proj", "q_proj", "linear_k", "linear_q"] ) - class LossFunction(str, Enum): - CrossEntropy = "ce" - KL_Divergence = "kl" + Response_CE = "Response_CE" + Response_KL = "Response_KL" + Input_KL = "Input_KL" + CIF_L1 = "CIF_L1" +class AdapterType(str, Enum): + STACKING = "STACKING" + CFORMER = "CFORMER" @dataclasses.dataclass class LossConfig: - loss_function: LossFunction = LossFunction.KL_Divergence + loss_weights: Dict[LossFunction, float] = dataclasses.field(default_factory=lambda: {LossFunction.Response_KL: 1.0}) kl_temperature: float = 2.0 @property def requires_alt_fields(self): - return self.loss_function == LossFunction.KL_Divergence + return any(lf in self.loss_weights for lf in [LossFunction.Input_KL, LossFunction.Response_KL]) + + def add_adapter_losses(self, adapter_type: AdapterType): + if adapter_type == AdapterType.CFORMER and LossFunction.CIF_L1 not in self.loss_weights: + self.loss_weights[LossFunction.CIF_L1] = 1.0 + + @property + def contains_kl_loss(self): + return any(lf in self.loss_weights for lf in [LossFunction.Input_KL, LossFunction.Response_KL]) + +@dataclasses.dataclass +class UltravoxCFormerAdapterConfig: + """ + CFormer Adapter configuration. + CIF+Transformer-based adapter to segment speech into continuous speech tokens with 1:1 correspondence to text tokens. +""" + num_pre_cif_layers: int = 2 + num_post_cif_layers: int = 2 + + +@dataclasses.dataclass +class UltravoxStackingAdapterConfig: + """ + Stacking Adapter configuration. + + Stacking+Convolutions-based adapter to segment speech into continuous speech tokens at a fixed downsampling rate. +""" + stack_factor: int = 8 + activation: str = "swiglu" + +ADAPTER_CONFIG_MAP: Dict[AdapterType, Any] = { + AdapterType.STACKING: UltravoxCFormerAdapterConfig, + AdapterType.CFORMER: UltravoxStackingAdapterConfig, +} class UltravoxConfig(transformers.PretrainedConfig): r""" @@ -64,7 +101,6 @@ class UltravoxConfig(transformers.PretrainedConfig): audio_model_lora_config (`LoraConfigSimplified`, *optional*): The LoRA configuration for finetuning the audio model. - Example: ```python @@ -96,8 +132,10 @@ def __init__( self, audio_config: Optional[Dict[str, Any]] = None, text_config: Optional[Dict[str, Any]] = None, + adapter_config: Union[UltravoxStackingAdapterConfig, UltravoxCFormerAdapterConfig] = None, audio_model_id: Optional[str] = None, text_model_id: Optional[str] = None, + adapter_type: AdapterType = AdapterType.STACKING, ignore_index: int = -100, audio_token_index: int = 32000, hidden_size: int = 4096, @@ -150,8 +188,14 @@ def __init__( else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified()) ) + self.adapter_type = adapter_type + self.adapter_config = dataclasses.asdict(adapter_config or ADAPTER_CONFIG_MAP[adapter_type]()) + self.vocab_size = self.text_config.vocab_size self.initializer_range = self.text_config.initializer_range super().__init__(**kwargs) + +UltravoxConfig.register_for_auto_class() +transformers.AutoConfig.register("ultravox", UltravoxConfig) diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index a6fa3445..50218eef 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -12,9 +12,8 @@ # We must use relative import in this directory to allow uploading to HF Hub # Even "from . import X" pattern doesn't work (undocumented and unclear why) -from .ultravox_config import LossConfig -from .ultravox_config import LossFunction -from .ultravox_config import UltravoxConfig +from .ultravox_config import LossConfig, LossFunction, UltravoxConfig, AdapterType +from .ultravox_adapter import UltravoxAdapter, StackingAdapter, CFormerAdapter from .whisper_model_modified import WhisperEncoder as ModifiedWhisperEncoder @@ -51,7 +50,7 @@ def __init__(self, config: UltravoxConfig): self.vocab_size = config.vocab_size self.audio_tower = self._create_audio_tower(config) - self.multi_modal_projector = UltravoxProjector(config) + self.adapter = self._create_adapter(config) self.language_model = self._create_language_model(config) self.loss_config = LossConfig() @@ -79,6 +78,10 @@ def tie_weights(self): return self.language_model.tie_weights() def set_loss_config(self, loss_config: LossConfig): + if LossFunction.Input_KL in loss_config.loss_weights and self.config.adapter_type != AdapterType.CFORMER: + raise ValueError( + f"Input KL loss is only supported for CFormer adapter, not {self.config.adapter_type}." + ) self.loss_config = loss_config def _setup_cache( @@ -107,6 +110,8 @@ def _compute_kl_loss( self, lm_output: transformers.modeling_outputs.CausalLMOutputWithPast, labels: Optional[torch.Tensor] = None, + audio_token_start_idx: Optional[torch.Tensor] = None, + audio_token_len: Optional[torch.Tensor] = None, past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None, alt_input_ids: Optional[torch.Tensor] = None, alt_attention_mask: Optional[torch.Tensor] = None, @@ -124,20 +129,43 @@ def _compute_kl_loss( past_key_values=past_key_values, **kwargs, ) + losses: Dict[LossFunction, torch.FloatTensor] = {} # compute the KL divergence loss between the two models - kl_loss = F.kl_div( - F.log_softmax( - lm_output.logits[labels != -100] / self.loss_config.kl_temperature, - dim=-1, - ), - F.softmax( - alt_lm_output.logits[alt_labels != -100] - / self.loss_config.kl_temperature, - dim=-1, - ), - reduction="batchmean", - ) - return {"loss": kl_loss} + if LossFunction.Response_KL in self.loss_config.loss_weights: + loss = F.kl_div( + F.log_softmax( + lm_output.logits[labels != -100] / self.loss_config.kl_temperature, + dim=-1, + ), + F.softmax( + alt_lm_output.logits[alt_labels != -100] + / self.loss_config.kl_temperature, + dim=-1, + ), + reduction="batchmean", + ) + losses[LossFunction.Response_KL] = loss + if LossFunction.Input_KL in self.loss_config.loss_weights and \ + audio_token_start_idx is not None and \ + audio_token_len is not None: + + # compute the KL divergence loss for audio tokens + audio_mask = ((audio_token_start_idx.unsqueeze(1) <= torch.arange(labels.size(1), device=labels.device)) & + (audio_token_start_idx.unsqueeze(1) + audio_token_len.unsqueeze(1) > torch.arange(labels.size(1), device=labels.device))) + + loss = F.kl_div( + F.log_softmax( + lm_output.logits[audio_mask] / self.loss_config.kl_temperature, + dim=-1, + ), + F.softmax( + alt_lm_output.logits[audio_mask] / self.loss_config.kl_temperature, + dim=-1, + ), + reduction="batchmean", + ) + losses[LossFunction.Input_KL] = loss + return losses def forward( self, @@ -192,7 +220,7 @@ def forward( ).last_hidden_state audio_tower_output = audio_tower_output.to(inputs_embeds.dtype) - audio_embeds = self.multi_modal_projector.forward(audio_tower_output) + audio_embeds, pred_num_tokens = self.adapter.forward(audio_tower_output, audio_token_len) # combine audio and text embeddings for i, (audio, start, length) in enumerate( @@ -209,24 +237,33 @@ def forward( **kwargs, ) if self.training: - if self.loss_config.loss_function == LossFunction.CrossEntropy: - return lm_output - elif self.loss_config.loss_function == LossFunction.KL_Divergence: - return self._compute_kl_loss( + total_loss = 0 + if self.loss_config.contains_kl_loss: + kl_loss = self._compute_kl_loss( lm_output=lm_output, labels=labels, + audio_token_start_idx=audio_token_start_idx, + audio_token_len=audio_token_len, past_key_values=past_key_values, alt_input_ids=alt_input_ids, alt_attention_mask=alt_attention_mask, alt_labels=alt_labels, - **kwargs, + **kwargs ) - else: - raise ValueError( - f"Unsupported loss function: {self.loss_config.loss_function}" - ) - else: - return lm_output + for loss, weight in self.loss_config.loss_weights.items(): + if loss == LossFunction.Response_CE: + total_loss += weight * lm_output.loss + elif loss == LossFunction.Response_KL: + total_loss += weight * kl_loss[LossFunction.Response_KL] + elif loss == LossFunction.Input_KL: + total_loss += weight * kl_loss[LossFunction.Input_KL] + elif loss == LossFunction.CIF_L1: + total_loss += weight * F.l1_loss(pred_num_tokens/audio_token_len, torch.ones_like(audio_token_len), reduction="mean") + else: + raise ValueError(f"Unsupported loss function: {loss}") + + lm_output.loss = total_loss + return lm_output def prepare_inputs_for_generation( self, @@ -265,6 +302,19 @@ def prepare_inputs_for_generation( return model_input + @classmethod + def _create_adapter( + cls, config: UltravoxConfig + ) -> UltravoxAdapter: + if config.adapter_type is AdapterType.STACKING: + adapter = StackingAdapter(config) + elif config.adapter_type is AdapterType.CFORMER: + adapter = CFormerAdapter(config) + else: + raise ValueError(f"Unsupported adapter type: {config.adapter_type}") + return adapter + + @classmethod def _create_audio_tower( cls, config: UltravoxConfig @@ -393,7 +443,7 @@ def print_trainable_parameters(self): lm_trainable_params, lm_all_params = count_params(self.language_model) audio_trainable_params, audio_all_params = count_params(self.audio_tower) - projector_trainable_params = ( + adapter_trainable_params = ( trainable_params - lm_trainable_params - audio_trainable_params ) projector_all_params = all_param - lm_all_params - audio_all_params @@ -402,7 +452,7 @@ def print_trainable_parameters(self): f"Trainable%: " f" LLM: {100 * lm_trainable_params / lm_all_params:.1f}%" f" || Audio Encoder: {100 * audio_trainable_params / audio_all_params:.1f}%" - f" || Projector: {100 * projector_trainable_params / projector_all_params:.1f}%" + f" || Adapter: {100 * adapter_trainable_params / projector_all_params:.1f}%" ) @@ -435,72 +485,10 @@ def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module: return model -class StackAudioFrames(nn.Module): - """ - Stack the audio embedding frames to reduce the sequence length by a factor of `stack_factor`. - - The number of output frames will be `ceil(T / stack_factor) + 1` where `T` is the number of input frames. - NOTE: the extra +1 is intentional: in case the number of audio tokens are over-estimated by the processor, - we want to make sure `processor.audio_token_replacement` (i.e. EOS) doesn't get leaked into the middle of embeddings. - In most cases this extra padding will get removed in the model's forward function so it has no effect. - """ - - def __init__(self, stack_factor: int = 8): - super().__init__() - self.stack_factor = stack_factor - - def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: - B, T, C = audio_embeds.shape - T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor - audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T + self.stack_factor)) - B, T, C = audio_embeds.shape - audio_embeds = audio_embeds.view( - B, T // self.stack_factor, C * self.stack_factor - ) - return audio_embeds - - -class RMSNorm(transformers.models.llama.modeling_llama.LlamaRMSNorm): - def __init__(self, hidden_size: int, init: float = 1, eps: float = 1e-6): - super().__init__(hidden_size=hidden_size, eps=eps) - self.weight.data.fill_(init) -class SwiGLU(nn.Module): - def forward(self, x): - x, gate = x.chunk(2, dim=-1) - return F.silu(gate) * x - -class UltravoxProjector(nn.Sequential): - def __init__(self, config: UltravoxConfig): - super().__init__() - self.hidden_dim = config.hidden_size - self._pad_and_stack = StackAudioFrames(config.stack_factor) - dim = config.audio_config.hidden_size * config.stack_factor - self.ln_pre = RMSNorm(dim, init=config.norm_init) - self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False) - dim = self.hidden_dim - self.act = transformers.activations.get_activation(config.projector_act) - dim = dim // 2 if config.projector_act == "swiglu" else dim - self.linear_2 = nn.Linear(dim, config.text_config.hidden_size, bias=False) - self.ln_post = RMSNorm(config.text_config.hidden_size, init=config.norm_init) - - def forward(self, audio_features: torch.Tensor) -> torch.Tensor: - audio_features = self._pad_and_stack(audio_features) - audio_features = self.ln_pre(audio_features) - hidden_states = self.linear_1(audio_features) - hidden_states = self.act(hidden_states) - hidden_states = self.linear_2(hidden_states) - hidden_states = self.ln_post(hidden_states) - return hidden_states - - -UltravoxConfig.register_for_auto_class() UltravoxModel.register_for_auto_class() - -transformers.AutoConfig.register("ultravox", UltravoxConfig) transformers.AutoModel.register(UltravoxConfig, UltravoxModel) # transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor) # TODO: make processor work standalone -transformers.activations.ACT2FN["swiglu"] = SwiGLU diff --git a/ultravox/model/ultravox_pipeline.py b/ultravox/model/ultravox_pipeline.py index c9a8aaa1..e2632dde 100644 --- a/ultravox/model/ultravox_pipeline.py +++ b/ultravox/model/ultravox_pipeline.py @@ -36,7 +36,7 @@ def __init__( self.processor = UltravoxProcessor( audio_processor=audio_processor, tokenizer=tokenizer, - stack_factor=model.config.stack_factor, + adapter=model.adapter, ) super().__init__(model=model, tokenizer=tokenizer, **kwargs) diff --git a/ultravox/model/ultravox_processing.py b/ultravox/model/ultravox_processing.py index 20b95611..20ee4c4b 100644 --- a/ultravox/model/ultravox_processing.py +++ b/ultravox/model/ultravox_processing.py @@ -4,7 +4,9 @@ import torch import transformers +from .ultravox_adapter import UltravoxAdapter +# TODO: update the comments to reflect the actual implementation class UltravoxProcessor(transformers.ProcessorMixin): """ Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor. @@ -34,7 +36,7 @@ def __init__( tokenizer=None, audio_padding: str = "longest", encoder_ds_factor: int = 320, - stack_factor: int = 8, + adapter: UltravoxAdapter = None, audio_placeholder: str = "<|audio|>", ): """ @@ -48,7 +50,7 @@ def __init__( """ self.audio_padding = audio_padding self.encoder_ds_factor = encoder_ds_factor - self.stack_factor = stack_factor + self.adapter = adapter self.audio_placeholder = audio_placeholder self.audio_token_replacement = tokenizer.eos_token assert ( @@ -60,6 +62,7 @@ def __call__( self, text: Optional[str] = None, audio: Optional[Union[np.ndarray, torch.Tensor]] = None, + transcript: Optional[str] = None, sampling_rate: Optional[int] = None, return_tensors: Optional[ Union[str, transformers.TensorType] @@ -105,20 +108,32 @@ def __call__( """ # TODO: Add support for multiple audio and text inputs. data = {} - audio_embed_frames = 0 + audio_token_len = 0 if audio is not None and len(audio) > 0: + if not self.adapter: + raise ValueError("Adapter must be provided for determing audio_token_len.") + if self.audio_padding == "max_length": # 30 seconds is the expected length for Whisper assert sampling_rate is not None, "Sampling rate must be provided." audio_len = 30 * sampling_rate else: audio_len = audio.shape[-1] + + # num_encoder_frames is needed for the Stacking adapter to determine audio_token_len, both in training and inference. # It's guaranteed that the number of frames is less than or equal to this amount. # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound. # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings. - nb_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4)) - audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor)) - data["audio_token_len"] = [audio_embed_frames] + num_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4)) + # num_text_tokens is needed for the CFormer adapter to determine audio_token_len in training mode. + # In inference mode, the inferred transcript length during forward pass is used to determine audio_token_len. + if transcript: + num_text_tokens = len(self.tokenizer.encode(transcript, add_special_tokens=False)) + else: + num_text_tokens = 0 + # compute the audio_token_len based on the model's adapter + audio_token_len = self.adapter.get_audio_token_len(num_encoder_frames, num_text_tokens) + data["audio_token_len"] = [audio_token_len] # Main audio processing. The processor is model-specific. x = self.audio_processor( @@ -156,7 +171,7 @@ def __call__( # where the number of is the number of audio frames. text = text.replace( self.audio_placeholder, - self.audio_token_replacement * audio_embed_frames, + self.audio_token_replacement * audio_token_len, ) # Special tokens like BOS should already have been added by the caller. diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 182887a9..f8d844a1 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -3,7 +3,7 @@ import logging import os from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import simple_parsing import torch @@ -21,7 +21,10 @@ class TrainConfig: text_model: str # audio encoder model to use audio_model: str - + # adapter type: + adapter_type: ultravox_config.AdapterType = ultravox_config.AdapterType.STACKING + adapter_config: Optional[Dict[str, Any]] = None + # The data_dicts field complements data_sets, allowing for the inclusion of # new datasets in the config. # @@ -129,3 +132,18 @@ def __post_init__(self): "LayerDrop cannot be used in DDP when encoder is not frozen. Disabling LayerDrop." ) self.disable_layerdrop = True + + if self.adapter_type is ultravox_config.AdapterType.STACKING: + self.adapter_config = ultravox_config.UltravoxStackingAdapterConfig( + **(self.adapter_config or {}) + ) + elif self.adapter_type is ultravox_config.AdapterType.CFORMER: + self.adapter_config = ultravox_config.UltravoxCFormerAdapterConfig( + **(self.adapter_config or {}) + ) + else: + raise ValueError(f"Unsupported adapter type: {self.adapter_type}") + + if self.loss_config is None: + self.loss_config = ultravox_config.LossConfig() + self.loss_config.add_adapter_losses(self.adapter_type) \ No newline at end of file diff --git a/ultravox/training/configs/asr_tinyllama_100s.yaml b/ultravox/training/configs/asr_tinyllama_100s.yaml deleted file mode 100644 index 0cb38740..00000000 --- a/ultravox/training/configs/asr_tinyllama_100s.yaml +++ /dev/null @@ -1,7 +0,0 @@ -# test config for fast experimentation, only 100 steps - -text_model: "TinyLlama/TinyLlama-1.1B-Chat-v1.0" -exp_name: "tinyllama_asr_100s" - -max_steps: 100 -lr_warmup_steps: 10 diff --git a/ultravox/training/configs/tinyllama_whisper.yaml b/ultravox/training/configs/tinyllama_whisper.yaml new file mode 100644 index 00000000..d35f3a65 --- /dev/null +++ b/ultravox/training/configs/tinyllama_whisper.yaml @@ -0,0 +1,29 @@ +# test config for fast experimentation, only 100 steps +exp_name: "tinyllama_whisper_s" +text_model: "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +audio_model: "openai/whisper-small" + +lr_warmup_steps: 10 +max_steps: 100 + +report_logs_to: ["tensorboard"] + +device: "cpu" +adapter_type: "CFORMER" + +loss_config: + loss_weights: + Response_KL: 1 + Input_KL: 1 + +val_sets: ["anyinstruct"] +data_sets: [] +data_dicts: + - path: "fixie-ai/librispeech_asr" + name: "clean" + splits: + - "train.100" + user_template: "Continue the following text using less than 50 words:\n\n<|audio|>" + assistant_template: "{{ continuation }}" + transcript_template: "{{ text }}" + num_samples: 100_000 \ No newline at end of file diff --git a/ultravox/training/train.py b/ultravox/training/train.py index d76514b7..29f9a7dc 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -97,8 +97,6 @@ def main() -> None: ) text_tokenizer.padding_side = "right" text_tokenizer.pad_token = text_tokenizer.eos_token - audio_processor = transformers.AutoProcessor.from_pretrained(args.audio_model) - processor = ultravox_processing.UltravoxProcessor(audio_processor, text_tokenizer) # Instantiate the model and processor config = ultravox_config.UltravoxConfig( @@ -106,6 +104,8 @@ def main() -> None: text_model_id=args.text_model, text_model_lora_config=args.text_model_lora_config, audio_model_lora_config=args.audio_model_lora_config, + adapter_type = args.adapter_type, + adapter_config=args.adapter_config, ) logging.info("Instantiating model...") @@ -115,6 +115,9 @@ def main() -> None: with ddp_utils.run_on_master_first(is_master): model = ultravox_model.UltravoxModel(config) + audio_processor = transformers.AutoProcessor.from_pretrained(args.audio_model) + processor = ultravox_processing.UltravoxProcessor(audio_processor=audio_processor, tokenizer=text_tokenizer, adapter=model.adapter) + assert model.get_input_embeddings().num_embeddings == len( text_tokenizer ), f"Model and tokenizer mismatch: {model.get_input_embeddings().num_embeddings} != {len(text_tokenizer)}"