preserve query parameters from the url

This commit is contained in:
Hitesh Joshi
2025-04-30 17:54:22 +05:30
parent 6b2577dc8f
commit 6ea6beb4df
2 changed files with 18 additions and 13 deletions

View File

@ -40,3 +40,6 @@ features = ["connect"]
[dependencies.futures-util]
version = "0.3.31"
features = ["sink", "std"]
[dependencies.url]
version = "2.5.4"

View File

@ -43,6 +43,7 @@ use reqwest::multipart::{Form, Part};
use reqwest::{Client, Method, Response};
use serde::Serialize;
use serde_json::Value;
use url::Url;
use std::error::Error;
use std::fs::{create_dir_all, File};
@ -56,7 +57,6 @@ const API_URL_V1: &str = "https://api.openai.com/v1";
pub struct OpenAIClientBuilder {
api_endpoint: Option<String>,
api_key: Option<String>,
api_version: Option<String>,
organization: Option<String>,
proxy: Option<String>,
timeout: Option<u64>,
@ -67,7 +67,6 @@ pub struct OpenAIClientBuilder {
pub struct OpenAIClient {
api_endpoint: String,
api_key: Option<String>,
api_version: Option<String>,
organization: Option<String>,
proxy: Option<String>,
timeout: Option<u64>,
@ -84,11 +83,6 @@ impl OpenAIClientBuilder {
self
}
pub fn with_api_version(mut self, api_version: impl Into<String>) -> Self {
self.api_version = Some(api_version.into());
self
}
pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.api_endpoint = Some(endpoint.into());
self
@ -125,7 +119,6 @@ impl OpenAIClientBuilder {
Ok(OpenAIClient {
api_endpoint,
api_version: self.api_version,
api_key: self.api_key,
organization: self.organization,
proxy: self.proxy,
@ -141,11 +134,9 @@ impl OpenAIClient {
}
async fn build_request(&self, method: Method, path: &str) -> reqwest::RequestBuilder {
let mut url = format!("{}/{}", self.api_endpoint, path);
if let Some(api_version) = &self.api_version {
url = format!("{}?api-version={}", url, api_version);
}
let url = self
.build_url_with_preserved_query(path)
.unwrap_or_else(|_| format!("{}/{}", self.api_endpoint, path));
let client = Client::builder();
@ -790,7 +781,18 @@ impl OpenAIClient {
let url = Self::query_params(limit, None, after, None, "batches".to_string());
self.get(&url).await
}
fn build_url_with_preserved_query(&self, path: &str) -> Result<String, url::ParseError> {
let base = Url::parse(&self.api_endpoint)?;
let mut url = base.join(path)?;
if let Some(q) = base.query() {
for (k, v) in url::form_urlencoded::parse(q.as_bytes()) {
url.query_pairs_mut().append_pair(&k, &v);
}
}
Ok(url.to_string())
}
fn query_params(
limit: Option<i64>,
order: Option<String>,