From 59f0b149b05f21c0ae0cc4b5923a784ee702c2df Mon Sep 17 00:00:00 2001 From: Adam Strojek Date: Mon, 9 Dec 2024 13:14:25 +0100 Subject: [PATCH] Initial image support Fix image structure for OpenAI API Implement base64 image support for OpenAI Image Support: Some APIs (Gemini) require mime type for URL and Base64 format Image Support: Update OpenAI and Anthropic API to support new image structure Image Support: Add Gemini 2.0 Flash Experimental support and implement Image support Image Support: Create example with Image support Image Support: Fix rebase issue Image Support: Fix example and make it runnable from cargo --- Cargo.toml | 4 + examples/c07-image.rs | 35 +++++++++ .../adapters/anthropic/adapter_impl.rs | 36 +++++++-- src/adapter/adapters/gemini/adapter_impl.rs | 67 +++++++++++++--- src/adapter/adapters/openai/adapter_impl.rs | 32 ++++++-- src/chat/message_content.rs | 77 +++++++++++++------ tests/support/seeders.rs | 17 +++- 7 files changed, 221 insertions(+), 47 deletions(-) create mode 100644 examples/c07-image.rs diff --git a/Cargo.toml b/Cargo.toml index 0bb934b..5ce0394 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,10 @@ keywords = ["generative-ai","openai","chatgpt","gemini","ollama"] homepage = "https://github.com/jeremychone/rust-genai" repository = "https://github.com/jeremychone/rust-genai" +[[example]] +name = "images" +path = "examples/c07-image.rs" + [lints.rust] unsafe_code = "forbid" # unused = { level = "allow", priority = -1 } # For exploratory dev. diff --git a/examples/c07-image.rs b/examples/c07-image.rs new file mode 100644 index 0000000..deb9e46 --- /dev/null +++ b/examples/c07-image.rs @@ -0,0 +1,35 @@ +//! This example demonstrates how to properly attach image to the conversations + +use genai::chat::printer::print_chat_stream; +use genai::chat::{ChatMessage, ChatRequest, ContentPart, ImageSource}; +use genai::Client; + +const MODEL: &str = "gpt-4o-mini"; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::default(); + + let question = "What is in this picture?"; + + let mut chat_req = ChatRequest::default().with_system("Answer in one sentence"); + // This is similar to sending initial system chat messages (which will be cumulative with system chat messages) + chat_req = chat_req.append_message(ChatMessage::user( + vec![ + ContentPart::Text(question.to_string()), + ContentPart::Image { + content: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(), + content_type: "image/png".to_string(), + source: ImageSource::Url, + } + ] + )); + + println!("\n--- Question:\n{question}"); + let chat_res = client.exec_chat_stream(MODEL, chat_req.clone(), None).await?; + + println!("\n--- Answer: (streaming)"); + let assistant_answer = print_chat_stream(chat_res, None).await?; + + Ok(()) +} diff --git a/src/adapter/adapters/anthropic/adapter_impl.rs b/src/adapter/adapters/anthropic/adapter_impl.rs index 3db1b0a..8027613 100644 --- a/src/adapter/adapters/anthropic/adapter_impl.rs +++ b/src/adapter/adapters/anthropic/adapter_impl.rs @@ -3,7 +3,7 @@ use crate::adapter::anthropic::AnthropicStreamer; use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; use crate::chat::{ ChatOptionsSet, ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse, MessageContent, MetaUsage, - ToolCall, + ToolCall, ContentPart, ImageSource, }; use crate::resolver::{AuthData, Endpoint}; use crate::webc::WebResponse; @@ -236,10 +236,35 @@ impl AnthropicAdapter { // TODO: Needs to trace/warn that other type are not supported } ChatRole::User => { - if let MessageContent::Text(content) = msg.content { - messages.push(json! ({"role": "user", "content": content})) - } - // TODO: Needs to trace/warn that other type are not supported + let content = match msg.content { + MessageContent::Text(content) => json!(content), + MessageContent::Parts(parts) => { + json!(parts.iter().map(|part| match part { + ContentPart::Text(text) => json!({"type": "text", "text": text.clone()}), + ContentPart::Image{content, content_type, source} => { + match source { + ImageSource::Url => todo!("Anthropic doesn't support images from URL, need to handle it gracefully"), + ImageSource::Base64 => json!({ + "type": "image", + "source": { + "type": "base64", + "media_type": content_type, + "data": content, + }, + }), + } + }, + }).collect::>()) + }, + // Use `match` instead of `if let`. This will allow to future-proof this + // implementation in case some new message content types would appear, + // this way library would not compile if not all methods are implemented + // continue would allow to gracefully skip pushing unserializable message + // TODO: Probably need to warn if it is a ToolCalls type of content + MessageContent::ToolCalls(_) => continue, + MessageContent::ToolResponses(_) => continue, + }; + messages.push(json! ({"role": "user", "content": content})); } ChatRole::Assistant => { // @@ -266,6 +291,7 @@ impl AnthropicAdapter { })); } // TODO: Probably need to trace/warn that this will be ignored + MessageContent::Parts(_) => (), MessageContent::ToolResponses(_) => (), } } diff --git a/src/adapter/adapters/gemini/adapter_impl.rs b/src/adapter/adapters/gemini/adapter_impl.rs index a763277..ac3525e 100644 --- a/src/adapter/adapters/gemini/adapter_impl.rs +++ b/src/adapter/adapters/gemini/adapter_impl.rs @@ -3,7 +3,7 @@ use crate::adapter::gemini::GeminiStreamer; use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; use crate::chat::{ ChatOptionsSet, ChatRequest, ChatResponse, ChatResponseFormat, ChatRole, ChatStream, ChatStreamResponse, - MessageContent, MetaUsage, + MessageContent, MetaUsage, ContentPart, ImageSource }; use crate::resolver::{AuthData, Endpoint}; use crate::webc::{WebResponse, WebStream}; @@ -21,6 +21,7 @@ const MODELS: &[&str] = &[ "gemini-1.5-flash-8b", "gemini-1.0-pro", "gemini-1.5-flash-latest", + "gemini-2.0-flash-exp" ]; // curl \ @@ -214,19 +215,61 @@ impl GeminiAdapter { // -- Build for msg in chat_req.messages { - // TODO: Needs to implement tool_calls - let MessageContent::Text(content) = msg.content else { - return Err(Error::MessageContentTypeNotSupported { - model_iden, - cause: "Only MessageContent::Text supported for this model (for now)", - }); - }; - match msg.role { // For now, system goes as "user" (later, we might have adapter_config.system_to_user_impl) - ChatRole::System => systems.push(content), - ChatRole::User => contents.push(json! ({"role": "user", "parts": [{"text": content}]})), - ChatRole::Assistant => contents.push(json! ({"role": "model", "parts": [{"text": content}]})), + ChatRole::System => { + let MessageContent::Text(content) = msg.content else { + return Err(Error::MessageContentTypeNotSupported { + model_iden, + cause: "Only MessageContent::Text supported for this model (for now)", + }); + }; + systems.push(content) + }, + ChatRole::User => { + let content = match msg.content { + MessageContent::Text(content) => json!([{"text": content}]), + MessageContent::Parts(parts) => { + json!(parts.iter().map(|part| match part { + ContentPart::Text(text) => json!({"text": text.clone()}), + ContentPart::Image{content, content_type, source} => { + match source { + ImageSource::Url => json!({ + "file_data": { + "mime_type": content_type, + "file_uri": content + } + }), + ImageSource::Base64 => json!({ + "inline_data": { + "mime_type": content_type, + "data": content + } + }), + } + }, + }).collect::>()) + }, + // Use `match` instead of `if let`. This will allow to future-proof this + // implementation in case some new message content types would appear, + // this way library would not compile if not all methods are implemented + // continue would allow to gracefully skip pushing unserializable message + // TODO: Probably need to warn if it is a ToolCalls type of content + MessageContent::ToolCalls(_) => continue, + MessageContent::ToolResponses(_) => continue, + }; + + contents.push(json!({"role": "user", "parts": content})); + }, + ChatRole::Assistant => { + let MessageContent::Text(content) = msg.content else { + return Err(Error::MessageContentTypeNotSupported { + model_iden, + cause: "Only MessageContent::Text supported for this model (for now)", + }); + }; + contents.push(json!({"role": "model", "parts": [{"text": content}]})) + }, ChatRole::Tool => { return Err(Error::MessageRoleNotSupported { model_iden, diff --git a/src/adapter/adapters/openai/adapter_impl.rs b/src/adapter/adapters/openai/adapter_impl.rs index 3c305c5..3b99293 100644 --- a/src/adapter/adapters/openai/adapter_impl.rs +++ b/src/adapter/adapters/openai/adapter_impl.rs @@ -3,7 +3,7 @@ use crate::adapter::openai::OpenAIStreamer; use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind, ServiceType, WebRequestData}; use crate::chat::{ ChatOptionsSet, ChatRequest, ChatResponse, ChatResponseFormat, ChatRole, ChatStream, ChatStreamResponse, - MessageContent, MetaUsage, ToolCall, + MessageContent, MetaUsage, ToolCall, ContentPart, ImageSource }; use crate::resolver::{AuthData, Endpoint}; use crate::webc::WebResponse; @@ -250,10 +250,31 @@ impl OpenAIAdapter { // TODO: Probably need to warn if it is a ToolCalls type of content } ChatRole::User => { - if let MessageContent::Text(content) = msg.content { - messages.push(json! ({"role": "user", "content": content})); - } - // TODO: Probably need to warn if it is a ToolCalls type of content + let content = match msg.content { + MessageContent::Text(content) => json!(content), + MessageContent::Parts(parts) => { + json!(parts.iter().map(|part| match part { + ContentPart::Text(text) => json!({"type": "text", "text": text.clone()}), + ContentPart::Image{content, content_type, source} => { + match source { + ImageSource::Url => json!({"type": "image_url", "image_url": {"url": content}}), + ImageSource::Base64 => { + let image_url = format!("data:{content_type};base64,{content}"); + json!({"type": "image_url", "image_url": {"url": image_url}}) + }, + } + }, + }).collect::>()) + }, + // Use `match` instead of `if let`. This will allow to future-proof this + // implementation in case some new message content types would appear, + // this way library would not compile if not all methods are implemented + // continue would allow to gracefully skip pushing unserializable message + // TODO: Probably need to warn if it is a ToolCalls type of content + MessageContent::ToolCalls(_) => continue, + MessageContent::ToolResponses(_) => continue, + }; + messages.push(json! ({"role": "user", "content": content})); } ChatRole::Assistant => match msg.content { @@ -275,6 +296,7 @@ impl OpenAIAdapter { messages.push(json! ({"role": "assistant", "tool_calls": tool_calls})) } // TODO: Probably need to trace/warn that this will be ignored + MessageContent::Parts(_) => (), MessageContent::ToolResponses(_) => (), }, diff --git a/src/chat/message_content.rs b/src/chat/message_content.rs index 481f66b..e86e41c 100644 --- a/src/chat/message_content.rs +++ b/src/chat/message_content.rs @@ -2,13 +2,14 @@ use crate::chat::{ToolCall, ToolResponse}; use derive_more::derive::From; use serde::{Deserialize, Serialize}; -/// Currently, it only supports Text, -/// but the goal is to support multi-part message content (see below) #[derive(Debug, Clone, Serialize, Deserialize, From)] pub enum MessageContent { /// Text content Text(String), + /// Content parts + Parts(Vec), + /// Tool calls #[from] ToolCalls(Vec), @@ -25,6 +26,9 @@ impl MessageContent { MessageContent::Text(content.into()) } + /// Create a new MessageContent from provided content parts + pub fn from_parts(parts: impl Into>) -> Self { MessageContent::Parts(parts.into()) } + /// Create a new MessageContent with the ToolCalls variant pub fn from_tool_calls(tool_calls: Vec) -> Self { MessageContent::ToolCalls(tool_calls) @@ -40,6 +44,12 @@ impl MessageContent { pub fn text_as_str(&self) -> Option<&str> { match self { MessageContent::Text(content) => Some(content.as_str()), + MessageContent::Parts(parts) => { + Some(parts.iter().filter_map(|part| match part { + ContentPart::Text(content) => Some(content.clone()), + _ => None, + }).collect::>().join("\n").leak()) // TODO revisit this, should we leak &str? + }, MessageContent::ToolCalls(_) => None, MessageContent::ToolResponses(_) => None, } @@ -53,6 +63,12 @@ impl MessageContent { pub fn text_into_string(self) -> Option { match self { MessageContent::Text(content) => Some(content), + MessageContent::Parts(parts) => { + Some(parts.into_iter().filter_map(|part| match part { + ContentPart::Text(content) => Some(content), + _ => None, + }).collect::>().join("\n")) + }, MessageContent::ToolCalls(_) => None, MessageContent::ToolResponses(_) => None, } @@ -62,6 +78,7 @@ impl MessageContent { pub fn is_empty(&self) -> bool { match self { MessageContent::Text(content) => content.is_empty(), + MessageContent::Parts(parts) => parts.is_empty(), MessageContent::ToolCalls(tool_calls) => tool_calls.is_empty(), MessageContent::ToolResponses(tool_responses) => tool_responses.is_empty(), } @@ -94,27 +111,39 @@ impl From for MessageContent { } } +impl From> for MessageContent { + fn from(parts: Vec) -> Self { MessageContent::Parts(parts) } +} + // endregion: --- Froms -// NOTE: The goal is to add a Parts variant with ContentPart for multipart support -// -// ```` -// pub enum MessageContent { -// Text(String), -// Parts(Vec)` variant to `MessageContent` -// } -// ``` -// -// With something like this: -// ``` -// pub enum ContentPart { -// Text(String), -// Image(ImagePart) -// } -// -// pub enum ImagePart { -// Local(PathBuf), -// Remote(Url), -// Base64(String) -// } -// ``` +#[derive(Debug, Clone, Serialize, Deserialize, From)] +pub enum ContentPart { + Text(String), + Image { + content: String, + content_type: String, + source: ImageSource, + }, +} + +// region: --- Froms + +impl<'a> From<&'a str> for ContentPart { + fn from(s: &'a str) -> Self { + ContentPart::Text(s.to_string()) + } +} + +// endregion: --- Froms + + +#[derive(Debug, Clone, Serialize, Deserialize, From)] +pub enum ImageSource { + Url, + Base64 + + // No `Local` location, this would require handling errors like "file not found" etc. + // Such file can be easily provided by user as Base64, also can implement convenient + // TryFrom to Base64 version. All LLMs accepts local Images only as Base64 +} diff --git a/tests/support/seeders.rs b/tests/support/seeders.rs index 2fd1051..96e1c05 100644 --- a/tests/support/seeders.rs +++ b/tests/support/seeders.rs @@ -1,4 +1,4 @@ -use genai::chat::{ChatMessage, ChatRequest, Tool}; +use genai::chat::{ChatMessage, ChatRequest, ContentPart, ImageSource, Tool}; use serde_json::json; pub fn seed_chat_req_simple() -> ChatRequest { @@ -9,6 +9,21 @@ pub fn seed_chat_req_simple() -> ChatRequest { ]) } +pub fn seed_chat_req_with_image() -> ChatRequest { + ChatRequest::new(vec![ + // -- Messages (deactivate to see the differences) + ChatMessage::system("Answer in one sentence"), + ChatMessage::user(vec![ + ContentPart::from("What is in this image?"), + ContentPart::Image { + content: "BASE64 ENCODED IMAGE".to_string(), + content_type:"image/png".to_string(), + source: ImageSource::Base64, + } + ]), + ]) +} + pub fn seed_chat_req_tool_simple() -> ChatRequest { ChatRequest::new(vec![ // -- Messages (deactivate to see the differences)