+ tool - First pass at adding Function Calling for OpenAI and Anthropic (rel #24)

This commit is contained in:
Jeremy Chone
2024-10-22 12:30:16 -07:00
parent 4b76a5ead6
commit 001b124381
23 changed files with 740 additions and 177 deletions

View File

@@ -11,8 +11,8 @@ repository = "https://github.com/jeremychone/rust-genai"
[lints.rust] [lints.rust]
unsafe_code = "forbid" unsafe_code = "forbid"
# unused = { level = "allow", priority = -1 } # For exploratory dev. unused = { level = "allow", priority = -1 } # For exploratory dev.
missing_docs = "warn" # missing_docs = "warn"
[dependencies] [dependencies]
# -- Async # -- Async

View File

@@ -3,10 +3,11 @@ use crate::adapter::support::get_api_key;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{ use crate::chat::{
ChatOptionsSet, ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse, MessageContent, MetaUsage, ChatOptionsSet, ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse, MessageContent, MetaUsage,
ToolCall,
}; };
use crate::webc::WebResponse; use crate::webc::WebResponse;
use crate::Result;
use crate::{ClientConfig, ModelIden}; use crate::{ClientConfig, ModelIden};
use crate::{Error, Result};
use reqwest::RequestBuilder; use reqwest::RequestBuilder;
use reqwest_eventsource::EventSource; use reqwest_eventsource::EventSource;
use serde_json::{json, Value}; use serde_json::{json, Value};
@@ -18,9 +19,9 @@ const BASE_URL: &str = "https://api.anthropic.com/v1/";
const MAX_TOKENS: u32 = 1024; const MAX_TOKENS: u32 = 1024;
const ANTRHOPIC_VERSION: &str = "2023-06-01"; const ANTRHOPIC_VERSION: &str = "2023-06-01";
const MODELS: &[&str] = &[ const MODELS: &[&str] = &[
"claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620",
"claude-3-opus-20240229", "claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
]; ];
@@ -53,7 +54,7 @@ impl Adapter for AnthropicAdapter {
let url = Self::get_service_url(model_iden.clone(), service_type); let url = Self::get_service_url(model_iden.clone(), service_type);
// -- api_key (this Adapter requires it) // -- api_key (this Adapter requires it)
let api_key = get_api_key(model_iden, client_config)?; let api_key = get_api_key(model_iden.clone(), client_config)?;
let headers = vec![ let headers = vec![
// headers // headers
@@ -61,7 +62,11 @@ impl Adapter for AnthropicAdapter {
("anthropic-version".to_string(), ANTRHOPIC_VERSION.to_string()), ("anthropic-version".to_string(), ANTRHOPIC_VERSION.to_string()),
]; ];
let AnthropicRequestParts { system, messages } = Self::into_anthropic_request_parts(chat_req)?; let AnthropicRequestParts {
system,
messages,
tools,
} = Self::into_anthropic_request_parts(model_iden, chat_req)?;
// -- Build the basic payload // -- Build the basic payload
let mut payload = json!({ let mut payload = json!({
@@ -69,10 +74,15 @@ impl Adapter for AnthropicAdapter {
"messages": messages, "messages": messages,
"stream": stream "stream": stream
}); });
if let Some(system) = system { if let Some(system) = system {
payload.x_insert("system", system)?; payload.x_insert("system", system)?;
} }
if let Some(tools) = tools {
payload.x_insert("/tools", tools);
}
// -- Add supported ChatOptions // -- Add supported ChatOptions
if let Some(temperature) = options_set.temperature() { if let Some(temperature) = options_set.temperature() {
payload.x_insert("temperature", temperature)?; payload.x_insert("temperature", temperature)?;
@@ -90,27 +100,51 @@ impl Adapter for AnthropicAdapter {
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> { fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
let WebResponse { mut body, .. } = web_response; let WebResponse { mut body, .. } = web_response;
let json_content_items: Vec<Value> = body.x_take("content")?;
let mut content: Vec<String> = Vec::new();
// -- Capture the usage
let usage = body.x_take("usage").map(Self::into_usage).unwrap_or_default(); let usage = body.x_take("usage").map(Self::into_usage).unwrap_or_default();
for mut item in json_content_items { // -- Capture the content
let item_text: String = item.x_take("text")?; // NOTE: Anthropic support a list of content of multitypes but not the ChatResponse
content.push(item_text); // So, the strategy is to:
// - List all of the content and capture the text and tool_use
// - If there is one or more tool_use, this will take precedence and MessageContent support tool_call list
// - Otherwise, the text is concatenated
// NOTE: We need to see if the multiple content type text happens and why. If not, we can probably simplify this by just capturing the first one.
// Eventually, ChatResponse will have `content: Option<Vec<MessageContent>>` for the multi parts (with images and such)
let content_items: Vec<Value> = body.x_take("content")?;
let mut text_content: Vec<String> = Vec::new();
// Note: here tool_calls is probably the exception, so, not creating the vector if not needed
let mut tool_calls: Option<Vec<ToolCall>> = None;
for mut item in content_items {
let typ: &str = item.x_get_as("type")?;
if typ == "text" {
text_content.push(item.x_take("text")?);
} else if typ == "tool_use" {
let call_id = item.x_take::<String>("id")?;
let fn_name = item.x_take::<String>("name")?;
// if not found, will be Value::Null
let fn_arguments = item.x_take::<Value>("input").unwrap_or_default();
let tool_call = ToolCall {
call_id,
fn_name,
fn_arguments,
};
tool_calls.get_or_insert_with(Vec::new).push(tool_call);
}
} }
let content = if content.is_empty() { let content = if let Some(tool_calls) = tool_calls {
None Some(MessageContent::from(tool_calls))
} else { } else {
Some(content.join("")) Some(MessageContent::from(text_content.join("\n")))
}; };
let content = content.map(MessageContent::from);
Ok(ChatResponse { Ok(ChatResponse {
model_iden,
content, content,
model_iden,
usage, usage,
}) })
} }
@@ -153,7 +187,7 @@ impl AnthropicAdapter {
/// Takes the GenAI ChatMessages and constructs the System string and JSON Messages for Anthropic. /// Takes the GenAI ChatMessages and constructs the System string and JSON Messages for Anthropic.
/// - Will push the `ChatRequest.system` and system message to `AnthropicRequestParts.system` /// - Will push the `ChatRequest.system` and system message to `AnthropicRequestParts.system`
fn into_anthropic_request_parts(chat_req: ChatRequest) -> Result<AnthropicRequestParts> { fn into_anthropic_request_parts(model_iden: ModelIden, chat_req: ChatRequest) -> Result<AnthropicRequestParts> {
let mut messages: Vec<Value> = Vec::new(); let mut messages: Vec<Value> = Vec::new();
let mut systems: Vec<String> = Vec::new(); let mut systems: Vec<String> = Vec::new();
@@ -161,32 +195,115 @@ impl AnthropicAdapter {
systems.push(system); systems.push(system);
} }
// -- Process the messages
for msg in chat_req.messages { for msg in chat_req.messages {
// Note: Will handle more types later
let MessageContent::Text(content) = msg.content;
match msg.role { match msg.role {
// for now, system and tool messages go to system // for now, system and tool messages go to system
ChatRole::System | ChatRole::Tool => systems.push(content), ChatRole::System => {
ChatRole::User => messages.push(json! ({"role": "user", "content": content})), if let MessageContent::Text(content) = msg.content {
ChatRole::Assistant => messages.push(json! ({"role": "assistant", "content": content})), systems.push(content)
}
// TODO: Needs to trace/warn that other type are not supported
}
ChatRole::User => {
if let MessageContent::Text(content) = msg.content {
messages.push(json! ({"role": "user", "content": content}))
}
// TODO: Needs to trace/warn that other type are not supported
}
ChatRole::Assistant => {
//
match msg.content {
MessageContent::Text(content) => {
messages.push(json! ({"role": "assistant", "content": content}))
}
MessageContent::ToolCalls(tool_calls) => {
let tool_calls = tool_calls
.into_iter()
.map(|tool_call| {
// see: https://docs.anthropic.com/en/docs/build-with-claude/tool-use#example-of-successful-tool-result
json!({
"type": "tool_use",
"id": tool_call.call_id,
"name": tool_call.fn_name,
"input": tool_call.fn_arguments,
})
})
.collect::<Vec<Value>>();
messages.push(json! ({
"role": "assistant",
"content": tool_calls
}));
}
// TODO: Probably need to trace/warn that this will be ignored
MessageContent::ToolResponses(_) => (),
}
}
ChatRole::Tool => {
if let MessageContent::ToolResponses(tool_responses) = msg.content {
let tool_responses = tool_responses
.into_iter()
.map(|tool_response| {
json!({
"type": "tool_result",
"content": tool_response.content,
"tool_use_id": tool_response.call_id,
})
})
.collect::<Vec<Value>>();
// FIXME: MessageContent::ToolResponse should be MessageContent::ToolResponses (even if openAI does require multi Tool message)
messages.push(json!({
"role": "user",
"content": tool_responses
}));
}
// TODO: Probably need to trace/warn that this will be ignored
}
} }
} }
// -- Create the Anthropic system
// NOTE: Anthropic does not have a "role": "system", just a single optional system property
let system = if !systems.is_empty() { let system = if !systems.is_empty() {
Some(systems.join("\n")) Some(systems.join("\n"))
} else { } else {
None None
}; };
Ok(AnthropicRequestParts { system, messages }) // -- Process the tools
let tools = chat_req.tools.map(|tools| {
tools
.into_iter()
.map(|tool| {
// TODO: Need to handle the error correctly
// TODO: Needs to have a custom serializer (tool should not have to match to a provider)
// NOTE: Right now, low probability, so, we just return null if cannto to value.
let mut tool_value = json!({
"name": tool.name,
"input_schema": tool.schema,
});
if let Some(description) = tool.description {
tool_value.x_insert("description", description);
}
tool_value
})
.collect::<Vec<Value>>()
});
Ok(AnthropicRequestParts {
system,
messages,
tools,
})
} }
} }
struct AnthropicRequestParts { struct AnthropicRequestParts {
system: Option<String>, system: Option<String>,
messages: Vec<Value>, messages: Vec<Value>,
// TODO: need to add tools tools: Option<Vec<Value>>,
} }
// endregion: --- Support // endregion: --- Support

View File

@@ -1,4 +1,5 @@
//! API Documentation: https://docs.anthropic.com/en/api/messages //! API Documentation: https://docs.anthropic.com/en/api/messages
//! Tool Documentation: https://docs.anthropic.com/en/docs/build-with-claude/tool-use
//! Model Names: https://docs.anthropic.com/en/docs/models-overview //! Model Names: https://docs.anthropic.com/en/docs/models-overview
//! Pricing: https://www.anthropic.com/pricing#anthropic-api //! Pricing: https://www.anthropic.com/pricing#anthropic-api

View File

@@ -110,8 +110,8 @@ impl Adapter for CohereAdapter {
.map(MessageContent::from); .map(MessageContent::from);
Ok(ChatResponse { Ok(ChatResponse {
model_iden,
content, content,
model_iden,
usage, usage,
}) })
} }
@@ -185,13 +185,23 @@ impl CohereAdapter {
actual_role: last_chat_msg.role, actual_role: last_chat_msg.role,
}); });
} }
// Will handle more types later
let MessageContent::Text(message) = last_chat_msg.content; // TODO: Needs to implement tool_calls
let MessageContent::Text(message) = last_chat_msg.content else {
return Err(Error::MessageContentTypeNotSupported {
model_iden,
cause: "Only MessageContent::Text supported for this model (for now)",
});
};
// -- Build // -- Build
for msg in chat_req.messages { for msg in chat_req.messages {
// Note: Will handle more types later let MessageContent::Text(content) = msg.content else {
let MessageContent::Text(content) = msg.content; return Err(Error::MessageContentTypeNotSupported {
model_iden,
cause: "Only MessageContent::Text supported for this model (for now)",
});
};
match msg.role { match msg.role {
// For now, system and tool go to the system // For now, system and tool go to the system

View File

@@ -121,8 +121,8 @@ impl Adapter for GeminiAdapter {
let content = content.map(MessageContent::from); let content = content.map(MessageContent::from);
Ok(ChatResponse { Ok(ChatResponse {
model_iden,
content, content,
model_iden,
usage, usage,
}) })
} }
@@ -192,8 +192,13 @@ impl GeminiAdapter {
// -- Build // -- Build
for msg in chat_req.messages { for msg in chat_req.messages {
// Note: Will handle more types later // TODO: Needs to implement tool_calls
let MessageContent::Text(content) = msg.content; let MessageContent::Text(content) = msg.content else {
return Err(Error::MessageContentTypeNotSupported {
model_iden,
cause: "Only MessageContent::Text supported for this model (for now)",
});
};
match msg.role { match msg.role {
// For now, system goes as "user" (later, we might have adapter_config.system_to_user_impl) // For now, system goes as "user" (later, we might have adapter_config.system_to_user_impl)

View File

@@ -3,13 +3,14 @@ use crate::adapter::support::get_api_key;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{ use crate::chat::{
ChatOptionsSet, ChatRequest, ChatResponse, ChatResponseFormat, ChatRole, ChatStream, ChatStreamResponse, ChatOptionsSet, ChatRequest, ChatResponse, ChatResponseFormat, ChatRole, ChatStream, ChatStreamResponse,
MessageContent, MetaUsage, MessageContent, MetaUsage, Tool, ToolCall,
}; };
use crate::webc::WebResponse; use crate::webc::WebResponse;
use crate::{ClientConfig, ModelIden}; use crate::{ClientConfig, ModelIden};
use crate::{Error, Result}; use crate::{Error, Result};
use reqwest::RequestBuilder; use reqwest::RequestBuilder;
use reqwest_eventsource::EventSource; use reqwest_eventsource::EventSource;
use serde::Deserialize;
use serde_json::{json, Value}; use serde_json::{json, Value};
use value_ext::JsonValueExt; use value_ext::JsonValueExt;
@@ -54,15 +55,31 @@ impl Adapter for OpenAIAdapter {
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> { fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
let WebResponse { mut body, .. } = web_response; let WebResponse { mut body, .. } = web_response;
// -- Capture the usage
let usage = body.x_take("usage").map(OpenAIAdapter::into_usage).unwrap_or_default(); let usage = body.x_take("usage").map(OpenAIAdapter::into_usage).unwrap_or_default();
let first_choice: Option<Value> = body.x_take("/choices/0")?; // -- Capture the content
let content: Option<String> = first_choice.map(|mut c| c.x_take("/message/content")).transpose()?; let content = if let Some(mut first_choice) = body.x_take::<Option<Value>>("/choices/0")? {
let content = content.map(MessageContent::from); if let Some(content) = first_choice
.x_take::<Option<String>>("/message/content")?
.map(MessageContent::from)
{
Some(content)
} else {
first_choice
.x_take("/message/tool_calls")
.ok()
.map(parse_tool_calls)
.transpose()?
.map(MessageContent::from_tool_calls)
}
} else {
None
};
Ok(ChatResponse { Ok(ChatResponse {
model_iden,
content, content,
model_iden,
usage, usage,
}) })
} }
@@ -117,13 +134,17 @@ impl OpenAIAdapter {
// -- Build the basic payload // -- Build the basic payload
let model_name = model_iden.model_name.to_string(); let model_name = model_iden.model_name.to_string();
let OpenAIRequestParts { messages } = Self::into_openai_request_parts(model_iden, chat_req)?; let OpenAIRequestParts { messages, tools } = Self::into_openai_request_parts(model_iden, chat_req)?;
let mut payload = json!({ let mut payload = json!({
"model": model_name, "model": model_name,
"messages": messages, "messages": messages,
"stream": stream "stream": stream
}); });
if let Some(tools) = tools {
payload.x_insert("/tools", tools);
}
// -- Add options // -- Add options
let response_format = if let Some(response_format) = options_set.response_format() { let response_format = if let Some(response_format) = options_set.response_format() {
match response_format { match response_format {
@@ -199,16 +220,17 @@ impl OpenAIAdapter {
/// Takes the genai ChatMessages and builds the OpenAIChatRequestParts /// Takes the genai ChatMessages and builds the OpenAIChatRequestParts
/// - `genai::ChatRequest.system`, if present, is added as the first message with role 'system'. /// - `genai::ChatRequest.system`, if present, is added as the first message with role 'system'.
/// - All messages get added with the corresponding roles (tools are not supported for now) /// - All messages get added with the corresponding roles (tools are not supported for now)
///
/// NOTE: Here, the last `true` is for the Ollama variant
/// It seems the Ollama compatibility layer does not work well with multiple system messages.
/// So, when `true`, it will concatenate the system message into a single one at the beginning
fn into_openai_request_parts(model_iden: ModelIden, chat_req: ChatRequest) -> Result<OpenAIRequestParts> { fn into_openai_request_parts(model_iden: ModelIden, chat_req: ChatRequest) -> Result<OpenAIRequestParts> {
let mut system_messages: Vec<String> = Vec::new();
let mut messages: Vec<Value> = Vec::new(); let mut messages: Vec<Value> = Vec::new();
/// NOTE: For now system_messages is use to fix an issue with the Ollama compatibility layer that does not support multiple system messages.
/// So, when ollama, it will concatenate the system message into a single one at the beginning
/// NOTE: This might be fixed now, so, we could remove this.
let mut system_messages: Vec<String> = Vec::new();
let ollama_variant = matches!(model_iden.adapter_kind, AdapterKind::Ollama); let ollama_variant = matches!(model_iden.adapter_kind, AdapterKind::Ollama);
// -- Process the system
if let Some(system_msg) = chat_req.system { if let Some(system_msg) = chat_req.system {
if ollama_variant { if ollama_variant {
system_messages.push(system_msg) system_messages.push(system_msg)
@@ -217,13 +239,15 @@ impl OpenAIAdapter {
} }
} }
// -- Process the messages
for msg in chat_req.messages { for msg in chat_req.messages {
// Note: Will handle more types later // Note: Will handle more types later
let MessageContent::Text(content) = msg.content;
match msg.role { match msg.role {
// For now, system and tool messages go to the system // For now, system and tool messages go to the system
ChatRole::System => { ChatRole::System => {
if let MessageContent::Text(content) = msg.content {
// NOTE: Ollama does not support multiple system messages
// See note in the function comment // See note in the function comment
if ollama_variant { if ollama_variant {
system_messages.push(content); system_messages.push(content);
@@ -231,23 +255,82 @@ impl OpenAIAdapter {
messages.push(json!({"role": "system", "content": content})) messages.push(json!({"role": "system", "content": content}))
} }
} }
ChatRole::User => messages.push(json! ({"role": "user", "content": content})), // TODO: Probably need to warn if it is a ToolCalls type of content
ChatRole::Assistant => messages.push(json! ({"role": "assistant", "content": content})), }
ChatRole::Tool => { ChatRole::User => {
return Err(Error::MessageRoleNotSupported { if let MessageContent::Text(content) = msg.content {
model_iden, messages.push(json! ({"role": "user", "content": content}));
role: ChatRole::Tool, }
// TODO: Probably need to warn if it is a ToolCalls type of content
}
ChatRole::Assistant => match msg.content {
MessageContent::Text(content) => messages.push(json! ({"role": "assistant", "content": content})),
MessageContent::ToolCalls(tool_calls) => {
let tool_calls = tool_calls
.into_iter()
.map(|tool_call| {
json!({
"type": "function",
"id": tool_call.call_id,
"function": {
"name": tool_call.fn_name,
"arguments": tool_call.fn_arguments.to_string(),
}
}) })
})
.collect::<Vec<Value>>();
messages.push(json! ({"role": "assistant", "tool_calls": tool_calls}))
}
// TODO: Probably need to trace/warn that this will be ignored
MessageContent::ToolResponses(_) => (),
},
ChatRole::Tool => {
if let MessageContent::ToolResponses(tool_responses) = msg.content {
for tool_response in tool_responses {
messages.push(json!({
"role": "tool",
"content": tool_response.content,
"tool_call_id": tool_response.call_id,
}))
}
}
// TODO: Probably need to trace/warn that this will be ignored
} }
} }
} }
// -- Finalize the system messages ollama case
if !system_messages.is_empty() { if !system_messages.is_empty() {
let system_message = system_messages.join("\n"); let system_message = system_messages.join("\n");
messages.insert(0, json!({"role": "system", "content": system_message})); messages.insert(0, json!({"role": "system", "content": system_message}));
} }
Ok(OpenAIRequestParts { messages }) // -- Process the tools
let tools = chat_req.tools.map(|tools| {
tools
.into_iter()
.map(|tool| {
// TODO: Need to handle the error correctly
// TODO: Needs to have a custom serializer (tool should not have to match to a provider)
// NOTE: Right now, low probability, so, we just return null if cannto to value.
json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.schema,
// TODO: If we need to support `strict: true` we need to add additionalProperties: false into the schema
// above (like structured output)
"strict": false,
}
})
})
.collect::<Vec<Value>>()
});
Ok(OpenAIRequestParts { messages, tools })
} }
} }
@@ -255,6 +338,59 @@ impl OpenAIAdapter {
struct OpenAIRequestParts { struct OpenAIRequestParts {
messages: Vec<Value>, messages: Vec<Value>,
tools: Option<Vec<Value>>,
}
fn parse_tool_calls(raw_tool_calls: Value) -> Result<Vec<ToolCall>> {
let Value::Array(raw_tool_calls) = raw_tool_calls else {
return Err(Error::InvalidJsonResponseElement {
info: "tool calls is not an array",
});
};
let tool_calls = raw_tool_calls.into_iter().map(parse_tool_call).collect::<Result<Vec<_>>>()?;
Ok(tool_calls)
}
fn parse_tool_call(raw_tool_call: Value) -> Result<ToolCall> {
// Define a helper struct to match the original JSON structure.
#[derive(Deserialize)]
struct IterimToolFnCall {
id: String,
#[serde(rename = "type")]
r#type: String,
function: IterimFunction,
}
#[derive(Deserialize)]
struct IterimFunction {
name: String,
arguments: Value,
}
let iterim = serde_json::from_value::<IterimToolFnCall>(raw_tool_call)?;
let fn_name = iterim.function.name;
// For now support Object only, and parse the eventual string as a json value.
// Eventually, we might check pricing
let fn_arguments = match iterim.function.arguments {
Value::Object(obj) => Value::Object(obj),
Value::String(txt) => serde_json::from_str(&txt)?,
_ => {
return Err(Error::InvalidJsonResponseElement {
info: "tool call arguments is not an object",
})
}
};
// Then, map the fields of the helper struct to the flat structure.
Ok(ToolCall {
call_id: iterim.id,
fn_name,
fn_arguments,
})
} }
// endregion: --- Support // endregion: --- Support

71
src/chat/chat_message.rs Normal file
View File

@@ -0,0 +1,71 @@
use crate::chat::{MessageContent, ToolCall, ToolResponse};
use serde::{Deserialize, Serialize};
/// An individual chat message.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
/// The role of the message.
pub role: ChatRole,
/// The content of the message.
pub content: MessageContent,
}
/// Constructors
impl ChatMessage {
/// Create a new ChatMessage with the role `ChatRole::System`.
pub fn system(content: impl Into<MessageContent>) -> Self {
Self {
role: ChatRole::System,
content: content.into(),
}
}
/// Create a new ChatMessage with the role `ChatRole::Assistant`.
pub fn assistant(content: impl Into<MessageContent>) -> Self {
Self {
role: ChatRole::Assistant,
content: content.into(),
}
}
/// Create a new ChatMessage with the role `ChatRole::User`.
pub fn user(content: impl Into<MessageContent>) -> Self {
Self {
role: ChatRole::User,
content: content.into(),
}
}
}
/// Chat roles.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum ChatRole {
System,
User,
Assistant,
Tool,
}
// region: --- Froms
impl From<Vec<ToolCall>> for ChatMessage {
fn from(tool_calls: Vec<ToolCall>) -> Self {
Self {
role: ChatRole::Assistant,
content: MessageContent::from(tool_calls),
}
}
}
impl From<ToolResponse> for ChatMessage {
fn from(value: ToolResponse) -> Self {
Self {
role: ChatRole::Tool,
content: MessageContent::from(value),
}
}
}
// endregion: --- Froms

View File

@@ -5,7 +5,7 @@
//! Note 1: In the future, we will probably allow setting the client //! Note 1: In the future, we will probably allow setting the client
//! Note 2: Extracting it from the `ChatRequest` object allows for better reusability of each component. //! Note 2: Extracting it from the `ChatRequest` object allows for better reusability of each component.
use crate::chat::chat_response_format::ChatResponseFormat; use crate::chat::chat_req_response_format::ChatResponseFormat;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Chat Options that are taken into account for any `Client::exec...` calls. /// Chat Options that are taken into account for any `Client::exec...` calls.

View File

@@ -2,7 +2,7 @@ use derive_more::From;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
/// The chat response format to be sent back by the LLM. /// The chat response format for the ChatRequest for structured output.
/// This will be taken into consideration only if the provider supports it. /// This will be taken into consideration only if the provider supports it.
/// ///
/// > Note: Currently, the AI Providers will not report an error if not supported. It will just be ignored. /// > Note: Currently, the AI Providers will not report an error if not supported. It will just be ignored.

View File

@@ -1,6 +1,6 @@
//! This module contains all the types related to a Chat Request (except ChatOptions, which has its own file). //! This module contains all the types related to a Chat Request (except ChatOptions, which has its own file).
use crate::chat::MessageContent; use crate::chat::{ChatMessage, ChatRole, MessageContent, Tool};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
// region: --- ChatRequest // region: --- ChatRequest
@@ -13,13 +13,19 @@ pub struct ChatRequest {
/// The messages of the request. /// The messages of the request.
pub messages: Vec<ChatMessage>, pub messages: Vec<ChatMessage>,
pub tools: Option<Vec<Tool>>,
} }
/// Constructors /// Constructors
impl ChatRequest { impl ChatRequest {
/// Create a new ChatRequest with the given messages. /// Create a new ChatRequest with the given messages.
pub fn new(messages: Vec<ChatMessage>) -> Self { pub fn new(messages: Vec<ChatMessage>) -> Self {
Self { messages, system: None } Self {
messages,
system: None,
tools: None,
}
} }
/// From the `.system` property content. /// From the `.system` property content.
@@ -27,6 +33,7 @@ impl ChatRequest {
Self { Self {
system: Some(content.into()), system: Some(content.into()),
messages: Vec::new(), messages: Vec::new(),
tools: None,
} }
} }
@@ -34,7 +41,8 @@ impl ChatRequest {
pub fn from_user(content: impl Into<String>) -> Self { pub fn from_user(content: impl Into<String>) -> Self {
Self { Self {
system: None, system: None,
messages: vec![ChatMessage::user(content)], messages: vec![ChatMessage::user(content.into())],
tools: None,
} }
} }
} }
@@ -48,8 +56,18 @@ impl ChatRequest {
} }
/// Append a message to the request. /// Append a message to the request.
pub fn append_message(mut self, msg: ChatMessage) -> Self { pub fn append_message(mut self, msg: impl Into<ChatMessage>) -> Self {
self.messages.push(msg); self.messages.push(msg.into());
self
}
pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = Some(tools);
self
}
pub fn append_tool(mut self, tool: impl Into<Tool>) -> Self {
self.tools.get_or_insert_with(Vec::new).push(tool.into());
self self
} }
} }
@@ -65,6 +83,8 @@ impl ChatRequest {
.chain(self.messages.iter().filter_map(|message| match message.role { .chain(self.messages.iter().filter_map(|message| match message.role {
ChatRole::System => match message.content { ChatRole::System => match message.content {
MessageContent::Text(ref content) => Some(content.as_str()), MessageContent::Text(ref content) => Some(content.as_str()),
/// If system content is not text, then, we do not add it for now.
_ => None,
}, },
_ => None, _ => None,
})) }))
@@ -97,74 +117,3 @@ impl ChatRequest {
} }
// endregion: --- ChatRequest // endregion: --- ChatRequest
// region: --- ChatMessage
/// An individual chat message.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
/// The role of the message.
pub role: ChatRole,
/// The content of the message.
pub content: MessageContent,
/// Extra information about the message.
pub extra: Option<MessageExtra>,
}
/// Constructors
impl ChatMessage {
/// Create a new ChatMessage with the role `ChatRole::System`.
pub fn system(content: impl Into<MessageContent>) -> Self {
Self {
role: ChatRole::System,
content: content.into(),
extra: None,
}
}
/// Create a new ChatMessage with the role `ChatRole::Assistant`.
pub fn assistant(content: impl Into<MessageContent>) -> Self {
Self {
role: ChatRole::Assistant,
content: content.into(),
extra: None,
}
}
/// Create a new ChatMessage with the role `ChatRole::User`.
pub fn user(content: impl Into<MessageContent>) -> Self {
Self {
role: ChatRole::User,
content: content.into(),
extra: None,
}
}
}
/// Chat roles.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum ChatRole {
System,
User,
Assistant,
Tool,
}
/// NOTE: DO NOT USE, just a placeholder for now.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum MessageExtra {
Tool(ToolExtra),
}
/// NOTE: DO NOT USE, just a placeholder for now.
#[allow(unused)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolExtra {
tool_id: String,
}
// endregion: --- ChatMessage

