Merge pull request #6 from huggingface/BPE-tests

Add BPE tests and documentation
This commit is contained in:
MOI Anthony
2019-12-20 15:34:38 -05:00
committed by GitHub
6 changed files with 277 additions and 11 deletions

103
tokenizers/Cargo.lock generated
View File

@ -35,6 +35,14 @@ name = "bitflags"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "c2-chacha"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "cfg-if"
version = "0.1.10"
@ -108,6 +116,16 @@ name = "either"
version = "1.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "getrandom"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
"wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "hermit-abi"
version = "0.1.3"
@ -153,6 +171,48 @@ dependencies = [
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "ppv-lite86"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "rand"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rand_chacha"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rand_core"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rayon"
version = "1.2.0"
@ -175,6 +235,11 @@ dependencies = [
"num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "redox_syscall"
version = "0.1.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "regex"
version = "1.3.1"
@ -191,6 +256,14 @@ name = "regex-syntax"
version = "0.6.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "remove_dir_all"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "rustc_version"
version = "0.2.3"
@ -247,6 +320,19 @@ name = "strsim"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "tempfile"
version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
"redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)",
"remove_dir_all 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)",
"winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "textwrap"
version = "0.11.0"
@ -273,6 +359,7 @@ dependencies = [
"regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)",
"tempfile 3.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
@ -300,6 +387,11 @@ name = "vec_map"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "wasi"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "winapi"
version = "0.3.8"
@ -325,6 +417,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum atty 0.2.13 (registry+https://github.com/rust-lang/crates.io-index)" = "1803c647a3ec87095e7ae7acfca019e98de5ec9a7d01343f611cf3152ed71a90"
"checksum autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2"
"checksum bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
"checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb"
"checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
"checksum clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5067f5bb2d80ef5d68b4c87db81601f0b75bca627bc2ef76b141d7b846a3c6d9"
"checksum crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3aa945d63861bfe624b55d153a39684da1e8c0bc8fba932f7ee3a3c16cea3ca"
@ -333,6 +426,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum crossbeam-utils 0.6.6 (registry+https://github.com/rust-lang/crates.io-index)" = "04973fa96e96579258a5091af6003abde64af786b860f18622b82e026cca60e6"
"checksum crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ce446db02cdc3165b94ae73111e570793400d0794e46125cc4056c81cbb039f4"
"checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
"checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120"
"checksum itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "501266b7edd0174f8530248f87f99c88fbe60ca4ef3dd486835b8d8d53136f7f"
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
@ -340,10 +434,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "88579771288728879b57485cc7d6b07d648c9f0141eb955f8ab7f9d45394468e"
"checksum memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "75189eb85871ea5c2e2c15abbdd541185f63b408415e5051f5cac122d8c774b9"
"checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72"
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"
"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853"
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
"checksum rayon 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "83a27732a533a1be0a0035a111fe76db89ad312f6f0347004c220c57f209a123"
"checksum rayon-core 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)" = "98dcf634205083b17d0861252431eb2acbfb698ab7478a2d20de07954f47ec7b"
"checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84"
"checksum regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "dc220bd33bdce8f093101afe22a037b8eb0e5af33592e6a9caafff0d4cb81cbd"
"checksum regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)" = "11a7e20d1cce64ef2fed88b66d347f88bd9babb82845b2b858f3edbf59a4f716"
"checksum remove_dir_all 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "4a83fa3702a688b9359eccba92d153ac33fd2e8462f9e0e3fdf155239ea7792e"
"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
"checksum ryu 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "bfa8506c1de11c9c4e4c38863ccbe02a305c8188e85a05a784c9e11e1c3910c8"
"checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d"
@ -353,12 +454,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)" = "2f72eb2a68a7dc3f9a691bfda9305a1c017a6215e5a4545c258500d2099a37c2"
"checksum smallvec 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "4ecf3b85f68e8abaa7555aa5abdb1153079387e60b718283d732f03897fcfc86"
"checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
"checksum tempfile 3.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7a6e24d9338a0a5be79593e2fa15a648add6138caa803e2d5bc782c371732ca9"
"checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
"checksum thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b"
"checksum unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)" = "b561e267b2326bb4cebfc0ef9e68355c7abe6c6f522aeac2f5bf95d56c59bdcf"
"checksum unicode-width 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7007dbd421b92cc6e28410fe7362e2e0a2503394908f417b68ec8d1c364c4e20"
"checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
"checksum vec_map 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "05c78687fb1a80548ae3250346c3db86a80a7cdd77bda190189f2d0a0987c81a"
"checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d"
"checksum winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)" = "8093091eeb260906a183e6ae1abdba2ef5ef2257a21801128899c3fc699229c6"
"checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
"checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"

View File

@ -3,6 +3,11 @@ authors = ["Anthony MOI <m.anthony.moi@gmail.com>"]
edition = "2018"
name = "tokenizers-lib"
version = "0.0.8"
[lib]
name = "tokenizers"
path = "src/lib.rs"
[[bin]]
name = "cli"
path = "src/cli.rs"
@ -17,6 +22,5 @@ clap = "2.33.0"
unicode-normalization = "0.1.11"
unicode_categories = "0.1.1"
[lib]
name = "tokenizers"
path = "src/lib.rs"
[dev-dependencies]
tempfile = "3.1"

View File

@ -17,6 +17,9 @@ pub enum Error {
JsonError(serde_json::Error),
/// When the vocab.json file is in the wrong format
BadVocabulary,
/// When the merges.txt file is in the wrong format. This error holds the line
/// number of the line that caused the error.
BadMerges(usize),
/// If a token found in merges, is not in the vocab
MergeTokenOutOfVocabulary(String),
}
@ -39,6 +42,7 @@ impl std::fmt::Display for Error {
Error::Io(e) => write!(f, "IoError: {}", e),
Error::JsonError(e) => write!(f, "JsonError: {}", e),
Error::BadVocabulary => write!(f, "Bad vocabulary json file"),
Error::BadMerges(line) => write!(f, "Merges text file invalid at line {}", line),
Error::MergeTokenOutOfVocabulary(token) => {
write!(f, "Token {} out of vocabulary", token)
}
@ -52,6 +56,7 @@ impl std::error::Error for Error {
Error::Io(e) => Some(e),
Error::JsonError(e) => Some(e),
Error::BadVocabulary => None,
Error::BadMerges(_) => None,
Error::MergeTokenOutOfVocabulary(_) => None,
}
}

View File

@ -70,6 +70,9 @@ impl BPE {
}
let parts = line.split(' ').collect::<Vec<_>>();
if parts.len() != 2 {
return Err(Error::BadMerges(rank + 1).into());
}
let a = vocab
.get(parts[0])
@ -192,3 +195,91 @@ impl Model for BPE {
self.vocab_r.get(&id).cloned()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
#[test]
// Ensure `BPE::from_files` works as expected.
fn test_bpe_from_files() {
// Set up vocab file.
let mut vocab_file = NamedTempFile::new().unwrap();
vocab_file
.write_all("{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}".as_bytes())
.unwrap();
// Set up merges file.
let mut merges_file = NamedTempFile::new().unwrap();
merges_file
.write_all("#version: 0.2\na b".as_bytes())
.unwrap();
// Make sure we can instatiate a BPE model from the files.
assert!(BPE::from_files(
vocab_file.path().to_str().unwrap(),
merges_file.path().to_str().unwrap()
)
.is_ok());
}
#[test]
// Ensure `MergeTokenOutOfVocabulary` error is returned when it should be.
fn test_bpe_from_files_merge_token_oov() {
// Set up vocab file.
let mut vocab_file = NamedTempFile::new().unwrap();
vocab_file
.write_all("{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}".as_bytes())
.unwrap();
// Set up merges file.
let mut merges_file = NamedTempFile::new().unwrap();
merges_file
.write_all("#version: 0.2\na b\na d".as_bytes())
.unwrap();
// Ensure the result of BPE::from_files is a MergeTokenOutOfVocabulary error.
match BPE::from_files(
vocab_file.path().to_str().unwrap(),
merges_file.path().to_str().unwrap(),
) {
Ok(_) => unreachable!(),
Err(err) => match err.downcast_ref::<Error>() {
Some(Error::MergeTokenOutOfVocabulary(token)) => {
assert_eq!(*token, String::from("d"))
}
_ => unreachable!(),
},
}
}
#[test]
// Ensure `BadMerges` error is returned when there is an invalid line in the
// merges.txt file.
fn test_bpe_from_files_bad_merges() {
// Set up vocab file.
let mut vocab_file = NamedTempFile::new().unwrap();
vocab_file
.write_all("{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}".as_bytes())
.unwrap();
// Set up merges file with a bad line.
let mut merges_file = NamedTempFile::new().unwrap();
merges_file
.write_all("#version: 0.2\na b\nc".as_bytes())
.unwrap();
// Ensure the result of BPE::from_files is a BadMerges error.
match BPE::from_files(
vocab_file.path().to_str().unwrap(),
merges_file.path().to_str().unwrap(),
) {
Ok(_) => unreachable!(),
Err(err) => match err.downcast_ref::<Error>() {
Some(Error::BadMerges(line)) => assert_eq!(*line, 3usize),
_ => unreachable!(),
},
}
}
}

View File

@ -1,8 +1,3 @@
//!
//! # Trainer
//!
//! In charge of training a BPE model
//!
#![allow(clippy::map_entry)]
use super::{Pair, Word, BPE};
@ -13,14 +8,15 @@ use std::{
};
pub struct BpeTrainerConfig {
vocab_size: usize,
min_frequency: u32,
vocab_size: usize,
}
impl BpeTrainerConfig {
pub fn new(min_frequency: u32, vocab_size: usize) -> Self {
BpeTrainerConfig {
vocab_size,
min_frequency,
vocab_size,
}
}
@ -32,12 +28,30 @@ impl BpeTrainerConfig {
self.min_frequency = value;
}
}
impl Default for BpeTrainerConfig {
fn default() -> Self {
BpeTrainerConfig::new(0, 30000)
}
}
/// In charge of training a BPE model from a mapping of words to word counts.
///
/// # Examples
///
/// ```
/// use std::collections::HashMap;
/// use tokenizers::tokenizer::Trainer;
/// use tokenizers::models::bpe::BpeTrainer;
///
/// let word_counts: HashMap<String, u32> = [
/// (String::from("Hello"), 1),
/// (String::from("World"), 1),
/// ].iter().cloned().collect();
/// let trainer = BpeTrainer::default();
/// let model = trainer.train(word_counts);
/// ```
#[derive(Default)]
pub struct BpeTrainer {
// Training parameters
config: BpeTrainerConfig,

View File

@ -1,11 +1,11 @@
use super::Pair;
// TODO: Add tests
#[derive(Clone, Default)]
pub struct Word {
chars: Vec<u32>,
sizes: Vec<usize>,
}
impl Word {
pub fn new() -> Self {
Word {
@ -75,3 +75,52 @@ impl Word {
offsets
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_merge() {
// Let's say we have the word 'hello' and a word-to-id vocab that looks
// like this: {'h': 0, 'e': 1, 'l': 2, 'o': 3}.
let mut word = Word::new();
word.add(0); // 'h'
word.add(1); // 'e'
word.add(2); // 'l'
word.add(2); // 'l'
word.add(3); // 'o'
// We're going to perform a merge on the pair ('l', 'l') ~= (2, 2). Let's
// say that 'll' has the ID of 4 in the updated word-to-id vocab.
let changes = word.merge(2, 2, 4);
// So the word should now look like this:
assert_eq!(
word.get_chars(),
&[
0u32, // 'h'
1u32, // 'e'
4u32, // 'll'
3u32, // 'o'
]
);
// The return value `changes` will be used to update the pair counts during
// training. This merge affects the counts for the pairs
// ('e', 'l') ~= (1, 2),
// ('e', 'll') ~= (1, 4),
// ('l', 'o') ~= (2, 3), and
// ('ll', 'o') ~= (4, 3).
// So the changes should reflect that:
assert_eq!(
changes,
&[
((1u32, 2u32), -1i32), // count for ('e', 'l') should be decreased by 1.
((1u32, 4u32), 1i32), // count for ('e', 'll') should be increased by 1.
((2u32, 3u32), -1i32), // count for ('l', 'o') should be decreased by 1.
((4u32, 3u32), 1i32), // count for ('ll', 'o') should be increased by 1.
]
);
}
}