feat: add chain method to create Request instance

This commit is contained in:
Night Cruising
2023-10-17 15:14:31 +08:00
parent f1f1fa7e86
commit a9be9efdfe
16 changed files with 402 additions and 135 deletions

View File

@ -33,25 +33,13 @@ let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
### Create request
```rust
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
let req = ChatCompletionRequest {
model: chat_completion::GPT4.to_string(),
messages: vec![chat_completion::ChatCompletionMessage {
let req = ChatCompletionRequest::new(
chat_completion::GPT4.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("Hello OpenAI!"),
}],
functions: None,
function_call: None,
temperature: None,
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
};
);
```
### Send request
@ -68,27 +56,15 @@ use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let req = ChatCompletionRequest {
model: chat_completion::GPT4.to_string(),
messages: vec![chat_completion::ChatCompletionMessage {
let req = ChatCompletionRequest::new(
chat_completion::GPT4.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("What is Bitcoin?"),
name: None,
function_call: None,
}],
functions: None,
function_call: None,
temperature: None,
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
};
);
let result = client.chat_completion(req)?;
println!("{:?}", result.choices[0].message.content);
Ok(())

View File

@ -4,29 +4,20 @@ use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let req = ChatCompletionRequest {
model: chat_completion::GPT4.to_string(),
messages: vec![chat_completion::ChatCompletionMessage {
let req = ChatCompletionRequest::new(
chat_completion::GPT4.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("What is Bitcoin?"),
name: None,
function_call: None,
}],
functions: None,
function_call: None,
temperature: None,
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
};
);
let result = client.chat_completion(req)?;
println!("{:?}", result.choices[0].message.content);
Ok(())
}

View File

@ -4,24 +4,18 @@ use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let req = CompletionRequest {
model: completion::GPT3_TEXT_DAVINCI_003.to_string(),
prompt: String::from("What is Bitcoin?"),
suffix: None,
max_tokens: Some(3000),
temperature: Some(0.9),
top_p: Some(1.0),
n: None,
stream: None,
logprobs: None,
echo: None,
stop: Some(vec![String::from(" Human:"), String::from(" AI:")]),
presence_penalty: Some(0.6),
frequency_penalty: Some(0.0),
best_of: None,
logit_bias: None,
user: None,
};
let req = CompletionRequest::new(
completion::GPT3_TEXT_DAVINCI_003.to_string(),
String::from("What is Bitcoin?"),
)
.max_tokens(3000)
.temperature(0.9)
.top_p(1.0)
.stop(vec![String::from(" Human:"), String::from(" AI:")])
.presence_penalty(0.6)
.frequency_penalty(0.0);
let result = client.completion(req)?;
println!("{:}", result.choices[0].text);

View File

@ -4,11 +4,12 @@ use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let req = EmbeddingRequest {
model: "text-embedding-ada-002".to_string(),
input: "story time".to_string(),
user: Option::None,
};
let req = EmbeddingRequest::new(
"text-embedding-ada-002".to_string(),
"story time".to_string(),
);
let result = client.embedding(req)?;
println!("{:?}", result.data);

View File

@ -29,35 +29,25 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}),
);
let req = ChatCompletionRequest {
model: chat_completion::GPT3_5_TURBO_0613.to_string(),
messages: vec![chat_completion::ChatCompletionMessage {
let req = ChatCompletionRequest::new(
chat_completion::GPT3_5_TURBO_0613.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("What is the price of Ethereum?"),
name: None,
function_call: None,
}],
functions: Some(vec![chat_completion::Function {
name: String::from("get_coin_price"),
description: Some(String::from("Get the price of a cryptocurrency")),
parameters: chat_completion::FunctionParameters {
schema_type: chat_completion::JSONSchemaType::Object,
properties: Some(properties),
required: Some(vec![String::from("coin")]),
},
}]),
function_call: Some(FunctionCallType::Auto), // Some(FunctionCallType::Function { name: "test".to_string() }),
temperature: None,
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
};
)
.functions(vec![chat_completion::Function {
name: String::from("get_coin_price"),
description: Some(String::from("Get the price of a cryptocurrency")),
parameters: chat_completion::FunctionParameters {
schema_type: chat_completion::JSONSchemaType::Object,
properties: Some(properties),
required: Some(vec![String::from("coin")]),
},
}])
.function_call(FunctionCallType::Auto);
// debug reuqest json
// let serialized = serde_json::to_string(&req).unwrap();