View File

@@ -2,7 +2,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::chat::{ChatStream, MessageContent}; use crate::chat::{ChatStream, MessageContent, ToolCall};
use crate::ModelIden; use crate::ModelIden;
// region: --- ChatResponse // region: --- ChatResponse
@@ -13,12 +13,12 @@ pub struct ChatResponse {
/// The eventual content of the chat response /// The eventual content of the chat response
pub content: Option<MessageContent>, pub content: Option<MessageContent>,
/// The eventual usage of the chat response
pub usage: MetaUsage,
/// The Model Identifier (AdapterKind/ModelName) used for this request. /// The Model Identifier (AdapterKind/ModelName) used for this request.
/// > NOTE: This might be different from the request model if changed by the ModelMapper /// > NOTE: This might be different from the request model if changed by the ModelMapper
pub model_iden: ModelIden, pub model_iden: ModelIden,
/// The eventual usage of the chat response
pub usage: MetaUsage,
} }
// Getters // Getters
@@ -34,6 +34,22 @@ impl ChatResponse {
pub fn content_text_into_string(self) -> Option<String> { pub fn content_text_into_string(self) -> Option<String> {
self.content.and_then(MessageContent::text_into_string) self.content.and_then(MessageContent::text_into_string)
} }
pub fn tool_calls(&self) -> Option<Vec<&ToolCall>> {
if let Some(MessageContent::ToolCalls(tool_calls)) = self.content.as_ref() {
Some(tool_calls.iter().collect())
} else {
None
}
}
pub fn into_tool_calls(self) -> Option<Vec<ToolCall>> {
if let Some(MessageContent::ToolCalls(tool_calls)) = self.content {
Some(tool_calls)
} else {
None
}
}
} }
// endregion: --- ChatResponse // endregion: --- ChatResponse

