Skip to content

Commit

Permalink
Update gradio demo to support text/voice conversation (#75)
Browse files Browse the repository at this point in the history
Run `just gradio` to start a text/voice conversation demo with model `fixie-ai/ultravox-v0_3`
  • Loading branch information
zqhuang211 authored Aug 16, 2024
1 parent b4a4fc5 commit e5caca9
Show file tree
Hide file tree
Showing 9 changed files with 1,816 additions and 1,810 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ ipython_config.py
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
poetry.toml

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
Expand Down
3,377 changes: 1,601 additions & 1,776 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ readme = "README.md"

[tool.poetry.dependencies]
python = "^3.11"
torch = "2.2.2"
torch = ">=2.4"
transformers = {version = ">=4.43.1", extras = ["torch"]}
bitsandbytes = "~0.42.0"
peft = "~0.11.1"
simple-parsing = "~0.1.5"
librosa = "~0.10.2.post1"
requests = "~2.26.0"
requests = "~2.31.0"
datasets = "~2.19.1"
mosaicml-streaming = "~0.7.6"
nltk = "~3.8.1"
Expand All @@ -39,8 +39,8 @@ fsspec = "~2024.3.1"
gcsfs = "~2024.3.1"
sounddevice = "~0.4.7"
mosaicml-cli = "~0.6.31"
gradio-client = "~1.0.1"
gradio = "~3.40.1"
gradio-client = ">=0.16.1"
gradio = ">=4.29.0"
gpustat = "~1.1.1"
types-requests = "~2.26.0"
types-pyyaml = "^6.0.12.20240724"
Expand Down
3 changes: 3 additions & 0 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def __post_init__(self):
), f"Unexpected audio dtype: {self.audio.dtype}"
assert self.audio.ndim == 1, f"Unexpected audio shape: {self.audio.shape}"

def add_past_messages(self, past_messages: List[Dict[str, str]]):
self.messages = past_messages + self.messages

messages: List[Dict[str, str]]
"""List of messages, each with a "role" and "content" field."""
audio: Optional[np.typing.NDArray[np.float32]] = None
Expand Down
69 changes: 61 additions & 8 deletions ultravox/inference/infer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import threading
from typing import Optional
from typing import Dict, List, Optional, Tuple, Union

import librosa
import numpy as np
Expand All @@ -11,7 +12,7 @@
from ultravox.model import ultravox_processing

SAMPLE_RATE = 16000
MAX_TOKENS = 1024
MAX_NEW_TOKENS = 1024
# Without this penalty, the model tends to repeat itself.
REPETITION_PENALTY = 1.1

Expand All @@ -24,26 +25,70 @@ def __init__(
tokenizer: transformers.PreTrainedTokenizer,
device: str,
dtype: torch.dtype,
conversation_mode: bool = False,
):
self.model = model.to(device).to(dtype).eval()
self.tokenizer = tokenizer
self.processor = processor
self.dtype = dtype

self.conversation_mode = conversation_mode
self.past_messages: List[Dict[str, str]] = []
self.past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = (
None
)

def update_conversation(
self,
past_messages: List[Dict[str, str]] = [],
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
):
self.past_messages = past_messages
self.past_key_values = past_key_values

def _get_sample_with_past(
self, sample: datasets.VoiceSample
) -> datasets.VoiceSample:
sample = copy.copy(sample)
sample.add_past_messages(self.past_messages)
return sample

