mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-07 13:18:31 +00:00
Fixing bad deserialization following inclusion of a default for Punctuation. (#884)
* Fixing bad deserialization following inclusion of a default for `Punctuation`. * don't remove the type now... * Adding slow test to run on all the tokenizers of the hub. * `PartialEq` everywhere. * Forcing `type` to exist on the `pre_tokenizers`.
This commit is contained in:
2
bindings/node/native/Cargo.lock
generated
2
bindings/node/native/Cargo.lock
generated
@@ -1668,7 +1668,7 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokenizers"
|
name = "tokenizers"
|
||||||
version = "0.11.0"
|
version = "0.11.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aho-corasick",
|
"aho-corasick",
|
||||||
"cached-path",
|
"cached-path",
|
||||||
|
|||||||
@@ -1,5 +1,10 @@
|
|||||||
from tokenizers import Tokenizer, models, normalizers
|
from tokenizers import Tokenizer
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
from .utils import data_dir, albert_base
|
from .utils import data_dir, albert_base
|
||||||
|
import json
|
||||||
|
from huggingface_hub import HfApi, hf_hub_url, cached_download
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
class TestSerialization:
|
class TestSerialization:
|
||||||
@@ -8,3 +13,70 @@ class TestSerialization:
|
|||||||
# This used to fail because of BufReader that would fail because the
|
# This used to fail because of BufReader that would fail because the
|
||||||
# file exceeds the buffer capacity
|
# file exceeds the buffer capacity
|
||||||
tokenizer = Tokenizer.from_file(albert_base)
|
tokenizer = Tokenizer.from_file(albert_base)
|
||||||
|
|
||||||
|
|
||||||
|
def check(tokenizer_file) -> bool:
|
||||||
|
with open(tokenizer_file, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if "pre_tokenizer" not in data:
|
||||||
|
return True
|
||||||
|
if "type" not in data["pre_tokenizer"]:
|
||||||
|
return False
|
||||||
|
if data["pre_tokenizer"]["type"] == "Sequence":
|
||||||
|
for pre_tok in data["pre_tokenizer"]["pretokenizers"]:
|
||||||
|
if "type" not in pre_tok:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def slow(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test as slow.
|
||||||
|
|
||||||
|
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if os.getenv("RUN_SLOW") != "1":
|
||||||
|
return unittest.skip("use `RUN_SLOW=1` to run")(test_case)
|
||||||
|
else:
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
@slow
|
||||||
|
class TestFullDeserialization(unittest.TestCase):
|
||||||
|
def test_full_deserialization_hub(self):
|
||||||
|
# Check we can read this file.
|
||||||
|
# This used to fail because of BufReader that would fail because the
|
||||||
|
# file exceeds the buffer capacity
|
||||||
|
api = HfApi()
|
||||||
|
|
||||||
|
not_loadable = []
|
||||||
|
invalid_pre_tokenizer = []
|
||||||
|
|
||||||
|
# models = api.list_models(filter="transformers")
|
||||||
|
# for model in tqdm.tqdm(models):
|
||||||
|
# model_id = model.modelId
|
||||||
|
# for model_file in model.siblings:
|
||||||
|
# filename = model_file.rfilename
|
||||||
|
# if filename == "tokenizer.json":
|
||||||
|
# all_models.append((model_id, filename))
|
||||||
|
|
||||||
|
all_models = [("HueyNemud/das22-10-camembert_pretrained", "tokenizer.json")]
|
||||||
|
for (model_id, filename) in tqdm.tqdm(all_models):
|
||||||
|
tokenizer_file = cached_download(hf_hub_url(model_id, filename=filename))
|
||||||
|
|
||||||
|
is_ok = check(tokenizer_file)
|
||||||
|
if not is_ok:
|
||||||
|
print(f"{model_id} is affected by no type")
|
||||||
|
invalid_pre_tokenizer.append(model_id)
|
||||||
|
try:
|
||||||
|
Tokenizer.from_file(tokenizer_file)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"{model_id} is not loadable: {e}")
|
||||||
|
not_loadable.append(model_id)
|
||||||
|
except:
|
||||||
|
print(f"{model_id} is not loadable: Rust error")
|
||||||
|
not_loadable.append(model_id)
|
||||||
|
|
||||||
|
self.assertEqual(invalid_pre_tokenizer, [])
|
||||||
|
self.assertEqual(not_loadable, [])
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ fn is_bert_punc(x: char) -> bool {
|
|||||||
char::is_ascii_punctuation(&x) || x.is_punctuation()
|
char::is_ascii_punctuation(&x) || x.is_punctuation()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug)]
|
#[derive(Copy, Clone, Debug, PartialEq)]
|
||||||
pub struct BertPreTokenizer;
|
pub struct BertPreTokenizer;
|
||||||
impl_serde_unit_struct!(BertVisitor, BertPreTokenizer);
|
impl_serde_unit_struct!(BertVisitor, BertPreTokenizer);
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
use onig::Regex;
|
use onig::Regex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
|
|
||||||
use crate::tokenizer::{
|
use crate::tokenizer::{
|
||||||
Decoder, Encoding, PostProcessor, PreTokenizedString, PreTokenizer, Result,
|
Decoder, Encoding, PostProcessor, PreTokenizedString, PreTokenizer, Result,
|
||||||
@@ -40,7 +40,7 @@ lazy_static! {
|
|||||||
bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
|
bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, Copy, Clone, Debug, PartialEq)]
|
#[derive(Serialize, Copy, Clone, Debug, PartialEq)]
|
||||||
/// Provides all the necessary steps to handle the BPE tokenization at the byte-level. Takes care
|
/// Provides all the necessary steps to handle the BPE tokenization at the byte-level. Takes care
|
||||||
/// of all the required processing steps to transform a UTF-8 string as needed before and after the
|
/// of all the required processing steps to transform a UTF-8 string as needed before and after the
|
||||||
/// BPE model does its job.
|
/// BPE model does its job.
|
||||||
@@ -53,6 +53,30 @@ pub struct ByteLevel {
|
|||||||
/// Whether the post processing step should trim offsets to avoid including whitespaces.
|
/// Whether the post processing step should trim offsets to avoid including whitespaces.
|
||||||
pub trim_offsets: bool,
|
pub trim_offsets: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for ByteLevel {
|
||||||
|
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
enum Type {
|
||||||
|
ByteLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct ByteLevelHelper {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
_type: Type,
|
||||||
|
add_prefix_space: bool,
|
||||||
|
trim_offsets: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
let helper = ByteLevelHelper::deserialize(deserializer)?;
|
||||||
|
Ok(ByteLevel::new(helper.add_prefix_space, helper.trim_offsets))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Default for ByteLevel {
|
impl Default for ByteLevel {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|||||||
@@ -1,14 +1,36 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
|
|
||||||
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
|
#[derive(Copy, Clone, Debug, Serialize, PartialEq)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
#[non_exhaustive]
|
#[non_exhaustive]
|
||||||
pub struct CharDelimiterSplit {
|
pub struct CharDelimiterSplit {
|
||||||
pub delimiter: char,
|
pub delimiter: char,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for CharDelimiterSplit {
|
||||||
|
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
enum Type {
|
||||||
|
CharDelimiterSplit,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct CharDelimiterSplitHelper {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
_type: Type,
|
||||||
|
delimiter: char,
|
||||||
|
}
|
||||||
|
|
||||||
|
let helper = CharDelimiterSplitHelper::deserialize(deserializer)?;
|
||||||
|
Ok(CharDelimiterSplit::new(helper.delimiter))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl CharDelimiterSplit {
|
impl CharDelimiterSplit {
|
||||||
pub fn new(delimiter: char) -> Self {
|
pub fn new(delimiter: char) -> Self {
|
||||||
CharDelimiterSplit { delimiter }
|
CharDelimiterSplit { delimiter }
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
|
|
||||||
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
#[derive(Serialize, Clone, Debug, PartialEq)]
|
||||||
/// Pre tokenizes the numbers into single tokens. If individual_digits is set
|
/// Pre tokenizes the numbers into single tokens. If individual_digits is set
|
||||||
/// to true, then all digits are splitted into individual tokens.
|
/// to true, then all digits are splitted into individual tokens.
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
@@ -11,6 +11,28 @@ pub struct Digits {
|
|||||||
pub individual_digits: bool,
|
pub individual_digits: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for Digits {
|
||||||
|
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
enum Type {
|
||||||
|
Digits,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct DigitsHelper {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
_type: Type,
|
||||||
|
individual_digits: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
let helper = DigitsHelper::deserialize(deserializer)?;
|
||||||
|
Ok(Digits::new(helper.individual_digits))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Digits {
|
impl Digits {
|
||||||
pub fn new(individual_digits: bool) -> Self {
|
pub fn new(individual_digits: bool) -> Self {
|
||||||
Self { individual_digits }
|
Self { individual_digits }
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
|
|
||||||
use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||||
/// Replaces all the whitespaces by the provided meta character and then
|
/// Replaces all the whitespaces by the provided meta character and then
|
||||||
/// splits on this character
|
/// splits on this character
|
||||||
#[serde(tag = "type", from = "MetaspaceDeserializer")]
|
#[serde(tag = "type")]
|
||||||
pub struct Metaspace {
|
pub struct Metaspace {
|
||||||
replacement: char,
|
replacement: char,
|
||||||
pub add_prefix_space: bool,
|
pub add_prefix_space: bool,
|
||||||
@@ -13,17 +13,28 @@ pub struct Metaspace {
|
|||||||
str_rep: String,
|
str_rep: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[doc(hidden)]
|
impl<'de> Deserialize<'de> for Metaspace {
|
||||||
#[derive(Deserialize)]
|
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||||
#[serde(tag = "type")]
|
where
|
||||||
pub struct MetaspaceDeserializer {
|
D: Deserializer<'de>,
|
||||||
replacement: char,
|
{
|
||||||
add_prefix_space: bool,
|
#[derive(Deserialize)]
|
||||||
}
|
enum Type {
|
||||||
|
Metaspace,
|
||||||
|
}
|
||||||
|
|
||||||
impl From<MetaspaceDeserializer> for Metaspace {
|
#[derive(Deserialize)]
|
||||||
fn from(v: MetaspaceDeserializer) -> Metaspace {
|
pub struct MetaspaceHelper {
|
||||||
Metaspace::new(v.replacement, v.add_prefix_space)
|
#[serde(rename = "type")]
|
||||||
|
_type: Type,
|
||||||
|
replacement: char,
|
||||||
|
pub add_prefix_space: bool,
|
||||||
|
#[serde(skip, rename = "str_rep")]
|
||||||
|
_str_rep: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
let helper = MetaspaceHelper::deserialize(deserializer)?;
|
||||||
|
Ok(Metaspace::new(helper.replacement, helper.add_prefix_space))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,6 +120,12 @@ mod tests {
|
|||||||
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
|
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
|
||||||
metaspace
|
metaspace
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let metaspace_parsed: Metaspace = serde_json::from_str(
|
||||||
|
r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(metaspace_parsed, metaspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ use crate::pre_tokenizers::unicode_scripts::UnicodeScripts;
|
|||||||
use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
|
use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
|
||||||
use crate::{PreTokenizedString, PreTokenizer};
|
use crate::{PreTokenizedString, PreTokenizer};
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
|
||||||
#[serde(untagged)]
|
#[serde(untagged)]
|
||||||
pub enum PreTokenizerWrapper {
|
pub enum PreTokenizerWrapper {
|
||||||
BertPreTokenizer(BertPreTokenizer),
|
BertPreTokenizer(BertPreTokenizer),
|
||||||
@@ -68,3 +68,51 @@ impl_enum_from!(Metaspace, PreTokenizerWrapper, Metaspace);
|
|||||||
impl_enum_from!(WhitespaceSplit, PreTokenizerWrapper, WhitespaceSplit);
|
impl_enum_from!(WhitespaceSplit, PreTokenizerWrapper, WhitespaceSplit);
|
||||||
impl_enum_from!(Digits, PreTokenizerWrapper, Digits);
|
impl_enum_from!(Digits, PreTokenizerWrapper, Digits);
|
||||||
impl_enum_from!(UnicodeScripts, PreTokenizerWrapper, UnicodeScripts);
|
impl_enum_from!(UnicodeScripts, PreTokenizerWrapper, UnicodeScripts);
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_deserialize() {
|
||||||
|
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(r#"{"type":"Sequence","pretokenizers":[{"type":"WhitespaceSplit"},{"type":"Metaspace","replacement":"▁","str_rep":"▁","add_prefix_space":true}]}"#).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
pre_tokenizer,
|
||||||
|
PreTokenizerWrapper::Sequence(Sequence::new(vec![
|
||||||
|
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
|
||||||
|
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
|
||||||
|
]))
|
||||||
|
);
|
||||||
|
|
||||||
|
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(
|
||||||
|
r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true}"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
pre_tokenizer,
|
||||||
|
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
|
||||||
|
);
|
||||||
|
|
||||||
|
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(r#"{"type":"Sequence","pretokenizers":[{"type":"WhitespaceSplit"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
pre_tokenizer,
|
||||||
|
PreTokenizerWrapper::Sequence(Sequence::new(vec![
|
||||||
|
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
|
||||||
|
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
|
||||||
|
]))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_deserialize_whitespace_split() {
|
||||||
|
let pre_tokenizer: PreTokenizerWrapper =
|
||||||
|
serde_json::from_str(r#"{"type":"WhitespaceSplit"}"#).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
pre_tokenizer,
|
||||||
|
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
|
|
||||||
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
||||||
use unicode_categories::UnicodeCategories;
|
use unicode_categories::UnicodeCategories;
|
||||||
@@ -7,13 +7,35 @@ fn is_punc(x: char) -> bool {
|
|||||||
char::is_ascii_punctuation(&x) || x.is_punctuation()
|
char::is_ascii_punctuation(&x) || x.is_punctuation()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Copy, Clone, Debug)]
|
#[derive(Serialize, Copy, Clone, Debug, PartialEq)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub struct Punctuation {
|
pub struct Punctuation {
|
||||||
#[serde(default = "default_split")]
|
|
||||||
behavior: SplitDelimiterBehavior,
|
behavior: SplitDelimiterBehavior,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for Punctuation {
|
||||||
|
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
enum Type {
|
||||||
|
Punctuation,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct PunctuationHelper {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
_type: Type,
|
||||||
|
#[serde(default = "default_split")]
|
||||||
|
behavior: SplitDelimiterBehavior,
|
||||||
|
}
|
||||||
|
|
||||||
|
let helper = PunctuationHelper::deserialize(deserializer)?;
|
||||||
|
Ok(Punctuation::new(helper.behavior))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn default_split() -> SplitDelimiterBehavior {
|
fn default_split() -> SplitDelimiterBehavior {
|
||||||
SplitDelimiterBehavior::Isolated
|
SplitDelimiterBehavior::Isolated
|
||||||
}
|
}
|
||||||
@@ -65,6 +87,18 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn deserialization() {
|
fn deserialization() {
|
||||||
let _punctuation: Punctuation = serde_json::from_str(r#"{"type": "punctuation"}"#).unwrap();
|
let punctuation: Punctuation = serde_json::from_str(r#"{"type": "Punctuation"}"#).unwrap();
|
||||||
|
assert_eq!(punctuation, Punctuation::default());
|
||||||
|
assert_eq!(
|
||||||
|
punctuation,
|
||||||
|
Punctuation::new(SplitDelimiterBehavior::Isolated)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic]
|
||||||
|
fn deserialization_erroneous() {
|
||||||
|
let _punctuation: Punctuation =
|
||||||
|
serde_json::from_str(r#"{"type": "WhitespaceSplit"}"#).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,35 @@
|
|||||||
use crate::pre_tokenizers::PreTokenizerWrapper;
|
use crate::pre_tokenizers::PreTokenizerWrapper;
|
||||||
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result};
|
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, PartialEq)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub struct Sequence {
|
pub struct Sequence {
|
||||||
pretokenizers: Vec<PreTokenizerWrapper>,
|
pretokenizers: Vec<PreTokenizerWrapper>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for Sequence {
|
||||||
|
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
enum Type {
|
||||||
|
Sequence,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct SequenceHelper {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
_type: Type,
|
||||||
|
pretokenizers: Vec<PreTokenizerWrapper>,
|
||||||
|
}
|
||||||
|
|
||||||
|
let helper = SequenceHelper::deserialize(deserializer)?;
|
||||||
|
Ok(Sequence::new(helper.pretokenizers))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Sequence {
|
impl Sequence {
|
||||||
pub fn new(pretokenizers: Vec<PreTokenizerWrapper>) -> Self {
|
pub fn new(pretokenizers: Vec<PreTokenizerWrapper>) -> Self {
|
||||||
Self { pretokenizers }
|
Self { pretokenizers }
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use onig::Regex;
|
use onig::Regex;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Deserializer, Serialize};
|
||||||
|
|
||||||
use crate::tokenizer::{
|
use crate::tokenizer::{
|
||||||
pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
|
pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
|
||||||
@@ -24,26 +24,8 @@ impl From<&str> for SplitPattern {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// We use this custom deserializer to provide the value for `regex` for `Split`
|
#[derive(Debug, Serialize)]
|
||||||
#[doc(hidden)]
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
struct SplitDeserializer {
|
|
||||||
pattern: SplitPattern,
|
|
||||||
behavior: SplitDelimiterBehavior,
|
|
||||||
invert: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::convert::TryFrom<SplitDeserializer> for Split {
|
|
||||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
|
||||||
|
|
||||||
fn try_from(v: SplitDeserializer) -> Result<Self> {
|
|
||||||
Split::new(v.pattern, v.behavior, v.invert)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
#[serde(tag = "type", try_from = "SplitDeserializer")]
|
|
||||||
pub struct Split {
|
pub struct Split {
|
||||||
pattern: SplitPattern,
|
pattern: SplitPattern,
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
@@ -52,6 +34,30 @@ pub struct Split {
|
|||||||
invert: bool,
|
invert: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for Split {
|
||||||
|
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
enum Type {
|
||||||
|
Split,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct SplitHelper {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
_type: Type,
|
||||||
|
pattern: SplitPattern,
|
||||||
|
behavior: SplitDelimiterBehavior,
|
||||||
|
invert: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
let helper = SplitHelper::deserialize(deserializer)?;
|
||||||
|
Split::new(helper.pattern, helper.behavior, helper.invert).map_err(serde::de::Error::custom)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Clone for Split {
|
impl Clone for Split {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Split::new(self.pattern.clone(), self.behavior, self.invert).unwrap()
|
Split::new(self.pattern.clone(), self.behavior, self.invert).unwrap()
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script};
|
use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script};
|
||||||
use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result};
|
use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result};
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug, PartialEq)]
|
||||||
pub struct UnicodeScripts;
|
pub struct UnicodeScripts;
|
||||||
impl_serde_unit_struct!(UnicodeScriptsVisitor, UnicodeScripts);
|
impl_serde_unit_struct!(UnicodeScriptsVisitor, UnicodeScripts);
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use crate::tokenizer::{
|
|||||||
pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
|
pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug, PartialEq)]
|
||||||
pub struct Whitespace;
|
pub struct Whitespace;
|
||||||
impl_serde_unit_struct!(WhitespaceVisitor, Whitespace);
|
impl_serde_unit_struct!(WhitespaceVisitor, Whitespace);
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ impl PreTokenizer for Whitespace {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug)]
|
#[derive(Copy, Clone, Debug, PartialEq)]
|
||||||
pub struct WhitespaceSplit;
|
pub struct WhitespaceSplit;
|
||||||
impl_serde_unit_struct!(WhitespaceSplitVisitor, WhitespaceSplit);
|
impl_serde_unit_struct!(WhitespaceSplitVisitor, WhitespaceSplit);
|
||||||
|
|
||||||
|
|||||||
@@ -403,7 +403,8 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
|
pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
|
||||||
let content = read_to_string(file)?;
|
let content = read_to_string(file)?;
|
||||||
Ok(serde_json::from_str(&content)?)
|
let tokenizer = serde_json::from_str(&content)?;
|
||||||
|
Ok(tokenizer)
|
||||||
}
|
}
|
||||||
#[cfg(feature = "http")]
|
#[cfg(feature = "http")]
|
||||||
pub fn from_pretrained<S: AsRef<str>>(
|
pub fn from_pretrained<S: AsRef<str>>(
|
||||||
@@ -1131,7 +1132,8 @@ where
|
|||||||
/// Instantiate a new Tokenizer from the given file
|
/// Instantiate a new Tokenizer from the given file
|
||||||
pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
|
pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
|
||||||
let content = read_to_string(file)?;
|
let content = read_to_string(file)?;
|
||||||
Ok(serde_json::from_str(&content)?)
|
let tokenizer = serde_json::from_str(&content)?;
|
||||||
|
Ok(tokenizer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user