View File

@@ -1,19 +1,34 @@
use crate::chat::{ToolCall, ToolResponse};
use derive_more::derive::From;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// Currently, it only supports Text, /// Currently, it only supports Text,
/// but the goal is to support multi-part message content (see below) /// but the goal is to support multi-part message content (see below)
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, From)]
pub enum MessageContent { pub enum MessageContent {
/// Text content /// Text content
Text(String), Text(String),
/// Tool calls
#[from]
ToolCalls(Vec<ToolCall>),
/// Tool call Responses
#[from]
ToolResponses(Vec<ToolResponse>),
} }
/// Constructors /// Constructors
impl MessageContent { impl MessageContent {
/// Create a new MessageContent with the Text variant /// Create a new MessageContent with the Text variant
pub fn text(content: impl Into<String>) -> Self { pub fn from_text(content: impl Into<String>) -> Self {
MessageContent::Text(content.into()) MessageContent::Text(content.into())
} }
/// Create a new MessageContent with the ToolCalls variant
pub fn from_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
MessageContent::ToolCalls(tool_calls)
}
} }
/// Getters /// Getters
@@ -25,6 +40,8 @@ impl MessageContent {
pub fn text_as_str(&self) -> Option<&str> { pub fn text_as_str(&self) -> Option<&str> {
match self { match self {
MessageContent::Text(content) => Some(content.as_str()), MessageContent::Text(content) => Some(content.as_str()),
MessageContent::ToolCalls(_) => None,
MessageContent::ToolResponses(_) => None,
} }
} }
@@ -36,29 +53,44 @@ impl MessageContent {
pub fn text_into_string(self) -> Option<String> { pub fn text_into_string(self) -> Option<String> {
match self { match self {
MessageContent::Text(content) => Some(content), MessageContent::Text(content) => Some(content),
MessageContent::ToolCalls(_) => None,
MessageContent::ToolResponses(_) => None,
} }
} }
/// Checks if the text content is empty (for now) /// Checks if the text content or the tools calls is empty.
/// Later, this will also validate each variant to check if they can be considered "empty"
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
match self { match self {
MessageContent::Text(content) => content.is_empty(), MessageContent::Text(content) => content.is_empty(),
MessageContent::ToolCalls(tool_calls) => tool_calls.is_empty(),
MessageContent::ToolResponses(tool_responses) => tool_responses.is_empty(),
} }
} }
} }
// region: --- Froms // region: --- Froms
/// Blanket implementation for MessageContent::Text for anything that implements Into<String> impl From<String> for MessageContent {
/// Note: This means that when we support base64 as images, it should not use `.into()` for MessageContent. fn from(s: String) -> Self {
/// It should be acceptable but may need reassessment. MessageContent::from_text(s)
impl<T> From<T> for MessageContent }
where }
T: Into<String>,
{ impl<'a> From<&'a str> for MessageContent {
fn from(s: T) -> Self { fn from(s: &'a str) -> Self {
MessageContent::text(s) MessageContent::from_text(s.to_string())
}
}
impl From<&String> for MessageContent {
fn from(s: &String) -> Self {
MessageContent::from_text(s.clone())
}
}
impl From<ToolResponse> for MessageContent {
fn from(tool_response: ToolResponse) -> Self {
MessageContent::ToolResponses(vec![tool_response])
} }
} }

