Add headers to client

This commit is contained in:
Dongri Jin
2025-03-05 07:26:04 +09:00
parent 2e6ea3eedd
commit ddeefd256c
26 changed files with 77 additions and 113 deletions

View File

@@ -68,7 +68,7 @@ pub struct OpenAIClient {
organization: Option<String>,
proxy: Option<String>,
timeout: Option<u64>,
headers: Option<HeaderMap>,
pub headers: Option<HeaderMap>,
}
impl OpenAIClientBuilder {
@@ -175,7 +175,7 @@ impl OpenAIClient {
}
async fn post<T: serde::de::DeserializeOwned>(
&self,
&mut self,
path: &str,
body: &impl serde::ser::Serialize,
) -> Result<T, APIError> {
@@ -185,7 +185,7 @@ impl OpenAIClient {
self.handle_response(response).await
}
async fn get<T: serde::de::DeserializeOwned>(&self, path: &str) -> Result<T, APIError> {
async fn get<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
let request = self.build_request(Method::GET, path).await;
let response = request.send().await?;
self.handle_response(response).await
@@ -197,14 +197,14 @@ impl OpenAIClient {
Ok(response.bytes().await?)
}
async fn delete<T: serde::de::DeserializeOwned>(&self, path: &str) -> Result<T, APIError> {
async fn delete<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
let request = self.build_request(Method::DELETE, path).await;
let response = request.send().await?;
self.handle_response(response).await
}
async fn post_form<T: serde::de::DeserializeOwned>(
&self,
&mut self,
path: &str,
form: Form,
) -> Result<T, APIError> {
@@ -222,14 +222,18 @@ impl OpenAIClient {
}
async fn handle_response<T: serde::de::DeserializeOwned>(
&self,
&mut self,
response: Response,
) -> Result<T, APIError> {
let status = response.status();
let headers = response.headers().clone();
if status.is_success() {
let text = response.text().await.unwrap_or_else(|_| "".to_string());
match serde_json::from_str::<T>(&text) {
Ok(parsed) => Ok(parsed),
Ok(parsed) => {
self.headers = Some(headers);
Ok(parsed)
},
Err(e) => Err(APIError::CustomError {
message: format!("Failed to parse JSON: {} / response {}", e, text),
}),
@@ -245,42 +249,42 @@ impl OpenAIClient {
}
}
pub async fn completion(&self, req: CompletionRequest) -> Result<CompletionResponse, APIError> {
pub async fn completion(&mut self, req: CompletionRequest) -> Result<CompletionResponse, APIError> {
self.post("completions", &req).await
}
pub async fn edit(&self, req: EditRequest) -> Result<EditResponse, APIError> {
pub async fn edit(&mut self, req: EditRequest) -> Result<EditResponse, APIError> {
self.post("edits", &req).await
}
pub async fn image_generation(
&self,
&mut self,
req: ImageGenerationRequest,
) -> Result<ImageGenerationResponse, APIError> {
self.post("images/generations", &req).await
}
pub async fn image_edit(&self, req: ImageEditRequest) -> Result<ImageEditResponse, APIError> {
pub async fn image_edit(&mut self, req: ImageEditRequest) -> Result<ImageEditResponse, APIError> {
self.post("images/edits", &req).await
}
pub async fn image_variation(
&self,
&mut self,
req: ImageVariationRequest,
) -> Result<ImageVariationResponse, APIError> {
self.post("images/variations", &req).await
}
pub async fn embedding(&self, req: EmbeddingRequest) -> Result<EmbeddingResponse, APIError> {
pub async fn embedding(&mut self, req: EmbeddingRequest) -> Result<EmbeddingResponse, APIError> {
self.post("embeddings", &req).await
}
pub async fn file_list(&self) -> Result<FileListResponse, APIError> {
pub async fn file_list(&mut self) -> Result<FileListResponse, APIError> {
self.get("files").await
}
pub async fn upload_file(
&self,
&mut self,
req: FileUploadRequest,
) -> Result<FileUploadResponse, APIError> {
let form = Self::create_form(&req, "file")?;
@@ -288,13 +292,13 @@ impl OpenAIClient {
}
pub async fn delete_file(
&self,
&mut self,
req: FileDeleteRequest,
) -> Result<FileDeleteResponse, APIError> {
self.delete(&format!("files/{}", req.file_id)).await
}
pub async fn retrieve_file(&self, file_id: String) -> Result<FileRetrieveResponse, APIError> {
pub async fn retrieve_file(&mut self, file_id: String) -> Result<FileRetrieveResponse, APIError> {
self.get(&format!("files/{}", file_id)).await
}
@@ -303,14 +307,14 @@ impl OpenAIClient {
}
pub async fn chat_completion(
&self,
&mut self,
req: ChatCompletionRequest,
) -> Result<ChatCompletionResponse, APIError> {
self.post("chat/completions", &req).await
}
pub async fn audio_transcription(
&self,
&mut self,
req: AudioTranscriptionRequest,
) -> Result<AudioTranscriptionResponse, APIError> {
// https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
@@ -335,7 +339,7 @@ impl OpenAIClient {
}
pub async fn audio_transcription_raw(
&self,
&mut self,
req: AudioTranscriptionRequest,
) -> Result<Bytes, APIError> {
// https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
@@ -360,7 +364,7 @@ impl OpenAIClient {
}
pub async fn audio_translation(
&self,
&mut self,
req: AudioTranslationRequest,
) -> Result<AudioTranslationResponse, APIError> {
let form = Self::create_form(&req, "file")?;
@@ -368,7 +372,7 @@ impl OpenAIClient {
}
pub async fn audio_speech(
&self,
&mut self,
req: AudioSpeechRequest,
) -> Result<AudioSpeechResponse, APIError> {
let request = self.build_request(Method::POST, "audio/speech").await;
@@ -410,20 +414,20 @@ impl OpenAIClient {
}
pub async fn create_fine_tuning_job(
&self,
&mut self,
req: CreateFineTuningJobRequest,
) -> Result<FineTuningJobObject, APIError> {
self.post("fine_tuning/jobs", &req).await
}
pub async fn list_fine_tuning_jobs(
&self,
&mut self,
) -> Result<FineTuningPagination<FineTuningJobObject>, APIError> {
self.get("fine_tuning/jobs").await
}
pub async fn list_fine_tuning_job_events(
&self,
&mut self,
req: ListFineTuningJobEventsRequest,
) -> Result<FineTuningPagination<FineTuningJobEvent>, APIError> {
self.get(&format!(
@@ -434,7 +438,7 @@ impl OpenAIClient {
}
pub async fn retrieve_fine_tuning_job(
&self,
&mut self,
req: RetrieveFineTuningJobRequest,
) -> Result<FineTuningJobObject, APIError> {
self.get(&format!("fine_tuning/jobs/{}", req.fine_tuning_job_id))
@@ -442,7 +446,7 @@ impl OpenAIClient {
}
pub async fn cancel_fine_tuning_job(
&self,
&mut self,
req: CancelFineTuningJobRequest,
) -> Result<FineTuningJobObject, APIError> {
self.post(
@@ -453,28 +457,28 @@ impl OpenAIClient {
}
pub async fn create_moderation(
&self,
&mut self,
req: CreateModerationRequest,
) -> Result<CreateModerationResponse, APIError> {
self.post("moderations", &req).await
}
pub async fn create_assistant(
&self,
&mut self,
req: AssistantRequest,
) -> Result<AssistantObject, APIError> {
self.post("assistants", &req).await
}
pub async fn retrieve_assistant(
&self,
&mut self,
assistant_id: String,
) -> Result<AssistantObject, APIError> {
self.get(&format!("assistants/{}", assistant_id)).await
}
pub async fn modify_assistant(
&self,
&mut self,
assistant_id: String,
req: AssistantRequest,
) -> Result<AssistantObject, APIError> {
@@ -482,12 +486,12 @@ impl OpenAIClient {
.await
}
pub async fn delete_assistant(&self, assistant_id: String) -> Result<DeletionStatus, APIError> {
pub async fn delete_assistant(&mut self, assistant_id: String) -> Result<DeletionStatus, APIError> {
self.delete(&format!("assistants/{}", assistant_id)).await
}
pub async fn list_assistant(
&self,
&mut self,
limit: Option<i64>,
order: Option<String>,
after: Option<String>,
@@ -498,7 +502,7 @@ impl OpenAIClient {
}
pub async fn create_assistant_file(
&self,
&mut self,
assistant_id: String,
req: AssistantFileRequest,
) -> Result<AssistantFileObject, APIError> {
@@ -507,7 +511,7 @@ impl OpenAIClient {
}
pub async fn retrieve_assistant_file(
&self,
&mut self,
assistant_id: String,
file_id: String,
) -> Result<AssistantFileObject, APIError> {
@@ -516,7 +520,7 @@ impl OpenAIClient {
}
pub async fn delete_assistant_file(
&self,
&mut self,
assistant_id: String,
file_id: String,
) -> Result<DeletionStatus, APIError> {
@@ -525,7 +529,7 @@ impl OpenAIClient {
}
pub async fn list_assistant_file(
&self,
&mut self,
assistant_id: String,
limit: Option<i64>,
order: Option<String>,
@@ -542,28 +546,28 @@ impl OpenAIClient {
self.get(&url).await
}
pub async fn create_thread(&self, req: CreateThreadRequest) -> Result<ThreadObject, APIError> {
pub async fn create_thread(&mut self, req: CreateThreadRequest) -> Result<ThreadObject, APIError> {
self.post("threads", &req).await
}
pub async fn retrieve_thread(&self, thread_id: String) -> Result<ThreadObject, APIError> {
pub async fn retrieve_thread(&mut self, thread_id: String) -> Result<ThreadObject, APIError> {
self.get(&format!("threads/{}", thread_id)).await
}
pub async fn modify_thread(
&self,
&mut self,
thread_id: String,
req: ModifyThreadRequest,
) -> Result<ThreadObject, APIError> {
self.post(&format!("threads/{}", thread_id), &req).await
}
pub async fn delete_thread(&self, thread_id: String) -> Result<DeletionStatus, APIError> {
pub async fn delete_thread(&mut self, thread_id: String) -> Result<DeletionStatus, APIError> {
self.delete(&format!("threads/{}", thread_id)).await
}
pub async fn create_message(
&self,
&mut self,
thread_id: String,
req: CreateMessageRequest,
) -> Result<MessageObject, APIError> {
@@ -572,7 +576,7 @@ impl OpenAIClient {
}
pub async fn retrieve_message(
&self,
&mut self,
thread_id: String,
message_id: String,
) -> Result<MessageObject, APIError> {
@@ -581,7 +585,7 @@ impl OpenAIClient {
}
pub async fn modify_message(
&self,
&mut self,
thread_id: String,
message_id: String,
req: ModifyMessageRequest,
@@ -593,12 +597,12 @@ impl OpenAIClient {
.await
}
pub async fn list_messages(&self, thread_id: String) -> Result<ListMessage, APIError> {
pub async fn list_messages(&mut self, thread_id: String) -> Result<ListMessage, APIError> {
self.get(&format!("threads/{}/messages", thread_id)).await
}
pub async fn retrieve_message_file(
&self,
&mut self,
thread_id: String,
message_id: String,
file_id: String,
@@ -611,7 +615,7 @@ impl OpenAIClient {
}
pub async fn list_message_file(
&self,
&mut self,
thread_id: String,
message_id: String,
limit: Option<i64>,
@@ -630,7 +634,7 @@ impl OpenAIClient {
}
pub async fn create_run(
&self,
&mut self,
thread_id: String,
req: CreateRunRequest,
) -> Result<RunObject, APIError> {
@@ -639,7 +643,7 @@ impl OpenAIClient {
}
pub async fn retrieve_run(
&self,
&mut self,
thread_id: String,
run_id: String,
) -> Result<RunObject, APIError> {
@@ -648,7 +652,7 @@ impl OpenAIClient {
}
pub async fn modify_run(
&self,
&mut self,
thread_id: String,
run_id: String,
req: ModifyRunRequest,
@@ -658,7 +662,7 @@ impl OpenAIClient {
}
pub async fn list_run(
&self,
&mut self,
thread_id: String,
limit: Option<i64>,
order: Option<String>,
@@ -676,7 +680,7 @@ impl OpenAIClient {
}
pub async fn cancel_run(
&self,
&mut self,
thread_id: String,
run_id: String,
) -> Result<RunObject, APIError> {
@@ -688,14 +692,14 @@ impl OpenAIClient {
}
pub async fn create_thread_and_run(
&self,
&mut self,
req: CreateThreadAndRunRequest,
) -> Result<RunObject, APIError> {
self.post("threads/runs", &req).await
}
pub async fn retrieve_run_step(
&self,
&mut self,
thread_id: String,
run_id: String,
step_id: String,
@@ -708,7 +712,7 @@ impl OpenAIClient {
}
pub async fn list_run_step(
&self,
&mut self,
thread_id: String,
run_id: String,
limit: Option<i64>,
@@ -726,15 +730,15 @@ impl OpenAIClient {
self.get(&url).await
}
pub async fn create_batch(&self, req: CreateBatchRequest) -> Result<BatchResponse, APIError> {
pub async fn create_batch(&mut self, req: CreateBatchRequest) -> Result<BatchResponse, APIError> {
self.post("batches", &req).await
}
pub async fn retrieve_batch(&self, batch_id: String) -> Result<BatchResponse, APIError> {
pub async fn retrieve_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
self.get(&format!("batches/{}", batch_id)).await
}
pub async fn cancel_batch(&self, batch_id: String) -> Result<BatchResponse, APIError> {
pub async fn cancel_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
self.post(
&format!("batches/{}/cancel", batch_id),
&common::EmptyRequestBody {},
@@ -743,7 +747,7 @@ impl OpenAIClient {
}
pub async fn list_batch(
&self,
&mut self,
after: Option<String>,
limit: Option<i64>,
) -> Result<ListBatchResponse, APIError> {