Fixing bad deserialization following inclusion of a default for Punctuation. (#884)

* Fixing bad deserialization following inclusion of a default for
`Punctuation`.

* don't remove the type now...

* Adding slow test to run on all the tokenizers of the hub.

* `PartialEq` everywhere.

* Forcing `type` to exist on the `pre_tokenizers`.
This commit is contained in:
Nicolas Patry
2022-01-17 22:28:25 +01:00
committed by GitHub
parent c2fd765087
commit 1a84958cc8
14 changed files with 323 additions and 54 deletions

View File

@ -1668,7 +1668,7 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
[[package]]
name = "tokenizers"
version = "0.11.0"
version = "0.11.1"
dependencies = [
"aho-corasick",
"cached-path",

View File

@ -1,5 +1,10 @@
from tokenizers import Tokenizer, models, normalizers
from tokenizers import Tokenizer
import os
import unittest
from .utils import data_dir, albert_base
import json
from huggingface_hub import HfApi, hf_hub_url, cached_download
import tqdm
class TestSerialization:
@ -8,3 +13,70 @@ class TestSerialization:
# This used to fail because of BufReader that would fail because the
# file exceeds the buffer capacity
tokenizer = Tokenizer.from_file(albert_base)
def check(tokenizer_file) -> bool:
with open(tokenizer_file, "r") as f:
data = json.load(f)
if "pre_tokenizer" not in data:
return True
if "type" not in data["pre_tokenizer"]:
return False
if data["pre_tokenizer"]["type"] == "Sequence":
for pre_tok in data["pre_tokenizer"]["pretokenizers"]:
if "type" not in pre_tok:
return False
return True
def slow(test_case):
"""
Decorator marking a test as slow.
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
if os.getenv("RUN_SLOW") != "1":
return unittest.skip("use `RUN_SLOW=1` to run")(test_case)
else:
return test_case
@slow
class TestFullDeserialization(unittest.TestCase):
def test_full_deserialization_hub(self):
# Check we can read this file.
# This used to fail because of BufReader that would fail because the
# file exceeds the buffer capacity
api = HfApi()
not_loadable = []
invalid_pre_tokenizer = []
# models = api.list_models(filter="transformers")
# for model in tqdm.tqdm(models):
# model_id = model.modelId
# for model_file in model.siblings:
# filename = model_file.rfilename
# if filename == "tokenizer.json":
# all_models.append((model_id, filename))
all_models = [("HueyNemud/das22-10-camembert_pretrained", "tokenizer.json")]
for (model_id, filename) in tqdm.tqdm(all_models):
tokenizer_file = cached_download(hf_hub_url(model_id, filename=filename))
is_ok = check(tokenizer_file)
if not is_ok:
print(f"{model_id} is affected by no type")
invalid_pre_tokenizer.append(model_id)
try:
Tokenizer.from_file(tokenizer_file)
except Exception as e:
print(f"{model_id} is not loadable: {e}")
not_loadable.append(model_id)
except:
print(f"{model_id} is not loadable: Rust error")
not_loadable.append(model_id)
self.assertEqual(invalid_pre_tokenizer, [])
self.assertEqual(not_loadable, [])

View File

@ -5,7 +5,7 @@ fn is_bert_punc(x: char) -> bool {
char::is_ascii_punctuation(&x) || x.is_punctuation()
}
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct BertPreTokenizer;
impl_serde_unit_struct!(BertVisitor, BertPreTokenizer);

View File

@ -1,7 +1,7 @@
use std::collections::{HashMap, HashSet};
use onig::Regex;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::{
Decoder, Encoding, PostProcessor, PreTokenizedString, PreTokenizer, Result,
@ -40,7 +40,7 @@ lazy_static! {
bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
}
#[derive(Deserialize, Serialize, Copy, Clone, Debug, PartialEq)]
#[derive(Serialize, Copy, Clone, Debug, PartialEq)]
/// Provides all the necessary steps to handle the BPE tokenization at the byte-level. Takes care
/// of all the required processing steps to transform a UTF-8 string as needed before and after the
/// BPE model does its job.
@ -53,6 +53,30 @@ pub struct ByteLevel {
/// Whether the post processing step should trim offsets to avoid including whitespaces.
pub trim_offsets: bool,
}
impl<'de> Deserialize<'de> for ByteLevel {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
enum Type {
ByteLevel,
}
#[derive(Deserialize)]
pub struct ByteLevelHelper {
#[serde(rename = "type")]
_type: Type,
add_prefix_space: bool,
trim_offsets: bool,
}
let helper = ByteLevelHelper::deserialize(deserializer)?;
Ok(ByteLevel::new(helper.add_prefix_space, helper.trim_offsets))
}
}
impl Default for ByteLevel {
fn default() -> Self {
Self {

View File

@ -1,14 +1,36 @@
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
#[derive(Copy, Clone, Debug, Serialize, PartialEq)]
#[serde(tag = "type")]
#[non_exhaustive]
pub struct CharDelimiterSplit {
pub delimiter: char,
}
impl<'de> Deserialize<'de> for CharDelimiterSplit {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
enum Type {
CharDelimiterSplit,
}
#[derive(Deserialize)]
pub struct CharDelimiterSplitHelper {
#[serde(rename = "type")]
_type: Type,
delimiter: char,
}
let helper = CharDelimiterSplitHelper::deserialize(deserializer)?;
Ok(CharDelimiterSplit::new(helper.delimiter))
}
}
impl CharDelimiterSplit {
pub fn new(delimiter: char) -> Self {
CharDelimiterSplit { delimiter }

View File

@ -1,8 +1,8 @@
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Serialize, Clone, Debug, PartialEq)]
/// Pre tokenizes the numbers into single tokens. If individual_digits is set
/// to true, then all digits are splitted into individual tokens.
#[serde(tag = "type")]
@ -11,6 +11,28 @@ pub struct Digits {
pub individual_digits: bool,
}
impl<'de> Deserialize<'de> for Digits {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
enum Type {
Digits,
}
#[derive(Deserialize)]
pub struct DigitsHelper {
#[serde(rename = "type")]
_type: Type,
individual_digits: bool,
}
let helper = DigitsHelper::deserialize(deserializer)?;
Ok(Digits::new(helper.individual_digits))
}
}
impl Digits {
pub fn new(individual_digits: bool) -> Self {
Self { individual_digits }

View File

@ -1,11 +1,11 @@
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize)]
/// Replaces all the whitespaces by the provided meta character and then
/// splits on this character
#[serde(tag = "type", from = "MetaspaceDeserializer")]
#[serde(tag = "type")]
pub struct Metaspace {
replacement: char,
pub add_prefix_space: bool,
@ -13,17 +13,28 @@ pub struct Metaspace {
str_rep: String,
}
#[doc(hidden)]
#[derive(Deserialize)]
#[serde(tag = "type")]
pub struct MetaspaceDeserializer {
replacement: char,
add_prefix_space: bool,
}
impl<'de> Deserialize<'de> for Metaspace {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
enum Type {
Metaspace,
}
impl From<MetaspaceDeserializer> for Metaspace {
fn from(v: MetaspaceDeserializer) -> Metaspace {
Metaspace::new(v.replacement, v.add_prefix_space)
#[derive(Deserialize)]
pub struct MetaspaceHelper {
#[serde(rename = "type")]
_type: Type,
replacement: char,
pub add_prefix_space: bool,
#[serde(skip, rename = "str_rep")]
_str_rep: String,
}
let helper = MetaspaceHelper::deserialize(deserializer)?;
Ok(Metaspace::new(helper.replacement, helper.add_prefix_space))
}
}
@ -109,6 +120,12 @@ mod tests {
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
metaspace
);
let metaspace_parsed: Metaspace = serde_json::from_str(
r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true}"#,
)
.unwrap();
assert_eq!(metaspace_parsed, metaspace);
}
#[test]

View File

@ -23,7 +23,7 @@ use crate::pre_tokenizers::unicode_scripts::UnicodeScripts;
use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
use crate::{PreTokenizedString, PreTokenizer};
#[derive(Serialize, Deserialize, Clone, Debug)]
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
#[serde(untagged)]
pub enum PreTokenizerWrapper {
BertPreTokenizer(BertPreTokenizer),
@ -68,3 +68,51 @@ impl_enum_from!(Metaspace, PreTokenizerWrapper, Metaspace);
impl_enum_from!(WhitespaceSplit, PreTokenizerWrapper, WhitespaceSplit);
impl_enum_from!(Digits, PreTokenizerWrapper, Digits);
impl_enum_from!(UnicodeScripts, PreTokenizerWrapper, UnicodeScripts);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deserialize() {
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(r#"{"type":"Sequence","pretokenizers":[{"type":"WhitespaceSplit"},{"type":"Metaspace","replacement":"▁","str_rep":"▁","add_prefix_space":true}]}"#).unwrap();
assert_eq!(
pre_tokenizer,
PreTokenizerWrapper::Sequence(Sequence::new(vec![
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
]))
);
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(
r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true}"#,
)
.unwrap();
assert_eq!(
pre_tokenizer,
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
);
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(r#"{"type":"Sequence","pretokenizers":[{"type":"WhitespaceSplit"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#).unwrap();
assert_eq!(
pre_tokenizer,
PreTokenizerWrapper::Sequence(Sequence::new(vec![
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
]))
);
}
#[test]
fn test_deserialize_whitespace_split() {
let pre_tokenizer: PreTokenizerWrapper =
serde_json::from_str(r#"{"type":"WhitespaceSplit"}"#).unwrap();
assert_eq!(
pre_tokenizer,
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {})
);
}
}

View File

@ -1,4 +1,4 @@
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
use unicode_categories::UnicodeCategories;
@ -7,13 +7,35 @@ fn is_punc(x: char) -> bool {
char::is_ascii_punctuation(&x) || x.is_punctuation()
}
#[derive(Serialize, Deserialize, Copy, Clone, Debug)]
#[derive(Serialize, Copy, Clone, Debug, PartialEq)]
#[serde(tag = "type")]
pub struct Punctuation {
#[serde(default = "default_split")]
behavior: SplitDelimiterBehavior,
}
impl<'de> Deserialize<'de> for Punctuation {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
enum Type {
Punctuation,
}
#[derive(Deserialize)]
pub struct PunctuationHelper {
#[serde(rename = "type")]
_type: Type,
#[serde(default = "default_split")]
behavior: SplitDelimiterBehavior,
}
let helper = PunctuationHelper::deserialize(deserializer)?;
Ok(Punctuation::new(helper.behavior))
}
}
fn default_split() -> SplitDelimiterBehavior {
SplitDelimiterBehavior::Isolated
}
@ -65,6 +87,18 @@ mod tests {
#[test]
fn deserialization() {
let _punctuation: Punctuation = serde_json::from_str(r#"{"type": "punctuation"}"#).unwrap();
let punctuation: Punctuation = serde_json::from_str(r#"{"type": "Punctuation"}"#).unwrap();
assert_eq!(punctuation, Punctuation::default());
assert_eq!(
punctuation,
Punctuation::new(SplitDelimiterBehavior::Isolated)
);
}
#[test]
#[should_panic]
fn deserialization_erroneous() {
let _punctuation: Punctuation =
serde_json::from_str(r#"{"type": "WhitespaceSplit"}"#).unwrap();
}
}

View File

@ -1,13 +1,35 @@
use crate::pre_tokenizers::PreTokenizerWrapper;
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result};
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, PartialEq)]
#[serde(tag = "type")]
pub struct Sequence {
pretokenizers: Vec<PreTokenizerWrapper>,
}
impl<'de> Deserialize<'de> for Sequence {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
enum Type {
Sequence,
}
#[derive(Deserialize)]
pub struct SequenceHelper {
#[serde(rename = "type")]
_type: Type,
pretokenizers: Vec<PreTokenizerWrapper>,
}
let helper = SequenceHelper::deserialize(deserializer)?;
Ok(Sequence::new(helper.pretokenizers))
}
}
impl Sequence {
pub fn new(pretokenizers: Vec<PreTokenizerWrapper>) -> Self {
Self { pretokenizers }

View File

@ -1,5 +1,5 @@
use onig::Regex;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::{
pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
@ -24,26 +24,8 @@ impl From<&str> for SplitPattern {
}
}
/// We use this custom deserializer to provide the value for `regex` for `Split`
#[doc(hidden)]
#[derive(Deserialize)]
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
struct SplitDeserializer {
pattern: SplitPattern,
behavior: SplitDelimiterBehavior,
invert: bool,
}
impl std::convert::TryFrom<SplitDeserializer> for Split {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from(v: SplitDeserializer) -> Result<Self> {
Split::new(v.pattern, v.behavior, v.invert)
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", try_from = "SplitDeserializer")]
pub struct Split {
pattern: SplitPattern,
#[serde(skip)]
@ -52,6 +34,30 @@ pub struct Split {
invert: bool,
}
impl<'de> Deserialize<'de> for Split {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
enum Type {
Split,
}
#[derive(Deserialize)]
pub struct SplitHelper {
#[serde(rename = "type")]
_type: Type,
pattern: SplitPattern,
behavior: SplitDelimiterBehavior,
invert: bool,
}
let helper = SplitHelper::deserialize(deserializer)?;
Split::new(helper.pattern, helper.behavior, helper.invert).map_err(serde::de::Error::custom)
}
}
impl Clone for Split {
fn clone(&self) -> Self {
Split::new(self.pattern.clone(), self.behavior, self.invert).unwrap()

View File

@ -1,7 +1,7 @@
use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script};
use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result};
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub struct UnicodeScripts;
impl_serde_unit_struct!(UnicodeScriptsVisitor, UnicodeScripts);

View File

@ -4,7 +4,7 @@ use crate::tokenizer::{
pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
};
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub struct Whitespace;
impl_serde_unit_struct!(WhitespaceVisitor, Whitespace);
@ -27,7 +27,7 @@ impl PreTokenizer for Whitespace {
}
}
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct WhitespaceSplit;
impl_serde_unit_struct!(WhitespaceSplitVisitor, WhitespaceSplit);

View File

@ -403,7 +403,8 @@ impl Tokenizer {
}
pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
let content = read_to_string(file)?;
Ok(serde_json::from_str(&content)?)
let tokenizer = serde_json::from_str(&content)?;
Ok(tokenizer)
}
#[cfg(feature = "http")]
pub fn from_pretrained<S: AsRef<str>>(
@ -1131,7 +1132,8 @@ where
/// Instantiate a new Tokenizer from the given file
pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
let content = read_to_string(file)?;
Ok(serde_json::from_str(&content)?)
let tokenizer = serde_json::from_str(&content)?;
Ok(tokenizer)
}
}