* prep refactor for custom endpoint

This commit is contained in:
Jeremy Chone
2024-12-08 09:56:25 -08:00
parent 175ed484cf
commit 011fb40a04
21 changed files with 353 additions and 228 deletions

View File

@ -11,7 +11,7 @@ repository = "https://github.com/jeremychone/rust-genai"
[lints.rust]
unsafe_code = "forbid"
# unused = { level = "allow", priority = -1 } # For exploratory dev.
unused = { level = "allow", priority = -1 } # For exploratory dev.
# missing_docs = "warn"
[dependencies]

View File

@ -54,7 +54,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
continue;
}
let adapter_kind = client.resolve_model_iden(model)?.adapter_kind;
let adapter_kind = client.resolve_service_target(model)?.model.adapter_kind;
println!("\n===== MODEL: {model} ({adapter_kind}) =====");

View File

@ -36,7 +36,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("\n--- Question:\n{question}");
let chat_res = client.exec_chat_stream(MODEL, chat_req.clone(), Some(&options)).await?;
let adapter_kind = client.resolve_model_iden(MODEL)?.adapter_kind;
let adapter_kind = client.resolve_service_target(MODEL)?.model.adapter_kind;
println!("\n--- Answer: ({MODEL} - {adapter_kind})");
print_chat_stream(chat_res, None).await?;

View File

