diff --git a/examples/function_call_role.rs b/examples/function_call_role.rs new file mode 100644 index 0000000..94a35c2 --- /dev/null +++ b/examples/function_call_role.rs @@ -0,0 +1,126 @@ +use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::{env, vec}; + +async fn get_coin_price(coin: &str) -> f64 { + let coin = coin.to_lowercase(); + match coin.as_str() { + "btc" | "bitcoin" => 10000.0, + "eth" | "ethereum" => 1000.0, + _ => 0.0, + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + + let mut properties = HashMap::new(); + properties.insert( + "coin".to_string(), + Box::new(chat_completion::JSONSchemaDefine { + schema_type: Some(chat_completion::JSONSchemaType::String), + description: Some("The cryptocurrency to get the price of".to_string()), + enum_values: None, + properties: None, + required: None, + items: None, + }), + ); + + let req = ChatCompletionRequest { + model: chat_completion::GPT3_5_TURBO_0613.to_string(), + messages: vec![chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::user, + content: Some(String::from("What is the price of Ethereum?")), + name: None, + function_call: None, + }], + functions: Some(vec![chat_completion::Function { + name: String::from("get_coin_price"), + description: Some(String::from("Get the price of a cryptocurrency")), + parameters: Some(chat_completion::FunctionParameters { + schema_type: chat_completion::JSONSchemaType::Object, + properties: Some(properties), + required: Some(vec![String::from("coin")]), + }), + }]), + function_call: None, + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + }; + + let result = client.chat_completion(req).await?; + + match result.choices[0].finish_reason { + chat_completion::FinishReason::stop => { + println!("Stop"); + println!("{:?}", result.choices[0].message.content); + } + chat_completion::FinishReason::length => { + println!("Length"); + } + chat_completion::FinishReason::function_call => { + println!("FunctionCall"); + #[derive(Serialize, Deserialize)] + struct Currency { + coin: String, + } + let function_call = result.choices[0].message.function_call.as_ref().unwrap(); + let arguments = function_call.arguments.clone().unwrap(); + let c: Currency = serde_json::from_str(&arguments)?; + let coin = c.coin; + + let req = ChatCompletionRequest { + model: chat_completion::GPT3_5_TURBO_0613.to_string(), + messages: vec![chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::user, + content: Some(String::from("What is the price of Ethereum?")), + name: None, + function_call: None, + }, chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::function, + content: Some({ + let price = get_coin_price(&coin).await; + format!("{{\"price\": {}}}", price) + }), + name: Some(String::from("get_coin_price")), + function_call: None, + }], + functions: None, + function_call: None, + temperature: None, + top_p: None, + n: None, + stream: None, + stop: None, + max_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + }; + let result = client.chat_completion(req).await?; + println!("{:?}", result.choices[0].message.content); + } + chat_completion::FinishReason::content_filter => { + println!("ContentFilter"); + } + chat_completion::FinishReason::null => { + println!("Null"); + } + } + Ok(()) +} + +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example function_call_role diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 94ac7a9..7f20a84 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -49,6 +49,7 @@ pub enum MessageRole { user, system, assistant, + function, } #[derive(Debug, Serialize, Deserialize)]