Add tools and tool_choice fields

This commit is contained in:
Sharif Haason
2023-12-27 01:14:41 -05:00
parent 589871c055
commit 9d6733f981

View File

@ -13,6 +13,13 @@ pub enum FunctionCallType {
Function { name: String },
}
#[derive(Debug, Serialize, Clone)]
pub enum ToolChoiceType {
None,
Auto,
ToolChoice { tool: Tool },
}
#[derive(Debug, Serialize, Clone)]
pub struct ChatCompletionRequest {
pub model: String,
@ -54,6 +61,11 @@ pub struct ChatCompletionRequest {
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Tool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(serialize_with = "serialize_tool_choice")]
pub tool_choice: Option<ToolChoiceType>,
}
impl ChatCompletionRequest {
@ -75,6 +87,8 @@ impl ChatCompletionRequest {
logit_bias: None,
user: None,
seed: None,
tools: None,
tool_choice: None,
}
}
}
@ -94,7 +108,9 @@ impl_builder_methods!(
frequency_penalty: f64,
logit_bias: HashMap<String, i32>,
user: String,
seed: i64
seed: i64,
tools: Tool,
tool_choice: ToolChoiceType
);
#[derive(Debug, Serialize, Deserialize, Clone)]
@ -233,3 +249,30 @@ where
None => serializer.serialize_none(),
}
}
fn serialize_tool_choice<S>(
value: &Option<ToolChoiceType>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match value {
Some(ToolChoiceType::None) => serializer.serialize_str("none"),
Some(ToolChoiceType::Auto) => serializer.serialize_str("auto"),
Some(ToolChoiceType::ToolChoice { tool }) => {
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("type", &tool.tool_type)?;
map.serialize_entry("function", &tool.function)?;
map.end()
}
None => serializer.serialize_none(),
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Tool {
#[serde(rename = "type")]
tool_type: String,
function: Function,
}