View File

@@ -3,19 +3,21 @@
// region: --- Modules // region: --- Modules
mod chat_message;
mod chat_options; mod chat_options;
mod chat_req; mod chat_req_response_format;
mod chat_request;
mod chat_res; mod chat_res;
mod chat_response_format;
mod chat_stream; mod chat_stream;
mod message_content; mod message_content;
mod tool; mod tool;
// -- Flatten // -- Flatten
pub use chat_message::*;
pub use chat_options::*; pub use chat_options::*;
pub use chat_req::*; pub use chat_req_response_format::*;
pub use chat_request::*;
pub use chat_res::*; pub use chat_res::*;
pub use chat_response_format::*;
pub use chat_stream::*; pub use chat_stream::*;
pub use message_content::*; pub use message_content::*;
pub use tool::*; pub use tool::*;

View File

@@ -1,14 +0,0 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
/// NOT USED FOR NOW
/// > For later, it will be used for function calling
/// > It will probably use the JsonSpec type we had in the response format,
/// > or have a `From<JsonSpec>` implementation.
#[allow(unused)] // Not used yet
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
fn_name: String,
fn_description: String,
params: Value,
}

11
src/chat/tool/mod.rs Normal file
View File

@@ -0,0 +1,11 @@
// region: --- Modules
mod tool_base;
mod tool_call;
mod tool_response;
pub use tool_base::*;
pub use tool_call::*;
pub use tool_response::*;
// endregion: --- Modules

