mirror of
https://github.com/mii443/openai-api-rs.git
synced 2025-08-22 15:15:34 +00:00
Merge pull request #144 from dongri/fix-audio-transcription
Fix audio transcription
This commit is contained in:
@ -1,16 +1,18 @@
|
||||
use openai_api_rs::v1::api::OpenAIClient;
|
||||
use openai_api_rs::v1::audio::{AudioTranscriptionRequest, WHISPER_1};
|
||||
use std::env;
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
|
||||
let client = OpenAIClient::builder().with_api_key(api_key).build()?;
|
||||
|
||||
let req = AudioTranscriptionRequest::new(
|
||||
"examples/data/problem.mp3".to_string(),
|
||||
WHISPER_1.to_string(),
|
||||
);
|
||||
let file_path = "examples/data/problem.mp3";
|
||||
|
||||
// Test with file
|
||||
let req = AudioTranscriptionRequest::new(file_path.to_string(), WHISPER_1.to_string());
|
||||
|
||||
let req_json = req.clone().response_format("json".to_string());
|
||||
|
||||
@ -22,7 +24,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let result = client.audio_transcription_raw(req_raw).await?;
|
||||
println!("{:?}", result);
|
||||
|
||||
// Test with bytes
|
||||
let mut file = File::open(file_path)?;
|
||||
let mut buffer = Vec::new();
|
||||
file.read_to_end(&mut buffer)?;
|
||||
|
||||
let req = AudioTranscriptionRequest::new_bytes(buffer, WHISPER_1.to_string());
|
||||
|
||||
let req_json = req.clone().response_format("json".to_string());
|
||||
|
||||
let result = client.audio_transcription(req_json).await?;
|
||||
println!("{:?}", result);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_translations
|
||||
// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_transcriptions
|
||||
|
@ -310,7 +310,7 @@ impl OpenAIClient {
|
||||
&self,
|
||||
req: AudioTranscriptionRequest,
|
||||
) -> Result<AudioTranscriptionResponse, APIError> {
|
||||
// https://platform.openai.com/docs/api-reference/audio/createTranslation#audio-createtranslation-response_format
|
||||
// https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
|
||||
if let Some(response_format) = &req.response_format {
|
||||
if response_format != "json" && response_format != "verbose_json" {
|
||||
return Err(APIError::CustomError {
|
||||
@ -318,7 +318,16 @@ impl OpenAIClient {
|
||||
});
|
||||
}
|
||||
}
|
||||
let form = Self::create_form(&req, "file")?;
|
||||
let form: Form;
|
||||
if req.clone().file.is_some() {
|
||||
form = Self::create_form(&req, "file")?;
|
||||
} else if let Some(bytes) = req.clone().bytes {
|
||||
form = Self::create_form_from_bytes(&req, bytes)?;
|
||||
} else {
|
||||
return Err(APIError::CustomError {
|
||||
message: "Either file or bytes must be provided".to_string(),
|
||||
});
|
||||
}
|
||||
self.post_form("audio/transcriptions", form).await
|
||||
}
|
||||
|
||||
@ -326,7 +335,7 @@ impl OpenAIClient {
|
||||
&self,
|
||||
req: AudioTranscriptionRequest,
|
||||
) -> Result<Bytes, APIError> {
|
||||
// https://platform.openai.com/docs/api-reference/audio/createTranslation#audio-createtranslation-response_format
|
||||
// https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
|
||||
if let Some(response_format) = &req.response_format {
|
||||
if response_format != "text" && response_format != "srt" && response_format != "vtt" {
|
||||
return Err(APIError::CustomError {
|
||||
@ -334,7 +343,16 @@ impl OpenAIClient {
|
||||
});
|
||||
}
|
||||
}
|
||||
let form = Self::create_form(&req, "file")?;
|
||||
let form: Form;
|
||||
if req.clone().file.is_some() {
|
||||
form = Self::create_form(&req, "file")?;
|
||||
} else if let Some(bytes) = req.clone().bytes {
|
||||
form = Self::create_form_from_bytes(&req, bytes)?;
|
||||
} else {
|
||||
return Err(APIError::CustomError {
|
||||
message: "Either file or bytes must be provided".to_string(),
|
||||
});
|
||||
}
|
||||
self.post_form_raw("audio/transcriptions", form).await
|
||||
}
|
||||
|
||||
@ -823,4 +841,36 @@ impl OpenAIClient {
|
||||
|
||||
Ok(form)
|
||||
}
|
||||
|
||||
fn create_form_from_bytes<T>(req: &T, bytes: Vec<u8>) -> Result<Form, APIError>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let json = match serde_json::to_value(req) {
|
||||
Ok(json) => json,
|
||||
Err(e) => {
|
||||
return Err(APIError::CustomError {
|
||||
message: e.to_string(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let mut form = Form::new().part("file", Part::bytes(bytes.clone()).file_name("file.mp3"));
|
||||
|
||||
if let Value::Object(map) = json {
|
||||
for (key, value) in map.into_iter() {
|
||||
match value {
|
||||
Value::String(s) => {
|
||||
form = form.text(key, s);
|
||||
}
|
||||
Value::Number(n) => {
|
||||
form = form.text(key, n.to_string());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(form)
|
||||
}
|
||||
}
|
||||
|
@ -8,9 +8,11 @@ pub const WHISPER_1: &str = "whisper-1";
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct AudioTranscriptionRequest {
|
||||
pub file: String,
|
||||
pub model: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub file: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bytes: Option<Vec<u8>>,
|
||||
pub prompt: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response_format: Option<String>,
|
||||
@ -23,8 +25,21 @@ pub struct AudioTranscriptionRequest {
|
||||
impl AudioTranscriptionRequest {
|
||||
pub fn new(file: String, model: String) -> Self {
|
||||
Self {
|
||||
file,
|
||||
model,
|
||||
file: Some(file),
|
||||
bytes: None,
|
||||
prompt: None,
|
||||
response_format: None,
|
||||
temperature: None,
|
||||
language: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_bytes(bytes: Vec<u8>, model: String) -> Self {
|
||||
Self {
|
||||
model,
|
||||
file: None,
|
||||
bytes: Some(bytes),
|
||||
prompt: None,
|
||||
response_format: None,
|
||||
temperature: None,
|
||||
|
Reference in New Issue
Block a user