mirror of
https://github.com/mii443/rust-genai.git
synced 2025-08-22 16:25:27 +00:00
259 lines
8.4 KiB
Rust
259 lines
8.4 KiB
Rust
use crate::adapter::adapters::support::get_api_key;
|
|
use crate::adapter::gemini::GeminiStreamer;
|
|
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
|
|
use crate::chat::{
|
|
ChatOptionsSet, ChatRequest, ChatResponse, ChatResponseFormat, ChatRole, ChatStream, ChatStreamResponse,
|
|
MessageContent, MetaUsage,
|
|
};
|
|
use crate::resolver::{AuthData, Endpoint};
|
|
use crate::webc::{WebResponse, WebStream};
|
|
use crate::{Error, Result};
|
|
use crate::{ModelIden, ServiceTarget};
|
|
use reqwest::RequestBuilder;
|
|
use serde_json::{json, Value};
|
|
use value_ext::JsonValueExt;
|
|
|
|
pub struct GeminiAdapter;
|
|
|
|
const MODELS: &[&str] = &[
|
|
"gemini-1.5-pro",
|
|
"gemini-1.5-flash",
|
|
"gemini-1.5-flash-8b",
|
|
"gemini-1.0-pro",
|
|
"gemini-1.5-flash-latest",
|
|
];
|
|
|
|
// curl \
|
|
// -H 'Content-Type: application/json' \
|
|
// -d '{"contents":[{"parts":[{"text":"Explain how AI works"}]}]}' \
|
|
// -X POST 'https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=YOUR_API_KEY'
|
|
|
|
impl Adapter for GeminiAdapter {
|
|
fn default_endpoint(_kind: AdapterKind) -> Endpoint {
|
|
const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
|
|
Endpoint::from_static(BASE_URL)
|
|
}
|
|
|
|
fn default_auth(_kind: AdapterKind) -> AuthData {
|
|
AuthData::from_env("GEMINI_API_KEY")
|
|
}
|
|
|
|
/// Note: For now, this returns the common models (see above)
|
|
async fn all_model_names(_kind: AdapterKind) -> Result<Vec<String>> {
|
|
Ok(MODELS.iter().map(|s| s.to_string()).collect())
|
|
}
|
|
|
|
/// NOTE: As Google Gemini has decided to put their API_KEY in the URL,
|
|
/// this will return the URL without the API_KEY in it. The API_KEY will need to be added by the caller.
|
|
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
|
|
let base_url = endpoint.base_url();
|
|
let model_name = model.model_name.clone();
|
|
match service_type {
|
|
ServiceType::Chat => format!("{base_url}models/{model_name}:generateContent"),
|
|
ServiceType::ChatStream => format!("{base_url}models/{model_name}:streamGenerateContent"),
|
|
}
|
|
}
|
|
|
|
fn to_web_request_data(
|
|
target: ServiceTarget,
|
|
service_type: ServiceType,
|
|
chat_req: ChatRequest,
|
|
options_set: ChatOptionsSet<'_, '_>,
|
|
) -> Result<WebRequestData> {
|
|
let ServiceTarget { endpoint, auth, model } = target;
|
|
|
|
// -- api_key
|
|
let api_key = get_api_key(auth, &model)?;
|
|
|
|
// -- url
|
|
// NOTE: Somehow, Google decided to put the API key in the URL.
|
|
// This should be considered an antipattern from a security point of view
|
|
// even if it is done by the well respected Google. Everybody can make mistake once in a while.
|
|
// e.g., '...models/gemini-1.5-flash-latest:generateContent?key=YOUR_API_KEY'
|
|
let url = Self::get_service_url(&model, service_type, endpoint);
|
|
let url = format!("{url}?key={api_key}");
|
|
|
|
// -- parts
|
|
let GeminiChatRequestParts { system, contents } = Self::into_gemini_request_parts(model, chat_req)?;
|
|
|
|
// -- Playload
|
|
let mut payload = json!({
|
|
"contents": contents,
|
|
});
|
|
|
|
// -- headers (empty for gemini, since API_KEY is in url)
|
|
let headers = vec![];
|
|
|
|
// Note: It's unclear from the spec if the content of systemInstruction should have a role.
|
|
// Right now, it is omitted (since the spec states it can only be "user" or "model")
|
|
// It seems to work. https://ai.google.dev/api/rest/v1beta/models/generateContent
|
|
if let Some(system) = system {
|
|
payload.x_insert(
|
|
"systemInstruction",
|
|
json!({
|
|
"parts": [ { "text": system }]
|
|
}),
|
|
)?;
|
|
}
|
|
|
|
// -- Response Format
|
|
if let Some(ChatResponseFormat::JsonSpec(st_json)) = options_set.response_format() {
|
|
// x_insert
|
|
// responseMimeType: "application/json",
|
|
// responseSchema: {
|
|
payload.x_insert("/generationConfig/responseMimeType", "application/json")?;
|
|
let mut schema = st_json.schema.clone();
|
|
schema.x_walk(|parent_map, name| {
|
|
if name == "additionalProperties" {
|
|
parent_map.remove("additionalProperties");
|
|
}
|
|
true
|
|
});
|
|
payload.x_insert("/generationConfig/responseSchema", schema)?;
|
|
}
|
|
|
|
// -- Add supported ChatOptions
|
|
if let Some(temperature) = options_set.temperature() {
|
|
payload.x_insert("/generationConfig/temperature", temperature)?;
|
|
}
|
|
|
|
if !options_set.stop_sequences().is_empty() {
|
|
payload.x_insert("/generationConfig/stopSequences", options_set.stop_sequences())?;
|
|
}
|
|
|
|
if let Some(max_tokens) = options_set.max_tokens() {
|
|
payload.x_insert("/generationConfig/maxOutputTokens", max_tokens)?;
|
|
}
|
|
if let Some(top_p) = options_set.top_p() {
|
|
payload.x_insert("/generationConfig/topP", top_p)?;
|
|
}
|
|
|
|
Ok(WebRequestData { url, headers, payload })
|
|
}
|
|
|
|
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
|
|
let WebResponse { body, .. } = web_response;
|
|
|
|
let gemini_response = Self::body_to_gemini_chat_response(&model_iden.clone(), body)?;
|
|
let GeminiChatResponse { content, usage } = gemini_response;
|
|
let content = content.map(MessageContent::from);
|
|
|
|
Ok(ChatResponse {
|
|
content,
|
|
model_iden,
|
|
usage,
|
|
})
|
|
}
|
|
|
|
fn to_chat_stream(
|
|
model_iden: ModelIden,
|
|
reqwest_builder: RequestBuilder,
|
|
options_set: ChatOptionsSet<'_, '_>,
|
|
) -> Result<ChatStreamResponse> {
|
|
let web_stream = WebStream::new_with_pretty_json_array(reqwest_builder);
|
|
|
|
let gemini_stream = GeminiStreamer::new(web_stream, model_iden.clone(), options_set);
|
|
let chat_stream = ChatStream::from_inter_stream(gemini_stream);
|
|
|
|
Ok(ChatStreamResponse {
|
|
model_iden,
|
|
stream: chat_stream,
|
|
})
|
|
}
|
|
}
|
|
|
|
// region: --- Support
|
|
|
|
/// Support functions for GeminiAdapter
|
|
impl GeminiAdapter {
|
|
pub(super) fn body_to_gemini_chat_response(model_iden: &ModelIden, mut body: Value) -> Result<GeminiChatResponse> {
|
|
// If the body has an `error` property, then it is assumed to be an error.
|
|
if body.get("error").is_some() {
|
|
return Err(Error::StreamEventError {
|
|
model_iden: model_iden.clone(),
|
|
body,
|
|
});
|
|
}
|
|
|
|
let content = body.x_take::<Value>("/candidates/0/content/parts/0/text")?;
|
|
let usage = body.x_take::<Value>("usageMetadata").map(Self::into_usage).unwrap_or_default();
|
|
|
|
Ok(GeminiChatResponse {
|
|
content: content.as_str().map(String::from),
|
|
usage,
|
|
})
|
|
}
|
|
|
|
pub(super) fn into_usage(mut usage_value: Value) -> MetaUsage {
|
|
let input_tokens: Option<i32> = usage_value.x_take("promptTokenCount").ok();
|
|
let output_tokens: Option<i32> = usage_value.x_take("candidatesTokenCount").ok();
|
|
let total_tokens: Option<i32> = usage_value.x_take("totalTokenCount").ok();
|
|
MetaUsage {
|
|
input_tokens,
|
|
output_tokens,
|
|
total_tokens,
|
|
}
|
|
}
|
|
|
|
/// Takes the genai ChatMessages and builds the System string and JSON Messages for Gemini.
|
|
/// - Role mapping `ChatRole:User -> role: "user"`, `ChatRole::Assistant -> role: "model"`
|
|
/// - `ChatRole::System` is concatenated (with an empty line) into a single `system` for the system instruction.
|
|
/// - This adapter uses version v1beta, which supports `systemInstruction`
|
|
/// - The eventual `chat_req.system` is pushed first into the "systemInstruction"
|
|
fn into_gemini_request_parts(model_iden: ModelIden, chat_req: ChatRequest) -> Result<GeminiChatRequestParts> {
|
|
let mut contents: Vec<Value> = Vec::new();
|
|
let mut systems: Vec<String> = Vec::new();
|
|
|
|
if let Some(system) = chat_req.system {
|
|
systems.push(system);
|
|
}
|
|
|
|
// -- Build
|
|
for msg in chat_req.messages {
|
|
// TODO: Needs to implement tool_calls
|
|
let MessageContent::Text(content) = msg.content else {
|
|
return Err(Error::MessageContentTypeNotSupported {
|
|
model_iden,
|
|
cause: "Only MessageContent::Text supported for this model (for now)",
|
|
});
|
|
};
|
|
|
|
match msg.role {
|
|
// For now, system goes as "user" (later, we might have adapter_config.system_to_user_impl)
|
|
ChatRole::System => systems.push(content),
|
|
ChatRole::User => contents.push(json! ({"role": "user", "parts": [{"text": content}]})),
|
|
ChatRole::Assistant => contents.push(json! ({"role": "model", "parts": [{"text": content}]})),
|
|
ChatRole::Tool => {
|
|
return Err(Error::MessageRoleNotSupported {
|
|
model_iden,
|
|
role: ChatRole::Tool,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
let system = if !systems.is_empty() {
|
|
Some(systems.join("\n"))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
Ok(GeminiChatRequestParts { system, contents })
|
|
}
|
|
}
|
|
|
|
// struct Gemini
|
|
|
|
pub(super) struct GeminiChatResponse {
|
|
pub content: Option<String>,
|
|
pub usage: MetaUsage,
|
|
}
|
|
|
|
struct GeminiChatRequestParts {
|
|
system: Option<String>,
|
|
/// The chat history (user and assistant, except for the last user message which is a message)
|
|
contents: Vec<Value>,
|
|
}
|
|
|
|
// endregion: --- Support
|