mirror of
https://github.com/mii443/openai-api-rs.git
synced 2025-08-22 23:25:39 +00:00
Support function calling
This commit is contained in:
@ -1,10 +1,10 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "openai-api-rs"
|
name = "openai-api-rs"
|
||||||
version = "0.1.7"
|
version = "0.1.8"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["Dongri Jin <dongrify@gmail.com>"]
|
authors = ["Dongri Jin <dongrify@gmail.com>"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
description = "OpenAI API wrapper for Rust"
|
description = "OpenAI API client library for Rust (unofficial)"
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
@ -12,3 +12,4 @@ description = "OpenAI API wrapper for Rust"
|
|||||||
reqwest = { version = "0.11", features = ["json"] }
|
reqwest = { version = "0.11", features = ["json"] }
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1.0.97"
|
||||||
|
13
README.md
13
README.md
@ -1,13 +1,13 @@
|
|||||||
# OpenAI API client library for Rust (unofficial)
|
# OpenAI API client library for Rust (unofficial)
|
||||||
The OpenAI API client Rust library provides convenient access to the OpenAI API from Rust applications.
|
The OpenAI API client Rust library provides convenient access to the OpenAI API from Rust applications.
|
||||||
|
|
||||||
Check out the [docs.rs](https://docs.rs/openai-api-rs/0.1.7/openai_api_rs/v1/index.html).
|
Check out the [docs.rs](https://docs.rs/openai-api-rs/0.1.8/openai_api_rs/v1/index.html).
|
||||||
|
|
||||||
## Installation:
|
## Installation:
|
||||||
Cargo.toml
|
Cargo.toml
|
||||||
```toml
|
```toml
|
||||||
[dependencies]
|
[dependencies]
|
||||||
openai-api-rs = "0.1.7"
|
openai-api-rs = "0.1.8"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
@ -56,14 +56,20 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
model: chat_completion::GPT4.to_string(),
|
model: chat_completion::GPT4.to_string(),
|
||||||
messages: vec![chat_completion::ChatCompletionMessage {
|
messages: vec![chat_completion::ChatCompletionMessage {
|
||||||
role: chat_completion::MessageRole::user,
|
role: chat_completion::MessageRole::user,
|
||||||
content: String::from("Hello OpenAI!"),
|
content: Some(String::from("What is Bitcoin?")),
|
||||||
|
name: None,
|
||||||
|
function_call: None,
|
||||||
}],
|
}],
|
||||||
|
functions: None,
|
||||||
|
function_call: None,
|
||||||
};
|
};
|
||||||
let result = client.chat_completion(req).await?;
|
let result = client.chat_completion(req).await?;
|
||||||
println!("{:?}", result.choices[0].message.content);
|
println!("{:?}", result.choices[0].message.content);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
More Examples: [examples](https://github.com/dongri/openai-api-rs/tree/main/examples)
|
||||||
|
|
||||||
Check out the [full API documentation](https://platform.openai.com/docs/api-reference/completions) for examples of all the available functions.
|
Check out the [full API documentation](https://platform.openai.com/docs/api-reference/completions) for examples of all the available functions.
|
||||||
|
|
||||||
## Supported APIs
|
## Supported APIs
|
||||||
@ -76,6 +82,7 @@ Check out the [full API documentation](https://platform.openai.com/docs/api-refe
|
|||||||
- [x] [Files](https://platform.openai.com/docs/api-reference/files)
|
- [x] [Files](https://platform.openai.com/docs/api-reference/files)
|
||||||
- [x] [Fine-tunes](https://platform.openai.com/docs/api-reference/fine-tunes)
|
- [x] [Fine-tunes](https://platform.openai.com/docs/api-reference/fine-tunes)
|
||||||
- [x] [Moderations](https://platform.openai.com/docs/api-reference/moderations)
|
- [x] [Moderations](https://platform.openai.com/docs/api-reference/moderations)
|
||||||
|
- [x] [Function calling](https://platform.openai.com/docs/guides/gpt/function-calling)
|
||||||
|
|
||||||
## License
|
## License
|
||||||
This project is licensed under [MIT license](https://github.com/dongri/openai-api-rs/blob/main/LICENSE).
|
This project is licensed under [MIT license](https://github.com/dongri/openai-api-rs/blob/main/LICENSE).
|
||||||
|
@ -9,12 +9,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
model: chat_completion::GPT4.to_string(),
|
model: chat_completion::GPT4.to_string(),
|
||||||
messages: vec![chat_completion::ChatCompletionMessage {
|
messages: vec![chat_completion::ChatCompletionMessage {
|
||||||
role: chat_completion::MessageRole::user,
|
role: chat_completion::MessageRole::user,
|
||||||
content: String::from("NFTとは?"),
|
content: Some(String::from("What is Bitcoin?")),
|
||||||
|
name: None,
|
||||||
|
function_call: None,
|
||||||
}],
|
}],
|
||||||
|
functions: None,
|
||||||
|
function_call: None,
|
||||||
};
|
};
|
||||||
let result = client.chat_completion(req).await?;
|
let result = client.chat_completion(req).await?;
|
||||||
println!("{:?}", result.choices[0].message.content);
|
println!("{:?}", result.choices[0].message.content);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
|
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
|
||||||
let req = CompletionRequest {
|
let req = CompletionRequest {
|
||||||
model: completion::GPT3_TEXT_DAVINCI_003.to_string(),
|
model: completion::GPT3_TEXT_DAVINCI_003.to_string(),
|
||||||
prompt: Some(String::from("NFTとは?")),
|
prompt: Some(String::from("What is Bitcoin?")),
|
||||||
suffix: None,
|
suffix: None,
|
||||||
max_tokens: Some(3000),
|
max_tokens: Some(3000),
|
||||||
temperature: Some(0.9),
|
temperature: Some(0.9),
|
||||||
|
89
examples/function_call.rs
Normal file
89
examples/function_call.rs
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
use openai_api_rs::v1::api::Client;
|
||||||
|
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::{env, vec};
|
||||||
|
|
||||||
|
async fn get_coin_price(coin: &str) -> f64 {
|
||||||
|
let coin = coin.to_lowercase();
|
||||||
|
match coin.as_str() {
|
||||||
|
"btc" | "bitcoin" => 10000.0,
|
||||||
|
"eth" | "ethereum" => 1000.0,
|
||||||
|
_ => 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string());
|
||||||
|
|
||||||
|
let mut properties = HashMap::new();
|
||||||
|
properties.insert(
|
||||||
|
"coin".to_string(),
|
||||||
|
Box::new(chat_completion::JSONSchemaDefine {
|
||||||
|
schema_type: Some(chat_completion::JSONSchemaType::String),
|
||||||
|
description: Some("The cryptocurrency to get the price of".to_string()),
|
||||||
|
enum_values: None,
|
||||||
|
properties: None,
|
||||||
|
required: None,
|
||||||
|
items: None,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
let req = ChatCompletionRequest {
|
||||||
|
model: chat_completion::GPT3_5_TURBO_0613.to_string(),
|
||||||
|
messages: vec![chat_completion::ChatCompletionMessage {
|
||||||
|
role: chat_completion::MessageRole::user,
|
||||||
|
content: Some(String::from("What is the price of Ethereum?")),
|
||||||
|
name: None,
|
||||||
|
function_call: None,
|
||||||
|
}],
|
||||||
|
functions: Some(vec![chat_completion::Function {
|
||||||
|
name: String::from("get_coin_price"),
|
||||||
|
description: Some(String::from("Get the price of a cryptocurrency")),
|
||||||
|
parameters: Some(chat_completion::FunctionParameters {
|
||||||
|
schema_type: chat_completion::JSONSchemaType::Object,
|
||||||
|
properties: Some(properties),
|
||||||
|
required: Some(vec![String::from("coin")]),
|
||||||
|
}),
|
||||||
|
}]),
|
||||||
|
function_call: Some("auto".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = client.chat_completion(req).await?;
|
||||||
|
|
||||||
|
match result.choices[0].finish_reason {
|
||||||
|
chat_completion::FinishReason::stop => {
|
||||||
|
println!("Stop");
|
||||||
|
println!("{:?}", result.choices[0].message.content);
|
||||||
|
}
|
||||||
|
chat_completion::FinishReason::length => {
|
||||||
|
println!("Length");
|
||||||
|
}
|
||||||
|
chat_completion::FinishReason::function_call => {
|
||||||
|
println!("FunctionCall");
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct Currency {
|
||||||
|
coin: String,
|
||||||
|
}
|
||||||
|
let function_call = result.choices[0].message.function_call.as_ref().unwrap();
|
||||||
|
let name = function_call.name.clone().unwrap();
|
||||||
|
let arguments = function_call.arguments.clone().unwrap();
|
||||||
|
let c: Currency = serde_json::from_str(&arguments)?;
|
||||||
|
let coin = c.coin;
|
||||||
|
if name == "get_coin_price" {
|
||||||
|
let price = get_coin_price(&coin).await;
|
||||||
|
println!("{} price: {}", coin, price);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chat_completion::FinishReason::content_filter => {
|
||||||
|
println!("ContentFilter");
|
||||||
|
}
|
||||||
|
chat_completion::FinishReason::null => {
|
||||||
|
println!("Null");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// OPENAI_API_KEY=xxxx cargo run --package openai-api-rs --example function_call
|
@ -1,18 +1,26 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use crate::v1::common;
|
use crate::v1::common;
|
||||||
|
|
||||||
pub const GPT3_5_TURBO: &str = "gpt-3.5-turbo";
|
pub const GPT3_5_TURBO: &str = "gpt-3.5-turbo";
|
||||||
pub const GPT3_5_TURBO_0301: &str = "gpt-3.5-turbo-0301";
|
pub const GPT3_5_TURBO_0301: &str = "gpt-3.5-turbo-0301";
|
||||||
|
pub const GPT3_5_TURBO_0613: &str = "gpt-3.5-turbo-0613";
|
||||||
|
|
||||||
pub const GPT4: &str = "gpt-4";
|
pub const GPT4: &str = "gpt-4";
|
||||||
pub const GPT4_0314: &str = "gpt-4-0314";
|
pub const GPT4_0314: &str = "gpt-4-0314";
|
||||||
pub const GPT4_32K: &str = "gpt-4-32k";
|
pub const GPT4_32K: &str = "gpt-4-32k";
|
||||||
pub const GPT4_32K_0314: &str = "gpt-4-32k-0314";
|
pub const GPT4_32K_0314: &str = "gpt-4-32k-0314";
|
||||||
|
pub const GPT4_0613: &str = "gpt-4-0613";
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub struct ChatCompletionRequest {
|
pub struct ChatCompletionRequest {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub messages: Vec<ChatCompletionMessage>,
|
pub messages: Vec<ChatCompletionMessage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub functions: Option<Vec<Function>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub function_call: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
@ -26,14 +34,19 @@ pub enum MessageRole {
|
|||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct ChatCompletionMessage {
|
pub struct ChatCompletionMessage {
|
||||||
pub role: MessageRole,
|
pub role: MessageRole,
|
||||||
pub content: String,
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub function_call: Option<FunctionCall>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct ChatCompletionChoice {
|
pub struct ChatCompletionChoice {
|
||||||
pub index: i64,
|
pub index: i64,
|
||||||
pub message: ChatCompletionMessage,
|
pub message: ChatCompletionMessage,
|
||||||
pub finish_reason: String,
|
pub finish_reason: FinishReason,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
@ -45,3 +58,67 @@ pub struct ChatCompletionResponse {
|
|||||||
pub choices: Vec<ChatCompletionChoice>,
|
pub choices: Vec<ChatCompletionChoice>,
|
||||||
pub usage: common::Usage,
|
pub usage: common::Usage,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Function {
|
||||||
|
pub name: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub description: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub parameters: Option<FunctionParameters>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum JSONSchemaType {
|
||||||
|
Object,
|
||||||
|
Number,
|
||||||
|
String,
|
||||||
|
Array,
|
||||||
|
Null,
|
||||||
|
Boolean,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct JSONSchemaDefine {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub schema_type: Option<JSONSchemaType>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub description: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub enum_values: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub properties: Option<HashMap<String, Box<JSONSchemaDefine>>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub required: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub items: Option<Box<JSONSchemaDefine>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct FunctionParameters {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub schema_type: JSONSchemaType,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub properties: Option<HashMap<String, Box<JSONSchemaDefine>>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub required: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
#[allow(non_camel_case_types)]
|
||||||
|
pub enum FinishReason {
|
||||||
|
stop,
|
||||||
|
length,
|
||||||
|
function_call,
|
||||||
|
content_filter,
|
||||||
|
null,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct FunctionCall {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub arguments: Option<String>,
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user