Skip to content

Commit 40bdf25

Browse files
authored
Merge pull request #144 from dongri/fix-audio-transcription
Fix audio transcription
2 parents b763a66 + 3b296c8 commit 40bdf25

File tree

3 files changed

+90
-11
lines changed

3 files changed

+90
-11
lines changed

examples/audio_transcriptions.rs

+19-5
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
use openai_api_rs::v1::api::OpenAIClient;
22
use openai_api_rs::v1::audio::{AudioTranscriptionRequest, WHISPER_1};
33
use std::env;
4+
use std::fs::File;
5+
use std::io::Read;
46

57
#[tokio::main]
68
async fn main() -> Result<(), Box<dyn std::error::Error>> {
79
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
810
let client = OpenAIClient::builder().with_api_key(api_key).build()?;
911

10-
let req = AudioTranscriptionRequest::new(
11-
"examples/data/problem.mp3".to_string(),
12-
WHISPER_1.to_string(),
13-
);
12+
let file_path = "examples/data/problem.mp3";
13+
14+
// Test with file
15+
let req = AudioTranscriptionRequest::new(file_path.to_string(), WHISPER_1.to_string());
1416

1517
let req_json = req.clone().response_format("json".to_string());
1618

@@ -22,7 +24,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
2224
let result = client.audio_transcription_raw(req_raw).await?;
2325
println!("{:?}", result);
2426

27+
// Test with bytes
28+
let mut file = File::open(file_path)?;
29+
let mut buffer = Vec::new();
30+
file.read_to_end(&mut buffer)?;
31+
32+
let req = AudioTranscriptionRequest::new_bytes(buffer, WHISPER_1.to_string());
33+
34+
let req_json = req.clone().response_format("json".to_string());
35+
36+
let result = client.audio_transcription(req_json).await?;
37+
println!("{:?}", result);
38+
2539
Ok(())
2640
}
2741

28-
// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_translations
42+
// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_transcriptions

src/v1/api.rs

+54-4
Original file line numberDiff line numberDiff line change
@@ -310,31 +310,49 @@ impl OpenAIClient {
310310
&self,
311311
req: AudioTranscriptionRequest,
312312
) -> Result<AudioTranscriptionResponse, APIError> {
313-
// https://platform.openai.com/docs/api-reference/audio/createTranslation#audio-createtranslation-response_format
313+
// https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
314314
if let Some(response_format) = &req.response_format {
315315
if response_format != "json" && response_format != "verbose_json" {
316316
return Err(APIError::CustomError {
317317
message: "response_format must be either 'json' or 'verbose_json' please use audio_transcription_raw".to_string(),
318318
});
319319
}
320320
}
321-
let form = Self::create_form(&req, "file")?;
321+
let form: Form;
322+
if req.clone().file.is_some() {
323+
form = Self::create_form(&req, "file")?;
324+
} else if let Some(bytes) = req.clone().bytes {
325+
form = Self::create_form_from_bytes(&req, bytes)?;
326+
} else {
327+
return Err(APIError::CustomError {
328+
message: "Either file or bytes must be provided".to_string(),
329+
});
330+
}
322331
self.post_form("audio/transcriptions", form).await
323332
}
324333

325334
pub async fn audio_transcription_raw(
326335
&self,
327336
req: AudioTranscriptionRequest,
328337
) -> Result<Bytes, APIError> {
329-
// https://platform.openai.com/docs/api-reference/audio/createTranslation#audio-createtranslation-response_format
338+
// https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
330339
if let Some(response_format) = &req.response_format {
331340
if response_format != "text" && response_format != "srt" && response_format != "vtt" {
332341
return Err(APIError::CustomError {
333342
message: "response_format must be either 'text', 'srt' or 'vtt', please use audio_transcription".to_string(),
334343
});
335344
}
336345
}
337-
let form = Self::create_form(&req, "file")?;
346+
let form: Form;
347+
if req.clone().file.is_some() {
348+
form = Self::create_form(&req, "file")?;
349+
} else if let Some(bytes) = req.clone().bytes {
350+
form = Self::create_form_from_bytes(&req, bytes)?;
351+
} else {
352+
return Err(APIError::CustomError {
353+
message: "Either file or bytes must be provided".to_string(),
354+
});
355+
}
338356
self.post_form_raw("audio/transcriptions", form).await
339357
}
340358

@@ -823,4 +841,36 @@ impl OpenAIClient {
823841

824842
Ok(form)
825843
}
844+
845+
fn create_form_from_bytes<T>(req: &T, bytes: Vec<u8>) -> Result<Form, APIError>
846+
where
847+
T: Serialize,
848+
{
849+
let json = match serde_json::to_value(req) {
850+
Ok(json) => json,
851+
Err(e) => {
852+
return Err(APIError::CustomError {
853+
message: e.to_string(),
854+
})
855+
}
856+
};
857+
858+
let mut form = Form::new().part("file", Part::bytes(bytes.clone()).file_name("file.mp3"));
859+
860+
if let Value::Object(map) = json {
861+
for (key, value) in map.into_iter() {
862+
match value {
863+
Value::String(s) => {
864+
form = form.text(key, s);
865+
}
866+
Value::Number(n) => {
867+
form = form.text(key, n.to_string());
868+
}
869+
_ => {}
870+
}
871+
}
872+
}
873+
874+
Ok(form)
875+
}
826876
}

src/v1/audio.rs

+17-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ pub const WHISPER_1: &str = "whisper-1";
88

99
#[derive(Debug, Serialize, Clone)]
1010
pub struct AudioTranscriptionRequest {
11-
pub file: String,
1211
pub model: String,
1312
#[serde(skip_serializing_if = "Option::is_none")]
13+
pub file: Option<String>,
14+
#[serde(skip_serializing_if = "Option::is_none")]
15+
pub bytes: Option<Vec<u8>>,
1416
pub prompt: Option<String>,
1517
#[serde(skip_serializing_if = "Option::is_none")]
1618
pub response_format: Option<String>,
@@ -23,8 +25,21 @@ pub struct AudioTranscriptionRequest {
2325
impl AudioTranscriptionRequest {
2426
pub fn new(file: String, model: String) -> Self {
2527
Self {
26-
file,
2728
model,
29+
file: Some(file),
30+
bytes: None,
31+
prompt: None,
32+
response_format: None,
33+
temperature: None,
34+
language: None,
35+
}
36+
}
37+
38+
pub fn new_bytes(bytes: Vec<u8>, model: String) -> Self {
39+
Self {
40+
model,
41+
file: None,
42+
bytes: Some(bytes),
2843
prompt: None,
2944
response_format: None,
3045
temperature: None,

0 commit comments

Comments
 (0)