mirror of
https://github.com/mii443/openai-api-rs.git
synced 2025-08-22 15:15:34 +00:00
@ -12,7 +12,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
role: chat_completion::MessageRole::user,
|
||||
content: String::from("What is Bitcoin?"),
|
||||
name: None,
|
||||
function_call: None,
|
||||
}],
|
||||
);
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
use openai_api_rs::v1::api::Client;
|
||||
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest, FunctionCallType};
|
||||
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
|
||||
use openai_api_rs::v1::common::GPT3_5_TURBO_0613;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
@ -33,19 +33,21 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
role: chat_completion::MessageRole::user,
|
||||
content: String::from("What is the price of Ethereum?"),
|
||||
name: None,
|
||||
function_call: None,
|
||||
}],
|
||||
)
|
||||
.functions(vec![chat_completion::Function {
|
||||
name: String::from("get_coin_price"),
|
||||
description: Some(String::from("Get the price of a cryptocurrency")),
|
||||
parameters: chat_completion::FunctionParameters {
|
||||
schema_type: chat_completion::JSONSchemaType::Object,
|
||||
properties: Some(properties),
|
||||
required: Some(vec![String::from("coin")]),
|
||||
.tools(vec![chat_completion::Tool {
|
||||
r#type: chat_completion::ToolType::Function,
|
||||
function: chat_completion::Function {
|
||||
name: String::from("get_coin_price"),
|
||||
description: Some(String::from("Get the price of a cryptocurrency")),
|
||||
parameters: chat_completion::FunctionParameters {
|
||||
schema_type: chat_completion::JSONSchemaType::Object,
|
||||
properties: Some(properties),
|
||||
required: Some(vec![String::from("coin")]),
|
||||
},
|
||||
},
|
||||
}])
|
||||
.function_call(FunctionCallType::Auto);
|
||||
.tool_choice(chat_completion::ToolChoiceType::Auto);
|
||||
|
||||
// debug request json
|
||||
// let serialized = serde_json::to_string(&req).unwrap();
|
||||
@ -65,20 +67,22 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
Some(chat_completion::FinishReason::length) => {
|
||||
println!("Length");
|
||||
}
|
||||
Some(chat_completion::FinishReason::function_call) => {
|
||||
println!("FunctionCall");
|
||||
Some(chat_completion::FinishReason::tool_calls) => {
|
||||
println!("ToolCalls");
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct Currency {
|
||||
coin: String,
|
||||
}
|
||||
let function_call = result.choices[0].message.function_call.as_ref().unwrap();
|
||||
let name = function_call.name.clone().unwrap();
|
||||
let arguments = function_call.arguments.clone().unwrap();
|
||||
let c: Currency = serde_json::from_str(&arguments)?;
|
||||
let coin = c.coin;
|
||||
if name == "get_coin_price" {
|
||||
let price = get_coin_price(&coin);
|
||||
println!("{} price: {}", coin, price);
|
||||
let tool_calls = result.choices[0].message.tool_calls.as_ref().unwrap();
|
||||
for tool_call in tool_calls {
|
||||
let name = tool_call.function.name.clone().unwrap();
|
||||
let arguments = tool_call.function.arguments.clone().unwrap();
|
||||
let c: Currency = serde_json::from_str(&arguments)?;
|
||||
let coin = c.coin;
|
||||
if name == "get_coin_price" {
|
||||
let price = get_coin_price(&coin);
|
||||
println!("{} price: {}", coin, price);
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(chat_completion::FinishReason::content_filter) => {
|
||||
|
@ -33,16 +33,18 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
role: chat_completion::MessageRole::user,
|
||||
content: String::from("What is the price of Ethereum?"),
|
||||
name: None,
|
||||
function_call: None,
|
||||
}],
|
||||
)
|
||||
.functions(vec![chat_completion::Function {
|
||||
name: String::from("get_coin_price"),
|
||||
description: Some(String::from("Get the price of a cryptocurrency")),
|
||||
parameters: chat_completion::FunctionParameters {
|
||||
schema_type: chat_completion::JSONSchemaType::Object,
|
||||
properties: Some(properties),
|
||||
required: Some(vec![String::from("coin")]),
|
||||
.tools(vec![chat_completion::Tool {
|
||||
r#type: chat_completion::ToolType::Function,
|
||||
function: chat_completion::Function {
|
||||
name: String::from("get_coin_price"),
|
||||
description: Some(String::from("Get the price of a cryptocurrency")),
|
||||
parameters: chat_completion::FunctionParameters {
|
||||
schema_type: chat_completion::JSONSchemaType::Object,
|
||||
properties: Some(properties),
|
||||
required: Some(vec![String::from("coin")]),
|
||||
},
|
||||
},
|
||||
}]);
|
||||
|
||||
@ -60,40 +62,44 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
Some(chat_completion::FinishReason::length) => {
|
||||
println!("Length");
|
||||
}
|
||||
Some(chat_completion::FinishReason::function_call) => {
|
||||
println!("FunctionCall");
|
||||
Some(chat_completion::FinishReason::tool_calls) => {
|
||||
println!("ToolCalls");
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct Currency {
|
||||
coin: String,
|
||||
}
|
||||
let function_call = result.choices[0].message.function_call.as_ref().unwrap();
|
||||
let arguments = function_call.arguments.clone().unwrap();
|
||||
let c: Currency = serde_json::from_str(&arguments)?;
|
||||
let coin = c.coin;
|
||||
let tool_calls = result.choices[0].message.tool_calls.as_ref().unwrap();
|
||||
for tool_call in tool_calls {
|
||||
let function_call = &tool_call.function;
|
||||
let arguments = function_call.arguments.clone().unwrap();
|
||||
let c: Currency = serde_json::from_str(&arguments)?;
|
||||
let coin = c.coin;
|
||||
println!("coin: {}", coin);
|
||||
let price = get_coin_price(&coin);
|
||||
println!("price: {}", price);
|
||||
|
||||
let req = ChatCompletionRequest::new(
|
||||
GPT3_5_TURBO_0613.to_string(),
|
||||
vec![
|
||||
chat_completion::ChatCompletionMessage {
|
||||
role: chat_completion::MessageRole::user,
|
||||
content: String::from("What is the price of Ethereum?"),
|
||||
name: None,
|
||||
function_call: None,
|
||||
},
|
||||
chat_completion::ChatCompletionMessage {
|
||||
role: chat_completion::MessageRole::function,
|
||||
content: {
|
||||
let price = get_coin_price(&coin);
|
||||
format!("{{\"price\": {}}}", price)
|
||||
let req = ChatCompletionRequest::new(
|
||||
GPT3_5_TURBO_0613.to_string(),
|
||||
vec![
|
||||
chat_completion::ChatCompletionMessage {
|
||||
role: chat_completion::MessageRole::user,
|
||||
content: String::from("What is the price of Ethereum?"),
|
||||
name: None,
|
||||
},
|
||||
name: Some(String::from("get_coin_price")),
|
||||
function_call: None,
|
||||
},
|
||||
],
|
||||
);
|
||||
chat_completion::ChatCompletionMessage {
|
||||
role: chat_completion::MessageRole::function,
|
||||
content: {
|
||||
let price = get_coin_price(&coin);
|
||||
format!("{{\"price\": {}}}", price)
|
||||
},
|
||||
name: Some(String::from("get_coin_price")),
|
||||
},
|
||||
],
|
||||
);
|
||||
|
||||
let result = client.chat_completion(req)?;
|
||||
println!("{:?}", result.choices[0].message.content);
|
||||
let result = client.chat_completion(req)?;
|
||||
println!("{:?}", result.choices[0].message.content);
|
||||
}
|
||||
}
|
||||
Some(chat_completion::FinishReason::content_filter) => {
|
||||
println!("ContentFilter");
|
||||
|
@ -6,13 +6,6 @@ use std::collections::HashMap;
|
||||
use crate::impl_builder_methods;
|
||||
use crate::v1::common;
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub enum FunctionCallType {
|
||||
None,
|
||||
Auto,
|
||||
Function { name: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub enum ToolChoiceType {
|
||||
None,
|
||||
@ -25,19 +18,6 @@ pub struct ChatCompletionRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatCompletionMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[deprecated(
|
||||
since = "2.1.5",
|
||||
note = "This field is deprecated. Use `tools` instead."
|
||||
)]
|
||||
pub functions: Option<Vec<Function>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(serialize_with = "serialize_function_call")]
|
||||
#[deprecated(
|
||||
since = "2.1.5",
|
||||
note = "This field is deprecated. Use `tool_choice` instead."
|
||||
)]
|
||||
pub function_call: Option<FunctionCallType>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f64>,
|
||||
@ -73,8 +53,6 @@ impl ChatCompletionRequest {
|
||||
Self {
|
||||
model,
|
||||
messages,
|
||||
functions: None,
|
||||
function_call: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
stream: None,
|
||||
@ -95,8 +73,6 @@ impl ChatCompletionRequest {
|
||||
|
||||
impl_builder_methods!(
|
||||
ChatCompletionRequest,
|
||||
functions: Vec<Function>,
|
||||
function_call: FunctionCallType,
|
||||
temperature: f64,
|
||||
top_p: f64,
|
||||
n: i64,
|
||||
@ -128,8 +104,6 @@ pub struct ChatCompletionMessage {
|
||||
pub content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function_call: Option<FunctionCall>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@ -140,7 +114,7 @@ pub struct ChatCompletionMessageForResponse {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function_call: Option<FunctionCall>,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@ -212,8 +186,8 @@ pub struct FunctionParameters {
|
||||
pub enum FinishReason {
|
||||
stop,
|
||||
length,
|
||||
function_call,
|
||||
content_filter,
|
||||
tool_calls,
|
||||
null,
|
||||
}
|
||||
|
||||
@ -225,32 +199,20 @@ pub struct FinishDetails {
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct FunctionCall {
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
pub r#type: String,
|
||||
pub function: ToolCallFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ToolCallFunction {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
fn serialize_function_call<S>(
|
||||
value: &Option<FunctionCallType>,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
match value {
|
||||
Some(FunctionCallType::None) => serializer.serialize_str("none"),
|
||||
Some(FunctionCallType::Auto) => serializer.serialize_str("auto"),
|
||||
Some(FunctionCallType::Function { name }) => {
|
||||
let mut map = serializer.serialize_map(Some(1))?;
|
||||
map.serialize_entry("name", name)?;
|
||||
map.end()
|
||||
}
|
||||
None => serializer.serialize_none(),
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_tool_choice<S>(
|
||||
value: &Option<ToolChoiceType>,
|
||||
serializer: S,
|
||||
|
Reference in New Issue
Block a user