mirror of
https://github.com/mii443/openai-api-rs.git
synced 2025-08-22 15:15:34 +00:00
preserve query parameters from the url
This commit is contained in:
@ -40,3 +40,6 @@ features = ["connect"]
|
||||
[dependencies.futures-util]
|
||||
version = "0.3.31"
|
||||
features = ["sink", "std"]
|
||||
|
||||
[dependencies.url]
|
||||
version = "2.5.4"
|
||||
|
@ -43,6 +43,7 @@ use reqwest::multipart::{Form, Part};
|
||||
use reqwest::{Client, Method, Response};
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use url::Url;
|
||||
|
||||
use std::error::Error;
|
||||
use std::fs::{create_dir_all, File};
|
||||
@ -56,7 +57,6 @@ const API_URL_V1: &str = "https://api.openai.com/v1";
|
||||
pub struct OpenAIClientBuilder {
|
||||
api_endpoint: Option<String>,
|
||||
api_key: Option<String>,
|
||||
api_version: Option<String>,
|
||||
organization: Option<String>,
|
||||
proxy: Option<String>,
|
||||
timeout: Option<u64>,
|
||||
@ -67,7 +67,6 @@ pub struct OpenAIClientBuilder {
|
||||
pub struct OpenAIClient {
|
||||
api_endpoint: String,
|
||||
api_key: Option<String>,
|
||||
api_version: Option<String>,
|
||||
organization: Option<String>,
|
||||
proxy: Option<String>,
|
||||
timeout: Option<u64>,
|
||||
@ -84,11 +83,6 @@ impl OpenAIClientBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_api_version(mut self, api_version: impl Into<String>) -> Self {
|
||||
self.api_version = Some(api_version.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
|
||||
self.api_endpoint = Some(endpoint.into());
|
||||
self
|
||||
@ -125,7 +119,6 @@ impl OpenAIClientBuilder {
|
||||
|
||||
Ok(OpenAIClient {
|
||||
api_endpoint,
|
||||
api_version: self.api_version,
|
||||
api_key: self.api_key,
|
||||
organization: self.organization,
|
||||
proxy: self.proxy,
|
||||
@ -141,11 +134,9 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
|
||||
let mut url = format!("{}/{}", self.api_endpoint, path);
|
||||
|
||||
if let Some(api_version) = &self.api_version {
|
||||
url = format!("{}?api-version={}", url, api_version);
|
||||
}
|
||||
let url = self
|
||||
.build_url_with_preserved_query(path)
|
||||
.unwrap_or_else(|_| format!("{}/{}", self.api_endpoint, path));
|
||||
|
||||
let client = Client::builder();
|
||||
|
||||
@ -790,7 +781,18 @@ impl OpenAIClient {
|
||||
let url = Self::query_params(limit, None, after, None, "batches".to_string());
|
||||
self.get(&url).await
|
||||
}
|
||||
fn build_url_with_preserved_query(&self, path: &str) -> Result<String, url::ParseError> {
|
||||
let base = Url::parse(&self.api_endpoint)?;
|
||||
let mut url = base.join(path)?;
|
||||
|
||||
if let Some(q) = base.query() {
|
||||
for (k, v) in url::form_urlencoded::parse(q.as_bytes()) {
|
||||
url.query_pairs_mut().append_pair(&k, &v);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(url.to_string())
|
||||
}
|
||||
fn query_params(
|
||||
limit: Option<i64>,
|
||||
order: Option<String>,
|
||||
|
Reference in New Issue
Block a user