From d3223befbdcba15ed9238fa9981474e8c4a6098f Mon Sep 17 00:00:00 2001 From: Dongri Jin Date: Sun, 5 Mar 2023 07:37:39 +0900 Subject: [PATCH] Add custom error --- Cargo.toml | 2 +- README.md | 2 +- src/v1/api.rs | 65 ++++++++++++++++++++++++++----------------------- src/v1/error.rs | 15 ++++++++++++ src/v1/mod.rs | 1 + 5 files changed, 53 insertions(+), 32 deletions(-) create mode 100644 src/v1/error.rs diff --git a/Cargo.toml b/Cargo.toml index 6899636..555d55d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "openai-api-rs" -version = "0.1.3" +version = "0.1.4" edition = "2021" authors = ["Dongri Jin "] license = "MIT" diff --git a/README.md b/README.md index cb899c1..fbf4176 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Cargo.toml ```toml [dependencies] -openai-api-rs = "0.1.3" +openai-api-rs = "0.1.4" ``` ## Example: diff --git a/src/v1/api.rs b/src/v1/api.rs index c6a4deb..7463756 100644 --- a/src/v1/api.rs +++ b/src/v1/api.rs @@ -2,6 +2,7 @@ use crate::v1::chat_completion::{ChatCompletionRequest, ChatCompletionResponse}; use crate::v1::completion::{CompletionRequest, CompletionResponse}; use crate::v1::edit::{EditRequest, EditResponse}; use crate::v1::embedding::{EmbeddingRequest, EmbeddingResponse}; +use crate::v1::error::APIError; use crate::v1::file::{ FileDeleteRequest, FileDeleteResponse, FileListResponse, FileRetrieveContentRequest, FileRetrieveContentResponse, FileRetrieveRequest, FileRetrieveResponse, FileUploadRequest, @@ -12,7 +13,6 @@ use crate::v1::image::{ ImageVariationRequest, ImageVariationResponse, }; use reqwest::Response; -use std::io::Error; const APU_URL_V1: &str = "https://api.openai.com/v1"; @@ -29,7 +29,7 @@ impl Client { &self, path: &str, params: &T, - ) -> Result { + ) -> Result { let client = reqwest::Client::new(); let url = format!("{APU_URL_V1}{path}"); let res = client @@ -45,16 +45,15 @@ impl Client { match res { Ok(res) => match res.status().is_success() { true => Ok(res), - false => Err(Error::new( - std::io::ErrorKind::Other, - format!("{}: {}", res.status(), res.text().await.unwrap()), - )), + false => Err(APIError { + message: format!("{}: {}", res.status(), res.text().await.unwrap()), + }), }, Err(e) => Err(self.new_error(e)), } } - pub async fn get(&self, path: &str) -> Result { + pub async fn get(&self, path: &str) -> Result { let client = reqwest::Client::new(); let url = format!("{APU_URL_V1}{path}"); let res = client @@ -69,16 +68,15 @@ impl Client { match res { Ok(res) => match res.status().is_success() { true => Ok(res), - false => Err(Error::new( - std::io::ErrorKind::Other, - format!("{}: {}", res.status(), res.text().await.unwrap()), - )), + false => Err(APIError { + message: format!("{}: {}", res.status(), res.text().await.unwrap()), + }), }, Err(e) => Err(self.new_error(e)), } } - pub async fn delete(&self, path: &str) -> Result { + pub async fn delete(&self, path: &str) -> Result { let client = reqwest::Client::new(); let url = format!("{APU_URL_V1}{path}"); let res = client @@ -93,16 +91,15 @@ impl Client { match res { Ok(res) => match res.status().is_success() { true => Ok(res), - false => Err(Error::new( - std::io::ErrorKind::Other, - format!("{}: {}", res.status(), res.text().await.unwrap()), - )), + false => Err(APIError { + message: format!("{}: {}", res.status(), res.text().await.unwrap()), + }), }, Err(e) => Err(self.new_error(e)), } } - pub async fn completion(&self, req: CompletionRequest) -> Result { + pub async fn completion(&self, req: CompletionRequest) -> Result { let res = self.post("/completions", &req).await?; let r = res.json::().await; match r { @@ -111,7 +108,7 @@ impl Client { } } - pub async fn edit(&self, req: EditRequest) -> Result { + pub async fn edit(&self, req: EditRequest) -> Result { let res = self.post("/edits", &req).await?; let r = res.json::().await; match r { @@ -123,7 +120,7 @@ impl Client { pub async fn image_generation( &self, req: ImageGenerationRequest, - ) -> Result { + ) -> Result { let res = self.post("/images/generations", &req).await?; let r = res.json::().await; match r { @@ -132,7 +129,7 @@ impl Client { } } - pub async fn image_edit(&self, req: ImageEditRequest) -> Result { + pub async fn image_edit(&self, req: ImageEditRequest) -> Result { let res = self.post("/images/edits", &req).await?; let r = res.json::().await; match r { @@ -144,7 +141,7 @@ impl Client { pub async fn image_variation( &self, req: ImageVariationRequest, - ) -> Result { + ) -> Result { let res = self.post("/images/variations", &req).await?; let r = res.json::().await; match r { @@ -153,7 +150,7 @@ impl Client { } } - pub async fn embedding(&self, req: EmbeddingRequest) -> Result { + pub async fn embedding(&self, req: EmbeddingRequest) -> Result { let res = self.post("/embeddings", &req).await?; let r = res.json::().await; match r { @@ -162,7 +159,7 @@ impl Client { } } - pub async fn file_list(&self) -> Result { + pub async fn file_list(&self) -> Result { let res = self.get("/files").await?; let r = res.json::().await; match r { @@ -171,7 +168,10 @@ impl Client { } } - pub async fn file_upload(&self, req: FileUploadRequest) -> Result { + pub async fn file_upload( + &self, + req: FileUploadRequest, + ) -> Result { let res = self.post("/files", &req).await?; let r = res.json::().await; match r { @@ -180,7 +180,10 @@ impl Client { } } - pub async fn file_delete(&self, req: FileDeleteRequest) -> Result { + pub async fn file_delete( + &self, + req: FileDeleteRequest, + ) -> Result { let res = self .delete(&format!("{}/{}", "/files", req.file_id)) .await?; @@ -194,7 +197,7 @@ impl Client { pub async fn file_retrieve( &self, req: FileRetrieveRequest, - ) -> Result { + ) -> Result { let res = self.get(&format!("{}/{}", "/files", req.file_id)).await?; let r = res.json::().await; match r { @@ -206,7 +209,7 @@ impl Client { pub async fn file_retrieve_content( &self, req: FileRetrieveContentRequest, - ) -> Result { + ) -> Result { let res = self .get(&format!("{}/{}/content", "/files", req.file_id)) .await?; @@ -220,7 +223,7 @@ impl Client { pub async fn chat_completion( &self, req: ChatCompletionRequest, - ) -> Result { + ) -> Result { let res = self.post("/chat/completions", &req).await?; let r = res.json::().await; match r { @@ -229,7 +232,9 @@ impl Client { } } - fn new_error(&self, err: reqwest::Error) -> Error { - Error::new(std::io::ErrorKind::Other, err) + fn new_error(&self, err: reqwest::Error) -> APIError { + APIError { + message: err.to_string(), + } } } diff --git a/src/v1/error.rs b/src/v1/error.rs new file mode 100644 index 0000000..b2d2bf0 --- /dev/null +++ b/src/v1/error.rs @@ -0,0 +1,15 @@ +use std::error::Error; +use std::fmt; + +#[derive(Debug)] +pub struct APIError { + pub message: String, +} + +impl fmt::Display for APIError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "APIError: {}", self.message) + } +} + +impl Error for APIError {} diff --git a/src/v1/mod.rs b/src/v1/mod.rs index 8c84b8a..c8ca2ad 100644 --- a/src/v1/mod.rs +++ b/src/v1/mod.rs @@ -1,4 +1,5 @@ pub mod common; +pub mod error; pub mod chat_completion; pub mod completion;