+ ServiceTargetResolver added (support for custom endpoint)

This commit is contained in:
Jeremy Chone
2024-12-08 16:27:14 -08:00
parent 011fb40a04
commit 04aafb475a
5 changed files with 220 additions and 6 deletions

View File

@ -0,0 +1,56 @@
//! This example demonstrates how to use a custom ServiceTargetResolver which gives full control of the final
//! mapping for Endpoint, Model/AdapterKind, and Auth
//!
//! IMPORTANT - Here we are using xAI as an example of a custom ServiceTarget.
//! It works with regular chat using the basic OpenAIAdapter,
//! but for streaming, xAI does not follow OpenAI's specifications.
//! Therefore, below we use regular chat, and this crate provides an XaiAdapter.
use genai::adapter::AdapterKind;
use genai::chat::printer::{print_chat_stream, PrintChatStreamOptions};
use genai::chat::{ChatMessage, ChatRequest};
use genai::resolver::{AuthData, AuthResolver, Endpoint, ServiceTargetResolver};
use genai::{Client, ModelIden, ServiceTarget};
const MODEL: &str = "grok-beta";
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let questions = &[
// Follow-up questions
"Why is the sky blue?",
"Why is it red sometimes?",
];
// -- Build an auth_resolver and the AdapterConfig
let target_resolver = ServiceTargetResolver::from_resolver_fn(
|service_target: ServiceTarget| -> Result<ServiceTarget, genai::resolver::Error> {
let ServiceTarget { endpoint, auth, model } = service_target;
let endpoint = Endpoint::from_static("https://api.x.ai/v1/");
let auth = AuthData::from_env("XAI_API_KEY");
let model = ModelIden::new(AdapterKind::OpenAI, model.model_name);
// TODO: point to xai
Ok(ServiceTarget { endpoint, auth, model })
},
);
// -- Build the new client with this adapter_config
let client = Client::builder().with_service_target_resolver(target_resolver).build();
let mut chat_req = ChatRequest::default().with_system("Answer in one sentence");
for &question in questions {
chat_req = chat_req.append_message(ChatMessage::user(question));
println!("\n--- Question:\n{question}");
let chat_res = client.exec_chat(MODEL, chat_req.clone(), None).await?;
println!("\n--- Answer: ");
let assistant_answer = chat_res.content_text_as_str().ok_or("Should have response")?;
println!("{assistant_answer}");
chat_req = chat_req.append_message(ChatMessage::assistant(assistant_answer));
}
Ok(())
}

View File

