mirror of
https://github.com/mii443/openai-api-rs.git
synced 2025-08-22 15:15:34 +00:00
Merge pull request #158 from hiteshjoshi/main
Optional auth header and with api keys | works with azure
This commit is contained in:
@ -40,3 +40,6 @@ features = ["connect"]
|
||||
[dependencies.futures-util]
|
||||
version = "0.3.31"
|
||||
features = ["sink", "std"]
|
||||
|
||||
[dependencies.url]
|
||||
version = "2.5.4"
|
||||
|
@ -43,6 +43,7 @@ use reqwest::multipart::{Form, Part};
|
||||
use reqwest::{Client, Method, Response};
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use url::Url;
|
||||
|
||||
use std::error::Error;
|
||||
use std::fs::{create_dir_all, File};
|
||||
@ -62,9 +63,10 @@ pub struct OpenAIClientBuilder {
|
||||
headers: Option<HeaderMap>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OpenAIClient {
|
||||
api_endpoint: String,
|
||||
api_key: String,
|
||||
api_key: Option<String>,
|
||||
organization: Option<String>,
|
||||
proxy: Option<String>,
|
||||
timeout: Option<u64>,
|
||||
@ -111,14 +113,13 @@ impl OpenAIClientBuilder {
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<OpenAIClient, Box<dyn Error>> {
|
||||
let api_key = self.api_key.ok_or("API key is required")?;
|
||||
let api_endpoint = self.api_endpoint.unwrap_or_else(|| {
|
||||
std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned())
|
||||
});
|
||||
|
||||
Ok(OpenAIClient {
|
||||
api_endpoint,
|
||||
api_key,
|
||||
api_key: self.api_key,
|
||||
organization: self.organization,
|
||||
proxy: self.proxy,
|
||||
timeout: self.timeout,
|
||||
@ -133,7 +134,10 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
|
||||
let url = format!("{}/{}", self.api_endpoint, path);
|
||||
let url = self
|
||||
.build_url_with_preserved_query(path)
|
||||
.unwrap_or_else(|_| format!("{}/{}", self.api_endpoint, path));
|
||||
|
||||
let client = Client::builder();
|
||||
|
||||
#[cfg(feature = "rustls")]
|
||||
@ -153,9 +157,11 @@ impl OpenAIClient {
|
||||
|
||||
let client = client.build().unwrap();
|
||||
|
||||
let mut request = client
|
||||
.request(method, url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key));
|
||||
let mut request = client.request(method, url);
|
||||
|
||||
if let Some(api_key) = &self.api_key {
|
||||
request = request.header("Authorization", format!("Bearer {}", api_key));
|
||||
}
|
||||
|
||||
if let Some(organization) = &self.organization {
|
||||
request = request.header("openai-organization", organization);
|
||||
@ -775,7 +781,22 @@ impl OpenAIClient {
|
||||
let url = Self::query_params(limit, None, after, None, "batches".to_string());
|
||||
self.get(&url).await
|
||||
}
|
||||
fn build_url_with_preserved_query(&self, path: &str) -> Result<String, url::ParseError> {
|
||||
let (base, query_opt) = match self.api_endpoint.split_once('?') {
|
||||
Some((b, q)) => (b.trim_end_matches('/'), Some(q)),
|
||||
None => (self.api_endpoint.trim_end_matches('/'), None),
|
||||
};
|
||||
|
||||
let full_path = format!("{}/{}", base, path.trim_start_matches('/'));
|
||||
let mut url = Url::parse(&full_path)?;
|
||||
|
||||
if let Some(query) = query_opt {
|
||||
for (k, v) in url::form_urlencoded::parse(query.as_bytes()) {
|
||||
url.query_pairs_mut().append_pair(&k, &v);
|
||||
}
|
||||
}
|
||||
Ok(url.to_string())
|
||||
}
|
||||
fn query_params(
|
||||
limit: Option<i64>,
|
||||
order: Option<String>,
|
||||
|
Reference in New Issue
Block a user