From a9be9efdfea68bfea5859e49f193722d55e31475 Mon Sep 17 00:00:00 2001 From: Night Cruising <2586447362@qq.com> Date: Tue, 17 Oct 2023 15:14:31 +0800 Subject: [PATCH] feat: add chain method to create Request instance --- README.md | 40 +++++---------------- examples/chat_completion.rs | 23 ++++-------- examples/completion.rs | 30 +++++++--------- examples/embedding.rs | 11 +++--- examples/function_call.rs | 38 ++++++++------------ examples/function_call_role.rs | 58 ++++++++++-------------------- src/v1/audio.rs | 42 ++++++++++++++++++++++ src/v1/chat_completion.rs | 38 ++++++++++++++++++++ src/v1/common.rs | 14 ++++++++ src/v1/completion.rs | 42 ++++++++++++++++++++++ src/v1/edit.rs | 22 ++++++++++++ src/v1/embedding.rs | 17 +++++++++ src/v1/file.rs | 24 +++++++++++++ src/v1/fine_tune.rs | 60 +++++++++++++++++++++++++++++++ src/v1/image.rs | 65 ++++++++++++++++++++++++++++++++++ src/v1/moderation.rs | 13 +++++++ 16 files changed, 402 insertions(+), 135 deletions(-) diff --git a/README.md b/README.md index 218e67b..c51f1e3 100644 --- a/README.md +++ b/README.md @@ -33,25 +33,13 @@ let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); ### Create request ```rust use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; -let req = ChatCompletionRequest { - model: chat_completion::GPT4.to_string(), - messages: vec![chat_completion::ChatCompletionMessage { +let req = ChatCompletionRequest::new( + chat_completion::GPT4.to_string(), + vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, content: String::from("Hello OpenAI!"), }], - functions: None, - function_call: None, - temperature: None, - top_p: None, - n: None, - stream: None, - stop: None, - max_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, -}; +); ``` ### Send request @@ -68,27 +56,15 @@ use std::env; fn main() -> Result<(), Box> { let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); - let req = ChatCompletionRequest { - model: chat_completion::GPT4.to_string(), - messages: vec![chat_completion::ChatCompletionMessage { + let req = ChatCompletionRequest::new( + chat_completion::GPT4.to_string(), + vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, content: String::from("What is Bitcoin?"), name: None, function_call: None, }], - functions: None, - function_call: None, - temperature: None, - top_p: None, - n: None, - stream: None, - stop: None, - max_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, - }; + ); let result = client.chat_completion(req)?; println!("{:?}", result.choices[0].message.content); Ok(()) diff --git a/examples/chat_completion.rs b/examples/chat_completion.rs index ae64146..542e4d7 100644 --- a/examples/chat_completion.rs +++ b/examples/chat_completion.rs @@ -4,29 +4,20 @@ use std::env; fn main() -> Result<(), Box> { let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); - let req = ChatCompletionRequest { - model: chat_completion::GPT4.to_string(), - messages: vec![chat_completion::ChatCompletionMessage { + + let req = ChatCompletionRequest::new( + chat_completion::GPT4.to_string(), + vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, content: String::from("What is Bitcoin?"), name: None, function_call: None, }], - functions: None, - function_call: None, - temperature: None, - top_p: None, - n: None, - stream: None, - stop: None, - max_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, - }; + ); + let result = client.chat_completion(req)?; println!("{:?}", result.choices[0].message.content); + Ok(()) } diff --git a/examples/completion.rs b/examples/completion.rs index 0a62a3a..362d9b7 100644 --- a/examples/completion.rs +++ b/examples/completion.rs @@ -4,24 +4,18 @@ use std::env; fn main() -> Result<(), Box> { let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); - let req = CompletionRequest { - model: completion::GPT3_TEXT_DAVINCI_003.to_string(), - prompt: String::from("What is Bitcoin?"), - suffix: None, - max_tokens: Some(3000), - temperature: Some(0.9), - top_p: Some(1.0), - n: None, - stream: None, - logprobs: None, - echo: None, - stop: Some(vec![String::from(" Human:"), String::from(" AI:")]), - presence_penalty: Some(0.6), - frequency_penalty: Some(0.0), - best_of: None, - logit_bias: None, - user: None, - }; + + let req = CompletionRequest::new( + completion::GPT3_TEXT_DAVINCI_003.to_string(), + String::from("What is Bitcoin?"), + ) + .max_tokens(3000) + .temperature(0.9) + .top_p(1.0) + .stop(vec![String::from(" Human:"), String::from(" AI:")]) + .presence_penalty(0.6) + .frequency_penalty(0.0); + let result = client.completion(req)?; println!("{:}", result.choices[0].text); diff --git a/examples/embedding.rs b/examples/embedding.rs index 7b72fc0..70403c9 100644 --- a/examples/embedding.rs +++ b/examples/embedding.rs @@ -4,11 +4,12 @@ use std::env; fn main() -> Result<(), Box> { let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); - let req = EmbeddingRequest { - model: "text-embedding-ada-002".to_string(), - input: "story time".to_string(), - user: Option::None, - }; + + let req = EmbeddingRequest::new( + "text-embedding-ada-002".to_string(), + "story time".to_string(), + ); + let result = client.embedding(req)?; println!("{:?}", result.data); diff --git a/examples/function_call.rs b/examples/function_call.rs index 7bff433..3a32144 100644 --- a/examples/function_call.rs +++ b/examples/function_call.rs @@ -29,35 +29,25 @@ fn main() -> Result<(), Box> { }), ); - let req = ChatCompletionRequest { - model: chat_completion::GPT3_5_TURBO_0613.to_string(), - messages: vec![chat_completion::ChatCompletionMessage { + let req = ChatCompletionRequest::new( + chat_completion::GPT3_5_TURBO_0613.to_string(), + vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, content: String::from("What is the price of Ethereum?"), name: None, function_call: None, }], - functions: Some(vec![chat_completion::Function { - name: String::from("get_coin_price"), - description: Some(String::from("Get the price of a cryptocurrency")), - parameters: chat_completion::FunctionParameters { - schema_type: chat_completion::JSONSchemaType::Object, - properties: Some(properties), - required: Some(vec![String::from("coin")]), - }, - }]), - function_call: Some(FunctionCallType::Auto), // Some(FunctionCallType::Function { name: "test".to_string() }), - temperature: None, - top_p: None, - n: None, - stream: None, - stop: None, - max_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, - }; + ) + .functions(vec![chat_completion::Function { + name: String::from("get_coin_price"), + description: Some(String::from("Get the price of a cryptocurrency")), + parameters: chat_completion::FunctionParameters { + schema_type: chat_completion::JSONSchemaType::Object, + properties: Some(properties), + required: Some(vec![String::from("coin")]), + }, + }]) + .function_call(FunctionCallType::Auto); // debug reuqest json // let serialized = serde_json::to_string(&req).unwrap(); diff --git a/examples/function_call_role.rs b/examples/function_call_role.rs index c011668..56ffc87 100644 --- a/examples/function_call_role.rs +++ b/examples/function_call_role.rs @@ -29,35 +29,24 @@ fn main() -> Result<(), Box> { }), ); - let req = ChatCompletionRequest { - model: chat_completion::GPT3_5_TURBO_0613.to_string(), - messages: vec![chat_completion::ChatCompletionMessage { + let req = ChatCompletionRequest::new( + chat_completion::GPT3_5_TURBO_0613.to_string(), + vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, content: String::from("What is the price of Ethereum?"), name: None, function_call: None, }], - functions: Some(vec![chat_completion::Function { - name: String::from("get_coin_price"), - description: Some(String::from("Get the price of a cryptocurrency")), - parameters: chat_completion::FunctionParameters { - schema_type: chat_completion::JSONSchemaType::Object, - properties: Some(properties), - required: Some(vec![String::from("coin")]), - }, - }]), - function_call: None, - temperature: None, - top_p: None, - n: None, - stream: None, - stop: None, - max_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, - }; + ) + .functions(vec![chat_completion::Function { + name: String::from("get_coin_price"), + description: Some(String::from("Get the price of a cryptocurrency")), + parameters: chat_completion::FunctionParameters { + schema_type: chat_completion::JSONSchemaType::Object, + properties: Some(properties), + required: Some(vec![String::from("coin")]), + }, + }]); let result = client.chat_completion(req)?; @@ -80,9 +69,9 @@ fn main() -> Result<(), Box> { let c: Currency = serde_json::from_str(&arguments)?; let coin = c.coin; - let req = ChatCompletionRequest { - model: chat_completion::GPT3_5_TURBO_0613.to_string(), - messages: vec![ + let req = ChatCompletionRequest::new( + chat_completion::GPT3_5_TURBO_0613.to_string(), + vec![ chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, content: String::from("What is the price of Ethereum?"), @@ -99,19 +88,8 @@ fn main() -> Result<(), Box> { function_call: None, }, ], - functions: None, - function_call: None, - temperature: None, - top_p: None, - n: None, - stream: None, - stop: None, - max_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, - }; + ); + let result = client.chat_completion(req)?; println!("{:?}", result.choices[0].message.content); } diff --git a/src/v1/audio.rs b/src/v1/audio.rs index 50e91cb..1a5a954 100644 --- a/src/v1/audio.rs +++ b/src/v1/audio.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +use crate::impl_builder_methods; + pub const WHISPER_1: &str = "whisper-1"; #[derive(Debug, Serialize)] @@ -16,6 +18,27 @@ pub struct AudioTranscriptionRequest { pub language: Option, } +impl AudioTranscriptionRequest { + pub fn new(file: String, model: String) -> Self { + Self { + file, + model, + prompt: None, + response_format: None, + temperature: None, + language: None, + } + } +} + +impl_builder_methods!( + AudioTranscriptionRequest, + prompt: String, + response_format: String, + temperature: f32, + language: String +); + #[derive(Debug, Deserialize)] pub struct AudioTranscriptionResponse { pub text: String, @@ -33,6 +56,25 @@ pub struct AudioTranslationRequest { pub temperature: Option, } +impl AudioTranslationRequest { + pub fn new(file: String, model: String) -> Self { + Self { + file, + model, + prompt: None, + response_format: None, + temperature: None, + } + } +} + +impl_builder_methods!( + AudioTranslationRequest, + prompt: String, + response_format: String, + temperature: f32 +); + #[derive(Debug, Deserialize)] pub struct AudioTranslationResponse { pub text: String, diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 03d7d98..2e64167 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -2,6 +2,7 @@ use serde::ser::SerializeMap; use serde::{Deserialize, Serialize, Serializer}; use std::collections::HashMap; +use crate::impl_builder_methods; use crate::v1::common; pub const GPT3_5_TURBO: &str = "gpt-3.5-turbo"; @@ -52,6 +53,43 @@ pub struct ChatCompletionRequest { pub user: Option, } +impl ChatCompletionRequest { + pub fn new(model: String, messages: Vec) -> Self { + Self { + model, + messages, + functions: None, + function_call: None, + temperature: None, + top_p: None, + stream: None, + n: None, + stop: None, + max_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + } + } +} + +impl_builder_methods!( + ChatCompletionRequest, + functions: Vec, + function_call: FunctionCallType, + temperature: f64, + top_p: f64, + n: i64, + stream: bool, + stop: Vec, + max_tokens: i64, + presence_penalty: f64, + frequency_penalty: f64, + logit_bias: HashMap, + user: String +); + #[derive(Debug, Serialize, Deserialize, Clone)] #[allow(non_camel_case_types)] pub enum MessageRole { diff --git a/src/v1/common.rs b/src/v1/common.rs index 05f500b..7ab7cd4 100644 --- a/src/v1/common.rs +++ b/src/v1/common.rs @@ -6,3 +6,17 @@ pub struct Usage { pub completion_tokens: i32, pub total_tokens: i32, } + +#[macro_export] +macro_rules! impl_builder_methods { + ($builder:ident, $($field:ident: $field_type:ty),*) => { + $( + impl $builder { + pub fn $field(mut self, $field: $field_type) -> Self { + self.$field = Some($field); + self + } + } + )* + }; +} diff --git a/src/v1/completion.rs b/src/v1/completion.rs index 3781294..aeb5550 100644 --- a/src/v1/completion.rs +++ b/src/v1/completion.rs @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::option::Option; +use crate::impl_builder_methods; use crate::v1::common; pub const GPT3_TEXT_DAVINCI_003: &str = "text-davinci-003"; @@ -51,6 +52,47 @@ pub struct CompletionRequest { pub user: Option, } +impl CompletionRequest { + pub fn new(model: String, prompt: String) -> Self { + Self { + model, + prompt, + suffix: None, + max_tokens: None, + temperature: None, + top_p: None, + n: None, + stream: None, + logprobs: None, + echo: None, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + } + } +} + +impl_builder_methods!( + CompletionRequest, + suffix: String, + max_tokens: i32, + temperature: f32, + top_p: f32, + n: i32, + stream: bool, + logprobs: i32, + echo: bool, + stop: Vec, + presence_penalty: f32, + frequency_penalty: f32, + best_of: i32, + logit_bias: HashMap, + user: String +); + #[derive(Debug, Deserialize)] pub struct CompletionChoice { pub text: String, diff --git a/src/v1/edit.rs b/src/v1/edit.rs index 82870cf..0bf2c03 100644 --- a/src/v1/edit.rs +++ b/src/v1/edit.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; use std::option::Option; +use crate::impl_builder_methods; use crate::v1::common; #[derive(Debug, Serialize)] @@ -17,6 +18,27 @@ pub struct EditRequest { pub top_p: Option, } +impl EditRequest { + pub fn new(model: String, instruction: String) -> Self { + Self { + model, + instruction, + input: None, + n: None, + temperature: None, + top_p: None, + } + } +} + +impl_builder_methods!( + EditRequest, + input: String, + n: i32, + temperature: f32, + top_p: f32 +); + #[derive(Debug, Deserialize)] pub struct EditChoice { pub text: String, diff --git a/src/v1/embedding.rs b/src/v1/embedding.rs index 0baf189..a299841 100644 --- a/src/v1/embedding.rs +++ b/src/v1/embedding.rs @@ -1,6 +1,8 @@ use serde::{Deserialize, Serialize}; use std::option::Option; +use crate::impl_builder_methods; + #[derive(Debug, Deserialize)] pub struct EmbeddingData { pub object: String, @@ -16,6 +18,21 @@ pub struct EmbeddingRequest { pub user: Option, } +impl EmbeddingRequest { + pub fn new(model: String, input: String) -> Self { + Self { + model, + input, + user: None, + } + } +} + +impl_builder_methods!( + EmbeddingRequest, + user: String +); + #[derive(Debug, Deserialize)] pub struct EmbeddingResponse { pub object: String, diff --git a/src/v1/file.rs b/src/v1/file.rs index 5eff6cc..71fb2ad 100644 --- a/src/v1/file.rs +++ b/src/v1/file.rs @@ -22,6 +22,12 @@ pub struct FileUploadRequest { pub purpose: String, } +impl FileUploadRequest { + pub fn new(file: String, purpose: String) -> Self { + Self { file, purpose } + } +} + #[derive(Debug, Deserialize)] pub struct FileUploadResponse { pub id: String, @@ -37,6 +43,12 @@ pub struct FileDeleteRequest { pub file_id: String, } +impl FileDeleteRequest { + pub fn new(file_id: String) -> Self { + Self { file_id } + } +} + #[derive(Debug, Deserialize)] pub struct FileDeleteResponse { pub id: String, @@ -49,6 +61,12 @@ pub struct FileRetrieveRequest { pub file_id: String, } +impl FileRetrieveRequest { + pub fn new(file_id: String) -> Self { + Self { file_id } + } +} + #[derive(Debug, Deserialize)] pub struct FileRetrieveResponse { pub id: String, @@ -64,6 +82,12 @@ pub struct FileRetrieveContentRequest { pub file_id: String, } +impl FileRetrieveContentRequest { + pub fn new(file_id: String) -> Self { + Self { file_id } + } +} + #[derive(Debug, Deserialize)] pub struct FileRetrieveContentResponse { pub id: String, diff --git a/src/v1/fine_tune.rs b/src/v1/fine_tune.rs index 25d9c7e..a46e3af 100644 --- a/src/v1/fine_tune.rs +++ b/src/v1/fine_tune.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +use crate::impl_builder_methods; + #[derive(Debug, Serialize)] pub struct CreateFineTuneRequest { pub training_file: String, @@ -27,6 +29,40 @@ pub struct CreateFineTuneRequest { pub suffix: Option, } +impl CreateFineTuneRequest { + pub fn new(training_file: String) -> Self { + Self { + training_file, + validation_file: None, + model: None, + n_epochs: None, + batch_size: None, + learning_rate_multiplier: None, + prompt_loss_weight: None, + compute_classification_metrics: None, + classification_n_classes: None, + classification_positive_class: None, + classification_betas: None, + suffix: None, + } + } +} + +impl_builder_methods!( + CreateFineTuneRequest, + validation_file: String, + model: String, + n_epochs: i32, + batch_size: i32, + learning_rate_multiplier: f32, + prompt_loss_weight: f32, + compute_classification_metrics: bool, + classification_n_classes: i32, + classification_positive_class: String, + classification_betas: Vec, + suffix: String +); + #[derive(Debug, Deserialize)] pub struct CreateFineTuneResponse { pub id: String, @@ -134,6 +170,12 @@ pub struct RetrieveFineTuneRequest { pub fine_tune_id: String, } +impl RetrieveFineTuneRequest { + pub fn new(fine_tune_id: String) -> Self { + Self { fine_tune_id } + } +} + #[derive(Debug, Deserialize)] pub struct RetrieveFineTuneResponse { pub id: String, @@ -156,6 +198,12 @@ pub struct CancelFineTuneRequest { pub fine_tune_id: String, } +impl CancelFineTuneRequest { + pub fn new(fine_tune_id: String) -> Self { + Self { fine_tune_id } + } +} + #[derive(Debug, Deserialize)] pub struct CancelFineTuneResponse { pub id: String, @@ -178,6 +226,12 @@ pub struct ListFineTuneEventsRequest { pub fine_tune_id: String, } +impl ListFineTuneEventsRequest { + pub fn new(fine_tune_id: String) -> Self { + Self { fine_tune_id } + } +} + #[derive(Debug, Deserialize)] pub struct ListFineTuneEventsResponse { pub object: String, @@ -189,6 +243,12 @@ pub struct DeleteFineTuneModelRequest { pub model_id: String, } +impl DeleteFineTuneModelRequest { + pub fn new(model_id: String) -> Self { + Self { model_id } + } +} + #[derive(Debug, Deserialize)] pub struct DeleteFineTuneModelResponse { pub id: String, diff --git a/src/v1/image.rs b/src/v1/image.rs index 9e03a43..6d92be9 100644 --- a/src/v1/image.rs +++ b/src/v1/image.rs @@ -1,6 +1,8 @@ use serde::{Deserialize, Serialize}; use std::option::Option; +use crate::impl_builder_methods; + #[derive(Debug, Deserialize)] pub struct ImageData { pub url: String, @@ -19,6 +21,26 @@ pub struct ImageGenerationRequest { pub user: Option, } +impl ImageGenerationRequest { + pub fn new(prompt: String) -> Self { + Self { + prompt, + n: None, + size: None, + response_format: None, + user: None, + } + } +} + +impl_builder_methods!( + ImageGenerationRequest, + n: i32, + size: String, + response_format: String, + user: String +); + #[derive(Debug, Deserialize)] pub struct ImageGenerationResponse { pub created: i64, @@ -41,6 +63,29 @@ pub struct ImageEditRequest { pub user: Option, } +impl ImageEditRequest { + pub fn new(image: String, prompt: String) -> Self { + Self { + image, + prompt, + mask: None, + n: None, + size: None, + response_format: None, + user: None, + } + } +} + +impl_builder_methods!( + ImageEditRequest, + mask: String, + n: i32, + size: String, + response_format: String, + user: String +); + #[derive(Debug, Deserialize)] pub struct ImageEditResponse { pub created: i64, @@ -60,6 +105,26 @@ pub struct ImageVariationRequest { pub user: Option, } +impl ImageVariationRequest { + pub fn new(image: String) -> Self { + Self { + image, + n: None, + size: None, + response_format: None, + user: None, + } + } +} + +impl_builder_methods!( + ImageVariationRequest, + n: i32, + size: String, + response_format: String, + user: String +); + #[derive(Debug, Deserialize)] pub struct ImageVariationResponse { pub created: i64, diff --git a/src/v1/moderation.rs b/src/v1/moderation.rs index 78b4f29..ee920d0 100644 --- a/src/v1/moderation.rs +++ b/src/v1/moderation.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +use crate::impl_builder_methods; + #[derive(Debug, Serialize)] pub struct CreateModerationRequest { pub input: String, @@ -7,6 +9,17 @@ pub struct CreateModerationRequest { pub model: Option, } +impl CreateModerationRequest { + pub fn new(input: String) -> Self { + Self { input, model: None } + } +} + +impl_builder_methods!( + CreateModerationRequest, + model: String +); + #[derive(Debug, Deserialize)] pub struct CreateModerationResponse { pub id: String,