diff --git a/src/adapter/adapters/openai/adapter_impl.rs b/src/adapter/adapters/openai/adapter_impl.rs index 98f1c2f..52ad26a 100644 --- a/src/adapter/adapters/openai/adapter_impl.rs +++ b/src/adapter/adapters/openai/adapter_impl.rs @@ -182,15 +182,19 @@ impl OpenAIAdapter { payload["response_format"] = response_format; } - // -- + // -- Add supported ChatOptions if stream & options_set.capture_usage().unwrap_or(false) { payload.x_insert("stream_options", json!({"include_usage": true}))?; } - // -- Add supported ChatOptions if let Some(temperature) = options_set.temperature() { payload.x_insert("temperature", temperature)?; } + + if !options_set.stop_sequences().is_empty() { + payload.x_insert("stop", options_set.stop_sequences())?; + } + if let Some(max_tokens) = options_set.max_tokens() { payload.x_insert("max_tokens", max_tokens)?; } diff --git a/src/chat/chat_options.rs b/src/chat/chat_options.rs index 1832a38..eaf7222 100644 --- a/src/chat/chat_options.rs +++ b/src/chat/chat_options.rs @@ -163,6 +163,13 @@ impl ChatOptionsSet<'_, '_> { .or_else(|| self.client.and_then(|client| client.response_format.as_ref())) } + pub fn stop_sequences(&self) -> &[String] { + self.chat + .map(|chat| chat.stop_sequences.deref()) + .or_else(|| self.client.map(|client| client.stop_sequences.deref())) + .unwrap_or(&[]) + } + /// Returns true only if there is a ChatResponseFormat::JsonMode #[deprecated(note = "Use .response_format()")] #[allow(unused)] @@ -173,13 +180,6 @@ impl ChatOptionsSet<'_, '_> { _ => Some(false), } } - - pub fn stop_sequences(&self) -> &[String] { - self.chat - .map(|chat| chat.stop_sequences.deref()) - .or_else(|| self.client.map(|client| client.stop_sequences.deref())) - .unwrap_or(&[]) - } } // endregion: --- ChatOptionsSet diff --git a/tests/tests_p_groq.rs b/tests/tests_p_groq.rs index e8e1cde..0af015d 100644 --- a/tests/tests_p_groq.rs +++ b/tests/tests_p_groq.rs @@ -26,6 +26,11 @@ async fn test_chat_temperature_ok() -> Result<()> { common_tests::common_test_chat_temperature_ok(MODEL).await } +#[tokio::test] +async fn test_chat_stop_sequences_ok() -> Result<()> { + common_tests::common_test_chat_stop_sequences_ok(MODEL).await +} + // endregion: --- Chat // region: --- Chat Stream Tests diff --git a/tests/tests_p_ollama.rs b/tests/tests_p_ollama.rs index 91a7f2e..d8c78d8 100644 --- a/tests/tests_p_ollama.rs +++ b/tests/tests_p_ollama.rs @@ -24,6 +24,11 @@ async fn test_chat_temperature_ok() -> Result<()> { common_tests::common_test_chat_temperature_ok(MODEL).await } +#[tokio::test] +async fn test_chat_stop_sequences_ok() -> Result<()> { + common_tests::common_test_chat_stop_sequences_ok(MODEL).await +} + // endregion: --- Chat // region: --- Chat Stream Tests diff --git a/tests/tests_p_openai.rs b/tests/tests_p_openai.rs index 3c2850e..abba9af 100644 --- a/tests/tests_p_openai.rs +++ b/tests/tests_p_openai.rs @@ -29,6 +29,11 @@ async fn test_chat_temperature_ok() -> Result<()> { common_tests::common_test_chat_temperature_ok(MODEL).await } +#[tokio::test] +async fn test_chat_stop_sequences_ok() -> Result<()> { + common_tests::common_test_chat_stop_sequences_ok(MODEL).await +} + // endregion: --- Chat // region: --- Chat Stream Tests