Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend audio ds_tool #113

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
112 changes: 105 additions & 7 deletions ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import datasets
import jinja2
import numpy as np
import openai
import simple_parsing
import yaml
Expand Down Expand Up @@ -175,6 +176,97 @@ def _map_sample(self, sample, exclude_fields):
return sample


@dataclasses.dataclass
class AudioExtensionTask:
audio_column_name: str = simple_parsing.field(default="audio", alias="-a")
asr_column_name: str = simple_parsing.field(default="sentence", alias="-A")
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
translation_column_name: str = simple_parsing.field(
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
default="translation", alias="-T"
)
id_column_name: str = simple_parsing.field(default="id", alias="-i")
extend_type: str = simple_parsing.field(
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
default="repeat", alias="-e", choices=["repeat", "combine"]
)
multiplier: int = simple_parsing.field(default=2, alias="-m")

def map_split(
self,
ds_split: datasets.Dataset,
num_proc: int,
writer_batch_size: int,
exclude_fields: List[str],
) -> datasets.Dataset:
print(
f'Extending audio using "{self.extend_type}" method with multiplier {self.multiplier}'
)

if self.extend_type == "repeat":
return ds_split.map(
function=self._map_sample_repeat,
num_proc=num_proc,
writer_batch_size=writer_batch_size,
remove_columns=ds_split.column_names,
)
elif self.extend_type == "combine":
return ds_split.map(
function=self._map_batch_combine,
batched=True,
batch_size=self.multiplier,
num_proc=num_proc,
writer_batch_size=writer_batch_size,
remove_columns=ds_split.column_names,
)
else:
raise ValueError(f"Unknown extend_type: {self.extend_type}")

def _map_sample_repeat(self, sample):
audio = sample[self.audio_column_name]
sentence = sample[self.asr_column_name]
translation = sample[self.translation_column_name]

if not isinstance(audio, dict) or "array" not in audio:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be able to handle this automatically by using ds_split.cast_column to Audio. (Note also that this doesn't exist in the combine operation below)

Copy link
Contributor Author

@liPatrick liPatrick Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, actually, i think if array isn't in audio, it'll just throw a key error (without the check), which should be fine in this case. I'm hesitant to stack more map operations than necessary because it takes a lot of time to process large datasets.

raise ValueError(f"Unsupported audio format: {type(audio)}")

audio_data = audio["array"]
repeated_audio = np.tile(audio_data, self.multiplier)
repeated_sentence = " ".join([sentence] * self.multiplier)
repeated_translation = " ".join([translation] * self.multiplier)

new_audio = {key: value for key, value in audio.items() if key != "path"}
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
new_audio["array"] = repeated_audio

new_sample = {
self.audio_column_name: new_audio,
self.asr_column_name: repeated_sentence,
self.translation_column_name: repeated_translation,
self.id_column_name: sample[self.id_column_name],
}

return new_sample

def _map_batch_combine(self, batch):
audios = batch[self.audio_column_name]
sentences = batch[self.asr_column_name]
translations = batch[self.translation_column_name]
ids = batch[self.id_column_name]

combined_audio = {
"sampling_rate": audios[0]["sampling_rate"],
"array": np.concatenate([audio["array"] for audio in audios]),
}
combined_sentences = " ".join(sentences)
combined_translations = " ".join(translations)
combined_ids = "+".join(ids)

new_batch = {
self.audio_column_name: [combined_audio],
self.asr_column_name: [combined_sentences],
self.translation_column_name: [combined_translations],
self.id_column_name: [combined_ids],
}
return new_batch


# This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model.
# just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -T {{question}} -a audio --token $HF_WRITE_TOKEN
# just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -T {{explanation}} -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct
Expand Down Expand Up @@ -218,10 +310,12 @@ class DatasetToolArgs:
default_factory=lambda: ["audio"]
)

task: Union[TtsTask, TextGenerationTask] = simple_parsing.subgroups(
{"tts": TtsTask, "textgen": TextGenerationTask}, # type: ignore
default_factory=TtsTask,
positional=True,
task: Union[TtsTask, TextGenerationTask, AudioExtensionTask] = (
simple_parsing.subgroups(
{"tts": TtsTask, "textgen": TextGenerationTask, "audioext": AudioExtensionTask}, # type: ignore
default_factory=TtsTask,
positional=True,
)
)

def __post_init__(self):
Expand Down Expand Up @@ -308,9 +402,9 @@ def process_and_upload_split_rescursive(
self.chunks_not_uploaded.append((start_index, end_index))
return None
failed_chunk_ranges.append((chunk_start, chunk_end))
successful_chunks = self.args.num_chunks - len(failed_chunk_ranges)
successful_chunks = total_chunks - len(failed_chunk_ranges)
print(
f"Finished processing and uploading {successful_chunks}/{self.args.num_chunks} chunks for range [{start_index}, {end_index})"
f"Finished processing and uploading {successful_chunks}/{total_chunks} chunks for range [{start_index}, {end_index})"
)
if len(failed_chunk_ranges) > 0:
for start, end in failed_chunk_ranges:
Expand Down Expand Up @@ -358,7 +452,11 @@ def _upload(self, ds_chunk_processed: datasets.Dataset, data_dir: str, split_nam
"split": split_name,
}
assert isinstance(self.args.upload_name, str)
ds_split_chunked.push_to_hub(self.args.upload_name, **hub_args)
try:
ds_split_chunked.push_to_hub(self.args.upload_name, **hub_args)
except Exception as e:
print(f"Failed to upload chunk to hub: {e}")
raise e


def main(args: DatasetToolArgs):
Expand Down
Loading