diff --git a/Cargo.toml b/Cargo.toml index 9ae5a37..38979df 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "openai-api-rs" -version = "5.2.4" +version = "6.0.0" edition = "2021" authors = ["Dongri Jin "] license = "MIT" diff --git a/README.md b/README.md index f2f2d27..03dfb95 100644 --- a/README.md +++ b/README.md @@ -7,21 +7,32 @@ Check out the [docs.rs](https://docs.rs/openai-api-rs/). Cargo.toml ```toml [dependencies] -openai-api-rs = "5.2.4" +openai-api-rs = "6.0.0" ``` ## Usage The library needs to be configured with your account's secret key, which is available on the [website](https://platform.openai.com/account/api-keys). We recommend setting it as an environment variable. Here's an example of initializing the library with the API key loaded from an environment variable and creating a completion: -### Set OPENAI_API_KEY to environment variable +### Set OPENAI_API_KEY or OPENROUTER_API_KEY to environment variable ```bash $ export OPENAI_API_KEY=sk-xxxxxxx +or +$ export OPENROUTER_API_KEY=sk-xxxxxxx ``` -### Create client +### Create OpenAI client ```rust let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); -let client = OpenAIClient::builder().with_api_key(api_key).build()?; +let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; +``` + +### Create OpenRouter client +```rust +let api_key = env::var("OPENROUTER_API_KEY").unwrap().to_string(); +let mut client = OpenAIClient::builder() + .with_endpoint("https://openrouter.ai/api/v1") + .with_api_key(api_key) + .build()?; ``` ### Create request @@ -42,6 +53,10 @@ let req = ChatCompletionRequest::new( ```rust let result = client.chat_completion(req)?; println!("Content: {:?}", result.choices[0].message.content); + +for (key, value) in client.headers.unwrap().iter() { + println!("{}: {:?}", key, value); +} ``` ### Set OPENAI_API_BASE to environment variable (optional) @@ -59,7 +74,7 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); - let client = OpenAIClient::builder().with_api_key(api_key).build()?; + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = ChatCompletionRequest::new( GPT4_O.to_string(), @@ -74,11 +89,52 @@ async fn main() -> Result<(), Box> { let result = client.chat_completion(req).await?; println!("Content: {:?}", result.choices[0].message.content); - println!("Response Headers: {:?}", result.headers); + + for (key, value) in client.headers.unwrap().iter() { + println!("{}: {:?}", key, value); + } Ok(()) } ``` + +## Example for OpenRouter +```rust +use openai_api_rs::v1::api::OpenAIClient; +use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use openai_api_rs::v1::common::GPT4_O_MINI; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let api_key = env::var("OPENROUTER_API_KEY").unwrap().to_string(); + let mut client = OpenAIClient::builder() + .with_endpoint("https://openrouter.ai/api/v1") + .with_api_key(api_key) + .build()?; + + let req = ChatCompletionRequest::new( + GPT4_O_MINI.to_string(), + vec![chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::user, + content: chat_completion::Content::Text(String::from("What is bitcoin?")), + name: None, + tool_calls: None, + tool_call_id: None, + }], + ); + + let result = client.chat_completion(req).await?; + println!("Content: {:?}", result.choices[0].message.content); + + for (key, value) in client.headers.unwrap().iter() { + println!("{}: {:?}", key, value); + } + + Ok(()) +} +``` + More Examples: [examples](https://github.com/dongri/openai-api-rs/tree/main/examples) Check out the [full API documentation](https://platform.openai.com/docs/api-reference/completions) for examples of all the available functions. diff --git a/examples/assistant.rs b/examples/assistant.rs index 2e212b8..b7573e6 100644 --- a/examples/assistant.rs +++ b/examples/assistant.rs @@ -10,7 +10,7 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); - let client = OpenAIClient::builder().with_api_key(api_key).build()?; + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; let mut tools = HashMap::new(); tools.insert("type".to_string(), "code_interpreter".to_string()); diff --git a/examples/audio_speech.rs b/examples/audio_speech.rs index 9e3af08..7df168f 100644 --- a/examples/audio_speech.rs +++ b/examples/audio_speech.rs @@ -5,7 +5,7 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); - let client = OpenAIClient::builder().with_api_key(api_key).build()?; + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = AudioSpeechRequest::new( TTS_1.to_string(), diff --git a/examples/audio_transcriptions.rs b/examples/audio_transcriptions.rs index 49f5b88..2c341de 100644 --- a/examples/audio_transcriptions.rs +++ b/examples/audio_transcriptions.rs @@ -1,16 +1,18 @@ use openai_api_rs::v1::api::OpenAIClient; use openai_api_rs::v1::audio::{AudioTranscriptionRequest, WHISPER_1}; use std::env; +use std::fs::File; +use std::io::Read; #[tokio::main] async fn main() -> Result<(), Box> { let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); - let client = OpenAIClient::builder().with_api_key(api_key).build()?; + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; - let req = AudioTranscriptionRequest::new( - "examples/data/problem.mp3".to_string(), - WHISPER_1.to_string(), - ); + let file_path = "examples/data/problem.mp3"; + + // Test with file + let req = AudioTranscriptionRequest::new(file_path.to_string(), WHISPER_1.to_string()); let req_json = req.clone().response_format("json".to_string()); @@ -22,7 +24,19 @@ async fn main() -> Result<(), Box> { let result = client.audio_transcription_raw(req_raw).await?; println!("{:?}", result); + // Test with bytes + let mut file = File::open(file_path)?; + let mut buffer = Vec::new(); + file.read_to_end(&mut buffer)?; + + let req = AudioTranscriptionRequest::new_bytes(buffer, WHISPER_1.to_string()); + + let req_json = req.clone().response_format("json".to_string()); + + let result = client.audio_transcription(req_json).await?; + println!("{:?}", result); + Ok(()) } -// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_translations +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example audio_transcriptions diff --git a/examples/audio_translations.rs b/examples/audio_translations.rs index 13b57e0..d352b65 100644 --- a/examples/audio_translations.rs +++ b/examples/audio_translations.rs @@ -5,7 +5,7 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); - let client = OpenAIClient::builder().with_api_key(api_key).build()?; + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = AudioTranslationRequest::new( "examples/data/problem_cn.mp3".to_string(), diff --git a/examples/batch.rs b/examples/batch.rs index f14617e..aca5784 100644 --- a/examples/batch.rs +++ b/examples/batch.rs @@ -10,7 +10,7 @@ use std::str; #[tokio::main] async fn main() -> Result<(), Box> { let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); - let client = OpenAIClient::builder().with_api_key(api_key).build()?; + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = FileUploadRequest::new( "examples/data/batch_request.json".to_string(), diff --git a/examples/chat_completion.rs b/examples/chat_completion.rs index 7a5791c..64f564e 100644 --- a/examples/chat_completion.rs +++ b/examples/chat_completion.rs @@ -6,7 +6,7 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); - let client = OpenAIClient::builder().with_api_key(api_key).build()?; + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = ChatCompletionRequest::new( GPT4_O_MINI.to_string(), @@ -21,7 +21,11 @@ async fn main() -> Result<(), Box> { let result = client.chat_completion(req).await?; println!("Content: {:?}", result.choices[0].message.content); - println!("Response Headers: {:?}", result.headers); + + // print response headers + for (key, value) in client.headers.unwrap().iter() { + println!("{}: {:?}", key, value); + } Ok(()) } diff --git a/examples/completion.rs b/examples/completion.rs index 138c1fe..95cbc23 100644 --- a/examples/completion.rs +++ b/examples/completion.rs @@ -5,7 +5,7 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); - let client = OpenAIClient::builder().with_api_key(api_key).build()?; + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = CompletionRequest::new( completion::GPT3_TEXT_DAVINCI_003.to_string(), diff --git a/examples/embedding.rs b/examples/embedding.rs index 8615bdb..23ca6db 100644 --- a/examples/embedding.rs +++ b/examples/embedding.rs @@ -6,7 +6,7 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); - let client = OpenAIClient::builder().with_api_key(api_key).build()?; + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; let mut req = EmbeddingRequest::new( TEXT_EMBEDDING_3_SMALL.to_string(), diff --git a/examples/function_call.rs b/examples/function_call.rs index 1858465..06bd922 100644 --- a/examples/function_call.rs +++ b/examples/function_call.rs @@ -18,7 +18,7 @@ fn get_coin_price(coin: &str) -> f64 { #[tokio::main] async fn main() -> Result<(), Box> { let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); - let client = OpenAIClient::builder().with_api_key(api_key).build()?; + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; let mut properties = HashMap::new(); properties.insert( diff --git a/examples/function_call_role.rs b/examples/function_call_role.rs index 1afc91a..46148f4 100644 --- a/examples/function_call_role.rs +++ b/examples/function_call_role.rs @@ -18,7 +18,7 @@ fn get_coin_price(coin: &str) -> f64 { #[tokio::main] async fn main() -> Result<(), Box> { let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); - let client = OpenAIClient::builder().with_api_key(api_key).build()?; + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; let mut properties = HashMap::new(); properties.insert( diff --git a/examples/openrouter.rs b/examples/openrouter.rs new file mode 100644 index 0000000..dd04e0a --- /dev/null +++ b/examples/openrouter.rs @@ -0,0 +1,32 @@ +use openai_api_rs::v1::api::OpenAIClient; +use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use openai_api_rs::v1::common::GPT4_O_MINI; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let api_key = env::var("OPENROUTER_API_KEY").unwrap().to_string(); + let mut client = OpenAIClient::builder() + .with_endpoint("https://openrouter.ai/api/v1") + .with_api_key(api_key) + .build()?; + + let req = ChatCompletionRequest::new( + GPT4_O_MINI.to_string(), + vec![chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::user, + content: chat_completion::Content::Text(String::from("What is bitcoin?")), + name: None, + tool_calls: None, + tool_call_id: None, + }], + ); + + let result = client.chat_completion(req).await?; + println!("Content: {:?}", result.choices[0].message.content); + println!("Response Headers: {:?}", client.headers); + + Ok(()) +} + +// OPENROUTER_API_KEY=xxxx cargo run --package openai-api-rs --example openrouter diff --git a/examples/vision.rs b/examples/vision.rs index 6a92c43..7bad362 100644 --- a/examples/vision.rs +++ b/examples/vision.rs @@ -6,7 +6,7 @@ use std::env; #[tokio::main] async fn main() -> Result<(), Box> { let api_key = env::var("OPENAI_API_KEY").unwrap().to_string(); - let client = OpenAIClient::builder().with_api_key(api_key).build()?; + let mut client = OpenAIClient::builder().with_api_key(api_key).build()?; let req = ChatCompletionRequest::new( GPT4_O.to_string(), diff --git a/src/realtime/types.rs b/src/realtime/types.rs index 0cc70d9..a9df8bd 100644 --- a/src/realtime/types.rs +++ b/src/realtime/types.rs @@ -30,17 +30,22 @@ pub struct Session { #[serde(rename_all = "lowercase")] pub enum RealtimeVoice { Alloy, - Shimmer, + Ash, + Ballad, + Coral, Echo, + Sage, + Shimmer, + Verse, } #[derive(Debug, Serialize, Deserialize, Clone)] pub enum AudioFormat { #[serde(rename = "pcm16")] PCM16, - #[serde(rename = "g711-ulaw")] + #[serde(rename = "g711_ulaw")] G711ULAW, - #[serde(rename = "g711-alaw")] + #[serde(rename = "g711_alaw")] G711ALAW, } diff --git a/src/v1/api.rs b/src/v1/api.rs index f669cca..c53f989 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -68,7 +68,7 @@ pub struct OpenAIClient { organization: Option, proxy: Option, timeout: Option, - headers: Option, + pub headers: Option, } impl OpenAIClientBuilder { @@ -136,6 +136,9 @@ impl OpenAIClient { let url = format!("{}/{}", self.api_endpoint, path); let client = Client::builder(); + #[cfg(feature = "rustls")] + let client = client.use_rustls_tls(); + let client = if let Some(timeout) = self.timeout { client.timeout(std::time::Duration::from_secs(timeout)) } else { @@ -172,7 +175,7 @@ impl OpenAIClient { } async fn post( - &self, + &mut self, path: &str, body: &impl serde::ser::Serialize, ) -> Result { @@ -182,7 +185,7 @@ impl OpenAIClient { self.handle_response(response).await } - async fn get(&self, path: &str) -> Result { + async fn get(&mut self, path: &str) -> Result { let request = self.build_request(Method::GET, path).await; let response = request.send().await?; self.handle_response(response).await @@ -194,14 +197,14 @@ impl OpenAIClient { Ok(response.bytes().await?) } - async fn delete(&self, path: &str) -> Result { + async fn delete(&mut self, path: &str) -> Result { let request = self.build_request(Method::DELETE, path).await; let response = request.send().await?; self.handle_response(response).await } async fn post_form( - &self, + &mut self, path: &str, form: Form, ) -> Result { @@ -219,14 +222,18 @@ impl OpenAIClient { } async fn handle_response( - &self, + &mut self, response: Response, ) -> Result { let status = response.status(); + let headers = response.headers().clone(); if status.is_success() { let text = response.text().await.unwrap_or_else(|_| "".to_string()); match serde_json::from_str::(&text) { - Ok(parsed) => Ok(parsed), + Ok(parsed) => { + self.headers = Some(headers); + Ok(parsed) + } Err(e) => Err(APIError::CustomError { message: format!("Failed to parse JSON: {} / response {}", e, text), }), @@ -242,42 +249,51 @@ impl OpenAIClient { } } - pub async fn completion(&self, req: CompletionRequest) -> Result { + pub async fn completion( + &mut self, + req: CompletionRequest, + ) -> Result { self.post("completions", &req).await } - pub async fn edit(&self, req: EditRequest) -> Result { + pub async fn edit(&mut self, req: EditRequest) -> Result { self.post("edits", &req).await } pub async fn image_generation( - &self, + &mut self, req: ImageGenerationRequest, ) -> Result { self.post("images/generations", &req).await } - pub async fn image_edit(&self, req: ImageEditRequest) -> Result { + pub async fn image_edit( + &mut self, + req: ImageEditRequest, + ) -> Result { self.post("images/edits", &req).await } pub async fn image_variation( - &self, + &mut self, req: ImageVariationRequest, ) -> Result { self.post("images/variations", &req).await } - pub async fn embedding(&self, req: EmbeddingRequest) -> Result { + pub async fn embedding( + &mut self, + req: EmbeddingRequest, + ) -> Result { self.post("embeddings", &req).await } - pub async fn file_list(&self) -> Result { + pub async fn file_list(&mut self) -> Result { self.get("files").await } pub async fn upload_file( - &self, + &mut self, req: FileUploadRequest, ) -> Result { let form = Self::create_form(&req, "file")?; @@ -285,13 +301,16 @@ impl OpenAIClient { } pub async fn delete_file( - &self, + &mut self, req: FileDeleteRequest, ) -> Result { self.delete(&format!("files/{}", req.file_id)).await } - pub async fn retrieve_file(&self, file_id: String) -> Result { + pub async fn retrieve_file( + &mut self, + file_id: String, + ) -> Result { self.get(&format!("files/{}", file_id)).await } @@ -300,17 +319,17 @@ impl OpenAIClient { } pub async fn chat_completion( - &self, + &mut self, req: ChatCompletionRequest, ) -> Result { self.post("chat/completions", &req).await } pub async fn audio_transcription( - &self, + &mut self, req: AudioTranscriptionRequest, ) -> Result { - // https://platform.openai.com/docs/api-reference/audio/createTranslation#audio-createtranslation-response_format + // https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format if let Some(response_format) = &req.response_format { if response_format != "json" && response_format != "verbose_json" { return Err(APIError::CustomError { @@ -318,15 +337,24 @@ impl OpenAIClient { }); } } - let form = Self::create_form(&req, "file")?; + let form: Form; + if req.clone().file.is_some() { + form = Self::create_form(&req, "file")?; + } else if let Some(bytes) = req.clone().bytes { + form = Self::create_form_from_bytes(&req, bytes)?; + } else { + return Err(APIError::CustomError { + message: "Either file or bytes must be provided".to_string(), + }); + } self.post_form("audio/transcriptions", form).await } pub async fn audio_transcription_raw( - &self, + &mut self, req: AudioTranscriptionRequest, ) -> Result { - // https://platform.openai.com/docs/api-reference/audio/createTranslation#audio-createtranslation-response_format + // https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format if let Some(response_format) = &req.response_format { if response_format != "text" && response_format != "srt" && response_format != "vtt" { return Err(APIError::CustomError { @@ -334,12 +362,21 @@ impl OpenAIClient { }); } } - let form = Self::create_form(&req, "file")?; + let form: Form; + if req.clone().file.is_some() { + form = Self::create_form(&req, "file")?; + } else if let Some(bytes) = req.clone().bytes { + form = Self::create_form_from_bytes(&req, bytes)?; + } else { + return Err(APIError::CustomError { + message: "Either file or bytes must be provided".to_string(), + }); + } self.post_form_raw("audio/transcriptions", form).await } pub async fn audio_translation( - &self, + &mut self, req: AudioTranslationRequest, ) -> Result { let form = Self::create_form(&req, "file")?; @@ -347,7 +384,7 @@ impl OpenAIClient { } pub async fn audio_speech( - &self, + &mut self, req: AudioSpeechRequest, ) -> Result { let request = self.build_request(Method::POST, "audio/speech").await; @@ -389,20 +426,20 @@ impl OpenAIClient { } pub async fn create_fine_tuning_job( - &self, + &mut self, req: CreateFineTuningJobRequest, ) -> Result { self.post("fine_tuning/jobs", &req).await } pub async fn list_fine_tuning_jobs( - &self, + &mut self, ) -> Result, APIError> { self.get("fine_tuning/jobs").await } pub async fn list_fine_tuning_job_events( - &self, + &mut self, req: ListFineTuningJobEventsRequest, ) -> Result, APIError> { self.get(&format!( @@ -413,7 +450,7 @@ impl OpenAIClient { } pub async fn retrieve_fine_tuning_job( - &self, + &mut self, req: RetrieveFineTuningJobRequest, ) -> Result { self.get(&format!("fine_tuning/jobs/{}", req.fine_tuning_job_id)) @@ -421,7 +458,7 @@ impl OpenAIClient { } pub async fn cancel_fine_tuning_job( - &self, + &mut self, req: CancelFineTuningJobRequest, ) -> Result { self.post( @@ -432,28 +469,28 @@ impl OpenAIClient { } pub async fn create_moderation( - &self, + &mut self, req: CreateModerationRequest, ) -> Result { self.post("moderations", &req).await } pub async fn create_assistant( - &self, + &mut self, req: AssistantRequest, ) -> Result { self.post("assistants", &req).await } pub async fn retrieve_assistant( - &self, + &mut self, assistant_id: String, ) -> Result { self.get(&format!("assistants/{}", assistant_id)).await } pub async fn modify_assistant( - &self, + &mut self, assistant_id: String, req: AssistantRequest, ) -> Result { @@ -461,12 +498,15 @@ impl OpenAIClient { .await } - pub async fn delete_assistant(&self, assistant_id: String) -> Result { + pub async fn delete_assistant( + &mut self, + assistant_id: String, + ) -> Result { self.delete(&format!("assistants/{}", assistant_id)).await } pub async fn list_assistant( - &self, + &mut self, limit: Option, order: Option, after: Option, @@ -477,7 +517,7 @@ impl OpenAIClient { } pub async fn create_assistant_file( - &self, + &mut self, assistant_id: String, req: AssistantFileRequest, ) -> Result { @@ -486,7 +526,7 @@ impl OpenAIClient { } pub async fn retrieve_assistant_file( - &self, + &mut self, assistant_id: String, file_id: String, ) -> Result { @@ -495,7 +535,7 @@ impl OpenAIClient { } pub async fn delete_assistant_file( - &self, + &mut self, assistant_id: String, file_id: String, ) -> Result { @@ -504,7 +544,7 @@ impl OpenAIClient { } pub async fn list_assistant_file( - &self, + &mut self, assistant_id: String, limit: Option, order: Option, @@ -521,28 +561,31 @@ impl OpenAIClient { self.get(&url).await } - pub async fn create_thread(&self, req: CreateThreadRequest) -> Result { + pub async fn create_thread( + &mut self, + req: CreateThreadRequest, + ) -> Result { self.post("threads", &req).await } - pub async fn retrieve_thread(&self, thread_id: String) -> Result { + pub async fn retrieve_thread(&mut self, thread_id: String) -> Result { self.get(&format!("threads/{}", thread_id)).await } pub async fn modify_thread( - &self, + &mut self, thread_id: String, req: ModifyThreadRequest, ) -> Result { self.post(&format!("threads/{}", thread_id), &req).await } - pub async fn delete_thread(&self, thread_id: String) -> Result { + pub async fn delete_thread(&mut self, thread_id: String) -> Result { self.delete(&format!("threads/{}", thread_id)).await } pub async fn create_message( - &self, + &mut self, thread_id: String, req: CreateMessageRequest, ) -> Result { @@ -551,7 +594,7 @@ impl OpenAIClient { } pub async fn retrieve_message( - &self, + &mut self, thread_id: String, message_id: String, ) -> Result { @@ -560,7 +603,7 @@ impl OpenAIClient { } pub async fn modify_message( - &self, + &mut self, thread_id: String, message_id: String, req: ModifyMessageRequest, @@ -572,12 +615,12 @@ impl OpenAIClient { .await } - pub async fn list_messages(&self, thread_id: String) -> Result { + pub async fn list_messages(&mut self, thread_id: String) -> Result { self.get(&format!("threads/{}/messages", thread_id)).await } pub async fn retrieve_message_file( - &self, + &mut self, thread_id: String, message_id: String, file_id: String, @@ -590,7 +633,7 @@ impl OpenAIClient { } pub async fn list_message_file( - &self, + &mut self, thread_id: String, message_id: String, limit: Option, @@ -609,7 +652,7 @@ impl OpenAIClient { } pub async fn create_run( - &self, + &mut self, thread_id: String, req: CreateRunRequest, ) -> Result { @@ -618,7 +661,7 @@ impl OpenAIClient { } pub async fn retrieve_run( - &self, + &mut self, thread_id: String, run_id: String, ) -> Result { @@ -627,7 +670,7 @@ impl OpenAIClient { } pub async fn modify_run( - &self, + &mut self, thread_id: String, run_id: String, req: ModifyRunRequest, @@ -637,7 +680,7 @@ impl OpenAIClient { } pub async fn list_run( - &self, + &mut self, thread_id: String, limit: Option, order: Option, @@ -655,7 +698,7 @@ impl OpenAIClient { } pub async fn cancel_run( - &self, + &mut self, thread_id: String, run_id: String, ) -> Result { @@ -667,14 +710,14 @@ impl OpenAIClient { } pub async fn create_thread_and_run( - &self, + &mut self, req: CreateThreadAndRunRequest, ) -> Result { self.post("threads/runs", &req).await } pub async fn retrieve_run_step( - &self, + &mut self, thread_id: String, run_id: String, step_id: String, @@ -687,7 +730,7 @@ impl OpenAIClient { } pub async fn list_run_step( - &self, + &mut self, thread_id: String, run_id: String, limit: Option, @@ -705,15 +748,18 @@ impl OpenAIClient { self.get(&url).await } - pub async fn create_batch(&self, req: CreateBatchRequest) -> Result { + pub async fn create_batch( + &mut self, + req: CreateBatchRequest, + ) -> Result { self.post("batches", &req).await } - pub async fn retrieve_batch(&self, batch_id: String) -> Result { + pub async fn retrieve_batch(&mut self, batch_id: String) -> Result { self.get(&format!("batches/{}", batch_id)).await } - pub async fn cancel_batch(&self, batch_id: String) -> Result { + pub async fn cancel_batch(&mut self, batch_id: String) -> Result { self.post( &format!("batches/{}/cancel", batch_id), &common::EmptyRequestBody {}, @@ -722,7 +768,7 @@ impl OpenAIClient { } pub async fn list_batch( - &self, + &mut self, after: Option, limit: Option, ) -> Result { @@ -823,4 +869,36 @@ impl OpenAIClient { Ok(form) } + + fn create_form_from_bytes(req: &T, bytes: Vec) -> Result + where + T: Serialize, + { + let json = match serde_json::to_value(req) { + Ok(json) => json, + Err(e) => { + return Err(APIError::CustomError { + message: e.to_string(), + }) + } + }; + + let mut form = Form::new().part("file", Part::bytes(bytes.clone()).file_name("file.mp3")); + + if let Value::Object(map) = json { + for (key, value) in map.into_iter() { + match value { + Value::String(s) => { + form = form.text(key, s); + } + Value::Number(n) => { + form = form.text(key, n.to_string()); + } + _ => {} + } + } + } + + Ok(form) + } } diff --git a/src/v1/assistant.rs b/src/v1/assistant.rs index 0c18ffe..319073e 100644 --- a/src/v1/assistant.rs +++ b/src/v1/assistant.rs @@ -61,7 +61,6 @@ pub struct AssistantObject { #[serde(skip_serializing_if = "Option::is_none")] pub tool_resources: Option, pub metadata: Option>, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -101,14 +100,12 @@ pub struct DeletionStatus { pub id: String, pub object: String, pub deleted: bool, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize)] pub struct ListAssistant { pub object: String, pub data: Vec, - pub headers: Option>, } #[derive(Debug, Serialize, Clone)] @@ -122,12 +119,10 @@ pub struct AssistantFileObject { pub object: String, pub created_at: i64, pub assistant_id: String, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize)] pub struct ListAssistantFile { pub object: String, pub data: Vec, - pub headers: Option>, } diff --git a/src/v1/audio.rs b/src/v1/audio.rs index b2c87f8..4ab93f4 100644 --- a/src/v1/audio.rs +++ b/src/v1/audio.rs @@ -1,6 +1,5 @@ use reqwest::header::HeaderMap; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use crate::impl_builder_methods; @@ -8,9 +7,11 @@ pub const WHISPER_1: &str = "whisper-1"; #[derive(Debug, Serialize, Clone)] pub struct AudioTranscriptionRequest { - pub file: String, pub model: String, #[serde(skip_serializing_if = "Option::is_none")] + pub file: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub bytes: Option>, pub prompt: Option, #[serde(skip_serializing_if = "Option::is_none")] pub response_format: Option, @@ -23,8 +24,21 @@ pub struct AudioTranscriptionRequest { impl AudioTranscriptionRequest { pub fn new(file: String, model: String) -> Self { Self { - file, model, + file: Some(file), + bytes: None, + prompt: None, + response_format: None, + temperature: None, + language: None, + } + } + + pub fn new_bytes(bytes: Vec, model: String) -> Self { + Self { + model, + file: None, + bytes: Some(bytes), prompt: None, response_format: None, temperature: None, @@ -44,7 +58,6 @@ impl_builder_methods!( #[derive(Debug, Deserialize, Serialize)] pub struct AudioTranscriptionResponse { pub text: String, - pub headers: Option>, } #[derive(Debug, Serialize, Clone)] @@ -81,7 +94,6 @@ impl_builder_methods!( #[derive(Debug, Deserialize, Serialize)] pub struct AudioTranslationResponse { pub text: String, - pub headers: Option>, } pub const TTS_1: &str = "tts-1"; diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index aac6159..3849870 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -243,14 +243,13 @@ pub struct ChatCompletionChoice { #[derive(Debug, Deserialize, Serialize)] pub struct ChatCompletionResponse { - pub id: String, + pub id: Option, pub object: String, pub created: i64, pub model: String, pub choices: Vec, pub usage: common::Usage, pub system_fingerprint: Option, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize, PartialEq, Eq)] diff --git a/src/v1/common.rs b/src/v1/common.rs index d641115..3fd5cf8 100644 --- a/src/v1/common.rs +++ b/src/v1/common.rs @@ -24,6 +24,10 @@ macro_rules! impl_builder_methods { #[derive(Debug, Serialize, Deserialize)] pub struct EmptyRequestBody {} +// https://platform.openai.com/docs/models#gpt-4-5 +pub const GPT4_5_PREVIEW: &str = "gpt-4.5-preview"; +pub const GPT4_5_PREVIEW_2025_02_27: &str = "gpt-4.5-preview-2025-02-27"; + // https://platform.openai.com/docs/models/o1 pub const O1_PREVIEW: &str = "o1-preview"; pub const O1_PREVIEW_2024_09_12: &str = "o1-preview-2024-09-12"; diff --git a/src/v1/completion.rs b/src/v1/completion.rs index c4d347f..a5adeaf 100644 --- a/src/v1/completion.rs +++ b/src/v1/completion.rs @@ -117,5 +117,4 @@ pub struct CompletionResponse { pub model: String, pub choices: Vec, pub usage: common::Usage, - pub headers: Option>, } diff --git a/src/v1/edit.rs b/src/v1/edit.rs index c47ab01..18cf383 100644 --- a/src/v1/edit.rs +++ b/src/v1/edit.rs @@ -1,5 +1,4 @@ use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use std::option::Option; use crate::impl_builder_methods; @@ -52,5 +51,4 @@ pub struct EditResponse { pub created: i64, pub usage: common::Usage, pub choices: Vec, - pub headers: Option>, } diff --git a/src/v1/embedding.rs b/src/v1/embedding.rs index a7305b0..3f68054 100644 --- a/src/v1/embedding.rs +++ b/src/v1/embedding.rs @@ -1,5 +1,4 @@ use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use std::option::Option; use crate::impl_builder_methods; @@ -52,7 +51,6 @@ pub struct EmbeddingResponse { pub data: Vec, pub model: String, pub usage: Usage, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize)] diff --git a/src/v1/file.rs b/src/v1/file.rs index d475a31..fceb4f8 100644 --- a/src/v1/file.rs +++ b/src/v1/file.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize, Serialize)] @@ -16,7 +14,6 @@ pub struct FileData { pub struct FileListResponse { pub object: String, pub data: Vec, - pub headers: Option>, } #[derive(Debug, Serialize)] @@ -39,7 +36,6 @@ pub struct FileUploadResponse { pub created_at: i64, pub filename: String, pub purpose: String, - pub headers: Option>, } #[derive(Debug, Serialize)] @@ -58,7 +54,6 @@ pub struct FileDeleteResponse { pub id: String, pub object: String, pub delete: bool, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize)] @@ -69,5 +64,4 @@ pub struct FileRetrieveResponse { pub created_at: i64, pub filename: String, pub purpose: String, - pub headers: Option>, } diff --git a/src/v1/fine_tuning.rs b/src/v1/fine_tuning.rs index 8408a9d..8d696f9 100644 --- a/src/v1/fine_tuning.rs +++ b/src/v1/fine_tuning.rs @@ -1,5 +1,4 @@ use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use crate::impl_builder_methods; @@ -99,7 +98,6 @@ pub struct FineTuningPagination { pub object: String, pub data: Vec, pub has_more: bool, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize)] @@ -118,7 +116,6 @@ pub struct FineTuningJobObject { pub trained_tokens: Option, pub training_file: String, pub validation_file: Option, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize)] diff --git a/src/v1/image.rs b/src/v1/image.rs index 1963146..37038b0 100644 --- a/src/v1/image.rs +++ b/src/v1/image.rs @@ -1,5 +1,4 @@ use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use std::option::Option; use crate::impl_builder_methods; @@ -50,7 +49,6 @@ impl_builder_methods!( pub struct ImageGenerationResponse { pub created: i64, pub data: Vec, - pub headers: Option>, } #[derive(Debug, Serialize, Clone)] @@ -100,7 +98,6 @@ impl_builder_methods!( pub struct ImageEditResponse { pub created: i64, pub data: Vec, - pub headers: Option>, } #[derive(Debug, Serialize, Clone)] @@ -144,5 +141,4 @@ impl_builder_methods!( pub struct ImageVariationResponse { pub created: i64, pub data: Vec, - pub headers: Option>, } diff --git a/src/v1/message.rs b/src/v1/message.rs index 094689c..8ac7b66 100644 --- a/src/v1/message.rs +++ b/src/v1/message.rs @@ -68,7 +68,6 @@ pub struct MessageObject { #[serde(skip_serializing_if = "Option::is_none")] pub attachments: Option>, pub metadata: Option>, - pub headers: Option>, } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -111,7 +110,6 @@ pub struct ListMessage { pub first_id: String, pub last_id: String, pub has_more: bool, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize)] @@ -120,7 +118,6 @@ pub struct MessageFileObject { pub object: String, pub created_at: i64, pub message_id: String, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize)] @@ -130,7 +127,6 @@ pub struct ListMessageFile { pub first_id: String, pub last_id: String, pub has_more: bool, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize)] diff --git a/src/v1/moderation.rs b/src/v1/moderation.rs index b4989eb..69acece 100644 --- a/src/v1/moderation.rs +++ b/src/v1/moderation.rs @@ -1,5 +1,4 @@ use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use crate::impl_builder_methods; @@ -26,7 +25,6 @@ pub struct CreateModerationResponse { pub id: String, pub model: String, pub results: Vec, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize)] diff --git a/src/v1/run.rs b/src/v1/run.rs index 645104f..d084b62 100644 --- a/src/v1/run.rs +++ b/src/v1/run.rs @@ -98,7 +98,6 @@ pub struct RunObject { pub instructions: Option, pub tools: Vec, pub metadata: HashMap, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize)] @@ -108,7 +107,6 @@ pub struct ListRun { pub first_id: String, pub last_id: String, pub has_more: bool, - pub headers: Option>, } #[derive(Debug, Serialize, Clone)] @@ -151,7 +149,6 @@ pub struct RunStepObject { #[serde(skip_serializing_if = "Option::is_none")] pub completed_at: Option, pub metadata: HashMap, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -161,5 +158,4 @@ pub struct ListRunStep { pub first_id: String, pub last_id: String, pub has_more: bool, - pub headers: Option>, } diff --git a/src/v1/thread.rs b/src/v1/thread.rs index 57e0b5c..73b1f5c 100644 --- a/src/v1/thread.rs +++ b/src/v1/thread.rs @@ -67,8 +67,6 @@ pub struct ThreadObject { pub metadata: HashMap, #[serde(skip_serializing_if = "Option::is_none")] pub tool_resources: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -86,7 +84,6 @@ pub struct Message { #[serde(skip_serializing_if = "Option::is_none")] pub attachments: Option>, pub metadata: Option>, - pub headers: Option>, } #[derive(Debug, Deserialize, Serialize, Clone)]