Support function calling

This commit is contained in:
Dongri Jin
2023-06-19 18:24:27 +09:00
parent 7ce3efe04f
commit 6292d78507
6 changed files with 188 additions and 11 deletions

View File

@ -1,10 +1,10 @@
[package] [package]
name = "openai-api-rs" name = "openai-api-rs"
version = "0.1.7" version = "0.1.8"
edition = "2021" edition = "2021"
authors = ["Dongri Jin <dongrify@gmail.com>"] authors = ["Dongri Jin <dongrify@gmail.com>"]
license = "MIT" license = "MIT"
description = "OpenAI API wrapper for Rust" description = "OpenAI API client library for Rust (unofficial)"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@ -12,3 +12,4 @@ description = "OpenAI API wrapper for Rust"
reqwest = { version = "0.11", features = ["json"] } reqwest = { version = "0.11", features = ["json"] }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1.0.97"

View File

@ -1,13 +1,13 @@
# OpenAI API client library for Rust (unofficial) # OpenAI API client library for Rust (unofficial)
The OpenAI API client Rust library provides convenient access to the OpenAI API from Rust applications. The OpenAI API client Rust library provides convenient access to the OpenAI API from Rust applications.
Check out the [docs.rs](https://docs.rs/openai-api-rs/0.1.7/openai_api_rs/v1/index.html). Check out the [docs.rs](https://docs.rs/openai-api-rs/0.1.8/openai_api_rs/v1/index.html).
## Installation: ## Installation:
Cargo.toml Cargo.toml
```toml ```toml
[dependencies] [dependencies]
openai-api-rs = "0.1.7" openai-api-rs = "0.1.8"
``` ```
## Usage ## Usage
@ -56,14 +56,20 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
model: chat_completion::GPT4.to_string(), model: chat_completion::GPT4.to_string(),
messages: vec![chat_completion::ChatCompletionMessage { messages: vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user, role: chat_completion::MessageRole::user,
content: String::from("Hello OpenAI!"), content: Some(String::from("What is Bitcoin?")),
name: None,
function_call: None,
}], }],
functions: None,
function_call: None,
}; };
let result = client.chat_completion(req).await?; let result = client.chat_completion(req).await?;
println!("{:?}", result.choices[0].message.content); println!("{:?}", result.choices[0].message.content);
Ok(()) Ok(())
} }
``` ```
More Examples: [examples](https://github.com/dongri/openai-api-rs/tree/main/examples)
Check out the [full API documentation](https://platform.openai.com/docs/api-reference/completions) for examples of all the available functions. Check out the [full API documentation](https://platform.openai.com/docs/api-reference/completions) for examples of all the available functions.
## Supported APIs ## Supported APIs
@ -76,6 +82,7 @@ Check out the [full API documentation](https://platform.openai.com/docs/api-refe
- [x] [Files](https://platform.openai.com/docs/api-reference/files) - [x] [Files](https://platform.openai.com/docs/api-reference/files)
- [x] [Fine-tunes](https://platform.openai.com/docs/api-reference/fine-tunes) - [x] [Fine-tunes](https://platform.openai.com/docs/api-reference/fine-tunes)
- [x] [Moderations](https://platform.openai.com/docs/api-reference/moderations) - [x] [Moderations](https://platform.openai.com/docs/api-reference/moderations)
- [x] [Function calling](https://platform.openai.com/docs/guides/gpt/function-calling)
## License ## License
This project is licensed under [MIT license](https://github.com/dongri/openai-api-rs/blob/main/LICENSE). This project is licensed under [MIT license](https://github.com/dongri/openai-api-rs/blob/main/LICENSE).

View File

@ -9,12 +9,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
model: chat_completion::GPT4.to_string(), model: chat_completion::GPT4.to_string(),
messages: vec![chat_completion::ChatCompletionMessage { messages: vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user, role: chat_completion::MessageRole::user,
content: String::from("NFTとは"), content: Some(String::from("What is Bitcoin?")),
name: None,
function_call: None,
}], }],
functions: None,
function_call: None,
}; };
let result = client.chat_completion(req).await?; let result = client.chat_completion(req).await?;
println!("{:?}", result.choices[0].message.content); println!("{:?}", result.choices[0].message.content);
Ok(()) Ok(())
} }

View File

@ -7,7 +7,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let req = CompletionRequest { let req = CompletionRequest {
model: completion::GPT3_TEXT_DAVINCI_003.to_string(), model: completion::GPT3_TEXT_DAVINCI_003.to_string(),
prompt: Some(String::from("NFTとは")), prompt: Some(String::from("What is Bitcoin?")),
suffix: None, suffix: None,
max_tokens: Some(3000), max_tokens: Some(3000),
temperature: Some(0.9), temperature: Some(0.9),

89
examples/function_call.rs Normal file
View File

@ -0,0 +1,89 @@
use openai_api_rs::v1::api::Client;
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::{env, vec};
async fn get_coin_price(coin: &str) -> f64 {
let coin = coin.to_lowercase();
match coin.as_str() {
"btc" | "bitcoin" => 10000.0,
"eth" | "ethereum" => 1000.0,
_ => 0.0,
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
let mut properties = HashMap::new();
properties.insert(
"coin".to_string(),
Box::new(chat_completion::JSONSchemaDefine {
schema_type: Some(chat_completion::JSONSchemaType::String),
description: Some("The cryptocurrency to get the price of".to_string()),
enum_values: None,
properties: None,
required: None,
items: None,
}),
);
let req = ChatCompletionRequest {
model: chat_completion::GPT3_5_TURBO_0613.to_string(),
messages: vec![chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::user,
content: Some(String::from("What is the price of Ethereum?")),
name: None,
function_call: None,
}],
functions: Some(vec![chat_completion::Function {
name: String::from("get_coin_price"),
description: Some(String::from("Get the price of a cryptocurrency")),
parameters: Some(chat_completion::FunctionParameters {
schema_type: chat_completion::JSONSchemaType::Object,
properties: Some(properties),
required: Some(vec![String::from("coin")]),
}),
}]),
function_call: Some("auto".to_string()),
};
let result = client.chat_completion(req).await?;
match result.choices[0].finish_reason {
chat_completion::FinishReason::stop => {
println!("Stop");
println!("{:?}", result.choices[0].message.content);
}
chat_completion::FinishReason::length => {
println!("Length");
}
chat_completion::FinishReason::function_call => {
println!("FunctionCall");
#[derive(Serialize, Deserialize)]
struct Currency {
coin: String,
}
let function_call = result.choices[0].message.function_call.as_ref().unwrap();
let name = function_call.name.clone().unwrap();
let arguments = function_call.arguments.clone().unwrap();
let c: Currency = serde_json::from_str(&arguments)?;
let coin = c.coin;
if name == "get_coin_price" {
let price = get_coin_price(&coin).await;
println!("{} price: {}", coin, price);
}
}
chat_completion::FinishReason::content_filter => {
println!("ContentFilter");
}
chat_completion::FinishReason::null => {
println!("Null");
}
}
Ok(())
}
// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example function_call

View File

@ -1,18 +1,26 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::v1::common; use crate::v1::common;
pub const GPT3_5_TURBO: &str = "gpt-3.5-turbo"; pub const GPT3_5_TURBO: &str = "gpt-3.5-turbo";
pub const GPT3_5_TURBO_0301: &str = "gpt-3.5-turbo-0301"; pub const GPT3_5_TURBO_0301: &str = "gpt-3.5-turbo-0301";
pub const GPT3_5_TURBO_0613: &str = "gpt-3.5-turbo-0613";
pub const GPT4: &str = "gpt-4"; pub const GPT4: &str = "gpt-4";
pub const GPT4_0314: &str = "gpt-4-0314"; pub const GPT4_0314: &str = "gpt-4-0314";
pub const GPT4_32K: &str = "gpt-4-32k"; pub const GPT4_32K: &str = "gpt-4-32k";
pub const GPT4_32K_0314: &str = "gpt-4-32k-0314"; pub const GPT4_32K_0314: &str = "gpt-4-32k-0314";
pub const GPT4_0613: &str = "gpt-4-0613";
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct ChatCompletionRequest { pub struct ChatCompletionRequest {
pub model: String, pub model: String,
pub messages: Vec<ChatCompletionMessage>, pub messages: Vec<ChatCompletionMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub functions: Option<Vec<Function>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<String>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@ -26,14 +34,19 @@ pub enum MessageRole {
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct ChatCompletionMessage { pub struct ChatCompletionMessage {
pub role: MessageRole, pub role: MessageRole,
pub content: String, #[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct ChatCompletionChoice { pub struct ChatCompletionChoice {
pub index: i64, pub index: i64,
pub message: ChatCompletionMessage, pub message: ChatCompletionMessage,
pub finish_reason: String, pub finish_reason: FinishReason,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@ -45,3 +58,67 @@ pub struct ChatCompletionResponse {
pub choices: Vec<ChatCompletionChoice>, pub choices: Vec<ChatCompletionChoice>,
pub usage: common::Usage, pub usage: common::Usage,
} }
#[derive(Debug, Serialize, Deserialize)]
pub struct Function {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<FunctionParameters>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum JSONSchemaType {
Object,
Number,
String,
Array,
Null,
Boolean,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct JSONSchemaDefine {
#[serde(rename = "type")]
pub schema_type: Option<JSONSchemaType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub enum_values: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, Box<JSONSchemaDefine>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub items: Option<Box<JSONSchemaDefine>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionParameters {
#[serde(rename = "type")]
pub schema_type: JSONSchemaType,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, Box<JSONSchemaDefine>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[allow(non_camel_case_types)]
pub enum FinishReason {
stop,
length,
function_call,
content_filter,
null,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FunctionCall {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}