From ed74d44c7a99f25fe4e5c3605d11208b43f677d7 Mon Sep 17 00:00:00 2001 From: Jeremy Chone Date: Mon, 6 Jan 2025 08:30:33 -0800 Subject: [PATCH] + add DeepSeek support --- README.md | 92 +++++++++++-------- examples/c00-readme.rs | 4 +- src/adapter/adapter_kind.rs | 8 ++ src/adapter/adapters/deepseek/adapter_impl.rs | 57 ++++++++++++ src/adapter/adapters/deepseek/mod.rs | 11 +++ src/adapter/adapters/groq/adapter_impl.rs | 2 +- src/adapter/adapters/mod.rs | 1 + src/adapter/adapters/openai/streamer.rs | 8 +- src/adapter/dispatcher.rs | 8 ++ tests/tests_p_deepseek.rs | 75 +++++++++++++++ 10 files changed, 222 insertions(+), 44 deletions(-) create mode 100644 src/adapter/adapters/deepseek/adapter_impl.rs create mode 100644 src/adapter/adapters/deepseek/mod.rs create mode 100644 tests/tests_p_deepseek.rs diff --git a/README.md b/README.md index 5d9097f..5a77db1 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,14 @@ Provides a common and ergonomic single API to many generative AI Providers, such Check out [devai.run](https://devai.run), the **Iterate to Automate** command-line application that leverages **GenAI** for multi-AI capabilities. +## Key Features + +- Native Multi-AI Provider/Model: OpenAI, Anthropic, Gemini, Ollama, Groq, xAI, DeepSeek (Direct chat and stream) (see [examples/c00-readme.rs](examples/c00-readme.rs)) +- Image Analysis (for OpenAI, Gemini flash-2, Anthropic) (see [examples/c07-image.rs](examples/c07-image.rs)) +- Custom Auth/API Key (see [examples/c02-auth.rs](examples/c02-auth.rs)) +- Model Alias (see [examples/c05-model-names.rs](examples/c05-model-names.rs)) +- Custom Endpoint, Auth, and Model Identifier (see [examples/c06-target-resolver.rs](examples/c06-target-resolver.rs)) + [Examples](#examples) | [Thanks](#thanks) | [Library Focus](#library-focus) | [Changelog](CHANGELOG.md) | Provider Mapping: [ChatOptions](#chatoptions) | [MetaUsage](#metausage) ## Examples @@ -32,29 +40,33 @@ Check out [devai.run](https://devai.run), the **Iterate to Automate** command-li [examples/c00-readme.rs](examples/c00-readme.rs) ```rust +//! Base examples demonstrating the core capabilities of genai + use genai::chat::printer::{print_chat_stream, PrintChatStreamOptions}; use genai::chat::{ChatMessage, ChatRequest}; use genai::Client; -const MODEL_OPENAI: &str = "gpt-4o-mini"; +const MODEL_OPENAI: &str = "gpt-4o-mini"; // o1-mini, gpt-4o-mini const MODEL_ANTHROPIC: &str = "claude-3-haiku-20240307"; const MODEL_COHERE: &str = "command-light"; const MODEL_GEMINI: &str = "gemini-1.5-flash-latest"; -const MODEL_GROQ: &str = "gemma-7b-it"; +const MODEL_GROQ: &str = "llama3-8b-8192"; const MODEL_OLLAMA: &str = "gemma:2b"; // sh: `ollama pull gemma:2b` const MODEL_XAI: &str = "grok-beta"; +const MODEL_DEEPSEEK: &str = "deepseek-chat"; -// NOTE: Those are the default environment keys for each AI Adapter Type. -// Can be customized, see `examples/c02-auth.rs` +// NOTE: These are the default environment keys for each AI Adapter Type. +// They can be customized; see `examples/c02-auth.rs` const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[ - // -- de/activate models/providers - (MODEL_OPENAI, "OPENAI_API_KEY"), - (MODEL_ANTHROPIC, "ANTHROPIC_API_KEY"), - (MODEL_COHERE, "COHERE_API_KEY"), - (MODEL_GEMINI, "GEMINI_API_KEY"), - (MODEL_GROQ, "GROQ_API_KEY"), - (MODEL_XAI, "XAI_API_KEY"), - (MODEL_OLLAMA, ""), + // -- De/activate models/providers + (MODEL_OPENAI, "OPENAI_API_KEY"), + (MODEL_ANTHROPIC, "ANTHROPIC_API_KEY"), + (MODEL_COHERE, "COHERE_API_KEY"), + (MODEL_GEMINI, "GEMINI_API_KEY"), + (MODEL_GROQ, "GROQ_API_KEY"), + (MODEL_XAI, "XAI_API_KEY"), + (MODEL_DEEPSEEK, "DEEPSEEK_API_KEY"), + (MODEL_OLLAMA, ""), ]; // NOTE: Model to AdapterKind (AI Provider) type mapping rule @@ -65,47 +77,47 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[ // - model in Groq models -> Groq // - For anything else -> Ollama // -// Can be customized, see `examples/c03-kind.rs` +// This can be customized; see `examples/c03-kind.rs` #[tokio::main] async fn main() -> Result<(), Box> { - let question = "Why is the sky red?"; + let question = "Why is the sky red?"; - let chat_req = ChatRequest::new(vec![ - // -- Messages (de/activate to see the differences) - ChatMessage::system("Answer in one sentence"), - ChatMessage::user(question), - ]); + let chat_req = ChatRequest::new(vec![ + // -- Messages (de/activate to see the differences) + ChatMessage::system("Answer in one sentence"), + ChatMessage::user(question), + ]); - let client = Client::default(); + let client = Client::default(); - let print_options = PrintChatStreamOptions::from_print_events(false); + let print_options = PrintChatStreamOptions::from_print_events(false); - for (model, env_name) in MODEL_AND_KEY_ENV_NAME_LIST { - // Skip if does not have the environment name set - if !env_name.is_empty() && std::env::var(env_name).is_err() { - println!("===== Skipping model: {model} (env var not set: {env_name})"); - continue; - } + for (model, env_name) in MODEL_AND_KEY_ENV_NAME_LIST { + // Skip if the environment name is not set + if !env_name.is_empty() && std::env::var(env_name).is_err() { + println!("===== Skipping model: {model} (env var not set: {env_name})"); + continue; + } - let adapter_kind = client.resolve_model_iden(model)?.adapter_kind; + let adapter_kind = client.resolve_service_target(model)?.model.adapter_kind; - println!("\n===== MODEL: {model} ({adapter_kind}) ====="); + println!("\n===== MODEL: {model} ({adapter_kind}) ====="); - println!("\n--- Question:\n{question}"); + println!("\n--- Question:\n{question}"); - println!("\n--- Answer:"); - let chat_res = client.exec_chat(model, chat_req.clone(), None).await?; - println!("{}", chat_res.content_text_as_str().unwrap_or("NO ANSWER")); + println!("\n--- Answer:"); + let chat_res = client.exec_chat(model, chat_req.clone(), None).await?; + println!("{}", chat_res.content_text_as_str().unwrap_or("NO ANSWER")); - println!("\n--- Answer: (streaming)"); - let chat_res = client.exec_chat_stream(model, chat_req.clone(), None).await?; - print_chat_stream(chat_res, Some(&print_options)).await?; + println!("\n--- Answer: (streaming)"); + let chat_res = client.exec_chat_stream(model, chat_req.clone(), None).await?; + print_chat_stream(chat_res, Some(&print_options)).await?; - println!(); - } + println!(); + } - Ok(()) + Ok(()) } ``` @@ -117,6 +129,8 @@ async fn main() -> Result<(), Box> { - [examples/c03-kind.rs](examples/c03-kind.rs) - Demonstrates how to provide a custom `AdapterKindResolver` to customize the "model name" to "adapter kind" mapping. - [examples/c04-chat-options.rs](examples/c04-chat-options.rs) - Demonstrates how to set chat generation options such as `temperature` and `max_tokens` at the client level (for all requests) and per request level. - [examples/c05-model-names.rs](examples/c05-model-names.rs) - Show how to get model names per AdapterKind. +- [examples/c06-target-resolver.rs](examples/c06-target-resolver.rs) - For custom Auth, EndPoint, and Model. +- [examples/c07-image.rs](examples/c07-image.rs) - Image Analysis support
Static Badge diff --git a/examples/c00-readme.rs b/examples/c00-readme.rs index 4cb3439..433352b 100644 --- a/examples/c00-readme.rs +++ b/examples/c00-readme.rs @@ -8,9 +8,10 @@ const MODEL_OPENAI: &str = "gpt-4o-mini"; // o1-mini, gpt-4o-mini const MODEL_ANTHROPIC: &str = "claude-3-haiku-20240307"; const MODEL_COHERE: &str = "command-light"; const MODEL_GEMINI: &str = "gemini-1.5-flash-latest"; -const MODEL_GROQ: &str = "gemma-7b-it"; +const MODEL_GROQ: &str = "llama3-8b-8192"; const MODEL_OLLAMA: &str = "gemma:2b"; // sh: `ollama pull gemma:2b` const MODEL_XAI: &str = "grok-beta"; +const MODEL_DEEPSEEK: &str = "deepseek-chat"; // NOTE: These are the default environment keys for each AI Adapter Type. // They can be customized; see `examples/c02-auth.rs` @@ -22,6 +23,7 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[ (MODEL_GEMINI, "GEMINI_API_KEY"), (MODEL_GROQ, "GROQ_API_KEY"), (MODEL_XAI, "XAI_API_KEY"), + (MODEL_DEEPSEEK, "DEEPSEEK_API_KEY"), (MODEL_OLLAMA, ""), ]; diff --git a/src/adapter/adapter_kind.rs b/src/adapter/adapter_kind.rs index af7c8f7..9bae7e9 100644 --- a/src/adapter/adapter_kind.rs +++ b/src/adapter/adapter_kind.rs @@ -1,6 +1,7 @@ use super::groq::MODELS as GROQ_MODELS; use crate::adapter::anthropic::AnthropicAdapter; use crate::adapter::cohere::CohereAdapter; +use crate::adapter::deepseek::DeepSeekAdapter; use crate::adapter::gemini::GeminiAdapter; use crate::adapter::groq::GroqAdapter; use crate::adapter::openai::OpenAIAdapter; @@ -26,6 +27,8 @@ pub enum AdapterKind { Groq, /// For xAI Xai, + /// For DeepSeek + DeepSeek, // Note: Variants will probably be suffixed // AnthropicBedrock, } @@ -42,6 +45,7 @@ impl AdapterKind { AdapterKind::Gemini => "Gemini", AdapterKind::Groq => "Groq", AdapterKind::Xai => "xAi", + AdapterKind::DeepSeek => "DeepSeek", } } @@ -55,6 +59,7 @@ impl AdapterKind { AdapterKind::Gemini => "gemini", AdapterKind::Groq => "groq", AdapterKind::Xai => "xai", + AdapterKind::DeepSeek => "DeepSeek", } } } @@ -70,6 +75,7 @@ impl AdapterKind { AdapterKind::Gemini => Some(GeminiAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Groq => Some(GroqAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Xai => Some(XaiAdapter::API_KEY_DEFAULT_ENV_NAME), + AdapterKind::DeepSeek => Some(DeepSeekAdapter::API_KEY_DEFAULT_ENV_NAME), AdapterKind::Ollama => None, } } @@ -98,6 +104,8 @@ impl AdapterKind { Ok(Self::Gemini) } else if model.starts_with("grok") { Ok(Self::Xai) + } else if model.starts_with("deepseek") { + Ok(Self::DeepSeek) } else if GROQ_MODELS.contains(&model) { return Ok(Self::Groq); } diff --git a/src/adapter/adapters/deepseek/adapter_impl.rs b/src/adapter/adapters/deepseek/adapter_impl.rs new file mode 100644 index 0000000..cb1853d --- /dev/null +++ b/src/adapter/adapters/deepseek/adapter_impl.rs @@ -0,0 +1,57 @@ +use crate::adapter::openai::OpenAIAdapter; +use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData}; +use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse}; +use crate::resolver::{AuthData, Endpoint}; +use crate::webc::WebResponse; +use crate::ModelIden; +use crate::{Result, ServiceTarget}; +use reqwest::RequestBuilder; + +pub struct DeepSeekAdapter; + +pub(in crate::adapter) const MODELS: &[&str] = &["deepseek-chat"]; + +impl DeepSeekAdapter { + pub const API_KEY_DEFAULT_ENV_NAME: &str = "DEEPSEEK_API_KEY"; +} + +// The Groq API adapter is modeled after the OpenAI adapter, as the Groq API is compatible with the OpenAI API. +impl Adapter for DeepSeekAdapter { + fn default_endpoint() -> Endpoint { + const BASE_URL: &str = "https://api.deepseek.com/v1/"; + Endpoint::from_static(BASE_URL) + } + + fn default_auth() -> AuthData { + AuthData::from_env(Self::API_KEY_DEFAULT_ENV_NAME) + } + + async fn all_model_names(_kind: AdapterKind) -> Result> { + Ok(MODELS.iter().map(|s| s.to_string()).collect()) + } + + fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String { + OpenAIAdapter::util_get_service_url(model, service_type, endpoint) + } + + fn to_web_request_data( + target: ServiceTarget, + service_type: ServiceType, + chat_req: ChatRequest, + chat_options: ChatOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::util_to_web_request_data(target, service_type, chat_req, chat_options) + } + + fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result { + OpenAIAdapter::to_chat_response(model_iden, web_response) + } + + fn to_chat_stream( + model_iden: ModelIden, + reqwest_builder: RequestBuilder, + options_set: ChatOptionsSet<'_, '_>, + ) -> Result { + OpenAIAdapter::to_chat_stream(model_iden, reqwest_builder, options_set) + } +} diff --git a/src/adapter/adapters/deepseek/mod.rs b/src/adapter/adapters/deepseek/mod.rs new file mode 100644 index 0000000..2d2c105 --- /dev/null +++ b/src/adapter/adapters/deepseek/mod.rs @@ -0,0 +1,11 @@ +//! API Documentation: https://api-docs.deepseek.com/ +//! Model Names: https://api-docs.deepseek.com/quick_start/pricing +//! Pricing: https://api-docs.deepseek.com/quick_start/pricing + +// region: --- Modules + +mod adapter_impl; + +pub use adapter_impl::*; + +// endregion: --- Modules diff --git a/src/adapter/adapters/groq/adapter_impl.rs b/src/adapter/adapters/groq/adapter_impl.rs index ca70177..fdb5b6d 100644 --- a/src/adapter/adapters/groq/adapter_impl.rs +++ b/src/adapter/adapters/groq/adapter_impl.rs @@ -18,8 +18,8 @@ pub(in crate::adapter) const MODELS: &[&str] = &[ "llama-3.1-70b-versatile", "llama-3.1-8b-instant", "mixtral-8x7b-32768", - "gemma-7b-it", "gemma2-9b-it", + "gemma-7b-it", // deprecated "llama3-groq-70b-8192-tool-use-preview", "llama3-groq-8b-8192-tool-use-preview", "llama3-8b-8192", diff --git a/src/adapter/adapters/mod.rs b/src/adapter/adapters/mod.rs index 1828c53..2f87c31 100644 --- a/src/adapter/adapters/mod.rs +++ b/src/adapter/adapters/mod.rs @@ -2,6 +2,7 @@ mod support; pub(super) mod anthropic; pub(super) mod cohere; +pub(super) mod deepseek; pub(super) mod gemini; pub(super) mod groq; pub(super) mod ollama; diff --git a/src/adapter/adapters/openai/streamer.rs b/src/adapter/adapters/openai/streamer.rs index f23bcd9..8315575 100644 --- a/src/adapter/adapters/openai/streamer.rs +++ b/src/adapter/adapters/openai/streamer.rs @@ -66,7 +66,6 @@ impl futures::Stream for OpenAIStreamer { } // -- Other Content Messages - let adapter_kind = self.options.model_iden.adapter_kind; // Parse to get the choice let mut message_data: Value = serde_json::from_str(&message.data).map_err(|serde_error| Error::StreamParse { @@ -76,6 +75,8 @@ impl futures::Stream for OpenAIStreamer { let first_choice: Option = message_data.x_take("/choices/0").ok(); + let adapter_kind = self.options.model_iden.adapter_kind; + // If we have a first choice, then it's a normal message if let Some(mut first_choice) = first_choice { // If finish_reason exists, it's the end of this choice. @@ -93,7 +94,7 @@ impl futures::Stream for OpenAIStreamer { .unwrap_or_default(); // permissive for now self.captured_data.usage = Some(usage) } - AdapterKind::Xai => { + AdapterKind::Xai | AdapterKind::DeepSeek => { let usage = message_data .x_take("usage") .map(OpenAIAdapter::into_usage) @@ -127,9 +128,10 @@ impl futures::Stream for OpenAIStreamer { } // -- Usage message else { - // If it's not Groq or xAI, then the usage is captured at the end when choices are empty or null + // If it's not Groq, xAI, DeepSeek the usage is captured at the end when choices are empty or null if !matches!(adapter_kind, AdapterKind::Groq) && !matches!(adapter_kind, AdapterKind::Xai) + && !matches!(adapter_kind, AdapterKind::DeepSeek) && self.captured_data.usage.is_none() // this might be redundant && self.options.capture_usage { diff --git a/src/adapter/dispatcher.rs b/src/adapter/dispatcher.rs index 0015756..6b4df03 100644 --- a/src/adapter/dispatcher.rs +++ b/src/adapter/dispatcher.rs @@ -11,6 +11,7 @@ use crate::{Result, ServiceTarget}; use reqwest::RequestBuilder; use super::groq::GroqAdapter; +use crate::adapter::deepseek::DeepSeekAdapter; use crate::adapter::xai::XaiAdapter; use crate::resolver::{AuthData, Endpoint}; @@ -31,6 +32,7 @@ impl AdapterDispatcher { AdapterKind::Gemini => GeminiAdapter::default_endpoint(), AdapterKind::Groq => GroqAdapter::default_endpoint(), AdapterKind::Xai => XaiAdapter::default_endpoint(), + AdapterKind::DeepSeek => DeepSeekAdapter::default_endpoint(), } } @@ -43,6 +45,7 @@ impl AdapterDispatcher { AdapterKind::Gemini => GeminiAdapter::default_auth(), AdapterKind::Groq => GroqAdapter::default_auth(), AdapterKind::Xai => XaiAdapter::default_auth(), + AdapterKind::DeepSeek => DeepSeekAdapter::default_auth(), } } @@ -55,6 +58,7 @@ impl AdapterDispatcher { AdapterKind::Gemini => GeminiAdapter::all_model_names(kind).await, AdapterKind::Groq => GroqAdapter::all_model_names(kind).await, AdapterKind::Xai => XaiAdapter::all_model_names(kind).await, + AdapterKind::DeepSeek => DeepSeekAdapter::all_model_names(kind).await, } } @@ -67,6 +71,7 @@ impl AdapterDispatcher { AdapterKind::Gemini => GeminiAdapter::get_service_url(model, service_type, endpoint), AdapterKind::Groq => GroqAdapter::get_service_url(model, service_type, endpoint), AdapterKind::Xai => XaiAdapter::get_service_url(model, service_type, endpoint), + AdapterKind::DeepSeek => DeepSeekAdapter::get_service_url(model, service_type, endpoint), } } @@ -87,6 +92,7 @@ impl AdapterDispatcher { AdapterKind::Gemini => GeminiAdapter::to_web_request_data(target, service_type, chat_req, options_set), AdapterKind::Groq => GroqAdapter::to_web_request_data(target, service_type, chat_req, options_set), AdapterKind::Xai => XaiAdapter::to_web_request_data(target, service_type, chat_req, options_set), + AdapterKind::DeepSeek => DeepSeekAdapter::to_web_request_data(target, service_type, chat_req, options_set), } } @@ -99,6 +105,7 @@ impl AdapterDispatcher { AdapterKind::Gemini => GeminiAdapter::to_chat_response(model_iden, web_response), AdapterKind::Groq => GroqAdapter::to_chat_response(model_iden, web_response), AdapterKind::Xai => XaiAdapter::to_chat_response(model_iden, web_response), + AdapterKind::DeepSeek => DeepSeekAdapter::to_chat_response(model_iden, web_response), } } @@ -115,6 +122,7 @@ impl AdapterDispatcher { AdapterKind::Gemini => GeminiAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), AdapterKind::Groq => GroqAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), AdapterKind::Xai => XaiAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), + AdapterKind::DeepSeek => DeepSeekAdapter::to_chat_stream(model_iden, reqwest_builder, options_set), } } } diff --git a/tests/tests_p_deepseek.rs b/tests/tests_p_deepseek.rs new file mode 100644 index 0000000..f0486bf --- /dev/null +++ b/tests/tests_p_deepseek.rs @@ -0,0 +1,75 @@ +mod support; + +use crate::support::common_tests; +use genai::adapter::AdapterKind; +use genai::resolver::AuthData; + +type Result = core::result::Result>; // For tests. + +const MODEL: &str = "deepseek-chat"; + +// region: --- Chat + +#[tokio::test] +async fn test_chat_simple_ok() -> Result<()> { + common_tests::common_test_chat_simple_ok(MODEL).await +} + +#[tokio::test] +async fn test_chat_multi_system_ok() -> Result<()> { + common_tests::common_test_chat_multi_system_ok(MODEL).await +} + +#[tokio::test] +async fn test_chat_json_mode_ok() -> Result<()> { + common_tests::common_test_chat_json_mode_ok(MODEL, true).await +} + +#[tokio::test] +async fn test_chat_temperature_ok() -> Result<()> { + common_tests::common_test_chat_temperature_ok(MODEL).await +} + +#[tokio::test] +async fn test_chat_stop_sequences_ok() -> Result<()> { + common_tests::common_test_chat_stop_sequences_ok(MODEL).await +} + +// endregion: --- Chat + +// region: --- Chat Stream Tests + +#[tokio::test] +async fn test_chat_stream_simple_ok() -> Result<()> { + common_tests::common_test_chat_stream_simple_ok(MODEL).await +} + +#[tokio::test] +async fn test_chat_stream_capture_content_ok() -> Result<()> { + common_tests::common_test_chat_stream_capture_content_ok(MODEL).await +} + +#[tokio::test] +async fn test_chat_stream_capture_all_ok() -> Result<()> { + common_tests::common_test_chat_stream_capture_all_ok(MODEL).await +} + +// endregion: --- Chat Stream Tests + +// region: --- Resolver Tests + +#[tokio::test] +async fn test_resolver_auth_ok() -> Result<()> { + common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("DEEPSEEK_API_KEY")).await +} + +// endregion: --- Resolver Tests + +// region: --- List + +#[tokio::test] +async fn test_list_models() -> Result<()> { + common_tests::common_test_list_models(AdapterKind::DeepSeek, "deepseek-chat").await +} + +// endregion: --- List