mirror of
https://github.com/mii443/rust-genai.git
synced 2025-08-22 16:25:27 +00:00
+ ServiceTargetResolver added (support for custom endpoint)
This commit is contained in:
56
examples/c06-target-resolver.rs
Normal file
56
examples/c06-target-resolver.rs
Normal file
@ -0,0 +1,56 @@
|
||||
//! This example demonstrates how to use a custom ServiceTargetResolver which gives full control of the final
|
||||
//! mapping for Endpoint, Model/AdapterKind, and Auth
|
||||
//!
|
||||
//! IMPORTANT - Here we are using xAI as an example of a custom ServiceTarget.
|
||||
//! It works with regular chat using the basic OpenAIAdapter,
|
||||
//! but for streaming, xAI does not follow OpenAI's specifications.
|
||||
//! Therefore, below we use regular chat, and this crate provides an XaiAdapter.
|
||||
|
||||
use genai::adapter::AdapterKind;
|
||||
use genai::chat::printer::{print_chat_stream, PrintChatStreamOptions};
|
||||
use genai::chat::{ChatMessage, ChatRequest};
|
||||
use genai::resolver::{AuthData, AuthResolver, Endpoint, ServiceTargetResolver};
|
||||
use genai::{Client, ModelIden, ServiceTarget};
|
||||
|
||||
const MODEL: &str = "grok-beta";
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let questions = &[
|
||||
// Follow-up questions
|
||||
"Why is the sky blue?",
|
||||
"Why is it red sometimes?",
|
||||
];
|
||||
|
||||
// -- Build an auth_resolver and the AdapterConfig
|
||||
let target_resolver = ServiceTargetResolver::from_resolver_fn(
|
||||
|service_target: ServiceTarget| -> Result<ServiceTarget, genai::resolver::Error> {
|
||||
let ServiceTarget { endpoint, auth, model } = service_target;
|
||||
let endpoint = Endpoint::from_static("https://api.x.ai/v1/");
|
||||
let auth = AuthData::from_env("XAI_API_KEY");
|
||||
let model = ModelIden::new(AdapterKind::OpenAI, model.model_name);
|
||||
// TODO: point to xai
|
||||
Ok(ServiceTarget { endpoint, auth, model })
|
||||
},
|
||||
);
|
||||
|
||||
// -- Build the new client with this adapter_config
|
||||
let client = Client::builder().with_service_target_resolver(target_resolver).build();
|
||||
|
||||
let mut chat_req = ChatRequest::default().with_system("Answer in one sentence");
|
||||
|
||||
for &question in questions {
|
||||
chat_req = chat_req.append_message(ChatMessage::user(question));
|
||||
|
||||
println!("\n--- Question:\n{question}");
|
||||
let chat_res = client.exec_chat(MODEL, chat_req.clone(), None).await?;
|
||||
|
||||
println!("\n--- Answer: ");
|
||||
let assistant_answer = chat_res.content_text_as_str().ok_or("Should have response")?;
|
||||
println!("{assistant_answer}");
|
||||
|
||||
chat_req = chat_req.append_message(ChatMessage::assistant(assistant_answer));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,5 +1,8 @@
|
||||
use crate::chat::ChatOptions;
|
||||
use crate::resolver::{AuthResolver, IntoAuthResolverFn, IntoModelMapperFn, ModelMapper};
|
||||
use crate::resolver::{
|
||||
AuthResolver, IntoAuthResolverFn, IntoModelMapperFn, IntoServiceTargetResolverFn, ModelMapper,
|
||||
ServiceTargetResolver,
|
||||
};
|
||||
use crate::webc::WebClient;
|
||||
use crate::{Client, ClientConfig};
|
||||
use std::sync::Arc;
|
||||
@ -56,6 +59,19 @@ impl ClientBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_service_target_resolver(mut self, target_resolver: ServiceTargetResolver) -> Self {
|
||||
let client_config = self.config.get_or_insert_with(ClientConfig::default);
|
||||
client_config.service_target_resolver = Some(target_resolver);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_service_target_resolver_fn(mut self, target_resolver_fn: impl IntoServiceTargetResolverFn) -> Self {
|
||||
let client_config = self.config.get_or_insert_with(ClientConfig::default);
|
||||
let target_resolver = ServiceTargetResolver::from_resolver_fn(target_resolver_fn);
|
||||
client_config.service_target_resolver = Some(target_resolver);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the model mapper for the ClientConfig of this ClientBuilder.
|
||||
pub fn with_model_mapper(mut self, model_mapper: ModelMapper) -> Self {
|
||||
let client_config = self.config.get_or_insert_with(ClientConfig::default);
|
||||
|
@ -1,31 +1,45 @@
|
||||
use crate::adapter::{Adapter, AdapterDispatcher, AdapterKind};
|
||||
use crate::chat::ChatOptions;
|
||||
use crate::client::ServiceTarget;
|
||||
use crate::resolver::{AuthResolver, Endpoint, ModelMapper};
|
||||
use crate::resolver::{AuthResolver, Endpoint, ModelMapper, ServiceTargetResolver};
|
||||
use crate::{Error, ModelIden, Result};
|
||||
|
||||
/// The Client configuration used in the configuration builder stage.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct ClientConfig {
|
||||
pub(in crate::client) auth_resolver: Option<AuthResolver>,
|
||||
pub(in crate::client) model_mapper: Option<ModelMapper>,
|
||||
pub(in crate::client) chat_options: Option<ChatOptions>,
|
||||
pub(super) auth_resolver: Option<AuthResolver>,
|
||||
pub(super) service_target_resolver: Option<ServiceTargetResolver>,
|
||||
pub(super) model_mapper: Option<ModelMapper>,
|
||||
pub(super) chat_options: Option<ChatOptions>,
|
||||
}
|
||||
|
||||
/// Chainable setters related to the ClientConfig.
|
||||
impl ClientConfig {
|
||||
/// Set the AuthResolver for the ClientConfig.
|
||||
/// Note: This will be called before the `service_target_resolver`, and if registered
|
||||
/// the `service_target_resolver` will get this new value.
|
||||
pub fn with_auth_resolver(mut self, auth_resolver: AuthResolver) -> Self {
|
||||
self.auth_resolver = Some(auth_resolver);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the ModelMapper for the ClientConfig.
|
||||
/// Note: This will be called before the `service_target_resolver`, and if registered
|
||||
/// the `service_target_resolver` will get this new value.
|
||||
pub fn with_model_mapper(mut self, model_mapper: ModelMapper) -> Self {
|
||||
self.model_mapper = Some(model_mapper);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the ServiceTargetResolver for this client config.
|
||||
///
|
||||
/// A ServiceTargetResolver is the last step before execution allowing the users full
|
||||
/// control of the resolved Endpoint, AuthData, and ModelIden
|
||||
pub fn with_service_target_resolver(mut self, service_target_resolver: ServiceTargetResolver) -> Self {
|
||||
self.service_target_resolver = Some(service_target_resolver);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the default chat request options for the ClientConfig.
|
||||
pub fn with_chat_options(mut self, options: ChatOptions) -> Self {
|
||||
self.chat_options = Some(options);
|
||||
@ -40,6 +54,10 @@ impl ClientConfig {
|
||||
self.auth_resolver.as_ref()
|
||||
}
|
||||
|
||||
pub fn service_target_resolver(&self) -> Option<&ServiceTargetResolver> {
|
||||
self.service_target_resolver.as_ref()
|
||||
}
|
||||
|
||||
/// Get a reference to the ModelMapper, if it exists.
|
||||
pub fn model_mapper(&self) -> Option<&ModelMapper> {
|
||||
self.model_mapper.as_ref()
|
||||
@ -81,6 +99,24 @@ impl ClientConfig {
|
||||
// For now, just get the default endpoint, the `resolve_target` will allow to override it
|
||||
let endpoint = AdapterDispatcher::default_endpoint(model.adapter_kind);
|
||||
|
||||
Ok(ServiceTarget { model, auth, endpoint })
|
||||
// -- Resolve the service_target
|
||||
let service_target = ServiceTarget {
|
||||
model: model.clone(),
|
||||
auth,
|
||||
endpoint,
|
||||
};
|
||||
let service_target = match self.service_target_resolver() {
|
||||
Some(service_target_resolver) => {
|
||||
service_target_resolver
|
||||
.resolve(service_target)
|
||||
.map_err(|resolver_error| Error::Resolver {
|
||||
model_iden: model,
|
||||
resolver_error,
|
||||
})?
|
||||
}
|
||||
None => service_target,
|
||||
};
|
||||
|
||||
Ok(service_target)
|
||||
}
|
||||
}
|
||||
|
@ -10,11 +10,13 @@ mod auth_resolver;
|
||||
mod endpoint;
|
||||
mod error;
|
||||
mod model_mapper;
|
||||
mod service_target_resolver;
|
||||
|
||||
pub use auth_data::*;
|
||||
pub use auth_resolver::*;
|
||||
pub use endpoint::*;
|
||||
pub use error::{Error, Result};
|
||||
pub use model_mapper::*;
|
||||
pub use service_target_resolver::*;
|
||||
|
||||
// endregion: --- Modules
|
||||
|
104
src/resolver/service_target_resolver.rs
Normal file
104
src/resolver/service_target_resolver.rs
Normal file
@ -0,0 +1,104 @@
|
||||
//! A `ServiceTargetResolver` is responsible for returning the `ServiceTarget`.
|
||||
//! It allows users to customize/override the service target properties.
|
||||
//!
|
||||
//! It can take the following forms:
|
||||
//! - Contains a fixed service target value,
|
||||
//! - Contains a `ServiceTargetResolverFn` trait object or closure that will be called to return the `ServiceTarget`.
|
||||
|
||||
use crate::ServiceTarget;
|
||||
use crate::resolver::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
// region: --- ServiceTargetResolver
|
||||
|
||||
/// Holder for the `ServiceTargetResolver` function.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ServiceTargetResolver {
|
||||
/// The `ServiceTargetResolverFn` trait object.
|
||||
ResolverFn(Arc<Box<dyn ServiceTargetResolverFn>>),
|
||||
}
|
||||
|
||||
impl ServiceTargetResolver {
|
||||
/// Create a new `ServiceTargetResolver` from a resolver function.
|
||||
pub fn from_resolver_fn(resolver_fn: impl IntoServiceTargetResolverFn) -> Self {
|
||||
ServiceTargetResolver::ResolverFn(resolver_fn.into_resolver_fn())
|
||||
}
|
||||
}
|
||||
|
||||
impl ServiceTargetResolver {
|
||||
pub(crate) fn resolve(&self, service_target: ServiceTarget) -> Result<ServiceTarget> {
|
||||
match self {
|
||||
ServiceTargetResolver::ResolverFn(resolver_fn) => {
|
||||
resolver_fn.clone().exec_fn(service_target)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// endregion: --- ServiceTargetResolver
|
||||
|
||||
// region: --- ServiceTargetResolverFn
|
||||
|
||||
/// The `ServiceTargetResolverFn` trait object.
|
||||
pub trait ServiceTargetResolverFn: Send + Sync {
|
||||
/// Execute the `ServiceTargetResolverFn` to get the `ServiceTarget`.
|
||||
fn exec_fn(&self, service_target: ServiceTarget) -> Result<ServiceTarget>;
|
||||
|
||||
/// Clone the trait object.
|
||||
fn clone_box(&self) -> Box<dyn ServiceTargetResolverFn>;
|
||||
}
|
||||
|
||||
/// `ServiceTargetResolverFn` blanket implementation for any function that matches the resolver function signature.
|
||||
impl<F> ServiceTargetResolverFn for F
|
||||
where
|
||||
F: FnOnce(ServiceTarget) -> Result<ServiceTarget> + Send + Sync + Clone + 'static,
|
||||
{
|
||||
fn exec_fn(&self, service_target: ServiceTarget) -> Result<ServiceTarget> {
|
||||
(self.clone())(service_target)
|
||||
}
|
||||
|
||||
fn clone_box(&self) -> Box<dyn ServiceTargetResolverFn> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
// Implement Clone for Box<dyn ServiceTargetResolverFn>
|
||||
impl Clone for Box<dyn ServiceTargetResolverFn> {
|
||||
fn clone(&self) -> Box<dyn ServiceTargetResolverFn> {
|
||||
self.clone_box()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn ServiceTargetResolverFn {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "ServiceTargetResolverFn")
|
||||
}
|
||||
}
|
||||
|
||||
// endregion: --- ServiceTargetResolverFn
|
||||
|
||||
// region: --- IntoServiceTargetResolverFn
|
||||
|
||||
/// Custom and convenient trait used in the `ServiceTargetResolver::from_resolver_fn` argument.
|
||||
pub trait IntoServiceTargetResolverFn {
|
||||
/// Convert the argument into a `ServiceTargetResolverFn` trait object.
|
||||
fn into_resolver_fn(self) -> Arc<Box<dyn ServiceTargetResolverFn>>;
|
||||
}
|
||||
|
||||
impl IntoServiceTargetResolverFn for Arc<Box<dyn ServiceTargetResolverFn>> {
|
||||
fn into_resolver_fn(self) -> Arc<Box<dyn ServiceTargetResolverFn>> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// Implement `IntoServiceTargetResolverFn` for closures.
|
||||
impl<F> IntoServiceTargetResolverFn for F
|
||||
where
|
||||
F: FnOnce(ServiceTarget) -> Result<ServiceTarget> + Send + Sync + Clone + 'static,
|
||||
{
|
||||
fn into_resolver_fn(self) -> Arc<Box<dyn ServiceTargetResolverFn>> {
|
||||
Arc::new(Box::new(self))
|
||||
}
|
||||
}
|
||||
|
||||
// endregion: --- IntoServiceTargetResolverFn
|
Reference in New Issue
Block a user