mirror of
https://github.com/mii443/openai-api-rs.git
synced 2025-08-23 07:35:37 +00:00
Add tools
and tool_choice
fields
This commit is contained in:
@ -13,6 +13,13 @@ pub enum FunctionCallType {
|
|||||||
Function { name: String },
|
Function { name: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Clone)]
|
||||||
|
pub enum ToolChoiceType {
|
||||||
|
None,
|
||||||
|
Auto,
|
||||||
|
ToolChoice { tool: Tool },
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Clone)]
|
#[derive(Debug, Serialize, Clone)]
|
||||||
pub struct ChatCompletionRequest {
|
pub struct ChatCompletionRequest {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
@ -54,6 +61,11 @@ pub struct ChatCompletionRequest {
|
|||||||
pub user: Option<String>,
|
pub user: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub seed: Option<i64>,
|
pub seed: Option<i64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tools: Option<Tool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[serde(serialize_with = "serialize_tool_choice")]
|
||||||
|
pub tool_choice: Option<ToolChoiceType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatCompletionRequest {
|
impl ChatCompletionRequest {
|
||||||
@ -75,6 +87,8 @@ impl ChatCompletionRequest {
|
|||||||
logit_bias: None,
|
logit_bias: None,
|
||||||
user: None,
|
user: None,
|
||||||
seed: None,
|
seed: None,
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -94,7 +108,9 @@ impl_builder_methods!(
|
|||||||
frequency_penalty: f64,
|
frequency_penalty: f64,
|
||||||
logit_bias: HashMap<String, i32>,
|
logit_bias: HashMap<String, i32>,
|
||||||
user: String,
|
user: String,
|
||||||
seed: i64
|
seed: i64,
|
||||||
|
tools: Tool,
|
||||||
|
tool_choice: ToolChoiceType
|
||||||
);
|
);
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
@ -233,3 +249,30 @@ where
|
|||||||
None => serializer.serialize_none(),
|
None => serializer.serialize_none(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn serialize_tool_choice<S>(
|
||||||
|
value: &Option<ToolChoiceType>,
|
||||||
|
serializer: S,
|
||||||
|
) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
match value {
|
||||||
|
Some(ToolChoiceType::None) => serializer.serialize_str("none"),
|
||||||
|
Some(ToolChoiceType::Auto) => serializer.serialize_str("auto"),
|
||||||
|
Some(ToolChoiceType::ToolChoice { tool }) => {
|
||||||
|
let mut map = serializer.serialize_map(Some(2))?;
|
||||||
|
map.serialize_entry("type", &tool.tool_type)?;
|
||||||
|
map.serialize_entry("function", &tool.function)?;
|
||||||
|
map.end()
|
||||||
|
}
|
||||||
|
None => serializer.serialize_none(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
|
pub struct Tool {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
tool_type: String,
|
||||||
|
function: Function,
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user