From d09241fba1d481e40aa8fd351a71966c942aed7b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 23 Jan 2023 15:38:14 +0100 Subject: [PATCH] Prevent using `from_pretrained` on invalid ids (better error message). (#1153) --- tokenizers/src/utils/from_pretrained.rs | 28 ++++++++++++++++++++++--- tokenizers/tests/from_pretrained.rs | 18 ++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/tokenizers/src/utils/from_pretrained.rs b/tokenizers/src/utils/from_pretrained.rs index 4df53dab..30ce41d8 100644 --- a/tokenizers/src/utils/from_pretrained.rs +++ b/tokenizers/src/utils/from_pretrained.rs @@ -102,9 +102,32 @@ pub fn from_pretrained>( identifier: S, params: Option, ) -> Result { + let identifier: &str = identifier.as_ref(); + + let is_valid_char = + |x: char| x.is_alphanumeric() || x == '-' || x == '_' || x == '.' || x == '/'; + + let valid = identifier.chars().all(is_valid_char); + if !valid { + return Err(format!( + "Model \"{}\" contains invalid characters, expected only alphanumeric or '/', '-', '_', '.'", + identifier + ) + .into()); + } let params = params.unwrap_or_default(); let cache_dir = ensure_cache_dir()?; + let revision = ¶ms.revision; + let valid_revision = revision.chars().all(is_valid_char); + if !valid_revision { + return Err(format!( + "Revision \"{}\" contains invalid characters, expected only alphanumeric or '/', '-', '_', '.'", + revision + ) + .into()); + } + // Build a custom HTTP Client using our user-agent and custom headers let mut headers = header::HeaderMap::new(); if let Some(ref token) = params.auth_token { @@ -124,14 +147,13 @@ pub fn from_pretrained>( let url_to_download = format!( "https://huggingface.co/{}/resolve/{}/tokenizer.json", - identifier.as_ref(), - params.revision, + identifier, revision, ); match cache.cached_path(&url_to_download) { Err(_) => Err(format!( "Model \"{}\" on the Hub doesn't have a tokenizer", - identifier.as_ref() + identifier ) .into()), Ok(path) => Ok(path), diff --git a/tokenizers/tests/from_pretrained.rs b/tokenizers/tests/from_pretrained.rs index a2e954e4..3036f35d 100644 --- a/tokenizers/tests/from_pretrained.rs +++ b/tokenizers/tests/from_pretrained.rs @@ -36,3 +36,21 @@ fn test_from_pretrained_revision() -> Result<()> { Ok(()) } + +#[test] +fn test_from_pretrained_invalid_model() { + let tokenizer = Tokenizer::from_pretrained("docs?", None); + assert!(tokenizer.is_err()); +} + +#[test] +fn test_from_pretrained_invalid_revision() { + let tokenizer = Tokenizer::from_pretrained( + "bert-base-cased", + Some(FromPretrainedParameters { + revision: "gpt?".to_string(), + ..Default::default() + }), + ); + assert!(tokenizer.is_err()); +}