mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Prevent using from_pretrained
on invalid ids (better error message). (#1153)
This commit is contained in:
@ -102,9 +102,32 @@ pub fn from_pretrained<S: AsRef<str>>(
|
|||||||
identifier: S,
|
identifier: S,
|
||||||
params: Option<FromPretrainedParameters>,
|
params: Option<FromPretrainedParameters>,
|
||||||
) -> Result<PathBuf> {
|
) -> Result<PathBuf> {
|
||||||
|
let identifier: &str = identifier.as_ref();
|
||||||
|
|
||||||
|
let is_valid_char =
|
||||||
|
|x: char| x.is_alphanumeric() || x == '-' || x == '_' || x == '.' || x == '/';
|
||||||
|
|
||||||
|
let valid = identifier.chars().all(is_valid_char);
|
||||||
|
if !valid {
|
||||||
|
return Err(format!(
|
||||||
|
"Model \"{}\" contains invalid characters, expected only alphanumeric or '/', '-', '_', '.'",
|
||||||
|
identifier
|
||||||
|
)
|
||||||
|
.into());
|
||||||
|
}
|
||||||
let params = params.unwrap_or_default();
|
let params = params.unwrap_or_default();
|
||||||
let cache_dir = ensure_cache_dir()?;
|
let cache_dir = ensure_cache_dir()?;
|
||||||
|
|
||||||
|
let revision = ¶ms.revision;
|
||||||
|
let valid_revision = revision.chars().all(is_valid_char);
|
||||||
|
if !valid_revision {
|
||||||
|
return Err(format!(
|
||||||
|
"Revision \"{}\" contains invalid characters, expected only alphanumeric or '/', '-', '_', '.'",
|
||||||
|
revision
|
||||||
|
)
|
||||||
|
.into());
|
||||||
|
}
|
||||||
|
|
||||||
// Build a custom HTTP Client using our user-agent and custom headers
|
// Build a custom HTTP Client using our user-agent and custom headers
|
||||||
let mut headers = header::HeaderMap::new();
|
let mut headers = header::HeaderMap::new();
|
||||||
if let Some(ref token) = params.auth_token {
|
if let Some(ref token) = params.auth_token {
|
||||||
@ -124,14 +147,13 @@ pub fn from_pretrained<S: AsRef<str>>(
|
|||||||
|
|
||||||
let url_to_download = format!(
|
let url_to_download = format!(
|
||||||
"https://huggingface.co/{}/resolve/{}/tokenizer.json",
|
"https://huggingface.co/{}/resolve/{}/tokenizer.json",
|
||||||
identifier.as_ref(),
|
identifier, revision,
|
||||||
params.revision,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
match cache.cached_path(&url_to_download) {
|
match cache.cached_path(&url_to_download) {
|
||||||
Err(_) => Err(format!(
|
Err(_) => Err(format!(
|
||||||
"Model \"{}\" on the Hub doesn't have a tokenizer",
|
"Model \"{}\" on the Hub doesn't have a tokenizer",
|
||||||
identifier.as_ref()
|
identifier
|
||||||
)
|
)
|
||||||
.into()),
|
.into()),
|
||||||
Ok(path) => Ok(path),
|
Ok(path) => Ok(path),
|
||||||
|
@ -36,3 +36,21 @@ fn test_from_pretrained_revision() -> Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_from_pretrained_invalid_model() {
|
||||||
|
let tokenizer = Tokenizer::from_pretrained("docs?", None);
|
||||||
|
assert!(tokenizer.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_from_pretrained_invalid_revision() {
|
||||||
|
let tokenizer = Tokenizer::from_pretrained(
|
||||||
|
"bert-base-cased",
|
||||||
|
Some(FromPretrainedParameters {
|
||||||
|
revision: "gpt?".to_string(),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
assert!(tokenizer.is_err());
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user