+ add DeepSeek support

This commit is contained in:
Jeremy Chone
2025-01-06 08:30:33 -08:00
parent 8c6ee76702
commit ed74d44c7a
10 changed files with 222 additions and 44 deletions

View File

@ -25,6 +25,14 @@ Provides a common and ergonomic single API to many generative AI Providers, such
Check out [devai.run](https://devai.run), the **Iterate to Automate** command-line application that leverages **GenAI** for multi-AI capabilities.
## Key Features
- Native Multi-AI Provider/Model: OpenAI, Anthropic, Gemini, Ollama, Groq, xAI, DeepSeek (Direct chat and stream) (see [examples/c00-readme.rs](examples/c00-readme.rs))
- Image Analysis (for OpenAI, Gemini flash-2, Anthropic) (see [examples/c07-image.rs](examples/c07-image.rs))
- Custom Auth/API Key (see [examples/c02-auth.rs](examples/c02-auth.rs))
- Model Alias (see [examples/c05-model-names.rs](examples/c05-model-names.rs))
- Custom Endpoint, Auth, and Model Identifier (see [examples/c06-target-resolver.rs](examples/c06-target-resolver.rs))
[Examples](#examples) | [Thanks](#thanks) | [Library Focus](#library-focus) | [Changelog](CHANGELOG.md) | Provider Mapping: [ChatOptions](#chatoptions) | [MetaUsage](#metausage)
## Examples
@ -32,29 +40,33 @@ Check out [devai.run](https://devai.run), the **Iterate to Automate** command-li
[examples/c00-readme.rs](examples/c00-readme.rs)
```rust
//! Base examples demonstrating the core capabilities of genai
use genai::chat::printer::{print_chat_stream, PrintChatStreamOptions};
use genai::chat::{ChatMessage, ChatRequest};
use genai::Client;
const MODEL_OPENAI: &str = "gpt-4o-mini";
const MODEL_OPENAI: &str = "gpt-4o-mini"; // o1-mini, gpt-4o-mini
const MODEL_ANTHROPIC: &str = "claude-3-haiku-20240307";
const MODEL_COHERE: &str = "command-light";
const MODEL_GEMINI: &str = "gemini-1.5-flash-latest";
const MODEL_GROQ: &str = "gemma-7b-it";
const MODEL_GROQ: &str = "llama3-8b-8192";
const MODEL_OLLAMA: &str = "gemma:2b"; // sh: `ollama pull gemma:2b`
const MODEL_XAI: &str = "grok-beta";
const MODEL_DEEPSEEK: &str = "deepseek-chat";
// NOTE: Those are the default environment keys for each AI Adapter Type.
// Can be customized, see `examples/c02-auth.rs`
// NOTE: These are the default environment keys for each AI Adapter Type.
// They can be customized; see `examples/c02-auth.rs`
const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
// -- de/activate models/providers
(MODEL_OPENAI, "OPENAI_API_KEY"),
(MODEL_ANTHROPIC, "ANTHROPIC_API_KEY"),
(MODEL_COHERE, "COHERE_API_KEY"),
(MODEL_GEMINI, "GEMINI_API_KEY"),
(MODEL_GROQ, "GROQ_API_KEY"),
(MODEL_XAI, "XAI_API_KEY"),
(MODEL_OLLAMA, ""),
// -- De/activate models/providers
(MODEL_OPENAI, "OPENAI_API_KEY"),
(MODEL_ANTHROPIC, "ANTHROPIC_API_KEY"),
(MODEL_COHERE, "COHERE_API_KEY"),
(MODEL_GEMINI, "GEMINI_API_KEY"),
(MODEL_GROQ, "GROQ_API_KEY"),
(MODEL_XAI, "XAI_API_KEY"),
(MODEL_DEEPSEEK, "DEEPSEEK_API_KEY"),
(MODEL_OLLAMA, ""),
];
// NOTE: Model to AdapterKind (AI Provider) type mapping rule
@ -65,47 +77,47 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
// - model in Groq models -> Groq
// - For anything else -> Ollama
//
// Can be customized, see `examples/c03-kind.rs`
// This can be customized; see `examples/c03-kind.rs`
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let question = "Why is the sky red?";
let question = "Why is the sky red?";
let chat_req = ChatRequest::new(vec![
// -- Messages (de/activate to see the differences)
ChatMessage::system("Answer in one sentence"),
ChatMessage::user(question),
]);
let chat_req = ChatRequest::new(vec![
// -- Messages (de/activate to see the differences)
ChatMessage::system("Answer in one sentence"),
ChatMessage::user(question),
]);
let client = Client::default();
let client = Client::default();
let print_options = PrintChatStreamOptions::from_print_events(false);
let print_options = PrintChatStreamOptions::from_print_events(false);
for (model, env_name) in MODEL_AND_KEY_ENV_NAME_LIST {
// Skip if does not have the environment name set
if !env_name.is_empty() && std::env::var(env_name).is_err() {
println!("===== Skipping model: {model} (env var not set: {env_name})");
continue;
}
for (model, env_name) in MODEL_AND_KEY_ENV_NAME_LIST {
// Skip if the environment name is not set
if !env_name.is_empty() && std::env::var(env_name).is_err() {
println!("===== Skipping model: {model} (env var not set: {env_name})");
continue;
}
let adapter_kind = client.resolve_model_iden(model)?.adapter_kind;
let adapter_kind = client.resolve_service_target(model)?.model.adapter_kind;
println!("\n===== MODEL: {model} ({adapter_kind}) =====");
println!("\n===== MODEL: {model} ({adapter_kind}) =====");
println!("\n--- Question:\n{question}");
println!("\n--- Question:\n{question}");
println!("\n--- Answer:");
let chat_res = client.exec_chat(model, chat_req.clone(), None).await?;
println!("{}", chat_res.content_text_as_str().unwrap_or("NO ANSWER"));
println!("\n--- Answer:");
let chat_res = client.exec_chat(model, chat_req.clone(), None).await?;
println!("{}", chat_res.content_text_as_str().unwrap_or("NO ANSWER"));
println!("\n--- Answer: (streaming)");
let chat_res = client.exec_chat_stream(model, chat_req.clone(), None).await?;
print_chat_stream(chat_res, Some(&print_options)).await?;
println!("\n--- Answer: (streaming)");
let chat_res = client.exec_chat_stream(model, chat_req.clone(), None).await?;
print_chat_stream(chat_res, Some(&print_options)).await?;
println!();
}
println!();
}
Ok(())
Ok(())
}
```
@ -117,6 +129,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
- [examples/c03-kind.rs](examples/c03-kind.rs) - Demonstrates how to provide a custom `AdapterKindResolver` to customize the "model name" to "adapter kind" mapping.
- [examples/c04-chat-options.rs](examples/c04-chat-options.rs) - Demonstrates how to set chat generation options such as `temperature` and `max_tokens` at the client level (for all requests) and per request level.
- [examples/c05-model-names.rs](examples/c05-model-names.rs) - Show how to get model names per AdapterKind.
- [examples/c06-target-resolver.rs](examples/c06-target-resolver.rs) - For custom Auth, EndPoint, and Model.
- [examples/c07-image.rs](examples/c07-image.rs) - Image Analysis support
<br />
<a href="https://www.youtube.com/playlist?list=PL7r-PXl6ZPcBcLsBdBABOFUuLziNyigqj"><img alt="Static Badge" src="https://img.shields.io/badge/YouTube_JC_AI_Playlist-Video?style=flat&logo=youtube&color=%23ff0000"></a>

View File

@ -8,9 +8,10 @@ const MODEL_OPENAI: &str = "gpt-4o-mini"; // o1-mini, gpt-4o-mini
const MODEL_ANTHROPIC: &str = "claude-3-haiku-20240307";
const MODEL_COHERE: &str = "command-light";
const MODEL_GEMINI: &str = "gemini-1.5-flash-latest";
const MODEL_GROQ: &str = "gemma-7b-it";
const MODEL_GROQ: &str = "llama3-8b-8192";
const MODEL_OLLAMA: &str = "gemma:2b"; // sh: `ollama pull gemma:2b`
const MODEL_XAI: &str = "grok-beta";
const MODEL_DEEPSEEK: &str = "deepseek-chat";
// NOTE: These are the default environment keys for each AI Adapter Type.
// They can be customized; see `examples/c02-auth.rs`
@ -22,6 +23,7 @@ const MODEL_AND_KEY_ENV_NAME_LIST: &[(&str, &str)] = &[
(MODEL_GEMINI, "GEMINI_API_KEY"),
(MODEL_GROQ, "GROQ_API_KEY"),
(MODEL_XAI, "XAI_API_KEY"),
(MODEL_DEEPSEEK, "DEEPSEEK_API_KEY"),
(MODEL_OLLAMA, ""),
];

View File

@ -1,6 +1,7 @@
use super::groq::MODELS as GROQ_MODELS;
use crate::adapter::anthropic::AnthropicAdapter;
use crate::adapter::cohere::CohereAdapter;
use crate::adapter::deepseek::DeepSeekAdapter;
use crate::adapter::gemini::GeminiAdapter;
use crate::adapter::groq::GroqAdapter;
use crate::adapter::openai::OpenAIAdapter;
@ -26,6 +27,8 @@ pub enum AdapterKind {
Groq,
/// For xAI
Xai,
/// For DeepSeek
DeepSeek,
// Note: Variants will probably be suffixed
// AnthropicBedrock,
}
@ -42,6 +45,7 @@ impl AdapterKind {
AdapterKind::Gemini => "Gemini",
AdapterKind::Groq => "Groq",
AdapterKind::Xai => "xAi",
AdapterKind::DeepSeek => "DeepSeek",
}
}
@ -55,6 +59,7 @@ impl AdapterKind {
AdapterKind::Gemini => "gemini",
AdapterKind::Groq => "groq",
AdapterKind::Xai => "xai",
AdapterKind::DeepSeek => "DeepSeek",
}
}
}
@ -70,6 +75,7 @@ impl AdapterKind {
AdapterKind::Gemini => Some(GeminiAdapter::API_KEY_DEFAULT_ENV_NAME),
AdapterKind::Groq => Some(GroqAdapter::API_KEY_DEFAULT_ENV_NAME),
AdapterKind::Xai => Some(XaiAdapter::API_KEY_DEFAULT_ENV_NAME),
AdapterKind::DeepSeek => Some(DeepSeekAdapter::API_KEY_DEFAULT_ENV_NAME),
AdapterKind::Ollama => None,
}
}
@ -98,6 +104,8 @@ impl AdapterKind {
Ok(Self::Gemini)
} else if model.starts_with("grok") {
Ok(Self::Xai)
} else if model.starts_with("deepseek") {
Ok(Self::DeepSeek)
} else if GROQ_MODELS.contains(&model) {
return Ok(Self::Groq);
}

