diff --git a/examples/audio_transcriptions.rs b/examples/audio_transcriptions.rs index 49f5b88..f74fc11 100644 --- a/examples/audio_transcriptions.rs +++ b/examples/audio_transcriptions.rs @@ -1,16 +1,18 @@ use openai_api_rs::v1::api::OpenAIClient; use openai_api_rs::v1::audio::{AudioTranscriptionRequest, WHISPER_1}; use std::env; +use std::fs::File; +use std::io::Read; #[tokio::main] async fn main() -> Result<(), Box> { let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); let client = OpenAIClient::builder().with_api_key(api_key).build()?; - let req = AudioTranscriptionRequest::new( - "examples/data/problem.mp3".to_string(), - WHISPER_1.to_string(), - ); + let file_path = "examples/data/problem.mp3"; + + // Test with file + let req = AudioTranscriptionRequest::new(file_path.to_string(), WHISPER_1.to_string()); let req_json = req.clone().response_format("json".to_string()); @@ -22,7 +24,19 @@ async fn main() -> Result<(), Box> { let result = client.audio_transcription_raw(req_raw).await?; println!("{:?}", result); + // Test with bytes + let mut file = File::open(file_path)?; + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer)?; + + let req = AudioTranscriptionRequest::new_bytes(buffer, WHISPER_1.to_string()); + + let req_json = req.clone().response_format("json".to_string()); + + let result = client.audio_transcription(req_json).await?; + println!("{:?}", result); + Ok(()) } -// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_translations +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_transcriptions diff --git a/src/v1/api.rs b/src/v1/api.rs index f669cca..4ee19e4 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -310,7 +310,7 @@ impl OpenAIClient { &self, req: AudioTranscriptionRequest, ) -> Result { - // https://platform.openai.com/docs/api-reference/audio/createTranslation#audio-createtranslation-response_format + // https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format if let Some(response_format) = &req.response_format { if response_format != "json" && response_format != "verbose_json" { return Err(APIError::CustomError { @@ -318,7 +318,16 @@ impl OpenAIClient { }); } } - let form = Self::create_form(&req, "file")?; + let form: Form; + if req.clone().file.is_some() { + form = Self::create_form(&req, "file")?; + } else if let Some(bytes) = req.clone().bytes { + form = Self::create_form_from_bytes(&req, bytes)?; + } else { + return Err(APIError::CustomError { + message: "Either file or bytes must be provided".to_string(), + }); + } self.post_form("audio/transcriptions", form).await } @@ -326,7 +335,7 @@ impl OpenAIClient { &self, req: AudioTranscriptionRequest, ) -> Result { - // https://platform.openai.com/docs/api-reference/audio/createTranslation#audio-createtranslation-response_format + // https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format if let Some(response_format) = &req.response_format { if response_format != "text" && response_format != "srt" && response_format != "vtt" { return Err(APIError::CustomError { @@ -334,7 +343,16 @@ impl OpenAIClient { }); } } - let form = Self::create_form(&req, "file")?; + let form: Form; + if req.clone().file.is_some() { + form = Self::create_form(&req, "file")?; + } else if let Some(bytes) = req.clone().bytes { + form = Self::create_form_from_bytes(&req, bytes)?; + } else { + return Err(APIError::CustomError { + message: "Either file or bytes must be provided".to_string(), + }); + } self.post_form_raw("audio/transcriptions", form).await } @@ -823,4 +841,36 @@ impl OpenAIClient { Ok(form) } + + fn create_form_from_bytes(req: &T, bytes: Vec) -> Result + where + T: Serialize, + { + let json = match serde_json::to_value(req) { + Ok(json) => json, + Err(e) => { + return Err(APIError::CustomError { + message: e.to_string(), + }) + } + }; + + let mut form = Form::new().part("file", Part::bytes(bytes.clone()).file_name("file.mp3")); + + if let Value::Object(map) = json { + for (key, value) in map.into_iter() { + match value { + Value::String(s) => { + form = form.text(key, s); + } + Value::Number(n) => { + form = form.text(key, n.to_string()); + } + _ => {} + } + } + } + + Ok(form) + } } diff --git a/src/v1/audio.rs b/src/v1/audio.rs index b2c87f8..130a042 100644 --- a/src/v1/audio.rs +++ b/src/v1/audio.rs @@ -8,9 +8,11 @@ pub const WHISPER_1: &str = "whisper-1"; #[derive(Debug, Serialize, Clone)] pub struct AudioTranscriptionRequest { - pub file: String, pub model: String, #[serde(skip_serializing_if = "Option::is_none")] + pub file: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub bytes: Option>, pub prompt: Option, #[serde(skip_serializing_if = "Option::is_none")] pub response_format: Option, @@ -23,8 +25,21 @@ pub struct AudioTranscriptionRequest { impl AudioTranscriptionRequest { pub fn new(file: String, model: String) -> Self { Self { - file, model, + file: Some(file), + bytes: None, + prompt: None, + response_format: None, + temperature: None, + language: None, + } + } + + pub fn new_bytes(bytes: Vec, model: String) -> Self { + Self { + model, + file: None, + bytes: Some(bytes), prompt: None, response_format: None, temperature: None,