From 1dc0debe36e39e9092c911e81aa970805c4e0baa Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 18 Dec 2019 16:45:11 -0800 Subject: [PATCH 1/5] add initial test --- tokenizers/Cargo.lock | 103 +++++++++++++++++++++++++++++ tokenizers/Cargo.toml | 10 ++- tokenizers/src/models/bpe/mod.rs | 5 ++ tokenizers/src/models/bpe/model.rs | 30 +++++++++ 4 files changed, 145 insertions(+), 3 deletions(-) diff --git a/tokenizers/Cargo.lock b/tokenizers/Cargo.lock index 0e896fb5..5c6f4fc7 100644 --- a/tokenizers/Cargo.lock +++ b/tokenizers/Cargo.lock @@ -35,6 +35,14 @@ name = "bitflags" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "c2-chacha" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "cfg-if" version = "0.1.10" @@ -108,6 +116,16 @@ name = "either" version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "getrandom" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)", + "wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "hermit-abi" version = "0.1.3" @@ -153,6 +171,48 @@ dependencies = [ "libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "ppv-lite86" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "rand" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rand_chacha" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rand_hc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "rayon" version = "1.2.0" @@ -175,6 +235,11 @@ dependencies = [ "num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "redox_syscall" +version = "0.1.56" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "regex" version = "1.3.1" @@ -191,6 +256,14 @@ name = "regex-syntax" version = "0.6.12" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "remove_dir_all" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "rustc_version" version = "0.2.3" @@ -247,6 +320,19 @@ name = "strsim" version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "tempfile" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)", + "rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", + "redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)", + "remove_dir_all 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)", + "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "textwrap" version = "0.11.0" @@ -273,6 +359,7 @@ dependencies = [ "regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)", "serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)", + "tempfile 3.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)", "unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -300,6 +387,11 @@ name = "vec_map" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "wasi" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "winapi" version = "0.3.8" @@ -325,6 +417,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum atty 0.2.13 (registry+https://github.com/rust-lang/crates.io-index)" = "1803c647a3ec87095e7ae7acfca019e98de5ec9a7d01343f611cf3152ed71a90" "checksum autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2" "checksum bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" +"checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb" "checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" "checksum clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5067f5bb2d80ef5d68b4c87db81601f0b75bca627bc2ef76b141d7b846a3c6d9" "checksum crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3aa945d63861bfe624b55d153a39684da1e8c0bc8fba932f7ee3a3c16cea3ca" @@ -333,6 +426,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum crossbeam-utils 0.6.6 (registry+https://github.com/rust-lang/crates.io-index)" = "04973fa96e96579258a5091af6003abde64af786b860f18622b82e026cca60e6" "checksum crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ce446db02cdc3165b94ae73111e570793400d0794e46125cc4056c81cbb039f4" "checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3" +"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407" "checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120" "checksum itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "501266b7edd0174f8530248f87f99c88fbe60ca4ef3dd486835b8d8d53136f7f" "checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" @@ -340,10 +434,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "88579771288728879b57485cc7d6b07d648c9f0141eb955f8ab7f9d45394468e" "checksum memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "75189eb85871ea5c2e2c15abbdd541185f63b408415e5051f5cac122d8c774b9" "checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72" +"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b" +"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412" +"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853" +"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" "checksum rayon 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "83a27732a533a1be0a0035a111fe76db89ad312f6f0347004c220c57f209a123" "checksum rayon-core 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)" = "98dcf634205083b17d0861252431eb2acbfb698ab7478a2d20de07954f47ec7b" +"checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84" "checksum regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "dc220bd33bdce8f093101afe22a037b8eb0e5af33592e6a9caafff0d4cb81cbd" "checksum regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)" = "11a7e20d1cce64ef2fed88b66d347f88bd9babb82845b2b858f3edbf59a4f716" +"checksum remove_dir_all 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "4a83fa3702a688b9359eccba92d153ac33fd2e8462f9e0e3fdf155239ea7792e" "checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a" "checksum ryu 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "bfa8506c1de11c9c4e4c38863ccbe02a305c8188e85a05a784c9e11e1c3910c8" "checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d" @@ -353,12 +454,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)" = "2f72eb2a68a7dc3f9a691bfda9305a1c017a6215e5a4545c258500d2099a37c2" "checksum smallvec 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "4ecf3b85f68e8abaa7555aa5abdb1153079387e60b718283d732f03897fcfc86" "checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" +"checksum tempfile 3.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7a6e24d9338a0a5be79593e2fa15a648add6138caa803e2d5bc782c371732ca9" "checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" "checksum thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b" "checksum unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)" = "b561e267b2326bb4cebfc0ef9e68355c7abe6c6f522aeac2f5bf95d56c59bdcf" "checksum unicode-width 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7007dbd421b92cc6e28410fe7362e2e0a2503394908f417b68ec8d1c364c4e20" "checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" "checksum vec_map 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "05c78687fb1a80548ae3250346c3db86a80a7cdd77bda190189f2d0a0987c81a" +"checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d" "checksum winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)" = "8093091eeb260906a183e6ae1abdba2ef5ef2257a21801128899c3fc699229c6" "checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" "checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 897bf387..b235c216 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -3,6 +3,11 @@ authors = ["Anthony MOI "] edition = "2018" name = "tokenizers-lib" version = "0.0.7" + +[lib] +name = "tokenizers" +path = "src/lib.rs" + [[bin]] name = "cli" path = "src/cli.rs" @@ -17,6 +22,5 @@ clap = "2.33.0" unicode-normalization = "0.1.11" unicode_categories = "0.1.1" -[lib] -name = "tokenizers" -path = "src/lib.rs" +[dev-dependencies] +tempfile = "3.1" diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index 08862e29..156cfe8f 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -17,6 +17,9 @@ pub enum Error { JsonError(serde_json::Error), /// When the vocab.json file is in the wrong format BadVocabulary, + /// When the merges.txt file is in the wrong format. This error holds the line + /// number of the line that caused the error. + BadMerges(usize), /// If a token found in merges, is not in the vocab MergeTokenOutOfVocabulary(String), } @@ -39,6 +42,7 @@ impl std::fmt::Display for Error { Error::Io(e) => write!(f, "IoError: {}", e), Error::JsonError(e) => write!(f, "JsonError: {}", e), Error::BadVocabulary => write!(f, "Bad vocabulary json file"), + Error::BadMerges(line) => write!(f, "Merges text file invalid at line {}", line), Error::MergeTokenOutOfVocabulary(token) => { write!(f, "Token {} out of vocabulary", token) } @@ -52,6 +56,7 @@ impl std::error::Error for Error { Error::Io(e) => Some(e), Error::JsonError(e) => Some(e), Error::BadVocabulary => None, + Error::BadMerges(_) => None, Error::MergeTokenOutOfVocabulary(_) => None, } } diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 665f32e0..2f49bd94 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -70,6 +70,9 @@ impl BPE { } let parts = line.split(' ').collect::>(); + if parts.len() != 2 { + return Err(Error::BadMerges(rank + 1).into()); + } let a = vocab .get(parts[0]) @@ -192,3 +195,30 @@ impl Model for BPE { self.vocab_r.get(&id).cloned() } } + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::NamedTempFile; + + #[test] + fn test_bpe_from_files() { + // Set up vocab file. + let mut vocab_file = NamedTempFile::new().unwrap(); + vocab_file + .write_all("{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}".as_bytes()) + .unwrap(); + + // Set up merges file. + let mut merges_file = NamedTempFile::new().unwrap(); + merges_file + .write_all("#version: 0.2\na b".as_bytes()) + .unwrap(); + + assert!(BPE::from_files( + vocab_file.path().to_str().unwrap(), + merges_file.path().to_str().unwrap() + ) + .is_ok()); + } +} From 184b09e3acc8c4af4473f5802a782c77b5fbd0f5 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 18 Dec 2019 17:40:13 -0800 Subject: [PATCH 2/5] add more tests --- tokenizers/src/models/bpe/model.rs | 57 ++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 2f49bd94..be4b77a2 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -215,10 +215,67 @@ mod tests { .write_all("#version: 0.2\na b".as_bytes()) .unwrap(); + // Make sure we can instatiate a BPE model from the files. assert!(BPE::from_files( vocab_file.path().to_str().unwrap(), merges_file.path().to_str().unwrap() ) .is_ok()); } + + #[test] + fn test_bpe_from_files_merge_token_oov() { + // Set up vocab file. + let mut vocab_file = NamedTempFile::new().unwrap(); + vocab_file + .write_all("{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}".as_bytes()) + .unwrap(); + + // Set up merges file. + let mut merges_file = NamedTempFile::new().unwrap(); + merges_file + .write_all("#version: 0.2\na b\na d".as_bytes()) + .unwrap(); + + // Ensure the result of BPE::from_files is a MergeTokenOutOfVocabulary error. + match BPE::from_files( + vocab_file.path().to_str().unwrap(), + merges_file.path().to_str().unwrap(), + ) { + Ok(_) => unreachable!(), + Err(err) => match err.downcast_ref::() { + Some(Error::MergeTokenOutOfVocabulary(token)) => { + assert_eq!(*token, String::from("d")) + } + _ => unreachable!(), + }, + } + } + + #[test] + fn test_bpe_from_files_bad_merges() { + // Set up vocab file. + let mut vocab_file = NamedTempFile::new().unwrap(); + vocab_file + .write_all("{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}".as_bytes()) + .unwrap(); + + // Set up merges file with a bad line. + let mut merges_file = NamedTempFile::new().unwrap(); + merges_file + .write_all("#version: 0.2\na b\nc".as_bytes()) + .unwrap(); + + // Ensure the result of BPE::from_files is a BadMerges error. + match BPE::from_files( + vocab_file.path().to_str().unwrap(), + merges_file.path().to_str().unwrap(), + ) { + Ok(_) => unreachable!(), + Err(err) => match err.downcast_ref::() { + Some(Error::BadMerges(line)) => assert_eq!(*line, 3usize), + _ => unreachable!(), + }, + } + } } From a16daa78f1469001b5f17ef4fbf21a4ae7234a63 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 19 Dec 2019 14:45:38 -0800 Subject: [PATCH 3/5] add test for word merge --- tokenizers/src/models/bpe/word.rs | 50 ++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 541ddac5..b61b5b49 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -1,6 +1,5 @@ use super::Pair; -// TODO: Add tests #[derive(Clone, Default)] pub struct Word { chars: Vec, @@ -75,3 +74,52 @@ impl Word { offsets } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_merge() { + // Let's say we have the word 'hello' and a word-to-id vocab that looks + // like this: {'h': 0, 'e': 1, 'l': 2, 'o': 3}. + let mut word = Word::new(); + word.add(0); // 'h' + word.add(1); // 'e' + word.add(2); // 'l' + word.add(2); // 'l' + word.add(3); // 'o' + + // We're going to perform a merge on the pair ('l', 'l') ~= (2, 2). Let's + // say that 'll' has the ID of 4 in the updated word-to-id vocab. + let changes = word.merge(2, 2, 4); + + // So the word should now look like this: + assert_eq!( + word.get_chars(), + &[ + 0u32, // 'h' + 1u32, // 'e' + 4u32, // 'll' + 3u32, // 'o' + ] + ); + + // The return value `changes` will be used to update the pair counts during + // training. This merge affects the counts for the pairs + // ('e', 'l') ~= (1, 2), + // ('e', 'll') ~= (1, 4), + // ('ll', 'o') ~= (4, 3), and + // ('l', 'o') ~= (2, 3). + // So the changes should reflect that: + assert_eq!( + changes, + &[ + ((1u32, 2u32), -1i32), // count for ('e', 'l') should be decreased by 1. + ((1u32, 4u32), 1i32), // count for ('e', 'll') should be increased by 1. + ((2u32, 3u32), -1i32), // count for ('l', 'o') should be decreased by 1. + ((4u32, 3u32), 1i32), // count for ('ll', 'o') should be increased by 1. + ] + ); + } +} From 69212e17e9853bb3fe70829e19a04a9755f59211 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 19 Dec 2019 15:07:27 -0800 Subject: [PATCH 4/5] formatting --- tokenizers/src/models/bpe/model.rs | 4 ++++ tokenizers/src/models/bpe/trainer.rs | 7 +++++-- tokenizers/src/models/bpe/word.rs | 5 +++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index be4b77a2..c938227e 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -202,6 +202,7 @@ mod tests { use tempfile::NamedTempFile; #[test] + // Ensure `BPE::from_files` works as expected. fn test_bpe_from_files() { // Set up vocab file. let mut vocab_file = NamedTempFile::new().unwrap(); @@ -224,6 +225,7 @@ mod tests { } #[test] + // Ensure `MergeTokenOutOfVocabulary` error is returned when it should be. fn test_bpe_from_files_merge_token_oov() { // Set up vocab file. let mut vocab_file = NamedTempFile::new().unwrap(); @@ -253,6 +255,8 @@ mod tests { } #[test] + // Ensure `BadMerges` error is returned when there is an invalid line in the + // merges.txt file. fn test_bpe_from_files_bad_merges() { // Set up vocab file. let mut vocab_file = NamedTempFile::new().unwrap(); diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index 88769606..e4c4741a 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -13,14 +13,15 @@ use std::{ }; pub struct BpeTrainerConfig { - vocab_size: usize, min_frequency: u32, + vocab_size: usize, } + impl BpeTrainerConfig { pub fn new(min_frequency: u32, vocab_size: usize) -> Self { BpeTrainerConfig { - vocab_size, min_frequency, + vocab_size, } } @@ -32,12 +33,14 @@ impl BpeTrainerConfig { self.min_frequency = value; } } + impl Default for BpeTrainerConfig { fn default() -> Self { BpeTrainerConfig::new(0, 30000) } } +#[derive(Default)] pub struct BpeTrainer { // Training parameters config: BpeTrainerConfig, diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index b61b5b49..b853290f 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -5,6 +5,7 @@ pub struct Word { chars: Vec, sizes: Vec, } + impl Word { pub fn new() -> Self { Word { @@ -109,8 +110,8 @@ mod tests { // training. This merge affects the counts for the pairs // ('e', 'l') ~= (1, 2), // ('e', 'll') ~= (1, 4), - // ('ll', 'o') ~= (4, 3), and - // ('l', 'o') ~= (2, 3). + // ('l', 'o') ~= (2, 3), and + // ('ll', 'o') ~= (4, 3). // So the changes should reflect that: assert_eq!( changes, From 6d51e7a393f18f1838e4617af6a0d33ade3abef9 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 19 Dec 2019 15:28:58 -0800 Subject: [PATCH 5/5] add example / doc test for BPE trainer --- tokenizers/src/models/bpe/trainer.rs | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index e4c4741a..0331f771 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -1,8 +1,3 @@ -//! -//! # Trainer -//! -//! In charge of training a BPE model -//! #![allow(clippy::map_entry)] use super::{Pair, Word, BPE}; @@ -40,6 +35,22 @@ impl Default for BpeTrainerConfig { } } +/// In charge of training a BPE model from a mapping of words to word counts. +/// +/// # Examples +/// +/// ``` +/// use std::collections::HashMap; +/// use tokenizers::tokenizer::Trainer; +/// use tokenizers::models::bpe::BpeTrainer; +/// +/// let word_counts: HashMap = [ +/// (String::from("Hello"), 1), +/// (String::from("World"), 1), +/// ].iter().cloned().collect(); +/// let trainer = BpeTrainer::default(); +/// let model = trainer.train(word_counts); +/// ``` #[derive(Default)] pub struct BpeTrainer { // Training parameters