Fixes and cleanup, suggestions by @n1t0.

This commit is contained in:
Anthony MOI
2020-08-04 09:56:36 +02:00
committed by Anthony MOI
parent f6adcf0e7c
commit 363adedb4c
14 changed files with 51 additions and 155 deletions

View File

@@ -301,7 +301,7 @@ mod test {
match py_dec.decoder {
PyDecoderWrapper::Wrapped(msp) => match msp.as_ref() {
DecoderWrapper::Metaspace(_) => {}
_ => panic!("Expected Whitespace"),
_ => panic!("Expected Metaspace"),
},
_ => panic!("Expected wrapped, not custom."),
}

View File

@@ -1,4 +1,3 @@
use std::collections::hash_map::RandomState;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
@@ -10,11 +9,11 @@ use serde::{Deserialize, Serialize};
use tk::models::bpe::BPE;
use tk::models::wordlevel::WordLevel;
use tk::models::wordpiece::WordPiece;
use tk::models::ModelWrapper;
use tk::{Model, Token};
use tokenizers as tk;
use super::error::ToPyResult;
use tk::models::ModelWrapper;
/// A Model represents some tokenization algorithm like BPE or Word
/// This class cannot be constructed directly. Please use one of the concrete models.
@@ -55,7 +54,7 @@ impl Model for PyModel {
self.model.id_to_token(id)
}
fn get_vocab(&self) -> &HashMap<String, u32, RandomState> {
fn get_vocab(&self) -> &HashMap<String, u32> {
self.model.get_vocab()
}

View File

@@ -1,10 +1,11 @@
use crate::tokenizer::{Decoder, Result};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize, Serializer};
#[derive(Deserialize, Clone, Debug)]
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Clone, Debug, Serialize)]
/// Allows decoding Original BPE by joining all the tokens and then replacing
/// the suffix used to identify end-of-words by whitespaces
#[serde(tag = "type")]
pub struct BPEDecoder {
suffix: String,
}
@@ -26,15 +27,3 @@ impl Decoder for BPEDecoder {
Ok(tokens.join("").replace(&self.suffix, " ").trim().to_owned())
}
}
impl Serialize for BPEDecoder {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut m = serializer.serialize_struct("BPEDecoder", 2)?;
m.serialize_field("type", "BPEDecoder")?;
m.serialize_field("suffix", &self.suffix)?;
m.end()
}
}

View File

@@ -1,10 +1,11 @@
use crate::tokenizer::{Decoder, Result};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize, Serializer};
#[derive(Deserialize, Clone, Debug)]
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Clone, Debug, Serialize)]
/// The WordPiece decoder takes care of decoding a list of wordpiece tokens
/// back into a readable string.
#[serde(tag = "type")]
pub struct WordPiece {
/// The prefix to be used for continuing subwords
prefix: String,
@@ -48,16 +49,3 @@ impl Decoder for WordPiece {
Ok(output)
}
}
impl Serialize for WordPiece {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut m = serializer.serialize_struct("BPEDecoder", 3)?;
m.serialize_field("type", "BPEDecoder")?;
m.serialize_field("prefix", &self.prefix)?;
m.serialize_field("cleanup", &self.cleanup)?;
m.end()
}
}

View File

