mirror of
https://github.com/mii443/openai-api-rs.git
synced 2025-08-22 15:15:34 +00:00
Merge pull request #144 from dongri/fix-audio-transcription
Fix audio transcription
This commit is contained in:
@ -1,16 +1,18 @@
|
|||||||
use openai_api_rs::v1::api::OpenAIClient;
|
use openai_api_rs::v1::api::OpenAIClient;
|
||||||
use openai_api_rs::v1::audio::{AudioTranscriptionRequest, WHISPER_1};
|
use openai_api_rs::v1::audio::{AudioTranscriptionRequest, WHISPER_1};
|
||||||
use std::env;
|
use std::env;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::Read;
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
|
let api_key = env::var("OPENAI_API_KEY").unwrap().to_string();
|
||||||
let client = OpenAIClient::builder().with_api_key(api_key).build()?;
|
let client = OpenAIClient::builder().with_api_key(api_key).build()?;
|
||||||
|
|
||||||
let req = AudioTranscriptionRequest::new(
|
let file_path = "examples/data/problem.mp3";
|
||||||
"examples/data/problem.mp3".to_string(),
|
|
||||||
WHISPER_1.to_string(),
|
// Test with file
|
||||||
);
|
let req = AudioTranscriptionRequest::new(file_path.to_string(), WHISPER_1.to_string());
|
||||||
|
|
||||||
let req_json = req.clone().response_format("json".to_string());
|
let req_json = req.clone().response_format("json".to_string());
|
||||||
|
|
||||||
@ -22,7 +24,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let result = client.audio_transcription_raw(req_raw).await?;
|
let result = client.audio_transcription_raw(req_raw).await?;
|
||||||
println!("{:?}", result);
|
println!("{:?}", result);
|
||||||
|
|
||||||
|
// Test with bytes
|
||||||
|
let mut file = File::open(file_path)?;
|
||||||
|
let mut buffer = Vec::new();
|
||||||
|
file.read_to_end(&mut buffer)?;
|
||||||
|
|
||||||
|
let req = AudioTranscriptionRequest::new_bytes(buffer, WHISPER_1.to_string());
|
||||||
|
|
||||||
|
let req_json = req.clone().response_format("json".to_string());
|
||||||
|
|
||||||
|
let result = client.audio_transcription(req_json).await?;
|
||||||
|
println!("{:?}", result);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_translations
|
// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_transcriptions
|
||||||
|
@ -310,7 +310,7 @@ impl OpenAIClient {
|
|||||||
&self,
|
&self,
|
||||||
req: AudioTranscriptionRequest,
|
req: AudioTranscriptionRequest,
|
||||||
) -> Result<AudioTranscriptionResponse, APIError> {
|
) -> Result<AudioTranscriptionResponse, APIError> {
|
||||||
// https://platform.openai.com/docs/api-reference/audio/createTranslation#audio-createtranslation-response_format
|
// https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
|
||||||
if let Some(response_format) = &req.response_format {
|
if let Some(response_format) = &req.response_format {
|
||||||
if response_format != "json" && response_format != "verbose_json" {
|
if response_format != "json" && response_format != "verbose_json" {
|
||||||
return Err(APIError::CustomError {
|
return Err(APIError::CustomError {
|
||||||
@ -318,7 +318,16 @@ impl OpenAIClient {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let form = Self::create_form(&req, "file")?;
|
let form: Form;
|
||||||
|
if req.clone().file.is_some() {
|
||||||
|
form = Self::create_form(&req, "file")?;
|
||||||
|
} else if let Some(bytes) = req.clone().bytes {
|
||||||
|
form = Self::create_form_from_bytes(&req, bytes)?;
|
||||||
|
} else {
|
||||||
|
return Err(APIError::CustomError {
|
||||||
|
message: "Either file or bytes must be provided".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
self.post_form("audio/transcriptions", form).await
|
self.post_form("audio/transcriptions", form).await
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -326,7 +335,7 @@ impl OpenAIClient {
|
|||||||
&self,
|
&self,
|
||||||
req: AudioTranscriptionRequest,
|
req: AudioTranscriptionRequest,
|
||||||
) -> Result<Bytes, APIError> {
|
) -> Result<Bytes, APIError> {
|
||||||
// https://platform.openai.com/docs/api-reference/audio/createTranslation#audio-createtranslation-response_format
|
// https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
|
||||||
if let Some(response_format) = &req.response_format {
|
if let Some(response_format) = &req.response_format {
|
||||||
if response_format != "text" && response_format != "srt" && response_format != "vtt" {
|
if response_format != "text" && response_format != "srt" && response_format != "vtt" {
|
||||||
return Err(APIError::CustomError {
|
return Err(APIError::CustomError {
|
||||||
@ -334,7 +343,16 @@ impl OpenAIClient {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let form = Self::create_form(&req, "file")?;
|
let form: Form;
|
||||||
|
if req.clone().file.is_some() {
|
||||||
|
form = Self::create_form(&req, "file")?;
|
||||||
|
} else if let Some(bytes) = req.clone().bytes {
|
||||||
|
form = Self::create_form_from_bytes(&req, bytes)?;
|
||||||
|
} else {
|
||||||
|
return Err(APIError::CustomError {
|
||||||
|
message: "Either file or bytes must be provided".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
self.post_form_raw("audio/transcriptions", form).await
|
self.post_form_raw("audio/transcriptions", form).await
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -823,4 +841,36 @@ impl OpenAIClient {
|
|||||||
|
|
||||||
Ok(form)
|
Ok(form)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn create_form_from_bytes<T>(req: &T, bytes: Vec<u8>) -> Result<Form, APIError>
|
||||||
|
where
|
||||||
|
T: Serialize,
|
||||||
|
{
|
||||||
|
let json = match serde_json::to_value(req) {
|
||||||
|
Ok(json) => json,
|
||||||
|
Err(e) => {
|
||||||
|
return Err(APIError::CustomError {
|
||||||
|
message: e.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut form = Form::new().part("file", Part::bytes(bytes.clone()).file_name("file.mp3"));
|
||||||
|
|
||||||
|
if let Value::Object(map) = json {
|
||||||
|
for (key, value) in map.into_iter() {
|
||||||
|
match value {
|
||||||
|
Value::String(s) => {
|
||||||
|
form = form.text(key, s);
|
||||||
|
}
|
||||||
|
Value::Number(n) => {
|
||||||
|
form = form.text(key, n.to_string());
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(form)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,9 +8,11 @@ pub const WHISPER_1: &str = "whisper-1";
|
|||||||
|
|
||||||
#[derive(Debug, Serialize, Clone)]
|
#[derive(Debug, Serialize, Clone)]
|
||||||
pub struct AudioTranscriptionRequest {
|
pub struct AudioTranscriptionRequest {
|
||||||
pub file: String,
|
|
||||||
pub model: String,
|
pub model: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub file: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub bytes: Option<Vec<u8>>,
|
||||||
pub prompt: Option<String>,
|
pub prompt: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub response_format: Option<String>,
|
pub response_format: Option<String>,
|
||||||
@ -23,8 +25,21 @@ pub struct AudioTranscriptionRequest {
|
|||||||
impl AudioTranscriptionRequest {
|
impl AudioTranscriptionRequest {
|
||||||
pub fn new(file: String, model: String) -> Self {
|
pub fn new(file: String, model: String) -> Self {
|
||||||
Self {
|
Self {
|
||||||
file,
|
|
||||||
model,
|
model,
|
||||||
|
file: Some(file),
|
||||||
|
bytes: None,
|
||||||
|
prompt: None,
|
||||||
|
response_format: None,
|
||||||
|
temperature: None,
|
||||||
|
language: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_bytes(bytes: Vec<u8>, model: String) -> Self {
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
file: None,
|
||||||
|
bytes: Some(bytes),
|
||||||
prompt: None,
|
prompt: None,
|
||||||
response_format: None,
|
response_format: None,
|
||||||
temperature: None,
|
temperature: None,
|
||||||
|
Reference in New Issue
Block a user