mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Prevent using from_pretrained
on invalid ids (better error message). (#1153)
This commit is contained in:
@ -102,9 +102,32 @@ pub fn from_pretrained<S: AsRef<str>>(
|
||||
identifier: S,
|
||||
params: Option<FromPretrainedParameters>,
|
||||
) -> Result<PathBuf> {
|
||||
let identifier: &str = identifier.as_ref();
|
||||
|
||||
let is_valid_char =
|
||||
|x: char| x.is_alphanumeric() || x == '-' || x == '_' || x == '.' || x == '/';
|
||||
|
||||
let valid = identifier.chars().all(is_valid_char);
|
||||
if !valid {
|
||||
return Err(format!(
|
||||
"Model \"{}\" contains invalid characters, expected only alphanumeric or '/', '-', '_', '.'",
|
||||
identifier
|
||||
)
|
||||
.into());
|
||||
}
|
||||
let params = params.unwrap_or_default();
|
||||
let cache_dir = ensure_cache_dir()?;
|
||||
|
||||
let revision = ¶ms.revision;
|
||||
let valid_revision = revision.chars().all(is_valid_char);
|
||||
if !valid_revision {
|
||||
return Err(format!(
|
||||
"Revision \"{}\" contains invalid characters, expected only alphanumeric or '/', '-', '_', '.'",
|
||||
revision
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
// Build a custom HTTP Client using our user-agent and custom headers
|
||||
let mut headers = header::HeaderMap::new();
|
||||
if let Some(ref token) = params.auth_token {
|
||||
@ -124,14 +147,13 @@ pub fn from_pretrained<S: AsRef<str>>(
|
||||
|
||||
let url_to_download = format!(
|
||||
"https://huggingface.co/{}/resolve/{}/tokenizer.json",
|
||||
identifier.as_ref(),
|
||||
params.revision,
|
||||
identifier, revision,
|
||||
);
|
||||
|
||||
match cache.cached_path(&url_to_download) {
|
||||
Err(_) => Err(format!(
|
||||
"Model \"{}\" on the Hub doesn't have a tokenizer",
|
||||
identifier.as_ref()
|
||||
identifier
|
||||
)
|
||||
.into()),
|
||||
Ok(path) => Ok(path),
|
||||
|
@ -36,3 +36,21 @@ fn test_from_pretrained_revision() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_pretrained_invalid_model() {
|
||||
let tokenizer = Tokenizer::from_pretrained("docs?", None);
|
||||
assert!(tokenizer.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_pretrained_invalid_revision() {
|
||||
let tokenizer = Tokenizer::from_pretrained(
|
||||
"bert-base-cased",
|
||||
Some(FromPretrainedParameters {
|
||||
revision: "gpt?".to_string(),
|
||||
..Default::default()
|
||||
}),
|
||||
);
|
||||
assert!(tokenizer.is_err());
|
||||
}
|
||||
|
Reference in New Issue
Block a user