Merge pull request #144 from dongri/fix-audio-transcription

Fix audio transcription
This commit is contained in:
Dongri Jin
2025-02-01 11:53:16 +09:00
committed by GitHub
3 changed files with 90 additions and 11 deletions

View File

@ -1,16 +1,18 @@
use openai_api_rs::v1::api::OpenAIClient;
use openai_api_rs::v1::audio::{AudioTranscriptionRequest, WHISPER_1};
use std::env;
use std::fs::File;
use std::io::Read;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
let client = OpenAIClient::builder().with_api_key(api_key).build()?;
let req = AudioTranscriptionRequest::new(
"examples/data/problem.mp3".to_string(),
WHISPER_1.to_string(),
);
let file_path = "examples/data/problem.mp3";
// Test with file
let req = AudioTranscriptionRequest::new(file_path.to_string(), WHISPER_1.to_string());
let req_json = req.clone().response_format("json".to_string());
@ -22,7 +24,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let result = client.audio_transcription_raw(req_raw).await?;
println!("{:?}", result);
// Test with bytes
let mut file = File::open(file_path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let req = AudioTranscriptionRequest::new_bytes(buffer, WHISPER_1.to_string());
let req_json = req.clone().response_format("json".to_string());
let result = client.audio_transcription(req_json).await?;
println!("{:?}", result);
Ok(())
}
// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_translations
// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_transcriptions

View File

@ -310,7 +310,7 @@ impl OpenAIClient {
&self,
req: AudioTranscriptionRequest,
) -> Result<AudioTranscriptionResponse, APIError> {
// https://platform.openai.com/docs/api-reference/audio/createTranslation#audio-createtranslation-response_format
// https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
if let Some(response_format) = &req.response_format {
if response_format != "json" && response_format != "verbose_json" {
return Err(APIError::CustomError {
@ -318,7 +318,16 @@ impl OpenAIClient {
});
}
}
let form = Self::create_form(&req, "file")?;
let form: Form;
if req.clone().file.is_some() {
form = Self::create_form(&req, "file")?;
} else if let Some(bytes) = req.clone().bytes {
form = Self::create_form_from_bytes(&req, bytes)?;
} else {
return Err(APIError::CustomError {
message: "Either file or bytes must be provided".to_string(),
});
}
self.post_form("audio/transcriptions", form).await
}
@ -326,7 +335,7 @@ impl OpenAIClient {
&self,
req: AudioTranscriptionRequest,
) -> Result<Bytes, APIError> {
// https://platform.openai.com/docs/api-reference/audio/createTranslation#audio-createtranslation-response_format
// https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
if let Some(response_format) = &req.response_format {
if response_format != "text" && response_format != "srt" && response_format != "vtt" {
return Err(APIError::CustomError {
@ -334,7 +343,16 @@ impl OpenAIClient {
});
}
}
let form = Self::create_form(&req, "file")?;
let form: Form;
if req.clone().file.is_some() {
form = Self::create_form(&req, "file")?;
} else if let Some(bytes) = req.clone().bytes {
form = Self::create_form_from_bytes(&req, bytes)?;
} else {
return Err(APIError::CustomError {
message: "Either file or bytes must be provided".to_string(),
});
}
self.post_form_raw("audio/transcriptions", form).await
}
@ -823,4 +841,36 @@ impl OpenAIClient {
Ok(form)
}
fn create_form_from_bytes<T>(req: &T, bytes: Vec<u8>) -> Result<Form, APIError>
where
T: Serialize,
{
let json = match serde_json::to_value(req) {
Ok(json) => json,
Err(e) => {
return Err(APIError::CustomError {
message: e.to_string(),
})
}
};
let mut form = Form::new().part("file", Part::bytes(bytes.clone()).file_name("file.mp3"));
if let Value::Object(map) = json {
for (key, value) in map.into_iter() {
match value {
Value::String(s) => {
form = form.text(key, s);
}
Value::Number(n) => {
form = form.text(key, n.to_string());
}
_ => {}
}
}
}
Ok(form)
}
}

View File

@ -8,9 +8,11 @@ pub const WHISPER_1: &str = "whisper-1";
#[derive(Debug, Serialize, Clone)]
pub struct AudioTranscriptionRequest {
pub file: String,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub file: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub bytes: Option<Vec<u8>>,
pub prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<String>,
@ -23,8 +25,21 @@ pub struct AudioTranscriptionRequest {
impl AudioTranscriptionRequest {
pub fn new(file: String, model: String) -> Self {
Self {
file,
model,
file: Some(file),
bytes: None,
prompt: None,
response_format: None,
temperature: None,
language: None,
}
}
pub fn new_bytes(bytes: Vec<u8>, model: String) -> Self {
Self {
model,
file: None,
bytes: Some(bytes),
prompt: None,
response_format: None,
temperature: None,