View File

@@ -0,0 +1,64 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
/// The tool name, which is typically the function name
/// e.g., `get_weather`
pub name: String,
/// The description of the tool which will be used by the LLM to understand the context/usage of this tool
pub description: Option<String>,
/// The json-schema for the parameters
/// e.g.,
/// ```json
/// json!({
/// "type": "object",
/// "properties": {
/// "city": {
/// "type": "string",
/// "description": "The city name"
/// },
/// "country": {
/// "type": "string",
/// "description": "The most likely country of this city name"
/// },
/// "unit": {
/// "type": "string",
/// "enum": ["C", "F"],
/// "description": "The temperature unit of the country. C for Celsius, and F for Fahrenheit"
/// }
/// },
/// "required": ["city", "country", "unit"],
/// })
/// ```
pub schema: Option<Value>,
}
/// Constructor
impl Tool {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
description: None,
schema: None,
}
}
}
// region: --- Setters
impl Tool {
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn with_schema(mut self, parameters: Value) -> Self {
self.schema = Some(parameters);
self
}
}
// endregion: --- Setters

View File

@@ -0,0 +1,10 @@
use serde::{Deserialize, Deserializer, Serialize};
use serde_json::Value;
/// The tool call function name and arguments send back by the LLM.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub call_id: String,
pub fn_name: String,
pub fn_arguments: Value,
}

