commit c2ce8b4133e1a00fdf67b86e695210b693d1d7d5 Author: Dongri Jin Date: Mon Dec 12 11:41:43 2022 +0900 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..0fb0b11 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "openai-rs" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +reqwest = { version = "0.11", features = ["json"] } +tokio = { version = "1", features = ["full"] } +serde = { version = "1", features = ["derive"] } \ No newline at end of file diff --git a/examples/completion.rs b/examples/completion.rs new file mode 100644 index 0000000..2511927 --- /dev/null +++ b/examples/completion.rs @@ -0,0 +1,32 @@ +use openai_rs::v1::completion::{self, CompletionRequest}; +use openai_rs::v1::api::Client; +use std::env; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + let req = CompletionRequest { + model: completion::GPT3_TEXT_DAVINCI_003.to_string(), + prompt: Some(String::from("NFTとは何か?")), + suffix: None, + max_tokens: Some(3000), + temperature: Some(0.9), + top_p: Some(1.0), + n: None, + stream: None, + logprobs: None, + echo: None, + stop: Some(vec![String::from(" Human:"), String::from(" AI:")]), + presence_penalty: Some(0.6), + frequency_penalty: Some(0.0), + best_of: None, + logit_bias: None, + user: None, + }; + let completion_response = client.completion(req).await?; + println!("{:?}", completion_response.choices[0].text); + + Ok(()) +} + +// cargo run --package openai-rs --example completion diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..a3a6d96 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1 @@ +pub mod v1; diff --git a/src/v1/api.rs b/src/v1/api.rs new file mode 100644 index 0000000..1fb6b37 --- /dev/null +++ b/src/v1/api.rs @@ -0,0 +1,250 @@ + +use crate::v1::completion::{CompletionRequest, CompletionResponse}; +use crate::v1::edit::{EditRequest, EditResponse}; +use crate::v1::image::{ + ImageGenerationRequest, + ImageGenerationResponse, + ImageEditRequest, + ImageEditResponse, + ImageVariationRequest, + ImageVariationResponse, +}; +use crate::v1::embedding::{EmbeddingRequest, EmbeddingResponse}; +use crate::v1::file::{ + FileListResponse, + FileUploadRequest, + FileUploadResponse, + FileDeleteRequest, + FileDeleteResponse, + FileRetrieveRequest, + FileRetrieveResponse, + FileRetrieveContentRequest, + FileRetrieveContentResponse, +}; +use reqwest::Response; + +const APU_URL_V1: &str = "https://api.openai.com/v1"; + +pub struct Client { + pub api_key: String, +} + +impl Client { + pub fn new(api_key: String) -> Self { + Self { api_key } + } + + pub async fn post(&self, path: &str, params: &T) -> Result> { + let client = reqwest::Client::new(); + let url = format!("{}{}", APU_URL_V1, path); + let res = client + .post(&url) + .header(reqwest::header::CONTENT_TYPE, "application/json") + .header(reqwest::header::AUTHORIZATION, "Bearer ".to_owned() + &self.api_key) + .json(¶ms) + .send() + .await; + match res { + Ok(res) => match res.status().is_success() { + true => Ok(res), + false => { + Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + format!("{}: {}", res.status(), res.text().await.unwrap()) + ))) + }, + }, + Err(e) => Err(Box::new(e)), + } + } + + pub async fn get(&self, path: &str) -> Result> { + let client = reqwest::Client::new(); + let url = format!("{}{}", APU_URL_V1, path); + let res = client + .get(&url) + .header(reqwest::header::CONTENT_TYPE, "application/json") + .header(reqwest::header::AUTHORIZATION, "Bearer ".to_owned() + &self.api_key) + .send() + .await; + match res { + Ok(res) => match res.status().is_success() { + true => Ok(res), + false => { + Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + format!("{}: {}", res.status(), res.text().await.unwrap()) + ))) + }, + }, + Err(e) => Err(Box::new(e)), + } + } + + pub async fn delete(&self, path: &str) -> Result> { + let client = reqwest::Client::new(); + let url = format!("{}{}", APU_URL_V1, path); + let res = client + .delete(&url) + .header(reqwest::header::CONTENT_TYPE, "application/json") + .header(reqwest::header::AUTHORIZATION, "Bearer ".to_owned() + &self.api_key) + .send() + .await; + match res { + Ok(res) => match res.status().is_success() { + true => Ok(res), + false => { + Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + format!("{}: {}", res.status(), res.text().await.unwrap()) + ))) + }, + }, + Err(e) => Err(Box::new(e)), + } + } + + pub async fn completion(&self, req: CompletionRequest) -> Result> { + let res = self.post("/completions", &req).await; + match res { + Ok(res) => { + let r = res.json::().await?; + return Ok(r); + }, + Err(e) => { + return Err(e); + }, + } + } + + pub async fn edit(&self, req: EditRequest) -> Result> { + let res = self.post("/edits", &req).await; + match res { + Ok(res) => { + let r = res.json::().await?; + return Ok(r); + }, + Err(e) => { + return Err(e); + }, + } + } + + pub async fn image_generation(&self, req: ImageGenerationRequest) -> Result> { + let res = self.post("/images/generations", &req).await; + match res { + Ok(res) => { + let r = res.json::().await?; + return Ok(r); + }, + Err(e) => { + return Err(e); + }, + } + } + + pub async fn image_edit(&self, req: ImageEditRequest) -> Result> { + let res = self.post("/images/edits", &req).await; + match res { + Ok(res) => { + let r = res.json::().await?; + return Ok(r); + }, + Err(e) => { + return Err(e); + }, + } + } + + pub async fn image_variation(&self, req: ImageVariationRequest) -> Result> { + let res = self.post("/images/variations", &req).await; + match res { + Ok(res) => { + let r = res.json::().await?; + return Ok(r); + }, + Err(e) => { + return Err(e); + }, + } + } + + pub async fn embedding(&self, req: EmbeddingRequest) -> Result> { + let res = self.post("/embeddings", &req).await; + match res { + Ok(res) => { + let r = res.json::().await?; + return Ok(r); + }, + Err(e) => { + return Err(e); + }, + } + } + + pub async fn file_list(&self) -> Result> { + let res = self.get("/files").await; + match res { + Ok(res) => { + let r = res.json::().await?; + return Ok(r); + }, + Err(e) => { + return Err(e); + }, + } + } + + pub async fn file_upload(&self, req: FileUploadRequest) -> Result> { + let res = self.post("/files", &req).await; + match res { + Ok(res) => { + let r = res.json::().await?; + return Ok(r); + }, + Err(e) => { + return Err(e); + }, + } + } + + pub async fn file_delete(&self, req: FileDeleteRequest) -> Result> { + let res = self.delete(&format!("{}/{}", "/files", req.file_id)).await; + match res { + Ok(res) => { + let r = res.json::().await?; + return Ok(r); + }, + Err(e) => { + return Err(e); + }, + } + } + + pub async fn file_retrieve(&self, req: FileRetrieveRequest) -> Result> { + let res = self.get(&format!("{}/{}", "/files", req.file_id)).await; + match res { + Ok(res) => { + let r = res.json::().await?; + return Ok(r); + }, + Err(e) => { + return Err(e); + }, + } + } + + pub async fn file_retrieve_content(&self, req: FileRetrieveContentRequest) -> Result> { + let res = self.get(&format!("{}/{}/content", "/files", req.file_id)).await; + match res { + Ok(res) => { + let r = res.json::().await?; + return Ok(r); + }, + Err(e) => { + return Err(e); + }, + } + } + +} diff --git a/src/v1/common.rs b/src/v1/common.rs new file mode 100644 index 0000000..7a87acc --- /dev/null +++ b/src/v1/common.rs @@ -0,0 +1,10 @@ + +use serde::{Deserialize}; + + +#[derive(Debug, Deserialize)] +pub struct Usage { + pub prompt_tokens: i32, + pub completion_tokens: i32, + pub total_tokens: i32, +} diff --git a/src/v1/completion.rs b/src/v1/completion.rs new file mode 100644 index 0000000..592cdd1 --- /dev/null +++ b/src/v1/completion.rs @@ -0,0 +1,79 @@ +use serde::{Serialize, Deserialize}; +use std::option::Option; +use std::collections::HashMap; + +use crate::v1::common; + +pub const GPT3_TEXT_DAVINCI_003: &str = "text-davinci-003"; +pub const GPT3_TEXT_DAVINCI_002: &str = "text-davinci-002"; +pub const GPT3_TEXT_CURIE_001: &str = "text-curie-001"; +pub const GPT3_TEXT_BABBAGE_001: &str = "text-babbage-001"; +pub const GPT3_TEXT_ADA_001: &str = "text-ada-001"; +pub const GPT3_TEXT_DAVINCI_001: &str = "text-davinci-001"; +pub const GPT3_DAVINCI_INSTRUCT_BETA: &str = "davinci-instruct-beta"; +pub const GPT3_DAVINCI: &str = "davinci"; +pub const GPT3_CURIE_INSTRUCT_BETA: &str = "curie-instruct-beta"; +pub const GPT3_CURIE: &str = "curie"; +pub const GPT3_ADA: &str = "ada"; +pub const GPT3_BABBAGE: &str = "babbage"; + +#[derive(Debug, Serialize)] +pub struct CompletionRequest { + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub echo: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub best_of: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +#[derive(Debug, Deserialize)] +pub struct CompletionChoice { + pub text: String, + pub index: i64, + pub finish_reason: String, + pub logprobs: Option, +} + +#[derive(Debug, Deserialize)] +pub struct LogprobResult { + pub tokens: Vec, + pub token_logprobs: Vec, + pub top_logprobs: Vec>, + pub text_offset: Vec, +} + +#[derive(Debug, Deserialize)] +pub struct CompletionResponse { + pub id: String, + pub object: String, + pub created: i64, + pub model: String, + pub choices: Vec, + pub usage: common::Usage, +} diff --git a/src/v1/edit.rs b/src/v1/edit.rs new file mode 100644 index 0000000..c562c52 --- /dev/null +++ b/src/v1/edit.rs @@ -0,0 +1,32 @@ +use serde::{Serialize, Deserialize}; +use std::option::Option; + +use crate::v1::common; + +#[derive(Debug, Serialize)] +pub struct EditRequest { + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, + pub instruction: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, +} + +#[derive(Debug, Deserialize)] +pub struct EditChoice{ + pub text: String, + pub index: i32, +} + +#[derive(Debug, Deserialize)] +pub struct EditResponse { + pub object: String, + pub created: i64, + pub usage: common::Usage, + pub choices: Vec, +} diff --git a/src/v1/embedding.rs b/src/v1/embedding.rs new file mode 100644 index 0000000..c49b06f --- /dev/null +++ b/src/v1/embedding.rs @@ -0,0 +1,27 @@ +use serde::{Serialize, Deserialize}; +use std::option::Option; + +use crate::v1::common; + +#[derive(Debug, Deserialize)] +pub struct EmbeddingData{ + pub object: String, + pub embedding: Vec, + pub index: i32, + pub usage: common::Usage, +} + +#[derive(Debug, Serialize)] +pub struct EmbeddingRequest { + pub model: String, + pub input: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + + +#[derive(Debug, Deserialize)] +pub struct EmbeddingResponse { + pub object: String, + pub data: Vec, +} diff --git a/src/v1/file.rs b/src/v1/file.rs new file mode 100644 index 0000000..9437700 --- /dev/null +++ b/src/v1/file.rs @@ -0,0 +1,78 @@ +use serde::{Serialize, Deserialize}; + +#[derive(Debug, Deserialize)] +pub struct FileData{ + pub id: String, + pub oejct: String, + pub bytes: i32, + pub created_at: i64, + pub filename: String, + pub purpose: String, +} + +#[derive(Debug, Deserialize)] +pub struct FileListResponse { + pub object: String, + pub data: Vec, +} + + +#[derive(Debug, Serialize)] +pub struct FileUploadRequest { + pub file: String, + pub purpose: String, +} + +#[derive(Debug, Deserialize)] +pub struct FileUploadResponse { + pub id: String, + pub oejct: String, + pub bytes: i32, + pub created_at: i64, + pub filename: String, + pub purpose: String, +} + + +#[derive(Debug, Serialize)] +pub struct FileDeleteRequest { + pub file_id: String, +} + +#[derive(Debug, Deserialize)] +pub struct FileDeleteResponse { + pub id: String, + pub oejct: String, + pub delete: bool, +} + +#[derive(Debug, Serialize)] +pub struct FileRetrieveRequest { + pub file_id: String, +} + +#[derive(Debug, Deserialize)] +pub struct FileRetrieveResponse { + pub id: String, + pub oejct: String, + pub bytes: i32, + pub created_at: i64, + pub filename: String, + pub purpose: String, +} + + +#[derive(Debug, Serialize)] +pub struct FileRetrieveContentRequest { + pub file_id: String, +} + +#[derive(Debug, Deserialize)] +pub struct FileRetrieveContentResponse { + pub id: String, + pub oejct: String, + pub bytes: i32, + pub created_at: i64, + pub filename: String, + pub purpose: String, +} diff --git a/src/v1/image.rs b/src/v1/image.rs new file mode 100644 index 0000000..ac7cc4e --- /dev/null +++ b/src/v1/image.rs @@ -0,0 +1,68 @@ +use serde::{Serialize, Deserialize}; +use std::option::Option; + +#[derive(Debug, Deserialize)] +pub struct ImageData{ + pub url: String, +} + +#[derive(Debug, Serialize)] +pub struct ImageGenerationRequest { + pub prompt: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + + +#[derive(Debug, Deserialize)] +pub struct ImageGenerationResponse { + pub created: i64, + pub data: Vec, +} + +#[derive(Debug, Serialize)] +pub struct ImageEditRequest { + pub image: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub mask: Option, + pub prompt: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ImageEditResponse { + pub created: i64, + pub data: Vec, +} + +#[derive(Debug, Serialize)] +pub struct ImageVariationRequest { + pub image: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, +} + +#[derive(Debug, Deserialize)] +pub struct ImageVariationResponse { + pub created: i64, + pub data: Vec, +} diff --git a/src/v1/mod.rs b/src/v1/mod.rs new file mode 100644 index 0000000..0c0728d --- /dev/null +++ b/src/v1/mod.rs @@ -0,0 +1,9 @@ +pub mod common; + +pub mod completion; +pub mod edit; +pub mod image; +pub mod embedding; +pub mod file; + +pub mod api;