mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
6
bindings/python/Cargo.lock
generated
6
bindings/python/Cargo.lock
generated
@ -456,14 +456,14 @@ dependencies = [
|
|||||||
"regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
"regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
"regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"serde_json 1.0.44 (registry+https://github.com/rust-lang/crates.io-index)",
|
"serde_json 1.0.44 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)",
|
"unicode-normalization 0.1.11 (git+https://github.com/n1t0/unicode-normalization)",
|
||||||
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-normalization"
|
name = "unicode-normalization"
|
||||||
version = "0.1.11"
|
version = "0.1.11"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "git+https://github.com/n1t0/unicode-normalization#894053d92493c55c89fe9b188c0fb2babaa9a84c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"smallvec 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"smallvec 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
@ -570,7 +570,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
"checksum syn 1.0.11 (registry+https://github.com/rust-lang/crates.io-index)" = "dff0acdb207ae2fe6d5976617f887eb1e35a2ba52c13c7234c790960cdad9238"
|
"checksum syn 1.0.11 (registry+https://github.com/rust-lang/crates.io-index)" = "dff0acdb207ae2fe6d5976617f887eb1e35a2ba52c13c7234c790960cdad9238"
|
||||||
"checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
|
"checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
|
||||||
"checksum thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b"
|
"checksum thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b"
|
||||||
"checksum unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)" = "b561e267b2326bb4cebfc0ef9e68355c7abe6c6f522aeac2f5bf95d56c59bdcf"
|
"checksum unicode-normalization 0.1.11 (git+https://github.com/n1t0/unicode-normalization)" = "<none>"
|
||||||
"checksum unicode-width 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "caaa9d531767d1ff2150b9332433f32a24622147e5ebb1f26409d5da67afd479"
|
"checksum unicode-width 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "caaa9d531767d1ff2150b9332433f32a24622147e5ebb1f26409d5da67afd479"
|
||||||
"checksum unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "826e7639553986605ec5979c7dd957c7895e93eabed50ab2ffa7f6128a75097c"
|
"checksum unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "826e7639553986605ec5979c7dd957c7895e93eabed50ab2ffa7f6128a75097c"
|
||||||
"checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
"checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
||||||
|
@ -6,7 +6,7 @@ import logging
|
|||||||
logging.getLogger('transformers').disabled = True
|
logging.getLogger('transformers').disabled = True
|
||||||
logging.getLogger('transformers.tokenization_utils').disabled = True
|
logging.getLogger('transformers.tokenization_utils').disabled = True
|
||||||
|
|
||||||
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors
|
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, processors, normalizers
|
||||||
from transformers import GPT2Tokenizer, BertTokenizer
|
from transformers import GPT2Tokenizer, BertTokenizer
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -61,8 +61,19 @@ elif args.type == "bert":
|
|||||||
print("Running Bert tokenizer")
|
print("Running Bert tokenizer")
|
||||||
tok_p = BertTokenizer.from_pretrained('bert-base-uncased')
|
tok_p = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
|
|
||||||
tok_r = Tokenizer(models.WordPiece.from_files(args.vocab, unk_token="[UNK]", max_input_chars_per_word=100))
|
tok_r = Tokenizer(models.WordPiece.from_files(
|
||||||
tok_r.with_pre_tokenizer(pre_tokenizers.BasicPreTokenizer.new(do_lower_case=True, tokenize_chinese_chars=True, never_split=[]))
|
args.vocab,
|
||||||
|
unk_token="[UNK]",
|
||||||
|
max_input_chars_per_word=100)
|
||||||
|
)
|
||||||
|
tok_r.with_normalizer(normalizers.BertNormalizer.new(
|
||||||
|
clean_text=True,
|
||||||
|
handle_chinese_chars=True,
|
||||||
|
strip_accents=True,
|
||||||
|
lowercase=True,
|
||||||
|
))
|
||||||
|
# tok_r.with_pre_tokenizer(pre_tokenizers.Whitespace.new())
|
||||||
|
tok_r.with_pre_tokenizer(pre_tokenizers.BertPreTokenizer.new())
|
||||||
tok_r.with_decoder(decoders.WordPiece.new())
|
tok_r.with_decoder(decoders.WordPiece.new())
|
||||||
tok_r.with_post_processor(processors.BertProcessing.new(
|
tok_r.with_post_processor(processors.BertProcessing.new(
|
||||||
("[SEP]", tok_r.token_to_id("[SEP]")),
|
("[SEP]", tok_r.token_to_id("[SEP]")),
|
||||||
@ -75,7 +86,7 @@ def tokenize_r():
|
|||||||
return tok_r.encode_batch(text);
|
return tok_r.encode_batch(text);
|
||||||
|
|
||||||
def tokenize_p():
|
def tokenize_p():
|
||||||
return [tok_p.encode(sentence) for sentence in tqdm(text)]
|
return [tok_p.encode(sentence, add_special_tokens=True) for sentence in tqdm(text)]
|
||||||
|
|
||||||
print(f"Tokenizing {len(text)} lines")
|
print(f"Tokenizing {len(text)} lines")
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
extern crate tokenizers as tk;
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::types::*;
|
||||||
|
|
||||||
#[pyclass(dict)]
|
#[pyclass(dict)]
|
||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
@ -16,15 +17,43 @@ impl Encoding {
|
|||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl Encoding {
|
impl Encoding {
|
||||||
// #[getter]
|
#[getter]
|
||||||
// fn get_original(&self) -> String {
|
fn get_original(&self) -> String {
|
||||||
// self.encoding.get_original().to_owned()
|
self.encoding.get_normalized().get_original().to_owned()
|
||||||
// }
|
}
|
||||||
|
|
||||||
// #[getter]
|
#[getter]
|
||||||
// fn get_normalized(&self) -> String {
|
fn get_normalized(&self) -> String {
|
||||||
// self.encoding.get_normalized().to_owned()
|
self.encoding.get_normalized().get().to_owned()
|
||||||
// }
|
}
|
||||||
|
|
||||||
|
#[args(kwargs = "**")]
|
||||||
|
fn get_range(
|
||||||
|
&self,
|
||||||
|
range: (usize, usize),
|
||||||
|
kwargs: Option<&PyDict>,
|
||||||
|
) -> PyResult<Option<String>> {
|
||||||
|
let mut original = false;
|
||||||
|
if let Some(kwargs) = kwargs {
|
||||||
|
if let Some(koriginal) = kwargs.get_item("original") {
|
||||||
|
original = koriginal.extract()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if original {
|
||||||
|
Ok(self
|
||||||
|
.encoding
|
||||||
|
.get_normalized()
|
||||||
|
.get_range_original(range.0..range.1)
|
||||||
|
.map(|s| s.to_owned()))
|
||||||
|
} else {
|
||||||
|
Ok(self
|
||||||
|
.encoding
|
||||||
|
.get_normalized()
|
||||||
|
.get_range(range.0..range.1)
|
||||||
|
.map(|s| s.to_owned()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_ids(&self) -> Vec<u32> {
|
fn get_ids(&self) -> Vec<u32> {
|
||||||
@ -41,10 +70,10 @@ impl Encoding {
|
|||||||
self.encoding.get_type_ids().to_vec()
|
self.encoding.get_type_ids().to_vec()
|
||||||
}
|
}
|
||||||
|
|
||||||
// #[getter]
|
#[getter]
|
||||||
// fn get_offsets(&self) -> Vec<(usize, usize)> {
|
fn get_offsets(&self) -> Vec<(usize, usize)> {
|
||||||
// self.encoding.get_offsets().to_vec()
|
self.encoding.get_offsets().to_vec()
|
||||||
// }
|
}
|
||||||
|
|
||||||
#[getter]
|
#[getter]
|
||||||
fn get_special_tokens_mask(&self) -> Vec<u32> {
|
fn get_special_tokens_mask(&self) -> Vec<u32> {
|
||||||
|
@ -2,6 +2,7 @@ mod decoders;
|
|||||||
mod encoding;
|
mod encoding;
|
||||||
mod error;
|
mod error;
|
||||||
mod models;
|
mod models;
|
||||||
|
mod normalizers;
|
||||||
mod pre_tokenizers;
|
mod pre_tokenizers;
|
||||||
mod processors;
|
mod processors;
|
||||||
mod token;
|
mod token;
|
||||||
@ -55,6 +56,14 @@ fn processors(_py: Python, m: &PyModule) -> PyResult<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Normalizers Module
|
||||||
|
#[pymodule]
|
||||||
|
fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
|
m.add_class::<normalizers::Normalizer>()?;
|
||||||
|
m.add_class::<normalizers::BertNormalizer>()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Tokenizers Module
|
/// Tokenizers Module
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||||
@ -63,6 +72,7 @@ fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
|||||||
m.add_wrapped(wrap_pymodule!(pre_tokenizers))?;
|
m.add_wrapped(wrap_pymodule!(pre_tokenizers))?;
|
||||||
m.add_wrapped(wrap_pymodule!(decoders))?;
|
m.add_wrapped(wrap_pymodule!(decoders))?;
|
||||||
m.add_wrapped(wrap_pymodule!(processors))?;
|
m.add_wrapped(wrap_pymodule!(processors))?;
|
||||||
|
m.add_wrapped(wrap_pymodule!(normalizers))?;
|
||||||
m.add_wrapped(wrap_pymodule!(trainers))?;
|
m.add_wrapped(wrap_pymodule!(trainers))?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
46
bindings/python/src/normalizers.rs
Normal file
46
bindings/python/src/normalizers.rs
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
extern crate tokenizers as tk;
|
||||||
|
|
||||||
|
use super::utils::Container;
|
||||||
|
use pyo3::prelude::*;
|
||||||
|
use pyo3::types::*;
|
||||||
|
|
||||||
|
#[pyclass(dict)]
|
||||||
|
pub struct Normalizer {
|
||||||
|
pub normalizer: Container<dyn tk::tokenizer::Normalizer + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
pub struct BertNormalizer {}
|
||||||
|
#[pymethods]
|
||||||
|
impl BertNormalizer {
|
||||||
|
#[staticmethod]
|
||||||
|
#[args(kwargs = "**")]
|
||||||
|
fn new(kwargs: Option<&PyDict>) -> PyResult<Normalizer> {
|
||||||
|
let mut clean_text = true;
|
||||||
|
let mut handle_chinese_chars = true;
|
||||||
|
let mut strip_accents = true;
|
||||||
|
let mut lowercase = true;
|
||||||
|
|
||||||
|
if let Some(kwargs) = kwargs {
|
||||||
|
for (key, value) in kwargs {
|
||||||
|
let key: &str = key.extract()?;
|
||||||
|
match key {
|
||||||
|
"clean_text" => clean_text = value.extract()?,
|
||||||
|
"handle_chinese_chars" => handle_chinese_chars = value.extract()?,
|
||||||
|
"strip_accents" => strip_accents = value.extract()?,
|
||||||
|
"lowercase" => lowercase = value.extract()?,
|
||||||
|
_ => println!("Ignored unknown kwargs option {}", key),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Normalizer {
|
||||||
|
normalizer: Container::Owned(Box::new(tk::normalizers::bert::BertNormalizer::new(
|
||||||
|
clean_text,
|
||||||
|
handle_chinese_chars,
|
||||||
|
strip_accents,
|
||||||
|
lowercase,
|
||||||
|
))),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -4,8 +4,7 @@ use super::error::{PyError, ToPyResult};
|
|||||||
use super::utils::Container;
|
use super::utils::Container;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
use pyo3::types::*;
|
use pyo3::types::*;
|
||||||
use std::collections::HashSet;
|
use tk::tokenizer::{Offsets, Result};
|
||||||
use tk::tokenizer::Result;
|
|
||||||
|
|
||||||
#[pyclass(dict)]
|
#[pyclass(dict)]
|
||||||
pub struct PreTokenizer {
|
pub struct PreTokenizer {
|
||||||
@ -21,7 +20,7 @@ impl PreTokenizer {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pre_tokenize(&self, s: &str) -> PyResult<Vec<String>> {
|
fn pre_tokenize(&self, s: &str) -> PyResult<Vec<(String, Offsets)>> {
|
||||||
ToPyResult(self.pretok.execute(|pretok| pretok.pre_tokenize(s))).into()
|
ToPyResult(self.pretok.execute(|pretok| pretok.pre_tokenize(s))).into()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -58,36 +57,9 @@ pub struct BertPreTokenizer {}
|
|||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl BertPreTokenizer {
|
impl BertPreTokenizer {
|
||||||
#[staticmethod]
|
#[staticmethod]
|
||||||
#[args(kwargs = "**")]
|
fn new() -> PyResult<PreTokenizer> {
|
||||||
fn new(kwargs: Option<&PyDict>) -> PyResult<PreTokenizer> {
|
|
||||||
let mut do_basic_tokenize = true;
|
|
||||||
let mut do_lower_case = true;
|
|
||||||
let mut never_split = HashSet::new();
|
|
||||||
let mut tokenize_chinese_chars = true;
|
|
||||||
|
|
||||||
if let Some(kwargs) = kwargs {
|
|
||||||
for (key, val) in kwargs {
|
|
||||||
let key: &str = key.extract()?;
|
|
||||||
match key {
|
|
||||||
"do_basic_tokenize" => do_basic_tokenize = val.extract()?,
|
|
||||||
"do_lower_case" => do_lower_case = val.extract()?,
|
|
||||||
"tokenize_chinese_chars" => tokenize_chinese_chars = val.extract()?,
|
|
||||||
"never_split" => {
|
|
||||||
let values: Vec<String> = val.extract()?;
|
|
||||||
never_split = values.into_iter().collect();
|
|
||||||
}
|
|
||||||
_ => println!("Ignored unknown kwargs option {}", key),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(PreTokenizer {
|
Ok(PreTokenizer {
|
||||||
pretok: Container::Owned(Box::new(tk::pre_tokenizers::bert::BertPreTokenizer::new(
|
pretok: Container::Owned(Box::new(tk::pre_tokenizers::bert::BertPreTokenizer)),
|
||||||
do_basic_tokenize,
|
|
||||||
do_lower_case,
|
|
||||||
never_split,
|
|
||||||
tokenize_chinese_chars,
|
|
||||||
))),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -104,7 +76,7 @@ impl PyPreTokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl tk::tokenizer::PreTokenizer for PyPreTokenizer {
|
impl tk::tokenizer::PreTokenizer for PyPreTokenizer {
|
||||||
fn pre_tokenize(&self, sentence: &str) -> Result<Vec<String>> {
|
fn pre_tokenize(&self, sentence: &str) -> Result<Vec<(String, Offsets)>> {
|
||||||
let gil = Python::acquire_gil();
|
let gil = Python::acquire_gil();
|
||||||
let py = gil.python();
|
let py = gil.python();
|
||||||
|
|
||||||
@ -112,9 +84,15 @@ impl tk::tokenizer::PreTokenizer for PyPreTokenizer {
|
|||||||
match self.class.call_method(py, "pre_tokenize", args, None) {
|
match self.class.call_method(py, "pre_tokenize", args, None) {
|
||||||
Ok(res) => Ok(res
|
Ok(res) => Ok(res
|
||||||
.cast_as::<PyList>(py)
|
.cast_as::<PyList>(py)
|
||||||
.map_err(|_| PyError::from("`pre_tokenize is expected to return a List[str]"))?
|
.map_err(|_| {
|
||||||
.extract::<Vec<String>>()
|
PyError::from("`pre_tokenize is expected to return a List[(str, (uint, uint))]")
|
||||||
.map_err(|_| PyError::from("`pre_tokenize` is expected to return a List[str]"))?),
|
})?
|
||||||
|
.extract::<Vec<(String, Offsets)>>()
|
||||||
|
.map_err(|_| {
|
||||||
|
PyError::from(
|
||||||
|
"`pre_tokenize` is expected to return a List[(str, (uint, uint))]",
|
||||||
|
)
|
||||||
|
})?),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
e.print(py);
|
e.print(py);
|
||||||
Err(Box::new(PyError::from(
|
Err(Box::new(PyError::from(
|
||||||
|
@ -8,6 +8,7 @@ use super::decoders::Decoder;
|
|||||||
use super::encoding::Encoding;
|
use super::encoding::Encoding;
|
||||||
use super::error::{PyError, ToPyResult};
|
use super::error::{PyError, ToPyResult};
|
||||||
use super::models::Model;
|
use super::models::Model;
|
||||||
|
use super::normalizers::Normalizer;
|
||||||
use super::pre_tokenizers::PreTokenizer;
|
use super::pre_tokenizers::PreTokenizer;
|
||||||
use super::processors::PostProcessor;
|
use super::processors::PostProcessor;
|
||||||
use super::trainers::Trainer;
|
use super::trainers::Trainer;
|
||||||
@ -97,6 +98,17 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn with_normalizer(&mut self, normalizer: &mut Normalizer) -> PyResult<()> {
|
||||||
|
if let Some(normalizer) = normalizer.normalizer.to_pointer() {
|
||||||
|
self.tokenizer.with_normalizer(normalizer);
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(exceptions::Exception::py_err(
|
||||||
|
"The Normalizer is already being used in another Tokenizer",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[args(kwargs = "**")]
|
#[args(kwargs = "**")]
|
||||||
fn with_truncation(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
|
fn with_truncation(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
|
||||||
let mut stride = 0;
|
let mut stride = 0;
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
__version__ = "0.0.11"
|
__version__ = "0.0.11"
|
||||||
|
|
||||||
from .tokenizers import Tokenizer, models, decoders, pre_tokenizers, trainers, processors
|
from .tokenizers import Tokenizer, models, decoders, pre_tokenizers, trainers, processors, normalizers
|
||||||
|
@ -19,7 +19,7 @@ regex-syntax = "0.6.12"
|
|||||||
rayon = "1.2.0"
|
rayon = "1.2.0"
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
clap = "2.33.0"
|
clap = "2.33.0"
|
||||||
unicode-normalization = "0.1.11"
|
unicode-normalization = { git = "https://github.com/n1t0/unicode-normalization" }
|
||||||
unicode_categories = "0.1.1"
|
unicode_categories = "0.1.1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use super::{Cache, Error, Pair, Word};
|
use super::{Cache, Error, Pair, Word};
|
||||||
use crate::tokenizer::{Model, Result, Token};
|
use crate::tokenizer::{Model, Offsets, Result, Token};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
@ -103,15 +103,20 @@ impl Model for BPE {
|
|||||||
self.vocab.len()
|
self.vocab.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tokenize(&self, sentence: Vec<String>) -> Result<Vec<Token>> {
|
fn tokenize(&self, sentence: Vec<(String, Offsets)>) -> Result<Vec<Token>> {
|
||||||
if sentence.is_empty() {
|
if sentence.is_empty() {
|
||||||
return Ok(vec![]);
|
return Ok(vec![]);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
|
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
|
||||||
let mut cached_words = self.cache.get_values(&sentence);
|
let mut cached_words = self.cache.get_values(
|
||||||
|
&sentence
|
||||||
|
.iter()
|
||||||
|
.map(|(s, _)| s.to_owned())
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
);
|
||||||
|
|
||||||
for (i, w) in sentence.iter().enumerate() {
|
for (i, (w, initial_offsets)) in sentence.iter().enumerate() {
|
||||||
if cached_words[i].is_none() {
|
if cached_words[i].is_none() {
|
||||||
let mut word = Word::new();
|
let mut word = Word::new();
|
||||||
for c in w.chars() {
|
for c in w.chars() {
|
||||||
@ -155,9 +160,6 @@ impl Model for BPE {
|
|||||||
cached_words[i] = Some(word);
|
cached_words[i] = Some(word);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Offsets are word-based, we need to translate them to be sentence-based
|
|
||||||
let last_offset = encoded.last().map(|token| token.offsets.1).unwrap_or(0);
|
|
||||||
|
|
||||||
let word = cached_words[i].as_ref().unwrap();
|
let word = cached_words[i].as_ref().unwrap();
|
||||||
let tokens = word
|
let tokens = word
|
||||||
.get_chars()
|
.get_chars()
|
||||||
@ -167,7 +169,7 @@ impl Model for BPE {
|
|||||||
Token::new(
|
Token::new(
|
||||||
*id,
|
*id,
|
||||||
self.vocab_r[id].clone(),
|
self.vocab_r[id].clone(),
|
||||||
(last_offset + offsets.0, last_offset + offsets.1),
|
(initial_offsets.0 + offsets.0, initial_offsets.0 + offsets.1),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
@ -180,7 +182,7 @@ impl Model for BPE {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(cached_words)
|
.zip(cached_words)
|
||||||
.filter(|(_, v)| v.is_some())
|
.filter(|(_, v)| v.is_some())
|
||||||
.map(|(k, v)| (k, v.unwrap()))
|
.map(|(k, v)| (k.0, v.unwrap()))
|
||||||
.unzip::<_, _, Vec<String>, Vec<Word>>();
|
.unzip::<_, _, Vec<String>, Vec<Word>>();
|
||||||
self.cache.set_values(keys, values);
|
self.cache.set_values(keys, values);
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::tokenizer::{Model, Result, Token};
|
use crate::tokenizer::{Model, Offsets, Result, Token};
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
fmt,
|
fmt,
|
||||||
@ -70,11 +70,10 @@ impl Model for WordPiece {
|
|||||||
self.vocab.len()
|
self.vocab.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tokenize(&self, sentence: Vec<String>) -> Result<Vec<Token>> {
|
fn tokenize(&self, sentence: Vec<(String, Offsets)>) -> Result<Vec<Token>> {
|
||||||
let mut output_tokens = vec![];
|
let mut output_tokens = vec![];
|
||||||
|
|
||||||
let mut offset = 0usize;
|
for (token, initial_offsets) in sentence {
|
||||||
for token in sentence {
|
|
||||||
let char_len = token.chars().count();
|
let char_len = token.chars().count();
|
||||||
if char_len > self.max_input_chars_per_word {
|
if char_len > self.max_input_chars_per_word {
|
||||||
output_tokens.push(Token {
|
output_tokens.push(Token {
|
||||||
@ -83,7 +82,7 @@ impl Model for WordPiece {
|
|||||||
.vocab
|
.vocab
|
||||||
.get(&self.unk_token)
|
.get(&self.unk_token)
|
||||||
.ok_or(Error::MissingUnkToken)?,
|
.ok_or(Error::MissingUnkToken)?,
|
||||||
offsets: (offset, offset + char_len),
|
offsets: initial_offsets,
|
||||||
});
|
});
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -106,7 +105,7 @@ impl Model for WordPiece {
|
|||||||
cur_str = Some(Token {
|
cur_str = Some(Token {
|
||||||
id: self.vocab[&substr],
|
id: self.vocab[&substr],
|
||||||
value: substr,
|
value: substr,
|
||||||
offsets: (offset + start, offset + end),
|
offsets: (initial_offsets.0 + start, initial_offsets.0 + end),
|
||||||
});
|
});
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -129,13 +128,11 @@ impl Model for WordPiece {
|
|||||||
.vocab
|
.vocab
|
||||||
.get(&self.unk_token)
|
.get(&self.unk_token)
|
||||||
.ok_or(Error::MissingUnkToken)?,
|
.ok_or(Error::MissingUnkToken)?,
|
||||||
offsets: (offset, offset + char_len),
|
offsets: initial_offsets,
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
output_tokens.extend(sub_tokens);
|
output_tokens.extend(sub_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
offset += char_len;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(output_tokens)
|
Ok(output_tokens)
|
||||||
|
122
tokenizers/src/normalizers/bert.rs
Normal file
122
tokenizers/src/normalizers/bert.rs
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
use crate::tokenizer::{NormalizedString, Normalizer, Result};
|
||||||
|
use unicode_categories::UnicodeCategories;
|
||||||
|
|
||||||
|
/// Checks whether a character is whitespace
|
||||||
|
fn is_whitespace(c: char) -> bool {
|
||||||
|
// These are technically control characters but we count them as whitespace
|
||||||
|
if c == '\t' || c == '\n' || c == '\r' {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
c.is_whitespace()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks whether a character is a control character
|
||||||
|
fn is_control(c: char) -> bool {
|
||||||
|
// These are technically control characters but we count them as whitespace
|
||||||
|
if c == '\t' || c == '\n' || c == '\r' {
|
||||||
|
false
|
||||||
|
} else {
|
||||||
|
// The definition of `is_control` here is quite large and contains also
|
||||||
|
// Cc, Cf, Cn or Co
|
||||||
|
// cf. https://unicode.org/reports/tr44/ (Table 12)
|
||||||
|
c.is_other()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks whether a character is chinese
|
||||||
|
/// This defines a "chinese character" as anything in the CJK Unicode block:
|
||||||
|
/// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||||
|
///
|
||||||
|
/// Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||||
|
/// despite its name. The modern Korean Hangul alphabet is a different block,
|
||||||
|
/// as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||||
|
/// space-separated words, so they are not treated specially and handled
|
||||||
|
/// like for all of the other languages.
|
||||||
|
fn is_chinese_char(c: char) -> bool {
|
||||||
|
match c as usize {
|
||||||
|
0x4E00..=0x9FFF => true,
|
||||||
|
0x3400..=0x4DBF => true,
|
||||||
|
0x20000..=0x2A6DF => true,
|
||||||
|
0x2A700..=0x2B73F => true,
|
||||||
|
0x2B740..=0x2B81F => true,
|
||||||
|
0x2B920..=0x2CEAF => true,
|
||||||
|
0xF900..=0xFAFF => true,
|
||||||
|
0x2F800..=0x2FA1F => true,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct BertNormalizer {
|
||||||
|
/// Whether to do the bert basic cleaning:
|
||||||
|
/// 1. Remove any control characters
|
||||||
|
/// 2. Replace all sorts of whitespace by the classic one ` `
|
||||||
|
clean_text: bool,
|
||||||
|
/// Whether to put spaces around chinese characters so they get split
|
||||||
|
handle_chinese_chars: bool,
|
||||||
|
/// Whether to strip accents
|
||||||
|
strip_accents: bool,
|
||||||
|
/// Whether to lowercase the input
|
||||||
|
lowercase: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BertNormalizer {
|
||||||
|
pub fn new(
|
||||||
|
clean_text: bool,
|
||||||
|
handle_chinese_chars: bool,
|
||||||
|
strip_accents: bool,
|
||||||
|
lowercase: bool,
|
||||||
|
) -> Self {
|
||||||
|
BertNormalizer {
|
||||||
|
clean_text,
|
||||||
|
handle_chinese_chars,
|
||||||
|
strip_accents,
|
||||||
|
lowercase,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn do_clean_text(&self, normalized: &mut NormalizedString) {
|
||||||
|
normalized
|
||||||
|
.filter(|c| !(*c as usize == 0 || *c as usize == 0xfffd || is_control(*c)))
|
||||||
|
.map(|c| if is_whitespace(c) { ' ' } else { c });
|
||||||
|
}
|
||||||
|
|
||||||
|
fn do_handle_chinese_chars(&self, normalized: &mut NormalizedString) {
|
||||||
|
let mut new_chars: Vec<(char, isize)> = vec![];
|
||||||
|
normalized.for_each(|c| {
|
||||||
|
if is_chinese_char(c) {
|
||||||
|
new_chars.extend(&[(' ', 1), (c, 0), (' ', 1)]);
|
||||||
|
} else {
|
||||||
|
new_chars.push((c, 0));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
normalized.transform(new_chars.into_iter());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn do_strip_accents(&self, normalized: &mut NormalizedString) {
|
||||||
|
normalized.nfd().filter(|c| !c.is_mark_nonspacing());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn do_lowercase(&self, normalized: &mut NormalizedString) {
|
||||||
|
normalized.lowercase();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Normalizer for BertNormalizer {
|
||||||
|
fn normalize(&self, mut normalized: &mut NormalizedString) -> Result<()> {
|
||||||
|
if self.clean_text {
|
||||||
|
self.do_clean_text(&mut normalized);
|
||||||
|
}
|
||||||
|
if self.handle_chinese_chars {
|
||||||
|
self.do_handle_chinese_chars(&mut normalized);
|
||||||
|
}
|
||||||
|
if self.strip_accents {
|
||||||
|
self.do_strip_accents(&mut normalized);
|
||||||
|
}
|
||||||
|
if self.lowercase {
|
||||||
|
self.do_lowercase(&mut normalized);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
@ -1 +1 @@
|
|||||||
|
pub mod bert;
|
||||||
|
@ -1,182 +1,76 @@
|
|||||||
use crate::tokenizer::{PreTokenizer, Result};
|
use crate::tokenizer::{Offsets, PreTokenizer, Result};
|
||||||
use std::collections::HashSet;
|
|
||||||
use unicode_categories::UnicodeCategories;
|
|
||||||
use unicode_normalization::UnicodeNormalization;
|
|
||||||
|
|
||||||
/// Extremely simple tokenization on whitespaces
|
/// Split the given string as the `should_split` predicate dictates. Keep track of the offsets
|
||||||
fn whitespace_tokenize(s: &str) -> Vec<&str> {
|
fn split_on<F: Fn(&char) -> bool>(
|
||||||
s.trim()
|
s: &str,
|
||||||
.split(char::is_whitespace)
|
should_split: F,
|
||||||
.filter(|s| *s != " ")
|
include_split_token: bool,
|
||||||
.collect()
|
) -> Vec<(String, Offsets)> {
|
||||||
}
|
let mut words: Vec<(String, Offsets)> = vec![];
|
||||||
|
let mut offset = 0;
|
||||||
/// Checks whether a character is whitespace
|
let mut word = Vec::with_capacity(50);
|
||||||
fn is_whitespace(c: char) -> bool {
|
s.chars().for_each(|c| {
|
||||||
// These are technically control characters but we count them as whitespace
|
if should_split(&c) {
|
||||||
if c == '\t' || c == '\n' || c == '\r' {
|
if !word.is_empty() {
|
||||||
true
|
let offsets = (offset - word.len(), offset);
|
||||||
} else {
|
words.push((word.drain(0..).collect::<String>(), offsets));
|
||||||
c.is_whitespace()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Checks whether a character is a control character
|
|
||||||
fn is_control(c: char) -> bool {
|
|
||||||
// These are technically control characters but we count them as whitespace
|
|
||||||
if c == '\t' || c == '\n' || c == '\r' {
|
|
||||||
false
|
|
||||||
} else {
|
|
||||||
// The definition of `is_control` here is quite large and contains also
|
|
||||||
// Cc, Cf, Cn or Co
|
|
||||||
// cf. https://unicode.org/reports/tr44/ (Table 12)
|
|
||||||
c.is_other()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Checks whether a character is chinese
|
|
||||||
/// This defines a "chinese character" as anything in the CJK Unicode block:
|
|
||||||
/// https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
|
||||||
///
|
|
||||||
/// Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
|
||||||
/// despite its name. The modern Korean Hangul alphabet is a different block,
|
|
||||||
/// as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
|
||||||
/// space-separated words, so they are not treated specially and handled
|
|
||||||
/// like for all of the other languages.
|
|
||||||
fn is_chinese_char(c: char) -> bool {
|
|
||||||
match c as usize {
|
|
||||||
0x4E00..=0x9FFF => true,
|
|
||||||
0x3400..=0x4DBF => true,
|
|
||||||
0x20000..=0x2A6DF => true,
|
|
||||||
0x2A700..=0x2B73F => true,
|
|
||||||
0x2B740..=0x2B81F => true,
|
|
||||||
0x2B920..=0x2CEAF => true,
|
|
||||||
0xF900..=0xFAFF => true,
|
|
||||||
0x2F800..=0x2FA1F => true,
|
|
||||||
_ => false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct BertPreTokenizer {
|
|
||||||
/// Whether to do the basic tokenization
|
|
||||||
do_basic_tokenize: bool,
|
|
||||||
/// Whether to lower case the input.
|
|
||||||
do_lower_case: bool,
|
|
||||||
/// A list of token not to split.
|
|
||||||
never_split: HashSet<String>,
|
|
||||||
/// Whether to tokenize Chinese characters
|
|
||||||
tokenize_chinese_chars: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BertPreTokenizer {
|
|
||||||
pub fn new(
|
|
||||||
do_basic_tokenize: bool,
|
|
||||||
do_lower_case: bool,
|
|
||||||
never_split: HashSet<String>,
|
|
||||||
tokenize_chinese_chars: bool,
|
|
||||||
) -> Self {
|
|
||||||
BertPreTokenizer {
|
|
||||||
do_basic_tokenize,
|
|
||||||
do_lower_case,
|
|
||||||
never_split,
|
|
||||||
tokenize_chinese_chars,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Strips accents from a piece of text
|
|
||||||
fn run_strip_accents(&self, text: &str) -> String {
|
|
||||||
text.nfd()
|
|
||||||
.filter(|c| !c.is_mark_nonspacing())
|
|
||||||
.collect::<String>()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Splits punctuation on a piece of text.
|
|
||||||
fn run_split_on_punc(&self, text: &str) -> Vec<String> {
|
|
||||||
if self.never_split.contains(text) {
|
|
||||||
return vec![text.to_owned()];
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut output: Vec<Vec<char>> = vec![];
|
|
||||||
let mut start_new_word = true;
|
|
||||||
text.chars().for_each(|c| {
|
|
||||||
if c.is_ascii_punctuation() {
|
|
||||||
output.push(vec![c]);
|
|
||||||
start_new_word = true;
|
|
||||||
} else {
|
|
||||||
if start_new_word {
|
|
||||||
output.push(vec![]);
|
|
||||||
}
|
|
||||||
start_new_word = false;
|
|
||||||
output.last_mut().unwrap().push(c);
|
|
||||||
}
|
}
|
||||||
});
|
if include_split_token {
|
||||||
|
words.push((c.to_string(), (offset, offset + 1)));
|
||||||
output
|
}
|
||||||
.into_iter()
|
} else if !should_split(&c) {
|
||||||
.map(|cs| cs.into_iter().collect::<String>())
|
word.push(c);
|
||||||
.collect()
|
}
|
||||||
|
offset += 1;
|
||||||
|
});
|
||||||
|
// Don't forget the potential last word
|
||||||
|
if !word.is_empty() {
|
||||||
|
let offsets = (offset - word.len(), offset);
|
||||||
|
words.push((word.drain(0..).collect::<String>(), offsets));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tokenize_chinese_chars(&self, text: &str) -> String {
|
words
|
||||||
text.chars()
|
|
||||||
.map(|c| {
|
|
||||||
if is_chinese_char(c) {
|
|
||||||
vec![' ', c, ' ']
|
|
||||||
} else {
|
|
||||||
vec![c]
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.flatten()
|
|
||||||
.collect::<String>()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn clean_text(&self, text: &str) -> String {
|
|
||||||
text.chars()
|
|
||||||
.map(|c| {
|
|
||||||
if c as usize == 0 || c as usize == 0xfffd || is_control(c) {
|
|
||||||
None
|
|
||||||
} else if is_whitespace(c) {
|
|
||||||
Some(' ')
|
|
||||||
} else {
|
|
||||||
Some(c)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.filter(|c| c.is_some())
|
|
||||||
.map(|c| c.unwrap())
|
|
||||||
.collect::<String>()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct BertPreTokenizer;
|
||||||
|
|
||||||
impl PreTokenizer for BertPreTokenizer {
|
impl PreTokenizer for BertPreTokenizer {
|
||||||
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
|
fn pre_tokenize(&self, s: &str) -> Result<Vec<(String, Offsets)>> {
|
||||||
if !self.do_basic_tokenize {
|
let mut split_tokens = vec![];
|
||||||
Ok(whitespace_tokenize(&s)
|
for (token, offsets) in split_on(&s, |c| char::is_whitespace(*c), false) {
|
||||||
.into_iter()
|
split_tokens.extend(
|
||||||
.map(|s| s.to_owned())
|
split_on(&token, char::is_ascii_punctuation, true)
|
||||||
.collect())
|
.into_iter()
|
||||||
} else {
|
.map(|(tok, off)| (tok, (off.0 + offsets.0, off.1 + offsets.0))),
|
||||||
let mut text = self.clean_text(s);
|
);
|
||||||
|
|
||||||
// This was added on November 1st, 2018 for the multilingual and Chinese
|
|
||||||
// models. This is also applied to the English models now, but it doesn't
|
|
||||||
// matter since the English models were not trained on any Chinese data
|
|
||||||
// and generally don't have any Chinese data in them (there are Chinese
|
|
||||||
// characters in the vocabulary because Wikipedia does have some Chinese
|
|
||||||
// words in the English Wikipedia.).
|
|
||||||
if self.tokenize_chinese_chars {
|
|
||||||
text = self.tokenize_chinese_chars(&text);
|
|
||||||
}
|
|
||||||
let orig_tokens = whitespace_tokenize(&text);
|
|
||||||
let mut split_tokens = vec![];
|
|
||||||
for token in orig_tokens {
|
|
||||||
let mut tk = token.to_owned();
|
|
||||||
if self.do_lower_case && !self.never_split.contains(token) {
|
|
||||||
tk = self.run_strip_accents(&token.to_lowercase())
|
|
||||||
}
|
|
||||||
split_tokens.extend(self.run_split_on_punc(&tk));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(split_tokens)
|
|
||||||
}
|
}
|
||||||
|
Ok(split_tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn basic() {
|
||||||
|
let pretok = BertPreTokenizer;
|
||||||
|
let res = pretok
|
||||||
|
.pre_tokenize("Hey friend! How are you?!?")
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
&res,
|
||||||
|
&[
|
||||||
|
("Hey".into(), (0, 3)),
|
||||||
|
("friend".into(), (4, 10)),
|
||||||
|
("!".into(), (10, 11)),
|
||||||
|
("How".into(), (16, 19)),
|
||||||
|
("are".into(), (20, 23)),
|
||||||
|
("you".into(), (24, 27)),
|
||||||
|
("?".into(), (27, 28)),
|
||||||
|
("!".into(), (28, 29)),
|
||||||
|
("?".into(), (29, 30)),
|
||||||
|
]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::tokenizer::{Decoder, PreTokenizer, Result};
|
use crate::tokenizer::{Decoder, Offsets, PreTokenizer, Result};
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use unicode_categories::UnicodeCategories;
|
use unicode_categories::UnicodeCategories;
|
||||||
@ -41,7 +41,7 @@ impl ByteLevel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PreTokenizer for ByteLevel {
|
impl PreTokenizer for ByteLevel {
|
||||||
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
|
fn pre_tokenize(&self, s: &str) -> Result<Vec<(String, Offsets)>> {
|
||||||
let s = if self.add_prefix_space && !s.starts_with(' ') {
|
let s = if self.add_prefix_space && !s.starts_with(' ') {
|
||||||
format!(" {}", s)
|
format!(" {}", s)
|
||||||
} else {
|
} else {
|
||||||
@ -59,19 +59,11 @@ impl PreTokenizer for ByteLevel {
|
|||||||
// we don't want to return it
|
// we don't want to return it
|
||||||
let last = s[start..end].chars().last();
|
let last = s[start..end].chars().last();
|
||||||
let next = s[end..].chars().nth(0);
|
let next = s[end..].chars().nth(0);
|
||||||
if last.is_some()
|
if let (Some(last), Some(next)) = (last, next) {
|
||||||
&& last.unwrap().is_separator_space()
|
if last.is_separator_space() && !next.is_separator_space() {
|
||||||
&& next.is_some()
|
let bytes = s[start..end - 1].as_bytes().to_vec();
|
||||||
&& !next.unwrap().is_separator_space()
|
let offsets = (start, end - 1);
|
||||||
{
|
return (bytes, offsets);
|
||||||
if let Some(newstr) = s[start..end]
|
|
||||||
.chars()
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.split_last()
|
|
||||||
.map(|(_, rest)| rest)
|
|
||||||
.map(|chars| chars.iter().collect::<String>())
|
|
||||||
{
|
|
||||||
return newstr;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// if our first char is not a whitespace but the previous one was, we return
|
// if our first char is not a whitespace but the previous one was, we return
|
||||||
@ -80,17 +72,22 @@ impl PreTokenizer for ByteLevel {
|
|||||||
let current = s[start..end].chars().nth(0).map(|c| c.is_whitespace());
|
let current = s[start..end].chars().nth(0).map(|c| c.is_whitespace());
|
||||||
if let (Some(prev), Some(current)) = (prev, current) {
|
if let (Some(prev), Some(current)) = (prev, current) {
|
||||||
if prev.is_separator_space() && !current {
|
if prev.is_separator_space() && !current {
|
||||||
return format!("{}{}", prev, s[start..end].to_owned());
|
let bytes =
|
||||||
|
[format!("{}", prev).as_bytes(), s[start..end].as_bytes()].concat();
|
||||||
|
let offsets = (start - 1, end);
|
||||||
|
return (bytes, offsets);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s[start..end].to_owned()
|
(s[start..end].as_bytes().to_vec(), (start, end))
|
||||||
})
|
})
|
||||||
.map(|s| {
|
.map(|(s, offsets)| {
|
||||||
s.into_bytes()
|
(
|
||||||
.iter()
|
s.iter()
|
||||||
.map(|b| std::char::from_u32(BYTES_CHAR[b]).unwrap())
|
.map(|b| std::char::from_u32(BYTES_CHAR[b]).unwrap())
|
||||||
.collect()
|
.collect(),
|
||||||
|
offsets,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
@ -122,7 +119,16 @@ mod tests {
|
|||||||
.pre_tokenize("Hello my friend, how is your day going?")
|
.pre_tokenize("Hello my friend, how is your day going?")
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
vec![
|
vec![
|
||||||
"Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?"
|
("Hello".into(), (0, 5)),
|
||||||
|
("Ġmy".into(), (5, 8)),
|
||||||
|
("Ġfriend".into(), (8, 15)),
|
||||||
|
(",".into(), (15, 16)),
|
||||||
|
("Ġhow".into(), (16, 20)),
|
||||||
|
("Ġis".into(), (20, 23)),
|
||||||
|
("Ġyour".into(), (23, 28)),
|
||||||
|
("Ġday".into(), (28, 32)),
|
||||||
|
("Ġgoing".into(), (32, 38)),
|
||||||
|
("?".into(), (38, 39))
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -154,7 +160,16 @@ mod tests {
|
|||||||
.pre_tokenize("Hello my friend, how is your day going?")
|
.pre_tokenize("Hello my friend, how is your day going?")
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
vec![
|
vec![
|
||||||
"ĠHello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?"
|
("ĠHello".into(), (0, 6)),
|
||||||
|
("Ġmy".into(), (6, 9)),
|
||||||
|
("Ġfriend".into(), (9, 16)),
|
||||||
|
(",".into(), (16, 17)),
|
||||||
|
("Ġhow".into(), (17, 21)),
|
||||||
|
("Ġis".into(), (21, 24)),
|
||||||
|
("Ġyour".into(), (24, 29)),
|
||||||
|
("Ġday".into(), (29, 33)),
|
||||||
|
("Ġgoing".into(), (33, 39)),
|
||||||
|
("?".into(), (39, 40))
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -176,7 +191,7 @@ mod tests {
|
|||||||
let pre_tokenized = bl.pre_tokenize(&sample).unwrap();
|
let pre_tokenized = bl.pre_tokenize(&sample).unwrap();
|
||||||
let separated_tokens = pre_tokenized
|
let separated_tokens = pre_tokenized
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|token| token.split("").map(|t| t.into()).collect::<Vec<_>>())
|
.map(|(token, _)| token.split("").map(|t| t.into()).collect::<Vec<_>>())
|
||||||
.flatten()
|
.flatten()
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
assert_eq!(sample, bl.decode(separated_tokens).unwrap());
|
assert_eq!(sample, bl.decode(separated_tokens).unwrap());
|
||||||
@ -192,11 +207,11 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
p,
|
p,
|
||||||
vec![
|
vec![
|
||||||
String::from("Hello"),
|
("Hello".into(), (0, 5)),
|
||||||
String::from("Ġthere"),
|
("Ġthere".into(), (5, 11)),
|
||||||
String::from("Ċ"),
|
("Ċ".into(), (11, 12)),
|
||||||
String::from("Hello"),
|
("Hello".into(), (12, 17)),
|
||||||
String::from("Ġthere")
|
("Ġthere".into(), (17, 23))
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -210,10 +225,10 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
p,
|
p,
|
||||||
vec![
|
vec![
|
||||||
String::from("Hello"),
|
("Hello".into(), (0, 5)),
|
||||||
String::from("Ġthere"),
|
("Ġthere".into(), (5, 11)),
|
||||||
String::from("ĠĠĠĠĠĠ"),
|
("ĠĠĠĠĠĠ".into(), (11, 17)),
|
||||||
String::from("Ġdear")
|
("Ġdear".into(), (17, 22))
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
use crate::tokenizer::{PreTokenizer, Result};
|
use crate::tokenizer::{Offsets, PreTokenizer, Result};
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
|
|
||||||
pub struct Whitespace;
|
pub struct Whitespace;
|
||||||
impl PreTokenizer for Whitespace {
|
impl PreTokenizer for Whitespace {
|
||||||
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
|
fn pre_tokenize(&self, s: &str) -> Result<Vec<(String, Offsets)>> {
|
||||||
lazy_static! {
|
lazy_static! {
|
||||||
static ref RE: Regex = Regex::new(r"\w+|[^\w\s]+").unwrap();
|
static ref RE: Regex = Regex::new(r"\w+|[^\w\s]+").unwrap();
|
||||||
}
|
}
|
||||||
@ -13,11 +13,15 @@ impl PreTokenizer for Whitespace {
|
|||||||
captures
|
captures
|
||||||
.iter()
|
.iter()
|
||||||
.map(|m| {
|
.map(|m| {
|
||||||
m.map(|capture| s[capture.start()..capture.end()].to_owned())
|
m.map(|capture| {
|
||||||
.unwrap_or_else(|| String::from(""))
|
let (start, end) = (capture.start(), capture.end());
|
||||||
|
(s[start..end].to_owned(), (start, end))
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|| (String::from(""), (0, 0)))
|
||||||
})
|
})
|
||||||
.collect()
|
.collect::<Vec<(String, Offsets)>>()
|
||||||
})
|
})
|
||||||
|
.flatten()
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -30,10 +34,23 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn basic() {
|
fn basic() {
|
||||||
let tests = vec![
|
let tests = vec![
|
||||||
("Hey man!", vec!["Hey", "man", "!"]),
|
(
|
||||||
|
"Hey man!",
|
||||||
|
vec![
|
||||||
|
("Hey".into(), (0, 3)),
|
||||||
|
("man".into(), (4, 7)),
|
||||||
|
("!".into(), (7, 8)),
|
||||||
|
],
|
||||||
|
),
|
||||||
(
|
(
|
||||||
"How are you doing?",
|
"How are you doing?",
|
||||||
vec!["How", "are", "you", "doing", "?"],
|
vec![
|
||||||
|
("How".into(), (0, 3)),
|
||||||
|
("are".into(), (4, 7)),
|
||||||
|
("you".into(), (8, 11)),
|
||||||
|
("doing".into(), (12, 17)),
|
||||||
|
("?".into(), (17, 18)),
|
||||||
|
],
|
||||||
),
|
),
|
||||||
];
|
];
|
||||||
let pretok = Whitespace;
|
let pretok = Whitespace;
|
||||||
|
@ -25,70 +25,52 @@ impl PostProcessor for BertProcessing {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
|
fn process(&self, mut encoding: Encoding, pair_encoding: Option<Encoding>) -> Result<Encoding> {
|
||||||
// Prepare ids
|
|
||||||
let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
|
let ids = [&[self.cls.1], &encoding.get_ids()[..], &[self.sep.1]].concat();
|
||||||
let pair_ids = pair_encoding
|
let type_ids = [&[0], &encoding.get_type_ids()[..], &[0]].concat();
|
||||||
.as_ref()
|
|
||||||
.map(|encoding| [&encoding.get_ids()[..], &[self.sep.1]].concat());
|
|
||||||
|
|
||||||
// Prepare tokens
|
|
||||||
let tokens = [
|
let tokens = [
|
||||||
&[self.cls.0.clone()],
|
&[self.cls.0.clone()],
|
||||||
&encoding.get_tokens()[..],
|
&encoding.get_tokens()[..],
|
||||||
&[self.sep.0.clone()],
|
&[self.sep.0.clone()],
|
||||||
]
|
]
|
||||||
.concat();
|
.concat();
|
||||||
let pair_tokens = pair_encoding
|
|
||||||
.as_ref()
|
|
||||||
.map(|encoding| [&encoding.get_tokens()[..], &[self.sep.0.clone()]].concat());
|
|
||||||
|
|
||||||
// Prepare offsets
|
|
||||||
let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
|
let offsets = [&[(0, 0)], &encoding.get_offsets()[..], &[(0, 0)]].concat();
|
||||||
let pair_offsets = pair_encoding
|
|
||||||
.as_ref()
|
|
||||||
.map(|encoding| [&encoding.get_offsets()[..], &[(0, 0)]].concat());
|
|
||||||
|
|
||||||
// Prepare type ids
|
|
||||||
let type_ids = [&[0], &encoding.get_type_ids()[..], &[0]].concat();
|
|
||||||
let pair_type_ids = pair_encoding
|
|
||||||
.as_ref()
|
|
||||||
.map(|encoding| [&encoding.get_type_ids()[..], &[1]].concat());
|
|
||||||
|
|
||||||
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
|
let special_tokens = [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
|
||||||
let pair_special_tokens = pair_encoding
|
let attention_mask = vec![1; ids.len()];
|
||||||
.as_ref()
|
|
||||||
.map(|encoding| [&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat());
|
|
||||||
|
|
||||||
let attention_mask = vec![1; ids.len() + pair_ids.as_ref().map(|e| e.len()).unwrap_or(0)];
|
let mut new_encoding = Encoding::new(
|
||||||
|
encoding.get_normalized().clone(),
|
||||||
Ok(Encoding::new(
|
ids,
|
||||||
format!(
|
type_ids,
|
||||||
"{}{}",
|
tokens,
|
||||||
encoding.get_original(),
|
offsets,
|
||||||
pair_encoding
|
special_tokens,
|
||||||
.as_ref()
|
|
||||||
.map(|e| e.get_original())
|
|
||||||
.unwrap_or("")
|
|
||||||
),
|
|
||||||
format!(
|
|
||||||
"{}{}",
|
|
||||||
encoding.get_normalized(),
|
|
||||||
pair_encoding
|
|
||||||
.as_ref()
|
|
||||||
.map(|e| e.get_normalized())
|
|
||||||
.unwrap_or("")
|
|
||||||
),
|
|
||||||
[&ids[..], &pair_ids.unwrap_or_else(|| vec![])[..]].concat(),
|
|
||||||
[&type_ids[..], &pair_type_ids.unwrap_or_else(|| vec![])[..]].concat(),
|
|
||||||
[&tokens[..], &pair_tokens.unwrap_or_else(|| vec![])[..]].concat(),
|
|
||||||
[&offsets[..], &pair_offsets.unwrap_or_else(|| vec![])[..]].concat(),
|
|
||||||
[
|
|
||||||
&special_tokens[..],
|
|
||||||
&pair_special_tokens.unwrap_or_else(|| vec![])[..],
|
|
||||||
]
|
|
||||||
.concat(),
|
|
||||||
attention_mask,
|
attention_mask,
|
||||||
encoding.take_overflowing(),
|
encoding.take_overflowing(),
|
||||||
))
|
);
|
||||||
|
|
||||||
|
if let Some(mut encoding) = pair_encoding {
|
||||||
|
let pair_ids = [&encoding.get_ids()[..], &[self.sep.1]].concat();
|
||||||
|
let pair_type_ids = [&encoding.get_type_ids()[..], &[1]].concat();
|
||||||
|
let pair_tokens = [&encoding.get_tokens()[..], &[self.sep.0.clone()]].concat();
|
||||||
|
let pair_offsets = [&encoding.get_offsets()[..], &[(0, 0)]].concat();
|
||||||
|
let pair_special_tokens =
|
||||||
|
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
|
||||||
|
let pair_attention_mask = vec![1; pair_ids.len()];
|
||||||
|
|
||||||
|
let new_pair_encoding = Encoding::new(
|
||||||
|
encoding.get_normalized().clone(),
|
||||||
|
pair_ids,
|
||||||
|
pair_type_ids,
|
||||||
|
pair_tokens,
|
||||||
|
pair_offsets,
|
||||||
|
pair_special_tokens,
|
||||||
|
pair_attention_mask,
|
||||||
|
encoding.take_overflowing(),
|
||||||
|
);
|
||||||
|
|
||||||
|
new_encoding.merge_with(new_pair_encoding);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(new_encoding)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
use crate::tokenizer::NormalizedString;
|
||||||
|
|
||||||
/// The various possible padding directions
|
/// The various possible padding directions
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum PaddingDirection {
|
pub enum PaddingDirection {
|
||||||
@ -8,8 +10,7 @@ pub enum PaddingDirection {
|
|||||||
/// The Encoding struct represents the output of the Tokenizer
|
/// The Encoding struct represents the output of the Tokenizer
|
||||||
#[derive(Default, PartialEq, Debug, Clone)]
|
#[derive(Default, PartialEq, Debug, Clone)]
|
||||||
pub struct Encoding {
|
pub struct Encoding {
|
||||||
original: String,
|
normalized: NormalizedString,
|
||||||
normalized: String,
|
|
||||||
ids: Vec<u32>,
|
ids: Vec<u32>,
|
||||||
type_ids: Vec<u32>,
|
type_ids: Vec<u32>,
|
||||||
tokens: Vec<String>,
|
tokens: Vec<String>,
|
||||||
@ -21,8 +22,7 @@ pub struct Encoding {
|
|||||||
impl Encoding {
|
impl Encoding {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
original: String,
|
normalized: NormalizedString,
|
||||||
normalized: String,
|
|
||||||
ids: Vec<u32>,
|
ids: Vec<u32>,
|
||||||
type_ids: Vec<u32>,
|
type_ids: Vec<u32>,
|
||||||
tokens: Vec<String>,
|
tokens: Vec<String>,
|
||||||
@ -32,7 +32,6 @@ impl Encoding {
|
|||||||
overflowing: Option<Box<Encoding>>,
|
overflowing: Option<Box<Encoding>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Encoding {
|
Encoding {
|
||||||
original,
|
|
||||||
normalized,
|
normalized,
|
||||||
ids,
|
ids,
|
||||||
type_ids,
|
type_ids,
|
||||||
@ -44,11 +43,7 @@ impl Encoding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_original(&self) -> &str {
|
pub fn get_normalized(&self) -> &NormalizedString {
|
||||||
&self.original
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_normalized(&self) -> &str {
|
|
||||||
&self.normalized
|
&self.normalized
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,15 +91,7 @@ impl Encoding {
|
|||||||
let mut o_spe_toks = self.special_tokens_mask.split_off(max_len);
|
let mut o_spe_toks = self.special_tokens_mask.split_off(max_len);
|
||||||
let mut o_attent = self.attention_mask.split_off(max_len);
|
let mut o_attent = self.attention_mask.split_off(max_len);
|
||||||
|
|
||||||
// Figure out offsets for original and normalized
|
let max = self.offsets.last().map(|(_, end)| *end).unwrap_or(0);
|
||||||
// TODO: We will be able to retrive the right part of original
|
|
||||||
// only when we will have the alignment difference between both
|
|
||||||
// For now we will use the normalized offset...
|
|
||||||
let max = self
|
|
||||||
.offsets
|
|
||||||
.iter()
|
|
||||||
.fold(0, |max, (_, end)| if *end > max { *end } else { max });
|
|
||||||
let trunc_original = self.original.split_off(max);
|
|
||||||
let trunc_normalized = self.normalized.split_off(max);
|
let trunc_normalized = self.normalized.split_off(max);
|
||||||
|
|
||||||
if stride > 0 {
|
if stride > 0 {
|
||||||
@ -117,7 +104,6 @@ impl Encoding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.overflowing = Some(Box::new(Encoding {
|
self.overflowing = Some(Box::new(Encoding {
|
||||||
original: trunc_original,
|
|
||||||
normalized: trunc_normalized,
|
normalized: trunc_normalized,
|
||||||
ids: o_ids,
|
ids: o_ids,
|
||||||
type_ids: o_type_ids,
|
type_ids: o_type_ids,
|
||||||
@ -130,8 +116,7 @@ impl Encoding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn merge_with(&mut self, pair: Encoding) {
|
pub fn merge_with(&mut self, pair: Encoding) {
|
||||||
self.original.push_str(&pair.original);
|
self.normalized.merge_with(&pair.normalized);
|
||||||
self.normalized.push_str(&pair.normalized);
|
|
||||||
self.ids.extend(pair.ids);
|
self.ids.extend(pair.ids);
|
||||||
self.type_ids.extend(pair.type_ids);
|
self.type_ids.extend(pair.type_ids);
|
||||||
self.tokens.extend(pair.tokens);
|
self.tokens.extend(pair.tokens);
|
||||||
@ -224,8 +209,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn merge_encodings() {
|
fn merge_encodings() {
|
||||||
let mut a = Encoding {
|
let mut a = Encoding {
|
||||||
original: String::from("Hello "),
|
normalized: NormalizedString::from("Hello "),
|
||||||
normalized: String::from("Hello "),
|
|
||||||
ids: vec![1],
|
ids: vec![1],
|
||||||
type_ids: vec![0],
|
type_ids: vec![0],
|
||||||
tokens: vec![String::from("Hello ")],
|
tokens: vec![String::from("Hello ")],
|
||||||
@ -235,8 +219,7 @@ mod tests {
|
|||||||
overflowing: None,
|
overflowing: None,
|
||||||
};
|
};
|
||||||
let b = Encoding {
|
let b = Encoding {
|
||||||
original: String::from("World!"),
|
normalized: NormalizedString::from("World!"),
|
||||||
normalized: String::from("World!"),
|
|
||||||
ids: vec![2],
|
ids: vec![2],
|
||||||
type_ids: vec![1],
|
type_ids: vec![1],
|
||||||
tokens: vec![String::from("World!")],
|
tokens: vec![String::from("World!")],
|
||||||
@ -250,8 +233,7 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
a,
|
a,
|
||||||
Encoding {
|
Encoding {
|
||||||
original: String::from("Hello World!"),
|
normalized: NormalizedString::from("Hello World!"),
|
||||||
normalized: String::from("Hello World!"),
|
|
||||||
ids: vec![1, 2],
|
ids: vec![1, 2],
|
||||||
type_ids: vec![0, 1],
|
type_ids: vec![0, 1],
|
||||||
tokens: vec![String::from("Hello "), String::from("World!")],
|
tokens: vec![String::from("Hello "), String::from("World!")],
|
||||||
@ -266,8 +248,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn truncate() {
|
fn truncate() {
|
||||||
let mut a = Encoding {
|
let mut a = Encoding {
|
||||||
original: String::from("Hello World!"),
|
normalized: NormalizedString::from("Hello World!"),
|
||||||
normalized: String::from("Hello World!"),
|
|
||||||
ids: vec![1, 2, 3],
|
ids: vec![1, 2, 3],
|
||||||
type_ids: vec![0, 0, 0],
|
type_ids: vec![0, 0, 0],
|
||||||
tokens: vec![
|
tokens: vec![
|
||||||
@ -285,8 +266,7 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
a,
|
a,
|
||||||
Encoding {
|
Encoding {
|
||||||
original: String::from("Hello World"),
|
normalized: NormalizedString::from("Hello World"),
|
||||||
normalized: String::from("Hello World"),
|
|
||||||
ids: vec![1, 2],
|
ids: vec![1, 2],
|
||||||
type_ids: vec![0, 0],
|
type_ids: vec![0, 0],
|
||||||
tokens: vec![String::from("Hello"), String::from("World")],
|
tokens: vec![String::from("Hello"), String::from("World")],
|
||||||
@ -294,8 +274,7 @@ mod tests {
|
|||||||
special_tokens_mask: vec![0, 0],
|
special_tokens_mask: vec![0, 0],
|
||||||
attention_mask: vec![1, 1],
|
attention_mask: vec![1, 1],
|
||||||
overflowing: Some(Box::new(Encoding {
|
overflowing: Some(Box::new(Encoding {
|
||||||
original: String::from("!"),
|
normalized: NormalizedString::from("!"),
|
||||||
normalized: String::from("!"),
|
|
||||||
ids: vec![3],
|
ids: vec![3],
|
||||||
type_ids: vec![0],
|
type_ids: vec![0],
|
||||||
tokens: vec![String::from("!")],
|
tokens: vec![String::from("!")],
|
||||||
|
@ -24,24 +24,22 @@ use std::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
mod encoding;
|
mod encoding;
|
||||||
|
mod normalizer;
|
||||||
|
|
||||||
pub use encoding::*;
|
pub use encoding::*;
|
||||||
|
pub use normalizer::*;
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
|
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
|
||||||
|
pub type Offsets = (usize, usize);
|
||||||
/// A Normalizer takes care of pre-processing strings
|
|
||||||
pub trait Normalizer {
|
|
||||||
fn normalize(&self, s: String) -> Result<String>;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model
|
/// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model
|
||||||
pub trait PreTokenizer {
|
pub trait PreTokenizer {
|
||||||
// TODO: Should return offsets with each substring
|
fn pre_tokenize(&self, s: &str) -> Result<Vec<(String, Offsets)>>;
|
||||||
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Represents a `Model` used during Tokenization (Like BPE or Word or Unigram)
|
/// Represents a `Model` used during Tokenization (Like BPE or Word or Unigram)
|
||||||
pub trait Model {
|
pub trait Model {
|
||||||
fn tokenize(&self, tokens: Vec<String>) -> Result<Vec<Token>>;
|
fn tokenize(&self, tokens: Vec<(String, Offsets)>) -> Result<Vec<Token>>;
|
||||||
fn token_to_id(&self, token: &str) -> Option<u32>;
|
fn token_to_id(&self, token: &str) -> Option<u32>;
|
||||||
fn id_to_token(&self, id: u32) -> Option<String>;
|
fn id_to_token(&self, id: u32) -> Option<String>;
|
||||||
fn get_vocab_size(&self) -> usize;
|
fn get_vocab_size(&self) -> usize;
|
||||||
@ -249,8 +247,7 @@ impl Tokenizer {
|
|||||||
// If this is one of our added tokens, lets return an encoding directly
|
// If this is one of our added tokens, lets return an encoding directly
|
||||||
if let Some(id) = id {
|
if let Some(id) = id {
|
||||||
return Ok(Encoding::new(
|
return Ok(Encoding::new(
|
||||||
sentence.clone(),
|
NormalizedString::from(&sentence),
|
||||||
sentence.clone(),
|
|
||||||
vec![id],
|
vec![id],
|
||||||
vec![type_id],
|
vec![type_id],
|
||||||
vec![sentence.to_owned()],
|
vec![sentence.to_owned()],
|
||||||
@ -262,13 +259,10 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 1. Normalization
|
// 1. Normalization
|
||||||
// TODO: Make sure we have the offsets update necessary to go from the original text to
|
let normalized = self.normalize(&sentence)?;
|
||||||
// the normalized one
|
|
||||||
let original = sentence.clone();
|
|
||||||
let normalized = self.normalize(sentence)?;
|
|
||||||
|
|
||||||
// 2. Pre tokenization
|
// 2. Pre tokenization
|
||||||
let pre_tokenized = self.pre_tokenize(&normalized)?;
|
let pre_tokenized = self.pre_tokenize(&normalized.get())?;
|
||||||
|
|
||||||
// 3. Model
|
// 3. Model
|
||||||
let output = self.model.tokenize(pre_tokenized)?;
|
let output = self.model.tokenize(pre_tokenized)?;
|
||||||
@ -289,7 +283,6 @@ impl Tokenizer {
|
|||||||
);
|
);
|
||||||
|
|
||||||
Ok(Encoding::new(
|
Ok(Encoding::new(
|
||||||
original,
|
|
||||||
normalized,
|
normalized,
|
||||||
ids,
|
ids,
|
||||||
vec![type_id; length],
|
vec![type_id; length],
|
||||||
@ -397,9 +390,12 @@ impl Tokenizer {
|
|||||||
|
|
||||||
for line in file.lines() {
|
for line in file.lines() {
|
||||||
let line = line?;
|
let line = line?;
|
||||||
let normalized = self.normalize(line)?;
|
let normalized = self.normalize(&line)?;
|
||||||
let pre_tokenized = self.pre_tokenize(&normalized)?;
|
let pre_tokenized = self.pre_tokenize(normalized.get())?;
|
||||||
trainer.process_tokens(&mut words, pre_tokenized);
|
trainer.process_tokens(
|
||||||
|
&mut words,
|
||||||
|
pre_tokenized.into_iter().map(|(t, _)| t).collect(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(words)
|
Ok(words)
|
||||||
@ -422,20 +418,22 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// PreTokenization logic, handling the case where there is no PreTokenizer set
|
/// PreTokenization logic, handling the case where there is no PreTokenizer set
|
||||||
fn pre_tokenize(&self, sentence: &str) -> Result<Vec<String>> {
|
fn pre_tokenize(&self, sentence: &str) -> Result<Vec<(String, Offsets)>> {
|
||||||
match &self.pre_tokenizer {
|
match &self.pre_tokenizer {
|
||||||
None => Ok(vec![sentence.to_owned()]),
|
None => Ok(vec![(sentence.to_owned(), (0, sentence.len()))]),
|
||||||
Some(pre_tokenizer) => pre_tokenizer.pre_tokenize(sentence),
|
Some(pre_tokenizer) => pre_tokenizer.pre_tokenize(sentence),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Normalization logic, go through all normalizers
|
/// Normalization logic, go through all normalizers
|
||||||
fn normalize(&self, sentence: String) -> Result<String> {
|
fn normalize(&self, sequence: &str) -> Result<NormalizedString> {
|
||||||
|
let mut normalized = NormalizedString::from(sequence);
|
||||||
|
|
||||||
if let Some(normalizer) = &self.normalizer {
|
if let Some(normalizer) = &self.normalizer {
|
||||||
normalizer.normalize(sentence)
|
normalizer.normalize(&mut normalized)?;
|
||||||
} else {
|
|
||||||
Ok(sentence)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(normalized)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Post processing logic, handling the case where there is no PostProcessor set
|
/// Post processing logic, handling the case where there is no PostProcessor set
|
||||||
|
333
tokenizers/src/tokenizer/normalizer.rs
Normal file
333
tokenizers/src/tokenizer/normalizer.rs
Normal file
@ -0,0 +1,333 @@
|
|||||||
|
use super::Result;
|
||||||
|
use std::cmp::Ordering;
|
||||||
|
use unicode_normalization::UnicodeNormalization;
|
||||||
|
|
||||||
|
/// A Normalizer takes care of pre-processing strings
|
||||||
|
pub trait Normalizer {
|
||||||
|
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A normalized string takes care of keeping both versions of a String, and
|
||||||
|
/// provides necessary alignments to retrieve ranges of both strings
|
||||||
|
#[derive(Default, Debug, Clone)]
|
||||||
|
pub struct NormalizedString {
|
||||||
|
original: String,
|
||||||
|
normalized: String,
|
||||||
|
/// Mapping from normalized string to original one
|
||||||
|
/// (pos, changes) where pos is the position in the modified string, and changes an isize
|
||||||
|
/// representing the number of insertions or deletions
|
||||||
|
alignments: Vec<(usize, usize)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::cmp::PartialEq for NormalizedString {
|
||||||
|
fn eq(&self, other: &NormalizedString) -> bool {
|
||||||
|
self.normalized == other.normalized
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NormalizedString {
|
||||||
|
pub fn from(s: &str) -> Self {
|
||||||
|
NormalizedString {
|
||||||
|
original: s.to_owned(),
|
||||||
|
normalized: s.to_owned(),
|
||||||
|
alignments: (0..s.chars().count()).map(|v| (v, v + 1)).collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self) -> &str {
|
||||||
|
&self.normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_original(&self) -> &str {
|
||||||
|
&self.original
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a range of the normalized string
|
||||||
|
pub fn get_range(&self, range: std::ops::Range<usize>) -> Option<&str> {
|
||||||
|
self.normalized.get(range)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a range of the original string, using a range from the normalized string
|
||||||
|
pub fn get_range_original(&self, range: std::ops::Range<usize>) -> Option<&str> {
|
||||||
|
self.alignments
|
||||||
|
.get(range)
|
||||||
|
.map(|alignments| {
|
||||||
|
if alignments.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let start = alignments[0].0;
|
||||||
|
let end = alignments[alignments.len() - 1].1;
|
||||||
|
self.original.get(start..end)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.flatten()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies transformations to the current normalized version, updating the current
|
||||||
|
/// alignments with the new ones.
|
||||||
|
/// This method expect an Iterator yielding each char of the new normalized string
|
||||||
|
/// with a `change` isize equals to:
|
||||||
|
/// - `1` if this is a new char
|
||||||
|
/// - `-N` if the char is right before N removed chars
|
||||||
|
/// - `0` if this char represents the old one (even if changed)
|
||||||
|
///
|
||||||
|
/// `change` should never be more than `1`. If multiple chars are added, each of
|
||||||
|
/// them has a `change` of `1`, but more doesn't make any sense.
|
||||||
|
/// We treat any value above `1` as `1`.
|
||||||
|
pub fn transform<I: Iterator<Item = (char, isize)>>(&mut self, dest: I) {
|
||||||
|
let mut offset: isize = 0;
|
||||||
|
let (ch, alignments): (Vec<_>, Vec<_>) = dest
|
||||||
|
.enumerate()
|
||||||
|
.map(|(index, (c, changes))| {
|
||||||
|
let uof = if offset < 0 {
|
||||||
|
-offset as usize
|
||||||
|
} else {
|
||||||
|
offset as usize
|
||||||
|
};
|
||||||
|
// A positive offset means we added characters. So we need to remove this offset
|
||||||
|
// from the current index to find out the previous id
|
||||||
|
let idx = if offset < 0 { index + uof } else { index - uof };
|
||||||
|
let align = match changes.cmp(&0) {
|
||||||
|
// This is a newly inserted character, so we use the alignment from the
|
||||||
|
// previous one
|
||||||
|
Ordering::Greater => {
|
||||||
|
if idx < 1 {
|
||||||
|
Some((0, 0))
|
||||||
|
} else {
|
||||||
|
offset += 1;
|
||||||
|
self.alignments.get(idx - 1).copied()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// No changes required here
|
||||||
|
Ordering::Equal => self.alignments.get(idx).copied(),
|
||||||
|
// Some characters where removed, so we merge our range with the one from the
|
||||||
|
// removed characters as the new alignment
|
||||||
|
Ordering::Less => {
|
||||||
|
let uch = -changes as usize;
|
||||||
|
offset += changes;
|
||||||
|
self.alignments.get(idx..idx + uch).map(|alignments| {
|
||||||
|
let min = alignments
|
||||||
|
.iter()
|
||||||
|
.map(|(start, end)| usize::min(*start, *end))
|
||||||
|
.min()
|
||||||
|
.unwrap();
|
||||||
|
let max = alignments
|
||||||
|
.iter()
|
||||||
|
.map(|(start, end)| usize::max(*start, *end))
|
||||||
|
.max()
|
||||||
|
.unwrap();
|
||||||
|
(min, max)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Then we keep only the char for string reconstruction
|
||||||
|
(
|
||||||
|
c,
|
||||||
|
align.expect("Bad alignement in NormalizedString::transform"),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.unzip();
|
||||||
|
self.alignments = alignments;
|
||||||
|
self.normalized = ch.iter().collect::<String>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies NFD normalization
|
||||||
|
pub fn nfd(&mut self) -> &mut Self {
|
||||||
|
self.transform(self.get().to_owned().nfd());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies NFKD normalization
|
||||||
|
pub fn nfkd(&mut self) -> &mut Self {
|
||||||
|
self.transform(self.get().to_owned().nfkd());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies NFC normalization
|
||||||
|
pub fn nfc(&mut self) -> &mut Self {
|
||||||
|
self.transform(self.get().to_owned().nfc());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies NFKC normalization
|
||||||
|
pub fn nfkc(&mut self) -> &mut Self {
|
||||||
|
self.transform(self.get().to_owned().nfkc());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies filtering over our characters
|
||||||
|
pub fn filter<F: Fn(&char) -> bool>(&mut self, filter: F) -> &mut Self {
|
||||||
|
let mut removed = 0;
|
||||||
|
let mut filtered = self
|
||||||
|
.normalized
|
||||||
|
.chars()
|
||||||
|
// We need to collect here to be able to reverse the iterator because Char is not ended
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.into_iter()
|
||||||
|
.rev()
|
||||||
|
.map(|c| {
|
||||||
|
let keep = filter(&c);
|
||||||
|
if keep {
|
||||||
|
if removed > 0 {
|
||||||
|
let res = (c, -removed);
|
||||||
|
removed = 0;
|
||||||
|
Some(res)
|
||||||
|
} else {
|
||||||
|
Some((c, 0))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
removed += 1;
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
// For some reason, if we use rev, and unwrap directly, some parts of the tuples we return
|
||||||
|
// above get mixed up... So we collect first, then reverse in place
|
||||||
|
filtered.reverse();
|
||||||
|
self.transform(filtered.iter().filter(|o| o.is_some()).map(|o| o.unwrap()));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Map our characters
|
||||||
|
pub fn map<F: Fn(char) -> char>(&mut self, map: F) -> &mut Self {
|
||||||
|
self.normalized = self.normalized.chars().map(map).collect::<String>();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Calls the given function for each characters
|
||||||
|
pub fn for_each<F: FnMut(char)>(&mut self, foreach: F) -> &mut Self {
|
||||||
|
self.normalized.chars().for_each(foreach);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Lowercase
|
||||||
|
pub fn lowercase(&mut self) -> &mut Self {
|
||||||
|
let mut new_chars: Vec<(char, isize)> = vec![];
|
||||||
|
self.for_each(|c| {
|
||||||
|
c.to_lowercase().enumerate().for_each(|(index, c)| {
|
||||||
|
new_chars.push((c, if index > 0 { 1 } else { 0 }));
|
||||||
|
})
|
||||||
|
});
|
||||||
|
self.transform(new_chars.into_iter());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Uppercase
|
||||||
|
pub fn uppercase(&mut self) -> &mut Self {
|
||||||
|
let mut new_chars: Vec<(char, isize)> = vec![];
|
||||||
|
self.for_each(|c| {
|
||||||
|
c.to_uppercase().enumerate().for_each(|(index, c)| {
|
||||||
|
new_chars.push((c, if index > 0 { 1 } else { 0 }));
|
||||||
|
})
|
||||||
|
});
|
||||||
|
self.transform(new_chars.into_iter());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Split off ourselves, returning a new Self that contains the range [at, len).
|
||||||
|
/// self will then contain the range [0, at).
|
||||||
|
///
|
||||||
|
/// Panic if at > len
|
||||||
|
pub fn split_off(&mut self, at: usize) -> Self {
|
||||||
|
let normalized = self.normalized.split_off(at);
|
||||||
|
let alignments = self.alignments.split_off(at);
|
||||||
|
let original_at = self.alignments.last().map(|(_, end)| *end).unwrap_or(0);
|
||||||
|
let original = self.original.split_off(original_at);
|
||||||
|
|
||||||
|
NormalizedString {
|
||||||
|
original,
|
||||||
|
normalized,
|
||||||
|
alignments,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Merge with the given NormalizedString by appending it to self
|
||||||
|
pub fn merge_with(&mut self, other: &NormalizedString) {
|
||||||
|
self.original.push_str(&other.original);
|
||||||
|
let len = self.len();
|
||||||
|
self.alignments.extend(
|
||||||
|
other
|
||||||
|
.alignments
|
||||||
|
.iter()
|
||||||
|
.map(|(start, end)| (start + len, end + len)),
|
||||||
|
);
|
||||||
|
self.normalized.push_str(&other.normalized);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the length
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.normalized.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whether empty
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.normalized.len() == 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use unicode_categories::UnicodeCategories;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn new_chars() {
|
||||||
|
let mut n = NormalizedString::from("élégant");
|
||||||
|
n.nfd();
|
||||||
|
assert_eq!(
|
||||||
|
&n.alignments,
|
||||||
|
&[
|
||||||
|
(0, 1),
|
||||||
|
(0, 1),
|
||||||
|
(1, 2),
|
||||||
|
(2, 3),
|
||||||
|
(2, 3),
|
||||||
|
(3, 4),
|
||||||
|
(4, 5),
|
||||||
|
(5, 6),
|
||||||
|
(6, 7)
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unchanged() {
|
||||||
|
let mut n = NormalizedString::from("élégant");
|
||||||
|
n.nfd().filter(|c| !c.is_mark_nonspacing());
|
||||||
|
assert_eq!(
|
||||||
|
&n.alignments,
|
||||||
|
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn removed_chars() {
|
||||||
|
let mut n = NormalizedString::from("élégant");
|
||||||
|
n.filter(|c| *c != 'n');
|
||||||
|
assert_eq!(
|
||||||
|
&n.alignments,
|
||||||
|
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (6, 7)]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn mixed_addition_and_removal() {
|
||||||
|
let mut n = NormalizedString::from("élégant");
|
||||||
|
n.nfd().filter(|c| !c.is_mark_nonspacing() && *c != 'n');
|
||||||
|
assert_eq!(
|
||||||
|
&n.alignments,
|
||||||
|
&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (6, 7)]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn original_range() {
|
||||||
|
let mut n = NormalizedString::from("Hello_______ World!");
|
||||||
|
n.filter(|c| *c != '_').lowercase();
|
||||||
|
let world_n = n.get_range(6..11).unwrap();
|
||||||
|
let world_o = n.get_range_original(6..11).unwrap();
|
||||||
|
assert_eq!(world_n, "world");
|
||||||
|
assert_eq!(world_o, "World");
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user