mirror of
https://github.com/mii443/openai-api-rs.git
synced 2025-12-04 19:48:20 +00:00
Add headers to client
This commit is contained in:
122
src/v1/api.rs
122
src/v1/api.rs
@@ -68,7 +68,7 @@ pub struct OpenAIClient {
|
||||
organization: Option<String>,
|
||||
proxy: Option<String>,
|
||||
timeout: Option<u64>,
|
||||
headers: Option<HeaderMap>,
|
||||
pub headers: Option<HeaderMap>,
|
||||
}
|
||||
|
||||
impl OpenAIClientBuilder {
|
||||
@@ -175,7 +175,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
async fn post<T: serde::de::DeserializeOwned>(
|
||||
&self,
|
||||
&mut self,
|
||||
path: &str,
|
||||
body: &impl serde::ser::Serialize,
|
||||
) -> Result<T, APIError> {
|
||||
@@ -185,7 +185,7 @@ impl OpenAIClient {
|
||||
self.handle_response(response).await
|
||||
}
|
||||
|
||||
async fn get<T: serde::de::DeserializeOwned>(&self, path: &str) -> Result<T, APIError> {
|
||||
async fn get<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
|
||||
let request = self.build_request(Method::GET, path).await;
|
||||
let response = request.send().await?;
|
||||
self.handle_response(response).await
|
||||
@@ -197,14 +197,14 @@ impl OpenAIClient {
|
||||
Ok(response.bytes().await?)
|
||||
}
|
||||
|
||||
async fn delete<T: serde::de::DeserializeOwned>(&self, path: &str) -> Result<T, APIError> {
|
||||
async fn delete<T: serde::de::DeserializeOwned>(&mut self, path: &str) -> Result<T, APIError> {
|
||||
let request = self.build_request(Method::DELETE, path).await;
|
||||
let response = request.send().await?;
|
||||
self.handle_response(response).await
|
||||
}
|
||||
|
||||
async fn post_form<T: serde::de::DeserializeOwned>(
|
||||
&self,
|
||||
&mut self,
|
||||
path: &str,
|
||||
form: Form,
|
||||
) -> Result<T, APIError> {
|
||||
@@ -222,14 +222,18 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
async fn handle_response<T: serde::de::DeserializeOwned>(
|
||||
&self,
|
||||
&mut self,
|
||||
response: Response,
|
||||
) -> Result<T, APIError> {
|
||||
let status = response.status();
|
||||
let headers = response.headers().clone();
|
||||
if status.is_success() {
|
||||
let text = response.text().await.unwrap_or_else(|_| "".to_string());
|
||||
match serde_json::from_str::<T>(&text) {
|
||||
Ok(parsed) => Ok(parsed),
|
||||
Ok(parsed) => {
|
||||
self.headers = Some(headers);
|
||||
Ok(parsed)
|
||||
},
|
||||
Err(e) => Err(APIError::CustomError {
|
||||
message: format!("Failed to parse JSON: {} / response {}", e, text),
|
||||
}),
|
||||
@@ -245,42 +249,42 @@ impl OpenAIClient {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn completion(&self, req: CompletionRequest) -> Result<CompletionResponse, APIError> {
|
||||
pub async fn completion(&mut self, req: CompletionRequest) -> Result<CompletionResponse, APIError> {
|
||||
self.post("completions", &req).await
|
||||
}
|
||||
|
||||
pub async fn edit(&self, req: EditRequest) -> Result<EditResponse, APIError> {
|
||||
pub async fn edit(&mut self, req: EditRequest) -> Result<EditResponse, APIError> {
|
||||
self.post("edits", &req).await
|
||||
}
|
||||
|
||||
pub async fn image_generation(
|
||||
&self,
|
||||
&mut self,
|
||||
req: ImageGenerationRequest,
|
||||
) -> Result<ImageGenerationResponse, APIError> {
|
||||
self.post("images/generations", &req).await
|
||||
}
|
||||
|
||||
pub async fn image_edit(&self, req: ImageEditRequest) -> Result<ImageEditResponse, APIError> {
|
||||
pub async fn image_edit(&mut self, req: ImageEditRequest) -> Result<ImageEditResponse, APIError> {
|
||||
self.post("images/edits", &req).await
|
||||
}
|
||||
|
||||
pub async fn image_variation(
|
||||
&self,
|
||||
&mut self,
|
||||
req: ImageVariationRequest,
|
||||
) -> Result<ImageVariationResponse, APIError> {
|
||||
self.post("images/variations", &req).await
|
||||
}
|
||||
|
||||
pub async fn embedding(&self, req: EmbeddingRequest) -> Result<EmbeddingResponse, APIError> {
|
||||
pub async fn embedding(&mut self, req: EmbeddingRequest) -> Result<EmbeddingResponse, APIError> {
|
||||
self.post("embeddings", &req).await
|
||||
}
|
||||
|
||||
pub async fn file_list(&self) -> Result<FileListResponse, APIError> {
|
||||
pub async fn file_list(&mut self) -> Result<FileListResponse, APIError> {
|
||||
self.get("files").await
|
||||
}
|
||||
|
||||
pub async fn upload_file(
|
||||
&self,
|
||||
&mut self,
|
||||
req: FileUploadRequest,
|
||||
) -> Result<FileUploadResponse, APIError> {
|
||||
let form = Self::create_form(&req, "file")?;
|
||||
@@ -288,13 +292,13 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn delete_file(
|
||||
&self,
|
||||
&mut self,
|
||||
req: FileDeleteRequest,
|
||||
) -> Result<FileDeleteResponse, APIError> {
|
||||
self.delete(&format!("files/{}", req.file_id)).await
|
||||
}
|
||||
|
||||
pub async fn retrieve_file(&self, file_id: String) -> Result<FileRetrieveResponse, APIError> {
|
||||
pub async fn retrieve_file(&mut self, file_id: String) -> Result<FileRetrieveResponse, APIError> {
|
||||
self.get(&format!("files/{}", file_id)).await
|
||||
}
|
||||
|
||||
@@ -303,14 +307,14 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn chat_completion(
|
||||
&self,
|
||||
&mut self,
|
||||
req: ChatCompletionRequest,
|
||||
) -> Result<ChatCompletionResponse, APIError> {
|
||||
self.post("chat/completions", &req).await
|
||||
}
|
||||
|
||||
pub async fn audio_transcription(
|
||||
&self,
|
||||
&mut self,
|
||||
req: AudioTranscriptionRequest,
|
||||
) -> Result<AudioTranscriptionResponse, APIError> {
|
||||
// https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
|
||||
@@ -335,7 +339,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn audio_transcription_raw(
|
||||
&self,
|
||||
&mut self,
|
||||
req: AudioTranscriptionRequest,
|
||||
) -> Result<Bytes, APIError> {
|
||||
// https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-response_format
|
||||
@@ -360,7 +364,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn audio_translation(
|
||||
&self,
|
||||
&mut self,
|
||||
req: AudioTranslationRequest,
|
||||
) -> Result<AudioTranslationResponse, APIError> {
|
||||
let form = Self::create_form(&req, "file")?;
|
||||
@@ -368,7 +372,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn audio_speech(
|
||||
&self,
|
||||
&mut self,
|
||||
req: AudioSpeechRequest,
|
||||
) -> Result<AudioSpeechResponse, APIError> {
|
||||
let request = self.build_request(Method::POST, "audio/speech").await;
|
||||
@@ -410,20 +414,20 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn create_fine_tuning_job(
|
||||
&self,
|
||||
&mut self,
|
||||
req: CreateFineTuningJobRequest,
|
||||
) -> Result<FineTuningJobObject, APIError> {
|
||||
self.post("fine_tuning/jobs", &req).await
|
||||
}
|
||||
|
||||
pub async fn list_fine_tuning_jobs(
|
||||
&self,
|
||||
&mut self,
|
||||
) -> Result<FineTuningPagination<FineTuningJobObject>, APIError> {
|
||||
self.get("fine_tuning/jobs").await
|
||||
}
|
||||
|
||||
pub async fn list_fine_tuning_job_events(
|
||||
&self,
|
||||
&mut self,
|
||||
req: ListFineTuningJobEventsRequest,
|
||||
) -> Result<FineTuningPagination<FineTuningJobEvent>, APIError> {
|
||||
self.get(&format!(
|
||||
@@ -434,7 +438,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn retrieve_fine_tuning_job(
|
||||
&self,
|
||||
&mut self,
|
||||
req: RetrieveFineTuningJobRequest,
|
||||
) -> Result<FineTuningJobObject, APIError> {
|
||||
self.get(&format!("fine_tuning/jobs/{}", req.fine_tuning_job_id))
|
||||
@@ -442,7 +446,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn cancel_fine_tuning_job(
|
||||
&self,
|
||||
&mut self,
|
||||
req: CancelFineTuningJobRequest,
|
||||
) -> Result<FineTuningJobObject, APIError> {
|
||||
self.post(
|
||||
@@ -453,28 +457,28 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn create_moderation(
|
||||
&self,
|
||||
&mut self,
|
||||
req: CreateModerationRequest,
|
||||
) -> Result<CreateModerationResponse, APIError> {
|
||||
self.post("moderations", &req).await
|
||||
}
|
||||
|
||||
pub async fn create_assistant(
|
||||
&self,
|
||||
&mut self,
|
||||
req: AssistantRequest,
|
||||
) -> Result<AssistantObject, APIError> {
|
||||
self.post("assistants", &req).await
|
||||
}
|
||||
|
||||
pub async fn retrieve_assistant(
|
||||
&self,
|
||||
&mut self,
|
||||
assistant_id: String,
|
||||
) -> Result<AssistantObject, APIError> {
|
||||
self.get(&format!("assistants/{}", assistant_id)).await
|
||||
}
|
||||
|
||||
pub async fn modify_assistant(
|
||||
&self,
|
||||
&mut self,
|
||||
assistant_id: String,
|
||||
req: AssistantRequest,
|
||||
) -> Result<AssistantObject, APIError> {
|
||||
@@ -482,12 +486,12 @@ impl OpenAIClient {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn delete_assistant(&self, assistant_id: String) -> Result<DeletionStatus, APIError> {
|
||||
pub async fn delete_assistant(&mut self, assistant_id: String) -> Result<DeletionStatus, APIError> {
|
||||
self.delete(&format!("assistants/{}", assistant_id)).await
|
||||
}
|
||||
|
||||
pub async fn list_assistant(
|
||||
&self,
|
||||
&mut self,
|
||||
limit: Option<i64>,
|
||||
order: Option<String>,
|
||||
after: Option<String>,
|
||||
@@ -498,7 +502,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn create_assistant_file(
|
||||
&self,
|
||||
&mut self,
|
||||
assistant_id: String,
|
||||
req: AssistantFileRequest,
|
||||
) -> Result<AssistantFileObject, APIError> {
|
||||
@@ -507,7 +511,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn retrieve_assistant_file(
|
||||
&self,
|
||||
&mut self,
|
||||
assistant_id: String,
|
||||
file_id: String,
|
||||
) -> Result<AssistantFileObject, APIError> {
|
||||
@@ -516,7 +520,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn delete_assistant_file(
|
||||
&self,
|
||||
&mut self,
|
||||
assistant_id: String,
|
||||
file_id: String,
|
||||
) -> Result<DeletionStatus, APIError> {
|
||||
@@ -525,7 +529,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn list_assistant_file(
|
||||
&self,
|
||||
&mut self,
|
||||
assistant_id: String,
|
||||
limit: Option<i64>,
|
||||
order: Option<String>,
|
||||
@@ -542,28 +546,28 @@ impl OpenAIClient {
|
||||
self.get(&url).await
|
||||
}
|
||||
|
||||
pub async fn create_thread(&self, req: CreateThreadRequest) -> Result<ThreadObject, APIError> {
|
||||
pub async fn create_thread(&mut self, req: CreateThreadRequest) -> Result<ThreadObject, APIError> {
|
||||
self.post("threads", &req).await
|
||||
}
|
||||
|
||||
pub async fn retrieve_thread(&self, thread_id: String) -> Result<ThreadObject, APIError> {
|
||||
pub async fn retrieve_thread(&mut self, thread_id: String) -> Result<ThreadObject, APIError> {
|
||||
self.get(&format!("threads/{}", thread_id)).await
|
||||
}
|
||||
|
||||
pub async fn modify_thread(
|
||||
&self,
|
||||
&mut self,
|
||||
thread_id: String,
|
||||
req: ModifyThreadRequest,
|
||||
) -> Result<ThreadObject, APIError> {
|
||||
self.post(&format!("threads/{}", thread_id), &req).await
|
||||
}
|
||||
|
||||
pub async fn delete_thread(&self, thread_id: String) -> Result<DeletionStatus, APIError> {
|
||||
pub async fn delete_thread(&mut self, thread_id: String) -> Result<DeletionStatus, APIError> {
|
||||
self.delete(&format!("threads/{}", thread_id)).await
|
||||
}
|
||||
|
||||
pub async fn create_message(
|
||||
&self,
|
||||
&mut self,
|
||||
thread_id: String,
|
||||
req: CreateMessageRequest,
|
||||
) -> Result<MessageObject, APIError> {
|
||||
@@ -572,7 +576,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn retrieve_message(
|
||||
&self,
|
||||
&mut self,
|
||||
thread_id: String,
|
||||
message_id: String,
|
||||
) -> Result<MessageObject, APIError> {
|
||||
@@ -581,7 +585,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn modify_message(
|
||||
&self,
|
||||
&mut self,
|
||||
thread_id: String,
|
||||
message_id: String,
|
||||
req: ModifyMessageRequest,
|
||||
@@ -593,12 +597,12 @@ impl OpenAIClient {
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn list_messages(&self, thread_id: String) -> Result<ListMessage, APIError> {
|
||||
pub async fn list_messages(&mut self, thread_id: String) -> Result<ListMessage, APIError> {
|
||||
self.get(&format!("threads/{}/messages", thread_id)).await
|
||||
}
|
||||
|
||||
pub async fn retrieve_message_file(
|
||||
&self,
|
||||
&mut self,
|
||||
thread_id: String,
|
||||
message_id: String,
|
||||
file_id: String,
|
||||
@@ -611,7 +615,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn list_message_file(
|
||||
&self,
|
||||
&mut self,
|
||||
thread_id: String,
|
||||
message_id: String,
|
||||
limit: Option<i64>,
|
||||
@@ -630,7 +634,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn create_run(
|
||||
&self,
|
||||
&mut self,
|
||||
thread_id: String,
|
||||
req: CreateRunRequest,
|
||||
) -> Result<RunObject, APIError> {
|
||||
@@ -639,7 +643,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn retrieve_run(
|
||||
&self,
|
||||
&mut self,
|
||||
thread_id: String,
|
||||
run_id: String,
|
||||
) -> Result<RunObject, APIError> {
|
||||
@@ -648,7 +652,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn modify_run(
|
||||
&self,
|
||||
&mut self,
|
||||
thread_id: String,
|
||||
run_id: String,
|
||||
req: ModifyRunRequest,
|
||||
@@ -658,7 +662,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn list_run(
|
||||
&self,
|
||||
&mut self,
|
||||
thread_id: String,
|
||||
limit: Option<i64>,
|
||||
order: Option<String>,
|
||||
@@ -676,7 +680,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn cancel_run(
|
||||
&self,
|
||||
&mut self,
|
||||
thread_id: String,
|
||||
run_id: String,
|
||||
) -> Result<RunObject, APIError> {
|
||||
@@ -688,14 +692,14 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn create_thread_and_run(
|
||||
&self,
|
||||
&mut self,
|
||||
req: CreateThreadAndRunRequest,
|
||||
) -> Result<RunObject, APIError> {
|
||||
self.post("threads/runs", &req).await
|
||||
}
|
||||
|
||||
pub async fn retrieve_run_step(
|
||||
&self,
|
||||
&mut self,
|
||||
thread_id: String,
|
||||
run_id: String,
|
||||
step_id: String,
|
||||
@@ -708,7 +712,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn list_run_step(
|
||||
&self,
|
||||
&mut self,
|
||||
thread_id: String,
|
||||
run_id: String,
|
||||
limit: Option<i64>,
|
||||
@@ -726,15 +730,15 @@ impl OpenAIClient {
|
||||
self.get(&url).await
|
||||
}
|
||||
|
||||
pub async fn create_batch(&self, req: CreateBatchRequest) -> Result<BatchResponse, APIError> {
|
||||
pub async fn create_batch(&mut self, req: CreateBatchRequest) -> Result<BatchResponse, APIError> {
|
||||
self.post("batches", &req).await
|
||||
}
|
||||
|
||||
pub async fn retrieve_batch(&self, batch_id: String) -> Result<BatchResponse, APIError> {
|
||||
pub async fn retrieve_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
|
||||
self.get(&format!("batches/{}", batch_id)).await
|
||||
}
|
||||
|
||||
pub async fn cancel_batch(&self, batch_id: String) -> Result<BatchResponse, APIError> {
|
||||
pub async fn cancel_batch(&mut self, batch_id: String) -> Result<BatchResponse, APIError> {
|
||||
self.post(
|
||||
&format!("batches/{}/cancel", batch_id),
|
||||
&common::EmptyRequestBody {},
|
||||
@@ -743,7 +747,7 @@ impl OpenAIClient {
|
||||
}
|
||||
|
||||
pub async fn list_batch(
|
||||
&self,
|
||||
&mut self,
|
||||
after: Option<String>,
|
||||
limit: Option<i64>,
|
||||
) -> Result<ListBatchResponse, APIError> {
|
||||
|
||||
Reference in New Issue
Block a user