From 49298933bf9ec5a1bb017a54cab10509d43a755b Mon Sep 17 00:00:00 2001 From: Dongri Jin Date: Thu, 29 Aug 2024 00:18:26 +0900 Subject: [PATCH] Add batch api --- Cargo.toml | 3 ++ examples/batch.rs | 60 ++++++++++++++++++++++++++++ examples/data/batch_request.json | 1 + examples/data/batch_result.json | 33 ++++++++++++++++ src/v1/api.rs | 60 ++++++++++++++++++++-------- src/v1/batch.rs | 67 ++++++++++++++++++++++++++++++++ src/v1/common.rs | 3 ++ src/v1/file.rs | 41 ++----------------- src/v1/mod.rs | 1 + 9 files changed, 216 insertions(+), 53 deletions(-) create mode 100644 examples/batch.rs create mode 100644 examples/data/batch_request.json create mode 100644 examples/data/batch_result.json create mode 100644 src/v1/batch.rs diff --git a/Cargo.toml b/Cargo.toml index d494a68..729c137 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,3 +23,6 @@ features = ["derive"] [dependencies.serde_json] version = "1" + +[dependencies.bytes] +version = "1.7.1" diff --git a/examples/batch.rs b/examples/batch.rs new file mode 100644 index 0000000..0924d4f --- /dev/null +++ b/examples/batch.rs @@ -0,0 +1,60 @@ +use openai_api_rs::v1::api::OpenAIClient; +use openai_api_rs::v1::batch::CreateBatchRequest; +use openai_api_rs::v1::file::FileUploadRequest; +use serde_json::{from_str, to_string_pretty, Value}; +use std::env; +use std::fs::File; +use std::io::Write; +use std::str; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = OpenAIClient::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + + let req = FileUploadRequest::new( + "examples/data/batch_request.json".to_string(), + "batch".to_string(), + ); + + let result = client.upload_file(req).await?; + println!("File id: {:?}", result.id); + + let input_file_id = result.id; + let req = CreateBatchRequest::new( + input_file_id.clone(), + "/v1/chat/completions".to_string(), + "24h".to_string(), + ); + + let result = client.create_batch(req).await?; + println!("Batch id: {:?}", result.id); + + let batch_id = result.id; + let result = client.retrieve_batch(batch_id.to_string()).await?; + println!("Batch status: {:?}", result.status); + + // sleep 30 seconds + println!("Sleeping for 30 seconds..."); + tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; + + let result = client.retrieve_batch(batch_id.to_string()).await?; + + let file_id = result.output_file_id.unwrap(); + let result = client.retrieve_file_content(file_id).await?; + let s = match str::from_utf8(&result) { + Ok(v) => v.to_string(), + Err(e) => panic!("Invalid UTF-8 sequence: {}", e), + }; + let json_value: Value = from_str(&s)?; + let result_json = to_string_pretty(&json_value)?; + + let output_file_path = "examples/data/batch_result.json"; + let mut file = File::create(output_file_path)?; + file.write_all(result_json.as_bytes())?; + + println!("File writed to {:?}", output_file_path); + + Ok(()) +} + +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example batch diff --git a/examples/data/batch_request.json b/examples/data/batch_request.json new file mode 100644 index 0000000..07fc46d --- /dev/null +++ b/examples/data/batch_request.json @@ -0,0 +1 @@ +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is 2+2?"}]}} diff --git a/examples/data/batch_result.json b/examples/data/batch_result.json new file mode 100644 index 0000000..443eb27 --- /dev/null +++ b/examples/data/batch_result.json @@ -0,0 +1,33 @@ +{ + "custom_id": "request-1", + "error": null, + "id": "batch_req_403hYy7nMxrxXFWXiwvoLG1q", + "response": { + "body": { + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "2 + 2 equals 4.", + "refusal": null, + "role": "assistant" + } + } + ], + "created": 1724858089, + "id": "chatcmpl-A1Efhv97EZNQeHKSLPnTmZex20gf2", + "model": "gpt-4o-mini-2024-07-18", + "object": "chat.completion", + "system_fingerprint": "fp_f33667828e", + "usage": { + "completion_tokens": 8, + "prompt_tokens": 24, + "total_tokens": 32 + } + }, + "request_id": "af0bac0d82530234e09bd6b5d9fbf5cf", + "status_code": 200 + } +} \ No newline at end of file diff --git a/src/v1/api.rs b/src/v1/api.rs index 1f3cf66..3cfe421 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -6,15 +6,16 @@ use crate::v1::audio::{ AudioSpeechRequest, AudioSpeechResponse, AudioTranscriptionRequest, AudioTranscriptionResponse, AudioTranslationRequest, AudioTranslationResponse, }; +use crate::v1::batch::{BatchResponse, CreateBatchRequest, ListBatchResponse}; use crate::v1::chat_completion::{ChatCompletionRequest, ChatCompletionResponse}; +use crate::v1::common; use crate::v1::completion::{CompletionRequest, CompletionResponse}; use crate::v1::edit::{EditRequest, EditResponse}; use crate::v1::embedding::{EmbeddingRequest, EmbeddingResponse}; use crate::v1::error::APIError; use crate::v1::file::{ - FileDeleteRequest, FileDeleteResponse, FileListResponse, FileRetrieveContentRequest, - FileRetrieveContentResponse, FileRetrieveRequest, FileRetrieveResponse, FileUploadRequest, - FileUploadResponse, + FileDeleteRequest, FileDeleteResponse, FileListResponse, FileRetrieveResponse, + FileUploadRequest, FileUploadResponse, }; use crate::v1::fine_tuning::{ CancelFineTuningJobRequest, CreateFineTuningJobRequest, FineTuningJobEvent, @@ -36,6 +37,7 @@ use crate::v1::run::{ }; use crate::v1::thread::{CreateThreadRequest, ModifyThreadRequest, ThreadObject}; +use bytes::Bytes; use reqwest::multipart::{Form, Part}; use reqwest::{Client, Method, Response}; use serde::Serialize; @@ -156,6 +158,12 @@ impl OpenAIClient { self.handle_response(response).await } + async fn get_raw(&self, path: &str) -> Result { + let request = self.build_request(Method::GET, path).await; + let response = request.send().await?; + Ok(response.bytes().await?) + } + async fn delete(&self, path: &str) -> Result { let request = self.build_request(Method::DELETE, path).await; let response = request.send().await?; @@ -226,32 +234,27 @@ impl OpenAIClient { self.get("files").await } - pub async fn file_upload( + pub async fn upload_file( &self, req: FileUploadRequest, ) -> Result { - self.post("files", &req).await + let form = Self::create_form(&req, "file")?; + self.post_form("files", form).await } - pub async fn file_delete( + pub async fn delete_file( &self, req: FileDeleteRequest, ) -> Result { self.delete(&format!("files/{}", req.file_id)).await } - pub async fn file_retrieve( - &self, - req: FileRetrieveRequest, - ) -> Result { - self.get(&format!("files/{}", req.file_id)).await + pub async fn retrieve_file(&self, file_id: String) -> Result { + self.get(&format!("files/{}", file_id)).await } - pub async fn file_retrieve_content( - &self, - req: FileRetrieveContentRequest, - ) -> Result { - self.get(&format!("files/{}/content", req.file_id)).await + pub async fn retrieve_file_content(&self, file_id: String) -> Result { + self.get_raw(&format!("files/{}/content", file_id)).await } pub async fn chat_completion( @@ -636,6 +639,31 @@ impl OpenAIClient { self.get(&url).await } + pub async fn create_batch(&self, req: CreateBatchRequest) -> Result { + self.post("batches", &req).await + } + + pub async fn retrieve_batch(&self, batch_id: String) -> Result { + self.get(&format!("batches/{}", batch_id)).await + } + + pub async fn cancel_batch(&self, batch_id: String) -> Result { + self.post( + &format!("batches/{}/cancel", batch_id), + &common::EmptyRequestBody {}, + ) + .await + } + + pub async fn list_batch( + &self, + after: Option, + limit: Option, + ) -> Result { + let url = Self::query_params(limit, None, after, None, "batches".to_string()); + self.get(&url).await + } + fn query_params( limit: Option, order: Option, diff --git a/src/v1/batch.rs b/src/v1/batch.rs new file mode 100644 index 0000000..05e889d --- /dev/null +++ b/src/v1/batch.rs @@ -0,0 +1,67 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct CreateBatchRequest { + pub input_file_id: String, + pub endpoint: String, + pub completion_window: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Metadata { + pub customer_id: String, + pub batch_description: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RequestCounts { + pub total: u32, + pub completed: u32, + pub failed: u32, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct BatchResponse { + pub id: String, + pub object: String, + pub endpoint: String, + pub errors: Option>, + pub input_file_id: String, + pub completion_window: String, + pub status: String, + pub output_file_id: Option, + pub error_file_id: Option, + pub created_at: u64, + pub in_progress_at: Option, + pub expires_at: Option, + pub finalizing_at: Option, + pub completed_at: Option, + pub failed_at: Option, + pub expired_at: Option, + pub cancelling_at: Option, + pub cancelled_at: Option, + pub request_counts: RequestCounts, + pub metadata: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ListBatchResponse { + pub object: String, + pub data: Vec, + pub first_id: String, + pub last_id: String, + pub has_more: bool, +} + +impl CreateBatchRequest { + pub fn new(input_file_id: String, endpoint: String, completion_window: String) -> Self { + Self { + input_file_id, + endpoint, + completion_window, + metadata: None, + } + } +} diff --git a/src/v1/common.rs b/src/v1/common.rs index ab3fbc0..1b077e0 100644 --- a/src/v1/common.rs +++ b/src/v1/common.rs @@ -21,6 +21,9 @@ macro_rules! impl_builder_methods { }; } +#[derive(Debug, Serialize, Deserialize)] +pub struct EmptyRequestBody {} + // https://platform.openai.com/docs/models/gpt-4o-mini pub const GPT4_O_MINI: &str = "gpt-4o-mini"; diff --git a/src/v1/file.rs b/src/v1/file.rs index 4bf1c7e..d475a31 100644 --- a/src/v1/file.rs +++ b/src/v1/file.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize, Serialize)] pub struct FileData { pub id: String, - pub oejct: String, + pub object: String, pub bytes: i32, pub created_at: i64, pub filename: String, @@ -34,7 +34,7 @@ impl FileUploadRequest { #[derive(Debug, Deserialize, Serialize)] pub struct FileUploadResponse { pub id: String, - pub oejct: String, + pub object: String, pub bytes: i32, pub created_at: i64, pub filename: String, @@ -56,48 +56,15 @@ impl FileDeleteRequest { #[derive(Debug, Deserialize, Serialize)] pub struct FileDeleteResponse { pub id: String, - pub oejct: String, + pub object: String, pub delete: bool, pub headers: Option>, } -#[derive(Debug, Serialize)] -pub struct FileRetrieveRequest { - pub file_id: String, -} - -impl FileRetrieveRequest { - pub fn new(file_id: String) -> Self { - Self { file_id } - } -} - #[derive(Debug, Deserialize, Serialize)] pub struct FileRetrieveResponse { pub id: String, - pub oejct: String, - pub bytes: i32, - pub created_at: i64, - pub filename: String, - pub purpose: String, - pub headers: Option>, -} - -#[derive(Debug, Serialize)] -pub struct FileRetrieveContentRequest { - pub file_id: String, -} - -impl FileRetrieveContentRequest { - pub fn new(file_id: String) -> Self { - Self { file_id } - } -} - -#[derive(Debug, Deserialize, Serialize)] -pub struct FileRetrieveContentResponse { - pub id: String, - pub oejct: String, + pub object: String, pub bytes: i32, pub created_at: i64, pub filename: String, diff --git a/src/v1/mod.rs b/src/v1/mod.rs index 856848e..6eaf80c 100644 --- a/src/v1/mod.rs +++ b/src/v1/mod.rs @@ -2,6 +2,7 @@ pub mod common; pub mod error; pub mod audio; +pub mod batch; pub mod chat_completion; pub mod completion; pub mod edit;