Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Defining block size in UltravoxConfig, and solving assertions #157

Merged
merged 6 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ultravox/model/ultravox_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class UltravoxConfig(transformers.PretrainedConfig):
The LoRA configuration for finetuning the text model.
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
The LoRA configuration for finetuning the audio model.
audio_latency_block_size (`int`, *optional*, defaults to `None`):
The latency block size for simulating audio streaming.


Example:
Expand Down Expand Up @@ -105,6 +107,7 @@ def __init__(
projector_act: str = "swiglu",
text_model_lora_config: Optional[LoraConfigSimplified] = None,
audio_model_lora_config: Optional[LoraConfigSimplified] = None,
audio_latency_block_size: Optional[int] = None,
**kwargs,
):
self.ignore_index = ignore_index
Expand Down Expand Up @@ -147,6 +150,7 @@ def __init__(
if isinstance(audio_model_lora_config, dict)
else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified())
)
self.audio_latency_block_size = audio_latency_block_size

self.vocab_size = self.text_config.vocab_size

Expand Down
20 changes: 16 additions & 4 deletions ultravox/model/ultravox_config_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from typing import Dict, Tuple

import pytest
import transformers

from ultravox.model import ultravox_config


def exclude_key(d: Dict, key_to_exclude: Tuple) -> Dict:
"""Exclude a specific key from a dictionary."""
return {k: v for k, v in d.items() if k not in key_to_exclude}


saeeddhqan marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize(
"model_id",
["fixie-ai/ultravox-v0_2", "fixie-ai/ultravox-v0_3", "fixie-ai/ultravox-v0_4"],
Expand All @@ -14,9 +21,14 @@ def test_can_load_release(model_id: str):
)
config_from_dict = ultravox_config.UltravoxConfig(**orig_config.to_dict())
config_from_diff_dict = ultravox_config.UltravoxConfig(**orig_config.to_diff_dict())
keys_to_ignore = ("audio_latency_block_size",)
saeeddhqan marked this conversation as resolved.
Show resolved Hide resolved
orig_values = {
**{k: None for k in keys_to_ignore},
**orig_config.to_dict(),
}

assert config_from_dict.to_dict() == orig_config.to_dict()
assert config_from_diff_dict.to_dict() == orig_config.to_dict()
assert config_from_dict.to_dict() == orig_values
assert config_from_diff_dict.to_dict() == orig_values

assert config_from_dict.text_config.to_dict() == orig_config.text_config.to_dict()
assert config_from_dict.audio_config.to_dict() == orig_config.audio_config.to_dict()
Expand All @@ -25,8 +37,8 @@ def test_can_load_release(model_id: str):
config_reloaded_diff = ultravox_config.UltravoxConfig(
**config_from_dict.to_diff_dict()
)
assert config_reloaded.to_dict() == orig_config.to_dict()
assert config_reloaded_diff.to_dict() == orig_config.to_dict()
assert config_reloaded.to_dict() == orig_values
assert config_reloaded_diff.to_dict() == orig_values


def test_no_config_when_id_present():
Expand Down
4 changes: 2 additions & 2 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _create_audio_tower(
config.audio_latency_block_size, dtype=config.torch_dtype
)
else:
assert config.audio_latency_block_size not in (
assert config.audio_latency_block_size in (
None,
0,
), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
Expand All @@ -305,7 +305,7 @@ def _create_audio_tower(
config.audio_latency_block_size, dtype=config.torch_dtype
)
else:
assert config.audio_latency_block_size not in (
assert config.audio_latency_block_size in (
None,
0,
), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
Expand Down