Add batch api

This commit is contained in:
Dongri Jin
2024-08-29 00:18:26 +09:00
parent cd13f29727
commit 49298933bf
9 changed files with 216 additions and 53 deletions

View File

@ -6,15 +6,16 @@ use crate::v1::audio::{
AudioSpeechRequest, AudioSpeechResponse, AudioTranscriptionRequest, AudioTranscriptionResponse,
AudioTranslationRequest, AudioTranslationResponse,
};
use crate::v1::batch::{BatchResponse, CreateBatchRequest, ListBatchResponse};
use crate::v1::chat_completion::{ChatCompletionRequest, ChatCompletionResponse};
use crate::v1::common;
use crate::v1::completion::{CompletionRequest, CompletionResponse};
use crate::v1::edit::{EditRequest, EditResponse};
use crate::v1::embedding::{EmbeddingRequest, EmbeddingResponse};
use crate::v1::error::APIError;
use crate::v1::file::{
FileDeleteRequest, FileDeleteResponse, FileListResponse, FileRetrieveContentRequest,
FileRetrieveContentResponse, FileRetrieveRequest, FileRetrieveResponse, FileUploadRequest,
FileUploadResponse,
FileDeleteRequest, FileDeleteResponse, FileListResponse, FileRetrieveResponse,
FileUploadRequest, FileUploadResponse,
};
use crate::v1::fine_tuning::{
CancelFineTuningJobRequest, CreateFineTuningJobRequest, FineTuningJobEvent,
@ -36,6 +37,7 @@ use crate::v1::run::{
};
use crate::v1::thread::{CreateThreadRequest, ModifyThreadRequest, ThreadObject};
use bytes::Bytes;
use reqwest::multipart::{Form, Part};
use reqwest::{Client, Method, Response};
use serde::Serialize;
@ -156,6 +158,12 @@ impl OpenAIClient {
self.handle_response(response).await
}
async fn get_raw(&self, path: &str) -> Result<Bytes, APIError> {
let request = self.build_request(Method::GET, path).await;
let response = request.send().await?;
Ok(response.bytes().await?)
}
async fn delete<T: serde::de::DeserializeOwned>(&self, path: &str) -> Result<T, APIError> {
let request = self.build_request(Method::DELETE, path).await;
let response = request.send().await?;
@ -226,32 +234,27 @@ impl OpenAIClient {
self.get("files").await
}
pub async fn file_upload(
pub async fn upload_file(
&self,
req: FileUploadRequest,
) -> Result<FileUploadResponse, APIError> {
self.post("files", &req).await
let form = Self::create_form(&req, "file")?;
self.post_form("files", form).await
}
pub async fn file_delete(
pub async fn delete_file(
&self,
req: FileDeleteRequest,
) -> Result<FileDeleteResponse, APIError> {
self.delete(&format!("files/{}", req.file_id)).await
}
pub async fn file_retrieve(
&self,
req: FileRetrieveRequest,
) -> Result<FileRetrieveResponse, APIError> {
self.get(&format!("files/{}", req.file_id)).await
pub async fn retrieve_file(&self, file_id: String) -> Result<FileRetrieveResponse, APIError> {
self.get(&format!("files/{}", file_id)).await
}
pub async fn file_retrieve_content(
&self,
req: FileRetrieveContentRequest,
) -> Result<FileRetrieveContentResponse, APIError> {
self.get(&format!("files/{}/content", req.file_id)).await
pub async fn retrieve_file_content(&self, file_id: String) -> Result<Bytes, APIError> {
self.get_raw(&format!("files/{}/content", file_id)).await
}
pub async fn chat_completion(
@ -636,6 +639,31 @@ impl OpenAIClient {
self.get(&url).await
}
pub async fn create_batch(&self, req: CreateBatchRequest) -> Result<BatchResponse, APIError> {
self.post("batches", &req).await
}
pub async fn retrieve_batch(&self, batch_id: String) -> Result<BatchResponse, APIError> {
self.get(&format!("batches/{}", batch_id)).await
}
pub async fn cancel_batch(&self, batch_id: String) -> Result<BatchResponse, APIError> {
self.post(
&format!("batches/{}/cancel", batch_id),
&common::EmptyRequestBody {},
)
.await
}
pub async fn list_batch(
&self,
after: Option<String>,
limit: Option<i64>,
) -> Result<ListBatchResponse, APIError> {
let url = Self::query_params(limit, None, after, None, "batches".to_string());
self.get(&url).await
}
fn query_params(
limit: Option<i64>,
order: Option<String>,