From 1dc0debe36e39e9092c911e81aa970805c4e0baa Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 18 Dec 2019 16:45:11 -0800 Subject: [PATCH] add initial test --- tokenizers/Cargo.lock | 103 +++++++++++++++++++++++++++++ tokenizers/Cargo.toml | 10 ++- tokenizers/src/models/bpe/mod.rs | 5 ++ tokenizers/src/models/bpe/model.rs | 30 +++++++++ 4 files changed, 145 insertions(+), 3 deletions(-) diff --git a/tokenizers/Cargo.lock b/tokenizers/Cargo.lock index 0e896fb5..5c6f4fc7 100644 --- a/tokenizers/Cargo.lock +++ b/tokenizers/Cargo.lock @@ -35,6 +35,14 @@ name = "bitflags" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "c2-chacha" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "cfg-if" version = "0.1.10" @@ -108,6 +116,16 @@ name = "either" version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "getrandom" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)", + "wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "hermit-abi" version = "0.1.3" @@ -153,6 +171,48 @@ dependencies = [ "libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "ppv-lite86" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "rand" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rand_chacha" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)", + "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "rand_hc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "rayon" version = "1.2.0" @@ -175,6 +235,11 @@ dependencies = [ "num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "redox_syscall" +version = "0.1.56" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "regex" version = "1.3.1" @@ -191,6 +256,14 @@ name = "regex-syntax" version = "0.6.12" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "remove_dir_all" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "rustc_version" version = "0.2.3" @@ -247,6 +320,19 @@ name = "strsim" version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "tempfile" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)", + "rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)", + "redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)", + "remove_dir_all 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)", + "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "textwrap" version = "0.11.0" @@ -273,6 +359,7 @@ dependencies = [ "regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)", "serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)", + "tempfile 3.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)", "unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", ] @@ -300,6 +387,11 @@ name = "vec_map" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "wasi" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "winapi" version = "0.3.8" @@ -325,6 +417,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum atty 0.2.13 (registry+https://github.com/rust-lang/crates.io-index)" = "1803c647a3ec87095e7ae7acfca019e98de5ec9a7d01343f611cf3152ed71a90" "checksum autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2" "checksum bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" +"checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb" "checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" "checksum clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5067f5bb2d80ef5d68b4c87db81601f0b75bca627bc2ef76b141d7b846a3c6d9" "checksum crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3aa945d63861bfe624b55d153a39684da1e8c0bc8fba932f7ee3a3c16cea3ca" @@ -333,6 +426,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum crossbeam-utils 0.6.6 (registry+https://github.com/rust-lang/crates.io-index)" = "04973fa96e96579258a5091af6003abde64af786b860f18622b82e026cca60e6" "checksum crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ce446db02cdc3165b94ae73111e570793400d0794e46125cc4056c81cbb039f4" "checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3" +"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407" "checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120" "checksum itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "501266b7edd0174f8530248f87f99c88fbe60ca4ef3dd486835b8d8d53136f7f" "checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" @@ -340,10 +434,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "88579771288728879b57485cc7d6b07d648c9f0141eb955f8ab7f9d45394468e" "checksum memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "75189eb85871ea5c2e2c15abbdd541185f63b408415e5051f5cac122d8c774b9" "checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72" +"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b" +"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412" +"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853" +"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" "checksum rayon 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "83a27732a533a1be0a0035a111fe76db89ad312f6f0347004c220c57f209a123" "checksum rayon-core 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)" = "98dcf634205083b17d0861252431eb2acbfb698ab7478a2d20de07954f47ec7b" +"checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84" "checksum regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "dc220bd33bdce8f093101afe22a037b8eb0e5af33592e6a9caafff0d4cb81cbd" "checksum regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)" = "11a7e20d1cce64ef2fed88b66d347f88bd9babb82845b2b858f3edbf59a4f716" +"checksum remove_dir_all 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "4a83fa3702a688b9359eccba92d153ac33fd2e8462f9e0e3fdf155239ea7792e" "checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a" "checksum ryu 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "bfa8506c1de11c9c4e4c38863ccbe02a305c8188e85a05a784c9e11e1c3910c8" "checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d" @@ -353,12 +454,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)" = "2f72eb2a68a7dc3f9a691bfda9305a1c017a6215e5a4545c258500d2099a37c2" "checksum smallvec 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "4ecf3b85f68e8abaa7555aa5abdb1153079387e60b718283d732f03897fcfc86" "checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" +"checksum tempfile 3.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7a6e24d9338a0a5be79593e2fa15a648add6138caa803e2d5bc782c371732ca9" "checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" "checksum thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b" "checksum unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)" = "b561e267b2326bb4cebfc0ef9e68355c7abe6c6f522aeac2f5bf95d56c59bdcf" "checksum unicode-width 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7007dbd421b92cc6e28410fe7362e2e0a2503394908f417b68ec8d1c364c4e20" "checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" "checksum vec_map 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "05c78687fb1a80548ae3250346c3db86a80a7cdd77bda190189f2d0a0987c81a" +"checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d" "checksum winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)" = "8093091eeb260906a183e6ae1abdba2ef5ef2257a21801128899c3fc699229c6" "checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" "checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 897bf387..b235c216 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -3,6 +3,11 @@ authors = ["Anthony MOI "] edition = "2018" name = "tokenizers-lib" version = "0.0.7" + +[lib] +name = "tokenizers" +path = "src/lib.rs" + [[bin]] name = "cli" path = "src/cli.rs" @@ -17,6 +22,5 @@ clap = "2.33.0" unicode-normalization = "0.1.11" unicode_categories = "0.1.1" -[lib] -name = "tokenizers" -path = "src/lib.rs" +[dev-dependencies] +tempfile = "3.1" diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index 08862e29..156cfe8f 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -17,6 +17,9 @@ pub enum Error { JsonError(serde_json::Error), /// When the vocab.json file is in the wrong format BadVocabulary, + /// When the merges.txt file is in the wrong format. This error holds the line + /// number of the line that caused the error. + BadMerges(usize), /// If a token found in merges, is not in the vocab MergeTokenOutOfVocabulary(String), } @@ -39,6 +42,7 @@ impl std::fmt::Display for Error { Error::Io(e) => write!(f, "IoError: {}", e), Error::JsonError(e) => write!(f, "JsonError: {}", e), Error::BadVocabulary => write!(f, "Bad vocabulary json file"), + Error::BadMerges(line) => write!(f, "Merges text file invalid at line {}", line), Error::MergeTokenOutOfVocabulary(token) => { write!(f, "Token {} out of vocabulary", token) } @@ -52,6 +56,7 @@ impl std::error::Error for Error { Error::Io(e) => Some(e), Error::JsonError(e) => Some(e), Error::BadVocabulary => None, + Error::BadMerges(_) => None, Error::MergeTokenOutOfVocabulary(_) => None, } } diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 665f32e0..2f49bd94 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -70,6 +70,9 @@ impl BPE { } let parts = line.split(' ').collect::>(); + if parts.len() != 2 { + return Err(Error::BadMerges(rank + 1).into()); + } let a = vocab .get(parts[0]) @@ -192,3 +195,30 @@ impl Model for BPE { self.vocab_r.get(&id).cloned() } } + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::NamedTempFile; + + #[test] + fn test_bpe_from_files() { + // Set up vocab file. + let mut vocab_file = NamedTempFile::new().unwrap(); + vocab_file + .write_all("{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}".as_bytes()) + .unwrap(); + + // Set up merges file. + let mut merges_file = NamedTempFile::new().unwrap(); + merges_file + .write_all("#version: 0.2\na b".as_bytes()) + .unwrap(); + + assert!(BPE::from_files( + vocab_file.path().to_str().unwrap(), + merges_file.path().to_str().unwrap() + ) + .is_ok()); + } +}