mirror of
https://github.com/mii443/openai-api-rs.git
synced 2025-08-22 15:15:34 +00:00
add additional json
This commit is contained in:
@ -62,6 +62,7 @@ pub struct OpenAIClientBuilder {
|
|||||||
proxy: Option<String>,
|
proxy: Option<String>,
|
||||||
timeout: Option<u64>,
|
timeout: Option<u64>,
|
||||||
headers: Option<HeaderMap>,
|
headers: Option<HeaderMap>,
|
||||||
|
additional_json: Option<Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -72,6 +73,7 @@ pub struct OpenAIClient {
|
|||||||
proxy: Option<String>,
|
proxy: Option<String>,
|
||||||
timeout: Option<u64>,
|
timeout: Option<u64>,
|
||||||
headers: Option<HeaderMap>,
|
headers: Option<HeaderMap>,
|
||||||
|
additional_json: Option<Value>,
|
||||||
pub response_headers: Option<HeaderMap>,
|
pub response_headers: Option<HeaderMap>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,6 +116,11 @@ impl OpenAIClientBuilder {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn with_additional_json(mut self, additional_json: Value) -> Self {
|
||||||
|
self.additional_json = Some(additional_json);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn build(self) -> Result<OpenAIClient, Box<dyn Error>> {
|
pub fn build(self) -> Result<OpenAIClient, Box<dyn Error>> {
|
||||||
let api_endpoint = self.api_endpoint.unwrap_or_else(|| {
|
let api_endpoint = self.api_endpoint.unwrap_or_else(|| {
|
||||||
std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned())
|
std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned())
|
||||||
@ -127,6 +134,7 @@ impl OpenAIClientBuilder {
|
|||||||
timeout: self.timeout,
|
timeout: self.timeout,
|
||||||
headers: self.headers,
|
headers: self.headers,
|
||||||
response_headers: None,
|
response_headers: None,
|
||||||
|
additional_json: self.additional_json,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -189,6 +197,19 @@ impl OpenAIClient {
|
|||||||
body: &impl serde::ser::Serialize,
|
body: &impl serde::ser::Serialize,
|
||||||
) -> Result<T, APIError> {
|
) -> Result<T, APIError> {
|
||||||
let request = self.build_request(Method::POST, path).await;
|
let request = self.build_request(Method::POST, path).await;
|
||||||
|
|
||||||
|
let request = if let Some(additional_json) = &self.additional_json {
|
||||||
|
let mut body = serde_json::to_value(body).unwrap();
|
||||||
|
if let Value::Object(map) = &mut body {
|
||||||
|
for (key, value) in additional_json.as_object().unwrap() {
|
||||||
|
map.insert(key.clone(), value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
request.json(&body)
|
||||||
|
} else {
|
||||||
|
request.json(body)
|
||||||
|
};
|
||||||
|
|
||||||
let request = request.json(body);
|
let request = request.json(body);
|
||||||
let response = request.send().await?;
|
let response = request.send().await?;
|
||||||
self.handle_response(response).await
|
self.handle_response(response).await
|
||||||
|
Reference in New Issue
Block a user