View File

@ -29,35 +29,24 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}),
);
let req = ChatCompletionRequest {
model: chat_completion::GPT3_5_TURBO_0613.to_string(),
messages: vec![chat_completion::ChatCompletionMessage {
let req = ChatCompletionRequest::new(
chat_completion::GPT3_5_TURBO_0613.to_string(),
vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("What is the price of Ethereum?"),
name: None,
function_call: None,
}],
functions: Some(vec![chat_completion::Function {
name: String::from("get_coin_price"),
description: Some(String::from("Get the price of a cryptocurrency")),
parameters: chat_completion::FunctionParameters {
schema_type: chat_completion::JSONSchemaType::Object,
properties: Some(properties),
required: Some(vec![String::from("coin")]),
},
}]),
function_call: None,
temperature: None,
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
};
)
.functions(vec![chat_completion::Function {
name: String::from("get_coin_price"),
description: Some(String::from("Get the price of a cryptocurrency")),
parameters: chat_completion::FunctionParameters {
schema_type: chat_completion::JSONSchemaType::Object,
properties: Some(properties),
required: Some(vec![String::from("coin")]),
},
}]);
let result = client.chat_completion(req)?;
@ -80,9 +69,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let c: Currency = serde_json::from_str(&arguments)?;
let coin = c.coin;
let req = ChatCompletionRequest {
model: chat_completion::GPT3_5_TURBO_0613.to_string(),
messages: vec![
let req = ChatCompletionRequest::new(
chat_completion::GPT3_5_TURBO_0613.to_string(),
vec![
chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: String::from("What is the price of Ethereum?"),
@ -99,19 +88,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
function_call: None,
},
],
functions: None,
function_call: None,
temperature: None,
top_p: None,
n: None,
stream: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
};
);
let result = client.chat_completion(req)?;
println!("{:?}", result.choices[0].message.content);
}

View File