View File

@@ -0,0 +1,29 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResponse {
pub call_id: String,
// for now, just string (would probably be serialized json)
pub content: String,
}
/// constructor
impl ToolResponse {
pub fn new(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
call_id: tool_call_id.into(),
content: content.into(),
}
}
}
/// Getters
impl ToolResponse {
fn tool_call_id(&self) -> &str {
&self.call_id
}
fn content(&self) -> &str {
&self.content
}
}

View File

@@ -23,12 +23,19 @@ pub enum Error {
model_iden: ModelIden, model_iden: ModelIden,
role: ChatRole, role: ChatRole,
}, },
MessageContentTypeNotSupported {
model_iden: ModelIden,
cause: &'static str,
},
JsonModeWithoutInstruction, JsonModeWithoutInstruction,
// -- Chat Output // -- Chat Output
NoChatResponse { NoChatResponse {
model_iden: ModelIden, model_iden: ModelIden,
}, },
InvalidJsonResponseElement {
info: &'static str,
},
// -- Auth // -- Auth
RequiresApiKey { RequiresApiKey {
@@ -77,14 +84,15 @@ pub enum Error {
resolver_error: resolver::Error, resolver_error: resolver::Error,
}, },
// -- Utils
// -- Externals // -- Externals
#[from] #[from]
EventSourceClone(reqwest_eventsource::CannotCloneRequestError), EventSourceClone(reqwest_eventsource::CannotCloneRequestError),
#[from] #[from]
JsonValueExt(JsonValueExtError), JsonValueExt(JsonValueExtError),
ReqwestEventSource(reqwest_eventsource::Error), ReqwestEventSource(reqwest_eventsource::Error),
// Note: will probably need to remvoe this one to give more context
#[from]
SerdeJson(serde_json::Error),
} }
// region: --- Error Boilerplate // region: --- Error Boilerplate

