diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 750f62e4..29c70e5f 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -6,6 +6,7 @@ import datasets import jinja2 +import numpy as np import openai import simple_parsing import yaml @@ -175,6 +176,103 @@ def _map_sample(self, sample, exclude_fields): return sample +@dataclasses.dataclass +class AudioExtensionTask: + audio_column_name: str = simple_parsing.field(default="audio", alias="-a") + text_column_name: str = simple_parsing.field(default="sentence", alias="-A") + translation_column_name: Optional[str] = simple_parsing.field( + default="translation", alias="-T" + ) + id_column_name: str = simple_parsing.field(default="id", alias="-i") + extend_type: str = simple_parsing.field( + default="repeat", alias="-o", choices=["repeat", "combine"] + ) + multiplier: int = simple_parsing.field(default=2, alias="-m") + + def map_split( + self, + ds_split: datasets.Dataset, + num_proc: int, + writer_batch_size: int, + exclude_fields: List[str], + ) -> datasets.Dataset: + print( + f'Extending audio using "{self.extend_type}" method with multiplier {self.multiplier}' + ) + + if self.extend_type == "repeat": + return ds_split.map( + function=self._map_sample_repeat, + num_proc=num_proc, + writer_batch_size=writer_batch_size, + remove_columns=ds_split.column_names, + ) + elif self.extend_type == "combine": + return ds_split.map( + function=self._map_batch_combine, + batched=True, + batch_size=self.multiplier, + num_proc=num_proc, + writer_batch_size=writer_batch_size, + remove_columns=ds_split.column_names, + ) + else: + raise ValueError(f"Unknown extend_type: {self.extend_type}") + + def _map_sample_repeat(self, sample): + audio = sample[self.audio_column_name] + + new_audio = { + "sampling_rate": audio["sampling_rate"], + "array": np.tile(audio["array"], self.multiplier), + } + + new_sample = { + self.audio_column_name: new_audio, + self.text_column_name: " ".join( + [sample[self.text_column_name]] * self.multiplier + ), + self.id_column_name: sample[self.id_column_name], + } + + if self.translation_column_name is not None: + translation = sample.get(self.translation_column_name) + if translation is not None: + new_sample[self.translation_column_name] = " ".join( + [translation] * self.multiplier + ) + + return new_sample + + def _map_batch_combine(self, batch): + audios = batch[self.audio_column_name] + sentences = batch[self.text_column_name] + ids = batch[self.id_column_name] + + combined_audio = { + "sampling_rate": audios[0]["sampling_rate"], + "array": np.concatenate([audio["array"] for audio in audios]), + } + combined_sentences = " ".join(sentences) + combined_ids = "+".join(ids) + + new_batch = { + self.audio_column_name: [combined_audio], + self.text_column_name: [combined_sentences], + self.id_column_name: [combined_ids], + } + + if self.translation_column_name in batch: + translations = batch[self.translation_column_name] + if translations is not None and all( + translation is not None for translation in translations + ): + combined_translations = " ".join(translations) + new_batch[self.translation_column_name] = [combined_translations] + + return new_batch + + # This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model. # just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -T {{question}} -a audio --token $HF_WRITE_TOKEN # just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -T {{explanation}} -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct @@ -218,10 +316,12 @@ class DatasetToolArgs: default_factory=lambda: ["audio"] ) - task: Union[TtsTask, TextGenerationTask] = simple_parsing.subgroups( - {"tts": TtsTask, "textgen": TextGenerationTask}, # type: ignore - default_factory=TtsTask, - positional=True, + task: Union[TtsTask, TextGenerationTask, AudioExtensionTask] = ( + simple_parsing.subgroups( + {"tts": TtsTask, "textgen": TextGenerationTask, "audioext": AudioExtensionTask}, # type: ignore + default_factory=TtsTask, + positional=True, + ) ) def __post_init__(self): @@ -308,9 +408,9 @@ def process_and_upload_split_rescursive( self.chunks_not_uploaded.append((start_index, end_index)) return None failed_chunk_ranges.append((chunk_start, chunk_end)) - successful_chunks = self.args.num_chunks - len(failed_chunk_ranges) + successful_chunks = total_chunks - len(failed_chunk_ranges) print( - f"Finished processing and uploading {successful_chunks}/{self.args.num_chunks} chunks for range [{start_index}, {end_index})" + f"Finished processing and uploading {successful_chunks}/{total_chunks} chunks for range [{start_index}, {end_index})" ) if len(failed_chunk_ranges) > 0: for start, end in failed_chunk_ranges: @@ -358,7 +458,11 @@ def _upload(self, ds_chunk_processed: datasets.Dataset, data_dir: str, split_nam "split": split_name, } assert isinstance(self.args.upload_name, str) - ds_split_chunked.push_to_hub(self.args.upload_name, **hub_args) + try: + ds_split_chunked.push_to_hub(self.args.upload_name, **hub_args) + except Exception as e: + print(f"Failed to upload chunk to hub: {e}") + raise e def main(args: DatasetToolArgs):