mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Fixing bad deserialization following inclusion of a default for Punctuation
. (#884)
* Fixing bad deserialization following inclusion of a default for `Punctuation`. * don't remove the type now... * Adding slow test to run on all the tokenizers of the hub. * `PartialEq` everywhere. * Forcing `type` to exist on the `pre_tokenizers`.
This commit is contained in:
2
bindings/node/native/Cargo.lock
generated
2
bindings/node/native/Cargo.lock
generated
@ -1668,7 +1668,7 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.11.0"
|
||||
version = "0.11.1"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"cached-path",
|
||||
|
@ -1,5 +1,10 @@
|
||||
from tokenizers import Tokenizer, models, normalizers
|
||||
from tokenizers import Tokenizer
|
||||
import os
|
||||
import unittest
|
||||
from .utils import data_dir, albert_base
|
||||
import json
|
||||
from huggingface_hub import HfApi, hf_hub_url, cached_download
|
||||
import tqdm
|
||||
|
||||
|
||||
class TestSerialization:
|
||||
@ -8,3 +13,70 @@ class TestSerialization:
|
||||
# This used to fail because of BufReader that would fail because the
|
||||
# file exceeds the buffer capacity
|
||||
tokenizer = Tokenizer.from_file(albert_base)
|
||||
|
||||
|
||||
def check(tokenizer_file) -> bool:
|
||||
with open(tokenizer_file, "r") as f:
|
||||
data = json.load(f)
|
||||
if "pre_tokenizer" not in data:
|
||||
return True
|
||||
if "type" not in data["pre_tokenizer"]:
|
||||
return False
|
||||
if data["pre_tokenizer"]["type"] == "Sequence":
|
||||
for pre_tok in data["pre_tokenizer"]["pretokenizers"]:
|
||||
if "type" not in pre_tok:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def slow(test_case):
|
||||
"""
|
||||
Decorator marking a test as slow.
|
||||
|
||||
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
|
||||
|
||||
"""
|
||||
if os.getenv("RUN_SLOW") != "1":
|
||||
return unittest.skip("use `RUN_SLOW=1` to run")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
@slow
|
||||
class TestFullDeserialization(unittest.TestCase):
|
||||
def test_full_deserialization_hub(self):
|
||||
# Check we can read this file.
|
||||
# This used to fail because of BufReader that would fail because the
|
||||
# file exceeds the buffer capacity
|
||||
api = HfApi()
|
||||
|
||||
not_loadable = []
|
||||
invalid_pre_tokenizer = []
|
||||
|
||||
# models = api.list_models(filter="transformers")
|
||||
# for model in tqdm.tqdm(models):
|
||||
# model_id = model.modelId
|
||||
# for model_file in model.siblings:
|
||||
# filename = model_file.rfilename
|
||||
# if filename == "tokenizer.json":
|
||||
# all_models.append((model_id, filename))
|
||||
|
||||
all_models = [("HueyNemud/das22-10-camembert_pretrained", "tokenizer.json")]
|
||||
for (model_id, filename) in tqdm.tqdm(all_models):
|
||||
tokenizer_file = cached_download(hf_hub_url(model_id, filename=filename))
|
||||
|
||||
is_ok = check(tokenizer_file)
|
||||
if not is_ok:
|
||||
print(f"{model_id} is affected by no type")
|
||||
invalid_pre_tokenizer.append(model_id)
|
||||
try:
|
||||
Tokenizer.from_file(tokenizer_file)
|
||||
except Exception as e:
|
||||
print(f"{model_id} is not loadable: {e}")
|
||||
not_loadable.append(model_id)
|
||||
except:
|
||||
print(f"{model_id} is not loadable: Rust error")
|
||||
not_loadable.append(model_id)
|
||||
|
||||
self.assertEqual(invalid_pre_tokenizer, [])
|
||||
self.assertEqual(not_loadable, [])
|
||||
|
@ -5,7 +5,7 @@ fn is_bert_punc(x: char) -> bool {
|
||||
char::is_ascii_punctuation(&x) || x.is_punctuation()
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[derive(Copy, Clone, Debug, PartialEq)]
|
||||
pub struct BertPreTokenizer;
|
||||
impl_serde_unit_struct!(BertVisitor, BertPreTokenizer);
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use onig::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::tokenizer::{
|
||||
Decoder, Encoding, PostProcessor, PreTokenizedString, PreTokenizer, Result,
|
||||
@ -40,7 +40,7 @@ lazy_static! {
|
||||
bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Copy, Clone, Debug, PartialEq)]
|
||||
#[derive(Serialize, Copy, Clone, Debug, PartialEq)]
|
||||
/// Provides all the necessary steps to handle the BPE tokenization at the byte-level. Takes care
|
||||
/// of all the required processing steps to transform a UTF-8 string as needed before and after the
|
||||
/// BPE model does its job.
|
||||
@ -53,6 +53,30 @@ pub struct ByteLevel {
|
||||
/// Whether the post processing step should trim offsets to avoid including whitespaces.
|
||||
pub trim_offsets: bool,
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for ByteLevel {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
enum Type {
|
||||
ByteLevel,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ByteLevelHelper {
|
||||
#[serde(rename = "type")]
|
||||
_type: Type,
|
||||
add_prefix_space: bool,
|
||||
trim_offsets: bool,
|
||||
}
|
||||
|
||||
let helper = ByteLevelHelper::deserialize(deserializer)?;
|
||||
Ok(ByteLevel::new(helper.add_prefix_space, helper.trim_offsets))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ByteLevel {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
|
@ -1,14 +1,36 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
|
||||
#[derive(Copy, Clone, Debug, Serialize, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
#[non_exhaustive]
|
||||
pub struct CharDelimiterSplit {
|
||||
pub delimiter: char,
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for CharDelimiterSplit {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
enum Type {
|
||||
CharDelimiterSplit,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct CharDelimiterSplitHelper {
|
||||
#[serde(rename = "type")]
|
||||
_type: Type,
|
||||
delimiter: char,
|
||||
}
|
||||
|
||||
let helper = CharDelimiterSplitHelper::deserialize(deserializer)?;
|
||||
Ok(CharDelimiterSplit::new(helper.delimiter))
|
||||
}
|
||||
}
|
||||
|
||||
impl CharDelimiterSplit {
|
||||
pub fn new(delimiter: char) -> Self {
|
||||
CharDelimiterSplit { delimiter }
|
||||
|
@ -1,8 +1,8 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Serialize, Clone, Debug, PartialEq)]
|
||||
/// Pre tokenizes the numbers into single tokens. If individual_digits is set
|
||||
/// to true, then all digits are splitted into individual tokens.
|
||||
#[serde(tag = "type")]
|
||||
@ -11,6 +11,28 @@ pub struct Digits {
|
||||
pub individual_digits: bool,
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Digits {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
enum Type {
|
||||
Digits,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct DigitsHelper {
|
||||
#[serde(rename = "type")]
|
||||
_type: Type,
|
||||
individual_digits: bool,
|
||||
}
|
||||
|
||||
let helper = DigitsHelper::deserialize(deserializer)?;
|
||||
Ok(Digits::new(helper.individual_digits))
|
||||
}
|
||||
}
|
||||
|
||||
impl Digits {
|
||||
pub fn new(individual_digits: bool) -> Self {
|
||||
Self { individual_digits }
|
||||
|
@ -1,11 +1,11 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
/// Replaces all the whitespaces by the provided meta character and then
|
||||
/// splits on this character
|
||||
#[serde(tag = "type", from = "MetaspaceDeserializer")]
|
||||
#[serde(tag = "type")]
|
||||
pub struct Metaspace {
|
||||
replacement: char,
|
||||
pub add_prefix_space: bool,
|
||||
@ -13,17 +13,28 @@ pub struct Metaspace {
|
||||
str_rep: String,
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub struct MetaspaceDeserializer {
|
||||
replacement: char,
|
||||
add_prefix_space: bool,
|
||||
}
|
||||
impl<'de> Deserialize<'de> for Metaspace {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
enum Type {
|
||||
Metaspace,
|
||||
}
|
||||
|
||||
impl From<MetaspaceDeserializer> for Metaspace {
|
||||
fn from(v: MetaspaceDeserializer) -> Metaspace {
|
||||
Metaspace::new(v.replacement, v.add_prefix_space)
|
||||
#[derive(Deserialize)]
|
||||
pub struct MetaspaceHelper {
|
||||
#[serde(rename = "type")]
|
||||
_type: Type,
|
||||
replacement: char,
|
||||
pub add_prefix_space: bool,
|
||||
#[serde(skip, rename = "str_rep")]
|
||||
_str_rep: String,
|
||||
}
|
||||
|
||||
let helper = MetaspaceHelper::deserialize(deserializer)?;
|
||||
Ok(Metaspace::new(helper.replacement, helper.add_prefix_space))
|
||||
}
|
||||
}
|
||||
|
||||
@ -109,6 +120,12 @@ mod tests {
|
||||
serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
|
||||
metaspace
|
||||
);
|
||||
|
||||
let metaspace_parsed: Metaspace = serde_json::from_str(
|
||||
r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true}"#,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(metaspace_parsed, metaspace);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -23,7 +23,7 @@ use crate::pre_tokenizers::unicode_scripts::UnicodeScripts;
|
||||
use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
|
||||
use crate::{PreTokenizedString, PreTokenizer};
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
|
||||
#[serde(untagged)]
|
||||
pub enum PreTokenizerWrapper {
|
||||
BertPreTokenizer(BertPreTokenizer),
|
||||
@ -68,3 +68,51 @@ impl_enum_from!(Metaspace, PreTokenizerWrapper, Metaspace);
|
||||
impl_enum_from!(WhitespaceSplit, PreTokenizerWrapper, WhitespaceSplit);
|
||||
impl_enum_from!(Digits, PreTokenizerWrapper, Digits);
|
||||
impl_enum_from!(UnicodeScripts, PreTokenizerWrapper, UnicodeScripts);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize() {
|
||||
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(r#"{"type":"Sequence","pretokenizers":[{"type":"WhitespaceSplit"},{"type":"Metaspace","replacement":"▁","str_rep":"▁","add_prefix_space":true}]}"#).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
pre_tokenizer,
|
||||
PreTokenizerWrapper::Sequence(Sequence::new(vec![
|
||||
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
|
||||
]))
|
||||
);
|
||||
|
||||
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(
|
||||
r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true}"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
pre_tokenizer,
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
|
||||
);
|
||||
|
||||
let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(r#"{"type":"Sequence","pretokenizers":[{"type":"WhitespaceSplit"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#).unwrap();
|
||||
|
||||
assert_eq!(
|
||||
pre_tokenizer,
|
||||
PreTokenizerWrapper::Sequence(Sequence::new(vec![
|
||||
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
|
||||
PreTokenizerWrapper::Metaspace(Metaspace::new('▁', true))
|
||||
]))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_whitespace_split() {
|
||||
let pre_tokenizer: PreTokenizerWrapper =
|
||||
serde_json::from_str(r#"{"type":"WhitespaceSplit"}"#).unwrap();
|
||||
assert_eq!(
|
||||
pre_tokenizer,
|
||||
PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
||||
use unicode_categories::UnicodeCategories;
|
||||
@ -7,13 +7,35 @@ fn is_punc(x: char) -> bool {
|
||||
char::is_ascii_punctuation(&x) || x.is_punctuation()
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Copy, Clone, Debug)]
|
||||
#[derive(Serialize, Copy, Clone, Debug, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
pub struct Punctuation {
|
||||
#[serde(default = "default_split")]
|
||||
behavior: SplitDelimiterBehavior,
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Punctuation {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
enum Type {
|
||||
Punctuation,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct PunctuationHelper {
|
||||
#[serde(rename = "type")]
|
||||
_type: Type,
|
||||
#[serde(default = "default_split")]
|
||||
behavior: SplitDelimiterBehavior,
|
||||
}
|
||||
|
||||
let helper = PunctuationHelper::deserialize(deserializer)?;
|
||||
Ok(Punctuation::new(helper.behavior))
|
||||
}
|
||||
}
|
||||
|
||||
fn default_split() -> SplitDelimiterBehavior {
|
||||
SplitDelimiterBehavior::Isolated
|
||||
}
|
||||
@ -65,6 +87,18 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn deserialization() {
|
||||
let _punctuation: Punctuation = serde_json::from_str(r#"{"type": "punctuation"}"#).unwrap();
|
||||
let punctuation: Punctuation = serde_json::from_str(r#"{"type": "Punctuation"}"#).unwrap();
|
||||
assert_eq!(punctuation, Punctuation::default());
|
||||
assert_eq!(
|
||||
punctuation,
|
||||
Punctuation::new(SplitDelimiterBehavior::Isolated)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn deserialization_erroneous() {
|
||||
let _punctuation: Punctuation =
|
||||
serde_json::from_str(r#"{"type": "WhitespaceSplit"}"#).unwrap();
|
||||
}
|
||||
}
|
||||
|
@ -1,13 +1,35 @@
|
||||
use crate::pre_tokenizers::PreTokenizerWrapper;
|
||||
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
pub struct Sequence {
|
||||
pretokenizers: Vec<PreTokenizerWrapper>,
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Sequence {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
enum Type {
|
||||
Sequence,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SequenceHelper {
|
||||
#[serde(rename = "type")]
|
||||
_type: Type,
|
||||
pretokenizers: Vec<PreTokenizerWrapper>,
|
||||
}
|
||||
|
||||
let helper = SequenceHelper::deserialize(deserializer)?;
|
||||
Ok(Sequence::new(helper.pretokenizers))
|
||||
}
|
||||
}
|
||||
|
||||
impl Sequence {
|
||||
pub fn new(pretokenizers: Vec<PreTokenizerWrapper>) -> Self {
|
||||
Self { pretokenizers }
|
||||
|
@ -1,5 +1,5 @@
|
||||
use onig::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::tokenizer::{
|
||||
pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
|
||||
@ -24,26 +24,8 @@ impl From<&str> for SplitPattern {
|
||||
}
|
||||
}
|
||||
|
||||
/// We use this custom deserializer to provide the value for `regex` for `Split`
|
||||
#[doc(hidden)]
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
struct SplitDeserializer {
|
||||
pattern: SplitPattern,
|
||||
behavior: SplitDelimiterBehavior,
|
||||
invert: bool,
|
||||
}
|
||||
|
||||
impl std::convert::TryFrom<SplitDeserializer> for Split {
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
fn try_from(v: SplitDeserializer) -> Result<Self> {
|
||||
Split::new(v.pattern, v.behavior, v.invert)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", try_from = "SplitDeserializer")]
|
||||
pub struct Split {
|
||||
pattern: SplitPattern,
|
||||
#[serde(skip)]
|
||||
@ -52,6 +34,30 @@ pub struct Split {
|
||||
invert: bool,
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Split {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
enum Type {
|
||||
Split,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct SplitHelper {
|
||||
#[serde(rename = "type")]
|
||||
_type: Type,
|
||||
pattern: SplitPattern,
|
||||
behavior: SplitDelimiterBehavior,
|
||||
invert: bool,
|
||||
}
|
||||
|
||||
let helper = SplitHelper::deserialize(deserializer)?;
|
||||
Split::new(helper.pattern, helper.behavior, helper.invert).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Split {
|
||||
fn clone(&self) -> Self {
|
||||
Split::new(self.pattern.clone(), self.behavior, self.invert).unwrap()
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script};
|
||||
use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct UnicodeScripts;
|
||||
impl_serde_unit_struct!(UnicodeScriptsVisitor, UnicodeScripts);
|
||||
|
||||
|
@ -4,7 +4,7 @@ use crate::tokenizer::{
|
||||
pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct Whitespace;
|
||||
impl_serde_unit_struct!(WhitespaceVisitor, Whitespace);
|
||||
|
||||
@ -27,7 +27,7 @@ impl PreTokenizer for Whitespace {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[derive(Copy, Clone, Debug, PartialEq)]
|
||||
pub struct WhitespaceSplit;
|
||||
impl_serde_unit_struct!(WhitespaceSplitVisitor, WhitespaceSplit);
|
||||
|
||||
|
@ -403,7 +403,8 @@ impl Tokenizer {
|
||||
}
|
||||
pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
|
||||
let content = read_to_string(file)?;
|
||||
Ok(serde_json::from_str(&content)?)
|
||||
let tokenizer = serde_json::from_str(&content)?;
|
||||
Ok(tokenizer)
|
||||
}
|
||||
#[cfg(feature = "http")]
|
||||
pub fn from_pretrained<S: AsRef<str>>(
|
||||
@ -1131,7 +1132,8 @@ where
|
||||
/// Instantiate a new Tokenizer from the given file
|
||||
pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
|
||||
let content = read_to_string(file)?;
|
||||
Ok(serde_json::from_str(&content)?)
|
||||
let tokenizer = serde_json::from_str(&content)?;
|
||||
Ok(tokenizer)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user