def infer(
self,
sample: datasets.VoiceSample,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> base.VoiceOutput:
inputs = self._dataproc(sample)
extended_sample = self._get_sample_with_past(sample)
inputs = self._dataproc(extended_sample)
input_len = inputs["input_ids"].shape[1]
output = self._generate(inputs, max_tokens, temperature)
output_tokens = output[0][input_len:]
output = self._generate(
inputs, max_tokens, temperature, past_key_values=self.past_key_values
)
output_tokens = output.sequences[0][input_len:]
output_text = self.tokenizer.decode(output_tokens, skip_special_tokens=True)
output_len = len(output_tokens)

if self.conversation_mode:
past_messages = copy.deepcopy(extended_sample.messages)
audio_token_len = (
0 if "audio_token_len" not in inputs else inputs["audio_token_len"][0]
)
if audio_token_len > 0:
user_content = past_messages[-1]["content"]
if user_content.count("<|audio|>") != 1:
raise ValueError(
f"Expected 1 audio placeholder, found {user_content.count('<|audio|>')}"
)
past_messages[-1]["content"] = user_content.replace(
"<|audio|>", self.tokenizer.eos_token * audio_token_len
)
past_messages.append({"role": "assistant", "content": output_text})
self.update_conversation(past_messages, output.past_key_values)

return base.VoiceOutput(output_text, input_len, output_len)

# streaming is not supported in conversation mode yet, to be implemented
def infer_stream(
self,
sample: datasets.VoiceSample,
Expand All @@ -57,7 +102,12 @@ def infer_stream(
self.tokenizer, skip_prompt=True, decode_kwargs=decode_kwargs
)

thread_args = (inputs, max_tokens, temperature, streamer)
thread_args = (
inputs,
max_tokens,
temperature,
streamer,
)
thread = threading.Thread(target=self._generate, args=thread_args)
thread.start()
output_tokens = 0
Expand Down Expand Up @@ -108,9 +158,10 @@ def _dataproc(self, sample: datasets.VoiceSample):
def _generate(
self,
inputs: torch.Tensor,
max_tokens: Optional[int] = None,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
streamer: Optional[transformers.TextStreamer] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
):
temperature = temperature or None
do_sample = temperature is not None
Expand All @@ -122,10 +173,12 @@ def _generate(
return self.model.generate(
**inputs,
do_sample=do_sample,
max_new_tokens=max_tokens or MAX_TOKENS,
max_new_tokens=max_new_tokens or MAX_NEW_TOKENS,
temperature=temperature,
repetition_penalty=REPETITION_PENALTY,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=terminators,
streamer=streamer,
past_key_values=past_key_values,
return_dict_in_generate=True,
)
6 changes: 4 additions & 2 deletions ultravox/inference/infer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ def __init__(
):
def fake_generate(**kwargs):
input = kwargs.get("input_ids")
output = [range(25)]
output = transformers.generation.utils.GenerateDecoderOnlyOutput(
sequences=[range(25)]
)
streamer = kwargs.get("streamer", None)
if streamer:
for token in output[0][input.shape[1] :]:
for token in output.sequences[0][input.shape[1] :]:
streamer.on_finalized_text(tokenizer.decode(token))
streamer.on_finalized_text("", stream_end=True)
return output
Expand Down
3 changes: 3 additions & 0 deletions ultravox/inference/ultravox_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
tokenizer_id: Optional[str] = None,
device: Optional[str] = None,
data_type: Optional[str] = None,
conversation_mode: bool = False,
):
"""
Args:
Expand All @@ -29,6 +30,7 @@ def __init__(
tokenizer_id: model_id for the tokenizer to use. If not provided, it will be inferred
device: where to put the model and data
data_type: data type to use for the model
conversation_mode: if true, keep track of past messages in a conversation
"""
device = device or utils.default_device()
dtype = utils.get_dtype(data_type) if data_type else utils.default_dtype()
Expand Down Expand Up @@ -65,4 +67,5 @@ def __init__(
tokenizer=tokenizer,
device=device,
dtype=dtype,
conversation_mode=conversation_mode,
)
16 changes: 13 additions & 3 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,20 +237,30 @@ def prepare_inputs_for_generation(
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, Any]:
model_input = self.language_model.prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
**kwargs,
)

if is_cache_empty(past_key_values) and audio_values is not None:
# We only want to use audio features in the 1st generation step
# include audio information in model_input only when it is needed during prefilling
# audio_token_start_idx should always be relative to the current cache position
prefill_start_idx = 0 if cache_position is None else cache_position[0]
if (
audio_values is not None
and audio_token_start_idx is not None
and prefill_start_idx <= torch.max(audio_token_start_idx)
):
model_input["audio_values"] = audio_values
model_input["audio_token_start_idx"] = audio_token_start_idx
model_input["audio_token_start_idx"] = (
audio_token_start_idx - prefill_start_idx
)
model_input["audio_token_len"] = audio_token_len

return model_input
Expand Down
Loading

0 comments on commit e5caca9

Please sign in to comment.