Merge pull request #158 from hiteshjoshi/main

Optional auth header and with api keys | works with azure
This commit is contained in:
Dongri Jin
2025-05-14 15:04:07 +09:00
committed by GitHub
2 changed files with 31 additions and 7 deletions

View File

@ -40,3 +40,6 @@ features = ["connect"]
[dependencies.futures-util]
version = "0.3.31"
features = ["sink", "std"]
[dependencies.url]
version = "2.5.4"

View File

@ -43,6 +43,7 @@ use reqwest::multipart::{Form, Part};
use reqwest::{Client, Method, Response};
use serde::Serialize;
use serde_json::Value;
use url::Url;
use std::error::Error;
use std::fs::{create_dir_all, File};
@ -62,9 +63,10 @@ pub struct OpenAIClientBuilder {
headers: Option<HeaderMap>,
}
#[derive(Debug)]
pub struct OpenAIClient {
api_endpoint: String,
api_key: String,
api_key: Option<String>,
organization: Option<String>,
proxy: Option<String>,
timeout: Option<u64>,
@ -111,14 +113,13 @@ impl OpenAIClientBuilder {
}
pub fn build(self) -> Result<OpenAIClient, Box<dyn Error>> {
let api_key = self.api_key.ok_or("API key is required")?;
let api_endpoint = self.api_endpoint.unwrap_or_else(|| {
std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned())
});
Ok(OpenAIClient {
api_endpoint,
api_key,
api_key: self.api_key,
organization: self.organization,
proxy: self.proxy,
timeout: self.timeout,
@ -133,7 +134,10 @@ impl OpenAIClient {
}
async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.api_endpoint, path);
let url = self
.build_url_with_preserved_query(path)
.unwrap_or_else(|_| format!("{}/{}", self.api_endpoint, path));
let client = Client::builder();
#[cfg(feature = "rustls")]
@ -153,9 +157,11 @@ impl OpenAIClient {
let client = client.build().unwrap();
let mut request = client
.request(method, url)
.header("Authorization", format!("Bearer {}", self.api_key));
let mut request = client.request(method, url);
if let Some(api_key) = &self.api_key {
request = request.header("Authorization", format!("Bearer {}", api_key));
}
if let Some(organization) = &self.organization {
request = request.header("openai-organization", organization);
@ -775,7 +781,22 @@ impl OpenAIClient {
let url = Self::query_params(limit, None, after, None, "batches".to_string());
self.get(&url).await
}
fn build_url_with_preserved_query(&self, path: &str) -> Result<String, url::ParseError> {
let (base, query_opt) = match self.api_endpoint.split_once('?') {
Some((b, q)) => (b.trim_end_matches('/'), Some(q)),
None => (self.api_endpoint.trim_end_matches('/'), None),
};
let full_path = format!("{}/{}", base, path.trim_start_matches('/'));
let mut url = Url::parse(&full_path)?;
if let Some(query) = query_opt {
for (k, v) in url::form_urlencoded::parse(query.as_bytes()) {
url.query_pairs_mut().append_pair(&k, &v);
}
}
Ok(url.to_string())
}
fn query_params(
limit: Option<i64>,
order: Option<String>,