Files
rust-genai/src/adapter/adapters/gemini/adapter_impl.rs
Jeremy Chone c47ae855ae . clippy clean
2024-12-08 19:23:39 -08:00

259 lines
8.4 KiB
Rust

use crate::adapter::adapters::support::get_api_key;
use crate::adapter::gemini::GeminiStreamer;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{
ChatOptionsSet, ChatRequest, ChatResponse, ChatResponseFormat, ChatRole, ChatStream, ChatStreamResponse,
MessageContent, MetaUsage,
};
use crate::resolver::{AuthData, Endpoint};
use crate::webc::{WebResponse, WebStream};
use crate::{Error, Result};
use crate::{ModelIden, ServiceTarget};
use reqwest::RequestBuilder;
use serde_json::{json, Value};
use value_ext::JsonValueExt;
pub struct GeminiAdapter;
const MODELS: &[&str] = &[
"gemini-1.5-pro",
"gemini-1.5-flash",
"gemini-1.5-flash-8b",
"gemini-1.0-pro",
"gemini-1.5-flash-latest",
];
// curl \
// -H 'Content-Type: application/json' \
// -d '{"contents":[{"parts":[{"text":"Explain how AI works"}]}]}' \
// -X POST 'https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key=YOUR_API_KEY'
impl Adapter for GeminiAdapter {
fn default_endpoint(_kind: AdapterKind) -> Endpoint {
const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/";
Endpoint::from_static(BASE_URL)
}
fn default_auth(_kind: AdapterKind) -> AuthData {
AuthData::from_env("GEMINI_API_KEY")
}
/// Note: For now, this returns the common models (see above)
async fn all_model_names(_kind: AdapterKind) -> Result<Vec<String>> {
Ok(MODELS.iter().map(|s| s.to_string()).collect())
}
/// NOTE: As Google Gemini has decided to put their API_KEY in the URL,
/// this will return the URL without the API_KEY in it. The API_KEY will need to be added by the caller.
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
let base_url = endpoint.base_url();
let model_name = model.model_name.clone();
match service_type {
ServiceType::Chat => format!("{base_url}models/{model_name}:generateContent"),
ServiceType::ChatStream => format!("{base_url}models/{model_name}:streamGenerateContent"),
}
}
fn to_web_request_data(
target: ServiceTarget,
service_type: ServiceType,
chat_req: ChatRequest,
options_set: ChatOptionsSet<'_, '_>,
) -> Result<WebRequestData> {
let ServiceTarget { endpoint, auth, model } = target;
// -- api_key
let api_key = get_api_key(auth, &model)?;
// -- url
// NOTE: Somehow, Google decided to put the API key in the URL.
// This should be considered an antipattern from a security point of view
// even if it is done by the well respected Google. Everybody can make mistake once in a while.
// e.g., '...models/gemini-1.5-flash-latest:generateContent?key=YOUR_API_KEY'
let url = Self::get_service_url(&model, service_type, endpoint);
let url = format!("{url}?key={api_key}");
// -- parts
let GeminiChatRequestParts { system, contents } = Self::into_gemini_request_parts(model, chat_req)?;
// -- Playload
let mut payload = json!({
"contents": contents,
});
// -- headers (empty for gemini, since API_KEY is in url)
let headers = vec![];
// Note: It's unclear from the spec if the content of systemInstruction should have a role.
// Right now, it is omitted (since the spec states it can only be "user" or "model")
// It seems to work. https://ai.google.dev/api/rest/v1beta/models/generateContent
if let Some(system) = system {
payload.x_insert(
"systemInstruction",
json!({
"parts": [ { "text": system }]
}),
)?;
}
// -- Response Format
if let Some(ChatResponseFormat::JsonSpec(st_json)) = options_set.response_format() {
// x_insert
// responseMimeType: "application/json",
// responseSchema: {
payload.x_insert("/generationConfig/responseMimeType", "application/json")?;
let mut schema = st_json.schema.clone();
schema.x_walk(|parent_map, name| {
if name == "additionalProperties" {
parent_map.remove("additionalProperties");
}
true
});
payload.x_insert("/generationConfig/responseSchema", schema)?;
}
// -- Add supported ChatOptions
if let Some(temperature) = options_set.temperature() {
payload.x_insert("/generationConfig/temperature", temperature)?;
}
if !options_set.stop_sequences().is_empty() {
payload.x_insert("/generationConfig/stopSequences", options_set.stop_sequences())?;
}
if let Some(max_tokens) = options_set.max_tokens() {
payload.x_insert("/generationConfig/maxOutputTokens", max_tokens)?;
}
if let Some(top_p) = options_set.top_p() {
payload.x_insert("/generationConfig/topP", top_p)?;
}
Ok(WebRequestData { url, headers, payload })
}
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
let WebResponse { body, .. } = web_response;
let gemini_response = Self::body_to_gemini_chat_response(&model_iden.clone(), body)?;
let GeminiChatResponse { content, usage } = gemini_response;
let content = content.map(MessageContent::from);
Ok(ChatResponse {
content,
model_iden,
usage,
})
}
fn to_chat_stream(
model_iden: ModelIden,
reqwest_builder: RequestBuilder,
options_set: ChatOptionsSet<'_, '_>,
) -> Result<ChatStreamResponse> {
let web_stream = WebStream::new_with_pretty_json_array(reqwest_builder);
let gemini_stream = GeminiStreamer::new(web_stream, model_iden.clone(), options_set);
let chat_stream = ChatStream::from_inter_stream(gemini_stream);
Ok(ChatStreamResponse {
model_iden,
stream: chat_stream,
})
}
}
// region: --- Support
/// Support functions for GeminiAdapter
impl GeminiAdapter {
pub(super) fn body_to_gemini_chat_response(model_iden: &ModelIden, mut body: Value) -> Result<GeminiChatResponse> {
// If the body has an `error` property, then it is assumed to be an error.
if body.get("error").is_some() {
return Err(Error::StreamEventError {
model_iden: model_iden.clone(),
body,
});
}
let content = body.x_take::<Value>("/candidates/0/content/parts/0/text")?;
let usage = body.x_take::<Value>("usageMetadata").map(Self::into_usage).unwrap_or_default();
Ok(GeminiChatResponse {
content: content.as_str().map(String::from),
usage,
})
}
pub(super) fn into_usage(mut usage_value: Value) -> MetaUsage {
let input_tokens: Option<i32> = usage_value.x_take("promptTokenCount").ok();
let output_tokens: Option<i32> = usage_value.x_take("candidatesTokenCount").ok();
let total_tokens: Option<i32> = usage_value.x_take("totalTokenCount").ok();
MetaUsage {
input_tokens,
output_tokens,
total_tokens,
}
}
/// Takes the genai ChatMessages and builds the System string and JSON Messages for Gemini.
/// - Role mapping `ChatRole:User -> role: "user"`, `ChatRole::Assistant -> role: "model"`
/// - `ChatRole::System` is concatenated (with an empty line) into a single `system` for the system instruction.
/// - This adapter uses version v1beta, which supports `systemInstruction`
/// - The eventual `chat_req.system` is pushed first into the "systemInstruction"
fn into_gemini_request_parts(model_iden: ModelIden, chat_req: ChatRequest) -> Result<GeminiChatRequestParts> {
let mut contents: Vec<Value> = Vec::new();
let mut systems: Vec<String> = Vec::new();
if let Some(system) = chat_req.system {
systems.push(system);
}
// -- Build
for msg in chat_req.messages {
// TODO: Needs to implement tool_calls
let MessageContent::Text(content) = msg.content else {
return Err(Error::MessageContentTypeNotSupported {
model_iden,
cause: "Only MessageContent::Text supported for this model (for now)",
});
};
match msg.role {
// For now, system goes as "user" (later, we might have adapter_config.system_to_user_impl)
ChatRole::System => systems.push(content),
ChatRole::User => contents.push(json! ({"role": "user", "parts": [{"text": content}]})),
ChatRole::Assistant => contents.push(json! ({"role": "model", "parts": [{"text": content}]})),
ChatRole::Tool => {
return Err(Error::MessageRoleNotSupported {
model_iden,
role: ChatRole::Tool,
})
}
}
}
let system = if !systems.is_empty() {
Some(systems.join("\n"))
} else {
None
};
Ok(GeminiChatRequestParts { system, contents })
}
}
// struct Gemini
pub(super) struct GeminiChatResponse {
pub content: Option<String>,
pub usage: MetaUsage,
}
struct GeminiChatRequestParts {
system: Option<String>,
/// The chat history (user and assistant, except for the last user message which is a message)
contents: Vec<Value>,
}
// endregion: --- Support