Prevent using from_pretrained on invalid ids (better error message). (#1153)

This commit is contained in:
Nicolas Patry
2023-01-23 15:38:14 +01:00
committed by GitHub
parent b861d48b06
commit d09241fba1
2 changed files with 43 additions and 3 deletions

View File

@ -102,9 +102,32 @@ pub fn from_pretrained<S: AsRef<str>>(
identifier: S,
params: Option<FromPretrainedParameters>,
) -> Result<PathBuf> {
let identifier: &str = identifier.as_ref();
let is_valid_char =
|x: char| x.is_alphanumeric() || x == '-' || x == '_' || x == '.' || x == '/';
let valid = identifier.chars().all(is_valid_char);
if !valid {
return Err(format!(
"Model \"{}\" contains invalid characters, expected only alphanumeric or '/', '-', '_', '.'",
identifier
)
.into());
}
let params = params.unwrap_or_default();
let cache_dir = ensure_cache_dir()?;
let revision = &params.revision;
let valid_revision = revision.chars().all(is_valid_char);
if !valid_revision {
return Err(format!(
"Revision \"{}\" contains invalid characters, expected only alphanumeric or '/', '-', '_', '.'",
revision
)
.into());
}
// Build a custom HTTP Client using our user-agent and custom headers
let mut headers = header::HeaderMap::new();
if let Some(ref token) = params.auth_token {
@ -124,14 +147,13 @@ pub fn from_pretrained<S: AsRef<str>>(
let url_to_download = format!(
"https://huggingface.co/{}/resolve/{}/tokenizer.json",
identifier.as_ref(),
params.revision,
identifier, revision,
);
match cache.cached_path(&url_to_download) {
Err(_) => Err(format!(
"Model \"{}\" on the Hub doesn't have a tokenizer",
identifier.as_ref()
identifier
)
.into()),
Ok(path) => Ok(path),

View File

@ -36,3 +36,21 @@ fn test_from_pretrained_revision() -> Result<()> {
Ok(())
}
#[test]
fn test_from_pretrained_invalid_model() {
let tokenizer = Tokenizer::from_pretrained("docs?", None);
assert!(tokenizer.is_err());
}
#[test]
fn test_from_pretrained_invalid_revision() {
let tokenizer = Tokenizer::from_pretrained(
"bert-base-cased",
Some(FromPretrainedParameters {
revision: "gpt?".to_string(),
..Default::default()
}),
);
assert!(tokenizer.is_err());
}