mirror of
https://github.com/mii443/openai-api-rs.git
synced 2025-08-22 23:25:39 +00:00
Fix function call type
This commit is contained in:
@ -46,7 +46,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
required: Some(vec![String::from("coin")]),
|
required: Some(vec![String::from("coin")]),
|
||||||
},
|
},
|
||||||
}]),
|
}]),
|
||||||
function_call: Some(FunctionCallType::auto), //Some(FunctionCallType::Function { name: "test".to_string() })
|
function_call: Some(FunctionCallType::Auto), // Some(FunctionCallType::Function { name: "test".to_string() }),
|
||||||
temperature: None,
|
temperature: None,
|
||||||
top_p: None,
|
top_p: None,
|
||||||
n: None,
|
n: None,
|
||||||
@ -59,6 +59,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
user: None,
|
user: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// debug reuqest json
|
||||||
|
// let serialized = serde_json::to_string(&req).unwrap();
|
||||||
|
// println!("{}", serialized);
|
||||||
|
|
||||||
let result = client.chat_completion(req)?;
|
let result = client.chat_completion(req)?;
|
||||||
|
|
||||||
match result.choices[0].finish_reason {
|
match result.choices[0].finish_reason {
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::ser::SerializeMap;
|
||||||
|
use serde::{Deserialize, Serialize, Serializer};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use crate::v1::common;
|
use crate::v1::common;
|
||||||
@ -14,11 +15,10 @@ pub const GPT4_32K_0314: &str = "gpt-4-32k-0314";
|
|||||||
pub const GPT4_0613: &str = "gpt-4-0613";
|
pub const GPT4_0613: &str = "gpt-4-0613";
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
#[allow(non_camel_case_types)]
|
|
||||||
pub enum FunctionCallType {
|
pub enum FunctionCallType {
|
||||||
none,
|
None,
|
||||||
auto,
|
Auto,
|
||||||
function { name: String },
|
Function { name: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
@ -28,6 +28,7 @@ pub struct ChatCompletionRequest {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub functions: Option<Vec<Function>>,
|
pub functions: Option<Vec<Function>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[serde(serialize_with = "serialize_function_call")]
|
||||||
pub function_call: Option<FunctionCallType>,
|
pub function_call: Option<FunctionCallType>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub temperature: Option<f64>,
|
pub temperature: Option<f64>,
|
||||||
@ -160,3 +161,22 @@ pub struct FunctionCall {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub arguments: Option<String>,
|
pub arguments: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn serialize_function_call<S>(
|
||||||
|
value: &Option<FunctionCallType>,
|
||||||
|
serializer: S,
|
||||||
|
) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
match value {
|
||||||
|
Some(FunctionCallType::None) => serializer.serialize_str("none"),
|
||||||
|
Some(FunctionCallType::Auto) => serializer.serialize_str("auto"),
|
||||||
|
Some(FunctionCallType::Function { name }) => {
|
||||||
|
let mut map = serializer.serialize_map(Some(1))?;
|
||||||
|
map.serialize_entry("name", name)?;
|
||||||
|
map.end()
|
||||||
|
}
|
||||||
|
None => serializer.serialize_none(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user