From 6292d785074abfbb0aaa724fdb437b666fb3a771 Mon Sep 17 00:00:00 2001 From: Dongri Jin Date: Mon, 19 Jun 2023 18:24:27 +0900 Subject: [PATCH] Support function calling --- Cargo.toml | 7 +-- README.md | 13 ++++-- examples/chat_completion.rs | 7 ++- examples/completion.rs | 2 +- examples/function_call.rs | 89 +++++++++++++++++++++++++++++++++++++ src/v1/chat_completion.rs | 81 ++++++++++++++++++++++++++++++++- 6 files changed, 188 insertions(+), 11 deletions(-) create mode 100644 examples/function_call.rs diff --git a/Cargo.toml b/Cargo.toml index 8b2c3d5..1f43e1f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,14 +1,15 @@ [package] name = "openai-api-rs" -version = "0.1.7" +version = "0.1.8" edition = "2021" authors = ["Dongri Jin "] license = "MIT" -description = "OpenAI API wrapper for Rust" +description = "OpenAI API client library for Rust (unofficial)" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] reqwest = { version = "0.11", features = ["json"] } tokio = { version = "1", features = ["full"] } -serde = { version = "1", features = ["derive"] } \ No newline at end of file +serde = { version = "1", features = ["derive"] } +serde_json = "1.0.97" diff --git a/README.md b/README.md index 40ba7b9..e169f1d 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,13 @@ # OpenAI API client library for Rust (unofficial) The OpenAI API client Rust library provides convenient access to the OpenAI API from Rust applications. -Check out the [docs.rs](https://docs.rs/openai-api-rs/0.1.7/openai_api_rs/v1/index.html). +Check out the [docs.rs](https://docs.rs/openai-api-rs/0.1.8/openai_api_rs/v1/index.html). ## Installation: Cargo.toml ```toml [dependencies] -openai-api-rs = "0.1.7" +openai-api-rs = "0.1.8" ``` ## Usage @@ -56,14 +56,20 @@ async fn main() -> Result<(), Box> { model: chat_completion::GPT4.to_string(), messages: vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, - content: String::from("Hello OpenAI!"), + content: Some(String::from("What is Bitcoin?")), + name: None, + function_call: None, }], + functions: None, + function_call: None, }; let result = client.chat_completion(req).await?; println!("{:?}", result.choices[0].message.content); Ok(()) } ``` +More Examples: [examples](https://github.com/dongri/openai-api-rs/tree/main/examples) + Check out the [full API documentation](https://platform.openai.com/docs/api-reference/completions) for examples of all the available functions. ## Supported APIs @@ -76,6 +82,7 @@ Check out the [full API documentation](https://platform.openai.com/docs/api-refe - [x] [Files](https://platform.openai.com/docs/api-reference/files) - [x] [Fine-tunes](https://platform.openai.com/docs/api-reference/fine-tunes) - [x] [Moderations](https://platform.openai.com/docs/api-reference/moderations) +- [x] [Function calling](https://platform.openai.com/docs/guides/gpt/function-calling) ## License This project is licensed under [MIT license](https://github.com/dongri/openai-api-rs/blob/main/LICENSE). diff --git a/examples/chat_completion.rs b/examples/chat_completion.rs index b87d63d..9326f1c 100644 --- a/examples/chat_completion.rs +++ b/examples/chat_completion.rs @@ -9,12 +9,15 @@ async fn main() -> Result<(), Box> { model: chat_completion::GPT4.to_string(), messages: vec![chat_completion::ChatCompletionMessage { role: chat_completion::MessageRole::user, - content: String::from("NFTとは?"), + content: Some(String::from("What is Bitcoin?")), + name: None, + function_call: None, }], + functions: None, + function_call: None, }; let result = client.chat_completion(req).await?; println!("{:?}", result.choices[0].message.content); - Ok(()) } diff --git a/examples/completion.rs b/examples/completion.rs index ecb20af..6b7447d 100644 --- a/examples/completion.rs +++ b/examples/completion.rs @@ -7,7 +7,7 @@ async fn main() -> Result<(), Box> { let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); let req = CompletionRequest { model: completion::GPT3_TEXT_DAVINCI_003.to_string(), - prompt: Some(String::from("NFTとは?")), + prompt: Some(String::from("What is Bitcoin?")), suffix: None, max_tokens: Some(3000), temperature: Some(0.9), diff --git a/examples/function_call.rs b/examples/function_call.rs new file mode 100644 index 0000000..de91b89 --- /dev/null +++ b/examples/function_call.rs @@ -0,0 +1,89 @@ +use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::{env, vec}; + +async fn get_coin_price(coin: &str) -> f64 { + let coin = coin.to_lowercase(); + match coin.as_str() { + "btc" | "bitcoin" => 10000.0, + "eth" | "ethereum" => 1000.0, + _ => 0.0, + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); + + let mut properties = HashMap::new(); + properties.insert( + "coin".to_string(), + Box::new(chat_completion::JSONSchemaDefine { + schema_type: Some(chat_completion::JSONSchemaType::String), + description: Some("The cryptocurrency to get the price of".to_string()), + enum_values: None, + properties: None, + required: None, + items: None, + }), + ); + + let req = ChatCompletionRequest { + model: chat_completion::GPT3_5_TURBO_0613.to_string(), + messages: vec![chat_completion::ChatCompletionMessage { + role: chat_completion::MessageRole::user, + content: Some(String::from("What is the price of Ethereum?")), + name: None, + function_call: None, + }], + functions: Some(vec![chat_completion::Function { + name: String::from("get_coin_price"), + description: Some(String::from("Get the price of a cryptocurrency")), + parameters: Some(chat_completion::FunctionParameters { + schema_type: chat_completion::JSONSchemaType::Object, + properties: Some(properties), + required: Some(vec![String::from("coin")]), + }), + }]), + function_call: Some("auto".to_string()), + }; + + let result = client.chat_completion(req).await?; + + match result.choices[0].finish_reason { + chat_completion::FinishReason::stop => { + println!("Stop"); + println!("{:?}", result.choices[0].message.content); + } + chat_completion::FinishReason::length => { + println!("Length"); + } + chat_completion::FinishReason::function_call => { + println!("FunctionCall"); + #[derive(Serialize, Deserialize)] + struct Currency { + coin: String, + } + let function_call = result.choices[0].message.function_call.as_ref().unwrap(); + let name = function_call.name.clone().unwrap(); + let arguments = function_call.arguments.clone().unwrap(); + let c: Currency = serde_json::from_str(&arguments)?; + let coin = c.coin; + if name == "get_coin_price" { + let price = get_coin_price(&coin).await; + println!("{} price: {}", coin, price); + } + } + chat_completion::FinishReason::content_filter => { + println!("ContentFilter"); + } + chat_completion::FinishReason::null => { + println!("Null"); + } + } + Ok(()) +} + +// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example function_call diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 22330ed..9233f0f 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -1,18 +1,26 @@ use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use crate::v1::common; pub const GPT3_5_TURBO: &str = "gpt-3.5-turbo"; pub const GPT3_5_TURBO_0301: &str = "gpt-3.5-turbo-0301"; +pub const GPT3_5_TURBO_0613: &str = "gpt-3.5-turbo-0613"; + pub const GPT4: &str = "gpt-4"; pub const GPT4_0314: &str = "gpt-4-0314"; pub const GPT4_32K: &str = "gpt-4-32k"; pub const GPT4_32K_0314: &str = "gpt-4-32k-0314"; +pub const GPT4_0613: &str = "gpt-4-0613"; #[derive(Debug, Serialize)] pub struct ChatCompletionRequest { pub model: String, pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub functions: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -26,14 +34,19 @@ pub enum MessageRole { #[derive(Debug, Serialize, Deserialize)] pub struct ChatCompletionMessage { pub role: MessageRole, - pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, } #[derive(Debug, Deserialize)] pub struct ChatCompletionChoice { pub index: i64, pub message: ChatCompletionMessage, - pub finish_reason: String, + pub finish_reason: FinishReason, } #[derive(Debug, Deserialize)] @@ -45,3 +58,67 @@ pub struct ChatCompletionResponse { pub choices: Vec, pub usage: common::Usage, } + +#[derive(Debug, Serialize, Deserialize)] +pub struct Function { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum JSONSchemaType { + Object, + Number, + String, + Array, + Null, + Boolean, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct JSONSchemaDefine { + #[serde(rename = "type")] + pub schema_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub enum_values: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub items: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct FunctionParameters { + #[serde(rename = "type")] + pub schema_type: JSONSchemaType, + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +#[allow(non_camel_case_types)] +pub enum FinishReason { + stop, + length, + function_call, + content_filter, + null, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct FunctionCall { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, +}