Fix function call type

This commit is contained in:
Dongri Jin
2023-10-06 17:19:46 +09:00
parent b079334e67
commit 6431c9cbdd
2 changed files with 30 additions and 6 deletions

View File

@ -46,7 +46,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
required: Some(vec![String::from("coin")]), required: Some(vec![String::from("coin")]),
}, },
}]), }]),
function_call: Some(FunctionCallType::auto), //Some(FunctionCallType::Function { name: "test".to_string() }) function_call: Some(FunctionCallType::Auto), // Some(FunctionCallType::Function { name: "test".to_string() }),
temperature: None, temperature: None,
top_p: None, top_p: None,
n: None, n: None,
@ -59,6 +59,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
user: None, user: None,
}; };
// debug reuqest json
// let serialized = serde_json::to_string(&req).unwrap();
// println!("{}", serialized);
let result = client.chat_completion(req)?; let result = client.chat_completion(req)?;
match result.choices[0].finish_reason { match result.choices[0].finish_reason {

View File

@ -1,4 +1,5 @@
use serde::{Deserialize, Serialize}; use serde::ser::SerializeMap;
use serde::{Deserialize, Serialize, Serializer};
use std::collections::HashMap; use std::collections::HashMap;
use crate::v1::common; use crate::v1::common;
@ -14,11 +15,10 @@ pub const GPT4_32K_0314: &str = "gpt-4-32k-0314";
pub const GPT4_0613: &str = "gpt-4-0613"; pub const GPT4_0613: &str = "gpt-4-0613";
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[allow(non_camel_case_types)]
pub enum FunctionCallType { pub enum FunctionCallType {
none, None,
auto, Auto,
function { name: String }, Function { name: String },
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -28,6 +28,7 @@ pub struct ChatCompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub functions: Option<Vec<Function>>, pub functions: Option<Vec<Function>>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[serde(serialize_with = "serialize_function_call")]
pub function_call: Option<FunctionCallType>, pub function_call: Option<FunctionCallType>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>, pub temperature: Option<f64>,
@ -160,3 +161,22 @@ pub struct FunctionCall {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>, pub arguments: Option<String>,
} }
fn serialize_function_call<S>(
value: &Option<FunctionCallType>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match value {
Some(FunctionCallType::None) => serializer.serialize_str("none"),
Some(FunctionCallType::Auto) => serializer.serialize_str("auto"),
Some(FunctionCallType::Function { name }) => {
let mut map = serializer.serialize_map(Some(1))?;
map.serialize_entry("name", name)?;
map.end()
}
None => serializer.serialize_none(),
}
}