diff --git a/Cargo.toml b/Cargo.toml index 8ecfc4c..ffb9a1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,3 +40,6 @@ features = ["connect"] [dependencies.futures-util] version = "0.3.31" features = ["sink", "std"] + +[dependencies.url] +version = "2.5.4" diff --git a/src/v1/api.rs b/src/v1/api.rs index c53f989..458dc34 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -43,6 +43,7 @@ use reqwest::multipart::{Form, Part}; use reqwest::{Client, Method, Response}; use serde::Serialize; use serde_json::Value; +use url::Url; use std::error::Error; use std::fs::{create_dir_all, File}; @@ -62,9 +63,10 @@ pub struct OpenAIClientBuilder { headers: Option, } +#[derive(Debug)] pub struct OpenAIClient { api_endpoint: String, - api_key: String, + api_key: Option, organization: Option, proxy: Option, timeout: Option, @@ -111,14 +113,13 @@ impl OpenAIClientBuilder { } pub fn build(self) -> Result> { - let api_key = self.api_key.ok_or("API key is required")?; let api_endpoint = self.api_endpoint.unwrap_or_else(|| { std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned()) }); Ok(OpenAIClient { api_endpoint, - api_key, + api_key: self.api_key, organization: self.organization, proxy: self.proxy, timeout: self.timeout, @@ -133,7 +134,10 @@ impl OpenAIClient { } async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.api_endpoint, path); + let url = self + .build_url_with_preserved_query(path) + .unwrap_or_else(|_| format!("{}/{}", self.api_endpoint, path)); + let client = Client::builder(); #[cfg(feature = "rustls")] @@ -153,9 +157,11 @@ impl OpenAIClient { let client = client.build().unwrap(); - let mut request = client - .request(method, url) - .header("Authorization", format!("Bearer {}", self.api_key)); + let mut request = client.request(method, url); + + if let Some(api_key) = &self.api_key { + request = request.header("Authorization", format!("Bearer {}", api_key)); + } if let Some(organization) = &self.organization { request = request.header("openai-organization", organization); @@ -775,7 +781,22 @@ impl OpenAIClient { let url = Self::query_params(limit, None, after, None, "batches".to_string()); self.get(&url).await } + fn build_url_with_preserved_query(&self, path: &str) -> Result { + let (base, query_opt) = match self.api_endpoint.split_once('?') { + Some((b, q)) => (b.trim_end_matches('/'), Some(q)), + None => (self.api_endpoint.trim_end_matches('/'), None), + }; + let full_path = format!("{}/{}", base, path.trim_start_matches('/')); + let mut url = Url::parse(&full_path)?; + + if let Some(query) = query_opt { + for (k, v) in url::form_urlencoded::parse(query.as_bytes()) { + url.query_pairs_mut().append_pair(&k, &v); + } + } + Ok(url.to_string()) + } fn query_params( limit: Option, order: Option,