mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 20:28:22 +00:00
Fixes and cleanup, suggestions by @n1t0.
This commit is contained in:
@@ -301,7 +301,7 @@ mod test {
|
||||
match py_dec.decoder {
|
||||
PyDecoderWrapper::Wrapped(msp) => match msp.as_ref() {
|
||||
DecoderWrapper::Metaspace(_) => {}
|
||||
_ => panic!("Expected Whitespace"),
|
||||
_ => panic!("Expected Metaspace"),
|
||||
},
|
||||
_ => panic!("Expected wrapped, not custom."),
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::collections::hash_map::RandomState;
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
@@ -10,11 +9,11 @@ use serde::{Deserialize, Serialize};
|
||||
use tk::models::bpe::BPE;
|
||||
use tk::models::wordlevel::WordLevel;
|
||||
use tk::models::wordpiece::WordPiece;
|
||||
use tk::models::ModelWrapper;
|
||||
use tk::{Model, Token};
|
||||
use tokenizers as tk;
|
||||
|
||||
use super::error::ToPyResult;
|
||||
use tk::models::ModelWrapper;
|
||||
|
||||
/// A Model represents some tokenization algorithm like BPE or Word
|
||||
/// This class cannot be constructed directly. Please use one of the concrete models.
|
||||
@@ -55,7 +54,7 @@ impl Model for PyModel {
|
||||
self.model.id_to_token(id)
|
||||
}
|
||||
|
||||
fn get_vocab(&self) -> &HashMap<String, u32, RandomState> {
|
||||
fn get_vocab(&self) -> &HashMap<String, u32> {
|
||||
self.model.get_vocab()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use crate::tokenizer::{Decoder, Result};
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Serialize, Serializer};
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Deserialize, Clone, Debug, Serialize)]
|
||||
/// Allows decoding Original BPE by joining all the tokens and then replacing
|
||||
/// the suffix used to identify end-of-words by whitespaces
|
||||
#[serde(tag = "type")]
|
||||
pub struct BPEDecoder {
|
||||
suffix: String,
|
||||
}
|
||||
@@ -26,15 +27,3 @@ impl Decoder for BPEDecoder {
|
||||
Ok(tokens.join("").replace(&self.suffix, " ").trim().to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for BPEDecoder {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut m = serializer.serialize_struct("BPEDecoder", 2)?;
|
||||
m.serialize_field("type", "BPEDecoder")?;
|
||||
m.serialize_field("suffix", &self.suffix)?;
|
||||
m.end()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use crate::tokenizer::{Decoder, Result};
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Serialize, Serializer};
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Deserialize, Clone, Debug, Serialize)]
|
||||
/// The WordPiece decoder takes care of decoding a list of wordpiece tokens
|
||||
/// back into a readable string.
|
||||
#[serde(tag = "type")]
|
||||
pub struct WordPiece {
|
||||
/// The prefix to be used for continuing subwords
|
||||
prefix: String,
|
||||
@@ -48,16 +49,3 @@ impl Decoder for WordPiece {
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for WordPiece {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut m = serializer.serialize_struct("BPEDecoder", 3)?;
|
||||
m.serialize_field("type", "BPEDecoder")?;
|
||||
m.serialize_field("prefix", &self.prefix)?;
|
||||
m.serialize_field("cleanup", &self.cleanup)?;
|
||||
m.end()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::tokenizer::{NormalizedString, Normalizer, Result};
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Serialize, Serializer};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use unicode_categories::UnicodeCategories;
|
||||
|
||||
/// Checks whether a character is whitespace
|
||||
@@ -49,7 +49,8 @@ fn is_chinese_char(c: char) -> bool {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Deserialize)]
|
||||
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub struct BertNormalizer {
|
||||
/// Whether to do the bert basic cleaning:
|
||||
/// 1. Remove any control characters
|
||||
@@ -63,21 +64,6 @@ pub struct BertNormalizer {
|
||||
lowercase: bool,
|
||||
}
|
||||
|
||||
impl Serialize for BertNormalizer {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut m = serializer.serialize_struct("BertNormalizer", 5)?;
|
||||
m.serialize_field("type", "BertNormalizer")?;
|
||||
m.serialize_field("clean_text", &self.clean_text)?;
|
||||
m.serialize_field("handle_chinese_chars", &self.handle_chinese_chars)?;
|
||||
m.serialize_field("strip_accents", &self.strip_accents)?;
|
||||
m.serialize_field("lowercase", &self.lowercase)?;
|
||||
m.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BertNormalizer {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
|
||||
@@ -1,26 +1,13 @@
|
||||
use crate::tokenizer::{NormalizedString, Normalizer, Result};
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Serialize, Serializer};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Deserialize)]
|
||||
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub struct Strip {
|
||||
strip_left: bool,
|
||||
strip_right: bool,
|
||||
}
|
||||
|
||||
impl Serialize for Strip {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut m = serializer.serialize_struct("Strip", 5)?;
|
||||
m.serialize_field("type", "BertNormalizer")?;
|
||||
m.serialize_field("strip_left", &self.strip_left)?;
|
||||
m.serialize_field("strip_right", &self.strip_right)?;
|
||||
m.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl Strip {
|
||||
pub fn new(strip_left: bool, strip_right: bool) -> Self {
|
||||
Strip {
|
||||
|
||||
@@ -1,27 +1,16 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::normalizers::NormalizerWrapper;
|
||||
use crate::tokenizer::{NormalizedString, Normalizer, Result};
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Serialize, Serializer};
|
||||
|
||||
#[derive(Clone, Deserialize, Debug)]
|
||||
#[derive(Clone, Deserialize, Debug, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
/// Allows concatenating multiple other Normalizer as a Sequence.
|
||||
/// All the normalizers run in sequence in the given order against the same NormalizedString.
|
||||
pub struct Sequence {
|
||||
normalizers: Vec<NormalizerWrapper>,
|
||||
}
|
||||
|
||||
impl Serialize for Sequence {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut m = serializer.serialize_struct("Sequence", 2)?;
|
||||
m.serialize_field("type", "Sequence")?;
|
||||
m.serialize_field("normalizers", &self.normalizers)?;
|
||||
m.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl Sequence {
|
||||
pub fn new(normalizers: Vec<NormalizerWrapper>) -> Self {
|
||||
Self { normalizers }
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use onig::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::tokenizer::{
|
||||
normalizer::Range, Decoder, Encoding, PostProcessor, PreTokenizedString, PreTokenizer, Result,
|
||||
};
|
||||
use onig::Regex;
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Serialize, Serializer};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
fn bytes_char() -> HashMap<u8, char> {
|
||||
let mut bs: Vec<u8> = vec![];
|
||||
@@ -38,10 +39,11 @@ lazy_static! {
|
||||
bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Copy, Clone, Debug)]
|
||||
#[derive(Deserialize, Serialize, Copy, Clone, Debug)]
|
||||
/// Provides all the necessary steps to handle the BPE tokenization at the byte-level. Takes care
|
||||
/// of all the required processing steps to transform a UTF-8 string as needed before and after the
|
||||
/// BPE model does its job.
|
||||
#[serde(tag = "type")]
|
||||
pub struct ByteLevel {
|
||||
/// Whether to add a leading space to the first word. This allows to treat the leading word
|
||||
/// just as any other word.
|
||||
@@ -212,19 +214,6 @@ pub fn process_offsets(encoding: &mut Encoding, add_prefix_space: bool) {
|
||||
});
|
||||
}
|
||||
|
||||
impl Serialize for ByteLevel {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut m = serializer.serialize_struct("ByteLevel", 3)?;
|
||||
m.serialize_field("type", "ByteLevel")?;
|
||||
m.serialize_field("add_prefix_space", &self.add_prefix_space)?;
|
||||
m.serialize_field("trim_offsets", &self.trim_offsets)?;
|
||||
m.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ByteLevel;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Serialize, Serializer};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Deserialize)]
|
||||
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub struct CharDelimiterSplit {
|
||||
delimiter: char,
|
||||
}
|
||||
@@ -22,15 +22,3 @@ impl PreTokenizer for CharDelimiterSplit {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for CharDelimiterSplit {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut m = serializer.serialize_struct("CharDelimiterSplit", 2)?;
|
||||
m.serialize_field("type", "CharDelimiterSplit")?;
|
||||
m.serialize_field("delimiter", &self.delimiter)?;
|
||||
m.end()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Serialize, Serializer};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
/// Replaces all the whitespaces by the provided meta character and then
|
||||
/// splits on this character
|
||||
#[serde(tag = "type")]
|
||||
pub struct Metaspace {
|
||||
replacement: char,
|
||||
str_rep: String,
|
||||
@@ -68,20 +68,6 @@ impl Decoder for Metaspace {
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Metaspace {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut m = serializer.serialize_struct("Metaspace", 3)?;
|
||||
m.serialize_field("type", "Metaspace")?;
|
||||
m.serialize_field("replacement", &self.replacement)?;
|
||||
m.serialize_field("str_rep", &self.str_rep)?;
|
||||
m.serialize_field("add_prefix_space", &self.add_prefix_space)?;
|
||||
m.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
use std::fmt;
|
||||
|
||||
use regex::Regex;
|
||||
use serde::de::{Error, Visitor};
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::tokenizer::{
|
||||
pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
|
||||
};
|
||||
use serde::de::{Error, Visitor};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub struct Whitespace {
|
||||
#[serde(default = "default_regex", skip)]
|
||||
re: Regex,
|
||||
}
|
||||
|
||||
@@ -34,19 +35,8 @@ impl PreTokenizer for Whitespace {
|
||||
}
|
||||
}
|
||||
|
||||
// manually implement serialize / deserialize because Whitespace is not a unit-struct but is
|
||||
// manually implement deserialize because Whitespace is not a unit-struct but is
|
||||
// serialized like one.
|
||||
impl Serialize for Whitespace {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut m = serializer.serialize_struct("Whitespace", 1)?;
|
||||
m.serialize_field("type", "Whitespace")?;
|
||||
m.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Whitespace {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
|
||||
@@ -657,10 +657,6 @@ where
|
||||
/// ```
|
||||
/// # use tokenizers::Tokenizer;
|
||||
/// # use tokenizers::models::bpe::BPE;
|
||||
/// # use tokenizers::normalizers::NormalizerWrapper;
|
||||
/// # use tokenizers::pre_tokenizers::PreTokenizerWrapper;
|
||||
/// # use tokenizers::processors::PostProcessorWrapper;
|
||||
/// # use tokenizers::decoders::DecoderWrapper;
|
||||
/// # let mut tokenizer = Tokenizer::new(BPE::default());
|
||||
/// #
|
||||
/// // Sequences:
|
||||
|
||||
@@ -125,6 +125,15 @@ fn pretoks() {
|
||||
// wrapped serializes same way as inner
|
||||
let ser_wrapped = serde_json::to_string(&ch_wrapped).unwrap();
|
||||
assert_eq!(ser_wrapped, ch_ser);
|
||||
|
||||
let wsp = Whitespace::default();
|
||||
let wsp_ser = serde_json::to_string(&wsp).unwrap();
|
||||
serde_json::from_str::<Whitespace>(&wsp_ser).unwrap();
|
||||
let err: Result<BertPreTokenizer, _> = serde_json::from_str(&wsp_ser);
|
||||
assert!(
|
||||
err.is_err(),
|
||||
"BertPreTokenizer shouldn't be deserializable from Whitespace"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user