mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Merge pull request #6 from huggingface/BPE-tests
Add BPE tests and documentation
This commit is contained in:
103
tokenizers/Cargo.lock
generated
103
tokenizers/Cargo.lock
generated
@ -35,6 +35,14 @@ name = "bitflags"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "c2-chacha"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "0.1.10"
|
||||
@ -108,6 +116,16 @@ name = "either"
|
||||
version = "1.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.3"
|
||||
@ -153,6 +171,48 @@ dependencies = [
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_hc"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.2.0"
|
||||
@ -175,6 +235,11 @@ dependencies = [
|
||||
"num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.1.56"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.3.1"
|
||||
@ -191,6 +256,14 @@ name = "regex-syntax"
|
||||
version = "0.6.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "remove_dir_all"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.2.3"
|
||||
@ -247,6 +320,19 @@ name = "strsim"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "tempfile"
|
||||
version = "3.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"remove_dir_all 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.11.0"
|
||||
@ -273,6 +359,7 @@ dependencies = [
|
||||
"regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"tempfile 3.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
@ -300,6 +387,11 @@ name = "vec_map"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.8"
|
||||
@ -325,6 +417,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum atty 0.2.13 (registry+https://github.com/rust-lang/crates.io-index)" = "1803c647a3ec87095e7ae7acfca019e98de5ec9a7d01343f611cf3152ed71a90"
|
||||
"checksum autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2"
|
||||
"checksum bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
|
||||
"checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb"
|
||||
"checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
||||
"checksum clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5067f5bb2d80ef5d68b4c87db81601f0b75bca627bc2ef76b141d7b846a3c6d9"
|
||||
"checksum crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3aa945d63861bfe624b55d153a39684da1e8c0bc8fba932f7ee3a3c16cea3ca"
|
||||
@ -333,6 +426,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum crossbeam-utils 0.6.6 (registry+https://github.com/rust-lang/crates.io-index)" = "04973fa96e96579258a5091af6003abde64af786b860f18622b82e026cca60e6"
|
||||
"checksum crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ce446db02cdc3165b94ae73111e570793400d0794e46125cc4056c81cbb039f4"
|
||||
"checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
|
||||
"checksum getrandom 0.1.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e7db7ca94ed4cd01190ceee0d8a8052f08a247aa1b469a7f68c6a3b71afcf407"
|
||||
"checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120"
|
||||
"checksum itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "501266b7edd0174f8530248f87f99c88fbe60ca4ef3dd486835b8d8d53136f7f"
|
||||
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
@ -340,10 +434,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "88579771288728879b57485cc7d6b07d648c9f0141eb955f8ab7f9d45394468e"
|
||||
"checksum memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "75189eb85871ea5c2e2c15abbdd541185f63b408415e5051f5cac122d8c774b9"
|
||||
"checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72"
|
||||
"checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b"
|
||||
"checksum rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3ae1b169243eaf61759b8475a998f0a385e42042370f3a7dbaf35246eacc8412"
|
||||
"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853"
|
||||
"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
|
||||
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
|
||||
"checksum rayon 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "83a27732a533a1be0a0035a111fe76db89ad312f6f0347004c220c57f209a123"
|
||||
"checksum rayon-core 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)" = "98dcf634205083b17d0861252431eb2acbfb698ab7478a2d20de07954f47ec7b"
|
||||
"checksum redox_syscall 0.1.56 (registry+https://github.com/rust-lang/crates.io-index)" = "2439c63f3f6139d1b57529d16bc3b8bb855230c8efcc5d3a896c8bea7c3b1e84"
|
||||
"checksum regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "dc220bd33bdce8f093101afe22a037b8eb0e5af33592e6a9caafff0d4cb81cbd"
|
||||
"checksum regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)" = "11a7e20d1cce64ef2fed88b66d347f88bd9babb82845b2b858f3edbf59a4f716"
|
||||
"checksum remove_dir_all 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "4a83fa3702a688b9359eccba92d153ac33fd2e8462f9e0e3fdf155239ea7792e"
|
||||
"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
|
||||
"checksum ryu 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "bfa8506c1de11c9c4e4c38863ccbe02a305c8188e85a05a784c9e11e1c3910c8"
|
||||
"checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d"
|
||||
@ -353,12 +454,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)" = "2f72eb2a68a7dc3f9a691bfda9305a1c017a6215e5a4545c258500d2099a37c2"
|
||||
"checksum smallvec 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "4ecf3b85f68e8abaa7555aa5abdb1153079387e60b718283d732f03897fcfc86"
|
||||
"checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
|
||||
"checksum tempfile 3.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7a6e24d9338a0a5be79593e2fa15a648add6138caa803e2d5bc782c371732ca9"
|
||||
"checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
|
||||
"checksum thread_local 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b"
|
||||
"checksum unicode-normalization 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)" = "b561e267b2326bb4cebfc0ef9e68355c7abe6c6f522aeac2f5bf95d56c59bdcf"
|
||||
"checksum unicode-width 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "7007dbd421b92cc6e28410fe7362e2e0a2503394908f417b68ec8d1c364c4e20"
|
||||
"checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
||||
"checksum vec_map 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "05c78687fb1a80548ae3250346c3db86a80a7cdd77bda190189f2d0a0987c81a"
|
||||
"checksum wasi 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b89c3ce4ce14bdc6fb6beaf9ec7928ca331de5df7e5ea278375642a2f478570d"
|
||||
"checksum winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)" = "8093091eeb260906a183e6ae1abdba2ef5ef2257a21801128899c3fc699229c6"
|
||||
"checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
"checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
@ -3,6 +3,11 @@ authors = ["Anthony MOI <m.anthony.moi@gmail.com>"]
|
||||
edition = "2018"
|
||||
name = "tokenizers-lib"
|
||||
version = "0.0.8"
|
||||
|
||||
[lib]
|
||||
name = "tokenizers"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "cli"
|
||||
path = "src/cli.rs"
|
||||
@ -17,6 +22,5 @@ clap = "2.33.0"
|
||||
unicode-normalization = "0.1.11"
|
||||
unicode_categories = "0.1.1"
|
||||
|
||||
[lib]
|
||||
name = "tokenizers"
|
||||
path = "src/lib.rs"
|
||||
[dev-dependencies]
|
||||
tempfile = "3.1"
|
||||
|
@ -17,6 +17,9 @@ pub enum Error {
|
||||
JsonError(serde_json::Error),
|
||||
/// When the vocab.json file is in the wrong format
|
||||
BadVocabulary,
|
||||
/// When the merges.txt file is in the wrong format. This error holds the line
|
||||
/// number of the line that caused the error.
|
||||
BadMerges(usize),
|
||||
/// If a token found in merges, is not in the vocab
|
||||
MergeTokenOutOfVocabulary(String),
|
||||
}
|
||||
@ -39,6 +42,7 @@ impl std::fmt::Display for Error {
|
||||
Error::Io(e) => write!(f, "IoError: {}", e),
|
||||
Error::JsonError(e) => write!(f, "JsonError: {}", e),
|
||||
Error::BadVocabulary => write!(f, "Bad vocabulary json file"),
|
||||
Error::BadMerges(line) => write!(f, "Merges text file invalid at line {}", line),
|
||||
Error::MergeTokenOutOfVocabulary(token) => {
|
||||
write!(f, "Token {} out of vocabulary", token)
|
||||
}
|
||||
@ -52,6 +56,7 @@ impl std::error::Error for Error {
|
||||
Error::Io(e) => Some(e),
|
||||
Error::JsonError(e) => Some(e),
|
||||
Error::BadVocabulary => None,
|
||||
Error::BadMerges(_) => None,
|
||||
Error::MergeTokenOutOfVocabulary(_) => None,
|
||||
}
|
||||
}
|
||||
|
@ -70,6 +70,9 @@ impl BPE {
|
||||
}
|
||||
|
||||
let parts = line.split(' ').collect::<Vec<_>>();
|
||||
if parts.len() != 2 {
|
||||
return Err(Error::BadMerges(rank + 1).into());
|
||||
}
|
||||
|
||||
let a = vocab
|
||||
.get(parts[0])
|
||||
@ -192,3 +195,91 @@ impl Model for BPE {
|
||||
self.vocab_r.get(&id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[test]
|
||||
// Ensure `BPE::from_files` works as expected.
|
||||
fn test_bpe_from_files() {
|
||||
// Set up vocab file.
|
||||
let mut vocab_file = NamedTempFile::new().unwrap();
|
||||
vocab_file
|
||||
.write_all("{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}".as_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Set up merges file.
|
||||
let mut merges_file = NamedTempFile::new().unwrap();
|
||||
merges_file
|
||||
.write_all("#version: 0.2\na b".as_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Make sure we can instatiate a BPE model from the files.
|
||||
assert!(BPE::from_files(
|
||||
vocab_file.path().to_str().unwrap(),
|
||||
merges_file.path().to_str().unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Ensure `MergeTokenOutOfVocabulary` error is returned when it should be.
|
||||
fn test_bpe_from_files_merge_token_oov() {
|
||||
// Set up vocab file.
|
||||
let mut vocab_file = NamedTempFile::new().unwrap();
|
||||
vocab_file
|
||||
.write_all("{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}".as_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Set up merges file.
|
||||
let mut merges_file = NamedTempFile::new().unwrap();
|
||||
merges_file
|
||||
.write_all("#version: 0.2\na b\na d".as_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Ensure the result of BPE::from_files is a MergeTokenOutOfVocabulary error.
|
||||
match BPE::from_files(
|
||||
vocab_file.path().to_str().unwrap(),
|
||||
merges_file.path().to_str().unwrap(),
|
||||
) {
|
||||
Ok(_) => unreachable!(),
|
||||
Err(err) => match err.downcast_ref::<Error>() {
|
||||
Some(Error::MergeTokenOutOfVocabulary(token)) => {
|
||||
assert_eq!(*token, String::from("d"))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Ensure `BadMerges` error is returned when there is an invalid line in the
|
||||
// merges.txt file.
|
||||
fn test_bpe_from_files_bad_merges() {
|
||||
// Set up vocab file.
|
||||
let mut vocab_file = NamedTempFile::new().unwrap();
|
||||
vocab_file
|
||||
.write_all("{\"a\": 0, \"b\": 1, \"c\": 2, \"ab\": 3}".as_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Set up merges file with a bad line.
|
||||
let mut merges_file = NamedTempFile::new().unwrap();
|
||||
merges_file
|
||||
.write_all("#version: 0.2\na b\nc".as_bytes())
|
||||
.unwrap();
|
||||
|
||||
// Ensure the result of BPE::from_files is a BadMerges error.
|
||||
match BPE::from_files(
|
||||
vocab_file.path().to_str().unwrap(),
|
||||
merges_file.path().to_str().unwrap(),
|
||||
) {
|
||||
Ok(_) => unreachable!(),
|
||||
Err(err) => match err.downcast_ref::<Error>() {
|
||||
Some(Error::BadMerges(line)) => assert_eq!(*line, 3usize),
|
||||
_ => unreachable!(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,3 @@
|
||||
//!
|
||||
//! # Trainer
|
||||
//!
|
||||
//! In charge of training a BPE model
|
||||
//!
|
||||
#![allow(clippy::map_entry)]
|
||||
|
||||
use super::{Pair, Word, BPE};
|
||||
@ -13,14 +8,15 @@ use std::{
|
||||
};
|
||||
|
||||
pub struct BpeTrainerConfig {
|
||||
vocab_size: usize,
|
||||
min_frequency: u32,
|
||||
vocab_size: usize,
|
||||
}
|
||||
|
||||
impl BpeTrainerConfig {
|
||||
pub fn new(min_frequency: u32, vocab_size: usize) -> Self {
|
||||
BpeTrainerConfig {
|
||||
vocab_size,
|
||||
min_frequency,
|
||||
vocab_size,
|
||||
}
|
||||
}
|
||||
|
||||
@ -32,12 +28,30 @@ impl BpeTrainerConfig {
|
||||
self.min_frequency = value;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BpeTrainerConfig {
|
||||
fn default() -> Self {
|
||||
BpeTrainerConfig::new(0, 30000)
|
||||
}
|
||||
}
|
||||
|
||||
/// In charge of training a BPE model from a mapping of words to word counts.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use std::collections::HashMap;
|
||||
/// use tokenizers::tokenizer::Trainer;
|
||||
/// use tokenizers::models::bpe::BpeTrainer;
|
||||
///
|
||||
/// let word_counts: HashMap<String, u32> = [
|
||||
/// (String::from("Hello"), 1),
|
||||
/// (String::from("World"), 1),
|
||||
/// ].iter().cloned().collect();
|
||||
/// let trainer = BpeTrainer::default();
|
||||
/// let model = trainer.train(word_counts);
|
||||
/// ```
|
||||
#[derive(Default)]
|
||||
pub struct BpeTrainer {
|
||||
// Training parameters
|
||||
config: BpeTrainerConfig,
|
||||
|
@ -1,11 +1,11 @@
|
||||
use super::Pair;
|
||||
|
||||
// TODO: Add tests
|
||||
#[derive(Clone, Default)]
|
||||
pub struct Word {
|
||||
chars: Vec<u32>,
|
||||
sizes: Vec<usize>,
|
||||
}
|
||||
|
||||
impl Word {
|
||||
pub fn new() -> Self {
|
||||
Word {
|
||||
@ -75,3 +75,52 @@ impl Word {
|
||||
offsets
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_merge() {
|
||||
// Let's say we have the word 'hello' and a word-to-id vocab that looks
|
||||
// like this: {'h': 0, 'e': 1, 'l': 2, 'o': 3}.
|
||||
let mut word = Word::new();
|
||||
word.add(0); // 'h'
|
||||
word.add(1); // 'e'
|
||||
word.add(2); // 'l'
|
||||
word.add(2); // 'l'
|
||||
word.add(3); // 'o'
|
||||
|
||||
// We're going to perform a merge on the pair ('l', 'l') ~= (2, 2). Let's
|
||||
// say that 'll' has the ID of 4 in the updated word-to-id vocab.
|
||||
let changes = word.merge(2, 2, 4);
|
||||
|
||||
// So the word should now look like this:
|
||||
assert_eq!(
|
||||
word.get_chars(),
|
||||
&[
|
||||
0u32, // 'h'
|
||||
1u32, // 'e'
|
||||
4u32, // 'll'
|
||||
3u32, // 'o'
|
||||
]
|
||||
);
|
||||
|
||||
// The return value `changes` will be used to update the pair counts during
|
||||
// training. This merge affects the counts for the pairs
|
||||
// ('e', 'l') ~= (1, 2),
|
||||
// ('e', 'll') ~= (1, 4),
|
||||
// ('l', 'o') ~= (2, 3), and
|
||||
// ('ll', 'o') ~= (4, 3).
|
||||
// So the changes should reflect that:
|
||||
assert_eq!(
|
||||
changes,
|
||||
&[
|
||||
((1u32, 2u32), -1i32), // count for ('e', 'l') should be decreased by 1.
|
||||
((1u32, 4u32), 1i32), // count for ('e', 'll') should be increased by 1.
|
||||
((2u32, 3u32), -1i32), // count for ('l', 'o') should be decreased by 1.
|
||||
((4u32, 3u32), 1i32), // count for ('ll', 'o') should be increased by 1.
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user