mirror of
https://github.com/mii443/rust-genai.git
synced 2025-08-22 16:25:27 +00:00
* prep refactor for custom endpoint
This commit is contained in:
@ -11,7 +11,7 @@ repository = "https://github.com/jeremychone/rust-genai"
|
||||
|
||||
[lints.rust]
|
||||
unsafe_code = "forbid"
|
||||
# unused = { level = "allow", priority = -1 } # For exploratory dev.
|
||||
unused = { level = "allow", priority = -1 } # For exploratory dev.
|
||||
# missing_docs = "warn"
|
||||
|
||||
[dependencies]
|
||||
|
@ -54,7 +54,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
continue;
|
||||
}
|
||||
|
||||
let adapter_kind = client.resolve_model_iden(model)?.adapter_kind;
|
||||
let adapter_kind = client.resolve_service_target(model)?.model.adapter_kind;
|
||||
|
||||
println!("\n===== MODEL: {model} ({adapter_kind}) =====");
|
||||
|
||||
|
@ -36,7 +36,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("\n--- Question:\n{question}");
|
||||
let chat_res = client.exec_chat_stream(MODEL, chat_req.clone(), Some(&options)).await?;
|
||||
|
||||
let adapter_kind = client.resolve_model_iden(MODEL)?.adapter_kind;
|
||||
let adapter_kind = client.resolve_service_target(MODEL)?.model.adapter_kind;
|
||||
println!("\n--- Answer: ({MODEL} - {adapter_kind})");
|
||||
print_chat_stream(chat_res, None).await?;
|
||||
|
||||
|
@ -51,12 +51,12 @@ impl AdapterKind {
|
||||
}
|
||||
|
||||
/// Utilities
|
||||
impl AdapterKind {
|
||||
/// Get the default key environment variable name for the adapter kind.
|
||||
pub fn default_key_env_name(&self) -> Option<&'static str> {
|
||||
AdapterDispatcher::default_key_env_name(*self)
|
||||
}
|
||||
}
|
||||
// impl AdapterKind {
|
||||
// /// Get the default key environment variable name for the adapter kind.
|
||||
// pub fn default_key_env_name(&self) -> Option<&'static str> {
|
||||
// AdapterDispatcher::default_key_env_name(*self)
|
||||
// }
|
||||
// }
|
||||
|
||||
/// From Model implementations
|
||||
impl AdapterKind {
|
||||
|
@ -1,24 +1,30 @@
|
||||
use crate::adapter::AdapterKind;
|
||||
use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse};
|
||||
use crate::resolver::{AuthData, Endpoint};
|
||||
use crate::webc::WebResponse;
|
||||
use crate::Result;
|
||||
use crate::{ClientConfig, ModelIden};
|
||||
use crate::{Result, ServiceTarget};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde_json::Value;
|
||||
|
||||
pub trait Adapter {
|
||||
fn default_key_env_name(kind: AdapterKind) -> Option<&'static str>;
|
||||
// #[deprecated(note = "use default_auth")]
|
||||
// fn default_key_env_name(kind: AdapterKind) -> Option<&'static str>;
|
||||
|
||||
fn default_auth(kind: AdapterKind) -> AuthData;
|
||||
|
||||
fn default_endpoint(kind: AdapterKind) -> Endpoint;
|
||||
|
||||
// NOTE: Adapter is a crate trait, so it is acceptable to use async fn here.
|
||||
async fn all_model_names(kind: AdapterKind) -> Result<Vec<String>>;
|
||||
|
||||
/// The base service URL for this AdapterKind for the given service type.
|
||||
/// NOTE: For some services, the URL will be further updated in the to_web_request_data method.
|
||||
fn get_service_url(model_iden: ModelIden, service_type: ServiceType) -> String;
|
||||
fn get_service_url(model_iden: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String;
|
||||
|
||||
/// To be implemented by Adapters.
|
||||
fn to_web_request_data(
|
||||
model_iden: ModelIden,
|
||||
service_target: ServiceTarget,
|
||||
config_set: &ClientConfig,
|
||||
service_type: ServiceType,
|
||||
chat_req: ChatRequest,
|
||||
|
@ -1,13 +1,14 @@
|
||||
use crate::adapter::adapters::support::get_api_key;
|
||||
use crate::adapter::anthropic::AnthropicStreamer;
|
||||
use crate::adapter::support::get_api_key;
|
||||
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
|
||||
use crate::chat::{
|
||||
ChatOptionsSet, ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse, MessageContent, MetaUsage,
|
||||
ToolCall,
|
||||
};
|
||||
use crate::resolver::{AuthData, Endpoint};
|
||||
use crate::webc::WebResponse;
|
||||
use crate::Result;
|
||||
use crate::{ClientConfig, ModelIden};
|
||||
use crate::{Result, ServiceTarget};
|
||||
use reqwest::RequestBuilder;
|
||||
use reqwest_eventsource::EventSource;
|
||||
use serde_json::{json, Value};
|
||||
@ -15,8 +16,6 @@ use value_ext::JsonValueExt;
|
||||
|
||||
pub struct AnthropicAdapter;
|
||||
|
||||
const BASE_URL: &str = "https://api.anthropic.com/v1/";
|
||||
|
||||
// NOTE: For Anthropic, the max_tokens must be specified.
|
||||
// To avoid surprises, the default value for genai is the maximum for a given model.
|
||||
// The 3-5 models have an 8k max token limit, while the 3 models have a 4k limit.
|
||||
@ -32,8 +31,13 @@ const MODELS: &[&str] = &[
|
||||
];
|
||||
|
||||
impl Adapter for AnthropicAdapter {
|
||||
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
|
||||
Some("ANTHROPIC_API_KEY")
|
||||
fn default_endpoint(kind: AdapterKind) -> Endpoint {
|
||||
const BASE_URL: &str = "https://api.anthropic.com/v1/";
|
||||
Endpoint::from_static(BASE_URL)
|
||||
}
|
||||
|
||||
fn default_auth(kind: AdapterKind) -> AuthData {
|
||||
AuthData::from_env("ANTHROPIC_API_KEY")
|
||||
}
|
||||
|
||||
/// Note: For now, it returns the common models (see above)
|
||||
@ -41,40 +45,47 @@ impl Adapter for AnthropicAdapter {
|
||||
Ok(MODELS.iter().map(|s| s.to_string()).collect())
|
||||
}
|
||||
|
||||
fn get_service_url(_model_iden: ModelIden, service_type: ServiceType) -> String {
|
||||
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
|
||||
let base_url = endpoint.base_url();
|
||||
match service_type {
|
||||
ServiceType::Chat | ServiceType::ChatStream => format!("{BASE_URL}messages"),
|
||||
ServiceType::Chat | ServiceType::ChatStream => format!("{base_url}messages"),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_web_request_data(
|
||||
model_iden: ModelIden,
|
||||
target: ServiceTarget,
|
||||
client_config: &ClientConfig,
|
||||
service_type: ServiceType,
|
||||
chat_req: ChatRequest,
|
||||
options_set: ChatOptionsSet<'_, '_>,
|
||||
) -> Result<WebRequestData> {
|
||||
let model_name = model_iden.model_name.clone();
|
||||
let ServiceTarget { endpoint, auth, model } = target;
|
||||
|
||||
let stream = matches!(service_type, ServiceType::ChatStream);
|
||||
let url = Self::get_service_url(model_iden.clone(), service_type);
|
||||
// -- api_key
|
||||
let api_key = get_api_key(auth, &model)?;
|
||||
|
||||
// -- api_key (this Adapter requires it)
|
||||
let api_key = get_api_key(model_iden.clone(), client_config)?;
|
||||
// -- url
|
||||
let url = Self::get_service_url(&model, service_type, endpoint);
|
||||
|
||||
// -- headers
|
||||
let headers = vec![
|
||||
// headers
|
||||
("x-api-key".to_string(), api_key.to_string()),
|
||||
("x-api-key".to_string(), api_key),
|
||||
("anthropic-version".to_string(), ANTRHOPIC_VERSION.to_string()),
|
||||
];
|
||||
|
||||
let model_name = model.model_name.clone();
|
||||
|
||||
// -- Parts
|
||||
let AnthropicRequestParts {
|
||||
system,
|
||||
messages,
|
||||
tools,
|
||||
} = Self::into_anthropic_request_parts(model_iden.clone(), chat_req)?;
|
||||
} = Self::into_anthropic_request_parts(model, chat_req)?;
|
||||
|
||||
// -- Build the basic payload
|
||||
|
||||
let stream = matches!(service_type, ServiceType::ChatStream);
|
||||
let mut payload = json!({
|
||||
"model": model_name.to_string(),
|
||||
"messages": messages,
|
||||
@ -99,7 +110,7 @@ impl Adapter for AnthropicAdapter {
|
||||
}
|
||||
|
||||
let max_tokens = options_set.max_tokens().unwrap_or_else(|| {
|
||||
if model_iden.model_name.contains("3-5") {
|
||||
if model_name.contains("3-5") {
|
||||
MAX_TOKENS_8K
|
||||
} else {
|
||||
MAX_TOKENS_4K
|
||||
|
@ -1,11 +1,12 @@
|
||||
use crate::adapter::adapters::support::get_api_key;
|
||||
use crate::adapter::cohere::CohereStreamer;
|
||||
use crate::adapter::support::get_api_key;
|
||||
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
|
||||
use crate::chat::{
|
||||
ChatOptionsSet, ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse, MessageContent, MetaUsage,
|
||||
};
|
||||
use crate::resolver::{AuthData, Endpoint};
|
||||
use crate::webc::{WebResponse, WebStream};
|
||||
use crate::{ClientConfig, ModelIden};
|
||||
use crate::{ClientConfig, ModelIden, ServiceTarget};
|
||||
use crate::{Error, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde_json::{json, Value};
|
||||
@ -13,7 +14,6 @@ use value_ext::JsonValueExt;
|
||||
|
||||
pub struct CohereAdapter;
|
||||
|
||||
const BASE_URL: &str = "https://api.cohere.com/v1/";
|
||||
const MODELS: &[&str] = &[
|
||||
"command-r-plus",
|
||||
"command-r",
|
||||
@ -24,8 +24,13 @@ const MODELS: &[&str] = &[
|
||||
];
|
||||
|
||||
impl Adapter for CohereAdapter {
|
||||
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
|
||||
Some("COHERE_API_KEY")
|
||||
fn default_endpoint(kind: AdapterKind) -> Endpoint {
|
||||
const BASE_URL: &str = "https://api.cohere.com/v1/";
|
||||
Endpoint::from_static(BASE_URL)
|
||||
}
|
||||
|
||||
fn default_auth(kind: AdapterKind) -> AuthData {
|
||||
AuthData::from_env("COHERE_API_KEY")
|
||||
}
|
||||
|
||||
/// Note: For now, it returns the common ones (see above)
|
||||
@ -33,40 +38,45 @@ impl Adapter for CohereAdapter {
|
||||
Ok(MODELS.iter().map(|s| s.to_string()).collect())
|
||||
}
|
||||
|
||||
fn get_service_url(_model_iden: ModelIden, service_type: ServiceType) -> String {
|
||||
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
|
||||
let base_url = endpoint.base_url();
|
||||
match service_type {
|
||||
ServiceType::Chat | ServiceType::ChatStream => format!("{BASE_URL}chat"),
|
||||
ServiceType::Chat | ServiceType::ChatStream => format!("{base_url}chat"),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_web_request_data(
|
||||
model_iden: ModelIden,
|
||||
target: ServiceTarget,
|
||||
client_config: &ClientConfig,
|
||||
service_type: ServiceType,
|
||||
chat_req: ChatRequest,
|
||||
options_set: ChatOptionsSet<'_, '_>,
|
||||
) -> Result<WebRequestData> {
|
||||
let model_name = model_iden.model_name.clone();
|
||||
|
||||
let stream = matches!(service_type, ServiceType::ChatStream);
|
||||
|
||||
let url = Self::get_service_url(model_iden.clone(), service_type);
|
||||
let ServiceTarget { endpoint, auth, model } = target;
|
||||
|
||||
// -- api_key (this Adapter requires it)
|
||||
let api_key = get_api_key(model_iden.clone(), client_config)?;
|
||||
let api_key = get_api_key(auth, &model)?;
|
||||
|
||||
// -- url
|
||||
let url = Self::get_service_url(&model, service_type, endpoint);
|
||||
|
||||
// -- headers
|
||||
let headers = vec![
|
||||
// headers
|
||||
("Authorization".to_string(), format!("Bearer {api_key}")),
|
||||
];
|
||||
|
||||
let model_name = model.model_name.clone();
|
||||
|
||||
// -- parts
|
||||
let CohereChatRequestParts {
|
||||
preamble,
|
||||
message,
|
||||
chat_history,
|
||||
} = Self::into_cohere_request_parts(model_iden, chat_req)?;
|
||||
} = Self::into_cohere_request_parts(model, chat_req)?;
|
||||
|
||||
// -- Build the basic payload
|
||||
let stream = matches!(service_type, ServiceType::ChatStream);
|
||||
let mut payload = json!({
|
||||
"model": model_name.to_string(),
|
||||
"message": message,
|
||||
|
@ -1,12 +1,13 @@
|
||||
use crate::adapter::adapters::support::get_api_key;
|
||||
use crate::adapter::gemini::GeminiStreamer;
|
||||
use crate::adapter::support::get_api_key;
|
||||
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
|
||||
use crate::chat::{
|
||||
ChatOptionsSet, ChatRequest, ChatResponse, ChatResponseFormat, ChatRole, ChatStream, ChatStreamResponse,
|
||||
MessageContent, MetaUsage,
|
||||
};
|
||||
use crate::resolver::{AuthData, Endpoint};
|
||||
use crate::webc::{WebResponse, WebStream};
|
||||
use crate::{ClientConfig, ModelIden};
|
||||
use crate::{ClientConfig, ModelIden, ServiceTarget};
|
||||
use crate::{Error, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde_json::{json, Value};
|
||||
@ -14,7 +15,6 @@ use value_ext::JsonValueExt;
|
||||
|
||||
pub struct GeminiAdapter;
|
||||
|
||||
const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
|
||||
const MODELS: &[&str] = &[
|
||||
"gemini-1.5-pro",
|
||||
"gemini-1.5-flash",
|
||||
@ -29,8 +29,13 @@ const MODELS: &[&str] = &[
|
||||
// -X POST 'https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=YOUR_API_KEY'
|
||||
|
||||
impl Adapter for GeminiAdapter {
|
||||
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
|
||||
Some("GEMINI_API_KEY")
|
||||
fn default_endpoint(kind: AdapterKind) -> Endpoint {
|
||||
const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
|
||||
Endpoint::from_static(BASE_URL)
|
||||
}
|
||||
|
||||
fn default_auth(kind: AdapterKind) -> AuthData {
|
||||
AuthData::from_env("GEMINI_API_KEY")
|
||||
}
|
||||
|
||||
/// Note: For now, this returns the common models (see above)
|
||||
@ -38,40 +43,48 @@ impl Adapter for GeminiAdapter {
|
||||
Ok(MODELS.iter().map(|s| s.to_string()).collect())
|
||||
}
|
||||
|
||||
fn get_service_url(_model_iden: ModelIden, service_type: ServiceType) -> String {
|
||||
/// NOTE: As Google Gemini has decided to put their API_KEY in the URL,
|
||||
/// this will return the URL without the API_KEY in it. The API_KEY will need to be added by the caller.
|
||||
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
|
||||
let base_url = endpoint.base_url();
|
||||
let model_name = model.model_name.clone();
|
||||
match service_type {
|
||||
ServiceType::Chat | ServiceType::ChatStream => BASE_URL.to_string(),
|
||||
ServiceType::Chat => format!("{base_url}models/{model_name}:generateContent"),
|
||||
ServiceType::ChatStream => format!("{base_url}models/{model_name}:streamGenerateContent"),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_web_request_data(
|
||||
model_iden: ModelIden,
|
||||
target: ServiceTarget,
|
||||
client_config: &ClientConfig,
|
||||
service_type: ServiceType,
|
||||
chat_req: ChatRequest,
|
||||
options_set: ChatOptionsSet<'_, '_>,
|
||||
) -> Result<WebRequestData> {
|
||||
let api_key = get_api_key(model_iden.clone(), client_config)?;
|
||||
let ServiceTarget { endpoint, auth, model } = target;
|
||||
|
||||
// For Gemini, the service URL returned is just the base URL
|
||||
// since the model and API key are part of the URL (see below)
|
||||
let url = Self::get_service_url(model_iden.clone(), service_type);
|
||||
// -- api_key
|
||||
let api_key = get_api_key(auth, &model)?;
|
||||
|
||||
// -- url
|
||||
// NOTE: Somehow, Google decided to put the API key in the URL.
|
||||
// This should be considered an antipattern from a security point of view
|
||||
// even if it is done by the well respected Google. Everybody can make mistake once in a while.
|
||||
// e.g., '...models/gemini-1.5-flash-latest:generateContent?key=YOUR_API_KEY'
|
||||
let model_name = &*model_iden.model_name;
|
||||
let url = match service_type {
|
||||
ServiceType::Chat => format!("{url}models/{model_name}:generateContent?key={api_key}"),
|
||||
ServiceType::ChatStream => format!("{url}models/{model_name}:streamGenerateContent?key={api_key}"),
|
||||
};
|
||||
let url = Self::get_service_url(&model, service_type, endpoint);
|
||||
let url = format!("{url}?key={api_key}");
|
||||
|
||||
let headers = vec![];
|
||||
|
||||
let GeminiChatRequestParts { system, contents } = Self::into_gemini_request_parts(model_iden, chat_req)?;
|
||||
// -- parts
|
||||
let GeminiChatRequestParts { system, contents } = Self::into_gemini_request_parts(model, chat_req)?;
|
||||
|
||||
// -- Playload
|
||||
let mut payload = json!({
|
||||
"contents": contents,
|
||||
});
|
||||
|
||||
// -- headers (empty for gemini, since API_KEY is in url)
|
||||
let headers = vec![];
|
||||
|
||||
// Note: It's unclear from the spec if the content of systemInstruction should have a role.
|
||||
// Right now, it is omitted (since the spec states it can only be "user" or "model")
|
||||
// It seems to work. https://ai.google.dev/api/rest/v1beta/models/generateContent
|
||||
|
@ -1,14 +1,14 @@
|
||||
use crate::adapter::openai::OpenAIAdapter;
|
||||
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
|
||||
use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse};
|
||||
use crate::resolver::{AuthData, Endpoint};
|
||||
use crate::webc::WebResponse;
|
||||
use crate::Result;
|
||||
use crate::{ClientConfig, ModelIden};
|
||||
use crate::{Result, ServiceTarget};
|
||||
use reqwest::RequestBuilder;
|
||||
|
||||
pub struct GroqAdapter;
|
||||
|
||||
const BASE_URL: &str = "https://api.groq.com/openai/v1/";
|
||||
pub(in crate::adapter) const MODELS: &[&str] = &[
|
||||
"llama-3.2-90b-vision-preview",
|
||||
"llama-3.2-11b-vision-preview",
|
||||
@ -29,28 +29,31 @@ pub(in crate::adapter) const MODELS: &[&str] = &[
|
||||
|
||||
// The Groq API adapter is modeled after the OpenAI adapter, as the Groq API is compatible with the OpenAI API.
|
||||
impl Adapter for GroqAdapter {
|
||||
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
|
||||
Some("GROQ_API_KEY")
|
||||
fn default_endpoint(kind: AdapterKind) -> Endpoint {
|
||||
const BASE_URL: &str = "https://api.groq.com/openai/v1/";
|
||||
Endpoint::from_static(BASE_URL)
|
||||
}
|
||||
|
||||
fn default_auth(kind: AdapterKind) -> AuthData {
|
||||
AuthData::from_env("GROQ_API_KEY")
|
||||
}
|
||||
|
||||
async fn all_model_names(_kind: AdapterKind) -> Result<Vec<String>> {
|
||||
Ok(MODELS.iter().map(|s| s.to_string()).collect())
|
||||
}
|
||||
|
||||
fn get_service_url(model_iden: ModelIden, service_type: ServiceType) -> String {
|
||||
OpenAIAdapter::util_get_service_url(model_iden, service_type, BASE_URL)
|
||||
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
|
||||
OpenAIAdapter::util_get_service_url(model, service_type, endpoint)
|
||||
}
|
||||
|
||||
fn to_web_request_data(
|
||||
model_iden: ModelIden,
|
||||
target: ServiceTarget,
|
||||
client_config: &ClientConfig,
|
||||
service_type: ServiceType,
|
||||
chat_req: ChatRequest,
|
||||
options_set: ChatOptionsSet<'_, '_>,
|
||||
chat_options: ChatOptionsSet<'_, '_>,
|
||||
) -> Result<WebRequestData> {
|
||||
let url = Self::get_service_url(model_iden.clone(), service_type);
|
||||
|
||||
OpenAIAdapter::util_to_web_request_data(model_iden, client_config, chat_req, service_type, options_set, url)
|
||||
OpenAIAdapter::util_to_web_request_data(target, service_type, chat_req, chat_options)
|
||||
}
|
||||
|
||||
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
|
||||
|
@ -3,8 +3,9 @@
|
||||
use crate::adapter::openai::OpenAIAdapter;
|
||||
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
|
||||
use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse};
|
||||
use crate::resolver::{AuthData, Endpoint};
|
||||
use crate::webc::WebResponse;
|
||||
use crate::{ClientConfig, ModelIden};
|
||||
use crate::{ClientConfig, ModelIden, ServiceTarget};
|
||||
use crate::{Error, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use serde_json::Value;
|
||||
@ -12,22 +13,29 @@ use value_ext::JsonValueExt;
|
||||
|
||||
pub struct OllamaAdapter;
|
||||
|
||||
// The OpenAI Compatibility base URL
|
||||
const BASE_URL: &str = "http://localhost:11434/v1/";
|
||||
// const OLLAMA_BASE_URL: &str = "http://localhost:11434/api/";
|
||||
|
||||
/// Note: For now, it uses the OpenAI compatibility layer
|
||||
/// (https://github.com/ollama/ollama/blob/main/docs/openai.md)
|
||||
/// Since the base Ollama API supports `application/x-ndjson` for streaming, whereas others support `text/event-stream`
|
||||
impl Adapter for OllamaAdapter {
|
||||
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
|
||||
None
|
||||
fn default_endpoint(kind: AdapterKind) -> Endpoint {
|
||||
const BASE_URL: &str = "http://localhost:11434/v1/";
|
||||
Endpoint::from_static(BASE_URL)
|
||||
}
|
||||
|
||||
fn default_auth(kind: AdapterKind) -> AuthData {
|
||||
AuthData::from_single("ollama")
|
||||
}
|
||||
|
||||
/// Note 1: For now, this adapter is the only one making a full request to the ollama server
|
||||
/// Note 2: Will the OpenAI API (https://platform.openai.com/docs/api-reference/models/list)
|
||||
/// Note 2: Will the OpenAI API to talk to Ollam server (https://platform.openai.com/docs/api-reference/models/list)
|
||||
///
|
||||
/// TODO: This will use the default endpoint.
|
||||
/// Later, we might add another function with a endpoint, so the the user can give an custom endpoint.
|
||||
async fn all_model_names(adapter_kind: AdapterKind) -> Result<Vec<String>> {
|
||||
let url = format!("{BASE_URL}models");
|
||||
// FIXME: This is harcoded to the default endpoint, should take endpoint as Argument
|
||||
let endpoint = Self::default_endpoint(adapter_kind);
|
||||
let base_url = endpoint.base_url();
|
||||
let url = format!("{base_url}models");
|
||||
|
||||
// TODO: Need to get the WebClient from the client.
|
||||
let web_c = crate::webc::WebClient::default();
|
||||
@ -51,20 +59,18 @@ impl Adapter for OllamaAdapter {
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
fn get_service_url(model_iden: ModelIden, service_type: ServiceType) -> String {
|
||||
OpenAIAdapter::util_get_service_url(model_iden, service_type, BASE_URL)
|
||||
fn get_service_url(model_iden: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
|
||||
OpenAIAdapter::util_get_service_url(model_iden, service_type, endpoint)
|
||||
}
|
||||
|
||||
fn to_web_request_data(
|
||||
model_iden: ModelIden,
|
||||
target: ServiceTarget,
|
||||
client_config: &ClientConfig,
|
||||
service_type: ServiceType,
|
||||
chat_req: ChatRequest,
|
||||
options_set: ChatOptionsSet<'_, '_>,
|
||||
chat_options: ChatOptionsSet<'_, '_>,
|
||||
) -> Result<WebRequestData> {
|
||||
let url = Self::get_service_url(model_iden.clone(), service_type);
|
||||
|
||||
OpenAIAdapter::util_to_web_request_data(model_iden, client_config, chat_req, service_type, options_set, url)
|
||||
OpenAIAdapter::util_to_web_request_data(target, service_type, chat_req, chat_options)
|
||||
}
|
||||
|
||||
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
|
||||
|
@ -1,12 +1,13 @@
|
||||
use crate::adapter::adapters::support::get_api_key;
|
||||
use crate::adapter::openai::OpenAIStreamer;
|
||||
use crate::adapter::support::get_api_key;
|
||||
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
|
||||
use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind, ServiceType, WebRequestData};
|
||||
use crate::chat::{
|
||||
ChatOptionsSet, ChatRequest, ChatResponse, ChatResponseFormat, ChatRole, ChatStream, ChatStreamResponse,
|
||||
MessageContent, MetaUsage, ToolCall,
|
||||
};
|
||||
use crate::resolver::{AuthData, Endpoint};
|
||||
use crate::webc::WebResponse;
|
||||
use crate::{ClientConfig, ModelIden};
|
||||
use crate::{ClientConfig, ModelIden, ServiceTarget};
|
||||
use crate::{Error, Result};
|
||||
use reqwest::RequestBuilder;
|
||||
use reqwest_eventsource::EventSource;
|
||||
@ -16,7 +17,6 @@ use value_ext::JsonValueExt;
|
||||
|
||||
pub struct OpenAIAdapter;
|
||||
|
||||
const BASE_URL: &str = "https://api.openai.com/v1/";
|
||||
// Latest models
|
||||
const MODELS: &[&str] = &[
|
||||
//
|
||||
@ -27,8 +27,13 @@ const MODELS: &[&str] = &[
|
||||
];
|
||||
|
||||
impl Adapter for OpenAIAdapter {
|
||||
fn default_key_env_name(_kind: AdapterKind) -> Option<&'static str> {
|
||||
Some("OPENAI_API_KEY")
|
||||
fn default_endpoint(kind: AdapterKind) -> Endpoint {
|
||||
const BASE_URL: &str = "https://api.openai.com/v1/";
|
||||
Endpoint::from_static(BASE_URL)
|
||||
}
|
||||
|
||||
fn default_auth(kind: AdapterKind) -> AuthData {
|
||||
AuthData::from_env("OPENAI_API_KEY")
|
||||
}
|
||||
|
||||
/// Note: Currently returns the common models (see above)
|
||||
@ -36,20 +41,18 @@ impl Adapter for OpenAIAdapter {
|
||||
Ok(MODELS.iter().map(|s| s.to_string()).collect())
|
||||
}
|
||||
|
||||
fn get_service_url(model_iden: ModelIden, service_type: ServiceType) -> String {
|
||||
Self::util_get_service_url(model_iden, service_type, BASE_URL)
|
||||
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
|
||||
Self::util_get_service_url(model, service_type, endpoint)
|
||||
}
|
||||
|
||||
fn to_web_request_data(
|
||||
model_iden: ModelIden,
|
||||
target: ServiceTarget,
|
||||
client_config: &ClientConfig,
|
||||
service_type: ServiceType,
|
||||
chat_req: ChatRequest,
|
||||
chat_options: ChatOptionsSet<'_, '_>,
|
||||
) -> Result<WebRequestData> {
|
||||
let url = Self::get_service_url(model_iden.clone(), service_type);
|
||||
|
||||
OpenAIAdapter::util_to_web_request_data(model_iden, client_config, chat_req, service_type, chat_options, url)
|
||||
OpenAIAdapter::util_to_web_request_data(target, service_type, chat_req, chat_options)
|
||||
}
|
||||
|
||||
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
|
||||
@ -103,38 +106,42 @@ impl Adapter for OpenAIAdapter {
|
||||
/// Support functions for other adapters that share OpenAI APIs
|
||||
impl OpenAIAdapter {
|
||||
pub(in crate::adapter::adapters) fn util_get_service_url(
|
||||
_model_iden: ModelIden,
|
||||
_model: &ModelIden,
|
||||
service_type: ServiceType,
|
||||
// -- utility arguments
|
||||
base_url: &str,
|
||||
default_endpoint: Endpoint,
|
||||
) -> String {
|
||||
let base_url = default_endpoint.base_url();
|
||||
match service_type {
|
||||
ServiceType::Chat | ServiceType::ChatStream => format!("{base_url}chat/completions"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(in crate::adapter::adapters) fn util_to_web_request_data(
|
||||
model_iden: ModelIden,
|
||||
client_config: &ClientConfig,
|
||||
chat_req: ChatRequest,
|
||||
target: ServiceTarget,
|
||||
service_type: ServiceType,
|
||||
chat_req: ChatRequest,
|
||||
options_set: ChatOptionsSet<'_, '_>,
|
||||
base_url: String,
|
||||
) -> Result<WebRequestData> {
|
||||
let stream = matches!(service_type, ServiceType::ChatStream);
|
||||
let ServiceTarget { model, auth, endpoint } = target;
|
||||
|
||||
// -- Get the key
|
||||
let api_key = get_api_key(model_iden.clone(), client_config)?;
|
||||
// -- api_key
|
||||
let api_key = get_api_key(auth, &model)?;
|
||||
|
||||
// -- Build the header
|
||||
// -- url
|
||||
let url = AdapterDispatcher::get_service_url(&model, service_type, endpoint);
|
||||
|
||||
// -- headers
|
||||
let headers = vec![
|
||||
// headers
|
||||
("Authorization".to_string(), format!("Bearer {api_key}")),
|
||||
];
|
||||
|
||||
let stream = matches!(service_type, ServiceType::ChatStream);
|
||||
|
||||
// -- Build the basic payload
|
||||
let model_name = model_iden.model_name.to_string();
|
||||
let OpenAIRequestParts { messages, tools } = Self::into_openai_request_parts(model_iden, chat_req)?;
|
||||
let model_name = model.model_name.to_string();
|
||||
let OpenAIRequestParts { messages, tools } = Self::into_openai_request_parts(model, chat_req)?;
|
||||
let mut payload = json!({
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
@ -202,11 +209,7 @@ impl OpenAIAdapter {
|
||||
payload.x_insert("top_p", top_p)?;
|
||||
}
|
||||
|
||||
Ok(WebRequestData {
|
||||
url: base_url,
|
||||
headers,
|
||||
payload,
|
||||
})
|
||||
Ok(WebRequestData { url, headers, payload })
|
||||
}
|
||||
|
||||
/// Note: Needs to be called from super::streamer as well
|
||||
|
@ -2,7 +2,16 @@
|
||||
//! It should be private to the `crate::adapter::adapters` module.
|
||||
|
||||
use crate::chat::{ChatOptionsSet, MetaUsage};
|
||||
use crate::resolver::AuthData;
|
||||
use crate::ModelIden;
|
||||
use crate::{Error, Result};
|
||||
|
||||
pub fn get_api_key(auth: AuthData, model: &ModelIden) -> Result<String> {
|
||||
auth.single_key_value().map_err(|resolver_error| Error::Resolver {
|
||||
model_iden: model.clone(),
|
||||
resolver_error,
|
||||
})
|
||||
}
|
||||
|
||||
// region: --- StreamerChatOptions
|
||||
|
||||
|
@ -6,23 +6,35 @@ use crate::adapter::openai::OpenAIAdapter;
|
||||
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
|
||||
use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse};
|
||||
use crate::webc::WebResponse;
|
||||
use crate::Result;
|
||||
use crate::{ClientConfig, ModelIden};
|
||||
use crate::{Result, ServiceTarget};
|
||||
use reqwest::RequestBuilder;
|
||||
|
||||
use super::groq::GroqAdapter;
|
||||
use crate::resolver::{AuthData, Endpoint};
|
||||
|
||||
pub struct AdapterDispatcher;
|
||||
|
||||
impl Adapter for AdapterDispatcher {
|
||||
fn default_key_env_name(kind: AdapterKind) -> Option<&'static str> {
|
||||
fn default_endpoint(kind: AdapterKind) -> Endpoint {
|
||||
match kind {
|
||||
AdapterKind::OpenAI => OpenAIAdapter::default_key_env_name(kind),
|
||||
AdapterKind::Anthropic => AnthropicAdapter::default_key_env_name(kind),
|
||||
AdapterKind::Cohere => CohereAdapter::default_key_env_name(kind),
|
||||
AdapterKind::Ollama => OllamaAdapter::default_key_env_name(kind),
|
||||
AdapterKind::Gemini => GeminiAdapter::default_key_env_name(kind),
|
||||
AdapterKind::Groq => GroqAdapter::default_key_env_name(kind),
|
||||
AdapterKind::OpenAI => OpenAIAdapter::default_endpoint(kind),
|
||||
AdapterKind::Anthropic => AnthropicAdapter::default_endpoint(kind),
|
||||
AdapterKind::Cohere => CohereAdapter::default_endpoint(kind),
|
||||
AdapterKind::Ollama => OllamaAdapter::default_endpoint(kind),
|
||||
AdapterKind::Gemini => GeminiAdapter::default_endpoint(kind),
|
||||
AdapterKind::Groq => GroqAdapter::default_endpoint(kind),
|
||||
}
|
||||
}
|
||||
|
||||
fn default_auth(kind: AdapterKind) -> AuthData {
|
||||
match kind {
|
||||
AdapterKind::OpenAI => OpenAIAdapter::default_auth(kind),
|
||||
AdapterKind::Anthropic => AnthropicAdapter::default_auth(kind),
|
||||
AdapterKind::Cohere => CohereAdapter::default_auth(kind),
|
||||
AdapterKind::Ollama => OllamaAdapter::default_auth(kind),
|
||||
AdapterKind::Gemini => GeminiAdapter::default_auth(kind),
|
||||
AdapterKind::Groq => GroqAdapter::default_auth(kind),
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,42 +49,43 @@ impl Adapter for AdapterDispatcher {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_service_url(model_iden: ModelIden, service_type: ServiceType) -> String {
|
||||
match model_iden.adapter_kind {
|
||||
AdapterKind::OpenAI => OpenAIAdapter::get_service_url(model_iden, service_type),
|
||||
AdapterKind::Anthropic => AnthropicAdapter::get_service_url(model_iden, service_type),
|
||||
AdapterKind::Cohere => CohereAdapter::get_service_url(model_iden, service_type),
|
||||
AdapterKind::Ollama => OllamaAdapter::get_service_url(model_iden, service_type),
|
||||
AdapterKind::Gemini => GeminiAdapter::get_service_url(model_iden, service_type),
|
||||
AdapterKind::Groq => GroqAdapter::get_service_url(model_iden, service_type),
|
||||
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
|
||||
match model.adapter_kind {
|
||||
AdapterKind::OpenAI => OpenAIAdapter::get_service_url(model, service_type, endpoint),
|
||||
AdapterKind::Anthropic => AnthropicAdapter::get_service_url(model, service_type, endpoint),
|
||||
AdapterKind::Cohere => CohereAdapter::get_service_url(model, service_type, endpoint),
|
||||
AdapterKind::Ollama => OllamaAdapter::get_service_url(model, service_type, endpoint),
|
||||
AdapterKind::Gemini => GeminiAdapter::get_service_url(model, service_type, endpoint),
|
||||
AdapterKind::Groq => GroqAdapter::get_service_url(model, service_type, endpoint),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_web_request_data(
|
||||
model_iden: ModelIden,
|
||||
target: ServiceTarget,
|
||||
client_config: &ClientConfig,
|
||||
service_type: ServiceType,
|
||||
chat_req: ChatRequest,
|
||||
options_set: ChatOptionsSet<'_, '_>,
|
||||
) -> Result<WebRequestData> {
|
||||
match model_iden.adapter_kind {
|
||||
let adapter_kind = &target.model.adapter_kind;
|
||||
match adapter_kind {
|
||||
AdapterKind::OpenAI => {
|
||||
OpenAIAdapter::to_web_request_data(model_iden, client_config, service_type, chat_req, options_set)
|
||||
OpenAIAdapter::to_web_request_data(target, client_config, service_type, chat_req, options_set)
|
||||
}
|
||||
AdapterKind::Anthropic => {
|
||||
AnthropicAdapter::to_web_request_data(model_iden, client_config, service_type, chat_req, options_set)
|
||||
AnthropicAdapter::to_web_request_data(target, client_config, service_type, chat_req, options_set)
|
||||
}
|
||||
AdapterKind::Cohere => {
|
||||
CohereAdapter::to_web_request_data(model_iden, client_config, service_type, chat_req, options_set)
|
||||
CohereAdapter::to_web_request_data(target, client_config, service_type, chat_req, options_set)
|
||||
}
|
||||
AdapterKind::Ollama => {
|
||||
OllamaAdapter::to_web_request_data(model_iden, client_config, service_type, chat_req, options_set)
|
||||
OllamaAdapter::to_web_request_data(target, client_config, service_type, chat_req, options_set)
|
||||
}
|
||||
AdapterKind::Gemini => {
|
||||
GeminiAdapter::to_web_request_data(model_iden, client_config, service_type, chat_req, options_set)
|
||||
GeminiAdapter::to_web_request_data(target, client_config, service_type, chat_req, options_set)
|
||||
}
|
||||
AdapterKind::Groq => {
|
||||
GroqAdapter::to_web_request_data(model_iden, client_config, service_type, chat_req, options_set)
|
||||
GroqAdapter::to_web_request_data(target, client_config, service_type, chat_req, options_set)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,47 +2,3 @@ use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind};
|
||||
use crate::resolver::AuthData;
|
||||
use crate::{ClientConfig, ModelIden};
|
||||
use crate::{Error, Result};
|
||||
|
||||
/// Returns the `api_key` value from the config_set auth_resolver
|
||||
/// This function should be called if the adapter requires an api_key
|
||||
/// Fails if there is no auth_resolver or no auth_data
|
||||
pub fn get_api_key(model_iden: ModelIden, client_config: &ClientConfig) -> Result<String> {
|
||||
// -- Try to get it from the optional auth_resolver
|
||||
let auth_data = client_config
|
||||
.auth_resolver()
|
||||
.map(|auth_resolver| {
|
||||
auth_resolver
|
||||
.resolve(model_iden.clone())
|
||||
.map_err(|resolver_error| Error::Resolver {
|
||||
model_iden: model_iden.clone(),
|
||||
resolver_error,
|
||||
})
|
||||
})
|
||||
.transpose()? // return error if there is an error on auth resolver
|
||||
.flatten(); // flatten the two options
|
||||
|
||||
// -- If there is no auth resolver, get it from the environment name (or 'ollama' for Ollama)
|
||||
let auth_data = auth_data.or_else(|| {
|
||||
if let AdapterKind::Ollama = model_iden.adapter_kind {
|
||||
Some(AuthData::from_single("ollama".to_string()))
|
||||
} else {
|
||||
AdapterDispatcher::default_key_env_name(model_iden.adapter_kind)
|
||||
.map(|env_name| AuthData::FromEnv(env_name.to_string()))
|
||||
}
|
||||
});
|
||||
|
||||
let Some(auth_data) = auth_data else {
|
||||
return Err(Error::NoAuthData { model_iden });
|
||||
};
|
||||
|
||||
// TODO: Needs to support multiple values
|
||||
let key = auth_data
|
||||
.single_value()
|
||||
.map_err(|resolver_error| Error::Resolver {
|
||||
model_iden,
|
||||
resolver_error,
|
||||
})?
|
||||
.to_string();
|
||||
|
||||
Ok(key)
|
||||
}
|
||||
|
@ -1,11 +1,16 @@
|
||||
use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind, ServiceType, WebRequestData};
|
||||
use crate::chat::{ChatOptions, ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse};
|
||||
use crate::client::Client;
|
||||
use crate::{Error, ModelIden, Result};
|
||||
use crate::{Client, Error, ModelIden, Result, ServiceTarget};
|
||||
|
||||
/// Public AI Functions
|
||||
impl Client {
|
||||
/// Returns all the model names for a given adapter kind.
|
||||
///
|
||||
/// IMPORTANT:
|
||||
/// - Besides the Ollama adapter, this will only look at a hardcoded static list of names for now.
|
||||
/// - For Ollama, it will currently make a live request to the default host/port (http://localhost:11434/v1/).
|
||||
/// - This function will eventually change to either take an endpoint or have another function to allow a custom endpoint.
|
||||
///
|
||||
/// Notes:
|
||||
/// - Since genai only supports Chat for now, the adapter implementation should attempt to remove the non-chat models.
|
||||
/// - Later, as genai adds more capabilities, we will have a `model_names(adapter_kind, Option<&[Skill]>)`
|
||||
@ -15,31 +20,27 @@ impl Client {
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
/// Resolves the adapter kind for a given model name.
|
||||
/// Note: This does not use the `all_model_names` function to find a match, but instead relies on hardcoded matching rules.
|
||||
/// This strategy makes the library more flexible as it does not require updates
|
||||
/// when the AI Provider adds new models (assuming they follow a consistent naming pattern).
|
||||
///
|
||||
/// See [AdapterKind::from_model]
|
||||
///
|
||||
/// [AdapterKind::from_model]: crate::adapter::AdapterKind::from_model
|
||||
pub fn resolve_model_iden(&self, model_name: &str) -> Result<ModelIden> {
|
||||
/// Return the default model for a model_name str.
|
||||
/// This is used before
|
||||
pub fn default_model(&self, model_name: &str) -> Result<ModelIden> {
|
||||
// -- First get the default ModelInfo
|
||||
let adapter_kind = AdapterKind::from_model(model_name)?;
|
||||
let model_iden = ModelIden::new(adapter_kind, model_name);
|
||||
|
||||
// -- Execute the optional model_mapper
|
||||
let model_iden = if let Some(model_mapper) = self.config().model_mapper() {
|
||||
model_mapper
|
||||
.map_model(model_iden.clone())
|
||||
.map_err(|cause| Error::ModelMapperFailed { model_iden, cause })?
|
||||
} else {
|
||||
model_iden
|
||||
};
|
||||
|
||||
Ok(model_iden)
|
||||
}
|
||||
|
||||
#[deprecated(note = "use `client.resolve_service_target(model_name)")]
|
||||
pub fn resolve_model_iden(&self, model_name: &str) -> Result<ModelIden> {
|
||||
let model = self.default_model(model_name)?;
|
||||
let target = self.config().resolve_service_target(model)?;
|
||||
Ok(target.model)
|
||||
}
|
||||
|
||||
pub fn resolve_service_target(&self, model_name: &str) -> Result<ServiceTarget> {
|
||||
let model = self.default_model(model_name)?;
|
||||
self.config().resolve_service_target(model)
|
||||
}
|
||||
|
||||
/// Executes a chat.
|
||||
pub async fn exec_chat(
|
||||
&self,
|
||||
@ -48,30 +49,27 @@ impl Client {
|
||||
// options not implemented yet
|
||||
options: Option<&ChatOptions>,
|
||||
) -> Result<ChatResponse> {
|
||||
let model_iden = self.resolve_model_iden(model)?;
|
||||
|
||||
let options_set = ChatOptionsSet::default()
|
||||
.with_chat_options(options)
|
||||
.with_client_options(self.config().chat_options());
|
||||
|
||||
let WebRequestData { headers, payload, url } = AdapterDispatcher::to_web_request_data(
|
||||
model_iden.clone(),
|
||||
self.config(),
|
||||
ServiceType::Chat,
|
||||
chat_req,
|
||||
options_set,
|
||||
)?;
|
||||
let model = self.default_model(model)?;
|
||||
let target = self.config().resolve_service_target(model)?;
|
||||
let model = target.model.clone();
|
||||
|
||||
let WebRequestData { headers, payload, url } =
|
||||
AdapterDispatcher::to_web_request_data(target, self.config(), ServiceType::Chat, chat_req, options_set)?;
|
||||
|
||||
let web_res =
|
||||
self.web_client()
|
||||
.do_post(&url, &headers, payload)
|
||||
.await
|
||||
.map_err(|webc_error| Error::WebModelCall {
|
||||
model_iden: model_iden.clone(),
|
||||
model_iden: model.clone(),
|
||||
webc_error,
|
||||
})?;
|
||||
|
||||
let chat_res = AdapterDispatcher::to_chat_response(model_iden, web_res)?;
|
||||
let chat_res = AdapterDispatcher::to_chat_response(model, web_res)?;
|
||||
|
||||
Ok(chat_res)
|
||||
}
|
||||
@ -83,14 +81,16 @@ impl Client {
|
||||
chat_req: ChatRequest, // options not implemented yet
|
||||
options: Option<&ChatOptions>,
|
||||
) -> Result<ChatStreamResponse> {
|
||||
let model_iden = self.resolve_model_iden(model)?;
|
||||
|
||||
let options_set = ChatOptionsSet::default()
|
||||
.with_chat_options(options)
|
||||
.with_client_options(self.config().chat_options());
|
||||
|
||||
let model = self.default_model(model)?;
|
||||
let target = self.config().resolve_service_target(model)?;
|
||||
let model = target.model.clone();
|
||||
|
||||
let WebRequestData { url, headers, payload } = AdapterDispatcher::to_web_request_data(
|
||||
model_iden.clone(),
|
||||
target,
|
||||
self.config(),
|
||||
ServiceType::ChatStream,
|
||||
chat_req,
|
||||
@ -101,11 +101,11 @@ impl Client {
|
||||
.web_client()
|
||||
.new_req_builder(&url, &headers, payload)
|
||||
.map_err(|webc_error| Error::WebModelCall {
|
||||
model_iden: model_iden.clone(),
|
||||
model_iden: model.clone(),
|
||||
webc_error,
|
||||
})?;
|
||||
|
||||
let res = AdapterDispatcher::to_chat_stream(model_iden, reqwest_builder, options_set)?;
|
||||
let res = AdapterDispatcher::to_chat_stream(model, reqwest_builder, options_set)?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
@ -1,5 +1,8 @@
|
||||
use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind};
|
||||
use crate::chat::ChatOptions;
|
||||
use crate::resolver::{AuthResolver, ModelMapper};
|
||||
use crate::client::ServiceTarget;
|
||||
use crate::resolver::{AuthResolver, Endpoint, ModelMapper};
|
||||
use crate::{Error, ModelIden, Result};
|
||||
|
||||
/// The Client configuration used in the configuration builder stage.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
@ -47,3 +50,37 @@ impl ClientConfig {
|
||||
self.chat_options.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolvers
|
||||
impl ClientConfig {
|
||||
pub fn resolve_service_target(&self, model: ModelIden) -> Result<ServiceTarget> {
|
||||
// -- Resolve the Model first
|
||||
let model = match self.model_mapper() {
|
||||
Some(model_mapper) => model_mapper.map_model(model.clone()),
|
||||
None => Ok(model.clone()),
|
||||
}
|
||||
.map_err(|resolver_error| Error::Resolver {
|
||||
model_iden: model.clone(),
|
||||
resolver_error,
|
||||
})?;
|
||||
|
||||
// -- Get the auth
|
||||
let auth = self
|
||||
.auth_resolver()
|
||||
.map(|auth_resolver| {
|
||||
auth_resolver.resolve(model.clone()).map_err(|resolver_error| Error::Resolver {
|
||||
model_iden: model.clone(),
|
||||
resolver_error,
|
||||
})
|
||||
})
|
||||
.transpose()? // return error if there is an error on auth resolver
|
||||
.flatten()
|
||||
.unwrap_or_else(|| AdapterDispatcher::default_auth(model.adapter_kind)); // flatten the two options
|
||||
|
||||
// -- Get the default endpoint
|
||||
// For now, just get the default endpoint, the `resolve_target` will allow to override it
|
||||
let endpoint = AdapterDispatcher::default_endpoint(model.adapter_kind);
|
||||
|
||||
Ok(ServiceTarget { model, auth, endpoint })
|
||||
}
|
||||
}
|
||||
|
@ -4,9 +4,11 @@ mod builder;
|
||||
mod client_impl;
|
||||
mod client_types;
|
||||
mod config;
|
||||
mod service_target;
|
||||
|
||||
pub use builder::*;
|
||||
pub use client_types::*;
|
||||
pub use config::*;
|
||||
pub use service_target::*;
|
||||
|
||||
// endregion: --- Modules
|
||||
|
14
src/client/service_target.rs
Normal file
14
src/client/service_target.rs
Normal file
@ -0,0 +1,14 @@
|
||||
use crate::resolver::{AuthData, Endpoint};
|
||||
use crate::ModelIden;
|
||||
|
||||
/// A ServiceTarget represents the destination and necessary details for making a service call.
|
||||
///
|
||||
/// This structure contains:
|
||||
/// - `endpoint`: The specific service endpoint to be contacted.
|
||||
/// - `auth`: The authentication data required to access the service.
|
||||
/// - `model`: The identifier of the model or resource associated with the service call.
|
||||
pub struct ServiceTarget {
|
||||
pub endpoint: Endpoint,
|
||||
pub auth: AuthData,
|
||||
pub model: ModelIden,
|
||||
}
|
@ -36,7 +36,7 @@ impl AuthData {
|
||||
/// Getters
|
||||
impl AuthData {
|
||||
/// Get the single value from the `AuthData`.
|
||||
pub fn single_value(&self) -> Result<String> {
|
||||
pub fn single_key_value(&self) -> Result<String> {
|
||||
match self {
|
||||
AuthData::FromEnv(env_name) => {
|
||||
// Get value from the environment name.
|
||||
|
40
src/resolver/endpoint.rs
Normal file
40
src/resolver/endpoint.rs
Normal file
@ -0,0 +1,40 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A construct to store the endpoint of a service.
|
||||
/// It is designed to be efficiently clonable.
|
||||
/// For now, it just supports `base_url` but later might have other URLs per "service name".
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Endpoint {
|
||||
inner: EndpointInner,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum EndpointInner {
|
||||
Static(&'static str),
|
||||
Owned(Arc<str>),
|
||||
}
|
||||
|
||||
/// Constructors
|
||||
impl Endpoint {
|
||||
pub fn from_static(url: &'static str) -> Self {
|
||||
Endpoint {
|
||||
inner: EndpointInner::Static(url),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_owned(url: impl Into<Arc<str>>) -> Self {
|
||||
Endpoint {
|
||||
inner: EndpointInner::Owned(url.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Getters
|
||||
impl Endpoint {
|
||||
pub fn base_url(&self) -> &str {
|
||||
match &self.inner {
|
||||
EndpointInner::Static(url) => url,
|
||||
EndpointInner::Owned(url) => url,
|
||||
}
|
||||
}
|
||||
}
|
@ -7,11 +7,13 @@
|
||||
|
||||
mod auth_data;
|
||||
mod auth_resolver;
|
||||
mod endpoint;
|
||||
mod error;
|
||||
mod model_mapper;
|
||||
|
||||
pub use auth_data::*;
|
||||
pub use auth_resolver::*;
|
||||
pub use endpoint::*;
|
||||
pub use error::{Error, Result};
|
||||
pub use model_mapper::*;
|
||||
|
||||
|
Reference in New Issue
Block a user