mirror of
https://github.com/mii443/rust-genai.git
synced 2025-08-23 16:49:20 +00:00
+ openai - add stop_sequences support
This commit is contained in:
@ -182,15 +182,19 @@ impl OpenAIAdapter {
|
|||||||
payload["response_format"] = response_format;
|
payload["response_format"] = response_format;
|
||||||
}
|
}
|
||||||
|
|
||||||
// --
|
// -- Add supported ChatOptions
|
||||||
if stream & options_set.capture_usage().unwrap_or(false) {
|
if stream & options_set.capture_usage().unwrap_or(false) {
|
||||||
payload.x_insert("stream_options", json!({"include_usage": true}))?;
|
payload.x_insert("stream_options", json!({"include_usage": true}))?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// -- Add supported ChatOptions
|
|
||||||
if let Some(temperature) = options_set.temperature() {
|
if let Some(temperature) = options_set.temperature() {
|
||||||
payload.x_insert("temperature", temperature)?;
|
payload.x_insert("temperature", temperature)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !options_set.stop_sequences().is_empty() {
|
||||||
|
payload.x_insert("stop", options_set.stop_sequences())?;
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(max_tokens) = options_set.max_tokens() {
|
if let Some(max_tokens) = options_set.max_tokens() {
|
||||||
payload.x_insert("max_tokens", max_tokens)?;
|
payload.x_insert("max_tokens", max_tokens)?;
|
||||||
}
|
}
|
||||||
|
@ -163,6 +163,13 @@ impl ChatOptionsSet<'_, '_> {
|
|||||||
.or_else(|| self.client.and_then(|client| client.response_format.as_ref()))
|
.or_else(|| self.client.and_then(|client| client.response_format.as_ref()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn stop_sequences(&self) -> &[String] {
|
||||||
|
self.chat
|
||||||
|
.map(|chat| chat.stop_sequences.deref())
|
||||||
|
.or_else(|| self.client.map(|client| client.stop_sequences.deref()))
|
||||||
|
.unwrap_or(&[])
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns true only if there is a ChatResponseFormat::JsonMode
|
/// Returns true only if there is a ChatResponseFormat::JsonMode
|
||||||
#[deprecated(note = "Use .response_format()")]
|
#[deprecated(note = "Use .response_format()")]
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
@ -173,13 +180,6 @@ impl ChatOptionsSet<'_, '_> {
|
|||||||
_ => Some(false),
|
_ => Some(false),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn stop_sequences(&self) -> &[String] {
|
|
||||||
self.chat
|
|
||||||
.map(|chat| chat.stop_sequences.deref())
|
|
||||||
.or_else(|| self.client.map(|client| client.stop_sequences.deref()))
|
|
||||||
.unwrap_or(&[])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// endregion: --- ChatOptionsSet
|
// endregion: --- ChatOptionsSet
|
||||||
|
@ -26,6 +26,11 @@ async fn test_chat_temperature_ok() -> Result<()> {
|
|||||||
common_tests::common_test_chat_temperature_ok(MODEL).await
|
common_tests::common_test_chat_temperature_ok(MODEL).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_chat_stop_sequences_ok() -> Result<()> {
|
||||||
|
common_tests::common_test_chat_stop_sequences_ok(MODEL).await
|
||||||
|
}
|
||||||
|
|
||||||
// endregion: --- Chat
|
// endregion: --- Chat
|
||||||
|
|
||||||
// region: --- Chat Stream Tests
|
// region: --- Chat Stream Tests
|
||||||
|
@ -24,6 +24,11 @@ async fn test_chat_temperature_ok() -> Result<()> {
|
|||||||
common_tests::common_test_chat_temperature_ok(MODEL).await
|
common_tests::common_test_chat_temperature_ok(MODEL).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_chat_stop_sequences_ok() -> Result<()> {
|
||||||
|
common_tests::common_test_chat_stop_sequences_ok(MODEL).await
|
||||||
|
}
|
||||||
|
|
||||||
// endregion: --- Chat
|
// endregion: --- Chat
|
||||||
|
|
||||||
// region: --- Chat Stream Tests
|
// region: --- Chat Stream Tests
|
||||||
|
@ -29,6 +29,11 @@ async fn test_chat_temperature_ok() -> Result<()> {
|
|||||||
common_tests::common_test_chat_temperature_ok(MODEL).await
|
common_tests::common_test_chat_temperature_ok(MODEL).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_chat_stop_sequences_ok() -> Result<()> {
|
||||||
|
common_tests::common_test_chat_stop_sequences_ok(MODEL).await
|
||||||
|
}
|
||||||
|
|
||||||
// endregion: --- Chat
|
// endregion: --- Chat
|
||||||
|
|
||||||
// region: --- Chat Stream Tests
|
// region: --- Chat Stream Tests
|
||||||
|
Reference in New Issue
Block a user