From 5d7d335c74d59a256c53eb09fe4fca701b1e33b3 Mon Sep 17 00:00:00 2001 From: Dongri Jin Date: Sun, 7 Jan 2024 07:39:36 +0900 Subject: [PATCH] Fix Deprecated --- examples/chat_completion.rs | 1 - examples/function_call.rs | 44 +++++++++++--------- examples/function_call_role.rs | 76 ++++++++++++++++++---------------- src/v1/chat_completion.rs | 58 +++++--------------------- 4 files changed, 75 insertions(+), 104 deletions(-) diff --git a/examples/chat_completion.rs b/examples/chat_completion.rs index 3135fc0..b80a1b9 100644 --- a/examples/chat_completion.rs +++ b/examples/chat_completion.rs @@ -12,7 +12,6 @@ fn main() -> Result<(), Box> { role: chat_completion::MessageRole::user, content: String::from("What is Bitcoin?"), name: None, - function_call: None, }], ); diff --git a/examples/function_call.rs b/examples/function_call.rs index 367fb1e..12e6a2d 100644 --- a/examples/function_call.rs +++ b/examples/function_call.rs @@ -1,5 +1,5 @@ use openai_api_rs::v1::api::Client; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest, FunctionCallType}; +use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; use openai_api_rs::v1::common::GPT3_5_TURBO_0613; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -33,19 +33,21 @@ fn main() -> Result<(), Box> { role: chat_completion::MessageRole::user, content: String::from("What is the price of Ethereum?"), name: None, - function_call: None, }], ) - .functions(vec![chat_completion::Function { - name: String::from("get_coin_price"), - description: Some(String::from("Get the price of a cryptocurrency")), - parameters: chat_completion::FunctionParameters { - schema_type: chat_completion::JSONSchemaType::Object, - properties: Some(properties), - required: Some(vec![String::from("coin")]), + .tools(vec![chat_completion::Tool { + r#type: chat_completion::ToolType::Function, + function: chat_completion::Function { + name: String::from("get_coin_price"), + description: Some(String::from("Get the price of a cryptocurrency")), + parameters: chat_completion::FunctionParameters { + schema_type: chat_completion::JSONSchemaType::Object, + properties: Some(properties), + required: Some(vec![String::from("coin")]), + }, }, }]) - .function_call(FunctionCallType::Auto); + .tool_choice(chat_completion::ToolChoiceType::Auto); // debug request json // let serialized = serde_json::to_string(&req).unwrap(); @@ -65,20 +67,22 @@ fn main() -> Result<(), Box> { Some(chat_completion::FinishReason::length) => { println!("Length"); } - Some(chat_completion::FinishReason::function_call) => { - println!("FunctionCall"); + Some(chat_completion::FinishReason::tool_calls) => { + println!("ToolCalls"); #[derive(Serialize, Deserialize)] struct Currency { coin: String, } - let function_call = result.choices[0].message.function_call.as_ref().unwrap(); - let name = function_call.name.clone().unwrap(); - let arguments = function_call.arguments.clone().unwrap(); - let c: Currency = serde_json::from_str(&arguments)?; - let coin = c.coin; - if name == "get_coin_price" { - let price = get_coin_price(&coin); - println!("{} price: {}", coin, price); + let tool_calls = result.choices[0].message.tool_calls.as_ref().unwrap(); + for tool_call in tool_calls { + let name = tool_call.function.name.clone().unwrap(); + let arguments = tool_call.function.arguments.clone().unwrap(); + let c: Currency = serde_json::from_str(&arguments)?; + let coin = c.coin; + if name == "get_coin_price" { + let price = get_coin_price(&coin); + println!("{} price: {}", coin, price); + } } } Some(chat_completion::FinishReason::content_filter) => { diff --git a/examples/function_call_role.rs b/examples/function_call_role.rs index 0e486c1..9091da2 100644 --- a/examples/function_call_role.rs +++ b/examples/function_call_role.rs @@ -33,16 +33,18 @@ fn main() -> Result<(), Box> { role: chat_completion::MessageRole::user, content: String::from("What is the price of Ethereum?"), name: None, - function_call: None, }], ) - .functions(vec![chat_completion::Function { - name: String::from("get_coin_price"), - description: Some(String::from("Get the price of a cryptocurrency")), - parameters: chat_completion::FunctionParameters { - schema_type: chat_completion::JSONSchemaType::Object, - properties: Some(properties), - required: Some(vec![String::from("coin")]), + .tools(vec![chat_completion::Tool { + r#type: chat_completion::ToolType::Function, + function: chat_completion::Function { + name: String::from("get_coin_price"), + description: Some(String::from("Get the price of a cryptocurrency")), + parameters: chat_completion::FunctionParameters { + schema_type: chat_completion::JSONSchemaType::Object, + properties: Some(properties), + required: Some(vec![String::from("coin")]), + }, }, }]); @@ -60,40 +62,44 @@ fn main() -> Result<(), Box> { Some(chat_completion::FinishReason::length) => { println!("Length"); } - Some(chat_completion::FinishReason::function_call) => { - println!("FunctionCall"); + Some(chat_completion::FinishReason::tool_calls) => { + println!("ToolCalls"); #[derive(Serialize, Deserialize)] struct Currency { coin: String, } - let function_call = result.choices[0].message.function_call.as_ref().unwrap(); - let arguments = function_call.arguments.clone().unwrap(); - let c: Currency = serde_json::from_str(&arguments)?; - let coin = c.coin; + let tool_calls = result.choices[0].message.tool_calls.as_ref().unwrap(); + for tool_call in tool_calls { + let function_call = &tool_call.function; + let arguments = function_call.arguments.clone().unwrap(); + let c: Currency = serde_json::from_str(&arguments)?; + let coin = c.coin; + println!("coin: {}", coin); + let price = get_coin_price(&coin); + println!("price: {}", price); - let req = ChatCompletionRequest::new( - GPT3_5_TURBO_0613.to_string(), - vec![ - chat_completion::ChatCompletionMessage { - role: chat_completion::MessageRole::user, - content: String::from("What is the price of Ethereum?"), - name: None, - function_call: None, - }, - chat_completion::ChatCompletionMessage { - role: chat_completion::MessageRole::function, - content: { - let price = get_coin_price(&coin); - format!("{{\"price\": {}}}", price) + let req = ChatCompletionRequest::new( + GPT3_5_TURBO_0613.to_string(), + vec![ + chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::user, + content: String::from("What is the price of Ethereum?"), + name: None, }, - name: Some(String::from("get_coin_price")), - function_call: None, - }, - ], - ); + chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::function, + content: { + let price = get_coin_price(&coin); + format!("{{\"price\": {}}}", price) + }, + name: Some(String::from("get_coin_price")), + }, + ], + ); - let result = client.chat_completion(req)?; - println!("{:?}", result.choices[0].message.content); + let result = client.chat_completion(req)?; + println!("{:?}", result.choices[0].message.content); + } } Some(chat_completion::FinishReason::content_filter) => { println!("ContentFilter"); diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index ce23c08..16715c7 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -6,13 +6,6 @@ use std::collections::HashMap; use crate::impl_builder_methods; use crate::v1::common; -#[derive(Debug, Serialize, Clone)] -pub enum FunctionCallType { - None, - Auto, - Function { name: String }, -} - #[derive(Debug, Serialize, Clone)] pub enum ToolChoiceType { None, @@ -25,19 +18,6 @@ pub struct ChatCompletionRequest { pub model: String, pub messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] - #[deprecated( - since = "2.1.5", - note = "This field is deprecated. Use `tools` instead." - )] - pub functions: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(serialize_with = "serialize_function_call")] - #[deprecated( - since = "2.1.5", - note = "This field is deprecated. Use `tool_choice` instead." - )] - pub function_call: Option, - #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, @@ -73,8 +53,6 @@ impl ChatCompletionRequest { Self { model, messages, - functions: None, - function_call: None, temperature: None, top_p: None, stream: None, @@ -95,8 +73,6 @@ impl ChatCompletionRequest { impl_builder_methods!( ChatCompletionRequest, - functions: Vec, - function_call: FunctionCallType, temperature: f64, top_p: f64, n: i64, @@ -128,8 +104,6 @@ pub struct ChatCompletionMessage { pub content: String, #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub function_call: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -140,7 +114,7 @@ pub struct ChatCompletionMessageForResponse { #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub function_call: Option, + pub tool_calls: Option>, } #[derive(Debug, Deserialize)] @@ -212,8 +186,8 @@ pub struct FunctionParameters { pub enum FinishReason { stop, length, - function_call, content_filter, + tool_calls, null, } @@ -225,32 +199,20 @@ pub struct FinishDetails { } #[derive(Debug, Serialize, Deserialize, Clone)] -pub struct FunctionCall { +pub struct ToolCall { + pub id: String, + pub r#type: String, + pub function: ToolCallFunction, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ToolCallFunction { #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, #[serde(skip_serializing_if = "Option::is_none")] pub arguments: Option, } -fn serialize_function_call( - value: &Option, - serializer: S, -) -> Result -where - S: Serializer, -{ - match value { - Some(FunctionCallType::None) => serializer.serialize_str("none"), - Some(FunctionCallType::Auto) => serializer.serialize_str("auto"), - Some(FunctionCallType::Function { name }) => { - let mut map = serializer.serialize_map(Some(1))?; - map.serialize_entry("name", name)?; - map.end() - } - None => serializer.serialize_none(), - } -} - fn serialize_tool_choice( value: &Option, serializer: S,