. clippy clean + resolver fn renaming

This commit is contained in:
Jeremy Chone
2024-06-08 19:15:05 -07:00
parent 9ac3eebf8c
commit 9ef4b614a7
4 changed files with 27 additions and 25 deletions

View File

@ -48,22 +48,23 @@ The goal of this library is to provide a common and ergonomic single API to many
[`examples/c00-readme.rs`](examples/c00-readme.rs)
```rust
mod support; // For examples support funtions
use crate::support::{has_env, print_chat_stream};
use genai::chat::{ChatMessage, ChatRequest};
use genai::client::Client;
use genai::utils::print_chat_stream;
const MODEL_OPENAI: &str = "gpt-3.5-turbo";
const MODEL_ANTHROPIC: &str = "claude-3-haiku-20240307";
const MODEL_COHERE: &str = "command-light"; // see: https://docs.cohere.com/docs/models
const MODEL_COHERE: &str = "command-light";
const MODEL_GEMINI: &str = "gemini-1.5-flash-latest";
const MODEL_OLLAMA: &str = "mixtral";
// NOTE: Those are the default env keys for each AI Provider type.
const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
// -- de/activate models/providers
(MODEL_OPENAI, "OPENAI_API_KEY"),
(MODEL_ANTHROPIC, "ANTHROPIC_API_KEY"),
(MODEL_COHERE, "COHERE_API_KEY"),
(MODEL_GEMINI, "GEMINI_API_KEY"),
(MODEL_OLLAMA, ""),
];
@ -71,6 +72,7 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
// - starts_with "gpt" -> OpenAI
// - starts_with "claude" -> Anthropic
// - starts_with "command" -> Cohere
// - starts_with "gemini" -> Gemini
// - For anything else -> Ollama
//
// Refined mapping rules will be added later and extended as provider support grows.
@ -89,7 +91,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
for (model, env_name) in MODEL_AND_KEY_ENV_NAME_LIST {
// Skip if does not have the environment name set
if !env_name.is_empty() && !has_env(env_name) {
if !env_name.is_empty() && std::env::var(env_name).is_err() {
continue;
}

View File

@ -8,6 +8,7 @@ const MODEL_COHERE: &str = "command-light";
const MODEL_GEMINI: &str = "gemini-1.5-flash-latest";
const MODEL_OLLAMA: &str = "mixtral";
// NOTE: Those are the default env keys for each AI Provider type.
const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
// -- de/activate models/providers
(MODEL_OPENAI, "OPENAI_API_KEY"),
@ -32,7 +33,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let chat_req = ChatRequest::new(vec![
// -- Messages (de/activate to see the differences)
ChatMessage::system("Answer in one sentence"),
// ChatMessage::system("Answer in one sentence"),
ChatMessage::user(question),
]);

View File

@ -16,7 +16,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
];
// -- Build a auth_resolver and the AdapterConfig
let auth_resolver = AuthResolver::from_provider_sync(
let auth_resolver = AuthResolver::from_sync_resolver(
|kind: AdapterKind, _config_set: &ConfigSet<'_>| -> Result<Option<AuthData>, genai::Error> {
println!("\n>> Custom auth provider for {kind} <<");
let key = std::env::var("OPENAI_API_KEY").map_err(|_| genai::Error::ApiKeyEnvNotFound {

View File

@ -2,7 +2,7 @@
//! It can take the following forms:
//! - Configured with a custom environment name,
//! - Contain a fixed auth value,
//! - Contain an `AuthDataProvider` trait object or closure that will be called to return the AuthData.
//! - Contain an `AuthResolverFnSync` trait object or closure that will be called to return the AuthData.
//!
//! Note: AuthData is typically a single value but can be Multi for future adapters (e.g., AWS Berock).
@ -30,47 +30,47 @@ impl AuthResolver {
}
}
pub fn from_provider_sync(provider: impl IntoAuthDataProviderSync) -> Self {
pub fn from_sync_resolver(resolver_fn: impl IntoSyncAuthResolverFn) -> Self {
AuthResolver {
inner: AuthResolverInner::SyncProvider(provider.into_provider()),
inner: AuthResolverInner::ResolverFnSync(resolver_fn.into_sync_resolver_fn()),
}
}
}
// region: --- AuthDataProvider & IntoAuthDataProvider
pub trait AuthDataProviderSync: Send + Sync {
fn provide_auth_data_sync(&self, adapter_kind: AdapterKind, config_set: &ConfigSet) -> Result<Option<AuthData>>;
pub trait SyncAuthResolverFn: Send + Sync {
fn exec_sync_resolver_fn(&self, adapter_kind: AdapterKind, config_set: &ConfigSet) -> Result<Option<AuthData>>;
}
// Define a trait for types that can be converted into Arc<dyn AuthDataProviderSync>
pub trait IntoAuthDataProviderSync {
fn into_provider(self) -> Arc<dyn AuthDataProviderSync>;
pub trait IntoSyncAuthResolverFn {
fn into_sync_resolver_fn(self) -> Arc<dyn SyncAuthResolverFn>;
}
// Implement IntoProvider for Arc<dyn AuthDataProviderSync>
impl IntoAuthDataProviderSync for Arc<dyn AuthDataProviderSync> {
fn into_provider(self) -> Arc<dyn AuthDataProviderSync> {
impl IntoSyncAuthResolverFn for Arc<dyn SyncAuthResolverFn> {
fn into_sync_resolver_fn(self) -> Arc<dyn SyncAuthResolverFn> {
self
}
}
// Implement IntoProvider for closures
impl<F> IntoAuthDataProviderSync for F
impl<F> IntoSyncAuthResolverFn for F
where
F: Fn(AdapterKind, &ConfigSet) -> Result<Option<AuthData>> + Send + Sync + 'static,
{
fn into_provider(self) -> Arc<dyn AuthDataProviderSync> {
fn into_sync_resolver_fn(self) -> Arc<dyn SyncAuthResolverFn> {
Arc::new(self)
}
}
// Implement AuthDataProviderSync for closures
impl<F> AuthDataProviderSync for F
impl<F> SyncAuthResolverFn for F
where
F: Fn(AdapterKind, &ConfigSet) -> Result<Option<AuthData>> + Send + Sync,
{
fn provide_auth_data_sync(&self, adapter_kind: AdapterKind, config_set: &ConfigSet) -> Result<Option<AuthData>> {
fn exec_sync_resolver_fn(&self, adapter_kind: AdapterKind, config_set: &ConfigSet) -> Result<Option<AuthData>> {
self(adapter_kind, config_set)
}
}
@ -87,8 +87,8 @@ impl AuthResolver {
Ok(Some(AuthData::from_single(key)))
}
AuthResolverInner::Fixed(auth_data) => Ok(Some(auth_data.clone())),
AuthResolverInner::SyncProvider(sync_provider) => {
sync_provider.provide_auth_data_sync(adapter_kind, config_set)
AuthResolverInner::ResolverFnSync(sync_provider) => {
sync_provider.exec_sync_resolver_fn(adapter_kind, config_set)
}
}
}
@ -98,7 +98,7 @@ enum AuthResolverInner {
EnvName(String),
Fixed(AuthData),
#[allow(unused)] // future
SyncProvider(Arc<dyn AuthDataProviderSync>),
ResolverFnSync(Arc<dyn SyncAuthResolverFn>),
}
// impl debug for AuthResolverInner
@ -107,7 +107,7 @@ impl std::fmt::Debug for AuthResolverInner {
match self {
AuthResolverInner::EnvName(env_name) => write!(f, "AuthResolverInner::EnvName({})", env_name),
AuthResolverInner::Fixed(auth_data) => write!(f, "AuthResolverInner::Fixed({:?})", auth_data),
AuthResolverInner::SyncProvider(_) => write!(f, "AuthResolverInner::SyncFn(...)"),
AuthResolverInner::ResolverFnSync(_) => write!(f, "AuthResolverInner::FnSync(...)"),
}
}
}
@ -117,7 +117,6 @@ impl std::fmt::Debug for AuthResolverInner {
#[derive(Clone)]
pub enum AuthData {
Single(String),
// TODO: Probable needs a HashMap
Multi(HashMap<String, String>),
}