mirror of
https://github.com/mii443/rust-genai.git
synced 2025-08-22 16:25:27 +00:00
+ ServiceTargetResolver added (support for custom endpoint)
This commit is contained in:
56
examples/c06-target-resolver.rs
Normal file
56
examples/c06-target-resolver.rs
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
//! This example demonstrates how to use a custom ServiceTargetResolver which gives full control of the final
|
||||||
|
//! mapping for Endpoint, Model/AdapterKind, and Auth
|
||||||
|
//!
|
||||||
|
//! IMPORTANT - Here we are using xAI as an example of a custom ServiceTarget.
|
||||||
|
//! It works with regular chat using the basic OpenAIAdapter,
|
||||||
|
//! but for streaming, xAI does not follow OpenAI's specifications.
|
||||||
|
//! Therefore, below we use regular chat, and this crate provides an XaiAdapter.
|
||||||
|
|
||||||
|
use genai::adapter::AdapterKind;
|
||||||
|
use genai::chat::printer::{print_chat_stream, PrintChatStreamOptions};
|
||||||
|
use genai::chat::{ChatMessage, ChatRequest};
|
||||||
|
use genai::resolver::{AuthData, AuthResolver, Endpoint, ServiceTargetResolver};
|
||||||
|
use genai::{Client, ModelIden, ServiceTarget};
|
||||||
|
|
||||||
|
const MODEL: &str = "grok-beta";
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
let questions = &[
|
||||||
|
// Follow-up questions
|
||||||
|
"Why is the sky blue?",
|
||||||
|
"Why is it red sometimes?",
|
||||||
|
];
|
||||||
|
|
||||||
|
// -- Build an auth_resolver and the AdapterConfig
|
||||||
|
let target_resolver = ServiceTargetResolver::from_resolver_fn(
|
||||||
|
|service_target: ServiceTarget| -> Result<ServiceTarget, genai::resolver::Error> {
|
||||||
|
let ServiceTarget { endpoint, auth, model } = service_target;
|
||||||
|
let endpoint = Endpoint::from_static("https://api.x.ai/v1/");
|
||||||
|
let auth = AuthData::from_env("XAI_API_KEY");
|
||||||
|
let model = ModelIden::new(AdapterKind::OpenAI, model.model_name);
|
||||||
|
// TODO: point to xai
|
||||||
|
Ok(ServiceTarget { endpoint, auth, model })
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// -- Build the new client with this adapter_config
|
||||||
|
let client = Client::builder().with_service_target_resolver(target_resolver).build();
|
||||||
|
|
||||||
|
let mut chat_req = ChatRequest::default().with_system("Answer in one sentence");
|
||||||
|
|
||||||
|
for &question in questions {
|
||||||
|
chat_req = chat_req.append_message(ChatMessage::user(question));
|
||||||
|
|
||||||
|
println!("\n--- Question:\n{question}");
|
||||||
|
let chat_res = client.exec_chat(MODEL, chat_req.clone(), None).await?;
|
||||||
|
|
||||||
|
println!("\n--- Answer: ");
|
||||||
|
let assistant_answer = chat_res.content_text_as_str().ok_or("Should have response")?;
|
||||||
|
println!("{assistant_answer}");
|
||||||
|
|
||||||
|
chat_req = chat_req.append_message(ChatMessage::assistant(assistant_answer));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -1,5 +1,8 @@
|
|||||||
use crate::chat::ChatOptions;
|
use crate::chat::ChatOptions;
|
||||||
use crate::resolver::{AuthResolver, IntoAuthResolverFn, IntoModelMapperFn, ModelMapper};
|
use crate::resolver::{
|
||||||
|
AuthResolver, IntoAuthResolverFn, IntoModelMapperFn, IntoServiceTargetResolverFn, ModelMapper,
|
||||||
|
ServiceTargetResolver,
|
||||||
|
};
|
||||||
use crate::webc::WebClient;
|
use crate::webc::WebClient;
|
||||||
use crate::{Client, ClientConfig};
|
use crate::{Client, ClientConfig};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@ -56,6 +59,19 @@ impl ClientBuilder {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn with_service_target_resolver(mut self, target_resolver: ServiceTargetResolver) -> Self {
|
||||||
|
let client_config = self.config.get_or_insert_with(ClientConfig::default);
|
||||||
|
client_config.service_target_resolver = Some(target_resolver);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_service_target_resolver_fn(mut self, target_resolver_fn: impl IntoServiceTargetResolverFn) -> Self {
|
||||||
|
let client_config = self.config.get_or_insert_with(ClientConfig::default);
|
||||||
|
let target_resolver = ServiceTargetResolver::from_resolver_fn(target_resolver_fn);
|
||||||
|
client_config.service_target_resolver = Some(target_resolver);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the model mapper for the ClientConfig of this ClientBuilder.
|
/// Set the model mapper for the ClientConfig of this ClientBuilder.
|
||||||
pub fn with_model_mapper(mut self, model_mapper: ModelMapper) -> Self {
|
pub fn with_model_mapper(mut self, model_mapper: ModelMapper) -> Self {
|
||||||
let client_config = self.config.get_or_insert_with(ClientConfig::default);
|
let client_config = self.config.get_or_insert_with(ClientConfig::default);
|
||||||
|
@ -1,31 +1,45 @@
|
|||||||
use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind};
|
use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind};
|
||||||
use crate::chat::ChatOptions;
|
use crate::chat::ChatOptions;
|
||||||
use crate::client::ServiceTarget;
|
use crate::client::ServiceTarget;
|
||||||
use crate::resolver::{AuthResolver, Endpoint, ModelMapper};
|
use crate::resolver::{AuthResolver, Endpoint, ModelMapper, ServiceTargetResolver};
|
||||||
use crate::{Error, ModelIden, Result};
|
use crate::{Error, ModelIden, Result};
|
||||||
|
|
||||||
/// The Client configuration used in the configuration builder stage.
|
/// The Client configuration used in the configuration builder stage.
|
||||||
#[derive(Debug, Default, Clone)]
|
#[derive(Debug, Default, Clone)]
|
||||||
pub struct ClientConfig {
|
pub struct ClientConfig {
|
||||||
pub(in crate::client) auth_resolver: Option<AuthResolver>,
|
pub(super) auth_resolver: Option<AuthResolver>,
|
||||||
pub(in crate::client) model_mapper: Option<ModelMapper>,
|
pub(super) service_target_resolver: Option<ServiceTargetResolver>,
|
||||||
pub(in crate::client) chat_options: Option<ChatOptions>,
|
pub(super) model_mapper: Option<ModelMapper>,
|
||||||
|
pub(super) chat_options: Option<ChatOptions>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Chainable setters related to the ClientConfig.
|
/// Chainable setters related to the ClientConfig.
|
||||||
impl ClientConfig {
|
impl ClientConfig {
|
||||||
/// Set the AuthResolver for the ClientConfig.
|
/// Set the AuthResolver for the ClientConfig.
|
||||||
|
/// Note: This will be called before the `service_target_resolver`, and if registered
|
||||||
|
/// the `service_target_resolver` will get this new value.
|
||||||
pub fn with_auth_resolver(mut self, auth_resolver: AuthResolver) -> Self {
|
pub fn with_auth_resolver(mut self, auth_resolver: AuthResolver) -> Self {
|
||||||
self.auth_resolver = Some(auth_resolver);
|
self.auth_resolver = Some(auth_resolver);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the ModelMapper for the ClientConfig.
|
/// Set the ModelMapper for the ClientConfig.
|
||||||
|
/// Note: This will be called before the `service_target_resolver`, and if registered
|
||||||
|
/// the `service_target_resolver` will get this new value.
|
||||||
pub fn with_model_mapper(mut self, model_mapper: ModelMapper) -> Self {
|
pub fn with_model_mapper(mut self, model_mapper: ModelMapper) -> Self {
|
||||||
self.model_mapper = Some(model_mapper);
|
self.model_mapper = Some(model_mapper);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the ServiceTargetResolver for this client config.
|
||||||
|
///
|
||||||
|
/// A ServiceTargetResolver is the last step before execution allowing the users full
|
||||||
|
/// control of the resolved Endpoint, AuthData, and ModelIden
|
||||||
|
pub fn with_service_target_resolver(mut self, service_target_resolver: ServiceTargetResolver) -> Self {
|
||||||
|
self.service_target_resolver = Some(service_target_resolver);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the default chat request options for the ClientConfig.
|
/// Set the default chat request options for the ClientConfig.
|
||||||
pub fn with_chat_options(mut self, options: ChatOptions) -> Self {
|
pub fn with_chat_options(mut self, options: ChatOptions) -> Self {
|
||||||
self.chat_options = Some(options);
|
self.chat_options = Some(options);
|
||||||
@ -40,6 +54,10 @@ impl ClientConfig {
|
|||||||
self.auth_resolver.as_ref()
|
self.auth_resolver.as_ref()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn service_target_resolver(&self) -> Option<&ServiceTargetResolver> {
|
||||||
|
self.service_target_resolver.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
/// Get a reference to the ModelMapper, if it exists.
|
/// Get a reference to the ModelMapper, if it exists.
|
||||||
pub fn model_mapper(&self) -> Option<&ModelMapper> {
|
pub fn model_mapper(&self) -> Option<&ModelMapper> {
|
||||||
self.model_mapper.as_ref()
|
self.model_mapper.as_ref()
|
||||||
@ -81,6 +99,24 @@ impl ClientConfig {
|
|||||||
// For now, just get the default endpoint, the `resolve_target` will allow to override it
|
// For now, just get the default endpoint, the `resolve_target` will allow to override it
|
||||||
let endpoint = AdapterDispatcher::default_endpoint(model.adapter_kind);
|
let endpoint = AdapterDispatcher::default_endpoint(model.adapter_kind);
|
||||||
|
|
||||||
Ok(ServiceTarget { model, auth, endpoint })
|
// -- Resolve the service_target
|
||||||
|
let service_target = ServiceTarget {
|
||||||
|
model: model.clone(),
|
||||||
|
auth,
|
||||||
|
endpoint,
|
||||||
|
};
|
||||||
|
let service_target = match self.service_target_resolver() {
|
||||||
|
Some(service_target_resolver) => {
|
||||||
|
service_target_resolver
|
||||||
|
.resolve(service_target)
|
||||||
|
.map_err(|resolver_error| Error::Resolver {
|
||||||
|
model_iden: model,
|
||||||
|
resolver_error,
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
None => service_target,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(service_target)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,11 +10,13 @@ mod auth_resolver;
|
|||||||
mod endpoint;
|
mod endpoint;
|
||||||
mod error;
|
mod error;
|
||||||
mod model_mapper;
|
mod model_mapper;
|
||||||
|
mod service_target_resolver;
|
||||||
|
|
||||||
pub use auth_data::*;
|
pub use auth_data::*;
|
||||||
pub use auth_resolver::*;
|
pub use auth_resolver::*;
|
||||||
pub use endpoint::*;
|
pub use endpoint::*;
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
pub use model_mapper::*;
|
pub use model_mapper::*;
|
||||||
|
pub use service_target_resolver::*;
|
||||||
|
|
||||||
// endregion: --- Modules
|
// endregion: --- Modules
|
||||||
|
104
src/resolver/service_target_resolver.rs
Normal file
104
src/resolver/service_target_resolver.rs
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
//! A `ServiceTargetResolver` is responsible for returning the `ServiceTarget`.
|
||||||
|
//! It allows users to customize/override the service target properties.
|
||||||
|
//!
|
||||||
|
//! It can take the following forms:
|
||||||
|
//! - Contains a fixed service target value,
|
||||||
|
//! - Contains a `ServiceTargetResolverFn` trait object or closure that will be called to return the `ServiceTarget`.
|
||||||
|
|
||||||
|
use crate::ServiceTarget;
|
||||||
|
use crate::resolver::Result;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
// region: --- ServiceTargetResolver
|
||||||
|
|
||||||
|
/// Holder for the `ServiceTargetResolver` function.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ServiceTargetResolver {
|
||||||
|
/// The `ServiceTargetResolverFn` trait object.
|
||||||
|
ResolverFn(Arc<Box<dyn ServiceTargetResolverFn>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServiceTargetResolver {
|
||||||
|
/// Create a new `ServiceTargetResolver` from a resolver function.
|
||||||
|
pub fn from_resolver_fn(resolver_fn: impl IntoServiceTargetResolverFn) -> Self {
|
||||||
|
ServiceTargetResolver::ResolverFn(resolver_fn.into_resolver_fn())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServiceTargetResolver {
|
||||||
|
pub(crate) fn resolve(&self, service_target: ServiceTarget) -> Result<ServiceTarget> {
|
||||||
|
match self {
|
||||||
|
ServiceTargetResolver::ResolverFn(resolver_fn) => {
|
||||||
|
resolver_fn.clone().exec_fn(service_target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion: --- ServiceTargetResolver
|
||||||
|
|
||||||
|
// region: --- ServiceTargetResolverFn
|
||||||
|
|
||||||
|
/// The `ServiceTargetResolverFn` trait object.
|
||||||
|
pub trait ServiceTargetResolverFn: Send + Sync {
|
||||||
|
/// Execute the `ServiceTargetResolverFn` to get the `ServiceTarget`.
|
||||||
|
fn exec_fn(&self, service_target: ServiceTarget) -> Result<ServiceTarget>;
|
||||||
|
|
||||||
|
/// Clone the trait object.
|
||||||
|
fn clone_box(&self) -> Box<dyn ServiceTargetResolverFn>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// `ServiceTargetResolverFn` blanket implementation for any function that matches the resolver function signature.
|
||||||
|
impl<F> ServiceTargetResolverFn for F
|
||||||
|
where
|
||||||
|
F: FnOnce(ServiceTarget) -> Result<ServiceTarget> + Send + Sync + Clone + 'static,
|
||||||
|
{
|
||||||
|
fn exec_fn(&self, service_target: ServiceTarget) -> Result<ServiceTarget> {
|
||||||
|
(self.clone())(service_target)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clone_box(&self) -> Box<dyn ServiceTargetResolverFn> {
|
||||||
|
Box::new(self.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement Clone for Box<dyn ServiceTargetResolverFn>
|
||||||
|
impl Clone for Box<dyn ServiceTargetResolverFn> {
|
||||||
|
fn clone(&self) -> Box<dyn ServiceTargetResolverFn> {
|
||||||
|
self.clone_box()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for dyn ServiceTargetResolverFn {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "ServiceTargetResolverFn")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion: --- ServiceTargetResolverFn
|
||||||
|
|
||||||
|
// region: --- IntoServiceTargetResolverFn
|
||||||
|
|
||||||
|
/// Custom and convenient trait used in the `ServiceTargetResolver::from_resolver_fn` argument.
|
||||||
|
pub trait IntoServiceTargetResolverFn {
|
||||||
|
/// Convert the argument into a `ServiceTargetResolverFn` trait object.
|
||||||
|
fn into_resolver_fn(self) -> Arc<Box<dyn ServiceTargetResolverFn>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoServiceTargetResolverFn for Arc<Box<dyn ServiceTargetResolverFn>> {
|
||||||
|
fn into_resolver_fn(self) -> Arc<Box<dyn ServiceTargetResolverFn>> {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement `IntoServiceTargetResolverFn` for closures.
|
||||||
|
impl<F> IntoServiceTargetResolverFn for F
|
||||||
|
where
|
||||||
|
F: FnOnce(ServiceTarget) -> Result<ServiceTarget> + Send + Sync + Clone + 'static,
|
||||||
|
{
|
||||||
|
fn into_resolver_fn(self) -> Arc<Box<dyn ServiceTargetResolverFn>> {
|
||||||
|
Arc::new(Box::new(self))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion: --- IntoServiceTargetResolverFn
|
Reference in New Issue
Block a user