@ -1,5 +1,8 @@
use crate::chat::ChatOptions;
use crate::resolver::{AuthResolver, IntoAuthResolverFn, IntoModelMapperFn, ModelMapper};
use crate::resolver::{
AuthResolver, IntoAuthResolverFn, IntoModelMapperFn, IntoServiceTargetResolverFn, ModelMapper,
ServiceTargetResolver,
};
use crate::webc::WebClient;
use crate::{Client, ClientConfig};
use std::sync::Arc;
@ -56,6 +59,19 @@ impl ClientBuilder {
self
}
pub fn with_service_target_resolver(mut self, target_resolver: ServiceTargetResolver) -> Self {
let client_config = self.config.get_or_insert_with(ClientConfig::default);
client_config.service_target_resolver = Some(target_resolver);
self
}
pub fn with_service_target_resolver_fn(mut self, target_resolver_fn: impl IntoServiceTargetResolverFn) -> Self {
let client_config = self.config.get_or_insert_with(ClientConfig::default);
let target_resolver = ServiceTargetResolver::from_resolver_fn(target_resolver_fn);
client_config.service_target_resolver = Some(target_resolver);
self
}
/// Set the model mapper for the ClientConfig of this ClientBuilder.
pub fn with_model_mapper(mut self, model_mapper: ModelMapper) -> Self {
let client_config = self.config.get_or_insert_with(ClientConfig::default);

View File

@ -1,31 +1,45 @@
use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind};
use crate::chat::ChatOptions;
use crate::client::ServiceTarget;
use crate::resolver::{AuthResolver, Endpoint, ModelMapper};
use crate::resolver::{AuthResolver, Endpoint, ModelMapper, ServiceTargetResolver};
use crate::{Error, ModelIden, Result};
/// The Client configuration used in the configuration builder stage.
#[derive(Debug, Default, Clone)]
pub struct ClientConfig {
pub(in crate::client) auth_resolver: Option<AuthResolver>,
pub(in crate::client) model_mapper: Option<ModelMapper>,
pub(in crate::client) chat_options: Option<ChatOptions>,
pub(super) auth_resolver: Option<AuthResolver>,
pub(super) service_target_resolver: Option<ServiceTargetResolver>,
pub(super) model_mapper: Option<ModelMapper>,
pub(super) chat_options: Option<ChatOptions>,
}
/// Chainable setters related to the ClientConfig.
impl ClientConfig {
/// Set the AuthResolver for the ClientConfig.
/// Note: This will be called before the `service_target_resolver`, and if registered
/// the `service_target_resolver` will get this new value.
pub fn with_auth_resolver(mut self, auth_resolver: AuthResolver) -> Self {
self.auth_resolver = Some(auth_resolver);
self
}
/// Set the ModelMapper for the ClientConfig.
/// Note: This will be called before the `service_target_resolver`, and if registered
/// the `service_target_resolver` will get this new value.
pub fn with_model_mapper(mut self, model_mapper: ModelMapper) -> Self {
self.model_mapper = Some(model_mapper);
self
}
/// Set the ServiceTargetResolver for this client config.
///
/// A ServiceTargetResolver is the last step before execution allowing the users full
/// control of the resolved Endpoint, AuthData, and ModelIden
pub fn with_service_target_resolver(mut self, service_target_resolver: ServiceTargetResolver) -> Self {
self.service_target_resolver = Some(service_target_resolver);
self
}
/// Set the default chat request options for the ClientConfig.
pub fn with_chat_options(mut self, options: ChatOptions) -> Self {
self.chat_options = Some(options);
@ -40,6 +54,10 @@ impl ClientConfig {
self.auth_resolver.as_ref()
}
pub fn service_target_resolver(&self) -> Option<&ServiceTargetResolver> {
self.service_target_resolver.as_ref()
}
/// Get a reference to the ModelMapper, if it exists.
pub fn model_mapper(&self) -> Option<&ModelMapper> {
self.model_mapper.as_ref()
@ -81,6 +99,24 @@ impl ClientConfig {
// For now, just get the default endpoint, the `resolve_target` will allow to override it
let endpoint = AdapterDispatcher::default_endpoint(model.adapter_kind);
Ok(ServiceTarget { model, auth, endpoint })
// -- Resolve the service_target
let service_target = ServiceTarget {
model: model.clone(),
auth,
endpoint,
};
let service_target = match self.service_target_resolver() {
Some(service_target_resolver) => {
service_target_resolver
.resolve(service_target)
.map_err(|resolver_error| Error::Resolver {
model_iden: model,
resolver_error,
})?
}
None => service_target,
};
Ok(service_target)
}
}

View File

@ -10,11 +10,13 @@ mod auth_resolver;
mod endpoint;
mod error;
mod model_mapper;
mod service_target_resolver;
pub use auth_data::*;
pub use auth_resolver::*;
pub use endpoint::*;
pub use error::{Error, Result};
pub use model_mapper::*;
pub use service_target_resolver::*;
// endregion: --- Modules

View File

@ -0,0 +1,104 @@
//! A `ServiceTargetResolver` is responsible for returning the `ServiceTarget`.
//! It allows users to customize/override the service target properties.
//!
//! It can take the following forms:
//! - Contains a fixed service target value,
//! - Contains a `ServiceTargetResolverFn` trait object or closure that will be called to return the `ServiceTarget`.
use crate::ServiceTarget;
use crate::resolver::Result;
use std::sync::Arc;
// region: --- ServiceTargetResolver
/// Holder for the `ServiceTargetResolver` function.
#[derive(Debug, Clone)]
pub enum ServiceTargetResolver {
/// The `ServiceTargetResolverFn` trait object.
ResolverFn(Arc<Box<dyn ServiceTargetResolverFn>>),
}
impl ServiceTargetResolver {
/// Create a new `ServiceTargetResolver` from a resolver function.
pub fn from_resolver_fn(resolver_fn: impl IntoServiceTargetResolverFn) -> Self {
ServiceTargetResolver::ResolverFn(resolver_fn.into_resolver_fn())
}
}
impl ServiceTargetResolver {
pub(crate) fn resolve(&self, service_target: ServiceTarget) -> Result<ServiceTarget> {
match self {
ServiceTargetResolver::ResolverFn(resolver_fn) => {
resolver_fn.clone().exec_fn(service_target)
}
}
}
}
// endregion: --- ServiceTargetResolver
// region: --- ServiceTargetResolverFn
/// The `ServiceTargetResolverFn` trait object.
pub trait ServiceTargetResolverFn: Send + Sync {
/// Execute the `ServiceTargetResolverFn` to get the `ServiceTarget`.
fn exec_fn(&self, service_target: ServiceTarget) -> Result<ServiceTarget>;
/// Clone the trait object.
fn clone_box(&self) -> Box<dyn ServiceTargetResolverFn>;
}
/// `ServiceTargetResolverFn` blanket implementation for any function that matches the resolver function signature.
impl<F> ServiceTargetResolverFn for F
where
F: FnOnce(ServiceTarget) -> Result<ServiceTarget> + Send + Sync + Clone + 'static,
{
fn exec_fn(&self, service_target: ServiceTarget) -> Result<ServiceTarget> {
(self.clone())(service_target)
}
fn clone_box(&self) -> Box<dyn ServiceTargetResolverFn> {
Box::new(self.clone())
}
}
// Implement Clone for Box<dyn ServiceTargetResolverFn>
impl Clone for Box<dyn ServiceTargetResolverFn> {
fn clone(&self) -> Box<dyn ServiceTargetResolverFn> {
self.clone_box()
}
}
impl std::fmt::Debug for dyn ServiceTargetResolverFn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ServiceTargetResolverFn")
}
}
// endregion: --- ServiceTargetResolverFn
// region: --- IntoServiceTargetResolverFn
/// Custom and convenient trait used in the `ServiceTargetResolver::from_resolver_fn` argument.
pub trait IntoServiceTargetResolverFn {
/// Convert the argument into a `ServiceTargetResolverFn` trait object.
fn into_resolver_fn(self) -> Arc<Box<dyn ServiceTargetResolverFn>>;
}
impl IntoServiceTargetResolverFn for Arc<Box<dyn ServiceTargetResolverFn>> {
fn into_resolver_fn(self) -> Arc<Box<dyn ServiceTargetResolverFn>> {
self
}
}
// Implement `IntoServiceTargetResolverFn` for closures.
impl<F> IntoServiceTargetResolverFn for F
where
F: FnOnce(ServiceTarget) -> Result<ServiceTarget> + Send + Sync + Clone + 'static,
{
fn into_resolver_fn(self) -> Arc<Box<dyn ServiceTargetResolverFn>> {
Arc::new(Box::new(self))
}
}
// endregion: --- IntoServiceTargetResolverFn