+ openai - add stop_sequences support

This commit is contained in:
Jeremy Chone
2024-12-07 12:43:33 -08:00
parent 8f00d665ba
commit 2ddc2d2023
5 changed files with 28 additions and 9 deletions

View File

@ -182,15 +182,19 @@ impl OpenAIAdapter {
payload["response_format"] = response_format;
}
// --
// -- Add supported ChatOptions
if stream & options_set.capture_usage().unwrap_or(false) {
payload.x_insert("stream_options", json!({"include_usage": true}))?;
}
// -- Add supported ChatOptions
if let Some(temperature) = options_set.temperature() {
payload.x_insert("temperature", temperature)?;
}
if !options_set.stop_sequences().is_empty() {
payload.x_insert("stop", options_set.stop_sequences())?;
}
if let Some(max_tokens) = options_set.max_tokens() {
payload.x_insert("max_tokens", max_tokens)?;
}

View File

@ -163,6 +163,13 @@ impl ChatOptionsSet<'_, '_> {
.or_else(|| self.client.and_then(|client| client.response_format.as_ref()))
}
pub fn stop_sequences(&self) -> &[String] {
self.chat
.map(|chat| chat.stop_sequences.deref())
.or_else(|| self.client.map(|client| client.stop_sequences.deref()))
.unwrap_or(&[])
}
/// Returns true only if there is a ChatResponseFormat::JsonMode
#[deprecated(note = "Use .response_format()")]
#[allow(unused)]
@ -173,13 +180,6 @@ impl ChatOptionsSet<'_, '_> {
_ => Some(false),
}
}
pub fn stop_sequences(&self) -> &[String] {
self.chat
.map(|chat| chat.stop_sequences.deref())
.or_else(|| self.client.map(|client| client.stop_sequences.deref()))
.unwrap_or(&[])
}
}
// endregion: --- ChatOptionsSet

View File

@ -26,6 +26,11 @@ async fn test_chat_temperature_ok() -> Result<()> {
common_tests::common_test_chat_temperature_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_stop_sequences_ok() -> Result<()> {
common_tests::common_test_chat_stop_sequences_ok(MODEL).await
}
// endregion: --- Chat
// region: --- Chat Stream Tests

View File

@ -24,6 +24,11 @@ async fn test_chat_temperature_ok() -> Result<()> {
common_tests::common_test_chat_temperature_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_stop_sequences_ok() -> Result<()> {
common_tests::common_test_chat_stop_sequences_ok(MODEL).await
}
// endregion: --- Chat
// region: --- Chat Stream Tests

View File

@ -29,6 +29,11 @@ async fn test_chat_temperature_ok() -> Result<()> {
common_tests::common_test_chat_temperature_ok(MODEL).await
}
#[tokio::test]
async fn test_chat_stop_sequences_ok() -> Result<()> {
common_tests::common_test_chat_stop_sequences_ok(MODEL).await
}
// endregion: --- Chat
// region: --- Chat Stream Tests