Skip to content

Commit 9806cf1

Browse files
rajasbansalzainhas
andauthored
Add changes for audio speech and audio transcriptions (#388)
* Add changes for audio speech and audio transcriptions * Remove testing word stuff * Black formatting' --------- Co-authored-by: Zain Hasan <[email protected]>
1 parent 2b1338f commit 9806cf1

File tree

3 files changed

+28
-17
lines changed

3 files changed

+28
-17
lines changed

src/together/resources/audio/speech.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def create(
3030
response_format: str = "wav",
3131
language: str = "en",
3232
response_encoding: str = "pcm_f32le",
33-
sample_rate: int = 44100,
33+
sample_rate: int | None = None,
3434
stream: bool = False,
3535
**kwargs: Any,
3636
) -> AudioSpeechStreamResponse:
@@ -49,14 +49,20 @@ def create(
4949
response_encoding (str, optional): Audio encoding of response.
5050
Defaults to "pcm_f32le".
5151
sample_rate (int, optional): Sampling rate to use for the output audio.
52-
Defaults to 44100.
52+
Defaults to None. If not provided, the default sampling rate for the model will be used.
5353
stream (bool, optional): If true, output is streamed for several characters at a time.
5454
Defaults to False.
5555
5656
Returns:
5757
Union[bytes, Iterator[AudioSpeechStreamChunk]]: The generated audio as bytes or an iterator over audio stream chunks.
5858
"""
5959

60+
if sample_rate is None:
61+
if "cartesia" in model:
62+
sample_rate = 44100
63+
else:
64+
sample_rate = 24000
65+
6066
requestor = api_requestor.APIRequestor(
6167
client=self._client,
6268
)

src/together/resources/audio/transcriptions.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def create(
3030
timestamp_granularities: Optional[
3131
Union[str, AudioTimestampGranularities]
3232
] = None,
33+
diarize: bool = False,
3334
**kwargs: Any,
3435
) -> Union[AudioTranscriptionResponse, AudioTranscriptionVerboseResponse]:
3536
"""
@@ -52,7 +53,11 @@ def create(
5253
timestamp_granularities: The timestamp granularities to populate for this
5354
transcription. response_format must be set verbose_json to use timestamp
5455
granularities. Either or both of these options are supported: word, or segment.
55-
56+
diarize: Whether to enable speaker diarization. When enabled, you will get the speaker id for each word in the transcription.
57+
In the response, in the words array, you will get the speaker id for each word.
58+
In addition, we also return the speaker_segments array which contains the speaker id for each speaker segment along with the start and end time of the segment along with all the words in the segment.
59+
You can use the speaker_id to group the words by speaker.
60+
You can use the speaker_segments to get the start and end time of each speaker segment.
5661
Returns:
5762
The transcribed text in the requested format.
5863
"""
@@ -103,6 +108,9 @@ def create(
103108
else timestamp_granularities
104109
)
105110

111+
if diarize:
112+
params_data["diarize"] = diarize
113+
106114
# Add any additional kwargs
107115
# Convert boolean values to lowercase strings for proper form encoding
108116
for key, value in kwargs.items():
@@ -135,6 +143,7 @@ def create(
135143
if (
136144
response_format == "verbose_json"
137145
or response_format == AudioTranscriptionResponseFormat.VERBOSE_JSON
146+
or diarize
138147
):
139148
# Create response with model validation that preserves extra fields
140149
return AudioTranscriptionVerboseResponse.model_validate(response.data)
@@ -158,6 +167,7 @@ async def create(
158167
timestamp_granularities: Optional[
159168
Union[str, AudioTimestampGranularities]
160169
] = None,
170+
diarize: bool = False,
161171
**kwargs: Any,
162172
) -> Union[AudioTranscriptionResponse, AudioTranscriptionVerboseResponse]:
163173
"""
@@ -180,7 +190,11 @@ async def create(
180190
timestamp_granularities: The timestamp granularities to populate for this
181191
transcription. response_format must be set verbose_json to use timestamp
182192
granularities. Either or both of these options are supported: word, or segment.
183-
193+
diarize: Whether to enable speaker diarization. When enabled, you will get the speaker id for each word in the transcription.
194+
In the response, in the words array, you will get the speaker id for each word.
195+
In addition, we also return the speaker_segments array which contains the speaker id for each speaker segment along with the start and end time of the segment along with all the words in the segment.
196+
You can use the speaker_id to group the words by speaker.
197+
You can use the speaker_segments to get the start and end time of each speaker segment.
184198
Returns:
185199
The transcribed text in the requested format.
186200
"""
@@ -239,6 +253,9 @@ async def create(
239253
)
240254
)
241255

256+
if diarize:
257+
params_data["diarize"] = diarize
258+
242259
# Add any additional kwargs
243260
# Convert boolean values to lowercase strings for proper form encoding
244261
for key, value in kwargs.items():
@@ -271,6 +288,7 @@ async def create(
271288
if (
272289
response_format == "verbose_json"
273290
or response_format == AudioTranscriptionResponseFormat.VERBOSE_JSON
291+
or diarize
274292
):
275293
# Create response with model validation that preserves extra fields
276294
return AudioTranscriptionVerboseResponse.model_validate(response.data)

tests/integration/resources/test_transcriptions.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,6 @@ def validate_diarization_response(response_dict):
3636
assert "end" in word
3737
assert "speaker_id" in word
3838

39-
# Validate top-level words field
40-
assert "words" in response_dict
41-
assert isinstance(response_dict["words"], list)
42-
assert len(response_dict["words"]) > 0
43-
44-
# Validate each word in top-level words
45-
for word in response_dict["words"]:
46-
assert "id" in word
47-
assert "word" in word
48-
assert "start" in word
49-
assert "end" in word
50-
assert "speaker_id" in word
51-
5239

5340
class TestTogetherTranscriptions:
5441
@pytest.fixture

0 commit comments

Comments
 (0)