Merge pull request #96 from hanselke/main

Support tool_calls in message history as well as tool role
This commit is contained in:
Dongri Jin
2024-07-17 17:39:41 +09:00
committed by GitHub
6 changed files with 30 additions and 1 deletions

View File

@ -31,6 +31,8 @@ let req = ChatCompletionRequest::new(
role: chat_completion::MessageRole::user, role: chat_completion::MessageRole::user,
content: chat_completion::Content::Text(String::from("What is bitcoin?")), content: chat_completion::Content::Text(String::from("What is bitcoin?")),
name: None, name: None,
tool_calls: None,
tool_call_id: None,
}], }],
); );
``` ```

View File

@ -13,6 +13,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
role: chat_completion::MessageRole::user, role: chat_completion::MessageRole::user,
content: chat_completion::Content::Text(String::from("What is bitcoin?")), content: chat_completion::Content::Text(String::from("What is bitcoin?")),
name: None, name: None,
tool_calls: None,
tool_call_id: None,
}], }],
); );

View File

@ -34,6 +34,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
role: chat_completion::MessageRole::user, role: chat_completion::MessageRole::user,
content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")), content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")),
name: None, name: None,
tool_calls: None,
tool_call_id: None,
}], }],
) )
.tools(vec![chat_completion::Tool { .tools(vec![chat_completion::Tool {

View File

@ -34,6 +34,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
role: chat_completion::MessageRole::user, role: chat_completion::MessageRole::user,
content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")), content: chat_completion::Content::Text(String::from("What is the price of Ethereum?")),
name: None, name: None,
tool_calls: None,
tool_call_id: None,
}], }],
) )
.tools(vec![chat_completion::Tool { .tools(vec![chat_completion::Tool {
@ -88,6 +90,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
"What is the price of Ethereum?", "What is the price of Ethereum?",
)), )),
name: None, name: None,
tool_calls: None,
tool_call_id: None,
}, },
chat_completion::ChatCompletionMessage { chat_completion::ChatCompletionMessage {
role: chat_completion::MessageRole::function, role: chat_completion::MessageRole::function,
@ -96,6 +100,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
format!("{{\"price\": {}}}", price) format!("{{\"price\": {}}}", price)
}), }),
name: Some(String::from("get_coin_price")), name: Some(String::from("get_coin_price")),
tool_calls: None,
tool_call_id: None,
}, },
], ],
); );

View File

@ -28,6 +28,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}, },
]), ]),
name: None, name: None,
tool_calls: None,
tool_call_id: None,
}], }],
); );

View File

@ -45,6 +45,8 @@ pub struct ChatCompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>, pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(serialize_with = "serialize_tool_choice")] #[serde(serialize_with = "serialize_tool_choice")]
pub tool_choice: Option<ToolChoiceType>, pub tool_choice: Option<ToolChoiceType>,
} }
@ -67,6 +69,7 @@ impl ChatCompletionRequest {
user: None, user: None,
seed: None, seed: None,
tools: None, tools: None,
parallel_tool_calls: None,
tool_choice: None, tool_choice: None,
} }
} }
@ -87,6 +90,7 @@ impl_builder_methods!(
user: String, user: String,
seed: i64, seed: i64,
tools: Vec<Tool>, tools: Vec<Tool>,
parallel_tool_calls: bool,
tool_choice: ToolChoiceType tool_choice: ToolChoiceType
); );
@ -97,6 +101,7 @@ pub enum MessageRole {
system, system,
assistant, assistant,
function, function,
tool,
} }
#[derive(Debug, Deserialize, Clone, PartialEq, Eq)] #[derive(Debug, Deserialize, Clone, PartialEq, Eq)]
@ -111,7 +116,13 @@ impl serde::Serialize for Content {
S: serde::Serializer, S: serde::Serializer,
{ {
match *self { match *self {
Content::Text(ref text) => serializer.serialize_str(text), Content::Text(ref text) => {
if text.is_empty() {
serializer.serialize_none()
} else {
serializer.serialize_str(text)
}
}
Content::ImageUrl(ref image_url) => image_url.serialize(serializer), Content::ImageUrl(ref image_url) => image_url.serialize(serializer),
} }
} }
@ -146,6 +157,10 @@ pub struct ChatCompletionMessage {
pub content: Content, pub content: Content,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>, pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
} }
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]