mirror of
https://github.com/mii443/rust-genai.git
synced 2025-08-22 16:25:27 +00:00
+ tool - First pass at adding Function Calling for OpenAI and Anthropic (rel #24)
This commit is contained in:
@ -11,8 +11,8 @@ repository = "https://github.com/jeremychone/rust-genai"
|
||||
|
||||
[lints.rust]
|
||||
unsafe_code = "forbid"
|
||||
# unused = { level = "allow", priority = -1 } # For exploratory dev.
|
||||
missing_docs = "warn"
|
||||
unused = { level = "allow", priority = -1 } # For exploratory dev.
|
||||
# missing_docs = "warn"
|
||||
|
||||
[dependencies]
|
||||
# -- Async
|
||||
|
@ -3,10 +3,11 @@ use crate::adapter::support::get_api_key;
|
||||
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
|
||||
use crate::chat::{
|
||||
ChatOptionsSet, ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse, MessageContent, MetaUsage,
|
||||
ToolCall,
|
||||
};
|
||||
use crate::webc::WebResponse;
|
||||
use crate::Result;
|
||||
use crate::{ClientConfig, ModelIden};
|
||||
use crate::{Error, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use reqwest_eventsource::EventSource;
|
||||
use serde_json::{json, Value};
|
||||
@ -18,9 +19,9 @@ const BASE_URL: &str = "https://api.anthropic.com/v1/";
|
||||
const MAX_TOKENS: u32 = 1024;
|
||||
const ANTRHOPIC_VERSION: &str = "2023-06-01";
|
||||
const MODELS: &[&str] = &[
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
];
|
||||
|
||||
@ -53,7 +54,7 @@ impl Adapter for AnthropicAdapter {
|
||||
let url = Self::get_service_url(model_iden.clone(), service_type);
|
||||
|
||||
// -- api_key (this Adapter requires it)
|
||||
let api_key = get_api_key(model_iden, client_config)?;
|
||||
let api_key = get_api_key(model_iden.clone(), client_config)?;
|
||||
|
||||
let headers = vec![
|
||||
// headers
|
||||
@ -61,7 +62,11 @@ impl Adapter for AnthropicAdapter {
|
||||
("anthropic-version".to_string(), ANTRHOPIC_VERSION.to_string()),
|
||||
];
|
||||
|
||||
let AnthropicRequestParts { system, messages } = Self::into_anthropic_request_parts(chat_req)?;
|
||||
let AnthropicRequestParts {
|
||||
system,
|
||||
messages,
|
||||
tools,
|
||||
} = Self::into_anthropic_request_parts(model_iden, chat_req)?;
|
||||
|
||||
// -- Build the basic payload
|
||||
let mut payload = json!({
|
||||
@ -69,10 +74,15 @@ impl Adapter for AnthropicAdapter {
|
||||
"messages": messages,
|
||||
"stream": stream
|
||||
});
|
||||
|
||||
if let Some(system) = system {
|
||||
payload.x_insert("system", system)?;
|
||||
}
|
||||
|
||||
if let Some(tools) = tools {
|
||||
payload.x_insert("/tools", tools);
|
||||
}
|
||||
|
||||
// -- Add supported ChatOptions
|
||||
if let Some(temperature) = options_set.temperature() {
|
||||
payload.x_insert("temperature", temperature)?;
|
||||
@ -90,27 +100,51 @@ impl Adapter for AnthropicAdapter {
|
||||
|
||||
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
|
||||
let WebResponse { mut body, .. } = web_response;
|
||||
let json_content_items: Vec<Value> = body.x_take("content")?;
|
||||
|
||||
let mut content: Vec<String> = Vec::new();
|
||||
|
||||
// -- Capture the usage
|
||||
let usage = body.x_take("usage").map(Self::into_usage).unwrap_or_default();
|
||||
|
||||
for mut item in json_content_items {
|
||||
let item_text: String = item.x_take("text")?;
|
||||
content.push(item_text);
|
||||
// -- Capture the content
|
||||
// NOTE: Anthropic support a list of content of multitypes but not the ChatResponse
|
||||
// So, the strategy is to:
|
||||
// - List all of the content and capture the text and tool_use
|
||||
// - If there is one or more tool_use, this will take precedence and MessageContent support tool_call list
|
||||
// - Otherwise, the text is concatenated
|
||||
// NOTE: We need to see if the multiple content type text happens and why. If not, we can probably simplify this by just capturing the first one.
|
||||
// Eventually, ChatResponse will have `content: Option<Vec<MessageContent>>` for the multi parts (with images and such)
|
||||
let content_items: Vec<Value> = body.x_take("content")?;
|
||||
|
||||
let mut text_content: Vec<String> = Vec::new();
|
||||
// Note: here tool_calls is probably the exception, so, not creating the vector if not needed
|
||||
let mut tool_calls: Option<Vec<ToolCall>> = None;
|
||||
|
||||
for mut item in content_items {
|
||||
let typ: &str = item.x_get_as("type")?;
|
||||
if typ == "text" {
|
||||
text_content.push(item.x_take("text")?);
|
||||
} else if typ == "tool_use" {
|
||||
let call_id = item.x_take::<String>("id")?;
|
||||
let fn_name = item.x_take::<String>("name")?;
|
||||
// if not found, will be Value::Null
|
||||
let fn_arguments = item.x_take::<Value>("input").unwrap_or_default();
|
||||
let tool_call = ToolCall {
|
||||
call_id,
|
||||
fn_name,
|
||||
fn_arguments,
|
||||
};
|
||||
tool_calls.get_or_insert_with(Vec::new).push(tool_call);
|
||||
}
|
||||
}
|
||||
|
||||
let content = if content.is_empty() {
|
||||
None
|
||||
let content = if let Some(tool_calls) = tool_calls {
|
||||
Some(MessageContent::from(tool_calls))
|
||||
} else {
|
||||
Some(content.join(""))
|
||||
Some(MessageContent::from(text_content.join("\n")))
|
||||
};
|
||||
let content = content.map(MessageContent::from);
|
||||
|
||||
Ok(ChatResponse {
|
||||
model_iden,
|
||||
content,
|
||||
model_iden,
|
||||
usage,
|
||||
})
|
||||
}
|
||||
@ -153,7 +187,7 @@ impl AnthropicAdapter {
|
||||
|
||||
/// Takes the GenAI ChatMessages and constructs the System string and JSON Messages for Anthropic.
|
||||
/// - Will push the `ChatRequest.system` and system message to `AnthropicRequestParts.system`
|
||||
fn into_anthropic_request_parts(chat_req: ChatRequest) -> Result<AnthropicRequestParts> {
|
||||
fn into_anthropic_request_parts(model_iden: ModelIden, chat_req: ChatRequest) -> Result<AnthropicRequestParts> {
|
||||
let mut messages: Vec<Value> = Vec::new();
|
||||
let mut systems: Vec<String> = Vec::new();
|
||||
|
||||
@ -161,32 +195,115 @@ impl AnthropicAdapter {
|
||||
systems.push(system);
|
||||
}
|
||||
|
||||
// -- Process the messages
|
||||
for msg in chat_req.messages {
|
||||
// Note: Will handle more types later
|
||||
let MessageContent::Text(content) = msg.content;
|
||||
|
||||
match msg.role {
|
||||
// for now, system and tool messages go to system
|
||||
ChatRole::System | ChatRole::Tool => systems.push(content),
|
||||
ChatRole::User => messages.push(json! ({"role": "user", "content": content})),
|
||||
ChatRole::Assistant => messages.push(json! ({"role": "assistant", "content": content})),
|
||||
ChatRole::System => {
|
||||
if let MessageContent::Text(content) = msg.content {
|
||||
systems.push(content)
|
||||
}
|
||||
// TODO: Needs to trace/warn that other type are not supported
|
||||
}
|
||||
ChatRole::User => {
|
||||
if let MessageContent::Text(content) = msg.content {
|
||||
messages.push(json! ({"role": "user", "content": content}))
|
||||
}
|
||||
// TODO: Needs to trace/warn that other type are not supported
|
||||
}
|
||||
ChatRole::Assistant => {
|
||||
//
|
||||
match msg.content {
|
||||
MessageContent::Text(content) => {
|
||||
messages.push(json! ({"role": "assistant", "content": content}))
|
||||
}
|
||||
MessageContent::ToolCalls(tool_calls) => {
|
||||
let tool_calls = tool_calls
|
||||
.into_iter()
|
||||
.map(|tool_call| {
|
||||
// see: https://docs.anthropic.com/en/docs/build-with-claude/tool-use#example-of-successful-tool-result
|
||||
json!({
|
||||
"type": "tool_use",
|
||||
"id": tool_call.call_id,
|
||||
"name": tool_call.fn_name,
|
||||
"input": tool_call.fn_arguments,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<Value>>();
|
||||
messages.push(json! ({
|
||||
"role": "assistant",
|
||||
"content": tool_calls
|
||||
}));
|
||||
}
|
||||
// TODO: Probably need to trace/warn that this will be ignored
|
||||
MessageContent::ToolResponses(_) => (),
|
||||
}
|
||||
}
|
||||
ChatRole::Tool => {
|
||||
if let MessageContent::ToolResponses(tool_responses) = msg.content {
|
||||
let tool_responses = tool_responses
|
||||
.into_iter()
|
||||
.map(|tool_response| {
|
||||
json!({
|
||||
"type": "tool_result",
|
||||
"content": tool_response.content,
|
||||
"tool_use_id": tool_response.call_id,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<Value>>();
|
||||
|
||||
// FIXME: MessageContent::ToolResponse should be MessageContent::ToolResponses (even if openAI does require multi Tool message)
|
||||
messages.push(json!({
|
||||
"role": "user",
|
||||
"content": tool_responses
|
||||
}));
|
||||
}
|
||||
// TODO: Probably need to trace/warn that this will be ignored
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -- Create the Anthropic system
|
||||
// NOTE: Anthropic does not have a "role": "system", just a single optional system property
|
||||
let system = if !systems.is_empty() {
|
||||
Some(systems.join("\n"))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(AnthropicRequestParts { system, messages })
|
||||
// -- Process the tools
|
||||
let tools = chat_req.tools.map(|tools| {
|
||||
tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
// TODO: Need to handle the error correctly
|
||||
// TODO: Needs to have a custom serializer (tool should not have to match to a provider)
|
||||
// NOTE: Right now, low probability, so, we just return null if cannto to value.
|
||||
let mut tool_value = json!({
|
||||
"name": tool.name,
|
||||
"input_schema": tool.schema,
|
||||
});
|
||||
|
||||
if let Some(description) = tool.description {
|
||||
tool_value.x_insert("description", description);
|
||||
}
|
||||
tool_value
|
||||
})
|
||||
.collect::<Vec<Value>>()
|
||||
});
|
||||
|
||||
Ok(AnthropicRequestParts {
|
||||
system,
|
||||
messages,
|
||||
tools,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct AnthropicRequestParts {
|
||||
system: Option<String>,
|
||||
messages: Vec<Value>,
|
||||
// TODO: need to add tools
|
||||
tools: Option<Vec<Value>>,
|
||||
}
|
||||
|
||||
// endregion: --- Support
|
||||
|
@ -1,4 +1,5 @@
|
||||
//! API Documentation: https://docs.anthropic.com/en/api/messages
|
||||
//! Tool Documentation: https://docs.anthropic.com/en/docs/build-with-claude/tool-use
|
||||
//! Model Names: https://docs.anthropic.com/en/docs/models-overview
|
||||
//! Pricing: https://www.anthropic.com/pricing#anthropic-api
|
||||
|
||||
|
@ -110,8 +110,8 @@ impl Adapter for CohereAdapter {
|
||||
.map(MessageContent::from);
|
||||
|
||||
Ok(ChatResponse {
|
||||
model_iden,
|
||||
content,
|
||||
model_iden,
|
||||
usage,
|
||||
})
|
||||
}
|
||||
@ -185,13 +185,23 @@ impl CohereAdapter {
|
||||
actual_role: last_chat_msg.role,
|
||||
});
|
||||
}
|
||||
// Will handle more types later
|
||||
let MessageContent::Text(message) = last_chat_msg.content;
|
||||
|
||||
// TODO: Needs to implement tool_calls
|
||||
let MessageContent::Text(message) = last_chat_msg.content else {
|
||||
return Err(Error::MessageContentTypeNotSupported {
|
||||
model_iden,
|
||||
cause: "Only MessageContent::Text supported for this model (for now)",
|
||||
});
|
||||
};
|
||||
|
||||
// -- Build
|
||||
for msg in chat_req.messages {
|
||||
// Note: Will handle more types later
|
||||
let MessageContent::Text(content) = msg.content;
|
||||
let MessageContent::Text(content) = msg.content else {
|
||||
return Err(Error::MessageContentTypeNotSupported {
|
||||
model_iden,
|
||||
cause: "Only MessageContent::Text supported for this model (for now)",
|
||||
});
|
||||
};
|
||||
|
||||
match msg.role {
|
||||
// For now, system and tool go to the system
|
||||
|
@ -121,8 +121,8 @@ impl Adapter for GeminiAdapter {
|
||||
let content = content.map(MessageContent::from);
|
||||
|
||||
Ok(ChatResponse {
|
||||
model_iden,
|
||||
content,
|
||||
model_iden,
|
||||
usage,
|
||||
})
|
||||
}
|
||||
@ -192,8 +192,13 @@ impl GeminiAdapter {
|
||||
|
||||
// -- Build
|
||||
for msg in chat_req.messages {
|
||||
// Note: Will handle more types later
|
||||
let MessageContent::Text(content) = msg.content;
|
||||
// TODO: Needs to implement tool_calls
|
||||
let MessageContent::Text(content) = msg.content else {
|
||||
return Err(Error::MessageContentTypeNotSupported {
|
||||
model_iden,
|
||||
cause: "Only MessageContent::Text supported for this model (for now)",
|
||||
});
|
||||
};
|
||||
|
||||
match msg.role {
|
||||
// For now, system goes as "user" (later, we might have adapter_config.system_to_user_impl)
|
||||
|
@ -3,13 +3,14 @@ use crate::adapter::support::get_api_key;
|
||||
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
|
||||
use crate::chat::{
|
||||
ChatOptionsSet, ChatRequest, ChatResponse, ChatResponseFormat, ChatRole, ChatStream, ChatStreamResponse,
|
||||
MessageContent, MetaUsage,
|
||||
MessageContent, MetaUsage, Tool, ToolCall,
|
||||
};
|
||||
use crate::webc::WebResponse;
|
||||
use crate::{ClientConfig, ModelIden};
|
||||
use crate::{Error, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use reqwest_eventsource::EventSource;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use value_ext::JsonValueExt;
|
||||
|
||||
@ -54,15 +55,31 @@ impl Adapter for OpenAIAdapter {
|
||||
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
|
||||
let WebResponse { mut body, .. } = web_response;
|
||||
|
||||
// -- Capture the usage
|
||||
let usage = body.x_take("usage").map(OpenAIAdapter::into_usage).unwrap_or_default();
|
||||
|
||||
let first_choice: Option<Value> = body.x_take("/choices/0")?;
|
||||
let content: Option<String> = first_choice.map(|mut c| c.x_take("/message/content")).transpose()?;
|
||||
let content = content.map(MessageContent::from);
|
||||
// -- Capture the content
|
||||
let content = if let Some(mut first_choice) = body.x_take::<Option<Value>>("/choices/0")? {
|
||||
if let Some(content) = first_choice
|
||||
.x_take::<Option<String>>("/message/content")?
|
||||
.map(MessageContent::from)
|
||||
{
|
||||
Some(content)
|
||||
} else {
|
||||
first_choice
|
||||
.x_take("/message/tool_calls")
|
||||
.ok()
|
||||
.map(parse_tool_calls)
|
||||
.transpose()?
|
||||
.map(MessageContent::from_tool_calls)
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(ChatResponse {
|
||||
model_iden,
|
||||
content,
|
||||
model_iden,
|
||||
usage,
|
||||
})
|
||||
}
|
||||
@ -117,13 +134,17 @@ impl OpenAIAdapter {
|
||||
|
||||
// -- Build the basic payload
|
||||
let model_name = model_iden.model_name.to_string();
|
||||
let OpenAIRequestParts { messages } = Self::into_openai_request_parts(model_iden, chat_req)?;
|
||||
let OpenAIRequestParts { messages, tools } = Self::into_openai_request_parts(model_iden, chat_req)?;
|
||||
let mut payload = json!({
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"stream": stream
|
||||
});
|
||||
|
||||
if let Some(tools) = tools {
|
||||
payload.x_insert("/tools", tools);
|
||||
}
|
||||
|
||||
// -- Add options
|
||||
let response_format = if let Some(response_format) = options_set.response_format() {
|
||||
match response_format {
|
||||
@ -199,16 +220,17 @@ impl OpenAIAdapter {
|
||||
/// Takes the genai ChatMessages and builds the OpenAIChatRequestParts
|
||||
/// - `genai::ChatRequest.system`, if present, is added as the first message with role 'system'.
|
||||
/// - All messages get added with the corresponding roles (tools are not supported for now)
|
||||
///
|
||||
/// NOTE: Here, the last `true` is for the Ollama variant
|
||||
/// It seems the Ollama compatibility layer does not work well with multiple system messages.
|
||||
/// So, when `true`, it will concatenate the system message into a single one at the beginning
|
||||
fn into_openai_request_parts(model_iden: ModelIden, chat_req: ChatRequest) -> Result<OpenAIRequestParts> {
|
||||
let mut system_messages: Vec<String> = Vec::new();
|
||||
let mut messages: Vec<Value> = Vec::new();
|
||||
|
||||
/// NOTE: For now system_messages is use to fix an issue with the Ollama compatibility layer that does not support multiple system messages.
|
||||
/// So, when ollama, it will concatenate the system message into a single one at the beginning
|
||||
/// NOTE: This might be fixed now, so, we could remove this.
|
||||
let mut system_messages: Vec<String> = Vec::new();
|
||||
|
||||
let ollama_variant = matches!(model_iden.adapter_kind, AdapterKind::Ollama);
|
||||
|
||||
// -- Process the system
|
||||
if let Some(system_msg) = chat_req.system {
|
||||
if ollama_variant {
|
||||
system_messages.push(system_msg)
|
||||
@ -217,37 +239,98 @@ impl OpenAIAdapter {
|
||||
}
|
||||
}
|
||||
|
||||
// -- Process the messages
|
||||
for msg in chat_req.messages {
|
||||
// Note: Will handle more types later
|
||||
let MessageContent::Text(content) = msg.content;
|
||||
|
||||
match msg.role {
|
||||
// For now, system and tool messages go to the system
|
||||
ChatRole::System => {
|
||||
// See note in the function comment
|
||||
if ollama_variant {
|
||||
system_messages.push(content);
|
||||
} else {
|
||||
messages.push(json!({"role": "system", "content": content}))
|
||||
if let MessageContent::Text(content) = msg.content {
|
||||
// NOTE: Ollama does not support multiple system messages
|
||||
|
||||
// See note in the function comment
|
||||
if ollama_variant {
|
||||
system_messages.push(content);
|
||||
} else {
|
||||
messages.push(json!({"role": "system", "content": content}))
|
||||
}
|
||||
}
|
||||
// TODO: Probably need to warn if it is a ToolCalls type of content
|
||||
}
|
||||
ChatRole::User => messages.push(json! ({"role": "user", "content": content})),
|
||||
ChatRole::Assistant => messages.push(json! ({"role": "assistant", "content": content})),
|
||||
ChatRole::User => {
|
||||
if let MessageContent::Text(content) = msg.content {
|
||||
messages.push(json! ({"role": "user", "content": content}));
|
||||
}
|
||||
// TODO: Probably need to warn if it is a ToolCalls type of content
|
||||
}
|
||||
|
||||
ChatRole::Assistant => match msg.content {
|
||||
MessageContent::Text(content) => messages.push(json! ({"role": "assistant", "content": content})),
|
||||
MessageContent::ToolCalls(tool_calls) => {
|
||||
let tool_calls = tool_calls
|
||||
.into_iter()
|
||||
.map(|tool_call| {
|
||||
json!({
|
||||
"type": "function",
|
||||
"id": tool_call.call_id,
|
||||
"function": {
|
||||
"name": tool_call.fn_name,
|
||||
"arguments": tool_call.fn_arguments.to_string(),
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<Value>>();
|
||||
messages.push(json! ({"role": "assistant", "tool_calls": tool_calls}))
|
||||
}
|
||||
// TODO: Probably need to trace/warn that this will be ignored
|
||||
MessageContent::ToolResponses(_) => (),
|
||||
},
|
||||
|
||||
ChatRole::Tool => {
|
||||
return Err(Error::MessageRoleNotSupported {
|
||||
model_iden,
|
||||
role: ChatRole::Tool,
|
||||
})
|
||||
if let MessageContent::ToolResponses(tool_responses) = msg.content {
|
||||
for tool_response in tool_responses {
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"content": tool_response.content,
|
||||
"tool_call_id": tool_response.call_id,
|
||||
}))
|
||||
}
|
||||
}
|
||||
// TODO: Probably need to trace/warn that this will be ignored
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -- Finalize the system messages ollama case
|
||||
if !system_messages.is_empty() {
|
||||
let system_message = system_messages.join("\n");
|
||||
messages.insert(0, json!({"role": "system", "content": system_message}));
|
||||
}
|
||||
|
||||
Ok(OpenAIRequestParts { messages })
|
||||
// -- Process the tools
|
||||
let tools = chat_req.tools.map(|tools| {
|
||||
tools
|
||||
.into_iter()
|
||||
.map(|tool| {
|
||||
// TODO: Need to handle the error correctly
|
||||
// TODO: Needs to have a custom serializer (tool should not have to match to a provider)
|
||||
// NOTE: Right now, low probability, so, we just return null if cannto to value.
|
||||
json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.schema,
|
||||
// TODO: If we need to support `strict: true` we need to add additionalProperties: false into the schema
|
||||
// above (like structured output)
|
||||
"strict": false,
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect::<Vec<Value>>()
|
||||
});
|
||||
|
||||
Ok(OpenAIRequestParts { messages, tools })
|
||||
}
|
||||
}
|
||||
|
||||
@ -255,6 +338,59 @@ impl OpenAIAdapter {
|
||||
|
||||
struct OpenAIRequestParts {
|
||||
messages: Vec<Value>,
|
||||
tools: Option<Vec<Value>>,
|
||||
}
|
||||
|
||||
fn parse_tool_calls(raw_tool_calls: Value) -> Result<Vec<ToolCall>> {
|
||||
let Value::Array(raw_tool_calls) = raw_tool_calls else {
|
||||
return Err(Error::InvalidJsonResponseElement {
|
||||
info: "tool calls is not an array",
|
||||
});
|
||||
};
|
||||
|
||||
let tool_calls = raw_tool_calls.into_iter().map(parse_tool_call).collect::<Result<Vec<_>>>()?;
|
||||
|
||||
Ok(tool_calls)
|
||||
}
|
||||
|
||||
fn parse_tool_call(raw_tool_call: Value) -> Result<ToolCall> {
|
||||
// Define a helper struct to match the original JSON structure.
|
||||
#[derive(Deserialize)]
|
||||
struct IterimToolFnCall {
|
||||
id: String,
|
||||
#[serde(rename = "type")]
|
||||
r#type: String,
|
||||
function: IterimFunction,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct IterimFunction {
|
||||
name: String,
|
||||
arguments: Value,
|
||||
}
|
||||
|
||||
let iterim = serde_json::from_value::<IterimToolFnCall>(raw_tool_call)?;
|
||||
|
||||
let fn_name = iterim.function.name;
|
||||
|
||||
// For now support Object only, and parse the eventual string as a json value.
|
||||
// Eventually, we might check pricing
|
||||
let fn_arguments = match iterim.function.arguments {
|
||||
Value::Object(obj) => Value::Object(obj),
|
||||
Value::String(txt) => serde_json::from_str(&txt)?,
|
||||
_ => {
|
||||
return Err(Error::InvalidJsonResponseElement {
|
||||
info: "tool call arguments is not an object",
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
// Then, map the fields of the helper struct to the flat structure.
|
||||
Ok(ToolCall {
|
||||
call_id: iterim.id,
|
||||
fn_name,
|
||||
fn_arguments,
|
||||
})
|
||||
}
|
||||
|
||||
// endregion: --- Support
|
||||
|
71
src/chat/chat_message.rs
Normal file
71
src/chat/chat_message.rs
Normal file
@ -0,0 +1,71 @@
|
||||
use crate::chat::{MessageContent, ToolCall, ToolResponse};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// An individual chat message.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
/// The role of the message.
|
||||
pub role: ChatRole,
|
||||
|
||||
/// The content of the message.
|
||||
pub content: MessageContent,
|
||||
}
|
||||
|
||||
/// Constructors
|
||||
impl ChatMessage {
|
||||
/// Create a new ChatMessage with the role `ChatRole::System`.
|
||||
pub fn system(content: impl Into<MessageContent>) -> Self {
|
||||
Self {
|
||||
role: ChatRole::System,
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new ChatMessage with the role `ChatRole::Assistant`.
|
||||
pub fn assistant(content: impl Into<MessageContent>) -> Self {
|
||||
Self {
|
||||
role: ChatRole::Assistant,
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new ChatMessage with the role `ChatRole::User`.
|
||||
pub fn user(content: impl Into<MessageContent>) -> Self {
|
||||
Self {
|
||||
role: ChatRole::User,
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Chat roles.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum ChatRole {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
Tool,
|
||||
}
|
||||
|
||||
// region: --- Froms
|
||||
|
||||
impl From<Vec<ToolCall>> for ChatMessage {
|
||||
fn from(tool_calls: Vec<ToolCall>) -> Self {
|
||||
Self {
|
||||
role: ChatRole::Assistant,
|
||||
content: MessageContent::from(tool_calls),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ToolResponse> for ChatMessage {
|
||||
fn from(value: ToolResponse) -> Self {
|
||||
Self {
|
||||
role: ChatRole::Tool,
|
||||
content: MessageContent::from(value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// endregion: --- Froms
|
@ -5,7 +5,7 @@
|
||||
//! Note 1: In the future, we will probably allow setting the client
|
||||
//! Note 2: Extracting it from the `ChatRequest` object allows for better reusability of each component.
|
||||
|
||||
use crate::chat::chat_response_format::ChatResponseFormat;
|
||||
use crate::chat::chat_req_response_format::ChatResponseFormat;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Chat Options that are taken into account for any `Client::exec...` calls.
|
||||
|
@ -2,7 +2,7 @@ use derive_more::From;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
/// The chat response format to be sent back by the LLM.
|
||||
/// The chat response format for the ChatRequest for structured output.
|
||||
/// This will be taken into consideration only if the provider supports it.
|
||||
///
|
||||
/// > Note: Currently, the AI Providers will not report an error if not supported. It will just be ignored.
|
@ -1,6 +1,6 @@
|
||||
//! This module contains all the types related to a Chat Request (except ChatOptions, which has its own file).
|
||||
|
||||
use crate::chat::MessageContent;
|
||||
use crate::chat::{ChatMessage, ChatRole, MessageContent, Tool};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// region: --- ChatRequest
|
||||
@ -13,13 +13,19 @@ pub struct ChatRequest {
|
||||
|
||||
/// The messages of the request.
|
||||
pub messages: Vec<ChatMessage>,
|
||||
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
}
|
||||
|
||||
/// Constructors
|
||||
impl ChatRequest {
|
||||
/// Create a new ChatRequest with the given messages.
|
||||
pub fn new(messages: Vec<ChatMessage>) -> Self {
|
||||
Self { messages, system: None }
|
||||
Self {
|
||||
messages,
|
||||
system: None,
|
||||
tools: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// From the `.system` property content.
|
||||
@ -27,6 +33,7 @@ impl ChatRequest {
|
||||
Self {
|
||||
system: Some(content.into()),
|
||||
messages: Vec::new(),
|
||||
tools: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -34,7 +41,8 @@ impl ChatRequest {
|
||||
pub fn from_user(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
system: None,
|
||||
messages: vec![ChatMessage::user(content)],
|
||||
messages: vec![ChatMessage::user(content.into())],
|
||||
tools: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -48,8 +56,18 @@ impl ChatRequest {
|
||||
}
|
||||
|
||||
/// Append a message to the request.
|
||||
pub fn append_message(mut self, msg: ChatMessage) -> Self {
|
||||
self.messages.push(msg);
|
||||
pub fn append_message(mut self, msg: impl Into<ChatMessage>) -> Self {
|
||||
self.messages.push(msg.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
|
||||
self.tools = Some(tools);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn append_tool(mut self, tool: impl Into<Tool>) -> Self {
|
||||
self.tools.get_or_insert_with(Vec::new).push(tool.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
@ -65,6 +83,8 @@ impl ChatRequest {
|
||||
.chain(self.messages.iter().filter_map(|message| match message.role {
|
||||
ChatRole::System => match message.content {
|
||||
MessageContent::Text(ref content) => Some(content.as_str()),
|
||||
/// If system content is not text, then, we do not add it for now.
|
||||
_ => None,
|
||||
},
|
||||
_ => None,
|
||||
}))
|
||||
@ -97,74 +117,3 @@ impl ChatRequest {
|
||||
}
|
||||
|
||||
// endregion: --- ChatRequest
|
||||
|
||||
// region: --- ChatMessage
|
||||
|
||||
/// An individual chat message.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
/// The role of the message.
|
||||
pub role: ChatRole,
|
||||
|
||||
/// The content of the message.
|
||||
pub content: MessageContent,
|
||||
|
||||
/// Extra information about the message.
|
||||
pub extra: Option<MessageExtra>,
|
||||
}
|
||||
|
||||
/// Constructors
|
||||
impl ChatMessage {
|
||||
/// Create a new ChatMessage with the role `ChatRole::System`.
|
||||
pub fn system(content: impl Into<MessageContent>) -> Self {
|
||||
Self {
|
||||
role: ChatRole::System,
|
||||
content: content.into(),
|
||||
extra: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new ChatMessage with the role `ChatRole::Assistant`.
|
||||
pub fn assistant(content: impl Into<MessageContent>) -> Self {
|
||||
Self {
|
||||
role: ChatRole::Assistant,
|
||||
content: content.into(),
|
||||
extra: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new ChatMessage with the role `ChatRole::User`.
|
||||
pub fn user(content: impl Into<MessageContent>) -> Self {
|
||||
Self {
|
||||
role: ChatRole::User,
|
||||
content: content.into(),
|
||||
extra: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Chat roles.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum ChatRole {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
Tool,
|
||||
}
|
||||
|
||||
/// NOTE: DO NOT USE, just a placeholder for now.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub enum MessageExtra {
|
||||
Tool(ToolExtra),
|
||||
}
|
||||
|
||||
/// NOTE: DO NOT USE, just a placeholder for now.
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolExtra {
|
||||
tool_id: String,
|
||||
}
|
||||
|
||||
// endregion: --- ChatMessage
|
@ -2,7 +2,7 @@
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::chat::{ChatStream, MessageContent};
|
||||
use crate::chat::{ChatStream, MessageContent, ToolCall};
|
||||
use crate::ModelIden;
|
||||
|
||||
// region: --- ChatResponse
|
||||
@ -13,12 +13,12 @@ pub struct ChatResponse {
|
||||
/// The eventual content of the chat response
|
||||
pub content: Option<MessageContent>,
|
||||
|
||||
/// The eventual usage of the chat response
|
||||
pub usage: MetaUsage,
|
||||
|
||||
/// The Model Identifier (AdapterKind/ModelName) used for this request.
|
||||
/// > NOTE: This might be different from the request model if changed by the ModelMapper
|
||||
pub model_iden: ModelIden,
|
||||
|
||||
/// The eventual usage of the chat response
|
||||
pub usage: MetaUsage,
|
||||
}
|
||||
|
||||
// Getters
|
||||
@ -34,6 +34,22 @@ impl ChatResponse {
|
||||
pub fn content_text_into_string(self) -> Option<String> {
|
||||
self.content.and_then(MessageContent::text_into_string)
|
||||
}
|
||||
|
||||
pub fn tool_calls(&self) -> Option<Vec<&ToolCall>> {
|
||||
if let Some(MessageContent::ToolCalls(tool_calls)) = self.content.as_ref() {
|
||||
Some(tool_calls.iter().collect())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_tool_calls(self) -> Option<Vec<ToolCall>> {
|
||||
if let Some(MessageContent::ToolCalls(tool_calls)) = self.content {
|
||||
Some(tool_calls)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// endregion: --- ChatResponse
|
||||
|
@ -1,19 +1,34 @@
|
||||
use crate::chat::{ToolCall, ToolResponse};
|
||||
use derive_more::derive::From;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Currently, it only supports Text,
|
||||
/// but the goal is to support multi-part message content (see below)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, From)]
|
||||
pub enum MessageContent {
|
||||
/// Text content
|
||||
Text(String),
|
||||
|
||||
/// Tool calls
|
||||
#[from]
|
||||
ToolCalls(Vec<ToolCall>),
|
||||
|
||||
/// Tool call Responses
|
||||
#[from]
|
||||
ToolResponses(Vec<ToolResponse>),
|
||||
}
|
||||
|
||||
/// Constructors
|
||||
impl MessageContent {
|
||||
/// Create a new MessageContent with the Text variant
|
||||
pub fn text(content: impl Into<String>) -> Self {
|
||||
pub fn from_text(content: impl Into<String>) -> Self {
|
||||
MessageContent::Text(content.into())
|
||||
}
|
||||
|
||||
/// Create a new MessageContent with the ToolCalls variant
|
||||
pub fn from_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
|
||||
MessageContent::ToolCalls(tool_calls)
|
||||
}
|
||||
}
|
||||
|
||||
/// Getters
|
||||
@ -25,6 +40,8 @@ impl MessageContent {
|
||||
pub fn text_as_str(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessageContent::Text(content) => Some(content.as_str()),
|
||||
MessageContent::ToolCalls(_) => None,
|
||||
MessageContent::ToolResponses(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -36,29 +53,44 @@ impl MessageContent {
|
||||
pub fn text_into_string(self) -> Option<String> {
|
||||
match self {
|
||||
MessageContent::Text(content) => Some(content),
|
||||
MessageContent::ToolCalls(_) => None,
|
||||
MessageContent::ToolResponses(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if the text content is empty (for now)
|
||||
/// Later, this will also validate each variant to check if they can be considered "empty"
|
||||
/// Checks if the text content or the tools calls is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
MessageContent::Text(content) => content.is_empty(),
|
||||
MessageContent::ToolCalls(tool_calls) => tool_calls.is_empty(),
|
||||
MessageContent::ToolResponses(tool_responses) => tool_responses.is_empty(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// region: --- Froms
|
||||
|
||||
/// Blanket implementation for MessageContent::Text for anything that implements Into<String>
|
||||
/// Note: This means that when we support base64 as images, it should not use `.into()` for MessageContent.
|
||||
/// It should be acceptable but may need reassessment.
|
||||
impl<T> From<T> for MessageContent
|
||||
where
|
||||
T: Into<String>,
|
||||
{
|
||||
fn from(s: T) -> Self {
|
||||
MessageContent::text(s)
|
||||
impl From<String> for MessageContent {
|
||||
fn from(s: String) -> Self {
|
||||
MessageContent::from_text(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a str> for MessageContent {
|
||||
fn from(s: &'a str) -> Self {
|
||||
MessageContent::from_text(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&String> for MessageContent {
|
||||
fn from(s: &String) -> Self {
|
||||
MessageContent::from_text(s.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ToolResponse> for MessageContent {
|
||||
fn from(tool_response: ToolResponse) -> Self {
|
||||
MessageContent::ToolResponses(vec![tool_response])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,19 +3,21 @@
|
||||
|
||||
// region: --- Modules
|
||||
|
||||
mod chat_message;
|
||||
mod chat_options;
|
||||
mod chat_req;
|
||||
mod chat_req_response_format;
|
||||
mod chat_request;
|
||||
mod chat_res;
|
||||
mod chat_response_format;
|
||||
mod chat_stream;
|
||||
mod message_content;
|
||||
mod tool;
|
||||
|
||||
// -- Flatten
|
||||
pub use chat_message::*;
|
||||
pub use chat_options::*;
|
||||
pub use chat_req::*;
|
||||
pub use chat_req_response_format::*;
|
||||
pub use chat_request::*;
|
||||
pub use chat_res::*;
|
||||
pub use chat_response_format::*;
|
||||
pub use chat_stream::*;
|
||||
pub use message_content::*;
|
||||
pub use tool::*;
|
||||
|
@ -1,14 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
/// NOT USED FOR NOW
|
||||
/// > For later, it will be used for function calling
|
||||
/// > It will probably use the JsonSpec type we had in the response format,
|
||||
/// > or have a `From<JsonSpec>` implementation.
|
||||
#[allow(unused)] // Not used yet
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
fn_name: String,
|
||||
fn_description: String,
|
||||
params: Value,
|
||||
}
|
11
src/chat/tool/mod.rs
Normal file
11
src/chat/tool/mod.rs
Normal file
@ -0,0 +1,11 @@
|
||||
// region: --- Modules
|
||||
|
||||
mod tool_base;
|
||||
mod tool_call;
|
||||
mod tool_response;
|
||||
|
||||
pub use tool_base::*;
|
||||
pub use tool_call::*;
|
||||
pub use tool_response::*;
|
||||
|
||||
// endregion: --- Modules
|
64
src/chat/tool/tool_base.rs
Normal file
64
src/chat/tool/tool_base.rs
Normal file
@ -0,0 +1,64 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
/// The tool name, which is typically the function name
|
||||
/// e.g., `get_weather`
|
||||
pub name: String,
|
||||
|
||||
/// The description of the tool which will be used by the LLM to understand the context/usage of this tool
|
||||
pub description: Option<String>,
|
||||
|
||||
/// The json-schema for the parameters
|
||||
/// e.g.,
|
||||
/// ```json
|
||||
/// json!({
|
||||
/// "type": "object",
|
||||
/// "properties": {
|
||||
/// "city": {
|
||||
/// "type": "string",
|
||||
/// "description": "The city name"
|
||||
/// },
|
||||
/// "country": {
|
||||
/// "type": "string",
|
||||
/// "description": "The most likely country of this city name"
|
||||
/// },
|
||||
/// "unit": {
|
||||
/// "type": "string",
|
||||
/// "enum": ["C", "F"],
|
||||
/// "description": "The temperature unit of the country. C for Celsius, and F for Fahrenheit"
|
||||
/// }
|
||||
/// },
|
||||
/// "required": ["city", "country", "unit"],
|
||||
/// })
|
||||
/// ```
|
||||
pub schema: Option<Value>,
|
||||
}
|
||||
|
||||
/// Constructor
|
||||
impl Tool {
|
||||
pub fn new(name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
description: None,
|
||||
schema: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// region: --- Setters
|
||||
|
||||
impl Tool {
|
||||
pub fn with_description(mut self, description: impl Into<String>) -> Self {
|
||||
self.description = Some(description.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_schema(mut self, parameters: Value) -> Self {
|
||||
self.schema = Some(parameters);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// endregion: --- Setters
|
10
src/chat/tool/tool_call.rs
Normal file
10
src/chat/tool/tool_call.rs
Normal file
@ -0,0 +1,10 @@
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
/// The tool call function name and arguments send back by the LLM.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub call_id: String,
|
||||
pub fn_name: String,
|
||||
pub fn_arguments: Value,
|
||||
}
|
29
src/chat/tool/tool_response.rs
Normal file
29
src/chat/tool/tool_response.rs
Normal file
@ -0,0 +1,29 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolResponse {
|
||||
pub call_id: String,
|
||||
// for now, just string (would probably be serialized json)
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// constructor
|
||||
impl ToolResponse {
|
||||
pub fn new(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
call_id: tool_call_id.into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Getters
|
||||
impl ToolResponse {
|
||||
fn tool_call_id(&self) -> &str {
|
||||
&self.call_id
|
||||
}
|
||||
|
||||
fn content(&self) -> &str {
|
||||
&self.content
|
||||
}
|
||||
}
|
12
src/error.rs
12
src/error.rs
@ -23,12 +23,19 @@ pub enum Error {
|
||||
model_iden: ModelIden,
|
||||
role: ChatRole,
|
||||
},
|
||||
MessageContentTypeNotSupported {
|
||||
model_iden: ModelIden,
|
||||
cause: &'static str,
|
||||
},
|
||||
JsonModeWithoutInstruction,
|
||||
|
||||
// -- Chat Output
|
||||
NoChatResponse {
|
||||
model_iden: ModelIden,
|
||||
},
|
||||
InvalidJsonResponseElement {
|
||||
info: &'static str,
|
||||
},
|
||||
|
||||
// -- Auth
|
||||
RequiresApiKey {
|
||||
@ -77,14 +84,15 @@ pub enum Error {
|
||||
resolver_error: resolver::Error,
|
||||
},
|
||||
|
||||
// -- Utils
|
||||
|
||||
// -- Externals
|
||||
#[from]
|
||||
EventSourceClone(reqwest_eventsource::CannotCloneRequestError),
|
||||
#[from]
|
||||
JsonValueExt(JsonValueExtError),
|
||||
ReqwestEventSource(reqwest_eventsource::Error),
|
||||
// Note: will probably need to remvoe this one to give more context
|
||||
#[from]
|
||||
SerdeJson(serde_json::Error),
|
||||
}
|
||||
|
||||
// region: --- Error Boilerplate
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::get_option_value;
|
||||
use crate::support::{extract_stream_end, seed_chat_req_simple, Result};
|
||||
use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatResponseFormat, JsonSpec};
|
||||
use crate::support::{extract_stream_end, seed_chat_req_simple, seed_chat_req_tool_simple, Result};
|
||||
use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatResponseFormat, JsonSpec, Tool, ToolResponse};
|
||||
use genai::resolver::{AuthData, AuthResolver, AuthResolverFn, IntoAuthResolverFn};
|
||||
use genai::{Client, ClientConfig, ModelIden};
|
||||
use serde_json::{json, Value};
|
||||
@ -260,6 +260,68 @@ pub async fn common_test_chat_stream_capture_all_ok(model: &str) -> Result<()> {
|
||||
|
||||
// endregion: --- Chat Stream Tests
|
||||
|
||||
// region: --- Tools
|
||||
|
||||
/// Just making the tool request, and checking the tool call response
|
||||
/// `complete_check` if for LLMs that are better at giving back the unit and weather.
|
||||
pub async fn common_test_tool_simple_ok(model: &str, complete_check: bool) -> Result<()> {
|
||||
// -- Setup & Fixtures
|
||||
let client = Client::default();
|
||||
let chat_req = seed_chat_req_tool_simple();
|
||||
|
||||
// -- Exec
|
||||
let chat_res = client.exec_chat(model, chat_req, None).await?;
|
||||
|
||||
// -- Check
|
||||
let mut tool_calls = chat_res.tool_calls().ok_or("Should have tool calls")?;
|
||||
let tool_call = tool_calls.pop().ok_or("Should have at least one tool call")?;
|
||||
assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("city")?, "Paris");
|
||||
assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("country")?, "France");
|
||||
if complete_check {
|
||||
// Note: Not all LLM will output the weather (e.g. Anthropic Haiku)
|
||||
assert_eq!(tool_call.fn_arguments.x_get_as::<&str>("unit")?, "C");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// `complete_check` if for LLMs that are better at giving back the unit and weather.
|
||||
///
|
||||
pub async fn common_test_tool_full_flow_ok(model: &str, complete_check: bool) -> Result<()> {
|
||||
// -- Setup & Fixtures
|
||||
let client = Client::default();
|
||||
let mut chat_req = seed_chat_req_tool_simple();
|
||||
|
||||
// -- Exec first request to get the tool calls
|
||||
let chat_res = client.exec_chat(model, chat_req.clone(), None).await?;
|
||||
let tool_calls = chat_res.into_tool_calls().ok_or("Should have tool calls in chat_res")?;
|
||||
|
||||
// -- Exec the second request
|
||||
// get the tool call id (first one)
|
||||
let first_tool_call = tool_calls.first().ok_or("Should have at least one tool call")?;
|
||||
let first_tool_call_id = &first_tool_call.call_id;
|
||||
// simulate the response
|
||||
let tool_response = ToolResponse::new(first_tool_call_id, r#"{"weather": "Sunny", "temperature": "32C"}"#);
|
||||
|
||||
// Add the tool_calls, tool_response
|
||||
let chat_req = chat_req.append_message(tool_calls).append_message(tool_response);
|
||||
|
||||
let chat_res = client.exec_chat(model, chat_req.clone(), None).await?;
|
||||
|
||||
// -- Check
|
||||
let content = chat_res.content_text_as_str().ok_or("Last response should be message")?;
|
||||
assert!(content.contains("Paris"), "Should contain 'Paris'");
|
||||
assert!(content.contains("32"), "Should contain '32'");
|
||||
if complete_check {
|
||||
// Note: Not all LLM will output the weather (e.g. Anthropic Haiku)
|
||||
assert!(content.contains("sunny"), "Should contain 'sunny'");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// endregion: --- Tools
|
||||
|
||||
// region: --- With Resolvers
|
||||
|
||||
pub async fn common_test_resolver_auth_ok(model: &str, auth_data: AuthData) -> Result<()> {
|
||||
|
@ -1,4 +1,5 @@
|
||||
use genai::chat::{ChatMessage, ChatRequest};
|
||||
use genai::chat::{ChatMessage, ChatRequest, Tool};
|
||||
use serde_json::json;
|
||||
|
||||
pub fn seed_chat_req_simple() -> ChatRequest {
|
||||
ChatRequest::new(vec![
|
||||
@ -7,3 +8,29 @@ pub fn seed_chat_req_simple() -> ChatRequest {
|
||||
ChatMessage::user("Why is the sky blue?"),
|
||||
])
|
||||
}
|
||||
|
||||
pub fn seed_chat_req_tool_simple() -> ChatRequest {
|
||||
ChatRequest::new(vec![
|
||||
// -- Messages (deactivate to see the differences)
|
||||
ChatMessage::user("What is the temperature in C, in Paris"),
|
||||
])
|
||||
.append_tool(Tool::new("get_weather").with_schema(json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
},
|
||||
"country": {
|
||||
"type": "string",
|
||||
"description": "The most likely country of this city name"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["C", "F"],
|
||||
"description": "The temperature unit of the country. C for Celsius, and F for Fahrenheit"
|
||||
}
|
||||
},
|
||||
"required": ["city", "country", "unit"],
|
||||
})))
|
||||
}
|
||||
|
@ -45,6 +45,20 @@ async fn test_chat_stream_capture_all_ok() -> Result<()> {
|
||||
|
||||
// endregion: --- Chat Stream Tests
|
||||
|
||||
// region: --- Tool Tests
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_simple_ok() -> Result<()> {
|
||||
common_tests::common_test_tool_simple_ok(MODEL, false).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_full_flow_ok() -> Result<()> {
|
||||
common_tests::common_test_tool_full_flow_ok(MODEL, false).await
|
||||
}
|
||||
|
||||
// endregion: --- Tool Tests
|
||||
|
||||
// region: --- Resolver Tests
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -50,6 +50,19 @@ async fn test_chat_stream_capture_all_ok() -> Result<()> {
|
||||
|
||||
// endregion: --- Chat Stream Tests
|
||||
|
||||
// region: --- Tool Tests
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_simple_ok() -> Result<()> {
|
||||
common_tests::common_test_tool_simple_ok(MODEL, true).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_full_flow_ok() -> Result<()> {
|
||||
common_tests::common_test_tool_full_flow_ok(MODEL, true).await
|
||||
}
|
||||
// endregion: --- Tool Tests
|
||||
|
||||
// region: --- Resolver Tests
|
||||
|
||||
#[tokio::test]
|
||||
|
Reference in New Issue
Block a user