. test - anthropic - add test_chat_stop_sequences_ok

This commit is contained in:
Jeremy Chone
2024-12-07 12:36:23 -08:00
parent 9eac22b64a
commit 8f00d665ba
3 changed files with 38 additions and 5 deletions

View File

@ -75,6 +75,11 @@ impl ChatOptions {
self self
} }
pub fn with_stop_sequences(mut self, values: Vec<String>) -> Self {
self.stop_sequences = values;
self
}
/// Set the `json_mode` for this request. /// Set the `json_mode` for this request.
/// ///
/// IMPORTANT: This is deprecated now; use `with_response_format(ChatResponseFormat::JsonMode)` /// IMPORTANT: This is deprecated now; use `with_response_format(ChatResponseFormat::JsonMode)`

View File

@ -167,6 +167,29 @@ pub async fn common_test_chat_temperature_ok(model: &str) -> Result<()> {
Ok(()) Ok(())
} }
pub async fn common_test_chat_stop_sequences_ok(model: &str) -> Result<()> {
// -- Setup & Fixtures
let client = Client::default();
let chat_req = ChatRequest::from_user("What is the capital of England?");
let chat_options = ChatOptions::default().with_stop_sequences(vec!["London".to_string()]);
// -- Exec
let chat_res = client.exec_chat(model, chat_req, Some(&chat_options)).await?;
let ai_content_lower = chat_res
.content_text_as_str()
.ok_or("Should have a AI response")?
.to_lowercase();
// -- Check
assert!(!ai_content_lower.is_empty(), "Content should not be empty");
assert!(
!ai_content_lower.contains("london"),
"Content should not contain 'London'"
);
Ok(())
}
// endregion: --- Chat // endregion: --- Chat
// region: --- Chat Stream Tests // region: --- Chat Stream Tests

View File

@ -6,7 +6,7 @@ use serial_test::serial;
type Result<T> = core::result::Result<T, Box<dyn std::error::Error>>; // For tests. type Result<T> = core::result::Result<T, Box<dyn std::error::Error>>; // For tests.
// 4k // 4k (cheaper)
const MODEL: &str = "claude-3-haiku-20240307"; const MODEL: &str = "claude-3-haiku-20240307";
// 8k output context // 8k output context
// const MODEL: &str = "claude-3-5-haiku-20241022"; // const MODEL: &str = "claude-3-5-haiku-20241022";
@ -14,19 +14,25 @@ const MODEL: &str = "claude-3-haiku-20240307";
// region: --- Chat // region: --- Chat
#[tokio::test] #[tokio::test]
// #[serial(anthropic)] #[serial(anthropic)]
async fn test_chat_simple_ok() -> Result<()> { async fn test_chat_simple_ok() -> Result<()> {
common_tests::common_test_chat_simple_ok(MODEL).await common_tests::common_test_chat_simple_ok(MODEL).await
} }
#[tokio::test] #[tokio::test]
// #[serial(anthropic)] #[serial(anthropic)]
async fn test_chat_temperature_ok() -> Result<()> { async fn test_chat_temperature_ok() -> Result<()> {
common_tests::common_test_chat_temperature_ok(MODEL).await common_tests::common_test_chat_temperature_ok(MODEL).await
} }
#[tokio::test] #[tokio::test]
// #[serial(anthropic)] #[serial(anthropic)]
async fn test_chat_stop_sequences_ok() -> Result<()> {
common_tests::common_test_chat_stop_sequences_ok(MODEL).await
}
#[tokio::test]
#[serial(anthropic)]
async fn test_chat_json_mode_ok() -> Result<()> { async fn test_chat_json_mode_ok() -> Result<()> {
common_tests::common_test_chat_json_mode_ok(MODEL, true).await common_tests::common_test_chat_json_mode_ok(MODEL, true).await
} }
@ -52,7 +58,6 @@ async fn test_chat_stream_capture_content_ok() -> Result<()> {
async fn test_chat_stream_capture_all_ok() -> Result<()> { async fn test_chat_stream_capture_all_ok() -> Result<()> {
common_tests::common_test_chat_stream_capture_all_ok(MODEL).await common_tests::common_test_chat_stream_capture_all_ok(MODEL).await
} }
// endregion: --- Chat Stream Tests // endregion: --- Chat Stream Tests
// region: --- Tool Tests // region: --- Tool Tests