@@ -1,6 +1,6 @@
use crate::tokenizer::{NormalizedString, Normalizer, Result};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use unicode_categories::UnicodeCategories;
/// Checks whether a character is whitespace
@@ -49,7 +49,8 @@ fn is_chinese_char(c: char) -> bool {
}
}
#[derive(Copy, Clone, Debug, Deserialize)]
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
#[serde(tag = "type")]
pub struct BertNormalizer {
/// Whether to do the bert basic cleaning:
/// 1. Remove any control characters
@@ -63,21 +64,6 @@ pub struct BertNormalizer {
lowercase: bool,
}
impl Serialize for BertNormalizer {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut m = serializer.serialize_struct("BertNormalizer", 5)?;
m.serialize_field("type", "BertNormalizer")?;
m.serialize_field("clean_text", &self.clean_text)?;
m.serialize_field("handle_chinese_chars", &self.handle_chinese_chars)?;
m.serialize_field("strip_accents", &self.strip_accents)?;
m.serialize_field("lowercase", &self.lowercase)?;
m.end()
}
}
impl Default for BertNormalizer {
fn default() -> Self {
Self {

View File

@@ -1,26 +1,13 @@
use crate::tokenizer::{NormalizedString, Normalizer, Result};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize, Serializer};
use serde::{Deserialize, Serialize};
#[derive(Copy, Clone, Debug, Deserialize)]
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
#[serde(tag = "type")]
pub struct Strip {
strip_left: bool,
strip_right: bool,
}
impl Serialize for Strip {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut m = serializer.serialize_struct("Strip", 5)?;
m.serialize_field("type", "BertNormalizer")?;
m.serialize_field("strip_left", &self.strip_left)?;
m.serialize_field("strip_right", &self.strip_right)?;
m.end()
}
}
impl Strip {
pub fn new(strip_left: bool, strip_right: bool) -> Self {
Strip {

View File

@@ -1,27 +1,16 @@
use serde::{Deserialize, Serialize};
use crate::normalizers::NormalizerWrapper;
use crate::tokenizer::{NormalizedString, Normalizer, Result};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize, Serializer};
#[derive(Clone, Deserialize, Debug)]
#[derive(Clone, Deserialize, Debug, Serialize)]
#[serde(tag = "type")]
/// Allows concatenating multiple other Normalizer as a Sequence.
/// All the normalizers run in sequence in the given order against the same NormalizedString.
pub struct Sequence {
normalizers: Vec<NormalizerWrapper>,
}
impl Serialize for Sequence {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut m = serializer.serialize_struct("Sequence", 2)?;
m.serialize_field("type", "Sequence")?;
m.serialize_field("normalizers", &self.normalizers)?;
m.end()
}
}
impl Sequence {
pub fn new(normalizers: Vec<NormalizerWrapper>) -> Self {
Self { normalizers }

View File

@@ -1,10 +1,11 @@
use std::collections::{HashMap, HashSet};
use onig::Regex;
use serde::{Deserialize, Serialize};
use crate::tokenizer::{
normalizer::Range, Decoder, Encoding, PostProcessor, PreTokenizedString, PreTokenizer, Result,
};
use onig::Regex;
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize, Serializer};
use std::collections::{HashMap, HashSet};
fn bytes_char() -> HashMap<u8, char> {
let mut bs: Vec<u8> = vec![];
@@ -38,10 +39,11 @@ lazy_static! {
bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
}
#[derive(Deserialize, Copy, Clone, Debug)]
#[derive(Deserialize, Serialize, Copy, Clone, Debug)]
/// Provides all the necessary steps to handle the BPE tokenization at the byte-level. Takes care
/// of all the required processing steps to transform a UTF-8 string as needed before and after the
/// BPE model does its job.
#[serde(tag = "type")]
pub struct ByteLevel {
/// Whether to add a leading space to the first word. This allows to treat the leading word
/// just as any other word.
@@ -212,19 +214,6 @@ pub fn process_offsets(encoding: &mut Encoding, add_prefix_space: bool) {
});
}
impl Serialize for ByteLevel {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut m = serializer.serialize_struct("ByteLevel", 3)?;
m.serialize_field("type", "ByteLevel")?;
m.serialize_field("add_prefix_space", &self.add_prefix_space)?;
m.serialize_field("trim_offsets", &self.trim_offsets)?;
m.end()
}
}
#[cfg(test)]
mod tests {
use super::ByteLevel;

View File

@@ -1,9 +1,9 @@
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
#[derive(Copy, Clone, Debug, Deserialize)]
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
#[serde(tag = "type")]
pub struct CharDelimiterSplit {
delimiter: char,
}
@@ -22,15 +22,3 @@ impl PreTokenizer for CharDelimiterSplit {
})
}
}
impl Serialize for CharDelimiterSplit {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut m = serializer.serialize_struct("CharDelimiterSplit", 2)?;
m.serialize_field("type", "CharDelimiterSplit")?;
m.serialize_field("delimiter", &self.delimiter)?;
m.end()
}
}

View File

@@ -1,11 +1,11 @@
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
#[derive(Deserialize, Clone, Debug)]
#[derive(Serialize, Deserialize, Clone, Debug)]
/// Replaces all the whitespaces by the provided meta character and then
/// splits on this character
#[serde(tag = "type")]
pub struct Metaspace {
replacement: char,
str_rep: String,
@@ -68,20 +68,6 @@ impl Decoder for Metaspace {
}
}
impl Serialize for Metaspace {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut m = serializer.serialize_struct("Metaspace", 3)?;
m.serialize_field("type", "Metaspace")?;
m.serialize_field("replacement", &self.replacement)?;
m.serialize_field("str_rep", &self.str_rep)?;
m.serialize_field("add_prefix_space", &self.add_prefix_space)?;
m.end()
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,16 +1,17 @@
use std::fmt;
use regex::Regex;
use serde::de::{Error, Visitor};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::{Deserialize, Deserializer, Serialize};
use crate::tokenizer::{
pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
};
use serde::de::{Error, Visitor};
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Serialize)]
#[serde(tag = "type")]
pub struct Whitespace {
#[serde(default = "default_regex", skip)]
re: Regex,
}
@@ -34,19 +35,8 @@ impl PreTokenizer for Whitespace {
}
}
// manually implement serialize / deserialize because Whitespace is not a unit-struct but is
// manually implement deserialize because Whitespace is not a unit-struct but is
// serialized like one.
impl Serialize for Whitespace {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut m = serializer.serialize_struct("Whitespace", 1)?;
m.serialize_field("type", "Whitespace")?;
m.end()
}
}
impl<'de> Deserialize<'de> for Whitespace {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where

View File

@@ -657,10 +657,6 @@ where
/// ```
/// # use tokenizers::Tokenizer;
/// # use tokenizers::models::bpe::BPE;
/// # use tokenizers::normalizers::NormalizerWrapper;
/// # use tokenizers::pre_tokenizers::PreTokenizerWrapper;
/// # use tokenizers::processors::PostProcessorWrapper;
/// # use tokenizers::decoders::DecoderWrapper;
/// # let mut tokenizer = Tokenizer::new(BPE::default());
/// #
/// // Sequences:

View File

@@ -125,6 +125,15 @@ fn pretoks() {
// wrapped serializes same way as inner
let ser_wrapped = serde_json::to_string(&ch_wrapped).unwrap();
assert_eq!(ser_wrapped, ch_ser);
let wsp = Whitespace::default();
let wsp_ser = serde_json::to_string(&wsp).unwrap();
serde_json::from_str::<Whitespace>(&wsp_ser).unwrap();
let err: Result<BertPreTokenizer, _> = serde_json::from_str(&wsp_ser);
assert!(
err.is_err(),
"BertPreTokenizer shouldn't be deserializable from Whitespace"
);
}
#[test]