diff --git a/src/v1/api.rs b/src/v1/api.rs index ce0b1c4..a26ebab 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -62,6 +62,7 @@ pub struct OpenAIClientBuilder { proxy: Option, timeout: Option, headers: Option, + additional_json: Option, } #[derive(Debug)] @@ -72,6 +73,7 @@ pub struct OpenAIClient { proxy: Option, timeout: Option, headers: Option, + additional_json: Option, pub response_headers: Option, } @@ -114,6 +116,11 @@ impl OpenAIClientBuilder { self } + pub fn with_additional_json(mut self, additional_json: Value) -> Self { + self.additional_json = Some(additional_json); + self + } + pub fn build(self) -> Result> { let api_endpoint = self.api_endpoint.unwrap_or_else(|| { std::env::var("OPENAI_API_BASE").unwrap_or_else(|_| API_URL_V1.to_owned()) @@ -127,6 +134,7 @@ impl OpenAIClientBuilder { timeout: self.timeout, headers: self.headers, response_headers: None, + additional_json: self.additional_json, }) } } @@ -189,6 +197,19 @@ impl OpenAIClient { body: &impl serde::ser::Serialize, ) -> Result { let request = self.build_request(Method::POST, path).await; + + let request = if let Some(additional_json) = &self.additional_json { + let mut body = serde_json::to_value(body).unwrap(); + if let Value::Object(map) = &mut body { + for (key, value) in additional_json.as_object().unwrap() { + map.insert(key.clone(), value.clone()); + } + } + request.json(&body) + } else { + request.json(body) + }; + let request = request.json(body); let response = request.send().await?; self.handle_response(response).await