diff --git a/examples/function_call.rs b/examples/function_call.rs index a656de3..1961778 100644 --- a/examples/function_call.rs +++ b/examples/function_call.rs @@ -1,5 +1,5 @@ use openai_api_rs::v1::api::Client; -use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest}; +use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest, FunctionCallType}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::{env, vec}; @@ -46,7 +46,7 @@ fn main() -> Result<(), Box> { required: Some(vec![String::from("coin")]), }, }]), - function_call: Some("auto".to_string()), + function_call: Some(FunctionCallType::auto), //Some(FunctionCallType::Function { name: "test".to_string() }) temperature: None, top_p: None, n: None, diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 6de6381..5354bd1 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -13,6 +13,13 @@ pub const GPT4_32K: &str = "gpt-4-32k"; pub const GPT4_32K_0314: &str = "gpt-4-32k-0314"; pub const GPT4_0613: &str = "gpt-4-0613"; +#[derive(Debug, Serialize)] +#[allow(non_camel_case_types)] +pub enum FunctionCallType { + auto, + function { name: String }, +} + #[derive(Debug, Serialize)] pub struct ChatCompletionRequest { pub model: String, @@ -20,7 +27,7 @@ pub struct ChatCompletionRequest { #[serde(skip_serializing_if = "Option::is_none")] pub functions: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub function_call: Option, + pub function_call: Option, #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")]