diff --git a/Cargo.toml b/Cargo.toml index 778ae16..2ed2562 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "0.0.12-alpha.1" edition = "2021" rust-version = "1.78" license = "MIT OR Apache-2.0" -description = "Multiprovider generative AI client (Ollama, OpenAI, Gemini, Anthropic, Cohere, ...)" +description = "Multi-Provider Generative AI Rust Library. (Ollama, OpenAI, Gemini, Anthropic, Cohere, ...)" keywords = [ "generative-ai", "library", diff --git a/examples/c00-readme.rs b/examples/c00-readme.rs index ce3c75d..c91fca8 100644 --- a/examples/c00-readme.rs +++ b/examples/c00-readme.rs @@ -50,11 +50,11 @@ async fn main() -> Result<(), Box> { println!("\n--- Question:\n{question}"); println!("\n--- Answer: (oneshot response)"); - let chat_res = client.exec_chat(model, chat_req.clone()).await?; + let chat_res = client.exec_chat(model, chat_req.clone(), None).await?; println!("{}", chat_res.content.as_deref().unwrap_or("NO ANSWER")); println!("\n--- Answer: (streaming)"); - let chat_res = client.exec_chat_stream(model, chat_req.clone()).await?; + let chat_res = client.exec_chat_stream(model, chat_req.clone(), None).await?; print_chat_stream(chat_res).await?; println!(); diff --git a/examples/c01-conv.rs b/examples/c01-conv.rs index e0d91bc..5323613 100644 --- a/examples/c01-conv.rs +++ b/examples/c01-conv.rs @@ -21,7 +21,7 @@ async fn main() -> Result<(), Box> { chat_req = chat_req.append_message(ChatMessage::user(question)); println!("\n--- Question:\n{question}"); - let chat_res = client.exec_chat_stream(MODEL, chat_req.clone()).await?; + let chat_res = client.exec_chat_stream(MODEL, chat_req.clone(), None).await?; println!("\n--- Answer: (streaming)"); let assistant_answer = print_chat_stream(chat_res).await?; diff --git a/examples/c02-auth.rs b/examples/c02-auth.rs index d70b820..48e78f9 100644 --- a/examples/c02-auth.rs +++ b/examples/c02-auth.rs @@ -38,7 +38,7 @@ async fn main() -> Result<(), Box> { chat_req = chat_req.append_message(ChatMessage::user(question)); println!("\n--- Question:\n{question}"); - let chat_res = client.exec_chat_stream(MODEL, chat_req.clone()).await?; + let chat_res = client.exec_chat_stream(MODEL, chat_req.clone(), None).await?; println!("\n--- Answer: (streaming)"); let assistant_answer = print_chat_stream(chat_res).await?; diff --git a/src/adapter/adapter_types.rs b/src/adapter/adapter_types.rs index 2c521af..25a0a46 100644 --- a/src/adapter/adapter_types.rs +++ b/src/adapter/adapter_types.rs @@ -1,6 +1,6 @@ use crate::adapter::support::get_api_key_resolver; use crate::adapter::AdapterConfig; -use crate::chat::{ChatRequest, ChatResponse, ChatStreamResponse}; +use crate::chat::{ChatRequest, ChatRequestOptions, ChatResponse, ChatStreamResponse}; use crate::webc::WebResponse; use crate::{ConfigSet, Result}; use derive_more::Display; @@ -19,7 +19,7 @@ pub enum AdapterKind { } impl AdapterKind { - /// Very simplistic getter for now. + /// Very simplistic mapper for now. pub fn from_model(model: &str) -> Result { if model.starts_with("gpt") { Ok(AdapterKind::OpenAI) @@ -42,6 +42,8 @@ pub trait Adapter { /// Note: Implementation typically using OnceLock fn default_adapter_config(kind: AdapterKind) -> &'static AdapterConfig; + /// The base service url for this AdapterKind for this given service type. + /// NOTE: For some services, the url will be further updated in the to_web_request_data fn get_service_url(kind: AdapterKind, service_type: ServiceType) -> String; /// Get the api_key, with default implementation. @@ -53,9 +55,10 @@ pub trait Adapter { fn to_web_request_data( kind: AdapterKind, config_set: &ConfigSet<'_>, + service_type: ServiceType, model: &str, chat_req: ChatRequest, - service_type: ServiceType, + chat_req_options: Option<&ChatRequestOptions>, ) -> Result; /// To be implemented by Adapters diff --git a/src/adapter/adapters/anthropic/adapter_impl.rs b/src/adapter/adapters/anthropic/adapter_impl.rs index 3d5ecad..1073709 100644 --- a/src/adapter/adapters/anthropic/adapter_impl.rs +++ b/src/adapter/adapters/anthropic/adapter_impl.rs @@ -1,7 +1,7 @@ use crate::adapter::anthropic::AnthropicMessagesStream; use crate::adapter::support::get_api_key_resolver; use crate::adapter::{Adapter, AdapterConfig, AdapterKind, ServiceType, WebRequestData}; -use crate::chat::{ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse}; +use crate::chat::{ChatRequest, ChatRequestOptions, ChatResponse, ChatRole, ChatStream, ChatStreamResponse}; use crate::utils::x_value::XValue; use crate::webc::WebResponse; use crate::{ConfigSet, Result}; @@ -31,9 +31,10 @@ impl Adapter for AnthropicAdapter { fn to_web_request_data( kind: AdapterKind, config_set: &ConfigSet<'_>, + service_type: ServiceType, model: &str, chat_req: ChatRequest, - service_type: ServiceType, + _chat_req_options: Option<&ChatRequestOptions>, ) -> Result { let stream = matches!(service_type, ServiceType::ChatStream); let url = Self::get_service_url(kind, service_type); @@ -78,7 +79,10 @@ impl Adapter for AnthropicAdapter { Some(content.join("")) }; - Ok(ChatResponse { content }) + Ok(ChatResponse { + content, + ..Default::default() + }) } fn to_chat_stream(_kind: AdapterKind, reqwest_builder: RequestBuilder) -> Result { diff --git a/src/adapter/adapters/cohere/adapter_impl.rs b/src/adapter/adapters/cohere/adapter_impl.rs index 053db31..20bb624 100644 --- a/src/adapter/adapters/cohere/adapter_impl.rs +++ b/src/adapter/adapters/cohere/adapter_impl.rs @@ -1,7 +1,7 @@ use crate::adapter::cohere::CohereStream; use crate::adapter::support::get_api_key_resolver; use crate::adapter::{Adapter, AdapterConfig, AdapterKind, ServiceType, WebRequestData}; -use crate::chat::{ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse}; +use crate::chat::{ChatRequest, ChatRequestOptions, ChatResponse, ChatRole, ChatStream, ChatStreamResponse}; use crate::utils::x_value::XValue; use crate::webc::{WebResponse, WebStream}; use crate::{ConfigSet, Error, Result}; @@ -29,9 +29,10 @@ impl Adapter for CohereAdapter { fn to_web_request_data( kind: AdapterKind, config_set: &ConfigSet<'_>, + service_type: ServiceType, model: &str, chat_req: ChatRequest, - service_type: ServiceType, + _chat_req_options: Option<&ChatRequestOptions>, ) -> Result { let stream = matches!(service_type, ServiceType::ChatStream); @@ -79,7 +80,10 @@ impl Adapter for CohereAdapter { let content: Option = last_chat_history_item.x_take("message")?; - Ok(ChatResponse { content }) + Ok(ChatResponse { + content, + ..Default::default() + }) } fn to_chat_stream(_kind: AdapterKind, reqwest_builder: RequestBuilder) -> Result { diff --git a/src/adapter/adapters/gemini/adapter_impl.rs b/src/adapter/adapters/gemini/adapter_impl.rs index db02d5e..b0ad80b 100644 --- a/src/adapter/adapters/gemini/adapter_impl.rs +++ b/src/adapter/adapters/gemini/adapter_impl.rs @@ -1,7 +1,7 @@ use crate::adapter::gemini::GeminiStream; use crate::adapter::support::get_api_key_resolver; use crate::adapter::{Adapter, AdapterConfig, AdapterKind, ServiceType, WebRequestData}; -use crate::chat::{ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse}; +use crate::chat::{ChatRequest, ChatRequestOptions, ChatResponse, ChatRole, ChatStream, ChatStreamResponse}; use crate::utils::x_value::XValue; use crate::webc::{WebResponse, WebStream}; use crate::{ConfigSet, Error, Result}; @@ -33,12 +33,15 @@ impl Adapter for GeminiAdapter { fn to_web_request_data( kind: AdapterKind, config_set: &ConfigSet<'_>, + service_type: ServiceType, model: &str, chat_req: ChatRequest, - service_type: ServiceType, + _chat_req_options: Option<&ChatRequestOptions>, ) -> Result { let api_key = get_api_key_resolver(kind, config_set)?; + // For gemini, the service url returned is just the base url + // since model and API key is part of the url (see below) let url = Self::get_service_url(kind, service_type); // e.g., '...models/gemini-1.5-flash-latest:generateContent?key=YOUR_API_KEY' @@ -65,6 +68,7 @@ impl Adapter for GeminiAdapter { Ok(ChatResponse { content: gemini_response.content, + ..Default::default() }) } diff --git a/src/adapter/adapters/ollama/adapter_impl.rs b/src/adapter/adapters/ollama/adapter_impl.rs index e4d0e2c..564cbc7 100644 --- a/src/adapter/adapters/ollama/adapter_impl.rs +++ b/src/adapter/adapters/ollama/adapter_impl.rs @@ -2,7 +2,7 @@ use crate::adapter::openai::OpenAIAdapter; use crate::adapter::{Adapter, AdapterConfig, AdapterKind, ServiceType, WebRequestData}; -use crate::chat::{ChatRequest, ChatResponse, ChatStreamResponse}; +use crate::chat::{ChatRequest, ChatRequestOptions, ChatResponse, ChatStreamResponse}; use crate::webc::WebResponse; use crate::{ConfigSet, Result}; use reqwest::RequestBuilder; @@ -28,9 +28,10 @@ impl Adapter for OllamaAdapter { fn to_web_request_data( kind: AdapterKind, _config_set: &ConfigSet<'_>, + service_type: ServiceType, model: &str, chat_req: ChatRequest, - service_type: ServiceType, + _chat_req_options: Option<&ChatRequestOptions>, ) -> Result { let url = Self::get_service_url(kind, service_type); diff --git a/src/adapter/adapters/openai/adapter_impl.rs b/src/adapter/adapters/openai/adapter_impl.rs index 13f671b..fa139d3 100644 --- a/src/adapter/adapters/openai/adapter_impl.rs +++ b/src/adapter/adapters/openai/adapter_impl.rs @@ -1,7 +1,7 @@ use crate::adapter::openai::OpenAIMessagesStream; use crate::adapter::support::get_api_key_resolver; use crate::adapter::{Adapter, AdapterConfig, AdapterKind, ServiceType, WebRequestData}; -use crate::chat::{ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse}; +use crate::chat::{ChatRequest, ChatRequestOptions, ChatResponse, ChatRole, ChatStream, ChatStreamResponse}; use crate::utils::x_value::XValue; use crate::webc::WebResponse; use crate::{ConfigSet, Error, Result}; @@ -27,9 +27,10 @@ impl Adapter for OpenAIAdapter { fn to_web_request_data( kind: AdapterKind, config_set: &ConfigSet<'_>, + service_type: ServiceType, model: &str, chat_req: ChatRequest, - service_type: ServiceType, + _chat_req_options: Option<&ChatRequestOptions>, ) -> Result { // -- api_key (this Adapter requires it) let api_key = get_api_key_resolver(kind, config_set)?; @@ -42,7 +43,10 @@ impl Adapter for OpenAIAdapter { let WebResponse { mut body, .. } = web_response; let first_choice: Option = body.x_take("/choices/0")?; let content: Option = first_choice.map(|mut c| c.x_take("/message/content")).transpose()?; - Ok(ChatResponse { content }) + Ok(ChatResponse { + content, + ..Default::default() + }) } fn to_chat_stream(_kind: AdapterKind, reqwest_builder: RequestBuilder) -> Result { diff --git a/src/adapter/dispatcher.rs b/src/adapter/dispatcher.rs index 04300fa..6da3308 100644 --- a/src/adapter/dispatcher.rs +++ b/src/adapter/dispatcher.rs @@ -4,7 +4,7 @@ use crate::adapter::gemini::GeminiAdapter; use crate::adapter::ollama::OllamaAdapter; use crate::adapter::openai::OpenAIAdapter; use crate::adapter::{Adapter, AdapterConfig, AdapterKind, ServiceType, WebRequestData}; -use crate::chat::{ChatRequest, ChatResponse, ChatStreamResponse}; +use crate::chat::{ChatRequest, ChatRequestOptions, ChatResponse, ChatStreamResponse}; use crate::webc::WebResponse; use crate::{ConfigSet, Result}; use reqwest::RequestBuilder; @@ -35,18 +35,27 @@ impl Adapter for AdapterDispatcher { fn to_web_request_data( kind: AdapterKind, config_set: &ConfigSet<'_>, + service_type: ServiceType, model: &str, chat_req: ChatRequest, - service_type: ServiceType, + chat_req_options: Option<&ChatRequestOptions>, ) -> Result { match kind { - AdapterKind::OpenAI => OpenAIAdapter::to_web_request_data(kind, config_set, model, chat_req, service_type), - AdapterKind::Anthropic => { - AnthropicAdapter::to_web_request_data(kind, config_set, model, chat_req, service_type) + AdapterKind::OpenAI => { + OpenAIAdapter::to_web_request_data(kind, config_set, service_type, model, chat_req, chat_req_options) + } + AdapterKind::Anthropic => { + AnthropicAdapter::to_web_request_data(kind, config_set, service_type, model, chat_req, chat_req_options) + } + AdapterKind::Cohere => { + CohereAdapter::to_web_request_data(kind, config_set, service_type, model, chat_req, chat_req_options) + } + AdapterKind::Ollama => { + OllamaAdapter::to_web_request_data(kind, config_set, service_type, model, chat_req, chat_req_options) + } + AdapterKind::Gemini => { + GeminiAdapter::to_web_request_data(kind, config_set, service_type, model, chat_req, chat_req_options) } - AdapterKind::Cohere => CohereAdapter::to_web_request_data(kind, config_set, model, chat_req, service_type), - AdapterKind::Ollama => OllamaAdapter::to_web_request_data(kind, config_set, model, chat_req, service_type), - AdapterKind::Gemini => GeminiAdapter::to_web_request_data(kind, config_set, model, chat_req, service_type), } } diff --git a/src/chat/chat_options.rs b/src/chat/chat_options.rs new file mode 100644 index 0000000..9868943 --- /dev/null +++ b/src/chat/chat_options.rs @@ -0,0 +1,17 @@ +//! ChatRequestOptions is a struct that can be passed into the `client::exec_chat...` as the last argument +//! to customize the request behavior per call. +//! Note: Splitting it out of the `ChatRequest` object allows for better reusability of each component. +//! +//! IMPORTANT: These are not implemented yet, but here to show some of the directions and start having them part of the client APIs. + +pub struct ChatRequestOptions { + /// Will capture the `MetaUsage` + /// - In the `ChatResponse` for `exec_chat` + /// - In the `StreamEnd` of `StreamEvent::End(StreamEnd)` for `exec_chat_stream` + pub capture_usage: Option, + + // -- For Stream only (for now, we flat them out) + /// Tell the chat stream executor to capture and concatenate all of the text chunk + /// to the last `StreamEvent::End(StreamEnd)` event as `StreamEnd.captured_content` (so, will be `Some(concatenated_chunks)`) + pub capture_content: Option, +} diff --git a/src/chat/chat_res.rs b/src/chat/chat_res.rs index e6cf4af..c0070db 100644 --- a/src/chat/chat_res.rs +++ b/src/chat/chat_res.rs @@ -2,9 +2,12 @@ use crate::chat::ChatStream; // region: --- ChatResponse -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct ChatResponse { pub content: Option, + /// NOT SUPPORTED + #[allow(unused)] + pub meta_usage: Option, } // endregion: --- ChatResponse @@ -16,3 +19,16 @@ pub struct ChatStreamResponse { } // endregion: --- ChatStreamResponse + +// region: --- MetaUsage + +// IMPORTANT: This is NOT used for now. To show the API direction. + +#[derive(Debug, Clone)] +pub struct MetaUsage { + pub input_token: Option, + pub output_token: Option, + pub total_token: Option, +} + +// endregion: --- MetaUsage diff --git a/src/chat/chat_stream.rs b/src/chat/chat_stream.rs index 8f57a79..9765127 100644 --- a/src/chat/chat_stream.rs +++ b/src/chat/chat_stream.rs @@ -1,4 +1,5 @@ use crate::adapter::inter_stream::InterStreamEvent; +use crate::chat::MetaUsage; use crate::Result; use derive_more::From; use futures::Stream; @@ -67,6 +68,9 @@ pub struct StreamChunk { #[derive(Debug, Default)] pub struct StreamEnd { + /// The eventual capture UsageMeta + pub meta_usage: Option, + /// The optional captured full content /// NOTE: NOT SUPPORTED YET (always None for now) /// Probably allow to toggle this on at the client_config, adapter_config diff --git a/src/chat/mod.rs b/src/chat/mod.rs index 186a627..db77c1d 100644 --- a/src/chat/mod.rs +++ b/src/chat/mod.rs @@ -3,12 +3,14 @@ // region: --- Modules +mod chat_options; mod chat_req; mod chat_res; mod chat_stream; mod tool; // -- Flatten +pub use chat_options::*; pub use chat_req::*; pub use chat_res::*; pub use chat_stream::*; diff --git a/src/client/client_impl.rs b/src/client/client_impl.rs index f05549c..1c56073 100644 --- a/src/client/client_impl.rs +++ b/src/client/client_impl.rs @@ -1,5 +1,5 @@ use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind, ServiceType, WebRequestData}; -use crate::chat::{ChatRequest, ChatResponse, ChatStreamResponse}; +use crate::chat::{ChatRequest, ChatRequestOptions, ChatResponse, ChatStreamResponse}; use crate::client::Client; use crate::{ConfigSet, Result}; @@ -9,7 +9,13 @@ impl Client { todo!() } - pub async fn exec_chat(&self, model: &str, chat_req: ChatRequest) -> Result { + pub async fn exec_chat( + &self, + model: &str, + chat_req: ChatRequest, + // options not implemented yet + options: Option<&ChatRequestOptions>, + ) -> Result { let adapter_kind = AdapterKind::from_model(model)?; let adapter_config = self @@ -18,8 +24,14 @@ impl Client { let config_set = ConfigSet::new(self.config(), adapter_config); - let WebRequestData { headers, payload, url } = - AdapterDispatcher::to_web_request_data(adapter_kind, &config_set, model, chat_req, ServiceType::Chat)?; + let WebRequestData { headers, payload, url } = AdapterDispatcher::to_web_request_data( + adapter_kind, + &config_set, + ServiceType::Chat, + model, + chat_req, + options, + )?; let web_res = self.web_client().do_post(&url, &headers, payload).await?; @@ -28,7 +40,12 @@ impl Client { Ok(chat_res) } - pub async fn exec_chat_stream(&self, model: &str, chat_req: ChatRequest) -> Result { + pub async fn exec_chat_stream( + &self, + model: &str, + chat_req: ChatRequest, // options not implemented yet + options: Option<&ChatRequestOptions>, + ) -> Result { let adapter_kind = AdapterKind::from_model(model)?; let adapter_config = self @@ -40,9 +57,10 @@ impl Client { let WebRequestData { url, headers, payload } = AdapterDispatcher::to_web_request_data( adapter_kind, &config_set, + ServiceType::ChatStream, model, chat_req, - ServiceType::ChatStream, + options, )?; let reqwest_builder = self.web_client().new_req_builder(&url, &headers, payload)?;