+ tool - First pass at adding Function Calling for OpenAI and Anthropic (rel #24)

This commit is contained in:
Jeremy Chone
2024-10-22 12:30:16 -07:00
parent 4b76a5ead6
commit 001b124381
23 changed files with 740 additions and 177 deletions

View File

@ -11,8 +11,8 @@ repository = "https://github.com/jeremychone/rust-genai"
[lints.rust]
unsafe_code = "forbid"
# unused = { level = "allow", priority = -1 } # For exploratory dev.
missing_docs = "warn"
unused = { level = "allow", priority = -1 } # For exploratory dev.
# missing_docs = "warn"
[dependencies]
# -- Async

View File

@ -3,10 +3,11 @@ use crate::adapter::support::get_api_key;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{
ChatOptionsSet, ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse, MessageContent, MetaUsage,
ToolCall,
};
use crate::webc::WebResponse;
use crate::Result;
use crate::{ClientConfig, ModelIden};
use crate::{Error, Result};
use reqwest::RequestBuilder;
use reqwest_eventsource::EventSource;
use serde_json::{json, Value};
@ -18,9 +19,9 @@ const BASE_URL: &str = "https://api.anthropic.com/v1/";
const MAX_TOKENS: u32 = 1024;
const ANTRHOPIC_VERSION: &str = "2023-06-01";
const MODELS: &[&str] = &[
"claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-20240620",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
];
@ -53,7 +54,7 @@ impl Adapter for AnthropicAdapter {
let url = Self::get_service_url(model_iden.clone(), service_type);
// -- api_key (this Adapter requires it)
let api_key = get_api_key(model_iden, client_config)?;
let api_key = get_api_key(model_iden.clone(), client_config)?;
let headers = vec![
// headers
@ -61,7 +62,11 @@ impl Adapter for AnthropicAdapter {
("anthropic-version".to_string(), ANTRHOPIC_VERSION.to_string()),
];
let AnthropicRequestParts { system, messages } = Self::into_anthropic_request_parts(chat_req)?;
let AnthropicRequestParts {
system,
messages,
tools,
} = Self::into_anthropic_request_parts(model_iden, chat_req)?;
// -- Build the basic payload
let mut payload = json!({
@ -69,10 +74,15 @@ impl Adapter for AnthropicAdapter {
"messages": messages,
"stream": stream
});
if let Some(system) = system {
payload.x_insert("system", system)?;
}
if let Some(tools) = tools {
payload.x_insert("/tools", tools);
}
// -- Add supported ChatOptions
if let Some(temperature) = options_set.temperature() {
payload.x_insert("temperature", temperature)?;
@ -90,27 +100,51 @@ impl Adapter for AnthropicAdapter {
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
let WebResponse { mut body, .. } = web_response;
let json_content_items: Vec<Value> = body.x_take("content")?;
let mut content: Vec<String> = Vec::new();
// -- Capture the usage
let usage = body.x_take("usage").map(Self::into_usage).unwrap_or_default();
for mut item in json_content_items {
let item_text: String = item.x_take("text")?;
content.push(item_text);
// -- Capture the content
// NOTE: Anthropic support a list of content of multitypes but not the ChatResponse
// So, the strategy is to:
// - List all of the content and capture the text and tool_use
// - If there is one or more tool_use, this will take precedence and MessageContent support tool_call list
// - Otherwise, the text is concatenated
// NOTE: We need to see if the multiple content type text happens and why. If not, we can probably simplify this by just capturing the first one.
// Eventually, ChatResponse will have `content: Option<Vec<MessageContent>>` for the multi parts (with images and such)
let content_items: Vec<Value> = body.x_take("content")?;
let mut text_content: Vec<String> = Vec::new();
// Note: here tool_calls is probably the exception, so, not creating the vector if not needed
let mut tool_calls: Option<Vec<ToolCall>> = None;
for mut item in content_items {
let typ: &str = item.x_get_as("type")?;
if typ == "text" {
text_content.push(item.x_take("text")?);
} else if typ == "tool_use" {
let call_id = item.x_take::<String>("id")?;
let fn_name = item.x_take::<String>("name")?;
// if not found, will be Value::Null
let fn_arguments = item.x_take::<Value>("input").unwrap_or_default();
let tool_call = ToolCall {
call_id,
fn_name,
fn_arguments,
};
tool_calls.get_or_insert_with(Vec::new).push(tool_call);
}
}
let content = if content.is_empty() {
None
let content = if let Some(tool_calls) = tool_calls {
Some(MessageContent::from(tool_calls))
} else {
Some(content.join(""))
Some(MessageContent::from(text_content.join("\n")))
};
let content = content.map(MessageContent::from);
Ok(ChatResponse {
model_iden,
content,
model_iden,
usage,
})
}
@ -153,7 +187,7 @@ impl AnthropicAdapter {
/// Takes the GenAI ChatMessages and constructs the System string and JSON Messages for Anthropic.
/// - Will push the `ChatRequest.system` and system message to `AnthropicRequestParts.system`
fn into_anthropic_request_parts(chat_req: ChatRequest) -> Result<AnthropicRequestParts> {
fn into_anthropic_request_parts(model_iden: ModelIden, chat_req: ChatRequest) -> Result<AnthropicRequestParts> {
let mut messages: Vec<Value> = Vec::new();
let mut systems: Vec<String> = Vec::new();
@ -161,32 +195,115 @@ impl AnthropicAdapter {
systems.push(system);
}
// -- Process the messages
for msg in chat_req.messages {
// Note: Will handle more types later
let MessageContent::Text(content) = msg.content;
match msg.role {
// for now, system and tool messages go to system
ChatRole::System | ChatRole::Tool => systems.push(content),
ChatRole::User => messages.push(json! ({"role": "user", "content": content})),
ChatRole::Assistant => messages.push(json! ({"role": "assistant", "content": content})),
ChatRole::System => {
if let MessageContent::Text(content) = msg.content {
systems.push(content)
}
// TODO: Needs to trace/warn that other type are not supported
}
ChatRole::User => {
if let MessageContent::Text(content) = msg.content {
messages.push(json! ({"role": "user", "content": content}))
}
// TODO: Needs to trace/warn that other type are not supported
}
ChatRole::Assistant => {
//
match msg.content {
MessageContent::Text(content) => {
messages.push(json! ({"role": "assistant", "content": content}))
}
MessageContent::ToolCalls(tool_calls) => {
let tool_calls = tool_calls
.into_iter()
.map(|tool_call| {
// see: https://docs.anthropic.com/en/docs/build-with-claude/tool-use#example-of-successful-tool-result
json!({
"type": "tool_use",
"id": tool_call.call_id,
"name": tool_call.fn_name,
"input": tool_call.fn_arguments,
})
})
.collect::<Vec<Value>>();
messages.push(json! ({
"role": "assistant",
"content": tool_calls
}));
}
// TODO: Probably need to trace/warn that this will be ignored
MessageContent::ToolResponses(_) => (),
}
}
ChatRole::Tool => {
if let MessageContent::ToolResponses(tool_responses) = msg.content {
let tool_responses = tool_responses
.into_iter()
.map(|tool_response| {
json!({
"type": "tool_result",
"content": tool_response.content,
"tool_use_id": tool_response.call_id,
})
})
.collect::<Vec<Value>>();
// FIXME: MessageContent::ToolResponse should be MessageContent::ToolResponses (even if openAI does require multi Tool message)
messages.push(json!({
"role": "user",
"content": tool_responses
}));
}
// TODO: Probably need to trace/warn that this will be ignored
}
}
}
// -- Create the Anthropic system
// NOTE: Anthropic does not have a "role": "system", just a single optional system property
let system = if !systems.is_empty() {
Some(systems.join("\n"))
} else {
None
};
Ok(AnthropicRequestParts { system, messages })
// -- Process the tools
let tools = chat_req.tools.map(|tools| {
tools
.into_iter()
.map(|tool| {
// TODO: Need to handle the error correctly
// TODO: Needs to have a custom serializer (tool should not have to match to a provider)
// NOTE: Right now, low probability, so, we just return null if cannto to value.
let mut tool_value = json!({
"name": tool.name,
"input_schema": tool.schema,
});
if let Some(description) = tool.description {
tool_value.x_insert("description", description);
}
tool_value
})
.collect::<Vec<Value>>()
});
Ok(AnthropicRequestParts {
system,
messages,
tools,
})
}
}
struct AnthropicRequestParts {
system: Option<String>,
messages: Vec<Value>,
// TODO: need to add tools
tools: Option<Vec<Value>>,
}
// endregion: --- Support

View File

@ -1,4 +1,5 @@
//! API Documentation: https://docs.anthropic.com/en/api/messages
//! Tool Documentation: https://docs.anthropic.com/en/docs/build-with-claude/tool-use
//! Model Names: https://docs.anthropic.com/en/docs/models-overview
//! Pricing: https://www.anthropic.com/pricing#anthropic-api

View File

@ -110,8 +110,8 @@ impl Adapter for CohereAdapter {
.map(MessageContent::from);
Ok(ChatResponse {
model_iden,
content,
model_iden,
usage,
})
}
@ -185,13 +185,23 @@ impl CohereAdapter {
actual_role: last_chat_msg.role,
});
}
// Will handle more types later
let MessageContent::Text(message) = last_chat_msg.content;
// TODO: Needs to implement tool_calls
let MessageContent::Text(message) = last_chat_msg.content else {
return Err(Error::MessageContentTypeNotSupported {
model_iden,
cause: "Only MessageContent::Text supported for this model (for now)",
});
};
// -- Build
for msg in chat_req.messages {
// Note: Will handle more types later
let MessageContent::Text(content) = msg.content;
let MessageContent::Text(content) = msg.content else {
return Err(Error::MessageContentTypeNotSupported {
model_iden,
cause: "Only MessageContent::Text supported for this model (for now)",
});
};
match msg.role {
// For now, system and tool go to the system

View File

@ -121,8 +121,8 @@ impl Adapter for GeminiAdapter {
let content = content.map(MessageContent::from);
Ok(ChatResponse {
model_iden,
content,
model_iden,
usage,
})
}
@ -192,8 +192,13 @@ impl GeminiAdapter {
// -- Build
for msg in chat_req.messages {
// Note: Will handle more types later
let MessageContent::Text(content) = msg.content;
// TODO: Needs to implement tool_calls
let MessageContent::Text(content) = msg.content else {
return Err(Error::MessageContentTypeNotSupported {
model_iden,
cause: "Only MessageContent::Text supported for this model (for now)",
});
};
match msg.role {
// For now, system goes as "user" (later, we might have adapter_config.system_to_user_impl)

View File

@ -3,13 +3,14 @@ use crate::adapter::support::get_api_key;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{
ChatOptionsSet, ChatRequest, ChatResponse, ChatResponseFormat, ChatRole, ChatStream, ChatStreamResponse,
MessageContent, MetaUsage,
MessageContent, MetaUsage, Tool, ToolCall,
};
use crate::webc::WebResponse;
use crate::{ClientConfig, ModelIden};
use crate::{Error, Result};
use reqwest::RequestBuilder;
use reqwest_eventsource::EventSource;
use serde::Deserialize;
use serde_json::{json, Value};
use value_ext::JsonValueExt;
@ -54,15 +55,31 @@ impl Adapter for OpenAIAdapter {
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
let WebResponse { mut body, .. } = web_response;
// -- Capture the usage
let usage = body.x_take("usage").map(OpenAIAdapter::into_usage).unwrap_or_default();
let first_choice: Option<Value> = body.x_take("/choices/0")?;
let content: Option<String> = first_choice.map(|mut c| c.x_take("/message/content")).transpose()?;
let content = content.map(MessageContent::from);
// -- Capture the content
let content = if let Some(mut first_choice) = body.x_take::<Option<Value>>("/choices/0")? {
if let Some(content) = first_choice
.x_take::<Option<String>>("/message/content")?
.map(MessageContent::from)
{
Some(content)
} else {
first_choice
.x_take("/message/tool_calls")
.ok()
.map(parse_tool_calls)
.transpose()?
.map(MessageContent::from_tool_calls)
}
} else {
None
};
Ok(ChatResponse {
model_iden,
content,
model_iden,
usage,
})
}
@ -117,13 +134,17 @@ impl OpenAIAdapter {
// -- Build the basic payload
let model_name = model_iden.model_name.to_string();
let OpenAIRequestParts { messages } = Self::into_openai_request_parts(model_iden, chat_req)?;
let OpenAIRequestParts { messages, tools } = Self::into_openai_request_parts(model_iden, chat_req)?;
let mut payload = json!({
"model": model_name,
"messages": messages,
"stream": stream
});
if let Some(tools) = tools {
payload.x_insert("/tools", tools);
}
// -- Add options
let response_format = if let Some(response_format) = options_set.response_format() {
match response_format {
@ -199,16 +220,17 @@ impl OpenAIAdapter {
/// Takes the genai ChatMessages and builds the OpenAIChatRequestParts
/// - `genai::ChatRequest.system`, if present, is added as the first message with role 'system'.
/// - All messages get added with the corresponding roles (tools are not supported for now)
///
/// NOTE: Here, the last `true` is for the Ollama variant
/// It seems the Ollama compatibility layer does not work well with multiple system messages.
/// So, when `true`, it will concatenate the system message into a single one at the beginning
fn into_openai_request_parts(model_iden: ModelIden, chat_req: ChatRequest) -> Result<OpenAIRequestParts> {
let mut system_messages: Vec<String> = Vec::new();
let mut messages: Vec<Value> = Vec::new();
/// NOTE: For now system_messages is use to fix an issue with the Ollama compatibility layer that does not support multiple system messages.
/// So, when ollama, it will concatenate the system message into a single one at the beginning
/// NOTE: This might be fixed now, so, we could remove this.
let mut system_messages: Vec<String> = Vec::new();
let ollama_variant = matches!(model_iden.adapter_kind, AdapterKind::Ollama);
// -- Process the system
if let Some(system_msg) = chat_req.system {
if ollama_variant {
system_messages.push(system_msg)
@ -217,37 +239,98 @@ impl OpenAIAdapter {
}
}
// -- Process the messages
for msg in chat_req.messages {
// Note: Will handle more types later
let MessageContent::Text(content) = msg.content;
match msg.role {
// For now, system and tool messages go to the system
ChatRole::System => {
// See note in the function comment
if ollama_variant {
system_messages.push(content);
} else {
messages.push(json!({"role": "system", "content": content}))
if let MessageContent::Text(content) = msg.content {
// NOTE: Ollama does not support multiple system messages
// See note in the function comment
if ollama_variant {
system_messages.push(content);
} else {
messages.push(json!({"role": "system", "content": content}))
}
}
// TODO: Probably need to warn if it is a ToolCalls type of content
}
ChatRole::User => messages.push(json! ({"role": "user", "content": content})),
ChatRole::Assistant => messages.push(json! ({"role": "assistant", "content": content})),
ChatRole::User => {
if let MessageContent::Text(content) = msg.content {
messages.push(json! ({"role": "user", "content": content}));
}
// TODO: Probably need to warn if it is a ToolCalls type of content
}
ChatRole::Assistant => match msg.content {
MessageContent::Text(content) => messages.push(json! ({"role": "assistant", "content": content})),
MessageContent::ToolCalls(tool_calls) => {
let tool_calls = tool_calls
.into_iter()
.map(|tool_call| {
json!({
"type": "function",
"id": tool_call.call_id,
"function": {
"name": tool_call.fn_name,
"arguments": tool_call.fn_arguments.to_string(),
}
})
})
.collect::<Vec<Value>>();
messages.push(json! ({"role": "assistant", "tool_calls": tool_calls}))
}
// TODO: Probably need to trace/warn that this will be ignored
MessageContent::ToolResponses(_) => (),
},
ChatRole::Tool => {
return Err(Error::MessageRoleNotSupported {
model_iden,
role: ChatRole::Tool,
})
if let MessageContent::ToolResponses(tool_responses) = msg.content {
for tool_response in tool_responses {
messages.push(json!({
"role": "tool",
"content": tool_response.content,
"tool_call_id": tool_response.call_id,
}))
}
}
// TODO: Probably need to trace/warn that this will be ignored
}
}
}
// -- Finalize the system messages ollama case
if !system_messages.is_empty() {
let system_message = system_messages.join("\n");
messages.insert(0, json!({"role": "system", "content": system_message}));
}
Ok(OpenAIRequestParts { messages })
// -- Process the tools
let tools = chat_req.tools.map(|tools| {
tools
.into_iter()
.map(|tool| {
// TODO: Need to handle the error correctly
// TODO: Needs to have a custom serializer (tool should not have to match to a provider)
// NOTE: Right now, low probability, so, we just return null if cannto to value.
json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.schema,
// TODO: If we need to support `strict: true` we need to add additionalProperties: false into the schema
// above (like structured output)
"strict": false,
}
})
})
.collect::<Vec<Value>>()
});
Ok(OpenAIRequestParts { messages, tools })
}
}
@ -255,6 +338,59 @@ impl OpenAIAdapter {
struct OpenAIRequestParts {
messages: Vec<Value>,
tools: Option<Vec<Value>>,
}
fn parse_tool_calls(raw_tool_calls: Value) -> Result<Vec<ToolCall>> {
let Value::Array(raw_tool_calls) = raw_tool_calls else {
return Err(Error::InvalidJsonResponseElement {
info: "tool calls is not an array",
});
};
let tool_calls = raw_tool_calls.into_iter().map(parse_tool_call).collect::<Result<Vec<_>>>()?;
Ok(tool_calls)
}
fn parse_tool_call(raw_tool_call: Value) -> Result<ToolCall> {
// Define a helper struct to match the original JSON structure.
#[derive(Deserialize)]
struct IterimToolFnCall {
id: String,
#[serde(rename = "type")]
r#type: String,
function: IterimFunction,
}
#[derive(Deserialize)]
struct IterimFunction {
name: String,
arguments: Value,
}
let iterim = serde_json::from_value::<IterimToolFnCall>(raw_tool_call)?;
let fn_name = iterim.function.name;
// For now support Object only, and parse the eventual string as a json value.
// Eventually, we might check pricing
let fn_arguments = match iterim.function.arguments {
Value::Object(obj) => Value::Object(obj),
Value::String(txt) => serde_json::from_str(&txt)?,
_ => {
return Err(Error::InvalidJsonResponseElement {
info: "tool call arguments is not an object",
})
}
};
// Then, map the fields of the helper struct to the flat structure.
Ok(ToolCall {
call_id: iterim.id,
fn_name,
fn_arguments,
})
}
// endregion: --- Support

71
src/chat/chat_message.rs Normal file
View File

@ -0,0 +1,71 @@
use crate::chat::{MessageContent, ToolCall, ToolResponse};
use serde::{Deserialize, Serialize};
/// An individual chat message.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
/// The role of the message.
pub role: ChatRole,
/// The content of the message.
pub content: MessageContent,
}
/// Constructors
impl ChatMessage {
/// Create a new ChatMessage with the role `ChatRole::System`.
pub fn system(content: impl Into<MessageContent>) -> Self {
Self {
role: ChatRole::System,
content: content.into(),
}
}
/// Create a new ChatMessage with the role `ChatRole::Assistant`.
pub fn assistant(content: impl Into<MessageContent>) -> Self {
Self {
role: ChatRole::Assistant,
content: content.into(),
}
}
/// Create a new ChatMessage with the role `ChatRole::User`.
pub fn user(content: impl Into<MessageContent>) -> Self {
Self {
role: ChatRole::User,
content: content.into(),
}
}
}
/// Chat roles.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum ChatRole {
System,
User,
Assistant,
Tool,
}
// region: --- Froms
impl From<Vec<ToolCall>> for ChatMessage {
fn from(tool_calls: Vec<ToolCall>) -> Self {
Self {
role: ChatRole::Assistant,
content: MessageContent::from(tool_calls),
}
}
}
impl From<ToolResponse> for ChatMessage {
fn from(value: ToolResponse) -> Self {
Self {
role: ChatRole::Tool,
content: MessageContent::from(value),
}
}
}
// endregion: --- Froms

View File

@ -5,7 +5,7 @@
//! Note 1: In the future, we will probably allow setting the client
//! Note 2: Extracting it from the `ChatRequest` object allows for better reusability of each component.
use crate::chat::chat_response_format::ChatResponseFormat;
use crate::chat::chat_req_response_format::ChatResponseFormat;
use serde::{Deserialize, Serialize};
/// Chat Options that are taken into account for any `Client::exec...` calls.

View File

@ -2,7 +2,7 @@ use derive_more::From;
use serde::{Deserialize, Serialize};
use serde_json::Value;
/// The chat response format to be sent back by the LLM.
/// The chat response format for the ChatRequest for structured output.
/// This will be taken into consideration only if the provider supports it.
///
/// > Note: Currently, the AI Providers will not report an error if not supported. It will just be ignored.

View File

@ -1,6 +1,6 @@
//! This module contains all the types related to a Chat Request (except ChatOptions, which has its own file).
use crate::chat::MessageContent;
use crate::chat::{ChatMessage, ChatRole, MessageContent, Tool};
use serde::{Deserialize, Serialize};
// region: --- ChatRequest
@ -13,13 +13,19 @@ pub struct ChatRequest {
/// The messages of the request.
pub messages: Vec<ChatMessage>,
pub tools: Option<Vec<Tool>>,
}
/// Constructors
impl ChatRequest {
/// Create a new ChatRequest with the given messages.
pub fn new(messages: Vec<ChatMessage>) -> Self {
Self { messages, system: None }
Self {
messages,
system: None,
tools: None,
}
}
/// From the `.system` property content.
@ -27,6 +33,7 @@ impl ChatRequest {
Self {
system: Some(content.into()),
messages: Vec::new(),
tools: None,
}
}
@ -34,7 +41,8 @@ impl ChatRequest {
pub fn from_user(content: impl Into<String>) -> Self {
Self {
system: None,
messages: vec![ChatMessage::user(content)],
messages: vec![ChatMessage::user(content.into())],
tools: None,
}
}
}
@ -48,8 +56,18 @@ impl ChatRequest {
}
/// Append a message to the request.
pub fn append_message(mut self, msg: ChatMessage) -> Self {
self.messages.push(msg);
pub fn append_message(mut self, msg: impl Into<ChatMessage>) -> Self {
self.messages.push(msg.into());
self
}
pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = Some(tools);
self
}
pub fn append_tool(mut self, tool: impl Into<Tool>) -> Self {
self.tools.get_or_insert_with(Vec::new).push(tool.into());
self
}
}
@ -65,6 +83,8 @@ impl ChatRequest {
.chain(self.messages.iter().filter_map(|message| match message.role {
ChatRole::System => match message.content {
MessageContent::Text(ref content) => Some(content.as_str()),
/// If system content is not text, then, we do not add it for now.
_ => None,
},
_ => None,
}))
@ -97,74 +117,3 @@ impl ChatRequest {
}
// endregion: --- ChatRequest
// region: --- ChatMessage
/// An individual chat message.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
/// The role of the message.
pub role: ChatRole,
/// The content of the message.
pub content: MessageContent,
/// Extra information about the message.
pub extra: Option<MessageExtra>,
}
/// Constructors
impl ChatMessage {
/// Create a new ChatMessage with the role `ChatRole::System`.
pub fn system(content: impl Into<MessageContent>) -> Self {
Self {
role: ChatRole::System,
content: content.into(),
extra: None,
}
}
/// Create a new ChatMessage with the role `ChatRole::Assistant`.
pub fn assistant(content: impl Into<MessageContent>) -> Self {
Self {
role: ChatRole::Assistant,
content: content.into(),
extra: None,
}
}
/// Create a new ChatMessage with the role `ChatRole::User`.
pub fn user(content: impl Into<MessageContent>) -> Self {
Self {
role: ChatRole::User,
content: content.into(),
extra: None,
}
}
}
/// Chat roles.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum ChatRole {
System,
User,
Assistant,
Tool,
}
/// NOTE: DO NOT USE, just a placeholder for now.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum MessageExtra {
Tool(ToolExtra),
}
/// NOTE: DO NOT USE, just a placeholder for now.
#[allow(unused)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolExtra {
tool_id: String,
}
// endregion: --- ChatMessage

View File

@ -2,7 +2,7 @@
use serde::{Deserialize, Serialize};
use crate::chat::{ChatStream, MessageContent};
use crate::chat::{ChatStream, MessageContent, ToolCall};
use crate::ModelIden;
// region: --- ChatResponse
@ -13,12 +13,12 @@ pub struct ChatResponse {
/// The eventual content of the chat response
pub content: Option<MessageContent>,
/// The eventual usage of the chat response
pub usage: MetaUsage,
/// The Model Identifier (AdapterKind/ModelName) used for this request.
/// > NOTE: This might be different from the request model if changed by the ModelMapper
pub model_iden: ModelIden,
/// The eventual usage of the chat response
pub usage: MetaUsage,
}
// Getters
@ -34,6 +34,22 @@ impl ChatResponse {
pub fn content_text_into_string(self) -> Option<String> {
self.content.and_then(MessageContent::text_into_string)
}
pub fn tool_calls(&self) -> Option<Vec<&ToolCall>> {
if let Some(MessageContent::ToolCalls(tool_calls)) = self.content.as_ref() {
Some(tool_calls.iter().collect())
} else {
None
}
}
pub fn into_tool_calls(self) -> Option<Vec<ToolCall>> {
if let Some(MessageContent::ToolCalls(tool_calls)) = self.content {
Some(tool_calls)
} else {
None
}
}
}
// endregion: --- ChatResponse

View File

@ -1,19 +1,34 @@
use crate::chat::{ToolCall, ToolResponse};
use derive_more::derive::From;
use serde::{Deserialize, Serialize};
/// Currently, it only supports Text,
/// but the goal is to support multi-part message content (see below)
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, From)]
pub enum MessageContent {
/// Text content
Text(String),
/// Tool calls
#[from]
ToolCalls(Vec<ToolCall>),
/// Tool call Responses
#[from]
ToolResponses(Vec<ToolResponse>),
}
/// Constructors
impl MessageContent {
/// Create a new MessageContent with the Text variant
pub fn text(content: impl Into<String>) -> Self {
pub fn from_text(content: impl Into<String>) -> Self {
MessageContent::Text(content.into())
}
/// Create a new MessageContent with the ToolCalls variant
pub fn from_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
MessageContent::ToolCalls(tool_calls)
}
}
/// Getters
@ -25,6 +40,8 @@ impl MessageContent {
pub fn text_as_str(&self) -> Option<&str> {
match self {
MessageContent::Text(content) => Some(content.as_str()),
MessageContent::ToolCalls(_) => None,
MessageContent::ToolResponses(_) => None,
}
}
@ -36,29 +53,44 @@ impl MessageContent {
pub fn text_into_string(self) -> Option<String> {
match self {
MessageContent::Text(content) => Some(content),
MessageContent::ToolCalls(_) => None,
MessageContent::ToolResponses(_) => None,
}
}
/// Checks if the text content is empty (for now)
/// Later, this will also validate each variant to check if they can be considered "empty"
/// Checks if the text content or the tools calls is empty.
pub fn is_empty(&self) -> bool {
match self {
MessageContent::Text(content) => content.is_empty(),
MessageContent::ToolCalls(tool_calls) => tool_calls.is_empty(),
MessageContent::ToolResponses(tool_responses) => tool_responses.is_empty(),
}
}
}
// region: --- Froms
/// Blanket implementation for MessageContent::Text for anything that implements Into<String>
/// Note: This means that when we support base64 as images, it should not use `.into()` for MessageContent.
/// It should be acceptable but may need reassessment.
impl<T> From<T> for MessageContent
where
T: Into<String>,
{
fn from(s: T) -> Self {
MessageContent::text(s)
impl From<String> for MessageContent {
fn from(s: String) -> Self {
MessageContent::from_text(s)
}
}
impl<'a> From<&'a str> for MessageContent {
fn from(s: &'a str) -> Self {
MessageContent::from_text(s.to_string())
}
}
impl From<&String> for MessageContent {
fn from(s: &String) -> Self {
MessageContent::from_text(s.clone())
}
}
impl From<ToolResponse> for MessageContent {
fn from(tool_response: ToolResponse) -> Self {
MessageContent::ToolResponses(vec![tool_response])
}
}

View File

@ -3,19 +3,21 @@
// region: --- Modules
mod chat_message;
mod chat_options;
mod chat_req;
mod chat_req_response_format;
mod chat_request;
mod chat_res;
mod chat_response_format;
mod chat_stream;
mod message_content;
mod tool;
// -- Flatten
pub use chat_message::*;
pub use chat_options::*;
pub use chat_req::*;
pub use chat_req_response_format::*;
pub use chat_request::*;
pub use chat_res::*;
pub use chat_response_format::*;
pub use chat_stream::*;
pub use message_content::*;
pub use tool::*;

View File

@ -1,14 +0,0 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
/// NOT USED FOR NOW
/// > For later, it will be used for function calling
/// > It will probably use the JsonSpec type we had in the response format,
/// > or have a `From<JsonSpec>` implementation.
#[allow(unused)] // Not used yet
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
fn_name: String,
fn_description: String,
params: Value,
}

11
src/chat/tool/mod.rs Normal file
View File

@ -0,0 +1,11 @@
// region: --- Modules
mod tool_base;
mod tool_call;
mod tool_response;
pub use tool_base::*;
pub use tool_call::*;
pub use tool_response::*;
// endregion: --- Modules

View File

@ -0,0 +1,64 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
/// The tool name, which is typically the function name
/// e.g., `get_weather`
pub name: String,
/// The description of the tool which will be used by the LLM to understand the context/usage of this tool
pub description: Option<String>,
/// The json-schema for the parameters
/// e.g.,
/// ```json
/// json!({
/// "type": "object",
/// "properties": {
/// "city": {
/// "type": "string",
/// "description": "The city name"
/// },
/// "country": {
/// "type": "string",
/// "description": "The most likely country of this city name"
/// },
/// "unit": {
/// "type": "string",
/// "enum": ["C", "F"],
/// "description": "The temperature unit of the country. C for Celsius, and F for Fahrenheit"
/// }
/// },
/// "required": ["city", "country", "unit"],
/// })
/// ```
pub schema: Option<Value>,
}
/// Constructor
impl Tool {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
description: None,
schema: None,
}
}
}
// region: --- Setters
impl Tool {
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn with_schema(mut self, parameters: Value) -> Self {
self.schema = Some(parameters);
self
}
}
// endregion: --- Setters

View File

@ -0,0 +1,10 @@
use serde::{Deserialize, Deserializer, Serialize};
use serde_json::Value;
/// The tool call function name and arguments send back by the LLM.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub call_id: String,
pub fn_name: String,
pub fn_arguments: Value,
}

View File

@ -0,0 +1,29 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResponse {
pub call_id: String,
// for now, just string (would probably be serialized json)
pub content: String,
}
/// constructor
impl ToolResponse {
pub fn new(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
call_id: tool_call_id.into(),
content: content.into(),
}
}
}
/// Getters
impl ToolResponse {
fn tool_call_id(&self) -> &str {
&self.call_id
}
fn content(&self) -> &str {
&self.content
}
}

View File

@ -23,12 +23,19 @@ pub enum Error {
model_iden: ModelIden,
role: ChatRole,
},
MessageContentTypeNotSupported {
model_iden: ModelIden,
cause: &'static str,
},
JsonModeWithoutInstruction,
// -- Chat Output
NoChatResponse {
model_iden: ModelIden,
},
InvalidJsonResponseElement {
info: &'static str,
},
// -- Auth
RequiresApiKey {
@ -77,14 +84,15 @@ pub enum Error {
resolver_error: resolver::Error,
},
// -- Utils
// -- Externals
#[from]
EventSourceClone(reqwest_eventsource::CannotCloneRequestError),
#[from]
JsonValueExt(JsonValueExtError),
ReqwestEventSource(reqwest_eventsource::Error),
// Note: will probably need to remvoe this one to give more context
#[from]
SerdeJson(serde_json::Error),
}
// region: --- Error Boilerplate

View File

@ -1,6 +1,6 @@
use crate::get_option_value;
use crate::support::{extract_stream_end, seed_chat_req_simple, Result};
use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatResponseFormat, JsonSpec};
use crate::support::{extract_stream_end, seed_chat_req_simple, seed_chat_req_tool_simple, Result};
use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatResponseFormat, JsonSpec, Tool, ToolResponse};
use genai::resolver::{AuthData, AuthResolver, AuthResolverFn, IntoAuthResolverFn};
use genai::{Client, ClientConfig, ModelIden};
use serde_json::{json, Value};
@ -260,6 +260,68 @@ pub async fn common_test_chat_stream_capture_all_ok(model: &str) -> Result<()> {
// endregion: --- Chat Stream Tests
// region: --- Tools
/// Just making the tool request, and checking the tool call response
/// `complete_check` if for LLMs that are better at giving back the unit and weather.
pub async fn common_test_tool_simple_ok(model: &str, complete_check: bool) -> Result<()> {
// -- Setup & Fixtures
let client = Client::default();
let chat_req = seed_chat_req_tool_simple();
// -- Exec
let chat_res = client.exec_chat(model, chat_req, None).await?;
// -- Check
let mut tool_calls = chat_res.tool_calls().ok_or("Should have tool calls")?;
let tool_call = tool_calls.pop().ok_or("Should have at least one tool call")?;
assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("city")?, "Paris");
assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("country")?, "France");
if complete_check {
// Note: Not all LLM will output the weather (e.g. Anthropic Haiku)
assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("unit")?, "C");
}
Ok(())
}
/// `complete_check` if for LLMs that are better at giving back the unit and weather.
///
pub async fn common_test_tool_full_flow_ok(model: &str, complete_check: bool) -> Result<()> {
// -- Setup & Fixtures
let client = Client::default();
let mut chat_req = seed_chat_req_tool_simple();
// -- Exec first request to get the tool calls
let chat_res = client.exec_chat(model, chat_req.clone(), None).await?;
let tool_calls = chat_res.into_tool_calls().ok_or("Should have tool calls in chat_res")?;
// -- Exec the second request
// get the tool call id (first one)
let first_tool_call = tool_calls.first().ok_or("Should have at least one tool call")?;
let first_tool_call_id = &first_tool_call.call_id;
// simulate the response
let tool_response = ToolResponse::new(first_tool_call_id, r#"{"weather": "Sunny", "temperature": "32C"}"#);
// Add the tool_calls, tool_response
let chat_req = chat_req.append_message(tool_calls).append_message(tool_response);
let chat_res = client.exec_chat(model, chat_req.clone(), None).await?;
// -- Check
let content = chat_res.content_text_as_str().ok_or("Last response should be message")?;
assert!(content.contains("Paris"), "Should contain 'Paris'");
assert!(content.contains("32"), "Should contain '32'");
if complete_check {
// Note: Not all LLM will output the weather (e.g. Anthropic Haiku)
assert!(content.contains("sunny"), "Should contain 'sunny'");
}
Ok(())
}
// endregion: --- Tools
// region: --- With Resolvers
pub async fn common_test_resolver_auth_ok(model: &str, auth_data: AuthData) -> Result<()> {

View File

@ -1,4 +1,5 @@
use genai::chat::{ChatMessage, ChatRequest};
use genai::chat::{ChatMessage, ChatRequest, Tool};
use serde_json::json;
pub fn seed_chat_req_simple() -> ChatRequest {
ChatRequest::new(vec![
@ -7,3 +8,29 @@ pub fn seed_chat_req_simple() -> ChatRequest {
ChatMessage::user("Why is the sky blue?"),
])
}
pub fn seed_chat_req_tool_simple() -> ChatRequest {
ChatRequest::new(vec![
// -- Messages (deactivate to see the differences)
ChatMessage::user("What is the temperature in C, in Paris"),
])
.append_tool(Tool::new("get_weather").with_schema(json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
},
"country": {
"type": "string",
"description": "The most likely country of this city name"
},
"unit": {
"type": "string",
"enum": ["C", "F"],
"description": "The temperature unit of the country. C for Celsius, and F for Fahrenheit"
}
},
"required": ["city", "country", "unit"],
})))
}

View File

@ -45,6 +45,20 @@ async fn test_chat_stream_capture_all_ok() -> Result<()> {
// endregion: --- Chat Stream Tests
// region: --- Tool Tests
#[tokio::test]
async fn test_tool_simple_ok() -> Result<()> {
common_tests::common_test_tool_simple_ok(MODEL, false).await
}
#[tokio::test]
async fn test_tool_full_flow_ok() -> Result<()> {
common_tests::common_test_tool_full_flow_ok(MODEL, false).await
}
// endregion: --- Tool Tests
// region: --- Resolver Tests
#[tokio::test]

View File

@ -50,6 +50,19 @@ async fn test_chat_stream_capture_all_ok() -> Result<()> {
// endregion: --- Chat Stream Tests
// region: --- Tool Tests
#[tokio::test]
async fn test_tool_simple_ok() -> Result<()> {
common_tests::common_test_tool_simple_ok(MODEL, true).await
}
#[tokio::test]
async fn test_tool_full_flow_ok() -> Result<()> {
common_tests::common_test_tool_full_flow_ok(MODEL, true).await
}
// endregion: --- Tool Tests
// region: --- Resolver Tests
#[tokio::test]