@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};
use crate::impl_builder_methods;
pub const WHISPER_1: &str = "whisper-1";
#[derive(Debug, Serialize)]
@ -16,6 +18,27 @@ pub struct AudioTranscriptionRequest {
pub language: Option<String>,
}
impl AudioTranscriptionRequest {
pub fn new(file: String, model: String) -> Self {
Self {
file,
model,
prompt: None,
response_format: None,
temperature: None,
language: None,
}
}
}
impl_builder_methods!(
AudioTranscriptionRequest,
prompt: String,
response_format: String,
temperature: f32,
language: String
);
#[derive(Debug, Deserialize)]
pub struct AudioTranscriptionResponse {
pub text: String,
@ -33,6 +56,25 @@ pub struct AudioTranslationRequest {
pub temperature: Option<f32>,
}
impl AudioTranslationRequest {
pub fn new(file: String, model: String) -> Self {
Self {
file,
model,
prompt: None,
response_format: None,
temperature: None,
}
}
}
impl_builder_methods!(
AudioTranslationRequest,
prompt: String,
response_format: String,
temperature: f32
);
#[derive(Debug, Deserialize)]
pub struct AudioTranslationResponse {
pub text: String,

View File

@ -2,6 +2,7 @@ use serde::ser::SerializeMap;
use serde::{Deserialize, Serialize, Serializer};
use std::collections::HashMap;
use crate::impl_builder_methods;
use crate::v1::common;
pub const GPT3_5_TURBO: &str = "gpt-3.5-turbo";
@ -52,6 +53,43 @@ pub struct ChatCompletionRequest {
pub user: Option<String>,
}
impl ChatCompletionRequest {
pub fn new(model: String, messages: Vec<ChatCompletionMessage>) -> Self {
Self {
model,
messages,
functions: None,
function_call: None,
temperature: None,
top_p: None,
stream: None,
n: None,
stop: None,
max_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
}
}
}
impl_builder_methods!(
ChatCompletionRequest,
functions: Vec<Function>,
function_call: FunctionCallType,
temperature: f64,
top_p: f64,
n: i64,
stream: bool,
stop: Vec<String>,
max_tokens: i64,
presence_penalty: f64,
frequency_penalty: f64,
logit_bias: HashMap<String, i32>,
user: String
);
#[derive(Debug, Serialize, Deserialize, Clone)]
#[allow(non_camel_case_types)]
pub enum MessageRole {

View File

@ -6,3 +6,17 @@ pub struct Usage {
pub completion_tokens: i32,
pub total_tokens: i32,
}
#[macro_export]
macro_rules! impl_builder_methods {
($builder:ident, $($field:ident: $field_type:ty),*) => {
$(
impl $builder {
pub fn $field(mut self, $field: $field_type) -> Self {
self.$field = Some($field);
self
}
}
)*
};
}

View File

@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::option::Option;
use crate::impl_builder_methods;
use crate::v1::common;
pub const GPT3_TEXT_DAVINCI_003: &str = "text-davinci-003";
@ -51,6 +52,47 @@ pub struct CompletionRequest {
pub user: Option<String>,
}
impl CompletionRequest {
pub fn new(model: String, prompt: String) -> Self {
Self {
model,
prompt,
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stream: None,
logprobs: None,
echo: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
}
}
}
impl_builder_methods!(
CompletionRequest,
suffix: String,
max_tokens: i32,
temperature: f32,
top_p: f32,
n: i32,
stream: bool,
logprobs: i32,
echo: bool,
stop: Vec<String>,
presence_penalty: f32,
frequency_penalty: f32,
best_of: i32,
logit_bias: HashMap<String, i32>,
user: String
);
#[derive(Debug, Deserialize)]
pub struct CompletionChoice {
pub text: String,

View File

@ -1,6 +1,7 @@
use serde::{Deserialize, Serialize};
use std::option::Option;
use crate::impl_builder_methods;
use crate::v1::common;
#[derive(Debug, Serialize)]
@ -17,6 +18,27 @@ pub struct EditRequest {
pub top_p: Option<f32>,
}
impl EditRequest {
pub fn new(model: String, instruction: String) -> Self {
Self {
model,
instruction,
input: None,
n: None,
temperature: None,
top_p: None,
}
}
}
impl_builder_methods!(
EditRequest,
input: String,
n: i32,
temperature: f32,
top_p: f32
);
#[derive(Debug, Deserialize)]
pub struct EditChoice {
pub text: String,

View File

@ -1,6 +1,8 @@
use serde::{Deserialize, Serialize};
use std::option::Option;
use crate::impl_builder_methods;
#[derive(Debug, Deserialize)]
pub struct EmbeddingData {
pub object: String,
@ -16,6 +18,21 @@ pub struct EmbeddingRequest {
pub user: Option<String>,
}
impl EmbeddingRequest {
pub fn new(model: String, input: String) -> Self {
Self {
model,
input,
user: None,
}
}
}
impl_builder_methods!(
EmbeddingRequest,
user: String
);
#[derive(Debug, Deserialize)]
pub struct EmbeddingResponse {
pub object: String,

View File

@ -22,6 +22,12 @@ pub struct FileUploadRequest {
pub purpose: String,
}
impl FileUploadRequest {
pub fn new(file: String, purpose: String) -> Self {
Self { file, purpose }
}
}
#[derive(Debug, Deserialize)]
pub struct FileUploadResponse {
pub id: String,
@ -37,6 +43,12 @@ pub struct FileDeleteRequest {
pub file_id: String,
}
impl FileDeleteRequest {
pub fn new(file_id: String) -> Self {
Self { file_id }
}
}
#[derive(Debug, Deserialize)]
pub struct FileDeleteResponse {
pub id: String,
@ -49,6 +61,12 @@ pub struct FileRetrieveRequest {
pub file_id: String,
}
impl FileRetrieveRequest {
pub fn new(file_id: String) -> Self {
Self { file_id }
}
}
#[derive(Debug, Deserialize)]
pub struct FileRetrieveResponse {
pub id: String,
@ -64,6 +82,12 @@ pub struct FileRetrieveContentRequest {
pub file_id: String,
}
impl FileRetrieveContentRequest {
pub fn new(file_id: String) -> Self {
Self { file_id }
}
}
#[derive(Debug, Deserialize)]
pub struct FileRetrieveContentResponse {
pub id: String,

View File

@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};
use crate::impl_builder_methods;
#[derive(Debug, Serialize)]
pub struct CreateFineTuneRequest {
pub training_file: String,
@ -27,6 +29,40 @@ pub struct CreateFineTuneRequest {
pub suffix: Option<String>,
}
impl CreateFineTuneRequest {
pub fn new(training_file: String) -> Self {
Self {
training_file,
validation_file: None,
model: None,
n_epochs: None,
batch_size: None,
learning_rate_multiplier: None,
prompt_loss_weight: None,
compute_classification_metrics: None,
classification_n_classes: None,
classification_positive_class: None,
classification_betas: None,
suffix: None,
}
}
}
impl_builder_methods!(
CreateFineTuneRequest,
validation_file: String,
model: String,
n_epochs: i32,
batch_size: i32,
learning_rate_multiplier: f32,
prompt_loss_weight: f32,
compute_classification_metrics: bool,
classification_n_classes: i32,
classification_positive_class: String,
classification_betas: Vec<f32>,
suffix: String
);
#[derive(Debug, Deserialize)]
pub struct CreateFineTuneResponse {
pub id: String,
@ -134,6 +170,12 @@ pub struct RetrieveFineTuneRequest {
pub fine_tune_id: String,
}
impl RetrieveFineTuneRequest {
pub fn new(fine_tune_id: String) -> Self {
Self { fine_tune_id }
}
}
#[derive(Debug, Deserialize)]
pub struct RetrieveFineTuneResponse {
pub id: String,
@ -156,6 +198,12 @@ pub struct CancelFineTuneRequest {
pub fine_tune_id: String,
}
impl CancelFineTuneRequest {
pub fn new(fine_tune_id: String) -> Self {
Self { fine_tune_id }
}
}
#[derive(Debug, Deserialize)]
pub struct CancelFineTuneResponse {
pub id: String,
@ -178,6 +226,12 @@ pub struct ListFineTuneEventsRequest {
pub fine_tune_id: String,
}
impl ListFineTuneEventsRequest {
pub fn new(fine_tune_id: String) -> Self {
Self { fine_tune_id }
}
}
#[derive(Debug, Deserialize)]
pub struct ListFineTuneEventsResponse {
pub object: String,
@ -189,6 +243,12 @@ pub struct DeleteFineTuneModelRequest {
pub model_id: String,
}
impl DeleteFineTuneModelRequest {
pub fn new(model_id: String) -> Self {
Self { model_id }
}
}
#[derive(Debug, Deserialize)]
pub struct DeleteFineTuneModelResponse {
pub id: String,

View File

@ -1,6 +1,8 @@
use serde::{Deserialize, Serialize};
use std::option::Option;
use crate::impl_builder_methods;
#[derive(Debug, Deserialize)]
pub struct ImageData {
pub url: String,
@ -19,6 +21,26 @@ pub struct ImageGenerationRequest {
pub user: Option<String>,
}
impl ImageGenerationRequest {
pub fn new(prompt: String) -> Self {
Self {
prompt,
n: None,
size: None,
response_format: None,
user: None,
}
}
}
impl_builder_methods!(
ImageGenerationRequest,
n: i32,
size: String,
response_format: String,
user: String
);
#[derive(Debug, Deserialize)]
pub struct ImageGenerationResponse {
pub created: i64,
@ -41,6 +63,29 @@ pub struct ImageEditRequest {
pub user: Option<String>,
}
impl ImageEditRequest {
pub fn new(image: String, prompt: String) -> Self {
Self {
image,
prompt,
mask: None,
n: None,
size: None,
response_format: None,
user: None,
}
}
}
impl_builder_methods!(
ImageEditRequest,
mask: String,
n: i32,
size: String,
response_format: String,
user: String
);
#[derive(Debug, Deserialize)]
pub struct ImageEditResponse {
pub created: i64,
@ -60,6 +105,26 @@ pub struct ImageVariationRequest {
pub user: Option<String>,
}
impl ImageVariationRequest {
pub fn new(image: String) -> Self {
Self {
image,
n: None,
size: None,
response_format: None,
user: None,
}
}
}
impl_builder_methods!(
ImageVariationRequest,
n: i32,
size: String,
response_format: String,
user: String
);
#[derive(Debug, Deserialize)]
pub struct ImageVariationResponse {
pub created: i64,

View File

@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};
use crate::impl_builder_methods;
#[derive(Debug, Serialize)]
pub struct CreateModerationRequest {
pub input: String,
@ -7,6 +9,17 @@ pub struct CreateModerationRequest {
pub model: Option<String>,
}
impl CreateModerationRequest {
pub fn new(input: String) -> Self {
Self { input, model: None }
}
}
impl_builder_methods!(
CreateModerationRequest,
model: String
);
#[derive(Debug, Deserialize)]
pub struct CreateModerationResponse {
pub id: String,