View File

@ -0,0 +1,57 @@
use crate::adapter::openai::OpenAIAdapter;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{ChatOptionsSet, ChatRequest, ChatResponse, ChatStreamResponse};
use crate::resolver::{AuthData, Endpoint};
use crate::webc::WebResponse;
use crate::ModelIden;
use crate::{Result, ServiceTarget};
use reqwest::RequestBuilder;
pub struct DeepSeekAdapter;
pub(in crate::adapter) const MODELS: &[&str] = &["deepseek-chat"];
impl DeepSeekAdapter {
pub const API_KEY_DEFAULT_ENV_NAME: &str = "DEEPSEEK_API_KEY";
}
// The Groq API adapter is modeled after the OpenAI adapter, as the Groq API is compatible with the OpenAI API.
impl Adapter for DeepSeekAdapter {
fn default_endpoint() -> Endpoint {
const BASE_URL: &str = "https://api.deepseek.com/v1/";
Endpoint::from_static(BASE_URL)
}
fn default_auth() -> AuthData {
AuthData::from_env(Self::API_KEY_DEFAULT_ENV_NAME)
}
async fn all_model_names(_kind: AdapterKind) -> Result<Vec<String>> {
Ok(MODELS.iter().map(|s| s.to_string()).collect())
}
fn get_service_url(model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> String {
OpenAIAdapter::util_get_service_url(model, service_type, endpoint)
}
fn to_web_request_data(
target: ServiceTarget,
service_type: ServiceType,
chat_req: ChatRequest,
chat_options: ChatOptionsSet<'_, '_>,
) -> Result<WebRequestData> {
OpenAIAdapter::util_to_web_request_data(target, service_type, chat_req, chat_options)
}
fn to_chat_response(model_iden: ModelIden, web_response: WebResponse) -> Result<ChatResponse> {
OpenAIAdapter::to_chat_response(model_iden, web_response)
}
fn to_chat_stream(
model_iden: ModelIden,
reqwest_builder: RequestBuilder,
options_set: ChatOptionsSet<'_, '_>,
) -> Result<ChatStreamResponse> {
OpenAIAdapter::to_chat_stream(model_iden, reqwest_builder, options_set)
}
}

View File

@ -0,0 +1,11 @@
//! API Documentation: https://api-docs.deepseek.com/
//! Model Names: https://api-docs.deepseek.com/quick_start/pricing
//! Pricing: https://api-docs.deepseek.com/quick_start/pricing
// region: --- Modules
mod adapter_impl;
pub use adapter_impl::*;
// endregion: --- Modules

View File

@ -18,8 +18,8 @@ pub(in crate::adapter) const MODELS: &[&str] = &[
"llama-3.1-70b-versatile",
"llama-3.1-8b-instant",
"mixtral-8x7b-32768",
"gemma-7b-it",
"gemma2-9b-it",
"gemma-7b-it", // deprecated
"llama3-groq-70b-8192-tool-use-preview",
"llama3-groq-8b-8192-tool-use-preview",
"llama3-8b-8192",

View File

@ -2,6 +2,7 @@ mod support;
pub(super) mod anthropic;
pub(super) mod cohere;
pub(super) mod deepseek;
pub(super) mod gemini;
pub(super) mod groq;
pub(super) mod ollama;

View File

@ -66,7 +66,6 @@ impl futures::Stream for OpenAIStreamer {
}
// -- Other Content Messages
let adapter_kind = self.options.model_iden.adapter_kind;
// Parse to get the choice
let mut message_data: Value =
serde_json::from_str(&message.data).map_err(|serde_error| Error::StreamParse {
@ -76,6 +75,8 @@ impl futures::Stream for OpenAIStreamer {
let first_choice: Option<Value> = message_data.x_take("/choices/0").ok();
let adapter_kind = self.options.model_iden.adapter_kind;
// If we have a first choice, then it's a normal message
if let Some(mut first_choice) = first_choice {
// If finish_reason exists, it's the end of this choice.
@ -93,7 +94,7 @@ impl futures::Stream for OpenAIStreamer {
.unwrap_or_default(); // permissive for now
self.captured_data.usage = Some(usage)
}
AdapterKind::Xai => {
AdapterKind::Xai | AdapterKind::DeepSeek => {
let usage = message_data
.x_take("usage")
.map(OpenAIAdapter::into_usage)
@ -127,9 +128,10 @@ impl futures::Stream for OpenAIStreamer {
}
// -- Usage message
else {
// If it's not Groq or xAI, then the usage is captured at the end when choices are empty or null
// If it's not Groq, xAI, DeepSeek the usage is captured at the end when choices are empty or null
if !matches!(adapter_kind, AdapterKind::Groq)
&& !matches!(adapter_kind, AdapterKind::Xai)
&& !matches!(adapter_kind, AdapterKind::DeepSeek)
&& self.captured_data.usage.is_none() // this might be redundant
&& self.options.capture_usage
{

View File

@ -11,6 +11,7 @@ use crate::{Result, ServiceTarget};
use reqwest::RequestBuilder;
use super::groq::GroqAdapter;
use crate::adapter::deepseek::DeepSeekAdapter;
use crate::adapter::xai::XaiAdapter;
use crate::resolver::{AuthData, Endpoint};
@ -31,6 +32,7 @@ impl AdapterDispatcher {
AdapterKind::Gemini => GeminiAdapter::default_endpoint(),
AdapterKind::Groq => GroqAdapter::default_endpoint(),
AdapterKind::Xai => XaiAdapter::default_endpoint(),
AdapterKind::DeepSeek => DeepSeekAdapter::default_endpoint(),
}
}
@ -43,6 +45,7 @@ impl AdapterDispatcher {
AdapterKind::Gemini => GeminiAdapter::default_auth(),
AdapterKind::Groq => GroqAdapter::default_auth(),
AdapterKind::Xai => XaiAdapter::default_auth(),
AdapterKind::DeepSeek => DeepSeekAdapter::default_auth(),
}
}
@ -55,6 +58,7 @@ impl AdapterDispatcher {
AdapterKind::Gemini => GeminiAdapter::all_model_names(kind).await,
AdapterKind::Groq => GroqAdapter::all_model_names(kind).await,
AdapterKind::Xai => XaiAdapter::all_model_names(kind).await,
AdapterKind::DeepSeek => DeepSeekAdapter::all_model_names(kind).await,
}
}
@ -67,6 +71,7 @@ impl AdapterDispatcher {
AdapterKind::Gemini => GeminiAdapter::get_service_url(model, service_type, endpoint),
AdapterKind::Groq => GroqAdapter::get_service_url(model, service_type, endpoint),
AdapterKind::Xai => XaiAdapter::get_service_url(model, service_type, endpoint),
AdapterKind::DeepSeek => DeepSeekAdapter::get_service_url(model, service_type, endpoint),
}
}
@ -87,6 +92,7 @@ impl AdapterDispatcher {
AdapterKind::Gemini => GeminiAdapter::to_web_request_data(target, service_type, chat_req, options_set),
AdapterKind::Groq => GroqAdapter::to_web_request_data(target, service_type, chat_req, options_set),
AdapterKind::Xai => XaiAdapter::to_web_request_data(target, service_type, chat_req, options_set),
AdapterKind::DeepSeek => DeepSeekAdapter::to_web_request_data(target, service_type, chat_req, options_set),
}
}
@ -99,6 +105,7 @@ impl AdapterDispatcher {
AdapterKind::Gemini => GeminiAdapter::to_chat_response(model_iden, web_response),
AdapterKind::Groq => GroqAdapter::to_chat_response(model_iden, web_response),
AdapterKind::Xai => XaiAdapter::to_chat_response(model_iden, web_response),
AdapterKind::DeepSeek => DeepSeekAdapter::to_chat_response(model_iden, web_response),
}
}
@ -115,6 +122,7 @@ impl AdapterDispatcher {
AdapterKind::Gemini => GeminiAdapter::to_chat_stream(model_iden, reqwest_builder, options_set),
AdapterKind::Groq => GroqAdapter::to_chat_stream(model_iden, reqwest_builder, options_set),
AdapterKind::Xai => XaiAdapter::to_chat_stream(model_iden, reqwest_builder, options_set),
AdapterKind::DeepSeek => DeepSeekAdapter::to_chat_stream(model_iden, reqwest_builder, options_set),
}
}
}

75
tests/tests_p_deepseek.rs Normal file
View File

@ -0,0 +1,75 @@
mod support;
use crate::support::common_tests;
use genai::adapter::AdapterKind;
use genai::resolver::AuthData;
type Result<T> = core::result::Result<T, Box<dyn std::error::Error>>; // For tests.
const MODEL: &str = "deepseek-chat";
// region: --- Chat
#[tokio::test]
async fn test_chat_simple_ok() -> Result<()> {
common_tests::common_test_chat_simple_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_multi_system_ok() -> Result<()> {
common_tests::common_test_chat_multi_system_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_json_mode_ok() -> Result<()> {
common_tests::common_test_chat_json_mode_ok(MODEL, true).await
}
#[tokio::test]
async fn test_chat_temperature_ok() -> Result<()> {
common_tests::common_test_chat_temperature_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_stop_sequences_ok() -> Result<()> {
common_tests::common_test_chat_stop_sequences_ok(MODEL).await
}
// endregion: --- Chat
// region: --- Chat Stream Tests
#[tokio::test]
async fn test_chat_stream_simple_ok() -> Result<()> {
common_tests::common_test_chat_stream_simple_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_stream_capture_content_ok() -> Result<()> {
common_tests::common_test_chat_stream_capture_content_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_stream_capture_all_ok() -> Result<()> {
common_tests::common_test_chat_stream_capture_all_ok(MODEL).await
}
// endregion: --- Chat Stream Tests
// region: --- Resolver Tests
#[tokio::test]
async fn test_resolver_auth_ok() -> Result<()> {
common_tests::common_test_resolver_auth_ok(MODEL, AuthData::from_env("DEEPSEEK_API_KEY")).await
}
// endregion: --- Resolver Tests
// region: --- List
#[tokio::test]
async fn test_list_models() -> Result<()> {
common_tests::common_test_list_models(AdapterKind::DeepSeek, "deepseek-chat").await
}
// endregion: --- List