diff --git a/Cargo.toml b/Cargo.toml index a9a13de..996b824 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "openai-api-rs" -version = "4.1.1" +version = "5.0.0" edition = "2021" authors = ["Dongri Jin "] license = "MIT" @@ -9,16 +9,17 @@ repository = "https://github.com/dongri/openai-api-rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[dependencies.reqwest] +version = "0.12" +features = ["json", "multipart"] + +[dependencies.tokio] +version = "1" +features = ["full"] + [dependencies.serde] version = "1" features = ["derive"] -default-features = false [dependencies.serde_json] version = "1" -default-features = false - -[dependencies.minreq] -version = "2" -default-features = false -features = ["https-rustls", "json-using-serde", "proxy"] diff --git a/README.md b/README.md index c64b63c..7f28b04 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Check out the [docs.rs](https://docs.rs/openai-api-rs/). Cargo.toml ```toml [dependencies] -openai-api-rs = "4.1.1" +openai-api-rs = "5.0.0" ``` ## Usage @@ -48,13 +48,14 @@ $ export OPENAI_API_BASE=https://api.openai.com/v1 ## Example of chat completion ```rust -use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::api::OpenAIClient; use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; use openai_api_rs::v1::common::GPT4_O; use std::env; -fn main() -> Result<(), Box> { - let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); let req = ChatCompletionRequest::new( GPT4_O.to_string(), @@ -65,7 +66,7 @@ fn main() -> Result<(), Box> { }], ); - let result = client.chat_completion(req)?; + let result = client.chat_completion(req).await?; println!("Content: {:?}", result.choices[0].message.content); println!("Response Headers: {:?}", result.headers); diff --git a/examples/assistant.rs b/examples/assistant.rs index c5acb83..133922e 100644 --- a/examples/assistant.rs +++ b/examples/assistant.rs @@ -1,4 +1,4 @@ -use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::api::OpenAIClient; use openai_api_rs::v1::assistant::AssistantRequest; use openai_api_rs::v1::common::GPT4_O; use openai_api_rs::v1::message::{CreateMessageRequest, MessageRole}; @@ -7,8 +7,9 @@ use openai_api_rs::v1::thread::CreateThreadRequest; use std::collections::HashMap; use std::env; -fn main() -> Result<(), Box> { - let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); let mut tools = HashMap::new(); tools.insert("type".to_string(), "code_interpreter".to_string()); @@ -21,11 +22,11 @@ fn main() -> Result<(), Box> { let req = req.clone().tools(vec![tools]); println!("AssistantRequest: {:?}", req); - let result = client.create_assistant(req)?; + let result = client.create_assistant(req).await?; println!("Create Assistant Result ID: {:?}", result.id); let thread_req = CreateThreadRequest::new(); - let thread_result = client.create_thread(thread_req)?; + let thread_result = client.create_thread(thread_req).await?; println!("Create Thread Result ID: {:?}", thread_result.id.clone()); let message_req = CreateMessageRequest::new( @@ -33,16 +34,19 @@ fn main() -> Result<(), Box> { "`I need to solve the equation 3x + 11 = 14. Can you help me?".to_string(), ); - let message_result = client.create_message(thread_result.id.clone(), message_req)?; + let message_result = client + .create_message(thread_result.id.clone(), message_req) + .await?; println!("Create Message Result ID: {:?}", message_result.id.clone()); let run_req = CreateRunRequest::new(result.id); - let run_result = client.create_run(thread_result.id.clone(), run_req)?; + let run_result = client.create_run(thread_result.id.clone(), run_req).await?; println!("Create Run Result ID: {:?}", run_result.id.clone()); loop { let run_result = client .retrieve_run(thread_result.id.clone(), run_result.id.clone()) + .await .unwrap(); if run_result.status == "completed" { break; @@ -52,7 +56,10 @@ fn main() -> Result<(), Box> { } } - let list_message_result = client.list_messages(thread_result.id.clone()).unwrap(); + let list_message_result = client + .list_messages(thread_result.id.clone()) + .await + .unwrap(); for data in list_message_result.data { for content in data.content { println!( diff --git a/examples/audio_speech.rs b/examples/audio_speech.rs new file mode 100644 index 0000000..c6541d1 --- /dev/null +++ b/examples/audio_speech.rs @@ -0,0 +1,22 @@ +use openai_api_rs::v1::api::OpenAIClient; +use openai_api_rs::v1::audio::{self, AudioSpeechRequest, TTS_1}; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + + let req = AudioSpeechRequest::new( + TTS_1.to_string(), + String::from("Money is not the problem, the problem is no money."), + audio::VOICE_ALLOY.to_string(), + String::from("examples/data/problem.mp3"), + ); + + let result = client.audio_speech(req).await?; + println!("{:?}", result); + + Ok(()) +} + +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_speech diff --git a/examples/audio_transcriptions.rs b/examples/audio_transcriptions.rs new file mode 100644 index 0000000..5a495c8 --- /dev/null +++ b/examples/audio_transcriptions.rs @@ -0,0 +1,20 @@ +use openai_api_rs::v1::api::OpenAIClient; +use openai_api_rs::v1::audio::{AudioTranscriptionRequest, WHISPER_1}; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + + let req = AudioTranscriptionRequest::new( + "examples/data/problem.mp3".to_string(), + WHISPER_1.to_string(), + ); + + let result = client.audio_transcription(req).await?; + println!("{:?}", result); + + Ok(()) +} + +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_translations diff --git a/examples/audio_translations.rs b/examples/audio_translations.rs new file mode 100644 index 0000000..89bf87c --- /dev/null +++ b/examples/audio_translations.rs @@ -0,0 +1,20 @@ +use openai_api_rs::v1::api::OpenAIClient; +use openai_api_rs::v1::audio::{AudioTranslationRequest, WHISPER_1}; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + + let req = AudioTranslationRequest::new( + "examples/data/problem_cn.mp3".to_string(), + WHISPER_1.to_string(), + ); + + let result = client.audio_translation(req).await?; + println!("{:?}", result); + + Ok(()) +} + +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_transcriptions diff --git a/examples/chat_completion.rs b/examples/chat_completion.rs index dfdabb4..d53134f 100644 --- a/examples/chat_completion.rs +++ b/examples/chat_completion.rs @@ -1,10 +1,11 @@ -use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::api::OpenAIClient; use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; use openai_api_rs::v1::common::GPT4_O; use std::env; -fn main() -> Result<(), Box> { - let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); let req = ChatCompletionRequest::new( GPT4_O.to_string(), @@ -15,7 +16,7 @@ fn main() -> Result<(), Box> { }], ); - let result = client.chat_completion(req)?; + let result = client.chat_completion(req).await?; println!("Content: {:?}", result.choices[0].message.content); println!("Response Headers: {:?}", result.headers); diff --git a/examples/completion.rs b/examples/completion.rs index 362d9b7..e0fab80 100644 --- a/examples/completion.rs +++ b/examples/completion.rs @@ -1,9 +1,10 @@ -use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::api::OpenAIClient; use openai_api_rs::v1::completion::{self, CompletionRequest}; use std::env; -fn main() -> Result<(), Box> { - let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); let req = CompletionRequest::new( completion::GPT3_TEXT_DAVINCI_003.to_string(), @@ -16,7 +17,7 @@ fn main() -> Result<(), Box> { .presence_penalty(0.6) .frequency_penalty(0.0); - let result = client.completion(req)?; + let result = client.completion(req).await?; println!("{:}", result.choices[0].text); Ok(()) diff --git a/examples/data/problem.mp3 b/examples/data/problem.mp3 new file mode 100644 index 0000000..9fce89a Binary files /dev/null and b/examples/data/problem.mp3 differ diff --git a/examples/data/problem_cn.mp3 b/examples/data/problem_cn.mp3 new file mode 100644 index 0000000..2a993a0 Binary files /dev/null and b/examples/data/problem_cn.mp3 differ diff --git a/examples/embedding.rs b/examples/embedding.rs index f415763..ef2e61d 100644 --- a/examples/embedding.rs +++ b/examples/embedding.rs @@ -1,16 +1,17 @@ -use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::api::OpenAIClient; use openai_api_rs::v1::common::TEXT_EMBEDDING_3_SMALL; use openai_api_rs::v1::embedding::EmbeddingRequest; use std::env; -fn main() -> Result<(), Box> { - let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); let mut req = EmbeddingRequest::new(TEXT_EMBEDDING_3_SMALL.to_string(), "story time".to_string()); req.dimensions = Some(10); - let result = client.embedding(req)?; + let result = client.embedding(req).await?; println!("{:?}", result.data); Ok(()) diff --git a/examples/function_call.rs b/examples/function_call.rs index 335be7b..27c6375 100644 --- a/examples/function_call.rs +++ b/examples/function_call.rs @@ -1,6 +1,6 @@ -use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::api::OpenAIClient; use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; -use openai_api_rs::v1::common::GPT3_5_TURBO_0613; +use openai_api_rs::v1::common::GPT4_O; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::{env, vec}; @@ -14,8 +14,9 @@ fn get_coin_price(coin: &str) -> f64 { } } -fn main() -> Result<(), Box> { - let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); let mut properties = HashMap::new(); properties.insert( @@ -28,7 +29,7 @@ fn main() -> Result<(), Box> { ); let req = ChatCompletionRequest::new( - GPT3_5_TURBO_0613.to_string(), + GPT4_O.to_string(), vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")), @@ -53,7 +54,7 @@ fn main() -> Result<(), Box> { // let serialized = serde_json::to_string(&req).unwrap(); // println!("{}", serialized); - let result = client.chat_completion(req)?; + let result = client.chat_completion(req).await?; match result.choices[0].finish_reason { None => { diff --git a/examples/function_call_role.rs b/examples/function_call_role.rs index eb55dd7..dcdf7ad 100644 --- a/examples/function_call_role.rs +++ b/examples/function_call_role.rs @@ -1,6 +1,6 @@ -use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::api::OpenAIClient; use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; -use openai_api_rs::v1::common::GPT3_5_TURBO_0613; +use openai_api_rs::v1::common::GPT4_O; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::{env, vec}; @@ -14,8 +14,9 @@ fn get_coin_price(coin: &str) -> f64 { } } -fn main() -> Result<(), Box> { - let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); let mut properties = HashMap::new(); properties.insert( @@ -28,7 +29,7 @@ fn main() -> Result<(), Box> { ); let req = ChatCompletionRequest::new( - GPT3_5_TURBO_0613.to_string(), + GPT4_O.to_string(), vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")), @@ -48,7 +49,7 @@ fn main() -> Result<(), Box> { }, }]); - let result = client.chat_completion(req)?; + let result = client.chat_completion(req).await?; match result.choices[0].finish_reason { None => { @@ -79,7 +80,7 @@ fn main() -> Result<(), Box> { println!("price: {}", price); let req = ChatCompletionRequest::new( - GPT3_5_TURBO_0613.to_string(), + GPT4_O.to_string(), vec![ chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, @@ -99,7 +100,7 @@ fn main() -> Result<(), Box> { ], ); - let result = client.chat_completion(req)?; + let result = client.chat_completion(req).await?; println!("{:?}", result.choices[0].message.content); } } diff --git a/examples/text_to_speech.rs b/examples/text_to_speech.rs deleted file mode 100644 index 43b5c4c..0000000 --- a/examples/text_to_speech.rs +++ /dev/null @@ -1,21 +0,0 @@ -use openai_api_rs::v1::api::Client; -use openai_api_rs::v1::audio::{self, AudioSpeechRequest, TTS_1}; -use std::env; - -fn main() -> Result<(), Box> { - let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); - - let req = AudioSpeechRequest::new( - TTS_1.to_string(), - String::from("Money is not problem, Problem is no money"), - audio::VOICE_ALLOY.to_string(), - String::from("problem.mp3"), - ); - - let result = client.audio_speech(req)?; - println!("{:?}", result); - - Ok(()) -} - -// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example text_to_speech diff --git a/examples/vision.rs b/examples/vision.rs index 195c4c0..b62a653 100644 --- a/examples/vision.rs +++ b/examples/vision.rs @@ -1,10 +1,11 @@ -use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::api::OpenAIClient; use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; use openai_api_rs::v1::common::GPT4_VISION_PREVIEW; use std::env; -fn main() -> Result<(), Box> { - let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); let req = ChatCompletionRequest::new( GPT4_VISION_PREVIEW.to_string(), @@ -30,7 +31,7 @@ fn main() -> Result<(), Box> { }], ); - let result = client.chat_completion(req)?; + let result = client.chat_completion(req).await?; println!("{:?}", result.choices[0].message.content); Ok(()) diff --git a/src/v1/api.rs b/src/v1/api.rs index 7b520bc..1f3cf66 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -36,14 +36,19 @@ use crate::v1::run::{ }; use crate::v1::thread::{CreateThreadRequest, ModifyThreadRequest, ThreadObject}; -use minreq::Response; +use reqwest::multipart::{Form, Part}; +use reqwest::{Client, Method, Response}; +use serde::Serialize; +use serde_json::Value; + use std::fs::{create_dir_all, File}; +use std::io::Read; use std::io::Write; use std::path::Path; const API_URL_V1: &str = "https://api.openai.com/v1"; -pub struct Client { +pub struct OpenAIClient { pub api_endpoint: String, pub api_key: String, pub organization: Option, @@ -51,7 +56,7 @@ pub struct Client { pub timeout: Option, } -impl Client { +impl OpenAIClient { pub fn new(api_key: String) -> Self { let endpoint = std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned()); Self::new_with_endpoint(endpoint, api_key) @@ -72,7 +77,7 @@ impl Client { Self { api_endpoint: endpoint, api_key, - organization: organization.into(), + organization: Some(organization), proxy: None, timeout: None, } @@ -100,536 +105,336 @@ impl Client { } } - pub fn build_request(&self, request: minreq::Request, is_beta: bool) -> minreq::Request { - let mut request = request - .with_header("Content-Type", "application/json") - .with_header("Authorization", format!("Bearer {}", self.api_key)); + async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.api_endpoint, path); + let client = Client::builder(); + + let client = if let Some(timeout) = self.timeout { + client.timeout(std::time::Duration::from_secs(timeout)) + } else { + client + }; + + let client = if let Some(proxy) = &self.proxy { + client.proxy(reqwest::Proxy::http(proxy).unwrap()) + } else { + client + }; + + let client = client.build().unwrap(); + + let mut request = client + .request(method, url) + // .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", self.api_key)); + if let Some(organization) = &self.organization { - request = request.with_header("openai-organization", organization); + request = request.header("openai-organization", organization); } - if is_beta { - request = request.with_header("OpenAI-Beta", "assistants=v2"); - } - if let Some(proxy) = &self.proxy { - request = request.with_proxy(minreq::Proxy::new(proxy).unwrap()); - } - if let Some(timeout) = self.timeout { - request = request.with_timeout(timeout); + + if Self::is_beta(path) { + request = request.header("OpenAI-Beta", "assistants=v2"); } + request } - pub fn post( + async fn post( &self, path: &str, - params: &T, - ) -> Result { - let url = format!( - "{api_endpoint}{path}", - api_endpoint = self.api_endpoint, - path = path - ); - let request = self.build_request(minreq::post(url), Self::is_beta(path)); - let res = request.with_json(params).unwrap().send(); - match res { - Ok(res) => { - if (200..=299).contains(&res.status_code) { - Ok(res) - } else { - Err(APIError { - message: format!("{}: {}", res.status_code, res.as_str().unwrap()), - }) - } - } - Err(e) => Err(self.new_error(e)), + body: &impl serde::ser::Serialize, + ) -> Result { + let request = self.build_request(Method::POST, path).await; + let request = request.json(body); + let response = request.send().await?; + self.handle_response(response).await + } + + async fn get(&self, path: &str) -> Result { + let request = self.build_request(Method::GET, path).await; + let response = request.send().await?; + self.handle_response(response).await + } + + async fn delete(&self, path: &str) -> Result { + let request = self.build_request(Method::DELETE, path).await; + let response = request.send().await?; + self.handle_response(response).await + } + + async fn post_form( + &self, + path: &str, + form: Form, + ) -> Result { + let request = self.build_request(Method::POST, path).await; + let request = request.multipart(form); + let response = request.send().await?; + self.handle_response(response).await + } + + async fn handle_response( + &self, + response: Response, + ) -> Result { + let status = response.status(); + if status.is_success() { + let parsed = response.json::().await?; + Ok(parsed) + } else { + let error_message = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + Err(APIError::CustomError { + message: format!("{}: {}", status, error_message), + }) } } - pub fn get(&self, path: &str) -> Result { - let url = format!( - "{api_endpoint}{path}", - api_endpoint = self.api_endpoint, - path = path - ); - let request = self.build_request(minreq::get(url), Self::is_beta(path)); - let res = request.send(); - match res { - Ok(res) => { - if (200..=299).contains(&res.status_code) { - Ok(res) - } else { - Err(APIError { - message: format!("{}: {}", res.status_code, res.as_str().unwrap()), - }) - } - } - Err(e) => Err(self.new_error(e)), - } + pub async fn completion(&self, req: CompletionRequest) -> Result { + self.post("completions", &req).await } - pub fn delete(&self, path: &str) -> Result { - let url = format!( - "{api_endpoint}{path}", - api_endpoint = self.api_endpoint, - path = path - ); - let request = self.build_request(minreq::delete(url), Self::is_beta(path)); - let res = request.send(); - match res { - Ok(res) => { - if (200..=299).contains(&res.status_code) { - Ok(res) - } else { - Err(APIError { - message: format!("{}: {}", res.status_code, res.as_str().unwrap()), - }) - } - } - Err(e) => Err(self.new_error(e)), - } + pub async fn edit(&self, req: EditRequest) -> Result { + self.post("edits", &req).await } - pub fn completion(&self, req: CompletionRequest) -> Result { - let res = self.post("/completions", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } - } - - pub fn edit(&self, req: EditRequest) -> Result { - let res = self.post("/edits", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } - } - - pub fn image_generation( + pub async fn image_generation( &self, req: ImageGenerationRequest, ) -> Result { - let res = self.post("/images/generations", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.post("images/generations", &req).await } - pub fn image_edit(&self, req: ImageEditRequest) -> Result { - let res = self.post("/images/edits", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + pub async fn image_edit(&self, req: ImageEditRequest) -> Result { + self.post("images/edits", &req).await } - pub fn image_variation( + pub async fn image_variation( &self, req: ImageVariationRequest, ) -> Result { - let res = self.post("/images/variations", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.post("images/variations", &req).await } - pub fn embedding(&self, req: EmbeddingRequest) -> Result { - let res = self.post("/embeddings", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + pub async fn embedding(&self, req: EmbeddingRequest) -> Result { + self.post("embeddings", &req).await } - pub fn file_list(&self) -> Result { - let res = self.get("/files")?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + pub async fn file_list(&self) -> Result { + self.get("files").await } - pub fn file_upload(&self, req: FileUploadRequest) -> Result { - let res = self.post("/files", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + pub async fn file_upload( + &self, + req: FileUploadRequest, + ) -> Result { + self.post("files", &req).await } - pub fn file_delete(&self, req: FileDeleteRequest) -> Result { - let res = self.delete(&format!("{}/{}", "/files", req.file_id))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + pub async fn file_delete( + &self, + req: FileDeleteRequest, + ) -> Result { + self.delete(&format!("files/{}", req.file_id)).await } - pub fn file_retrieve( + pub async fn file_retrieve( &self, req: FileRetrieveRequest, ) -> Result { - let res = self.get(&format!("{}/{}", "/files", req.file_id))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.get(&format!("files/{}", req.file_id)).await } - pub fn file_retrieve_content( + pub async fn file_retrieve_content( &self, req: FileRetrieveContentRequest, ) -> Result { - let res = self.get(&format!("{}/{}/content", "/files", req.file_id))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.get(&format!("files/{}/content", req.file_id)).await } - pub fn chat_completion( + pub async fn chat_completion( &self, req: ChatCompletionRequest, ) -> Result { - let res = self.post("/chat/completions", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.post("chat/completions", &req).await } - pub fn audio_transcription( + pub async fn audio_transcription( &self, req: AudioTranscriptionRequest, ) -> Result { - let res = self.post("/audio/transcriptions", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + let form = Self::create_form(&req, "file")?; + self.post_form("audio/transcriptions", form).await } - pub fn audio_translation( + pub async fn audio_translation( &self, req: AudioTranslationRequest, ) -> Result { - let res = self.post("/audio/translations", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + let form = Self::create_form(&req, "file")?; + self.post_form("audio/translations", form).await } - pub fn audio_speech(&self, req: AudioSpeechRequest) -> Result { - let res = self.post("/audio/speech", &req)?; - let bytes = res.as_bytes(); - let path = req.output.as_str(); - let path = Path::new(path); + pub async fn audio_speech( + &self, + req: AudioSpeechRequest, + ) -> Result { + let request = self.build_request(Method::POST, "audio/speech").await; + let request = request.json(&req); + let response = request.send().await?; + let headers = response.headers().clone(); + let bytes = response.bytes().await?; + let path = Path::new(req.output.as_str()); if let Some(parent) = path.parent() { match create_dir_all(parent) { Ok(_) => {} Err(e) => { - return Err(APIError { + return Err(APIError::CustomError { message: e.to_string(), }) } } } match File::create(path) { - Ok(mut file) => match file.write_all(bytes) { + Ok(mut file) => match file.write_all(&bytes) { Ok(_) => {} Err(e) => { - return Err(APIError { + return Err(APIError::CustomError { message: e.to_string(), }) } }, Err(e) => { - return Err(APIError { + return Err(APIError::CustomError { message: e.to_string(), }) } } + Ok(AudioSpeechResponse { result: true, - headers: Some(res.headers), + headers: Some(headers), }) } - pub fn create_fine_tuning_job( + pub async fn create_fine_tuning_job( &self, req: CreateFineTuningJobRequest, ) -> Result { - let res = self.post("/fine_tuning/jobs", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.post("fine_tuning/jobs", &req).await } - pub fn list_fine_tuning_jobs( + pub async fn list_fine_tuning_jobs( &self, ) -> Result, APIError> { - let res = self.get("/fine_tuning/jobs")?; - let r = res.json::>(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.get("fine_tuning/jobs").await } - pub fn list_fine_tuning_job_events( + pub async fn list_fine_tuning_job_events( &self, req: ListFineTuningJobEventsRequest, ) -> Result, APIError> { - let res = self.get(&format!( - "/fine_tuning/jobs/{}/events", + self.get(&format!( + "fine_tuning/jobs/{}/events", req.fine_tuning_job_id - ))?; - let r = res.json::>(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + )) + .await } - pub fn retrieve_fine_tuning_job( + pub async fn retrieve_fine_tuning_job( &self, req: RetrieveFineTuningJobRequest, ) -> Result { - let res = self.get(&format!("/fine_tuning/jobs/{}", req.fine_tuning_job_id))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.get(&format!("fine_tuning/jobs/{}", req.fine_tuning_job_id)) + .await } - pub fn cancel_fine_tuning_job( + pub async fn cancel_fine_tuning_job( &self, req: CancelFineTuningJobRequest, ) -> Result { - let res = self.post( - &format!("/fine_tuning/jobs/{}/cancel", req.fine_tuning_job_id), + self.post( + &format!("fine_tuning/jobs/{}/cancel", req.fine_tuning_job_id), &req, - )?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + ) + .await } - pub fn create_moderation( + pub async fn create_moderation( &self, req: CreateModerationRequest, ) -> Result { - let res = self.post("/moderations", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.post("moderations", &req).await } - pub fn create_assistant(&self, req: AssistantRequest) -> Result { - let res = self.post("/assistants", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + pub async fn create_assistant( + &self, + req: AssistantRequest, + ) -> Result { + self.post("assistants", &req).await } - pub fn retrieve_assistant(&self, assistant_id: String) -> Result { - let res = self.get(&format!("/assistants/{}", assistant_id))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + pub async fn retrieve_assistant( + &self, + assistant_id: String, + ) -> Result { + self.get(&format!("assistants/{}", assistant_id)).await } - pub fn modify_assistant( + pub async fn modify_assistant( &self, assistant_id: String, req: AssistantRequest, ) -> Result { - let res = self.post(&format!("/assistants/{}", assistant_id), &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.post(&format!("assistants/{}", assistant_id), &req) + .await } - pub fn delete_assistant(&self, assistant_id: String) -> Result { - let res = self.delete(&format!("/assistants/{}", assistant_id))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + pub async fn delete_assistant(&self, assistant_id: String) -> Result { + self.delete(&format!("assistants/{}", assistant_id)).await } - pub fn list_assistant( + pub async fn list_assistant( &self, limit: Option, order: Option, after: Option, before: Option, ) -> Result { - let mut url = "/assistants".to_owned(); - url = Self::query_params(limit, order, after, before, url); - let res = self.get(&url)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + let url = Self::query_params(limit, order, after, before, "assistants".to_string()); + self.get(&url).await } - pub fn create_assistant_file( + pub async fn create_assistant_file( &self, assistant_id: String, req: AssistantFileRequest, ) -> Result { - let res = self.post(&format!("/assistants/{}/files", assistant_id), &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.post(&format!("assistants/{}/files", assistant_id), &req) + .await } - pub fn retrieve_assistant_file( + pub async fn retrieve_assistant_file( &self, assistant_id: String, file_id: String, ) -> Result { - let res = self.get(&format!("/assistants/{}/files/{}", assistant_id, file_id))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.get(&format!("assistants/{}/files/{}", assistant_id, file_id)) + .await } - pub fn delete_assistant_file( + pub async fn delete_assistant_file( &self, assistant_id: String, file_id: String, ) -> Result { - let res = self.delete(&format!("/assistants/{}/files/{}", assistant_id, file_id))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.delete(&format!("assistants/{}/files/{}", assistant_id, file_id)) + .await } - pub fn list_assistant_file( + pub async fn list_assistant_file( &self, assistant_id: String, limit: Option, @@ -637,156 +442,85 @@ impl Client { after: Option, before: Option, ) -> Result { - let mut url = format!("/assistants/{}/files", assistant_id); - url = Self::query_params(limit, order, after, before, url); - let res = self.get(&url)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + let url = Self::query_params( + limit, + order, + after, + before, + format!("assistants/{}/files", assistant_id), + ); + self.get(&url).await } - pub fn create_thread(&self, req: CreateThreadRequest) -> Result { - let res = self.post("/threads", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + pub async fn create_thread(&self, req: CreateThreadRequest) -> Result { + self.post("threads", &req).await } - pub fn retrieve_thread(&self, thread_id: String) -> Result { - let res = self.get(&format!("/threads/{}", thread_id))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + pub async fn retrieve_thread(&self, thread_id: String) -> Result { + self.get(&format!("threads/{}", thread_id)).await } - pub fn modify_thread( + pub async fn modify_thread( &self, thread_id: String, req: ModifyThreadRequest, ) -> Result { - let res = self.post(&format!("/threads/{}", thread_id), &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.post(&format!("threads/{}", thread_id), &req).await } - pub fn delete_thread(&self, thread_id: String) -> Result { - let res = self.delete(&format!("/threads/{}", thread_id))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + pub async fn delete_thread(&self, thread_id: String) -> Result { + self.delete(&format!("threads/{}", thread_id)).await } - pub fn create_message( + pub async fn create_message( &self, thread_id: String, req: CreateMessageRequest, ) -> Result { - let res = self.post(&format!("/threads/{}/messages", thread_id), &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.post(&format!("threads/{}/messages", thread_id), &req) + .await } - pub fn retrieve_message( + pub async fn retrieve_message( &self, thread_id: String, message_id: String, ) -> Result { - let res = self.get(&format!("/threads/{}/messages/{}", thread_id, message_id))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.get(&format!("threads/{}/messages/{}", thread_id, message_id)) + .await } - pub fn modify_message( + pub async fn modify_message( &self, thread_id: String, message_id: String, req: ModifyMessageRequest, ) -> Result { - let res = self.post( - &format!("/threads/{}/messages/{}", thread_id, message_id), + self.post( + &format!("threads/{}/messages/{}", thread_id, message_id), &req, - )?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + ) + .await } - pub fn list_messages(&self, thread_id: String) -> Result { - let res = self.get(&format!("/threads/{}/messages", thread_id))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + pub async fn list_messages(&self, thread_id: String) -> Result { + self.get(&format!("threads/{}/messages", thread_id)).await } - pub fn retrieve_message_file( + pub async fn retrieve_message_file( &self, thread_id: String, message_id: String, file_id: String, ) -> Result { - let res = self.get(&format!( - "/threads/{}/messages/{}/files/{}", + self.get(&format!( + "threads/{}/messages/{}/files/{}", thread_id, message_id, file_id - ))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + )) + .await } - pub fn list_message_file( + pub async fn list_message_file( &self, thread_id: String, message_id: String, @@ -795,65 +529,45 @@ impl Client { after: Option, before: Option, ) -> Result { - let mut url = format!("/threads/{}/messages/{}/files", thread_id, message_id); - url = Self::query_params(limit, order, after, before, url); - let res = self.get(&url)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + let url = Self::query_params( + limit, + order, + after, + before, + format!("threads/{}/messages/{}/files", thread_id, message_id), + ); + self.get(&url).await } - pub fn create_run( + pub async fn create_run( &self, thread_id: String, req: CreateRunRequest, ) -> Result { - let res = self.post(&format!("/threads/{}/runs", thread_id), &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.post(&format!("threads/{}/runs", thread_id), &req) + .await } - pub fn retrieve_run(&self, thread_id: String, run_id: String) -> Result { - let res = self.get(&format!("/threads/{}/runs/{}", thread_id, run_id))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + pub async fn retrieve_run( + &self, + thread_id: String, + run_id: String, + ) -> Result { + self.get(&format!("threads/{}/runs/{}", thread_id, run_id)) + .await } - pub fn modify_run( + pub async fn modify_run( &self, thread_id: String, run_id: String, req: ModifyRunRequest, ) -> Result { - let res = self.post(&format!("/threads/{}/runs/{}", thread_id, run_id), &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.post(&format!("threads/{}/runs/{}", thread_id, run_id), &req) + .await } - pub fn list_run( + pub async fn list_run( &self, thread_id: String, limit: Option, @@ -861,71 +575,49 @@ impl Client { after: Option, before: Option, ) -> Result { - let mut url = format!("/threads/{}/runs", thread_id); - url = Self::query_params(limit, order, after, before, url); - let res = self.get(&url)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + let url = Self::query_params( + limit, + order, + after, + before, + format!("threads/{}/runs", thread_id), + ); + self.get(&url).await } - pub fn cancel_run(&self, thread_id: String, run_id: String) -> Result { - let empty_req = ModifyRunRequest::new(); - let res = self.post( - &format!("/threads/{}/runs/{}/cancel", thread_id, run_id), - &empty_req, - )?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + pub async fn cancel_run( + &self, + thread_id: String, + run_id: String, + ) -> Result { + self.post( + &format!("threads/{}/runs/{}/cancel", thread_id, run_id), + &ModifyRunRequest::default(), + ) + .await } - pub fn create_thread_and_run( + pub async fn create_thread_and_run( &self, req: CreateThreadAndRunRequest, ) -> Result { - let res = self.post("/threads/runs", &req)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + self.post("threads/runs", &req).await } - pub fn retrieve_run_step( + pub async fn retrieve_run_step( &self, thread_id: String, run_id: String, step_id: String, ) -> Result { - let res = self.get(&format!( - "/threads/{}/runs/{}/steps/{}", + self.get(&format!( + "threads/{}/runs/{}/steps/{}", thread_id, run_id, step_id - ))?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } + )) + .await } - pub fn list_run_step( + pub async fn list_run_step( &self, thread_id: String, run_id: String, @@ -934,27 +626,14 @@ impl Client { after: Option, before: Option, ) -> Result { - let mut url = format!("/threads/{}/runs/{}/steps", thread_id, run_id); - url = Self::query_params(limit, order, after, before, url); - let res = self.get(&url)?; - let r = res.json::(); - match r { - Ok(mut r) => { - r.headers = Some(res.headers); - Ok(r) - } - Err(e) => Err(self.new_error(e)), - } - } - - fn new_error(&self, err: minreq::Error) -> APIError { - APIError { - message: err.to_string(), - } - } - - fn is_beta(path: &str) -> bool { - path.starts_with("/assistants") || path.starts_with("/threads") + let url = Self::query_params( + limit, + order, + after, + before, + format!("threads/{}/runs/{}/steps", thread_id, run_id), + ); + self.get(&url).await } fn query_params( @@ -982,4 +661,72 @@ impl Client { } url } + + fn is_beta(path: &str) -> bool { + path.starts_with("assistants") || path.starts_with("threads") + } + + fn create_form(req: &T, file_field: &str) -> Result + where + T: Serialize, + { + let json = match serde_json::to_value(req) { + Ok(json) => json, + Err(e) => { + return Err(APIError::CustomError { + message: e.to_string(), + }) + } + }; + let file_path = if let Value::Object(map) = &json { + map.get(file_field) + .and_then(|v| v.as_str()) + .ok_or(APIError::CustomError { + message: format!("Field '{}' not found or not a string", file_field), + })? + } else { + return Err(APIError::CustomError { + message: "Request is not a JSON object".to_string(), + }); + }; + + let mut file = match File::open(file_path) { + Ok(file) => file, + Err(e) => { + return Err(APIError::CustomError { + message: e.to_string(), + }) + } + }; + let mut buffer = Vec::new(); + match file.read_to_end(&mut buffer) { + Ok(_) => {} + Err(e) => { + return Err(APIError::CustomError { + message: e.to_string(), + }) + } + } + + let mut form = + Form::new().part("file", Part::bytes(buffer).file_name(file_path.to_string())); + + if let Value::Object(map) = json { + for (key, value) in map.into_iter() { + if key != file_field { + match value { + Value::String(s) => { + form = form.text(key, s); + } + Value::Number(n) => { + form = form.text(key, n.to_string()); + } + _ => {} + } + } + } + } + + Ok(form) + } } diff --git a/src/v1/audio.rs b/src/v1/audio.rs index 07e53bd..b2c87f8 100644 --- a/src/v1/audio.rs +++ b/src/v1/audio.rs @@ -1,6 +1,6 @@ -use std::collections::HashMap; - +use reqwest::header::HeaderMap; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use crate::impl_builder_methods; @@ -115,8 +115,8 @@ impl AudioSpeechRequest { impl_builder_methods!(AudioSpeechRequest,); -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug)] pub struct AudioSpeechResponse { pub result: bool, - pub headers: Option>, + pub headers: Option, } diff --git a/src/v1/error.rs b/src/v1/error.rs index b2d2bf0..d1625a2 100644 --- a/src/v1/error.rs +++ b/src/v1/error.rs @@ -1,15 +1,26 @@ +use reqwest::{self}; use std::error::Error; use std::fmt; #[derive(Debug)] -pub struct APIError { - pub message: String, +pub enum APIError { + ReqwestError(reqwest::Error), + CustomError { message: String }, } impl fmt::Display for APIError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "APIError: {}", self.message) + match self { + APIError::ReqwestError(err) => write!(f, "ReqwestError: {}", err), + APIError::CustomError { message } => write!(f, "APIError: {}", message), + } } } impl Error for APIError {} + +impl From for APIError { + fn from(err: reqwest::Error) -> APIError { + APIError::ReqwestError(err) + } +}