View File

@@ -1,6 +1,6 @@
use crate::get_option_value; use crate::get_option_value;
use crate::support::{extract_stream_end, seed_chat_req_simple, Result}; use crate::support::{extract_stream_end, seed_chat_req_simple, seed_chat_req_tool_simple, Result};
use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatResponseFormat, JsonSpec}; use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatResponseFormat, JsonSpec, Tool, ToolResponse};
use genai::resolver::{AuthData, AuthResolver, AuthResolverFn, IntoAuthResolverFn}; use genai::resolver::{AuthData, AuthResolver, AuthResolverFn, IntoAuthResolverFn};
use genai::{Client, ClientConfig, ModelIden}; use genai::{Client, ClientConfig, ModelIden};
use serde_json::{json, Value}; use serde_json::{json, Value};
@@ -260,6 +260,68 @@ pub async fn common_test_chat_stream_capture_all_ok(model: &str) -> Result<()> {
// endregion: --- Chat Stream Tests // endregion: --- Chat Stream Tests
// region: --- Tools
/// Just making the tool request, and checking the tool call response
/// `complete_check` if for LLMs that are better at giving back the unit and weather.
pub async fn common_test_tool_simple_ok(model: &str, complete_check: bool) -> Result<()> {
// -- Setup & Fixtures
let client = Client::default();
let chat_req = seed_chat_req_tool_simple();
// -- Exec
let chat_res = client.exec_chat(model, chat_req, None).await?;
// -- Check
let mut tool_calls = chat_res.tool_calls().ok_or("Should have tool calls")?;
let tool_call = tool_calls.pop().ok_or("Should have at least one tool call")?;
assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("city")?, "Paris");
assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("country")?, "France");
if complete_check {
// Note: Not all LLM will output the weather (e.g. Anthropic Haiku)
assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("unit")?, "C");
}
Ok(())
}
/// `complete_check` if for LLMs that are better at giving back the unit and weather.
///
pub async fn common_test_tool_full_flow_ok(model: &str, complete_check: bool) -> Result<()> {
// -- Setup & Fixtures
let client = Client::default();
let mut chat_req = seed_chat_req_tool_simple();
// -- Exec first request to get the tool calls
let chat_res = client.exec_chat(model, chat_req.clone(), None).await?;
let tool_calls = chat_res.into_tool_calls().ok_or("Should have tool calls in chat_res")?;
// -- Exec the second request
// get the tool call id (first one)
let first_tool_call = tool_calls.first().ok_or("Should have at least one tool call")?;
let first_tool_call_id = &first_tool_call.call_id;
// simulate the response
let tool_response = ToolResponse::new(first_tool_call_id, r#"{"weather": "Sunny", "temperature": "32C"}"#);
// Add the tool_calls, tool_response
let chat_req = chat_req.append_message(tool_calls).append_message(tool_response);
let chat_res = client.exec_chat(model, chat_req.clone(), None).await?;
// -- Check
let content = chat_res.content_text_as_str().ok_or("Last response should be message")?;
assert!(content.contains("Paris"), "Should contain 'Paris'");
assert!(content.contains("32"), "Should contain '32'");
if complete_check {
// Note: Not all LLM will output the weather (e.g. Anthropic Haiku)
assert!(content.contains("sunny"), "Should contain 'sunny'");
}
Ok(())
}
// endregion: --- Tools
// region: --- With Resolvers // region: --- With Resolvers
pub async fn common_test_resolver_auth_ok(model: &str, auth_data: AuthData) -> Result<()> { pub async fn common_test_resolver_auth_ok(model: &str, auth_data: AuthData) -> Result<()> {

View File

@@ -1,4 +1,5 @@
use genai::chat::{ChatMessage, ChatRequest}; use genai::chat::{ChatMessage, ChatRequest, Tool};
use serde_json::json;
pub fn seed_chat_req_simple() -> ChatRequest { pub fn seed_chat_req_simple() -> ChatRequest {
ChatRequest::new(vec![ ChatRequest::new(vec![
@@ -7,3 +8,29 @@ pub fn seed_chat_req_simple() -> ChatRequest {
ChatMessage::user("Why is the sky blue?"), ChatMessage::user("Why is the sky blue?"),
]) ])
} }
pub fn seed_chat_req_tool_simple() -> ChatRequest {
ChatRequest::new(vec![
// -- Messages (deactivate to see the differences)
ChatMessage::user("What is the temperature in C, in Paris"),
])
.append_tool(Tool::new("get_weather").with_schema(json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name"
},
"country": {
"type": "string",
"description": "The most likely country of this city name"
},
"unit": {
"type": "string",
"enum": ["C", "F"],
"description": "The temperature unit of the country. C for Celsius, and F for Fahrenheit"
}
},
"required": ["city", "country", "unit"],
})))
}

