mirror of
https://github.com/mii443/rust-genai.git
synced 2025-09-02 07:29:19 +00:00
. clippy clean + resolver fn renaming
This commit is contained in:
12
README.md
12
README.md
@ -48,22 +48,23 @@ The goal of this library is to provide a common and ergonomic single API to many
|
|||||||
[`examples/c00-readme.rs`](examples/c00-readme.rs)
|
[`examples/c00-readme.rs`](examples/c00-readme.rs)
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
mod support; // For examples support funtions
|
|
||||||
use crate::support::{has_env, print_chat_stream};
|
|
||||||
|
|
||||||
use genai::chat::{ChatMessage, ChatRequest};
|
use genai::chat::{ChatMessage, ChatRequest};
|
||||||
use genai::client::Client;
|
use genai::client::Client;
|
||||||
|
use genai::utils::print_chat_stream;
|
||||||
|
|
||||||
const MODEL_OPENAI: &str = "gpt-3.5-turbo";
|
const MODEL_OPENAI: &str = "gpt-3.5-turbo";
|
||||||
const MODEL_ANTHROPIC: &str = "claude-3-haiku-20240307";
|
const MODEL_ANTHROPIC: &str = "claude-3-haiku-20240307";
|
||||||
const MODEL_COHERE: &str = "command-light"; // see: https://docs.cohere.com/docs/models
|
const MODEL_COHERE: &str = "command-light";
|
||||||
|
const MODEL_GEMINI: &str = "gemini-1.5-flash-latest";
|
||||||
const MODEL_OLLAMA: &str = "mixtral";
|
const MODEL_OLLAMA: &str = "mixtral";
|
||||||
|
|
||||||
|
// NOTE: Those are the default env keys for each AI Provider type.
|
||||||
const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
|
const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
|
||||||
// -- de/activate models/providers
|
// -- de/activate models/providers
|
||||||
(MODEL_OPENAI, "OPENAI_API_KEY"),
|
(MODEL_OPENAI, "OPENAI_API_KEY"),
|
||||||
(MODEL_ANTHROPIC, "ANTHROPIC_API_KEY"),
|
(MODEL_ANTHROPIC, "ANTHROPIC_API_KEY"),
|
||||||
(MODEL_COHERE, "COHERE_API_KEY"),
|
(MODEL_COHERE, "COHERE_API_KEY"),
|
||||||
|
(MODEL_GEMINI, "GEMINI_API_KEY"),
|
||||||
(MODEL_OLLAMA, ""),
|
(MODEL_OLLAMA, ""),
|
||||||
];
|
];
|
||||||
|
|
||||||
@ -71,6 +72,7 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
|
|||||||
// - starts_with "gpt" -> OpenAI
|
// - starts_with "gpt" -> OpenAI
|
||||||
// - starts_with "claude" -> Anthropic
|
// - starts_with "claude" -> Anthropic
|
||||||
// - starts_with "command" -> Cohere
|
// - starts_with "command" -> Cohere
|
||||||
|
// - starts_with "gemini" -> Gemini
|
||||||
// - For anything else -> Ollama
|
// - For anything else -> Ollama
|
||||||
//
|
//
|
||||||
// Refined mapping rules will be added later and extended as provider support grows.
|
// Refined mapping rules will be added later and extended as provider support grows.
|
||||||
@ -89,7 +91,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
for (model, env_name) in MODEL_AND_KEY_ENV_NAME_LIST {
|
for (model, env_name) in MODEL_AND_KEY_ENV_NAME_LIST {
|
||||||
// Skip if does not have the environment name set
|
// Skip if does not have the environment name set
|
||||||
if !env_name.is_empty() && !has_env(env_name) {
|
if !env_name.is_empty() && std::env::var(env_name).is_err() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ const MODEL_COHERE: &str = "command-light";
|
|||||||
const MODEL_GEMINI: &str = "gemini-1.5-flash-latest";
|
const MODEL_GEMINI: &str = "gemini-1.5-flash-latest";
|
||||||
const MODEL_OLLAMA: &str = "mixtral";
|
const MODEL_OLLAMA: &str = "mixtral";
|
||||||
|
|
||||||
|
// NOTE: Those are the default env keys for each AI Provider type.
|
||||||
const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
|
const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
|
||||||
// -- de/activate models/providers
|
// -- de/activate models/providers
|
||||||
(MODEL_OPENAI, "OPENAI_API_KEY"),
|
(MODEL_OPENAI, "OPENAI_API_KEY"),
|
||||||
@ -32,7 +33,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let chat_req = ChatRequest::new(vec![
|
let chat_req = ChatRequest::new(vec![
|
||||||
// -- Messages (de/activate to see the differences)
|
// -- Messages (de/activate to see the differences)
|
||||||
ChatMessage::system("Answer in one sentence"),
|
// ChatMessage::system("Answer in one sentence"),
|
||||||
ChatMessage::user(question),
|
ChatMessage::user(question),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
];
|
];
|
||||||
|
|
||||||
// -- Build a auth_resolver and the AdapterConfig
|
// -- Build a auth_resolver and the AdapterConfig
|
||||||
let auth_resolver = AuthResolver::from_provider_sync(
|
let auth_resolver = AuthResolver::from_sync_resolver(
|
||||||
|kind: AdapterKind, _config_set: &ConfigSet<'_>| -> Result<Option<AuthData>, genai::Error> {
|
|kind: AdapterKind, _config_set: &ConfigSet<'_>| -> Result<Option<AuthData>, genai::Error> {
|
||||||
println!("\n>> Custom auth provider for {kind} <<");
|
println!("\n>> Custom auth provider for {kind} <<");
|
||||||
let key = std::env::var("OPENAI_API_KEY").map_err(|_| genai::Error::ApiKeyEnvNotFound {
|
let key = std::env::var("OPENAI_API_KEY").map_err(|_| genai::Error::ApiKeyEnvNotFound {
|
@ -2,7 +2,7 @@
|
|||||||
//! It can take the following forms:
|
//! It can take the following forms:
|
||||||
//! - Configured with a custom environment name,
|
//! - Configured with a custom environment name,
|
||||||
//! - Contain a fixed auth value,
|
//! - Contain a fixed auth value,
|
||||||
//! - Contain an `AuthDataProvider` trait object or closure that will be called to return the AuthData.
|
//! - Contain an `AuthResolverFnSync` trait object or closure that will be called to return the AuthData.
|
||||||
//!
|
//!
|
||||||
//! Note: AuthData is typically a single value but can be Multi for future adapters (e.g., AWS Berock).
|
//! Note: AuthData is typically a single value but can be Multi for future adapters (e.g., AWS Berock).
|
||||||
|
|
||||||
@ -30,47 +30,47 @@ impl AuthResolver {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_provider_sync(provider: impl IntoAuthDataProviderSync) -> Self {
|
pub fn from_sync_resolver(resolver_fn: impl IntoSyncAuthResolverFn) -> Self {
|
||||||
AuthResolver {
|
AuthResolver {
|
||||||
inner: AuthResolverInner::SyncProvider(provider.into_provider()),
|
inner: AuthResolverInner::ResolverFnSync(resolver_fn.into_sync_resolver_fn()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// region: --- AuthDataProvider & IntoAuthDataProvider
|
// region: --- AuthDataProvider & IntoAuthDataProvider
|
||||||
|
|
||||||
pub trait AuthDataProviderSync: Send + Sync {
|
pub trait SyncAuthResolverFn: Send + Sync {
|
||||||
fn provide_auth_data_sync(&self, adapter_kind: AdapterKind, config_set: &ConfigSet) -> Result<Option<AuthData>>;
|
fn exec_sync_resolver_fn(&self, adapter_kind: AdapterKind, config_set: &ConfigSet) -> Result<Option<AuthData>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define a trait for types that can be converted into Arc<dyn AuthDataProviderSync>
|
// Define a trait for types that can be converted into Arc<dyn AuthDataProviderSync>
|
||||||
pub trait IntoAuthDataProviderSync {
|
pub trait IntoSyncAuthResolverFn {
|
||||||
fn into_provider(self) -> Arc<dyn AuthDataProviderSync>;
|
fn into_sync_resolver_fn(self) -> Arc<dyn SyncAuthResolverFn>;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implement IntoProvider for Arc<dyn AuthDataProviderSync>
|
// Implement IntoProvider for Arc<dyn AuthDataProviderSync>
|
||||||
impl IntoAuthDataProviderSync for Arc<dyn AuthDataProviderSync> {
|
impl IntoSyncAuthResolverFn for Arc<dyn SyncAuthResolverFn> {
|
||||||
fn into_provider(self) -> Arc<dyn AuthDataProviderSync> {
|
fn into_sync_resolver_fn(self) -> Arc<dyn SyncAuthResolverFn> {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implement IntoProvider for closures
|
// Implement IntoProvider for closures
|
||||||
impl<F> IntoAuthDataProviderSync for F
|
impl<F> IntoSyncAuthResolverFn for F
|
||||||
where
|
where
|
||||||
F: Fn(AdapterKind, &ConfigSet) -> Result<Option<AuthData>> + Send + Sync + 'static,
|
F: Fn(AdapterKind, &ConfigSet) -> Result<Option<AuthData>> + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
fn into_provider(self) -> Arc<dyn AuthDataProviderSync> {
|
fn into_sync_resolver_fn(self) -> Arc<dyn SyncAuthResolverFn> {
|
||||||
Arc::new(self)
|
Arc::new(self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Implement AuthDataProviderSync for closures
|
// Implement AuthDataProviderSync for closures
|
||||||
impl<F> AuthDataProviderSync for F
|
impl<F> SyncAuthResolverFn for F
|
||||||
where
|
where
|
||||||
F: Fn(AdapterKind, &ConfigSet) -> Result<Option<AuthData>> + Send + Sync,
|
F: Fn(AdapterKind, &ConfigSet) -> Result<Option<AuthData>> + Send + Sync,
|
||||||
{
|
{
|
||||||
fn provide_auth_data_sync(&self, adapter_kind: AdapterKind, config_set: &ConfigSet) -> Result<Option<AuthData>> {
|
fn exec_sync_resolver_fn(&self, adapter_kind: AdapterKind, config_set: &ConfigSet) -> Result<Option<AuthData>> {
|
||||||
self(adapter_kind, config_set)
|
self(adapter_kind, config_set)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -87,8 +87,8 @@ impl AuthResolver {
|
|||||||
Ok(Some(AuthData::from_single(key)))
|
Ok(Some(AuthData::from_single(key)))
|
||||||
}
|
}
|
||||||
AuthResolverInner::Fixed(auth_data) => Ok(Some(auth_data.clone())),
|
AuthResolverInner::Fixed(auth_data) => Ok(Some(auth_data.clone())),
|
||||||
AuthResolverInner::SyncProvider(sync_provider) => {
|
AuthResolverInner::ResolverFnSync(sync_provider) => {
|
||||||
sync_provider.provide_auth_data_sync(adapter_kind, config_set)
|
sync_provider.exec_sync_resolver_fn(adapter_kind, config_set)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -98,7 +98,7 @@ enum AuthResolverInner {
|
|||||||
EnvName(String),
|
EnvName(String),
|
||||||
Fixed(AuthData),
|
Fixed(AuthData),
|
||||||
#[allow(unused)] // future
|
#[allow(unused)] // future
|
||||||
SyncProvider(Arc<dyn AuthDataProviderSync>),
|
ResolverFnSync(Arc<dyn SyncAuthResolverFn>),
|
||||||
}
|
}
|
||||||
|
|
||||||
// impl debug for AuthResolverInner
|
// impl debug for AuthResolverInner
|
||||||
@ -107,7 +107,7 @@ impl std::fmt::Debug for AuthResolverInner {
|
|||||||
match self {
|
match self {
|
||||||
AuthResolverInner::EnvName(env_name) => write!(f, "AuthResolverInner::EnvName({})", env_name),
|
AuthResolverInner::EnvName(env_name) => write!(f, "AuthResolverInner::EnvName({})", env_name),
|
||||||
AuthResolverInner::Fixed(auth_data) => write!(f, "AuthResolverInner::Fixed({:?})", auth_data),
|
AuthResolverInner::Fixed(auth_data) => write!(f, "AuthResolverInner::Fixed({:?})", auth_data),
|
||||||
AuthResolverInner::SyncProvider(_) => write!(f, "AuthResolverInner::SyncFn(...)"),
|
AuthResolverInner::ResolverFnSync(_) => write!(f, "AuthResolverInner::FnSync(...)"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -117,7 +117,6 @@ impl std::fmt::Debug for AuthResolverInner {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub enum AuthData {
|
pub enum AuthData {
|
||||||
Single(String),
|
Single(String),
|
||||||
// TODO: Probable needs a HashMap
|
|
||||||
Multi(HashMap<String, String>),
|
Multi(HashMap<String, String>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user