. ollama - removed workaround for multi-system lack of support (for old ollama)

This commit is contained in:
Jeremy Chone
2024-12-07 16:33:53 -08:00
parent a32f69dfee
commit e439080f16
9 changed files with 68 additions and 28 deletions

View File

@ -7,7 +7,8 @@
## 2024-12-07 - `0.1.13`
- `+` add stop_sequences support cohere -
- `.` ollama - removed workaround for multi-system lack of support (for old ollama)
- `+` add stop_sequences support cohere
- `+` stop_sequences - for openai, ollama, groq, gemini, cochere
- `+` stop_sequences - for anthropic (thanks [@semtexzv](https://github.com/semtexzv))

View File

@ -224,23 +224,12 @@ impl OpenAIAdapter {
/// Takes the genai ChatMessages and builds the OpenAIChatRequestParts
/// - `genai::ChatRequest.system`, if present, is added as the first message with role 'system'.
/// - All messages get added with the corresponding roles (tools are not supported for now)
fn into_openai_request_parts(model_iden: ModelIden, chat_req: ChatRequest) -> Result<OpenAIRequestParts> {
fn into_openai_request_parts(_model_iden: ModelIden, chat_req: ChatRequest) -> Result<OpenAIRequestParts> {
let mut messages: Vec<Value> = Vec::new();
// NOTE: For now system_messages is use to fix an issue with the Ollama compatibility layer that does not support multiple system messages.
// So, when ollama, it will concatenate the system message into a single one at the beginning
// NOTE: This might be fixed now, so, we could remove this.
let mut system_messages: Vec<String> = Vec::new();
let ollama_variant = matches!(model_iden.adapter_kind, AdapterKind::Ollama);
// -- Process the system
if let Some(system_msg) = chat_req.system {
if ollama_variant {
system_messages.push(system_msg)
} else {
messages.push(json!({"role": "system", "content": system_msg}));
}
messages.push(json!({"role": "system", "content": system_msg}));
}
// -- Process the messages
@ -250,14 +239,7 @@ impl OpenAIAdapter {
// For now, system and tool messages go to the system
ChatRole::System => {
if let MessageContent::Text(content) = msg.content {
// NOTE: Ollama does not support multiple system messages
// See note in the function comment
if ollama_variant {
system_messages.push(content);
} else {
messages.push(json!({"role": "system", "content": content}))
}
messages.push(json!({"role": "system", "content": content}))
}
// TODO: Probably need to warn if it is a ToolCalls type of content
}
@ -305,12 +287,6 @@ impl OpenAIAdapter {
}
}
// -- Finalize the system messages ollama case
if !system_messages.is_empty() {
let system_message = system_messages.join("\n");
messages.insert(0, json!({"role": "system", "content": system_message}));
}
// -- Process the tools
let tools = chat_req.tools.map(|tools| {
tools

View File

@ -36,6 +36,39 @@ pub async fn common_test_chat_simple_ok(model: &str) -> Result<()> {
Ok(())
}
pub async fn common_test_chat_multi_system_ok(model: &str) -> Result<()> {
// -- Setup & Fixtures
let client = Client::default();
let chat_req = ChatRequest::new(vec![
// -- Messages (deactivate to see the differences)
ChatMessage::system("Be very concise"),
ChatMessage::system("Explain with bullet points"),
ChatMessage::user("Why is the sky blue?"),
])
.with_system("And end with 'Thank you'");
// -- Exec
let chat_res = client.exec_chat(model, chat_req, None).await?;
// -- Check
assert!(
!get_option_value!(chat_res.content).is_empty(),
"Content should not be empty"
);
let usage = chat_res.usage;
let input_tokens = get_option_value!(usage.input_tokens);
let output_tokens = get_option_value!(usage.output_tokens);
let total_tokens = get_option_value!(usage.total_tokens);
assert!(total_tokens > 0, "total_tokens should be > 0");
assert!(
total_tokens == input_tokens + output_tokens,
"total_tokens should be equal to input_tokens + output_tokens"
);
Ok(())
}
/// Test with JSON mode enabled. This is not a structured output test.
/// - test_token: This is to avoid checking the token (due to an Ollama bug when in JSON mode, no token is returned)
pub async fn common_test_chat_json_mode_ok(model: &str, test_token: bool) -> Result<()> {

View File

@ -19,6 +19,11 @@ async fn test_chat_simple_ok() -> Result<()> {
common_tests::common_test_chat_simple_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_multi_system_ok() -> Result<()> {
common_tests::common_test_chat_multi_system_ok(MODEL).await
}
#[tokio::test]
#[serial(anthropic)]
async fn test_chat_temperature_ok() -> Result<()> {

View File

@ -14,6 +14,11 @@ async fn test_chat_simple_ok() -> Result<()> {
common_tests::common_test_chat_simple_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_multi_system_ok() -> Result<()> {
common_tests::common_test_chat_multi_system_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_stop_sequences_ok() -> Result<()> {
common_tests::common_test_chat_stop_sequences_ok(MODEL).await

View File

@ -14,6 +14,11 @@ async fn test_chat_simple_ok() -> Result<()> {
common_tests::common_test_chat_simple_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_multi_system_ok() -> Result<()> {
common_tests::common_test_chat_multi_system_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_json_structured_ok() -> Result<()> {
common_tests::common_test_chat_json_structured_ok(MODEL, true).await

View File

@ -16,6 +16,11 @@ async fn test_chat_simple_ok() -> Result<()> {
common_tests::common_test_chat_simple_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_multi_system_ok() -> Result<()> {
common_tests::common_test_chat_multi_system_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_json_mode_ok() -> Result<()> {
common_tests::common_test_chat_json_mode_ok(MODEL, true).await

View File

@ -14,6 +14,11 @@ async fn test_chat_simple_ok() -> Result<()> {
common_tests::common_test_chat_simple_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_multi_system_ok() -> Result<()> {
common_tests::common_test_chat_multi_system_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_json_mode_ok() -> Result<()> {
common_tests::common_test_chat_json_mode_ok(MODEL, false).await

View File

@ -14,6 +14,11 @@ async fn test_chat_simple_ok() -> Result<()> {
common_tests::common_test_chat_simple_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_multi_system_ok() -> Result<()> {
common_tests::common_test_chat_multi_system_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_json_mode_ok() -> Result<()> {
common_tests::common_test_chat_json_mode_ok(MODEL, true).await