. clippy clean + resolver fn renaming

This commit is contained in:
Jeremy Chone
2024-06-08 19:15:05 -07:00
parent 9ac3eebf8c
commit 9ef4b614a7
4 changed files with 27 additions and 25 deletions

View File

@ -48,22 +48,23 @@ The goal of this library is to provide a common and ergonomic single API to many
[`examples/c00-readme.rs`](examples/c00-readme.rs) [`examples/c00-readme.rs`](examples/c00-readme.rs)
```rust ```rust
mod support; // For examples support funtions
use crate::support::{has_env, print_chat_stream};
use genai::chat::{ChatMessage, ChatRequest}; use genai::chat::{ChatMessage, ChatRequest};
use genai::client::Client; use genai::client::Client;
use genai::utils::print_chat_stream;
const MODEL_OPENAI: &str = "gpt-3.5-turbo"; const MODEL_OPENAI: &str = "gpt-3.5-turbo";
const MODEL_ANTHROPIC: &str = "claude-3-haiku-20240307"; const MODEL_ANTHROPIC: &str = "claude-3-haiku-20240307";
const MODEL_COHERE: &str = "command-light"; // see: https://docs.cohere.com/docs/models const MODEL_COHERE: &str = "command-light";
const MODEL_GEMINI: &str = "gemini-1.5-flash-latest";
const MODEL_OLLAMA: &str = "mixtral"; const MODEL_OLLAMA: &str = "mixtral";
// NOTE: Those are the default env keys for each AI Provider type.
const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
// -- de/activate models/providers // -- de/activate models/providers
(MODEL_OPENAI, "OPENAI_API_KEY"), (MODEL_OPENAI, "OPENAI_API_KEY"),
(MODEL_ANTHROPIC, "ANTHROPIC_API_KEY"), (MODEL_ANTHROPIC, "ANTHROPIC_API_KEY"),
(MODEL_COHERE, "COHERE_API_KEY"), (MODEL_COHERE, "COHERE_API_KEY"),
(MODEL_GEMINI, "GEMINI_API_KEY"),
(MODEL_OLLAMA, ""), (MODEL_OLLAMA, ""),
]; ];
@ -71,6 +72,7 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
// - starts_with "gpt" -> OpenAI // - starts_with "gpt" -> OpenAI
// - starts_with "claude" -> Anthropic // - starts_with "claude" -> Anthropic
// - starts_with "command" -> Cohere // - starts_with "command" -> Cohere
// - starts_with "gemini" -> Gemini
// - For anything else -> Ollama // - For anything else -> Ollama
// //
// Refined mapping rules will be added later and extended as provider support grows. // Refined mapping rules will be added later and extended as provider support grows.
@ -89,7 +91,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
for (model, env_name) in MODEL_AND_KEY_ENV_NAME_LIST { for (model, env_name) in MODEL_AND_KEY_ENV_NAME_LIST {
// Skip if does not have the environment name set // Skip if does not have the environment name set
if !env_name.is_empty() && !has_env(env_name) { if !env_name.is_empty() && std::env::var(env_name).is_err() {
continue; continue;
} }

View File

@ -8,6 +8,7 @@ const MODEL_COHERE: &str = "command-light";
const MODEL_GEMINI: &str = "gemini-1.5-flash-latest"; const MODEL_GEMINI: &str = "gemini-1.5-flash-latest";
const MODEL_OLLAMA: &str = "mixtral"; const MODEL_OLLAMA: &str = "mixtral";
// NOTE: Those are the default env keys for each AI Provider type.
const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
// -- de/activate models/providers // -- de/activate models/providers
(MODEL_OPENAI, "OPENAI_API_KEY"), (MODEL_OPENAI, "OPENAI_API_KEY"),
@ -32,7 +33,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let chat_req = ChatRequest::new(vec![ let chat_req = ChatRequest::new(vec![
// -- Messages (de/activate to see the differences) // -- Messages (de/activate to see the differences)
ChatMessage::system("Answer in one sentence"), // ChatMessage::system("Answer in one sentence"),
ChatMessage::user(question), ChatMessage::user(question),
]); ]);

View File

@ -16,7 +16,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
]; ];
// -- Build a auth_resolver and the AdapterConfig // -- Build a auth_resolver and the AdapterConfig
let auth_resolver = AuthResolver::from_provider_sync( let auth_resolver = AuthResolver::from_sync_resolver(
|kind: AdapterKind, _config_set: &ConfigSet<'_>| -> Result<Option<AuthData>, genai::Error> { |kind: AdapterKind, _config_set: &ConfigSet<'_>| -> Result<Option<AuthData>, genai::Error> {
println!("\n>> Custom auth provider for {kind} <<"); println!("\n>> Custom auth provider for {kind} <<");
let key = std::env::var("OPENAI_API_KEY").map_err(|_| genai::Error::ApiKeyEnvNotFound { let key = std::env::var("OPENAI_API_KEY").map_err(|_| genai::Error::ApiKeyEnvNotFound {

View File

@ -2,7 +2,7 @@
//! It can take the following forms: //! It can take the following forms:
//! - Configured with a custom environment name, //! - Configured with a custom environment name,
//! - Contain a fixed auth value, //! - Contain a fixed auth value,
//! - Contain an `AuthDataProvider` trait object or closure that will be called to return the AuthData. //! - Contain an `AuthResolverFnSync` trait object or closure that will be called to return the AuthData.
//! //!
//! Note: AuthData is typically a single value but can be Multi for future adapters (e.g., AWS Berock). //! Note: AuthData is typically a single value but can be Multi for future adapters (e.g., AWS Berock).
@ -30,47 +30,47 @@ impl AuthResolver {
} }
} }
pub fn from_provider_sync(provider: impl IntoAuthDataProviderSync) -> Self { pub fn from_sync_resolver(resolver_fn: impl IntoSyncAuthResolverFn) -> Self {
AuthResolver { AuthResolver {
inner: AuthResolverInner::SyncProvider(provider.into_provider()), inner: AuthResolverInner::ResolverFnSync(resolver_fn.into_sync_resolver_fn()),
} }
} }
} }
// region: --- AuthDataProvider & IntoAuthDataProvider // region: --- AuthDataProvider & IntoAuthDataProvider
pub trait AuthDataProviderSync: Send + Sync { pub trait SyncAuthResolverFn: Send + Sync {
fn provide_auth_data_sync(&self, adapter_kind: AdapterKind, config_set: &ConfigSet) -> Result<Option<AuthData>>; fn exec_sync_resolver_fn(&self, adapter_kind: AdapterKind, config_set: &ConfigSet) -> Result<Option<AuthData>>;
} }
// Define a trait for types that can be converted into Arc<dyn AuthDataProviderSync> // Define a trait for types that can be converted into Arc<dyn AuthDataProviderSync>
pub trait IntoAuthDataProviderSync { pub trait IntoSyncAuthResolverFn {
fn into_provider(self) -> Arc<dyn AuthDataProviderSync>; fn into_sync_resolver_fn(self) -> Arc<dyn SyncAuthResolverFn>;
} }
// Implement IntoProvider for Arc<dyn AuthDataProviderSync> // Implement IntoProvider for Arc<dyn AuthDataProviderSync>
impl IntoAuthDataProviderSync for Arc<dyn AuthDataProviderSync> { impl IntoSyncAuthResolverFn for Arc<dyn SyncAuthResolverFn> {
fn into_provider(self) -> Arc<dyn AuthDataProviderSync> { fn into_sync_resolver_fn(self) -> Arc<dyn SyncAuthResolverFn> {
self self
} }
} }
// Implement IntoProvider for closures // Implement IntoProvider for closures
impl<F> IntoAuthDataProviderSync for F impl<F> IntoSyncAuthResolverFn for F
where where
F: Fn(AdapterKind, &ConfigSet) -> Result<Option<AuthData>> + Send + Sync + 'static, F: Fn(AdapterKind, &ConfigSet) -> Result<Option<AuthData>> + Send + Sync + 'static,
{ {
fn into_provider(self) -> Arc<dyn AuthDataProviderSync> { fn into_sync_resolver_fn(self) -> Arc<dyn SyncAuthResolverFn> {
Arc::new(self) Arc::new(self)
} }
} }
// Implement AuthDataProviderSync for closures // Implement AuthDataProviderSync for closures
impl<F> AuthDataProviderSync for F impl<F> SyncAuthResolverFn for F
where where
F: Fn(AdapterKind, &ConfigSet) -> Result<Option<AuthData>> + Send + Sync, F: Fn(AdapterKind, &ConfigSet) -> Result<Option<AuthData>> + Send + Sync,
{ {
fn provide_auth_data_sync(&self, adapter_kind: AdapterKind, config_set: &ConfigSet) -> Result<Option<AuthData>> { fn exec_sync_resolver_fn(&self, adapter_kind: AdapterKind, config_set: &ConfigSet) -> Result<Option<AuthData>> {
self(adapter_kind, config_set) self(adapter_kind, config_set)
} }
} }
@ -87,8 +87,8 @@ impl AuthResolver {
Ok(Some(AuthData::from_single(key))) Ok(Some(AuthData::from_single(key)))
} }
AuthResolverInner::Fixed(auth_data) => Ok(Some(auth_data.clone())), AuthResolverInner::Fixed(auth_data) => Ok(Some(auth_data.clone())),
AuthResolverInner::SyncProvider(sync_provider) => { AuthResolverInner::ResolverFnSync(sync_provider) => {
sync_provider.provide_auth_data_sync(adapter_kind, config_set) sync_provider.exec_sync_resolver_fn(adapter_kind, config_set)
} }
} }
} }
@ -98,7 +98,7 @@ enum AuthResolverInner {
EnvName(String), EnvName(String),
Fixed(AuthData), Fixed(AuthData),
#[allow(unused)] // future #[allow(unused)] // future
SyncProvider(Arc<dyn AuthDataProviderSync>), ResolverFnSync(Arc<dyn SyncAuthResolverFn>),
} }
// impl debug for AuthResolverInner // impl debug for AuthResolverInner
@ -107,7 +107,7 @@ impl std::fmt::Debug for AuthResolverInner {
match self { match self {
AuthResolverInner::EnvName(env_name) => write!(f, "AuthResolverInner::EnvName({})", env_name), AuthResolverInner::EnvName(env_name) => write!(f, "AuthResolverInner::EnvName({})", env_name),
AuthResolverInner::Fixed(auth_data) => write!(f, "AuthResolverInner::Fixed({:?})", auth_data), AuthResolverInner::Fixed(auth_data) => write!(f, "AuthResolverInner::Fixed({:?})", auth_data),
AuthResolverInner::SyncProvider(_) => write!(f, "AuthResolverInner::SyncFn(...)"), AuthResolverInner::ResolverFnSync(_) => write!(f, "AuthResolverInner::FnSync(...)"),
} }
} }
} }
@ -117,7 +117,6 @@ impl std::fmt::Debug for AuthResolverInner {
#[derive(Clone)] #[derive(Clone)]
pub enum AuthData { pub enum AuthData {
Single(String), Single(String),
// TODO: Probable needs a HashMap
Multi(HashMap<String, String>), Multi(HashMap<String, String>),
} }