-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add adapter for HiSanta data #47
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,17 @@ | ||
import abc | ||
import base64 | ||
import dataclasses | ||
import datetime | ||
import enum | ||
import io | ||
import json | ||
import logging | ||
import os | ||
import tempfile | ||
from typing import Any, Callable, Dict, List, Optional, Sequence | ||
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence | ||
|
||
import datasets | ||
import google.cloud.storage as gcs | ||
import librosa | ||
import numpy as np | ||
import requests | ||
|
@@ -60,7 +63,7 @@ | |
|
||
# TODO(juberti): set these in the environment so they don't need to be hard-coded here. | ||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "service_account.json" | ||
os.environ["GOOGLE_CLOUD_PROJECT"] = "fixie-training" | ||
os.environ["GOOGLE_CLOUD_PROJECT"] = "fixie-frame" | ||
|
||
|
||
# Silence the spurious warnings coming from the MosaicML streaming library. | ||
|
@@ -244,7 +247,7 @@ def _load_audio_dataset( | |
gcs_path += f"/{name}" | ||
if split: | ||
gcs_path += f"/{split}" | ||
url = f"gs://fixie-datasets/mds/{gcs_path}" | ||
url = f"gs://fixie-training-datasets/mds/{gcs_path}" | ||
temp_dir = os.path.join( | ||
tempfile.gettempdir(), f"mds_{gcs_path.replace('/', '_')}" | ||
) | ||
|
@@ -681,6 +684,80 @@ def _get_sample(self, idx, row) -> VoiceSample: | |
return self._get_transcribe_sample(idx, row, tcol="text") | ||
|
||
|
||
class HiSantaDataset(data.IterableDataset): | ||
""" | ||
A proprietary dataset from post-processed conversations with Santa and | ||
friends between 12/18 and 12/31 2023. | ||
""" | ||
|
||
class Subset(str, enum.Enum): | ||
ALL = "all" # All recoverable data | ||
BEST = "best" # Only the samples expected to have the best audio alignment | ||
|
||
def __init__(self, args: VoiceDatasetArgs, subset: Subset = Subset.BEST) -> None: | ||
super().__init__() | ||
self._args = args # TODO(mdepinet): Respect whatever args we need to. | ||
self._bucket = gcs.Client().get_bucket("hisanta-dataset") | ||
|
||
self._conversations: list[str] = json.loads( | ||
self._bucket.get_blob(f"{subset}.json").download_as_bytes() | ||
)["conversations"] | ||
"""List of references to conversation metadata JSON files in the bucket. | ||
These all look like {conversation_id}/metadata.json.""" | ||
|
||
self._system_prompts: list[dict] = json.loads( | ||
self._bucket.get_blob("hisanta_prompts.json").download_as_bytes() | ||
) | ||
"""List of system prompts for each agent, with the time at which each became effective. | ||
These all look like {"agent_id": AGENT_ID, "prompts": [{"start": ISO_TIMESTAMP, "prompt": PROMPT}]} | ||
where the prompts are in chronological order.""" | ||
|
||
def __iter__(self) -> Generator[VoiceSample, Any, None]: | ||
worker_info = data.get_worker_info() | ||
start = worker_info.id if worker_info else 0 | ||
increment = worker_info.num_workers if worker_info else 1 | ||
for i in range(start, len(self._conversations), increment): | ||
yield from self._from_conversation(i) | ||
Comment on lines
+719
to
+720
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'll probably have to experiment with how to form our samples here.
|
||
|
||
def _from_conversation(self, idx: int) -> Generator[VoiceSample, Any, None]: | ||
conversation_id = self._conversations[idx].split("/")[0] | ||
conversation = json.loads( | ||
self._bucket.get_blob(self._conversations[idx]).download_as_bytes() | ||
) | ||
system_prompt = self._get_system_prompt( | ||
conversation["agent_id"], | ||
datetime.datetime.fromisoformat(conversation["call_time"]), | ||
) | ||
if not system_prompt: | ||
return | ||
history = [{"role": "system", "content": system_prompt}] | ||
for message in conversation["messages"]: | ||
if message["role"] == "assistant": | ||
history.append({"role": "assistant", "content": message["message"]}) | ||
else: | ||
audio = self._bucket.get_blob( | ||
f"{conversation_id}/{message['speech']}" | ||
).download_as_bytes() | ||
yield VoiceSample( | ||
messages=[*history, {"role": "user", "content": "<|audio|>"}], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Current assumption is that the last message should be the assistant message. |
||
audio=audio_from_buf(audio), | ||
sample_rate=SAMPLE_RATE, | ||
audio_transcript=message["message"], | ||
) | ||
history.append({"role": "user", "content": message["message"]}) | ||
|
||
def _get_system_prompt( | ||
self, agent_id: str, call_time: datetime.datetime | ||
) -> str | None: | ||
for agent_prompts in self._system_prompts: | ||
if agent_prompts["agent_id"] != agent_id: | ||
continue | ||
for prompt in reversed(agent_prompts["prompts"]): | ||
if datetime.datetime.fromisoformat(prompt["start"]) <= call_time: | ||
return prompt["prompt"] | ||
return None # UGC agent | ||
|
||
|
||
def create_dataset(name: str, args: VoiceDatasetArgs) -> data.IterableDataset: | ||
DATASET_MAP: Dict[str, Any] = { | ||
"anyinstruct": AnyInstructAnswerDataset, | ||
|
@@ -690,6 +767,7 @@ def create_dataset(name: str, args: VoiceDatasetArgs) -> data.IterableDataset: | |
"boolq_in": BoolQInputDataset, | ||
"boolq_extended": BoolQWithExtendedAnswerDataset, | ||
"gigaspeech": GigaSpeechDataset, | ||
"hisanta": HiSantaDataset, | ||
"librispeech": LibriSpeechDataset, | ||
"voxpopuli": VoxPopuliDataset, | ||
"commonvoice": CommonVoiceDataset, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Why use strings for comments?