Fix Deprecated

This commit is contained in:
Dongri Jin
2024-01-07 07:39:36 +09:00
parent f0973eaf5b
commit 5d7d335c74
4 changed files with 75 additions and 104 deletions

View File

@ -12,7 +12,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
role: chat_completion::MessageRole::user, role: chat_completion::MessageRole::user,
content: String::from("What is Bitcoin?"), content: String::from("What is Bitcoin?"),
name: None, name: None,
function_call: None,
}], }],
); );

View File

@ -1,5 +1,5 @@
use openai_api_rs::v1::api::Client; use openai_api_rs::v1::api::Client;
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest, FunctionCallType}; use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
use openai_api_rs::v1::common::GPT3_5_TURBO_0613; use openai_api_rs::v1::common::GPT3_5_TURBO_0613;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
@ -33,19 +33,21 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
role: chat_completion::MessageRole::user, role: chat_completion::MessageRole::user,
content: String::from("What is the price of Ethereum?"), content: String::from("What is the price of Ethereum?"),
name: None, name: None,
function_call: None,
}], }],
) )
.functions(vec![chat_completion::Function { .tools(vec![chat_completion::Tool {
name: String::from("get_coin_price"), r#type: chat_completion::ToolType::Function,
description: Some(String::from("Get the price of a cryptocurrency")), function: chat_completion::Function {
parameters: chat_completion::FunctionParameters { name: String::from("get_coin_price"),
schema_type: chat_completion::JSONSchemaType::Object, description: Some(String::from("Get the price of a cryptocurrency")),
properties: Some(properties), parameters: chat_completion::FunctionParameters {
required: Some(vec![String::from("coin")]), schema_type: chat_completion::JSONSchemaType::Object,
properties: Some(properties),
required: Some(vec![String::from("coin")]),
},
}, },
}]) }])
.function_call(FunctionCallType::Auto); .tool_choice(chat_completion::ToolChoiceType::Auto);
// debug request json // debug request json
// let serialized = serde_json::to_string(&req).unwrap(); // let serialized = serde_json::to_string(&req).unwrap();
@ -65,20 +67,22 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Some(chat_completion::FinishReason::length) => { Some(chat_completion::FinishReason::length) => {
println!("Length"); println!("Length");
} }
Some(chat_completion::FinishReason::function_call) => { Some(chat_completion::FinishReason::tool_calls) => {
println!("FunctionCall"); println!("ToolCalls");
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct Currency { struct Currency {
coin: String, coin: String,
} }
let function_call = result.choices[0].message.function_call.as_ref().unwrap(); let tool_calls = result.choices[0].message.tool_calls.as_ref().unwrap();
let name = function_call.name.clone().unwrap(); for tool_call in tool_calls {
let arguments = function_call.arguments.clone().unwrap(); let name = tool_call.function.name.clone().unwrap();
let c: Currency = serde_json::from_str(&arguments)?; let arguments = tool_call.function.arguments.clone().unwrap();
let coin = c.coin; let c: Currency = serde_json::from_str(&arguments)?;
if name == "get_coin_price" { let coin = c.coin;
let price = get_coin_price(&coin); if name == "get_coin_price" {
println!("{} price: {}", coin, price); let price = get_coin_price(&coin);
println!("{} price: {}", coin, price);
}
} }
} }
Some(chat_completion::FinishReason::content_filter) => { Some(chat_completion::FinishReason::content_filter) => {

View File

@ -33,16 +33,18 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
role: chat_completion::MessageRole::user, role: chat_completion::MessageRole::user,
content: String::from("What is the price of Ethereum?"), content: String::from("What is the price of Ethereum?"),
name: None, name: None,
function_call: None,
}], }],
) )
.functions(vec![chat_completion::Function { .tools(vec![chat_completion::Tool {
name: String::from("get_coin_price"), r#type: chat_completion::ToolType::Function,
description: Some(String::from("Get the price of a cryptocurrency")), function: chat_completion::Function {
parameters: chat_completion::FunctionParameters { name: String::from("get_coin_price"),
schema_type: chat_completion::JSONSchemaType::Object, description: Some(String::from("Get the price of a cryptocurrency")),
properties: Some(properties), parameters: chat_completion::FunctionParameters {
required: Some(vec![String::from("coin")]), schema_type: chat_completion::JSONSchemaType::Object,
properties: Some(properties),
required: Some(vec![String::from("coin")]),
},
}, },
}]); }]);
@ -60,40 +62,44 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Some(chat_completion::FinishReason::length) => { Some(chat_completion::FinishReason::length) => {
println!("Length"); println!("Length");
} }
Some(chat_completion::FinishReason::function_call) => { Some(chat_completion::FinishReason::tool_calls) => {
println!("FunctionCall"); println!("ToolCalls");
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct Currency { struct Currency {
coin: String, coin: String,
} }
let function_call = result.choices[0].message.function_call.as_ref().unwrap(); let tool_calls = result.choices[0].message.tool_calls.as_ref().unwrap();
let arguments = function_call.arguments.clone().unwrap(); for tool_call in tool_calls {
let c: Currency = serde_json::from_str(&arguments)?; let function_call = &tool_call.function;
let coin = c.coin; let arguments = function_call.arguments.clone().unwrap();
let c: Currency = serde_json::from_str(&arguments)?;
let coin = c.coin;
println!("coin: {}", coin);
let price = get_coin_price(&coin);
println!("price: {}", price);
let req = ChatCompletionRequest::new( let req = ChatCompletionRequest::new(
GPT3_5_TURBO_0613.to_string(), GPT3_5_TURBO_0613.to_string(),
vec![ vec![
chat_completion::ChatCompletionMessage { chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user, role: chat_completion::MessageRole::user,
content: String::from("What is the price of Ethereum?"), content: String::from("What is the price of Ethereum?"),
name: None, name: None,
function_call: None,
},
chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::function,
content: {
let price = get_coin_price(&coin);
format!("{{\"price\": {}}}", price)
}, },
name: Some(String::from("get_coin_price")), chat_completion::ChatCompletionMessage {
function_call: None, role: chat_completion::MessageRole::function,
}, content: {
], let price = get_coin_price(&coin);
); format!("{{\"price\": {}}}", price)
},
name: Some(String::from("get_coin_price")),
},
],
);
let result = client.chat_completion(req)?; let result = client.chat_completion(req)?;
println!("{:?}", result.choices[0].message.content); println!("{:?}", result.choices[0].message.content);
}
} }
Some(chat_completion::FinishReason::content_filter) => { Some(chat_completion::FinishReason::content_filter) => {
println!("ContentFilter"); println!("ContentFilter");

View File

@ -6,13 +6,6 @@ use std::collections::HashMap;
use crate::impl_builder_methods; use crate::impl_builder_methods;
use crate::v1::common; use crate::v1::common;
#[derive(Debug, Serialize, Clone)]
pub enum FunctionCallType {
None,
Auto,
Function { name: String },
}
#[derive(Debug, Serialize, Clone)] #[derive(Debug, Serialize, Clone)]
pub enum ToolChoiceType { pub enum ToolChoiceType {
None, None,
@ -25,19 +18,6 @@ pub struct ChatCompletionRequest {
pub model: String, pub model: String,
pub messages: Vec<ChatCompletionMessage>, pub messages: Vec<ChatCompletionMessage>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[deprecated(
since = "2.1.5",
note = "This field is deprecated. Use `tools` instead."
)]
pub functions: Option<Vec<Function>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(serialize_with = "serialize_function_call")]
#[deprecated(
since = "2.1.5",
note = "This field is deprecated. Use `tool_choice` instead."
)]
pub function_call: Option<FunctionCallType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>, pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>, pub top_p: Option<f64>,
@ -73,8 +53,6 @@ impl ChatCompletionRequest {
Self { Self {
model, model,
messages, messages,
functions: None,
function_call: None,
temperature: None, temperature: None,
top_p: None, top_p: None,
stream: None, stream: None,
@ -95,8 +73,6 @@ impl ChatCompletionRequest {
impl_builder_methods!( impl_builder_methods!(
ChatCompletionRequest, ChatCompletionRequest,
functions: Vec<Function>,
function_call: FunctionCallType,
temperature: f64, temperature: f64,
top_p: f64, top_p: f64,
n: i64, n: i64,
@ -128,8 +104,6 @@ pub struct ChatCompletionMessage {
pub content: String, pub content: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>, pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -140,7 +114,7 @@ pub struct ChatCompletionMessageForResponse {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>, pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>, pub tool_calls: Option<Vec<ToolCall>>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -212,8 +186,8 @@ pub struct FunctionParameters {
pub enum FinishReason { pub enum FinishReason {
stop, stop,
length, length,
function_call,
content_filter, content_filter,
tool_calls,
null, null,
} }
@ -225,32 +199,20 @@ pub struct FinishDetails {
} }
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FunctionCall { pub struct ToolCall {
pub id: String,
pub r#type: String,
pub function: ToolCallFunction,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolCallFunction {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>, pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>, pub arguments: Option<String>,
} }
fn serialize_function_call<S>(
value: &Option<FunctionCallType>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match value {
Some(FunctionCallType::None) => serializer.serialize_str("none"),
Some(FunctionCallType::Auto) => serializer.serialize_str("auto"),
Some(FunctionCallType::Function { name }) => {
let mut map = serializer.serialize_map(Some(1))?;
map.serialize_entry("name", name)?;
map.end()
}
None => serializer.serialize_none(),
}
}
fn serialize_tool_choice<S>( fn serialize_tool_choice<S>(
value: &Option<ToolChoiceType>, value: &Option<ToolChoiceType>,
serializer: S, serializer: S,