View File

@@ -45,6 +45,20 @@ async fn test_chat_stream_capture_all_ok() -> Result<()> {
// endregion: --- Chat Stream Tests // endregion: --- Chat Stream Tests
// region: --- Tool Tests
#[tokio::test]
async fn test_tool_simple_ok() -> Result<()> {
common_tests::common_test_tool_simple_ok(MODEL, false).await
}
#[tokio::test]
async fn test_tool_full_flow_ok() -> Result<()> {
common_tests::common_test_tool_full_flow_ok(MODEL, false).await
}
// endregion: --- Tool Tests
// region: --- Resolver Tests // region: --- Resolver Tests
#[tokio::test] #[tokio::test]

View File

@@ -50,6 +50,19 @@ async fn test_chat_stream_capture_all_ok() -> Result<()> {
// endregion: --- Chat Stream Tests // endregion: --- Chat Stream Tests
// region: --- Tool Tests
#[tokio::test]
async fn test_tool_simple_ok() -> Result<()> {
common_tests::common_test_tool_simple_ok(MODEL, true).await
}
#[tokio::test]
async fn test_tool_full_flow_ok() -> Result<()> {
common_tests::common_test_tool_full_flow_ok(MODEL, true).await
}
// endregion: --- Tool Tests
// region: --- Resolver Tests // region: --- Resolver Tests
#[tokio::test] #[tokio::test]