From a0c319cd54cba05f4d22475929099b8b63f51934 Mon Sep 17 00:00:00 2001 From: Dongri Jin Date: Tue, 25 Jul 2023 11:48:52 +0900 Subject: [PATCH 1/2] Remove Option --- examples/chat_completion.rs | 2 +- examples/completion.rs | 2 +- examples/function_call.rs | 6 +++--- examples/function_call_role.rs | 10 +++++----- src/v1/audio.rs | 4 ++-- src/v1/chat_completion.rs | 12 +++++------- src/v1/completion.rs | 3 +-- src/v1/moderation.rs | 2 ++ 8 files changed, 20 insertions(+), 21 deletions(-) diff --git a/examples/chat_completion.rs b/examples/chat_completion.rs index af2deab..afa4997 100644 --- a/examples/chat_completion.rs +++ b/examples/chat_completion.rs @@ -9,7 +9,7 @@ async fn main() -> Result<(), Box> { model: chat_completion::GPT4.to_string(), messages: vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, - content: Some(String::from("What is Bitcoin?")), + content: String::from("What is Bitcoin?"), name: None, function_call: None, }], diff --git a/examples/completion.rs b/examples/completion.rs index 6b7447d..653c326 100644 --- a/examples/completion.rs +++ b/examples/completion.rs @@ -7,7 +7,7 @@ async fn main() -> Result<(), Box> { let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); let req = CompletionRequest { model: completion::GPT3_TEXT_DAVINCI_003.to_string(), - prompt: Some(String::from("What is Bitcoin?")), + prompt: String::from("What is Bitcoin?"), suffix: None, max_tokens: Some(3000), temperature: Some(0.9), diff --git a/examples/function_call.rs b/examples/function_call.rs index 9c76d85..9882f89 100644 --- a/examples/function_call.rs +++ b/examples/function_call.rs @@ -34,18 +34,18 @@ async fn main() -> Result<(), Box> { model: chat_completion::GPT3_5_TURBO_0613.to_string(), messages: vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, - content: Some(String::from("What is the price of Ethereum?")), + content: String::from("What is the price of Ethereum?"), name: None, function_call: None, }], functions: Some(vec![chat_completion::Function { name: String::from("get_coin_price"), description: Some(String::from("Get the price of a cryptocurrency")), - parameters: Some(chat_completion::FunctionParameters { + parameters: chat_completion::FunctionParameters { schema_type: chat_completion::JSONSchemaType::Object, properties: Some(properties), required: Some(vec![String::from("coin")]), - }), + }, }]), function_call: Some("auto".to_string()), temperature: None, diff --git a/examples/function_call_role.rs b/examples/function_call_role.rs index cc22cd2..df4ff13 100644 --- a/examples/function_call_role.rs +++ b/examples/function_call_role.rs @@ -34,18 +34,18 @@ async fn main() -> Result<(), Box> { model: chat_completion::GPT3_5_TURBO_0613.to_string(), messages: vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, - content: Some(String::from("What is the price of Ethereum?")), + content: String::from("What is the price of Ethereum?"), name: None, function_call: None, }], functions: Some(vec![chat_completion::Function { name: String::from("get_coin_price"), description: Some(String::from("Get the price of a cryptocurrency")), - parameters: Some(chat_completion::FunctionParameters { + parameters: chat_completion::FunctionParameters { schema_type: chat_completion::JSONSchemaType::Object, properties: Some(properties), required: Some(vec![String::from("coin")]), - }), + }, }]), function_call: None, temperature: None, @@ -86,13 +86,13 @@ async fn main() -> Result<(), Box> { messages: vec![ chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, - content: Some(String::from("What is the price of Ethereum?")), + content: String::from("What is the price of Ethereum?"), name: None, function_call: None, }, chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::function, - content: Some({ + content: ({ let price = get_coin_price(&coin).await; format!("{{\"price\": {}}}", price) }), diff --git a/src/v1/audio.rs b/src/v1/audio.rs index 1948dd3..50e91cb 100644 --- a/src/v1/audio.rs +++ b/src/v1/audio.rs @@ -4,8 +4,8 @@ pub const WHISPER_1: &str = "whisper-1"; #[derive(Debug, Serialize)] pub struct AudioTranscriptionRequest { - pub model: String, pub file: String, + pub model: String, #[serde(skip_serializing_if = "Option::is_none")] pub prompt: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -23,8 +23,8 @@ pub struct AudioTranscriptionResponse { #[derive(Debug, Serialize)] pub struct AudioTranslationRequest { - pub model: String, pub file: String, + pub model: String, #[serde(skip_serializing_if = "Option::is_none")] pub prompt: Option, #[serde(skip_serializing_if = "Option::is_none")] diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 7f20a84..690f3f3 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -43,7 +43,7 @@ pub struct ChatCompletionRequest { pub user: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] #[allow(non_camel_case_types)] pub enum MessageRole { user, @@ -52,11 +52,10 @@ pub enum MessageRole { function, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct ChatCompletionMessage { pub role: MessageRole, - #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option, + pub content: String, #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -85,8 +84,7 @@ pub struct Function { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub parameters: Option, + pub parameters: FunctionParameters, } #[derive(Debug, Serialize, Deserialize)] @@ -136,7 +134,7 @@ pub enum FinishReason { null, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct FunctionCall { #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, diff --git a/src/v1/completion.rs b/src/v1/completion.rs index 4e22d9f..3781294 100644 --- a/src/v1/completion.rs +++ b/src/v1/completion.rs @@ -20,8 +20,7 @@ pub const GPT3_BABBAGE: &str = "babbage"; #[derive(Debug, Serialize)] pub struct CompletionRequest { pub model: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt: Option, + pub prompt: String, #[serde(skip_serializing_if = "Option::is_none")] pub suffix: Option, #[serde(skip_serializing_if = "Option::is_none")] diff --git a/src/v1/moderation.rs b/src/v1/moderation.rs index 58bc861..78b4f29 100644 --- a/src/v1/moderation.rs +++ b/src/v1/moderation.rs @@ -3,6 +3,8 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize)] pub struct CreateModerationRequest { pub input: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, } #[derive(Debug, Deserialize)] From 5ee3321ed095831cefd773d8a6924a1ab9951703 Mon Sep 17 00:00:00 2001 From: Dongri Jin Date: Tue, 25 Jul 2023 12:02:34 +0900 Subject: [PATCH 2/2] Fix message --- examples/function_call_role.rs | 4 ++-- src/v1/chat_completion.rs | 13 ++++++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/function_call_role.rs b/examples/function_call_role.rs index df4ff13..2df42ce 100644 --- a/examples/function_call_role.rs +++ b/examples/function_call_role.rs @@ -92,10 +92,10 @@ async fn main() -> Result<(), Box> { }, chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::function, - content: ({ + content: { let price = get_coin_price(&coin).await; format!("{{\"price\": {}}}", price) - }), + }, name: Some(String::from("get_coin_price")), function_call: None, }, diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 690f3f3..6de6381 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -62,10 +62,21 @@ pub struct ChatCompletionMessage { pub function_call: Option, } +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatCompletionMessageForResponse { + pub role: MessageRole, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, +} + #[derive(Debug, Deserialize)] pub struct ChatCompletionChoice { pub index: i64, - pub message: ChatCompletionMessage, + pub message: ChatCompletionMessageForResponse, pub finish_reason: FinishReason, }