Add function call type

This commit is contained in:
Dongri Jin
2023-10-04 16:57:13 +09:00
parent 036b850035
commit 3a37c625e0
2 changed files with 10 additions and 3 deletions

View File

@ -1,5 +1,5 @@
use openai_api_rs::v1::api::Client; use openai_api_rs::v1::api::Client;
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest, FunctionCallType};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::{env, vec}; use std::{env, vec};
@ -46,7 +46,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
required: Some(vec![String::from("coin")]), required: Some(vec![String::from("coin")]),
}, },
}]), }]),
function_call: Some("auto".to_string()), function_call: Some(FunctionCallType::auto), //Some(FunctionCallType::Function { name: "test".to_string() })
temperature: None, temperature: None,
top_p: None, top_p: None,
n: None, n: None,

View File

@ -13,6 +13,13 @@ pub const GPT4_32K: &str = "gpt-4-32k";
pub const GPT4_32K_0314: &str = "gpt-4-32k-0314"; pub const GPT4_32K_0314: &str = "gpt-4-32k-0314";
pub const GPT4_0613: &str = "gpt-4-0613"; pub const GPT4_0613: &str = "gpt-4-0613";
#[derive(Debug, Serialize)]
#[allow(non_camel_case_types)]
pub enum FunctionCallType {
auto,
function { name: String },
}
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct ChatCompletionRequest { pub struct ChatCompletionRequest {
pub model: String, pub model: String,
@ -20,7 +27,7 @@ pub struct ChatCompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub functions: Option<Vec<Function>>, pub functions: Option<Vec<Function>>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<String>, pub function_call: Option<FunctionCallType>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>, pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]