diff --git a/tests/support/asserts.rs b/tests/support/asserts.rs new file mode 100644 index 0000000..2b75da8 --- /dev/null +++ b/tests/support/asserts.rs @@ -0,0 +1,115 @@ +use std::convert::Into; +use std::fmt::Formatter; + +pub fn assert_contains<'a, T>(data: T, val: &str) +where + T: Into>, +{ + let container: DataContainer = data.into(); + assert!( + container.contains(val), + "Should contain: {}\nBut was: {:?}", + val, + container + ); +} + +pub fn assert_not_contains<'a, T>(data: T, val: &str) +where + T: Into>, +{ + let container: DataContainer = data.into(); + assert!( + !container.contains(val), + "Should not contain: {}\nBut was: {:?}", + val, + container + ); +} + +// region: --- Support Types + +pub enum DataContainer<'a> { + Owned(Vec<&'a str>), + Slice(&'a [&'a str]), + Str(&'a str), +} + +impl std::fmt::Debug for DataContainer<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DataContainer::Owned(vec) => write!(f, "{:?}", vec), + DataContainer::Slice(slice) => write!(f, "{:?}", slice), + DataContainer::Str(s) => { + write!(f, "{s}") + } + } + } +} + +impl<'a> From<&'a [&'a str]> for DataContainer<'a> { + fn from(slice: &'a [&'a str]) -> Self { + DataContainer::Slice(slice) + } +} + +impl<'a> From<&'a Vec<&'a str>> for DataContainer<'a> { + fn from(vec: &'a Vec<&'a str>) -> Self { + DataContainer::Slice(&vec[..]) + } +} + +impl<'a> From<&'a Vec> for DataContainer<'a> { + fn from(vec: &'a Vec) -> Self { + DataContainer::Owned(vec.iter().map(|s| s.as_str()).collect()) + } +} + +impl<'a> From<&'a str> for DataContainer<'a> { + fn from(string: &'a str) -> Self { + DataContainer::Str(string) + } +} + +impl<'a> From<&'a String> for DataContainer<'a> { + fn from(string: &'a String) -> Self { + DataContainer::Str(string) + } +} + +impl<'a> DataContainer<'a> { + fn contains(&self, val: &str) -> bool { + match self { + DataContainer::Owned(vec) => vec.contains(&val), + DataContainer::Slice(slice) => slice.contains(&val), + DataContainer::Str(string) => string.contains(val), + } + } +} + +// endregion: --- Support Types + +// region: --- Tests + +#[cfg(test)] +mod tests { + type Result = core::result::Result>; // For tests. + + use super::*; + + #[test] + fn test_assert_contains() -> Result<()> { + let data_vec = vec!["apple", "banana", "cherry"]; + assert_contains(&data_vec, "banana"); + + let data_slice: &[&str] = &["dog", "cat", "mouse"]; + assert_contains(data_slice, "cat"); + + let data_str = "This is a test string"; + assert_contains(data_str, "test"); + + Ok(()) + } +} + +// endregion: --- Tests diff --git a/tests/support/common_tests.rs b/tests/support/common_tests.rs index 4ae165d..858f932 100644 --- a/tests/support/common_tests.rs +++ b/tests/support/common_tests.rs @@ -1,5 +1,6 @@ use crate::get_option_value; -use crate::support::{extract_stream_end, seed_chat_req_simple, seed_chat_req_tool_simple, Result}; +use crate::support::{assert_contains, extract_stream_end, seed_chat_req_simple, seed_chat_req_tool_simple, Result}; +use genai::adapter::AdapterKind; use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatResponseFormat, JsonSpec, Tool, ToolResponse}; use genai::resolver::{AuthData, AuthResolver, AuthResolverFn, IntoAuthResolverFn}; use genai::{Client, ClientConfig, ModelIden}; @@ -222,6 +223,7 @@ pub async fn common_test_chat_stop_sequences_ok(model: &str) -> Result<()> { Ok(()) } + // endregion: --- Chat // region: --- Chat Stream Tests @@ -401,3 +403,19 @@ pub async fn common_test_resolver_auth_ok(model: &str, auth_data: AuthData) -> R } // endregion: --- With Resolvers + +// region: --- List + +pub async fn common_test_list_models(adapter_kind: AdapterKind, contains: &str) -> Result<()> { + let client = Client::default(); + + // -- Exec + let models = client.all_model_names(adapter_kind).await?; + + // -- Check + assert_contains(&models, contains); + + Ok(()) +} + +// endregion: --- List diff --git a/tests/support/mod.rs b/tests/support/mod.rs index 863ec36..bf63c66 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -5,9 +5,11 @@ // region: --- Modules +mod asserts; mod helpers; mod seeders; +pub use asserts::*; pub use helpers::*; pub use seeders::*; diff --git a/tests/tests_p_anthropic.rs b/tests/tests_p_anthropic.rs index 333a4ee..f718fdc 100644 --- a/tests/tests_p_anthropic.rs +++ b/tests/tests_p_anthropic.rs @@ -1,6 +1,7 @@ mod support; use crate::support::common_tests; +use genai::adapter::AdapterKind; use genai::resolver::AuthData; use serial_test::serial; @@ -90,3 +91,12 @@ async fn test_resolver_auth_ok() -> Result<()> { } // endregion: --- Resolver Tests + +// region: --- List + +#[tokio::test] +async fn test_list_models() -> Result<()> { + common_tests::common_test_list_models(AdapterKind::Anthropic, "claude-3-5-sonnet-20241022").await +} + +// endregion: --- List diff --git a/tests/tests_p_cohere.rs b/tests/tests_p_cohere.rs index b1722f2..6b33154 100644 --- a/tests/tests_p_cohere.rs +++ b/tests/tests_p_cohere.rs @@ -1,6 +1,7 @@ mod support; use crate::support::common_tests; +use genai::adapter::AdapterKind; use genai::resolver::AuthData; type Result = core::result::Result>; // For tests. @@ -58,3 +59,12 @@ async fn test_resolver_auth_ok() -> Result<()> { } // endregion: --- Resolver Tests + +// region: --- List + +#[tokio::test] +async fn test_list_models() -> Result<()> { + common_tests::common_test_list_models(AdapterKind::Cohere, "command-r-plus").await +} + +// endregion: --- List diff --git a/tests/tests_p_gemini.rs b/tests/tests_p_gemini.rs index 3aa7665..fcf81c0 100644 --- a/tests/tests_p_gemini.rs +++ b/tests/tests_p_gemini.rs @@ -1,6 +1,7 @@ mod support; use crate::support::common_tests; +use genai::adapter::AdapterKind; use genai::resolver::AuthData; type Result = core::result::Result>; // For tests. @@ -63,3 +64,12 @@ async fn test_resolver_auth_ok() -> Result<()> { } // endregion: --- Resolver Tests + +// region: --- List + +#[tokio::test] +async fn test_list_models() -> Result<()> { + common_tests::common_test_list_models(AdapterKind::Gemini, "gemini-1.5-pro").await +} + +// endregion: --- List diff --git a/tests/tests_p_groq.rs b/tests/tests_p_groq.rs index 131b525..8855cbf 100644 --- a/tests/tests_p_groq.rs +++ b/tests/tests_p_groq.rs @@ -1,6 +1,7 @@ mod support; use crate::support::common_tests; +use genai::adapter::AdapterKind; use genai::resolver::AuthData; type Result = core::result::Result>; // For tests. @@ -65,3 +66,12 @@ async fn test_resolver_auth_ok() -> Result<()> { } // endregion: --- Resolver Tests + +// region: --- List + +#[tokio::test] +async fn test_list_models() -> Result<()> { + common_tests::common_test_list_models(AdapterKind::Groq, "llama-3.1-70b-versatile").await +} + +// endregion: --- List diff --git a/tests/tests_p_ollama.rs b/tests/tests_p_ollama.rs index c13010f..9b15370 100644 --- a/tests/tests_p_ollama.rs +++ b/tests/tests_p_ollama.rs @@ -1,6 +1,7 @@ mod support; use crate::support::common_tests; +use genai::adapter::AdapterKind; use genai::resolver::AuthData; type Result = core::result::Result>; // For tests. @@ -65,3 +66,12 @@ async fn test_resolver_auth_ok() -> Result<()> { } // endregion: --- Resolver Tests + +// region: --- List + +#[tokio::test] +async fn test_list_models() -> Result<()> { + common_tests::common_test_list_models(AdapterKind::Ollama, "llama3.1:8b").await +} + +// endregion: --- List diff --git a/tests/tests_p_openai.rs b/tests/tests_p_openai.rs index d7bd159..bb67eb7 100644 --- a/tests/tests_p_openai.rs +++ b/tests/tests_p_openai.rs @@ -1,6 +1,7 @@ mod support; use crate::support::common_tests; +use genai::adapter::AdapterKind; use genai::resolver::AuthData; type Result = core::result::Result>; // For tests. @@ -81,3 +82,12 @@ async fn test_resolver_auth_ok() -> Result<()> { } // endregion: --- Resolver Tests + +// region: --- List + +#[tokio::test] +async fn test_list_models() -> Result<()> { + common_tests::common_test_list_models(AdapterKind::OpenAI, "gpt-4o").await +} + +// endregion: --- List