Merge pull request #19 from huggingface/handle-offsets

Handle offsets
This commit is contained in:
MOI Anthony
2019-12-30 10:46:30 -05:00
committed by GitHub
20 changed files with 826 additions and 401 deletions

View File

@ -456,14 +456,14 @@ dependencies = [
"regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_json 1.0.44 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode-normalization 0.1.11 (git+https://github.com/n1t0/unicode-normalization)",
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "unicode-normalization"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
source = "git+https://github.com/n1t0/unicode-normalization#894053d92493c55c89fe9b188c0fb2babaa9a84c"
dependencies = [
"smallvec 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
@ -570,7 +570,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum syn 1.0.11 (registry+https://github.com/rust-lang/crates.io-index)" = "dff0acdb207ae2fe6d5976617f887eb1e35a2ba52c13c7234c790960cdad9238"
"checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
"checksum thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b"
"checksum unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)" = "b561e267b2326bb4cebfc0ef9e68355c7abe6c6f522aeac2f5bf95d56c59bdcf"
"checksum unicode-normalization 0.1.11 (git+https://github.com/n1t0/unicode-normalization)" = "<none>"
"checksum unicode-width 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "caaa9d531767d1ff2150b9332433f32a24622147e5ebb1f26409d5da67afd479"
"checksum unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "826e7639553986605ec5979c7dd957c7895e93eabed50ab2ffa7f6128a75097c"
"checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"

View File

@ -6,7 +6,7 @@ import logging
logging.getLogger('transformers').disabled = True
logging.getLogger('transformers.tokenization_utils').disabled = True
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors, normalizers
from transformers import GPT2Tokenizer, BertTokenizer
parser = argparse.ArgumentParser()
@ -61,8 +61,19 @@ elif args.type == "bert":
print("Running Bert tokenizer")
tok_p = BertTokenizer.from_pretrained('bert-base-uncased')
tok_r = Tokenizer(models.WordPiece.from_files(args.vocab, unk_token="[UNK]", max_input_chars_per_word=100))
tok_r.with_pre_tokenizer(pre_tokenizers.BasicPreTokenizer.new(do_lower_case=True, tokenize_chinese_chars=True, never_split=[]))
tok_r = Tokenizer(models.WordPiece.from_files(
args.vocab,
unk_token="[UNK]",
max_input_chars_per_word=100)
)
tok_r.with_normalizer(normalizers.BertNormalizer.new(
clean_text=True,
handle_chinese_chars=True,
strip_accents=True,
lowercase=True,
))
# tok_r.with_pre_tokenizer(pre_tokenizers.Whitespace.new())
tok_r.with_pre_tokenizer(pre_tokenizers.BertPreTokenizer.new())
tok_r.with_decoder(decoders.WordPiece.new())
tok_r.with_post_processor(processors.BertProcessing.new(
("[SEP]", tok_r.token_to_id("[SEP]")),
@ -75,7 +86,7 @@ def tokenize_r():
return tok_r.encode_batch(text);
def tokenize_p():
return [tok_p.encode(sentence) for sentence in tqdm(text)]
return [tok_p.encode(sentence, add_special_tokens=True) for sentence in tqdm(text)]
print(f"Tokenizing {len(text)} lines")

View File

@ -1,6 +1,7 @@
extern crate tokenizers as tk;
use pyo3::prelude::*;
use pyo3::types::*;
#[pyclass(dict)]
#[repr(transparent)]
@ -16,15 +17,43 @@ impl Encoding {
#[pymethods]
impl Encoding {
// #[getter]
// fn get_original(&self) -> String {
// self.encoding.get_original().to_owned()
// }
#[getter]
fn get_original(&self) -> String {
self.encoding.get_normalized().get_original().to_owned()
}
// #[getter]
// fn get_normalized(&self) -> String {
// self.encoding.get_normalized().to_owned()
// }
#[getter]
fn get_normalized(&self) -> String {
self.encoding.get_normalized().get().to_owned()
}
#[args(kwargs = "**")]
fn get_range(
&self,
range: (usize, usize),
kwargs: Option<&PyDict>,
) -> PyResult<Option<String>> {
let mut original = false;
if let Some(kwargs) = kwargs {
if let Some(koriginal) = kwargs.get_item("original") {
original = koriginal.extract()?;
}
}
if original {
Ok(self
.encoding
.get_normalized()
.get_range_original(range.0..range.1)
.map(|s| s.to_owned()))
} else {
Ok(self
.encoding
.get_normalized()
.get_range(range.0..range.1)
.map(|s| s.to_owned()))
}
}
#[getter]
fn get_ids(&self) -> Vec<u32> {
@ -41,10 +70,10 @@ impl Encoding {
self.encoding.get_type_ids().to_vec()
}
// #[getter]
// fn get_offsets(&self) -> Vec<(usize, usize)> {
// self.encoding.get_offsets().to_vec()
// }
#[getter]
fn get_offsets(&self) -> Vec<(usize, usize)> {
self.encoding.get_offsets().to_vec()
}
#[getter]
fn get_special_tokens_mask(&self) -> Vec<u32> {

View File

@ -2,6 +2,7 @@ mod decoders;
mod encoding;
mod error;
mod models;
mod normalizers;
mod pre_tokenizers;
mod processors;
mod token;
@ -55,6 +56,14 @@ fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
Ok(())
}
/// Normalizers Module
#[pymodule]
fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<normalizers::Normalizer>()?;
m.add_class::<normalizers::BertNormalizer>()?;
Ok(())
}
/// Tokenizers Module
#[pymodule]
fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
@ -63,6 +72,7 @@ fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pymodule!(pre_tokenizers))?;
m.add_wrapped(wrap_pymodule!(decoders))?;
m.add_wrapped(wrap_pymodule!(processors))?;
m.add_wrapped(wrap_pymodule!(normalizers))?;
m.add_wrapped(wrap_pymodule!(trainers))?;
Ok(())
}

View File

@ -0,0 +1,46 @@
extern crate tokenizers as tk;
use super::utils::Container;
use pyo3::prelude::*;
use pyo3::types::*;
#[pyclass(dict)]
pub struct Normalizer {
pub normalizer: Container<dyn tk::tokenizer::Normalizer + Sync>,
}
#[pyclass]
pub struct BertNormalizer {}
#[pymethods]
impl BertNormalizer {
#[staticmethod]
#[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<Normalizer> {
let mut clean_text = true;
let mut handle_chinese_chars = true;
let mut strip_accents = true;
let mut lowercase = true;
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
match key {
"clean_text" => clean_text = value.extract()?,
"handle_chinese_chars" => handle_chinese_chars = value.extract()?,
"strip_accents" => strip_accents = value.extract()?,
"lowercase" => lowercase = value.extract()?,
_ => println!("Ignored unknown kwargs option {}", key),
}
}
}
Ok(Normalizer {
normalizer: Container::Owned(Box::new(tk::normalizers::bert::BertNormalizer::new(
clean_text,
handle_chinese_chars,
strip_accents,
lowercase,
))),
})
}
}

View File

@ -4,8 +4,7 @@ use super::error::{PyError, ToPyResult};
use super::utils::Container;
use pyo3::prelude::*;
use pyo3::types::*;
use std::collections::HashSet;
use tk::tokenizer::Result;
use tk::tokenizer::{Offsets, Result};
#[pyclass(dict)]
pub struct PreTokenizer {
@ -21,7 +20,7 @@ impl PreTokenizer {
})
}
fn pre_tokenize(&self, s: &str) -> PyResult<Vec<String>> {
fn pre_tokenize(&self, s: &str) -> PyResult<Vec<(String, Offsets)>> {
ToPyResult(self.pretok.execute(|pretok| pretok.pre_tokenize(s))).into()
}
}
@ -58,36 +57,9 @@ pub struct BertPreTokenizer {}
#[pymethods]
impl BertPreTokenizer {
#[staticmethod]
#[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<PreTokenizer> {
let mut do_basic_tokenize = true;
let mut do_lower_case = true;
let mut never_split = HashSet::new();
let mut tokenize_chinese_chars = true;
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
let key: &str = key.extract()?;
match key {
"do_basic_tokenize" => do_basic_tokenize = val.extract()?,
"do_lower_case" => do_lower_case = val.extract()?,
"tokenize_chinese_chars" => tokenize_chinese_chars = val.extract()?,
"never_split" => {
let values: Vec<String> = val.extract()?;
never_split = values.into_iter().collect();
}
_ => println!("Ignored unknown kwargs option {}", key),
}
}
}
fn new() -> PyResult<PreTokenizer> {
Ok(PreTokenizer {
pretok: Container::Owned(Box::new(tk::pre_tokenizers::bert::BertPreTokenizer::new(
do_basic_tokenize,
do_lower_case,
never_split,
tokenize_chinese_chars,
))),
pretok: Container::Owned(Box::new(tk::pre_tokenizers::bert::BertPreTokenizer)),
})
}
}
@ -104,7 +76,7 @@ impl PyPreTokenizer {
}
impl tk::tokenizer::PreTokenizer for PyPreTokenizer {
fn pre_tokenize(&self, sentence: &str) -> Result<Vec<String>> {
fn pre_tokenize(&self, sentence: &str) -> Result<Vec<(String, Offsets)>> {
let gil = Python::acquire_gil();
let py = gil.python();
@ -112,9 +84,15 @@ impl tk::tokenizer::PreTokenizer for PyPreTokenizer {
match self.class.call_method(py, "pre_tokenize", args, None) {
Ok(res) => Ok(res
.cast_as::<PyList>(py)
.map_err(|_| PyError::from("`pre_tokenize is expected to return a List[str]"))?
.extract::<Vec<String>>()
.map_err(|_| PyError::from("`pre_tokenize` is expected to return a List[str]"))?),
.map_err(|_| {
PyError::from("`pre_tokenize is expected to return a List[(str, (uint, uint))]")
})?
.extract::<Vec<(String, Offsets)>>()
.map_err(|_| {
PyError::from(
"`pre_tokenize` is expected to return a List[(str, (uint, uint))]",
)
})?),
Err(e) => {
e.print(py);
Err(Box::new(PyError::from(

View File

@ -8,6 +8,7 @@ use super::decoders::Decoder;
use super::encoding::Encoding;
use super::error::{PyError, ToPyResult};
use super::models::Model;
use super::normalizers::Normalizer;
use super::pre_tokenizers::PreTokenizer;
use super::processors::PostProcessor;
use super::trainers::Trainer;
@ -97,6 +98,17 @@ impl Tokenizer {
}
}
fn with_normalizer(&mut self, normalizer: &mut Normalizer) -> PyResult<()> {
if let Some(normalizer) = normalizer.normalizer.to_pointer() {
self.tokenizer.with_normalizer(normalizer);
Ok(())
} else {
Err(exceptions::Exception::py_err(
"The Normalizer is already being used in another Tokenizer",
))
}
}
#[args(kwargs = "**")]
fn with_truncation(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut stride = 0;

View File

@ -1,3 +1,3 @@
__version__ = "0.0.11"
from .tokenizers import Tokenizer, models, decoders, pre_tokenizers, trainers, processors
from .tokenizers import Tokenizer, models, decoders, pre_tokenizers, trainers, processors, normalizers

View File

@ -19,7 +19,7 @@ regex-syntax = "0.6.12"
rayon = "1.2.0"
serde_json = "1.0"
clap = "2.33.0"
unicode-normalization = "0.1.11"
unicode-normalization = { git = "https://github.com/n1t0/unicode-normalization" }
unicode_categories = "0.1.1"
[dev-dependencies]

View File

@ -1,5 +1,5 @@
use super::{Cache, Error, Pair, Word};
use crate::tokenizer::{Model, Result, Token};
use crate::tokenizer::{Model, Offsets, Result, Token};
use serde_json::Value;
use std::{
collections::HashMap,
@ -103,15 +103,20 @@ impl Model for BPE {
self.vocab.len()
}
fn tokenize(&self, sentence: Vec<String>) -> Result<Vec<Token>> {
fn tokenize(&self, sentence: Vec<(String, Offsets)>) -> Result<Vec<Token>> {
if sentence.is_empty() {
return Ok(vec![]);
}
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
let mut cached_words = self.cache.get_values(&sentence);
let mut cached_words = self.cache.get_values(
&sentence
.iter()
.map(|(s, _)| s.to_owned())
.collect::<Vec<_>>(),
);
for (i, w) in sentence.iter().enumerate() {
for (i, (w, initial_offsets)) in sentence.iter().enumerate() {
if cached_words[i].is_none() {
let mut word = Word::new();
for c in w.chars() {
@ -155,9 +160,6 @@ impl Model for BPE {
cached_words[i] = Some(word);
}
// Offsets are word-based, we need to translate them to be sentence-based
let last_offset = encoded.last().map(|token| token.offsets.1).unwrap_or(0);
let word = cached_words[i].as_ref().unwrap();
let tokens = word
.get_chars()
@ -167,7 +169,7 @@ impl Model for BPE {
Token::new(
*id,
self.vocab_r[id].clone(),
(last_offset + offsets.0, last_offset + offsets.1),
(initial_offsets.0 + offsets.0, initial_offsets.0 + offsets.1),
)
})
.collect::<Vec<_>>();
@ -180,7 +182,7 @@ impl Model for BPE {
.into_iter()
.zip(cached_words)
.filter(|(_, v)| v.is_some())
.map(|(k, v)| (k, v.unwrap()))
.map(|(k, v)| (k.0, v.unwrap()))
.unzip::<_, _, Vec<String>, Vec<Word>>();
self.cache.set_values(keys, values);

View File

@ -1,4 +1,4 @@
use crate::tokenizer::{Model, Result, Token};
use crate::tokenizer::{Model, Offsets, Result, Token};
use std::{
collections::HashMap,
fmt,
@ -70,11 +70,10 @@ impl Model for WordPiece {
self.vocab.len()
}
fn tokenize(&self, sentence: Vec<String>) -> Result<Vec<Token>> {
fn tokenize(&self, sentence: Vec<(String, Offsets)>) -> Result<Vec<Token>> {
let mut output_tokens = vec![];
let mut offset = 0usize;
for token in sentence {
for (token, initial_offsets) in sentence {
let char_len = token.chars().count();
if char_len > self.max_input_chars_per_word {
output_tokens.push(Token {
@ -83,7 +82,7 @@ impl Model for WordPiece {
.vocab
.get(&self.unk_token)
.ok_or(Error::MissingUnkToken)?,
offsets: (offset, offset + char_len),
offsets: initial_offsets,
});
continue;
}
@ -106,7 +105,7 @@ impl Model for WordPiece {
cur_str = Some(Token {
id: self.vocab[&substr],
value: substr,
offsets: (offset + start, offset + end),
offsets: (initial_offsets.0 + start, initial_offsets.0 + end),
});
break;
}
@ -129,13 +128,11 @@ impl Model for WordPiece {
.vocab
.get(&self.unk_token)
.ok_or(Error::MissingUnkToken)?,
offsets: (offset, offset + char_len),
offsets: initial_offsets,
});
} else {
output_tokens.extend(sub_tokens);
}
offset += char_len;
}
Ok(output_tokens)

View File

@ -0,0 +1,122 @@
use crate::tokenizer::{NormalizedString, Normalizer, Result};
use unicode_categories::UnicodeCategories;
/// Checks whether a character is whitespace
fn is_whitespace(c: char) -> bool {
// These are technically control characters but we count them as whitespace
if c == '\t' || c == '\n' || c == '\r' {
true
} else {
c.is_whitespace()
}
}
/// Checks whether a character is a control character
fn is_control(c: char) -> bool {
// These are technically control characters but we count them as whitespace
if c == '\t' || c == '\n' || c == '\r' {
false
} else {
// The definition of `is_control` here is quite large and contains also
// Cc, Cf, Cn or Co
// cf. https://unicode.org/reports/tr44/ (Table 12)
c.is_other()
}
}
/// Checks whether a character is chinese
/// This defines a "chinese character" as anything in the CJK Unicode block:
/// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
///
/// Note that the CJK Unicode block is NOT all Japanese and Korean characters,
/// despite its name. The modern Korean Hangul alphabet is a different block,
/// as is Japanese Hiragana and Katakana. Those alphabets are used to write
/// space-separated words, so they are not treated specially and handled
/// like for all of the other languages.
fn is_chinese_char(c: char) -> bool {
match c as usize {
0x4E00..=0x9FFF => true,
0x3400..=0x4DBF => true,
0x20000..=0x2A6DF => true,
0x2A700..=0x2B73F => true,
0x2B740..=0x2B81F => true,
0x2B920..=0x2CEAF => true,
0xF900..=0xFAFF => true,
0x2F800..=0x2FA1F => true,
_ => false,
}
}
pub struct BertNormalizer {
/// Whether to do the bert basic cleaning:
/// 1. Remove any control characters
/// 2. Replace all sorts of whitespace by the classic one ` `
clean_text: bool,
/// Whether to put spaces around chinese characters so they get split
handle_chinese_chars: bool,
/// Whether to strip accents
strip_accents: bool,
/// Whether to lowercase the input
lowercase: bool,
}
impl BertNormalizer {
pub fn new(
clean_text: bool,
handle_chinese_chars: bool,
strip_accents: bool,
lowercase: bool,
) -> Self {
BertNormalizer {
clean_text,
handle_chinese_chars,
strip_accents,
lowercase,
}
}
fn do_clean_text(&self, normalized: &mut NormalizedString) {
normalized
.filter(|c| !(*c as usize == 0 || *c as usize == 0xfffd || is_control(*c)))
.map(|c| if is_whitespace(c) { ' ' } else { c });
}
fn do_handle_chinese_chars(&self, normalized: &mut NormalizedString) {
let mut new_chars: Vec<(char, isize)> = vec![];
normalized.for_each(|c| {
if is_chinese_char(c) {
new_chars.extend(&[(' ', 1), (c, 0), (' ', 1)]);
} else {
new_chars.push((c, 0));
}
});
normalized.transform(new_chars.into_iter());
}
fn do_strip_accents(&self, normalized: &mut NormalizedString) {
normalized.nfd().filter(|c| !c.is_mark_nonspacing());
}
fn do_lowercase(&self, normalized: &mut NormalizedString) {
normalized.lowercase();
}
}
impl Normalizer for BertNormalizer {
fn normalize(&self, mut normalized: &mut NormalizedString) -> Result<()> {
if self.clean_text {
self.do_clean_text(&mut normalized);
}
if self.handle_chinese_chars {
self.do_handle_chinese_chars(&mut normalized);
}
if self.strip_accents {
self.do_strip_accents(&mut normalized);
}
if self.lowercase {
self.do_lowercase(&mut normalized);
}
Ok(())
}
}

View File

@ -1 +1 @@
pub mod bert;

View File

@ -1,182 +1,76 @@
use crate::tokenizer::{PreTokenizer, Result};
use std::collections::HashSet;
use unicode_categories::UnicodeCategories;
use unicode_normalization::UnicodeNormalization;
use crate::tokenizer::{Offsets, PreTokenizer, Result};
/// Extremely simple tokenization on whitespaces
fn whitespace_tokenize(s: &str) -> Vec<&str> {
s.trim()
.split(char::is_whitespace)
.filter(|s| *s != " ")
.collect()
}
/// Checks whether a character is whitespace
fn is_whitespace(c: char) -> bool {
// These are technically control characters but we count them as whitespace
if c == '\t' || c == '\n' || c == '\r' {
true
} else {
c.is_whitespace()
/// Split the given string as the `should_split` predicate dictates. Keep track of the offsets
fn split_on<F: Fn(&char) -> bool>(
s: &str,
should_split: F,
include_split_token: bool,
) -> Vec<(String, Offsets)> {
let mut words: Vec<(String, Offsets)> = vec![];
let mut offset = 0;
let mut word = Vec::with_capacity(50);
s.chars().for_each(|c| {
if should_split(&c) {
if !word.is_empty() {
let offsets = (offset - word.len(), offset);
words.push((word.drain(0..).collect::<String>(), offsets));
}
}
/// Checks whether a character is a control character
fn is_control(c: char) -> bool {
// These are technically control characters but we count them as whitespace
if c == '\t' || c == '\n' || c == '\r' {
false
} else {
// The definition of `is_control` here is quite large and contains also
// Cc, Cf, Cn or Co
// cf. https://unicode.org/reports/tr44/ (Table 12)
c.is_other()
if include_split_token {
words.push((c.to_string(), (offset, offset + 1)));
}
}
/// Checks whether a character is chinese
/// This defines a "chinese character" as anything in the CJK Unicode block:
/// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
///
/// Note that the CJK Unicode block is NOT all Japanese and Korean characters,
/// despite its name. The modern Korean Hangul alphabet is a different block,
/// as is Japanese Hiragana and Katakana. Those alphabets are used to write
/// space-separated words, so they are not treated specially and handled
/// like for all of the other languages.
fn is_chinese_char(c: char) -> bool {
match c as usize {
0x4E00..=0x9FFF => true,
0x3400..=0x4DBF => true,
0x20000..=0x2A6DF => true,
0x2A700..=0x2B73F => true,
0x2B740..=0x2B81F => true,
0x2B920..=0x2CEAF => true,
0xF900..=0xFAFF => true,
0x2F800..=0x2FA1F => true,
_ => false,
}
}
pub struct BertPreTokenizer {
/// Whether to do the basic tokenization
do_basic_tokenize: bool,
/// Whether to lower case the input.
do_lower_case: bool,
/// A list of token not to split.
never_split: HashSet<String>,
/// Whether to tokenize Chinese characters
tokenize_chinese_chars: bool,
}
impl BertPreTokenizer {
pub fn new(
do_basic_tokenize: bool,
do_lower_case: bool,
never_split: HashSet<String>,
tokenize_chinese_chars: bool,
) -> Self {
BertPreTokenizer {
do_basic_tokenize,
do_lower_case,
never_split,
tokenize_chinese_chars,
}
}
/// Strips accents from a piece of text
fn run_strip_accents(&self, text: &str) -> String {
text.nfd()
.filter(|c| !c.is_mark_nonspacing())
.collect::<String>()
}
/// Splits punctuation on a piece of text.
fn run_split_on_punc(&self, text: &str) -> Vec<String> {
if self.never_split.contains(text) {
return vec![text.to_owned()];
}
let mut output: Vec<Vec<char>> = vec![];
let mut start_new_word = true;
text.chars().for_each(|c| {
if c.is_ascii_punctuation() {
output.push(vec![c]);
start_new_word = true;
} else {
if start_new_word {
output.push(vec![]);
}
start_new_word = false;
output.last_mut().unwrap().push(c);
} else if !should_split(&c) {
word.push(c);
}
offset += 1;
});
output
.into_iter()
.map(|cs| cs.into_iter().collect::<String>())
.collect()
// Don't forget the potential last word
if !word.is_empty() {
let offsets = (offset - word.len(), offset);
words.push((word.drain(0..).collect::<String>(), offsets));
}
fn tokenize_chinese_chars(&self, text: &str) -> String {
text.chars()
.map(|c| {
if is_chinese_char(c) {
vec![' ', c, ' ']
} else {
vec![c]
}
})
.flatten()
.collect::<String>()
}
fn clean_text(&self, text: &str) -> String {
text.chars()
.map(|c| {
if c as usize == 0 || c as usize == 0xfffd || is_control(c) {
None
} else if is_whitespace(c) {
Some(' ')
} else {
Some(c)
}
})
.filter(|c| c.is_some())
.map(|c| c.unwrap())
.collect::<String>()
}
words
}
pub struct BertPreTokenizer;
impl PreTokenizer for BertPreTokenizer {
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
if !self.do_basic_tokenize {
Ok(whitespace_tokenize(&s)
.into_iter()
.map(|s| s.to_owned())
.collect())
} else {
let mut text = self.clean_text(s);
// This was added on November 1st, 2018 for the multilingual and Chinese
// models. This is also applied to the English models now, but it doesn't
// matter since the English models were not trained on any Chinese data
// and generally don't have any Chinese data in them (there are Chinese
// characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.).
if self.tokenize_chinese_chars {
text = self.tokenize_chinese_chars(&text);
}
let orig_tokens = whitespace_tokenize(&text);
fn pre_tokenize(&self, s: &str) -> Result<Vec<(String, Offsets)>> {
let mut split_tokens = vec![];
for token in orig_tokens {
let mut tk = token.to_owned();
if self.do_lower_case && !self.never_split.contains(token) {
tk = self.run_strip_accents(&token.to_lowercase())
for (token, offsets) in split_on(&s, |c| char::is_whitespace(*c), false) {
split_tokens.extend(
split_on(&token, char::is_ascii_punctuation, true)
.into_iter()
.map(|(tok, off)| (tok, (off.0 + offsets.0, off.1 + offsets.0))),
);
}
split_tokens.extend(self.run_split_on_punc(&tk));
}
Ok(split_tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic() {
let pretok = BertPreTokenizer;
let res = pretok
.pre_tokenize("Hey friend! How are you?!?")
.unwrap();
assert_eq!(
&res,
&[
("Hey".into(), (0, 3)),
("friend".into(), (4, 10)),
("!".into(), (10, 11)),
("How".into(), (16, 19)),
("are".into(), (20, 23)),
("you".into(), (24, 27)),
("?".into(), (27, 28)),
("!".into(), (28, 29)),
("?".into(), (29, 30)),
]
);
}
}

View File

@ -1,4 +1,4 @@
use crate::tokenizer::{Decoder, PreTokenizer, Result};
use crate::tokenizer::{Decoder, Offsets, PreTokenizer, Result};
use regex::Regex;
use std::collections::HashMap;
use unicode_categories::UnicodeCategories;
@ -41,7 +41,7 @@ impl ByteLevel {
}
impl PreTokenizer for ByteLevel {
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
fn pre_tokenize(&self, s: &str) -> Result<Vec<(String, Offsets)>> {
let s = if self.add_prefix_space && !s.starts_with(' ') {
format!(" {}", s)
} else {
@ -59,19 +59,11 @@ impl PreTokenizer for ByteLevel {
// we don't want to return it
let last = s[start..end].chars().last();
let next = s[end..].chars().nth(0);
if last.is_some()
&& last.unwrap().is_separator_space()
&& next.is_some()
&& !next.unwrap().is_separator_space()
{
if let Some(newstr) = s[start..end]
.chars()
.collect::<Vec<_>>()
.split_last()
.map(|(_, rest)| rest)
.map(|chars| chars.iter().collect::<String>())
{
return newstr;
if let (Some(last), Some(next)) = (last, next) {
if last.is_separator_space() && !next.is_separator_space() {
let bytes = s[start..end - 1].as_bytes().to_vec();
let offsets = (start, end - 1);
return (bytes, offsets);
}
}
// if our first char is not a whitespace but the previous one was, we return
@ -80,17 +72,22 @@ impl PreTokenizer for ByteLevel {
let current = s[start..end].chars().nth(0).map(|c| c.is_whitespace());
if let (Some(prev), Some(current)) = (prev, current) {
if prev.is_separator_space() && !current {
return format!("{}{}", prev, s[start..end].to_owned());
let bytes =
[format!("{}", prev).as_bytes(), s[start..end].as_bytes()].concat();
let offsets = (start - 1, end);
return (bytes, offsets);
}
}
s[start..end].to_owned()
(s[start..end].as_bytes().to_vec(), (start, end))
})
.map(|s| {
s.into_bytes()
.iter()
.map(|(s, offsets)| {
(
s.iter()
.map(|b| std::char::from_u32(BYTES_CHAR[b]).unwrap())
.collect()
.collect(),
offsets,
)
})
.collect())
}
@ -122,7 +119,16 @@ mod tests {
.pre_tokenize("Hello my friend, how is your day going?")
.unwrap(),
vec![
"Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?"
("Hello".into(), (0, 5)),
("Ġmy".into(), (5, 8)),
("Ġfriend".into(), (8, 15)),
(",".into(), (15, 16)),
("Ġhow".into(), (16, 20)),
("Ġis".into(), (20, 23)),
("Ġyour".into(), (23, 28)),
("Ġday".into(), (28, 32)),
("Ġgoing".into(), (32, 38)),
("?".into(), (38, 39))
]
);
}
@ -154,7 +160,16 @@ mod tests {
.pre_tokenize("Hello my friend, how is your day going?")
.unwrap(),
vec![
"ĠHello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?"
("ĠHello".into(), (0, 6)),
("Ġmy".into(), (6, 9)),
("Ġfriend".into(), (9, 16)),
(",".into(), (16, 17)),
("Ġhow".into(), (17, 21)),
("Ġis".into(), (21, 24)),
("Ġyour".into(), (24, 29)),
("Ġday".into(), (29, 33)),
("Ġgoing".into(), (33, 39)),
("?".into(), (39, 40))
]
);
}
@ -176,7 +191,7 @@ mod tests {
let pre_tokenized = bl.pre_tokenize(&sample).unwrap();
let separated_tokens = pre_tokenized
.into_iter()
.map(|token| token.split("").map(|t| t.into()).collect::<Vec<_>>())
.map(|(token, _)| token.split("").map(|t| t.into()).collect::<Vec<_>>())
.flatten()
.collect::<Vec<_>>();
assert_eq!(sample, bl.decode(separated_tokens).unwrap());
@ -192,11 +207,11 @@ mod tests {
assert_eq!(
p,
vec![
String::from("Hello"),
String::from("Ġthere"),
String::from("Ċ"),
String::from("Hello"),
String::from("Ġthere")
("Hello".into(), (0, 5)),
("Ġthere".into(), (5, 11)),
("Ċ".into(), (11, 12)),
("Hello".into(), (12, 17)),
("Ġthere".into(), (17, 23))
]
);
}
@ -210,10 +225,10 @@ mod tests {
assert_eq!(
p,
vec![
String::from("Hello"),
String::from("Ġthere"),
String::from("ĠĠĠĠĠĠ"),
String::from("Ġdear")
("Hello".into(), (0, 5)),
("Ġthere".into(), (5, 11)),
("ĠĠĠĠĠĠ".into(), (11, 17)),
("Ġdear".into(), (17, 22))
]
);
}

View File

@ -1,9 +1,9 @@
use crate::tokenizer::{PreTokenizer, Result};
use crate::tokenizer::{Offsets, PreTokenizer, Result};
use regex::Regex;
pub struct Whitespace;
impl PreTokenizer for Whitespace {
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
fn pre_tokenize(&self, s: &str) -> Result<Vec<(String, Offsets)>> {
lazy_static! {
static ref RE: Regex = Regex::new(r"\w+|[^\w\s]+").unwrap();
}
@ -13,11 +13,15 @@ impl PreTokenizer for Whitespace {
captures
.iter()
.map(|m| {
m.map(|capture| s[capture.start()..capture.end()].to_owned())
.unwrap_or_else(|| String::from(""))
m.map(|capture| {
let (start, end) = (capture.start(), capture.end());
(s[start..end].to_owned(), (start, end))
})
.collect()
.unwrap_or_else(|| (String::from(""), (0, 0)))
})
.collect::<Vec<(String, Offsets)>>()
})
.flatten()
.collect())
}
}
@ -30,10 +34,23 @@ mod tests {
#[test]
fn basic() {
let tests = vec![
("Hey man!", vec!["Hey", "man", "!"]),
(
"Hey man!",
vec![
("Hey".into(), (0, 3)),
("man".into(), (4, 7)),
("!".into(), (7, 8)),
],
),
(
"How are you doing?",
vec!["How", "are", "you", "doing", "?"],
vec![
("How".into(), (0, 3)),
("are".into(), (4, 7)),
("you".into(), (8, 11)),
("doing".into(), (12, 17)),
("?".into(), (17, 18)),
],
),
];
let pretok = Whitespace;

View File

@ -25,70 +25,52 @@ impl PostProcessor for BertProcessing {
}
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
// Prepare ids
let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
let pair_ids = pair_encoding
.as_ref()
.map(|encoding| [&encoding.get_ids()[..], &[self.sep.1]].concat());
// Prepare tokens
let type_ids = [&[0], &encoding.get_type_ids()[..], &[0]].concat();
let tokens = [
&[self.cls.0.clone()],
&encoding.get_tokens()[..],
&[self.sep.0.clone()],
]
.concat();
let pair_tokens = pair_encoding
.as_ref()
.map(|encoding| [&encoding.get_tokens()[..], &[self.sep.0.clone()]].concat());
// Prepare offsets
let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
let pair_offsets = pair_encoding
.as_ref()
.map(|encoding| [&encoding.get_offsets()[..], &[(0, 0)]].concat());
// Prepare type ids
let type_ids = [&[0], &encoding.get_type_ids()[..], &[0]].concat();
let pair_type_ids = pair_encoding
.as_ref()
.map(|encoding| [&encoding.get_type_ids()[..], &[1]].concat());
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let pair_special_tokens = pair_encoding
.as_ref()
.map(|encoding| [&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat());
let attention_mask = vec![1; ids.len()];
let attention_mask = vec![1; ids.len() + pair_ids.as_ref().map(|e| e.len()).unwrap_or(0)];
Ok(Encoding::new(
format!(
"{}{}",
encoding.get_original(),
pair_encoding
.as_ref()
.map(|e| e.get_original())
.unwrap_or("")
),
format!(
"{}{}",
encoding.get_normalized(),
pair_encoding
.as_ref()
.map(|e| e.get_normalized())
.unwrap_or("")
),
[&ids[..], &pair_ids.unwrap_or_else(|| vec![])[..]].concat(),
[&type_ids[..], &pair_type_ids.unwrap_or_else(|| vec![])[..]].concat(),
[&tokens[..], &pair_tokens.unwrap_or_else(|| vec![])[..]].concat(),
[&offsets[..], &pair_offsets.unwrap_or_else(|| vec![])[..]].concat(),
[
&special_tokens[..],
&pair_special_tokens.unwrap_or_else(|| vec![])[..],
]
.concat(),
let mut new_encoding = Encoding::new(
encoding.get_normalized().clone(),
ids,
type_ids,
tokens,
offsets,
special_tokens,
attention_mask,
encoding.take_overflowing(),
))
);
if let Some(mut encoding) = pair_encoding {
let pair_ids = [&encoding.get_ids()[..], &[self.sep.1]].concat();
let pair_type_ids = [&encoding.get_type_ids()[..], &[1]].concat();
let pair_tokens = [&encoding.get_tokens()[..], &[self.sep.0.clone()]].concat();
let pair_offsets = [&encoding.get_offsets()[..], &[(0, 0)]].concat();
let pair_special_tokens =
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
let new_pair_encoding = Encoding::new(
encoding.get_normalized().clone(),
pair_ids,
pair_type_ids,
pair_tokens,
pair_offsets,
pair_special_tokens,
pair_attention_mask,
encoding.take_overflowing(),
);
new_encoding.merge_with(new_pair_encoding);
}
Ok(new_encoding)
}
}

View File

@ -1,3 +1,5 @@
use crate::tokenizer::NormalizedString;
/// The various possible padding directions
#[derive(Debug, Clone)]
pub enum PaddingDirection {
@ -8,8 +10,7 @@ pub enum PaddingDirection {
/// The Encoding struct represents the output of the Tokenizer
#[derive(Default, PartialEq, Debug, Clone)]
pub struct Encoding {
original: String,
normalized: String,
normalized: NormalizedString,
ids: Vec<u32>,
type_ids: Vec<u32>,
tokens: Vec<String>,
@ -21,8 +22,7 @@ pub struct Encoding {
impl Encoding {
#[allow(clippy::too_many_arguments)]
pub fn new(
original: String,
normalized: String,
normalized: NormalizedString,
ids: Vec<u32>,
type_ids: Vec<u32>,
tokens: Vec<String>,
@ -32,7 +32,6 @@ impl Encoding {
overflowing: Option<Box<Encoding>>,
) -> Self {
Encoding {
original,
normalized,
ids,
type_ids,
@ -44,11 +43,7 @@ impl Encoding {
}
}
pub fn get_original(&self) -> &str {
&self.original
}
pub fn get_normalized(&self) -> &str {
pub fn get_normalized(&self) -> &NormalizedString {
&self.normalized
}
@ -96,15 +91,7 @@ impl Encoding {
let mut o_spe_toks = self.special_tokens_mask.split_off(max_len);
let mut o_attent = self.attention_mask.split_off(max_len);
// Figure out offsets for original and normalized
// TODO: We will be able to retrive the right part of original
// only when we will have the alignment difference between both
// For now we will use the normalized offset...
let max = self
.offsets
.iter()
.fold(0, |max, (_, end)| if *end > max { *end } else { max });
let trunc_original = self.original.split_off(max);
let max = self.offsets.last().map(|(_, end)| *end).unwrap_or(0);
let trunc_normalized = self.normalized.split_off(max);
if stride > 0 {
@ -117,7 +104,6 @@ impl Encoding {
}
self.overflowing = Some(Box::new(Encoding {
original: trunc_original,
normalized: trunc_normalized,
ids: o_ids,
type_ids: o_type_ids,
@ -130,8 +116,7 @@ impl Encoding {
}
pub fn merge_with(&mut self, pair: Encoding) {
self.original.push_str(&pair.original);
self.normalized.push_str(&pair.normalized);
self.normalized.merge_with(&pair.normalized);
self.ids.extend(pair.ids);
self.type_ids.extend(pair.type_ids);
self.tokens.extend(pair.tokens);
@ -224,8 +209,7 @@ mod tests {
#[test]
fn merge_encodings() {
let mut a = Encoding {
original: String::from("Hello "),
normalized: String::from("Hello "),
normalized: NormalizedString::from("Hello "),
ids: vec![1],
type_ids: vec![0],
tokens: vec![String::from("Hello ")],
@ -235,8 +219,7 @@ mod tests {
overflowing: None,
};
let b = Encoding {
original: String::from("World!"),
normalized: String::from("World!"),
normalized: NormalizedString::from("World!"),
ids: vec![2],
type_ids: vec![1],
tokens: vec![String::from("World!")],
@ -250,8 +233,7 @@ mod tests {
assert_eq!(
a,
Encoding {
original: String::from("Hello World!"),
normalized: String::from("Hello World!"),
normalized: NormalizedString::from("Hello World!"),
ids: vec![1, 2],
type_ids: vec![0, 1],
tokens: vec![String::from("Hello "), String::from("World!")],
@ -266,8 +248,7 @@ mod tests {
#[test]
fn truncate() {
let mut a = Encoding {
original: String::from("Hello World!"),
normalized: String::from("Hello World!"),
normalized: NormalizedString::from("Hello World!"),
ids: vec![1, 2, 3],
type_ids: vec![0, 0, 0],
tokens: vec![
@ -285,8 +266,7 @@ mod tests {
assert_eq!(
a,
Encoding {
original: String::from("Hello World"),
normalized: String::from("Hello World"),
normalized: NormalizedString::from("Hello World"),
ids: vec![1, 2],
type_ids: vec![0, 0],
tokens: vec![String::from("Hello"), String::from("World")],
@ -294,8 +274,7 @@ mod tests {
special_tokens_mask: vec![0, 0],
attention_mask: vec![1, 1],
overflowing: Some(Box::new(Encoding {
original: String::from("!"),
normalized: String::from("!"),
normalized: NormalizedString::from("!"),
ids: vec![3],
type_ids: vec![0],
tokens: vec![String::from("!")],

View File

@ -24,24 +24,22 @@ use std::{
};
mod encoding;
mod normalizer;
pub use encoding::*;
pub use normalizer::*;
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
/// A Normalizer takes care of pre-processing strings
pub trait Normalizer {
fn normalize(&self, s: String) -> Result<String>;
}
pub type Offsets = (usize, usize);
/// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model
pub trait PreTokenizer {
// TODO: Should return offsets with each substring
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>>;
fn pre_tokenize(&self, s: &str) -> Result<Vec<(String, Offsets)>>;
}
/// Represents a `Model` used during Tokenization (Like BPE or Word or Unigram)
pub trait Model {
fn tokenize(&self, tokens: Vec<String>) -> Result<Vec<Token>>;
fn tokenize(&self, tokens: Vec<(String, Offsets)>) -> Result<Vec<Token>>;
fn token_to_id(&self, token: &str) -> Option<u32>;
fn id_to_token(&self, id: u32) -> Option<String>;
fn get_vocab_size(&self) -> usize;
@ -249,8 +247,7 @@ impl Tokenizer {
// If this is one of our added tokens, lets return an encoding directly
if let Some(id) = id {
return Ok(Encoding::new(
sentence.clone(),
sentence.clone(),
NormalizedString::from(&sentence),
vec![id],
vec![type_id],
vec![sentence.to_owned()],
@ -262,13 +259,10 @@ impl Tokenizer {
}
// 1. Normalization
// TODO: Make sure we have the offsets update necessary to go from the original text to
// the normalized one
let original = sentence.clone();
let normalized = self.normalize(sentence)?;
let normalized = self.normalize(&sentence)?;
// 2. Pre tokenization
let pre_tokenized = self.pre_tokenize(&normalized)?;
let pre_tokenized = self.pre_tokenize(&normalized.get())?;
// 3. Model
let output = self.model.tokenize(pre_tokenized)?;
@ -289,7 +283,6 @@ impl Tokenizer {
);
Ok(Encoding::new(
original,
normalized,
ids,
vec![type_id; length],
@ -397,9 +390,12 @@ impl Tokenizer {
for line in file.lines() {
let line = line?;
let normalized = self.normalize(line)?;
let pre_tokenized = self.pre_tokenize(&normalized)?;
trainer.process_tokens(&mut words, pre_tokenized);
let normalized = self.normalize(&line)?;
let pre_tokenized = self.pre_tokenize(normalized.get())?;
trainer.process_tokens(
&mut words,
pre_tokenized.into_iter().map(|(t, _)| t).collect(),
);
}
Ok(words)
@ -422,20 +418,22 @@ impl Tokenizer {
}
/// PreTokenization logic, handling the case where there is no PreTokenizer set
fn pre_tokenize(&self, sentence: &str) -> Result<Vec<String>> {
fn pre_tokenize(&self, sentence: &str) -> Result<Vec<(String, Offsets)>> {
match &self.pre_tokenizer {
None => Ok(vec![sentence.to_owned()]),
None => Ok(vec![(sentence.to_owned(), (0, sentence.len()))]),
Some(pre_tokenizer) => pre_tokenizer.pre_tokenize(sentence),
}
}
/// Normalization logic, go through all normalizers
fn normalize(&self, sentence: String) -> Result<String> {
fn normalize(&self, sequence: &str) -> Result<NormalizedString> {
let mut normalized = NormalizedString::from(sequence);
if let Some(normalizer) = &self.normalizer {
normalizer.normalize(sentence)
} else {
Ok(sentence)
normalizer.normalize(&mut normalized)?;
}
Ok(normalized)
}
/// Post processing logic, handling the case where there is no PostProcessor set

View File

@ -0,0 +1,333 @@
use super::Result;
use std::cmp::Ordering;
use unicode_normalization::UnicodeNormalization;
/// A Normalizer takes care of pre-processing strings
pub trait Normalizer {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()>;
}
/// A normalized string takes care of keeping both versions of a String, and
/// provides necessary alignments to retrieve ranges of both strings
#[derive(Default, Debug, Clone)]
pub struct NormalizedString {
original: String,
normalized: String,
/// Mapping from normalized string to original one
/// (pos, changes) where pos is the position in the modified string, and changes an isize
/// representing the number of insertions or deletions
alignments: Vec<(usize, usize)>,
}
impl std::cmp::PartialEq for NormalizedString {
fn eq(&self, other: &NormalizedString) -> bool {
self.normalized == other.normalized
}
}
impl NormalizedString {
pub fn from(s: &str) -> Self {
NormalizedString {
original: s.to_owned(),
normalized: s.to_owned(),
alignments: (0..s.chars().count()).map(|v| (v, v + 1)).collect(),
}
}
pub fn get(&self) -> &str {
&self.normalized
}
pub fn get_original(&self) -> &str {
&self.original
}
/// Return a range of the normalized string
pub fn get_range(&self, range: std::ops::Range<usize>) -> Option<&str> {
self.normalized.get(range)
}
/// Return a range of the original string, using a range from the normalized string
pub fn get_range_original(&self, range: std::ops::Range<usize>) -> Option<&str> {
self.alignments
.get(range)
.map(|alignments| {
if alignments.is_empty() {
None
} else {
let start = alignments[0].0;
let end = alignments[alignments.len() - 1].1;
self.original.get(start..end)
}
})
.flatten()
}
/// Applies transformations to the current normalized version, updating the current
/// alignments with the new ones.
/// This method expect an Iterator yielding each char of the new normalized string
/// with a `change` isize equals to:
/// - `1` if this is a new char
/// - `-N` if the char is right before N removed chars
/// - `0` if this char represents the old one (even if changed)
///
/// `change` should never be more than `1`. If multiple chars are added, each of
/// them has a `change` of `1`, but more doesn't make any sense.
/// We treat any value above `1` as `1`.
pub fn transform<I: Iterator<Item = (char, isize)>>(&mut self, dest: I) {
let mut offset: isize = 0;
let (ch, alignments): (Vec<_>, Vec<_>) = dest
.enumerate()
.map(|(index, (c, changes))| {
let uof = if offset < 0 {
-offset as usize
} else {
offset as usize
};
// A positive offset means we added characters. So we need to remove this offset
// from the current index to find out the previous id
let idx = if offset < 0 { index + uof } else { index - uof };
let align = match changes.cmp(&0) {
// This is a newly inserted character, so we use the alignment from the
// previous one
Ordering::Greater => {
if idx < 1 {
Some((0, 0))
} else {
offset += 1;
self.alignments.get(idx - 1).copied()
}
}
// No changes required here
Ordering::Equal => self.alignments.get(idx).copied(),
// Some characters where removed, so we merge our range with the one from the
// removed characters as the new alignment
Ordering::Less => {
let uch = -changes as usize;
offset += changes;
self.alignments.get(idx..idx + uch).map(|alignments| {
let min = alignments
.iter()
.map(|(start, end)| usize::min(*start, *end))
.min()
.unwrap();
let max = alignments
.iter()
.map(|(start, end)| usize::max(*start, *end))
.max()
.unwrap();
(min, max)
})
}
};
// Then we keep only the char for string reconstruction
(
c,
align.expect("Bad alignement in NormalizedString::transform"),
)
})
.unzip();
self.alignments = alignments;
self.normalized = ch.iter().collect::<String>();
}
/// Applies NFD normalization
pub fn nfd(&mut self) -> &mut Self {
self.transform(self.get().to_owned().nfd());
self
}
/// Applies NFKD normalization
pub fn nfkd(&mut self) -> &mut Self {
self.transform(self.get().to_owned().nfkd());
self
}
/// Applies NFC normalization
pub fn nfc(&mut self) -> &mut Self {
self.transform(self.get().to_owned().nfc());
self
}
/// Applies NFKC normalization
pub fn nfkc(&mut self) -> &mut Self {
self.transform(self.get().to_owned().nfkc());
self
}
/// Applies filtering over our characters
pub fn filter<F: Fn(&char) -> bool>(&mut self, filter: F) -> &mut Self {
let mut removed = 0;
let mut filtered = self
.normalized
.chars()
// We need to collect here to be able to reverse the iterator because Char is not ended
.collect::<Vec<_>>()
.into_iter()
.rev()
.map(|c| {
let keep = filter(&c);
if keep {
if removed > 0 {
let res = (c, -removed);
removed = 0;
Some(res)
} else {
Some((c, 0))
}
} else {
removed += 1;
None
}
})
.collect::<Vec<_>>();
// For some reason, if we use rev, and unwrap directly, some parts of the tuples we return
// above get mixed up... So we collect first, then reverse in place
filtered.reverse();
self.transform(filtered.iter().filter(|o| o.is_some()).map(|o| o.unwrap()));
self
}
/// Map our characters
pub fn map<F: Fn(char) -> char>(&mut self, map: F) -> &mut Self {
self.normalized = self.normalized.chars().map(map).collect::<String>();
self
}
/// Calls the given function for each characters
pub fn for_each<F: FnMut(char)>(&mut self, foreach: F) -> &mut Self {
self.normalized.chars().for_each(foreach);
self
}
/// Lowercase
pub fn lowercase(&mut self) -> &mut Self {
let mut new_chars: Vec<(char, isize)> = vec![];
self.for_each(|c| {
c.to_lowercase().enumerate().for_each(|(index, c)| {
new_chars.push((c, if index > 0 { 1 } else { 0 }));
})
});
self.transform(new_chars.into_iter());
self
}
/// Uppercase
pub fn uppercase(&mut self) -> &mut Self {
let mut new_chars: Vec<(char, isize)> = vec![];
self.for_each(|c| {
c.to_uppercase().enumerate().for_each(|(index, c)| {
new_chars.push((c, if index > 0 { 1 } else { 0 }));
})
});
self.transform(new_chars.into_iter());
self
}
/// Split off ourselves, returning a new Self that contains the range [at, len).
/// self will then contain the range [0, at).
///
/// Panic if at > len
pub fn split_off(&mut self, at: usize) -> Self {
let normalized = self.normalized.split_off(at);
let alignments = self.alignments.split_off(at);
let original_at = self.alignments.last().map(|(_, end)| *end).unwrap_or(0);
let original = self.original.split_off(original_at);
NormalizedString {
original,
normalized,
alignments,
}
}
/// Merge with the given NormalizedString by appending it to self
pub fn merge_with(&mut self, other: &NormalizedString) {
self.original.push_str(&other.original);
let len = self.len();
self.alignments.extend(
other
.alignments
.iter()
.map(|(start, end)| (start + len, end + len)),
);
self.normalized.push_str(&other.normalized);
}
/// Returns the length
pub fn len(&self) -> usize {
self.normalized.len()
}
/// Whether empty
pub fn is_empty(&self) -> bool {
self.normalized.len() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use unicode_categories::UnicodeCategories;
#[test]
fn new_chars() {
let mut n = NormalizedString::from("élégant");
n.nfd();
assert_eq!(
&n.alignments,
&[
(0, 1),
(0, 1),
(1, 2),
(2, 3),
(2, 3),
(3, 4),
(4, 5),
(5, 6),
(6, 7)
]
);
}
#[test]
fn unchanged() {
let mut n = NormalizedString::from("élégant");
n.nfd().filter(|c| !c.is_mark_nonspacing());
assert_eq!(
&n.alignments,
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]
);
}
#[test]
fn removed_chars() {
let mut n = NormalizedString::from("élégant");
n.filter(|c| *c != 'n');
assert_eq!(
&n.alignments,
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (6, 7)]
);
}
#[test]
fn mixed_addition_and_removal() {
let mut n = NormalizedString::from("élégant");
n.nfd().filter(|c| !c.is_mark_nonspacing() && *c != 'n');
assert_eq!(
&n.alignments,
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (6, 7)]
);
}
#[test]
fn original_range() {
let mut n = NormalizedString::from("Hello_______ World!");
n.filter(|c| *c != '_').lowercase();
let world_n = n.get_range(6..11).unwrap();
let world_o = n.get_range_original(6..11).unwrap();
assert_eq!(world_n, "world");
assert_eq!(world_o, "World");
}
}