diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 1ea40c95..6acf09b2 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -79,6 +79,7 @@ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq): def __call__(self, features, *args, **kwargs): audio_values = [f.pop("audio_values", None) for f in features] + if self.include_alt_fields: # these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method alt_features = [ @@ -89,6 +90,7 @@ def __call__(self, features, *args, **kwargs): } for f in features ] + input_ids_lens = torch.LongTensor([f["input_ids"].shape[-1] for f in features]) batch = super().__call__(features, *args, **kwargs) if self.include_alt_fields: @@ -100,7 +102,7 @@ def __call__(self, features, *args, **kwargs): # Pad the last dimension of all audio_values to the same length, with 0s on the right. if audio_values and audio_values[0] is not None: max_len = max([x.shape[-1] for x in audio_values]) - batch["audio_values"] = torch.stack( + batch["audio_values"] = torch.cat( [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values] ) if self.tokenizer.padding_side == "left": @@ -108,7 +110,6 @@ def __call__(self, features, *args, **kwargs): batch["audio_token_start_idx"] += displacement.to( batch["audio_token_start_idx"].device ) - return batch diff --git a/ultravox/inference/infer.py b/ultravox/inference/infer.py index f117c6ee..b2f8c068 100644 --- a/ultravox/inference/infer.py +++ b/ultravox/inference/infer.py @@ -111,7 +111,8 @@ def infer_batch( inputs = [self._dataproc(s) for s in samples] for input in inputs: for key, val in input.items(): - input[key] = val.squeeze(0) + if key != "audio_values": + input[key] = val.squeeze(0) tensors = self.data_collator(inputs) input_len = tensors["input_ids"].shape[1] diff --git a/ultravox/inference/infer_test.py b/ultravox/inference/infer_test.py index f6a06c74..dee1f125 100644 --- a/ultravox/inference/infer_test.py +++ b/ultravox/inference/infer_test.py @@ -1,3 +1,4 @@ +from typing import Optional from unittest import mock import numpy as np @@ -30,16 +31,23 @@ def audio_processor(): ) +@pytest.fixture(scope="module") +def audio_processor_whisper(): + return transformers.AutoProcessor.from_pretrained("openai/whisper-tiny") + + class FakeInference(infer.LocalInference): def __init__( self, tokenizer: transformers.PreTrainedTokenizer, audio_processor: transformers.ProcessorMixin, + audio_context_size: Optional[int] = None, ): def fake_generate(**kwargs): input = kwargs.get("input_ids") + input_len = input.shape[1] if input is not None else 0 output = transformers.generation.utils.GenerateDecoderOnlyOutput( - sequences=[range(25)] + sequences=[range(input_len + 5)] # Always output 5 tokens ) streamer = kwargs.get("streamer", None) if streamer: @@ -49,7 +57,7 @@ def fake_generate(**kwargs): return output processor = ultravox_processing.UltravoxProcessor( - audio_processor, tokenizer=tokenizer + audio_processor, tokenizer=tokenizer, audio_context_size=audio_context_size ) super().__init__( mock.MagicMock(), @@ -66,6 +74,26 @@ def fake_generate(**kwargs): EXPECTED_TOKEN_IDS_END = [128009, 128006, 78191, 128007, 271] +def test_long_audio_context(tokenizer, audio_processor_whisper): + """Ensure we handle long audio context properly.""" + inference = FakeInference( + tokenizer, audio_processor_whisper, audio_context_size=3000 + ) + array = np.ones(960000, dtype=np.float32) + sample = datasets.VoiceSample.from_prompt_and_raw( + "Transcribe\n<|audio|>", array, 16000 + ) + output = inference.infer(sample) + assert output.input_tokens == 388 + assert output.output_tokens == 5 + assert output.text == "ers on conapub" + generate_args = inference.model.generate.call_args[1] + assert generate_args["audio_values"].shape == (2, 80, 3000) + assert generate_args["audio_token_len"].item() == torch.tensor(375) + assert generate_args["audio_token_start_idx"] == torch.tensor(8) + assert generate_args["audio_batch_size"] == torch.tensor(2) + + def test_infer_16kHz(tokenizer, audio_processor): """Ensure we handle 16kHz float32 audio properly.""" inference = FakeInference(tokenizer, audio_processor) @@ -148,8 +176,8 @@ def test_infer_text_only(tokenizer, audio_processor): sample = datasets.VoiceSample.from_prompt("Hello?") output = inference.infer(sample) assert output.input_tokens == 12 - assert output.output_tokens == 13 - assert output.text == "-./0123456789" + assert output.output_tokens == 5 + assert output.text == "-./01" generate_args = inference.model.generate.call_args[1] assert generate_args.get("audio_values") is None call_input_ids = generate_args["input_ids"] diff --git a/ultravox/inference/ultravox_infer.py b/ultravox/inference/ultravox_infer.py index 6765ece1..87911fcb 100644 --- a/ultravox/inference/ultravox_infer.py +++ b/ultravox/inference/ultravox_infer.py @@ -58,7 +58,10 @@ def __init__( ) processor = ultravox_processing.UltravoxProcessor( - audio_processor, tokenizer=tokenizer, stack_factor=model.config.stack_factor + audio_processor, + tokenizer=tokenizer, + stack_factor=model.config.stack_factor, + audio_context_size=model.audio_tower_context_length, ) super().__init__( diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index cb525fdd..57719c1d 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -51,6 +51,10 @@ def __init__(self, config: UltravoxConfig): self.vocab_size = config.vocab_size self.audio_tower = self._create_audio_tower(config) + self.audio_tower_context_length: Optional[int] = None + if config.audio_model_id is not None and "whisper" in config.audio_model_id: + self.audio_tower_context_length = 3000 + self.multi_modal_projector = self._create_multi_modal_projector(config) self.language_model = self._create_language_model(config) @@ -155,6 +159,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, audio_token_start_idx: Optional[torch.Tensor] = None, audio_token_len: Optional[torch.Tensor] = None, + audio_batch_size: Optional[torch.Tensor] = None, past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None, # the alt_* fields are needed for KL divergence loss alt_input_ids: Optional[torch.Tensor] = None, @@ -186,27 +191,36 @@ def forward( inputs_embeds = self.get_input_embeddings().forward(input_ids) if audio_values is not None: + assert ( - audio_token_start_idx is not None and audio_token_len is not None - ), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided." + audio_token_start_idx is not None + and audio_token_len is not None + and audio_batch_size is not None + ), "audio_token_start_idx and audio_token_len and audio_batch_size must be provided if audio_values are provided." assert ( - len(audio_token_start_idx) == len(audio_token_len) == len(audio_values) - ), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size." + len(audio_token_start_idx) + == len(audio_token_len) + == len(audio_batch_size) + ), "audio_token_start_idx and audio_token_len and audio_batch_size must have the same batch size." - # B x A/3200 x D audio_tower_output = self.audio_tower.forward( audio_values.to(self.audio_tower.dtype) ).last_hidden_state audio_tower_output = audio_tower_output.to(inputs_embeds.dtype) - audio_embeds = self.multi_modal_projector.forward(audio_tower_output) # combine audio and text embeddings - for i, (audio, start, length) in enumerate( - zip(audio_embeds, audio_token_start_idx, audio_token_len) + audio_ind = 0 + for i, (start, length, batch_size) in enumerate( + zip(audio_token_start_idx, audio_token_len, audio_batch_size) ): + audio = torch.cat( + [audio_embeds[k] for k in range(audio_ind, audio_ind + batch_size)], + dim=0, + ) length = min(length, audio.shape[0]) inputs_embeds[i, start : start + length] = audio[:length] + audio_ind += batch_size lm_output = self.language_model.forward( inputs_embeds=inputs_embeds, @@ -241,6 +255,7 @@ def prepare_inputs_for_generation( audio_values: Optional[torch.FloatTensor] = None, audio_token_start_idx: Optional[torch.Tensor] = None, audio_token_len: Optional[torch.Tensor] = None, + audio_batch_size: Optional[torch.Tensor] = None, past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, @@ -269,6 +284,7 @@ def prepare_inputs_for_generation( audio_token_start_idx - prefill_start_idx ) model_input["audio_token_len"] = audio_token_len + model_input["audio_batch_size"] = audio_batch_size return model_input diff --git a/ultravox/model/ultravox_pipeline.py b/ultravox/model/ultravox_pipeline.py index c9a8aaa1..33bff932 100644 --- a/ultravox/model/ultravox_pipeline.py +++ b/ultravox/model/ultravox_pipeline.py @@ -37,6 +37,7 @@ def __init__( audio_processor=audio_processor, tokenizer=tokenizer, stack_factor=model.config.stack_factor, + audio_context_size=model.audio_tower_context_length, ) super().__init__(model=model, tokenizer=tokenizer, **kwargs) diff --git a/ultravox/model/ultravox_processing.py b/ultravox/model/ultravox_processing.py index 211f7f0a..ab518a78 100644 --- a/ultravox/model/ultravox_processing.py +++ b/ultravox/model/ultravox_processing.py @@ -1,7 +1,8 @@ -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import numpy as np import torch +import torch.nn.functional as F import transformers from .ultravox_config import UltravoxConfig @@ -38,6 +39,9 @@ def __init__( encoder_ds_factor: int = 320, stack_factor: int = 8, audio_placeholder: str = "<|audio|>", + audio_context_size: Optional[ + int + ] = 3000, # Defaults to whisper encoder context size ): """ Args: @@ -53,6 +57,7 @@ def __init__( self.stack_factor = stack_factor self.audio_placeholder = audio_placeholder self.audio_token_replacement = tokenizer.eos_token + self.audio_context_size = audio_context_size assert ( self.audio_token_replacement is not None ), "The tokenizer has no EOS token. Cannot recover." @@ -132,7 +137,7 @@ def __call__( - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`. """ # TODO: Add support for multiple audio and text inputs. - data = {} + data: Dict[str, Any] = {} audio_embed_frames = 0 if audio is not None and len(audio) > 0: if self.audio_padding == "max_length": @@ -141,6 +146,7 @@ def __call__( audio_len = 30 * sampling_rate else: audio_len = audio.shape[-1] + # It's guaranteed that the number of frames is less than or equal to this amount. # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound. # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings. @@ -157,9 +163,37 @@ def __call__( **kwargs, ) if "input_features" in x: - data["audio_values"] = x.input_features + audio_values = x.input_features else: - data["audio_values"] = x.input_values + audio_values = x.input_values + + audio_values = torch.tensor(audio_values) + if ( + self.audio_context_size + and audio_values.shape[-1] > self.audio_context_size + ): + audio_values_chunks = list( + torch.split( + audio_values, + self.audio_context_size, + dim=len(audio_values.shape) - 1, + ) + ) + # Pad the last chunk to match audio_context_size + last_chunk = audio_values_chunks[-1] + pad_size = self.audio_context_size - last_chunk.shape[-1] + if pad_size > 0: + # Pad only the last dimension (T) in B,D,T format + audio_values_chunks[-1] = F.pad( + last_chunk, (0, pad_size, 0, 0, 0, 0) + ) + else: + audio_values_chunks = [audio_values] + + data["audio_values"] = torch.cat(audio_values_chunks) + num_audio_chunks = data["audio_values"].shape[0] + + data["audio_batch_size"] = [num_audio_chunks] if text is not None: assert isinstance( diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 5dde3f62..6b337752 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -117,7 +117,6 @@ def train(args: config_base.TrainConfig): text_tokenizer.padding_side = "right" text_tokenizer.pad_token = text_tokenizer.eos_token audio_processor = transformers.AutoProcessor.from_pretrained(args.audio_model) - processor = ultravox_processing.UltravoxProcessor(audio_processor, text_tokenizer) # Instantiate the model and processor config = ultravox_config.UltravoxConfig( @@ -142,6 +141,12 @@ def train(args: config_base.TrainConfig): with model_load_context: model = ultravox_model.UltravoxModel(config) + processor = ultravox_processing.UltravoxProcessor( + audio_processor, + text_tokenizer, + audio_context_size=model.audio_tower_context_length, + ) + assert model.get_input_embeddings().num_embeddings == len( text_tokenizer ), f"Model and tokenizer mismatch: {model.get_input_embeddings().num_embeddings} != {len(text_tokenizer)}"