@ -51,12 +51,12 @@ impl AdapterKind {
}
/// Utilities
impl AdapterKind {
/// Get the default key environment variable name for the adapter kind.
pub fn default_key_env_name(&self) -> Option<&'static str> {
AdapterDispatcher::default_key_env_name(*self)
}
}
// impl AdapterKind {
// /// Get the default key environment variable name for the adapter kind.
// pub fn default_key_env_name(&self) -> Option<&'static str> {
// AdapterDispatcher::default_key_env_name(*self)
// }
// }
/// From Model implementations
impl AdapterKind {

View File

@ -1,24 +1,30 @@
use crate::adapter::AdapterKind;
use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse};
use crate::resolver::{AuthData, Endpoint};
use crate::webc::WebResponse;
use crate::Result;
use crate::{ClientConfig, ModelIden};
use crate::{Result, ServiceTarget};
use reqwest::RequestBuilder;
use serde_json::Value;
pub trait Adapter {
fn default_key_env_name(kind: AdapterKind) -> Option<&'static str>;
// #[deprecated(note = "use default_auth")]
// fn default_key_env_name(kind: AdapterKind) -> Option<&'static str>;
fn default_auth(kind: AdapterKind) -> AuthData;
fn default_endpoint(kind: AdapterKind) -> Endpoint;
// NOTE: Adapter is a crate trait, so it is acceptable to use async fn here.
async fn all_model_names(kind: AdapterKind) -> Result<Vec<String>>;
/// The base service URL for this AdapterKind for the given service type.
/// NOTE: For some services, the URL will be further updated in the to_web_request_data method.
fn get_service_url(model_iden: ModelIden, service_type: ServiceType) -> String;
fn get_service_url(model_iden: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String;
/// To be implemented by Adapters.
fn to_web_request_data(
model_iden: ModelIden,
service_target: ServiceTarget,
config_set: &ClientConfig,
service_type: ServiceType,
chat_req: ChatRequest,

View File

@ -1,13 +1,14 @@
use crate::adapter::adapters::support::get_api_key;
use crate::adapter::anthropic::AnthropicStreamer;
use crate::adapter::support::get_api_key;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{
ChatOptionsSet, ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse, MessageContent, MetaUsage,
ToolCall,
};
use crate::resolver::{AuthData, Endpoint};
use crate::webc::WebResponse;
use crate::Result;
use crate::{ClientConfig, ModelIden};
use crate::{Result, ServiceTarget};
use reqwest::RequestBuilder;
use reqwest_eventsource::EventSource;
use serde_json::{json, Value};
@ -15,8 +16,6 @@ use value_ext::JsonValueExt;
pub struct AnthropicAdapter;
const BASE_URL: &str = "https://api.anthropic.com/v1/";
// NOTE: For Anthropic, the max_tokens must be specified.
// To avoid surprises, the default value for genai is the maximum for a given model.
// The 3-5 models have an 8k max token limit, while the 3 models have a 4k limit.
@ -32,8 +31,13 @@ const MODELS: &[&str] = &[
];
impl Adapter for AnthropicAdapter {
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
Some("ANTHROPIC_API_KEY")
fn default_endpoint(kind: AdapterKind) -> Endpoint {
const BASE_URL: &str = "https://api.anthropic.com/v1/";
Endpoint::from_static(BASE_URL)
}
fn default_auth(kind: AdapterKind) -> AuthData {
AuthData::from_env("ANTHROPIC_API_KEY")
}
/// Note: For now, it returns the common models (see above)
@ -41,40 +45,47 @@ impl Adapter for AnthropicAdapter {
Ok(MODELS.iter().map(|s| s.to_string()).collect())
}
fn get_service_url(_model_iden: ModelIden, service_type: ServiceType) -> String {
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
let base_url = endpoint.base_url();
match service_type {
ServiceType::Chat | ServiceType::ChatStream => format!("{BASE_URL}messages"),
ServiceType::Chat | ServiceType::ChatStream => format!("{base_url}messages"),
}
}
fn to_web_request_data(
model_iden: ModelIden,
target: ServiceTarget,
client_config: &ClientConfig,
service_type: ServiceType,
chat_req: ChatRequest,
options_set: ChatOptionsSet<'_, '_>,
) -> Result<WebRequestData> {
let model_name = model_iden.model_name.clone();
let ServiceTarget { endpoint, auth, model } = target;
let stream = matches!(service_type, ServiceType::ChatStream);
let url = Self::get_service_url(model_iden.clone(), service_type);
// -- api_key
let api_key = get_api_key(auth, &model)?;
// -- api_key (this Adapter requires it)
let api_key = get_api_key(model_iden.clone(), client_config)?;
// -- url
let url = Self::get_service_url(&model, service_type, endpoint);
// -- headers
let headers = vec![
// headers
("x-api-key".to_string(), api_key.to_string()),
("x-api-key".to_string(), api_key),
("anthropic-version".to_string(), ANTRHOPIC_VERSION.to_string()),
];
let model_name = model.model_name.clone();
// -- Parts
let AnthropicRequestParts {
system,
messages,
tools,
} = Self::into_anthropic_request_parts(model_iden.clone(), chat_req)?;
} = Self::into_anthropic_request_parts(model, chat_req)?;
// -- Build the basic payload
let stream = matches!(service_type, ServiceType::ChatStream);
let mut payload = json!({
"model": model_name.to_string(),
"messages": messages,
@ -99,7 +110,7 @@ impl Adapter for AnthropicAdapter {
}
let max_tokens = options_set.max_tokens().unwrap_or_else(|| {
if model_iden.model_name.contains("3-5") {
if model_name.contains("3-5") {
MAX_TOKENS_8K
} else {
MAX_TOKENS_4K

View File

@ -1,11 +1,12 @@
use crate::adapter::adapters::support::get_api_key;
use crate::adapter::cohere::CohereStreamer;
use crate::adapter::support::get_api_key;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{
ChatOptionsSet, ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse, MessageContent, MetaUsage,
};
use crate::resolver::{AuthData, Endpoint};
use crate::webc::{WebResponse, WebStream};
use crate::{ClientConfig, ModelIden};
use crate::{ClientConfig, ModelIden, ServiceTarget};
use crate::{Error, Result};
use reqwest::RequestBuilder;
use serde_json::{json, Value};
@ -13,7 +14,6 @@ use value_ext::JsonValueExt;
pub struct CohereAdapter;
const BASE_URL: &str = "https://api.cohere.com/v1/";
const MODELS: &[&str] = &[
"command-r-plus",
"command-r",
@ -24,8 +24,13 @@ const MODELS: &[&str] = &[
];
impl Adapter for CohereAdapter {
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
Some("COHERE_API_KEY")
fn default_endpoint(kind: AdapterKind) -> Endpoint {
const BASE_URL: &str = "https://api.cohere.com/v1/";
Endpoint::from_static(BASE_URL)
}
fn default_auth(kind: AdapterKind) -> AuthData {
AuthData::from_env("COHERE_API_KEY")
}
/// Note: For now, it returns the common ones (see above)
@ -33,40 +38,45 @@ impl Adapter for CohereAdapter {
Ok(MODELS.iter().map(|s| s.to_string()).collect())
}
fn get_service_url(_model_iden: ModelIden, service_type: ServiceType) -> String {
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
let base_url = endpoint.base_url();
match service_type {
ServiceType::Chat | ServiceType::ChatStream => format!("{BASE_URL}chat"),
ServiceType::Chat | ServiceType::ChatStream => format!("{base_url}chat"),
}
}
fn to_web_request_data(
model_iden: ModelIden,
target: ServiceTarget,
client_config: &ClientConfig,
service_type: ServiceType,
chat_req: ChatRequest,
options_set: ChatOptionsSet<'_, '_>,
) -> Result<WebRequestData> {
let model_name = model_iden.model_name.clone();
let stream = matches!(service_type, ServiceType::ChatStream);
let url = Self::get_service_url(model_iden.clone(), service_type);
let ServiceTarget { endpoint, auth, model } = target;
// -- api_key (this Adapter requires it)
let api_key = get_api_key(model_iden.clone(), client_config)?;
let api_key = get_api_key(auth, &model)?;
// -- url
let url = Self::get_service_url(&model, service_type, endpoint);
// -- headers
let headers = vec![
// headers
("Authorization".to_string(), format!("Bearer {api_key}")),
];
let model_name = model.model_name.clone();
// -- parts
let CohereChatRequestParts {
preamble,
message,
chat_history,
} = Self::into_cohere_request_parts(model_iden, chat_req)?;
} = Self::into_cohere_request_parts(model, chat_req)?;
// -- Build the basic payload
let stream = matches!(service_type, ServiceType::ChatStream);
let mut payload = json!({
"model": model_name.to_string(),
"message": message,

View File

@ -1,12 +1,13 @@
use crate::adapter::adapters::support::get_api_key;
use crate::adapter::gemini::GeminiStreamer;
use crate::adapter::support::get_api_key;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{
ChatOptionsSet, ChatRequest, ChatResponse, ChatResponseFormat, ChatRole, ChatStream, ChatStreamResponse,
MessageContent, MetaUsage,
};
use crate::resolver::{AuthData, Endpoint};
use crate::webc::{WebResponse, WebStream};
use crate::{ClientConfig, ModelIden};
use crate::{ClientConfig, ModelIden, ServiceTarget};
use crate::{Error, Result};
use reqwest::RequestBuilder;
use serde_json::{json, Value};
@ -14,7 +15,6 @@ use value_ext::JsonValueExt;
pub struct GeminiAdapter;
const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
const MODELS: &[&str] = &[
"gemini-1.5-pro",
"gemini-1.5-flash",
@ -29,8 +29,13 @@ const MODELS: &[&str] = &[
// -X POST 'https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=YOUR_API_KEY'
impl Adapter for GeminiAdapter {
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
Some("GEMINI_API_KEY")
fn default_endpoint(kind: AdapterKind) -> Endpoint {
const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
Endpoint::from_static(BASE_URL)
}
fn default_auth(kind: AdapterKind) -> AuthData {
AuthData::from_env("GEMINI_API_KEY")
}
/// Note: For now, this returns the common models (see above)
@ -38,40 +43,48 @@ impl Adapter for GeminiAdapter {
Ok(MODELS.iter().map(|s| s.to_string()).collect())
}
fn get_service_url(_model_iden: ModelIden, service_type: ServiceType) -> String {
/// NOTE: As Google Gemini has decided to put their API_KEY in the URL,
/// this will return the URL without the API_KEY in it. The API_KEY will need to be added by the caller.
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
let base_url = endpoint.base_url();
let model_name = model.model_name.clone();
match service_type {
ServiceType::Chat | ServiceType::ChatStream => BASE_URL.to_string(),
ServiceType::Chat => format!("{base_url}models/{model_name}:generateContent"),
ServiceType::ChatStream => format!("{base_url}models/{model_name}:streamGenerateContent"),
}
}
fn to_web_request_data(
model_iden: ModelIden,
target: ServiceTarget,
client_config: &ClientConfig,
service_type: ServiceType,
chat_req: ChatRequest,
options_set: ChatOptionsSet<'_, '_>,
) -> Result<WebRequestData> {
let api_key = get_api_key(model_iden.clone(), client_config)?;
let ServiceTarget { endpoint, auth, model } = target;
// For Gemini, the service URL returned is just the base URL
// since the model and API key are part of the URL (see below)
let url = Self::get_service_url(model_iden.clone(), service_type);
// -- api_key
let api_key = get_api_key(auth, &model)?;
// -- url
// NOTE: Somehow, Google decided to put the API key in the URL.
// This should be considered an antipattern from a security point of view
// even if it is done by the well respected Google. Everybody can make mistake once in a while.
// e.g., '...models/gemini-1.5-flash-latest:generateContent?key=YOUR_API_KEY'
let model_name = &*model_iden.model_name;
let url = match service_type {
ServiceType::Chat => format!("{url}models/{model_name}:generateContent?key={api_key}"),
ServiceType::ChatStream => format!("{url}models/{model_name}:streamGenerateContent?key={api_key}"),
};
let url = Self::get_service_url(&model, service_type, endpoint);
let url = format!("{url}?key={api_key}");
let headers = vec![];
let GeminiChatRequestParts { system, contents } = Self::into_gemini_request_parts(model_iden, chat_req)?;
// -- parts
let GeminiChatRequestParts { system, contents } = Self::into_gemini_request_parts(model, chat_req)?;
// -- Playload
let mut payload = json!({
"contents": contents,
});
// -- headers (empty for gemini, since API_KEY is in url)
let headers = vec![];
// Note: It's unclear from the spec if the content of systemInstruction should have a role.
// Right now, it is omitted (since the spec states it can only be "user" or "model")
// It seems to work. https://ai.google.dev/api/rest/v1beta/models/generateContent

View File

@ -1,14 +1,14 @@
use crate::adapter::openai::OpenAIAdapter;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse};
use crate::resolver::{AuthData, Endpoint};
use crate::webc::WebResponse;
use crate::Result;
use crate::{ClientConfig, ModelIden};
use crate::{Result, ServiceTarget};
use reqwest::RequestBuilder;
pub struct GroqAdapter;
const BASE_URL: &str = "https://api.groq.com/openai/v1/";
pub(in crate::adapter) const MODELS: &[&str] = &[
"llama-3.2-90b-vision-preview",
"llama-3.2-11b-vision-preview",
@ -29,28 +29,31 @@ pub(in crate::adapter) const MODELS: &[&str] = &[
// The Groq API adapter is modeled after the OpenAI adapter, as the Groq API is compatible with the OpenAI API.
impl Adapter for GroqAdapter {
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
Some("GROQ_API_KEY")
fn default_endpoint(kind: AdapterKind) -> Endpoint {
const BASE_URL: &str = "https://api.groq.com/openai/v1/";
Endpoint::from_static(BASE_URL)
}
fn default_auth(kind: AdapterKind) -> AuthData {
AuthData::from_env("GROQ_API_KEY")
}
async fn all_model_names(_kind: AdapterKind) -> Result<Vec<String>> {
Ok(MODELS.iter().map(|s| s.to_string()).collect())
}
fn get_service_url(model_iden: ModelIden, service_type: ServiceType) -> String {
OpenAIAdapter::util_get_service_url(model_iden, service_type, BASE_URL)
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
OpenAIAdapter::util_get_service_url(model, service_type, endpoint)
}
fn to_web_request_data(
model_iden: ModelIden,
target: ServiceTarget,
client_config: &ClientConfig,
service_type: ServiceType,
chat_req: ChatRequest,
options_set: ChatOptionsSet<'_, '_>,
chat_options: ChatOptionsSet<'_, '_>,
) -> Result<WebRequestData> {
let url = Self::get_service_url(model_iden.clone(), service_type);
OpenAIAdapter::util_to_web_request_data(model_iden, client_config, chat_req, service_type, options_set, url)
OpenAIAdapter::util_to_web_request_data(target, service_type, chat_req, chat_options)
}
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {

View File

@ -3,8 +3,9 @@
use crate::adapter::openai::OpenAIAdapter;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse};
use crate::resolver::{AuthData, Endpoint};
use crate::webc::WebResponse;
use crate::{ClientConfig, ModelIden};
use crate::{ClientConfig, ModelIden, ServiceTarget};
use crate::{Error, Result};
use reqwest::RequestBuilder;
use serde_json::Value;
@ -12,22 +13,29 @@ use value_ext::JsonValueExt;
pub struct OllamaAdapter;
// The OpenAI Compatibility base URL
const BASE_URL: &str = "http://localhost:11434/v1/";
// const OLLAMA_BASE_URL: &str = "http://localhost:11434/api/";
/// Note: For now, it uses the OpenAI compatibility layer
/// (https://github.com/ollama/ollama/blob/main/docs/openai.md)
/// Since the base Ollama API supports `application/x-ndjson` for streaming, whereas others support `text/event-stream`
impl Adapter for OllamaAdapter {
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
None
fn default_endpoint(kind: AdapterKind) -> Endpoint {
const BASE_URL: &str = "http://localhost:11434/v1/";
Endpoint::from_static(BASE_URL)
}
fn default_auth(kind: AdapterKind) -> AuthData {
AuthData::from_single("ollama")
}
/// Note 1: For now, this adapter is the only one making a full request to the ollama server
/// Note 2: Will the OpenAI API (https://platform.openai.com/docs/api-reference/models/list)
/// Note 2: Will the OpenAI API to talk to Ollam server (https://platform.openai.com/docs/api-reference/models/list)
///
/// TODO: This will use the default endpoint.
/// Later, we might add another function with a endpoint, so the the user can give an custom endpoint.
async fn all_model_names(adapter_kind: AdapterKind) -> Result<Vec<String>> {
let url = format!("{BASE_URL}models");
// FIXME: This is harcoded to the default endpoint, should take endpoint as Argument
let endpoint = Self::default_endpoint(adapter_kind);
let base_url = endpoint.base_url();
let url = format!("{base_url}models");
// TODO: Need to get the WebClient from the client.
let web_c = crate::webc::WebClient::default();
@ -51,20 +59,18 @@ impl Adapter for OllamaAdapter {
Ok(models)
}
fn get_service_url(model_iden: ModelIden, service_type: ServiceType) -> String {
OpenAIAdapter::util_get_service_url(model_iden, service_type, BASE_URL)
fn get_service_url(model_iden: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
OpenAIAdapter::util_get_service_url(model_iden, service_type, endpoint)
}
fn to_web_request_data(
model_iden: ModelIden,
target: ServiceTarget,
client_config: &ClientConfig,
service_type: ServiceType,
chat_req: ChatRequest,
options_set: ChatOptionsSet<'_, '_>,
chat_options: ChatOptionsSet<'_, '_>,
) -> Result<WebRequestData> {
let url = Self::get_service_url(model_iden.clone(), service_type);
OpenAIAdapter::util_to_web_request_data(model_iden, client_config, chat_req, service_type, options_set, url)
OpenAIAdapter::util_to_web_request_data(target, service_type, chat_req, chat_options)
}
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {

View File

@ -1,12 +1,13 @@
use crate::adapter::adapters::support::get_api_key;
use crate::adapter::openai::OpenAIStreamer;
use crate::adapter::support::get_api_key;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{
ChatOptionsSet, ChatRequest, ChatResponse, ChatResponseFormat, ChatRole, ChatStream, ChatStreamResponse,
MessageContent, MetaUsage, ToolCall,
};
use crate::resolver::{AuthData, Endpoint};
use crate::webc::WebResponse;
use crate::{ClientConfig, ModelIden};
use crate::{ClientConfig, ModelIden, ServiceTarget};
use crate::{Error, Result};
use reqwest::RequestBuilder;
use reqwest_eventsource::EventSource;
@ -16,7 +17,6 @@ use value_ext::JsonValueExt;
pub struct OpenAIAdapter;
const BASE_URL: &str = "https://api.openai.com/v1/";
// Latest models
const MODELS: &[&str] = &[
//
@ -27,8 +27,13 @@ const MODELS: &[&str] = &[
];
impl Adapter for OpenAIAdapter {
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
Some("OPENAI_API_KEY")
fn default_endpoint(kind: AdapterKind) -> Endpoint {
const BASE_URL: &str = "https://api.openai.com/v1/";
Endpoint::from_static(BASE_URL)
}
fn default_auth(kind: AdapterKind) -> AuthData {
AuthData::from_env("OPENAI_API_KEY")
}
/// Note: Currently returns the common models (see above)
@ -36,20 +41,18 @@ impl Adapter for OpenAIAdapter {
Ok(MODELS.iter().map(|s| s.to_string()).collect())
}
fn get_service_url(model_iden: ModelIden, service_type: ServiceType) -> String {
Self::util_get_service_url(model_iden, service_type, BASE_URL)
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
Self::util_get_service_url(model, service_type, endpoint)
}
fn to_web_request_data(
model_iden: ModelIden,
target: ServiceTarget,
client_config: &ClientConfig,
service_type: ServiceType,
chat_req: ChatRequest,
chat_options: ChatOptionsSet<'_, '_>,
) -> Result<WebRequestData> {
let url = Self::get_service_url(model_iden.clone(), service_type);
OpenAIAdapter::util_to_web_request_data(model_iden, client_config, chat_req, service_type, chat_options, url)
OpenAIAdapter::util_to_web_request_data(target, service_type, chat_req, chat_options)
}
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
@ -103,38 +106,42 @@ impl Adapter for OpenAIAdapter {
/// Support functions for other adapters that share OpenAI APIs
impl OpenAIAdapter {
pub(in crate::adapter::adapters) fn util_get_service_url(
_model_iden: ModelIden,
_model: &ModelIden,
service_type: ServiceType,
// -- utility arguments
base_url: &str,
default_endpoint: Endpoint,
) -> String {
let base_url = default_endpoint.base_url();
match service_type {
ServiceType::Chat | ServiceType::ChatStream => format!("{base_url}chat/completions"),
}
}
pub(in crate::adapter::adapters) fn util_to_web_request_data(
model_iden: ModelIden,
client_config: &ClientConfig,
chat_req: ChatRequest,
target: ServiceTarget,
service_type: ServiceType,
chat_req: ChatRequest,
options_set: ChatOptionsSet<'_, '_>,
base_url: String,
) -> Result<WebRequestData> {
let stream = matches!(service_type, ServiceType::ChatStream);
let ServiceTarget { model, auth, endpoint } = target;
// -- Get the key
let api_key = get_api_key(model_iden.clone(), client_config)?;
// -- api_key
let api_key = get_api_key(auth, &model)?;
// -- Build the header
// -- url
let url = AdapterDispatcher::get_service_url(&model, service_type, endpoint);
// -- headers
let headers = vec![
// headers
("Authorization".to_string(), format!("Bearer {api_key}")),
];
let stream = matches!(service_type, ServiceType::ChatStream);
// -- Build the basic payload
let model_name = model_iden.model_name.to_string();
let OpenAIRequestParts { messages, tools } = Self::into_openai_request_parts(model_iden, chat_req)?;
let model_name = model.model_name.to_string();
let OpenAIRequestParts { messages, tools } = Self::into_openai_request_parts(model, chat_req)?;
let mut payload = json!({
"model": model_name,
"messages": messages,
@ -202,11 +209,7 @@ impl OpenAIAdapter {
payload.x_insert("top_p", top_p)?;
}
Ok(WebRequestData {
url: base_url,
headers,
payload,
})
Ok(WebRequestData { url, headers, payload })
}
/// Note: Needs to be called from super::streamer as well

View File

@ -2,7 +2,16 @@
//! It should be private to the `crate::adapter::adapters` module.
use crate::chat::{ChatOptionsSet, MetaUsage};
use crate::resolver::AuthData;
use crate::ModelIden;
use crate::{Error, Result};
pub fn get_api_key(auth: AuthData, model: &ModelIden) -> Result<String> {
auth.single_key_value().map_err(|resolver_error| Error::Resolver {
model_iden: model.clone(),
resolver_error,
})
}
// region: --- StreamerChatOptions

View File

@ -6,23 +6,35 @@ use crate::adapter::openai::OpenAIAdapter;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse};
use crate::webc::WebResponse;
use crate::Result;
use crate::{ClientConfig, ModelIden};
use crate::{Result, ServiceTarget};
use reqwest::RequestBuilder;
use super::groq::GroqAdapter;
use crate::resolver::{AuthData, Endpoint};
pub struct AdapterDispatcher;
impl Adapter for AdapterDispatcher {
fn default_key_env_name(kind: AdapterKind) -> Option<&'static str> {
fn default_endpoint(kind: AdapterKind) -> Endpoint {
match kind {
AdapterKind::OpenAI => OpenAIAdapter::default_key_env_name(kind),
AdapterKind::Anthropic => AnthropicAdapter::default_key_env_name(kind),
AdapterKind::Cohere => CohereAdapter::default_key_env_name(kind),
AdapterKind::Ollama => OllamaAdapter::default_key_env_name(kind),
AdapterKind::Gemini => GeminiAdapter::default_key_env_name(kind),
AdapterKind::Groq => GroqAdapter::default_key_env_name(kind),
AdapterKind::OpenAI => OpenAIAdapter::default_endpoint(kind),
AdapterKind::Anthropic => AnthropicAdapter::default_endpoint(kind),
AdapterKind::Cohere => CohereAdapter::default_endpoint(kind),
AdapterKind::Ollama => OllamaAdapter::default_endpoint(kind),
AdapterKind::Gemini => GeminiAdapter::default_endpoint(kind),
AdapterKind::Groq => GroqAdapter::default_endpoint(kind),
}
}
fn default_auth(kind: AdapterKind) -> AuthData {
match kind {
AdapterKind::OpenAI => OpenAIAdapter::default_auth(kind),
AdapterKind::Anthropic => AnthropicAdapter::default_auth(kind),
AdapterKind::Cohere => CohereAdapter::default_auth(kind),
AdapterKind::Ollama => OllamaAdapter::default_auth(kind),
AdapterKind::Gemini => GeminiAdapter::default_auth(kind),
AdapterKind::Groq => GroqAdapter::default_auth(kind),
}
}
@ -37,42 +49,43 @@ impl Adapter for AdapterDispatcher {
}
}
fn get_service_url(model_iden: ModelIden, service_type: ServiceType) -> String {
match model_iden.adapter_kind {
AdapterKind::OpenAI => OpenAIAdapter::get_service_url(model_iden, service_type),
AdapterKind::Anthropic => AnthropicAdapter::get_service_url(model_iden, service_type),
AdapterKind::Cohere => CohereAdapter::get_service_url(model_iden, service_type),
AdapterKind::Ollama => OllamaAdapter::get_service_url(model_iden, service_type),
AdapterKind::Gemini => GeminiAdapter::get_service_url(model_iden, service_type),
AdapterKind::Groq => GroqAdapter::get_service_url(model_iden, service_type),
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
match model.adapter_kind {
AdapterKind::OpenAI => OpenAIAdapter::get_service_url(model, service_type, endpoint),
AdapterKind::Anthropic => AnthropicAdapter::get_service_url(model, service_type, endpoint),
AdapterKind::Cohere => CohereAdapter::get_service_url(model, service_type, endpoint),
AdapterKind::Ollama => OllamaAdapter::get_service_url(model, service_type, endpoint),
AdapterKind::Gemini => GeminiAdapter::get_service_url(model, service_type, endpoint),
AdapterKind::Groq => GroqAdapter::get_service_url(model, service_type, endpoint),
}
}
fn to_web_request_data(
model_iden: ModelIden,
target: ServiceTarget,
client_config: &ClientConfig,
service_type: ServiceType,
chat_req: ChatRequest,
options_set: ChatOptionsSet<'_, '_>,
) -> Result<WebRequestData> {
match model_iden.adapter_kind {
let adapter_kind = &target.model.adapter_kind;
match adapter_kind {
AdapterKind::OpenAI => {
OpenAIAdapter::to_web_request_data(model_iden, client_config, service_type, chat_req, options_set)
OpenAIAdapter::to_web_request_data(target, client_config, service_type, chat_req, options_set)
}
AdapterKind::Anthropic => {
AnthropicAdapter::to_web_request_data(model_iden, client_config, service_type, chat_req, options_set)
AnthropicAdapter::to_web_request_data(target, client_config, service_type, chat_req, options_set)
}
AdapterKind::Cohere => {
CohereAdapter::to_web_request_data(model_iden, client_config, service_type, chat_req, options_set)
CohereAdapter::to_web_request_data(target, client_config, service_type, chat_req, options_set)
}
AdapterKind::Ollama => {
OllamaAdapter::to_web_request_data(model_iden, client_config, service_type, chat_req, options_set)
OllamaAdapter::to_web_request_data(target, client_config, service_type, chat_req, options_set)
}
AdapterKind::Gemini => {
GeminiAdapter::to_web_request_data(model_iden, client_config, service_type, chat_req, options_set)
GeminiAdapter::to_web_request_data(target, client_config, service_type, chat_req, options_set)
}
AdapterKind::Groq => {
GroqAdapter::to_web_request_data(model_iden, client_config, service_type, chat_req, options_set)
GroqAdapter::to_web_request_data(target, client_config, service_type, chat_req, options_set)
}
}
}

View File

@ -2,47 +2,3 @@ use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind};
use crate::resolver::AuthData;
use crate::{ClientConfig, ModelIden};
use crate::{Error, Result};
/// Returns the `api_key` value from the config_set auth_resolver
/// This function should be called if the adapter requires an api_key
/// Fails if there is no auth_resolver or no auth_data
pub fn get_api_key(model_iden: ModelIden, client_config: &ClientConfig) -> Result<String> {
// -- Try to get it from the optional auth_resolver
let auth_data = client_config
.auth_resolver()
.map(|auth_resolver| {
auth_resolver
.resolve(model_iden.clone())
.map_err(|resolver_error| Error::Resolver {
model_iden: model_iden.clone(),
resolver_error,
})
})
.transpose()? // return error if there is an error on auth resolver
.flatten(); // flatten the two options
// -- If there is no auth resolver, get it from the environment name (or 'ollama' for Ollama)
let auth_data = auth_data.or_else(|| {
if let AdapterKind::Ollama = model_iden.adapter_kind {
Some(AuthData::from_single("ollama".to_string()))
} else {
AdapterDispatcher::default_key_env_name(model_iden.adapter_kind)
.map(|env_name| AuthData::FromEnv(env_name.to_string()))
}
});
let Some(auth_data) = auth_data else {
return Err(Error::NoAuthData { model_iden });
};
// TODO: Needs to support multiple values
let key = auth_data
.single_value()
.map_err(|resolver_error| Error::Resolver {
model_iden,
resolver_error,
})?
.to_string();
Ok(key)
}

View File

@ -1,11 +1,16 @@
use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{ChatOptions, ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse};
use crate::client::Client;
use crate::{Error, ModelIden, Result};
use crate::{Client, Error, ModelIden, Result, ServiceTarget};
/// Public AI Functions
impl Client {
/// Returns all the model names for a given adapter kind.
///
/// IMPORTANT:
/// - Besides the Ollama adapter, this will only look at a hardcoded static list of names for now.
/// - For Ollama, it will currently make a live request to the default host/port (http://localhost:11434/v1/).
/// - This function will eventually change to either take an endpoint or have another function to allow a custom endpoint.
///
/// Notes:
/// - Since genai only supports Chat for now, the adapter implementation should attempt to remove the non-chat models.
/// - Later, as genai adds more capabilities, we will have a `model_names(adapter_kind, Option<&[Skill]>)`
@ -15,31 +20,27 @@ impl Client {
Ok(models)
}
/// Resolves the adapter kind for a given model name.
/// Note: This does not use the `all_model_names` function to find a match, but instead relies on hardcoded matching rules.
/// This strategy makes the library more flexible as it does not require updates
/// when the AI Provider adds new models (assuming they follow a consistent naming pattern).
///
/// See [AdapterKind::from_model]
///
/// [AdapterKind::from_model]: crate::adapter::AdapterKind::from_model
pub fn resolve_model_iden(&self, model_name: &str) -> Result<ModelIden> {
/// Return the default model for a model_name str.
/// This is used before
pub fn default_model(&self, model_name: &str) -> Result<ModelIden> {
// -- First get the default ModelInfo
let adapter_kind = AdapterKind::from_model(model_name)?;
let model_iden = ModelIden::new(adapter_kind, model_name);
// -- Execute the optional model_mapper
let model_iden = if let Some(model_mapper) = self.config().model_mapper() {
model_mapper
.map_model(model_iden.clone())
.map_err(|cause| Error::ModelMapperFailed { model_iden, cause })?
} else {
model_iden
};
Ok(model_iden)
}
#[deprecated(note = "use `client.resolve_service_target(model_name)")]
pub fn resolve_model_iden(&self, model_name: &str) -> Result<ModelIden> {
let model = self.default_model(model_name)?;
let target = self.config().resolve_service_target(model)?;
Ok(target.model)
}
pub fn resolve_service_target(&self, model_name: &str) -> Result<ServiceTarget> {
let model = self.default_model(model_name)?;
self.config().resolve_service_target(model)
}
/// Executes a chat.
pub async fn exec_chat(
&self,
@ -48,30 +49,27 @@ impl Client {
// options not implemented yet
options: Option<&ChatOptions>,
) -> Result<ChatResponse> {
let model_iden = self.resolve_model_iden(model)?;
let options_set = ChatOptionsSet::default()
.with_chat_options(options)
.with_client_options(self.config().chat_options());
let WebRequestData { headers, payload, url } = AdapterDispatcher::to_web_request_data(
model_iden.clone(),
self.config(),
ServiceType::Chat,
chat_req,
options_set,
)?;
let model = self.default_model(model)?;
let target = self.config().resolve_service_target(model)?;
let model = target.model.clone();
let WebRequestData { headers, payload, url } =
AdapterDispatcher::to_web_request_data(target, self.config(), ServiceType::Chat, chat_req, options_set)?;
let web_res =
self.web_client()
.do_post(&url, &headers, payload)
.await
.map_err(|webc_error| Error::WebModelCall {
model_iden: model_iden.clone(),
model_iden: model.clone(),
webc_error,
})?;
let chat_res = AdapterDispatcher::to_chat_response(model_iden, web_res)?;
let chat_res = AdapterDispatcher::to_chat_response(model, web_res)?;
Ok(chat_res)
}
@ -83,14 +81,16 @@ impl Client {
chat_req: ChatRequest, // options not implemented yet
options: Option<&ChatOptions>,
) -> Result<ChatStreamResponse> {
let model_iden = self.resolve_model_iden(model)?;
let options_set = ChatOptionsSet::default()
.with_chat_options(options)
.with_client_options(self.config().chat_options());
let model = self.default_model(model)?;
let target = self.config().resolve_service_target(model)?;
let model = target.model.clone();
let WebRequestData { url, headers, payload } = AdapterDispatcher::to_web_request_data(
model_iden.clone(),
target,
self.config(),
ServiceType::ChatStream,
chat_req,
@ -101,11 +101,11 @@ impl Client {
.web_client()
.new_req_builder(&url, &headers, payload)
.map_err(|webc_error| Error::WebModelCall {
model_iden: model_iden.clone(),
model_iden: model.clone(),
webc_error,
})?;
let res = AdapterDispatcher::to_chat_stream(model_iden, reqwest_builder, options_set)?;
let res = AdapterDispatcher::to_chat_stream(model, reqwest_builder, options_set)?;
Ok(res)
}

View File

@ -1,5 +1,8 @@
use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind};
use crate::chat::ChatOptions;
use crate::resolver::{AuthResolver, ModelMapper};
use crate::client::ServiceTarget;
use crate::resolver::{AuthResolver, Endpoint, ModelMapper};
use crate::{Error, ModelIden, Result};
/// The Client configuration used in the configuration builder stage.
#[derive(Debug, Default, Clone)]
@ -47,3 +50,37 @@ impl ClientConfig {
self.chat_options.as_ref()
}
}
/// Resolvers
impl ClientConfig {
pub fn resolve_service_target(&self, model: ModelIden) -> Result<ServiceTarget> {
// -- Resolve the Model first
let model = match self.model_mapper() {
Some(model_mapper) => model_mapper.map_model(model.clone()),
None => Ok(model.clone()),
}
.map_err(|resolver_error| Error::Resolver {
model_iden: model.clone(),
resolver_error,
})?;
// -- Get the auth
let auth = self
.auth_resolver()
.map(|auth_resolver| {
auth_resolver.resolve(model.clone()).map_err(|resolver_error| Error::Resolver {
model_iden: model.clone(),
resolver_error,
})
})
.transpose()? // return error if there is an error on auth resolver
.flatten()
.unwrap_or_else(|| AdapterDispatcher::default_auth(model.adapter_kind)); // flatten the two options
// -- Get the default endpoint
// For now, just get the default endpoint, the `resolve_target` will allow to override it
let endpoint = AdapterDispatcher::default_endpoint(model.adapter_kind);
Ok(ServiceTarget { model, auth, endpoint })
}
}

View File

@ -4,9 +4,11 @@ mod builder;
mod client_impl;
mod client_types;
mod config;
mod service_target;
pub use builder::*;
pub use client_types::*;
pub use config::*;
pub use service_target::*;
// endregion: --- Modules

View File

@ -0,0 +1,14 @@
use crate::resolver::{AuthData, Endpoint};
use crate::ModelIden;
/// A ServiceTarget represents the destination and necessary details for making a service call.
///
/// This structure contains:
/// - `endpoint`: The specific service endpoint to be contacted.
/// - `auth`: The authentication data required to access the service.
/// - `model`: The identifier of the model or resource associated with the service call.
pub struct ServiceTarget {
pub endpoint: Endpoint,
pub auth: AuthData,
pub model: ModelIden,
}

View File

@ -36,7 +36,7 @@ impl AuthData {
/// Getters
impl AuthData {
/// Get the single value from the `AuthData`.
pub fn single_value(&self) -> Result<String> {
pub fn single_key_value(&self) -> Result<String> {
match self {
AuthData::FromEnv(env_name) => {
// Get value from the environment name.

40
src/resolver/endpoint.rs Normal file
View File

@ -0,0 +1,40 @@
use std::sync::Arc;
/// A construct to store the endpoint of a service.
/// It is designed to be efficiently clonable.
/// For now, it just supports `base_url` but later might have other URLs per "service name".
#[derive(Debug, Clone)]
pub struct Endpoint {
inner: EndpointInner,
}
#[derive(Debug, Clone)]
enum EndpointInner {
Static(&'static str),
Owned(Arc<str>),
}
/// Constructors
impl Endpoint {
pub fn from_static(url: &'static str) -> Self {
Endpoint {
inner: EndpointInner::Static(url),
}
}
pub fn from_owned(url: impl Into<Arc<str>>) -> Self {
Endpoint {
inner: EndpointInner::Owned(url.into()),
}
}
}
/// Getters
impl Endpoint {
pub fn base_url(&self) -> &str {
match &self.inner {
EndpointInner::Static(url) => url,
EndpointInner::Owned(url) => url,
}
}
}

View File

@ -7,11 +7,13 @@
mod auth_data;
mod auth_resolver;
mod endpoint;
mod error;
mod model_mapper;
pub use auth_data::*;
pub use auth_resolver::*;
pub use endpoint::*;
pub use error::{Error, Result};
pub use model_mapper::*;