mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-03 15:59:25 +00:00
Improvements on spm parity: (#401)
* Removing all pre_tokenizer logic from Unigram algorithm. * Improving *a lot* the parity check. - We can now detect a lot more errors - Special cases have been added temporarily. * Adding 2 new normalizers that mimick spm defaut's behavior. * Adding `encoding_optimized` version of the `encode` algorithm. - Removes Lattice allocation. - Changes trie `common_prefix_search` to return an iterator to avoid allocation of the full results. * Trie<char> -> Trie<u8> Another improvement on speed. * [WIP] Attempt to create a Precompiled Normalizer from SPM to be 100% compliant with arbitrary models. * Adding a new `Precompiled` Normalizer that is replacing `SpmNmtNfkc`. - It will be used for direct compatiblity with `Spm` and replace all their custom rules by using directly the normalizer spec embedded within spm files, removing all need for any rules for us. - We need `nom` dependency to parse the binary format of `spm`. - We need to add `sentencepiece_model_pb2.py` file to be able to read the proto file. - We reimplemented their `Darts::DoubleArray` compact trie format. * Fixing a bug with Precompiled normalizer. * Fixing some edge cases (now in tests) with this weird precompiled normalizer. It seems a very handy crafted trie does not prevent from shooting oneself in the foot. Sorry future reader. * Keep API stable for this PR (change of the API should come later #409). - Removed sentencepiece_model_pb2 from binding and add instructions to make `from_spm` work. * Adding model check in `from_spm`. * Adressing @n1t0's comments. * Adding a check to make sure alignments stay correct. Also added a bit more documentation on how Precompiled works. * Extracting `Precompiled` into it's own `spm_precompiled` crate. * Using ranges in `do_nmt`.
This commit is contained in:
48
bindings/node/native/Cargo.lock
generated
48
bindings/node/native/Cargo.lock
generated
@ -33,6 +33,12 @@ dependencies = [
|
|||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "arrayvec"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "cff77d8686867eceff3105329d4698d96c2391c176d5d03adc90c7389162b5b8"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "atty"
|
name = "atty"
|
||||||
version = "0.2.14"
|
version = "0.2.14"
|
||||||
@ -339,6 +345,19 @@ version = "1.4.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lexical-core"
|
||||||
|
version = "0.7.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "db65c6da02e61f55dae90a0ae427b2a5f6b3e8db09f58d10efab23af92592616"
|
||||||
|
dependencies = [
|
||||||
|
"arrayvec",
|
||||||
|
"bitflags",
|
||||||
|
"cfg-if",
|
||||||
|
"ryu",
|
||||||
|
"static_assertions",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libc"
|
name = "libc"
|
||||||
version = "0.2.74"
|
version = "0.2.74"
|
||||||
@ -452,6 +471,17 @@ dependencies = [
|
|||||||
"tokenizers",
|
"tokenizers",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nom"
|
||||||
|
version = "5.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ffb4262d26ed83a1c0a33a38fe2bb15797329c85770da05e6b828ddb782627af"
|
||||||
|
dependencies = [
|
||||||
|
"lexical-core",
|
||||||
|
"memchr",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num"
|
name = "num"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
@ -754,6 +784,23 @@ version = "1.4.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3757cb9d89161a2f24e1cf78efa0c1fcff485d18e3f55e0aa3480824ddaa0f3f"
|
checksum = "3757cb9d89161a2f24e1cf78efa0c1fcff485d18e3f55e0aa3480824ddaa0f3f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "spm_precompiled"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f78be885c9efc899a7c0348f67c98b488cbeaf2cb608a48fb87ef1484ecab5c5"
|
||||||
|
dependencies = [
|
||||||
|
"nom",
|
||||||
|
"serde",
|
||||||
|
"unicode-segmentation",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "static_assertions"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "strsim"
|
name = "strsim"
|
||||||
version = "0.8.0"
|
version = "0.8.0"
|
||||||
@ -833,6 +880,7 @@ dependencies = [
|
|||||||
"regex-syntax",
|
"regex-syntax",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"spm_precompiled",
|
||||||
"unicode-normalization-alignments",
|
"unicode-normalization-alignments",
|
||||||
"unicode-segmentation",
|
"unicode-segmentation",
|
||||||
"unicode_categories",
|
"unicode_categories",
|
||||||
|
54
bindings/python/Cargo.lock
generated
54
bindings/python/Cargo.lock
generated
@ -16,6 +16,11 @@ dependencies = [
|
|||||||
"winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)",
|
"winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "arrayvec"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "atty"
|
name = "atty"
|
||||||
version = "0.2.14"
|
version = "0.2.14"
|
||||||
@ -350,6 +355,18 @@ name = "lazy_static"
|
|||||||
version = "1.4.0"
|
version = "1.4.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lexical-core"
|
||||||
|
version = "0.7.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"arrayvec 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"ryu 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"static_assertions 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "libc"
|
name = "libc"
|
||||||
version = "0.2.77"
|
version = "0.2.77"
|
||||||
@ -409,6 +426,16 @@ dependencies = [
|
|||||||
"rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
"rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nom"
|
||||||
|
version = "5.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"lexical-core 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"version_check 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num-complex"
|
name = "num-complex"
|
||||||
version = "0.2.4"
|
version = "0.2.4"
|
||||||
@ -741,6 +768,21 @@ name = "smallvec"
|
|||||||
version = "1.4.2"
|
version = "1.4.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "spm_precompiled"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"nom 5.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"serde 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"unicode-segmentation 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "static_assertions"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "strsim"
|
name = "strsim"
|
||||||
version = "0.8.0"
|
version = "0.8.0"
|
||||||
@ -834,6 +876,7 @@ dependencies = [
|
|||||||
"regex-syntax 0.6.18 (registry+https://github.com/rust-lang/crates.io-index)",
|
"regex-syntax 0.6.18 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"serde 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)",
|
"serde 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)",
|
"serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"spm_precompiled 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"unicode-normalization-alignments 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
"unicode-normalization-alignments 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"unicode-segmentation 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"unicode-segmentation 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
@ -893,6 +936,11 @@ name = "vec_map"
|
|||||||
version = "0.8.2"
|
version = "0.8.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "version_check"
|
||||||
|
version = "0.9.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wasi"
|
name = "wasi"
|
||||||
version = "0.9.0+wasi-snapshot-preview1"
|
version = "0.9.0+wasi-snapshot-preview1"
|
||||||
@ -928,6 +976,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
[metadata]
|
[metadata]
|
||||||
"checksum aho-corasick 0.7.13 (registry+https://github.com/rust-lang/crates.io-index)" = "043164d8ba5c4c3035fec9bbee8647c0261d788f3474306f93bb65901cae0e86"
|
"checksum aho-corasick 0.7.13 (registry+https://github.com/rust-lang/crates.io-index)" = "043164d8ba5c4c3035fec9bbee8647c0261d788f3474306f93bb65901cae0e86"
|
||||||
"checksum ansi_term 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b"
|
"checksum ansi_term 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b"
|
||||||
|
"checksum arrayvec 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cff77d8686867eceff3105329d4698d96c2391c176d5d03adc90c7389162b5b8"
|
||||||
"checksum atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)" = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
"checksum atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)" = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
||||||
"checksum autocfg 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
|
"checksum autocfg 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
|
||||||
"checksum bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
|
"checksum bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
|
||||||
@ -966,6 +1015,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
"checksum itertools 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b"
|
"checksum itertools 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b"
|
||||||
"checksum itoa 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "dc6f3ad7b9d11a0c00842ff8de1b60ee58661048eb8049ed33c73594f359d7e6"
|
"checksum itoa 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "dc6f3ad7b9d11a0c00842ff8de1b60ee58661048eb8049ed33c73594f359d7e6"
|
||||||
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||||
|
"checksum lexical-core 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)" = "db65c6da02e61f55dae90a0ae427b2a5f6b3e8db09f58d10efab23af92592616"
|
||||||
"checksum libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)" = "f2f96b10ec2560088a8e76961b00d47107b3a625fecb76dedb29ee7ccbf98235"
|
"checksum libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)" = "f2f96b10ec2560088a8e76961b00d47107b3a625fecb76dedb29ee7ccbf98235"
|
||||||
"checksum lock_api 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)" = "28247cc5a5be2f05fbcd76dd0cf2c7d3b5400cb978a28042abcd4fa0b3f8261c"
|
"checksum lock_api 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)" = "28247cc5a5be2f05fbcd76dd0cf2c7d3b5400cb978a28042abcd4fa0b3f8261c"
|
||||||
"checksum log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)" = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b"
|
"checksum log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)" = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b"
|
||||||
@ -974,6 +1024,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
"checksum memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400"
|
"checksum memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400"
|
||||||
"checksum memoffset 0.5.5 (registry+https://github.com/rust-lang/crates.io-index)" = "c198b026e1bbf08a937e94c6c60f9ec4a2267f5b0d2eec9c1b21b061ce2be55f"
|
"checksum memoffset 0.5.5 (registry+https://github.com/rust-lang/crates.io-index)" = "c198b026e1bbf08a937e94c6c60f9ec4a2267f5b0d2eec9c1b21b061ce2be55f"
|
||||||
"checksum ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)" = "ac06db03ec2f46ee0ecdca1a1c34a99c0d188a0d83439b84bf0cb4b386e4ab09"
|
"checksum ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)" = "ac06db03ec2f46ee0ecdca1a1c34a99c0d188a0d83439b84bf0cb4b386e4ab09"
|
||||||
|
"checksum nom 5.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "ffb4262d26ed83a1c0a33a38fe2bb15797329c85770da05e6b828ddb782627af"
|
||||||
"checksum num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
|
"checksum num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
|
||||||
"checksum num-integer 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)" = "8d59457e662d541ba17869cf51cf177c0b5f0cbf476c66bdc90bf1edac4f875b"
|
"checksum num-integer 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)" = "8d59457e662d541ba17869cf51cf177c0b5f0cbf476c66bdc90bf1edac4f875b"
|
||||||
"checksum num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)" = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611"
|
"checksum num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)" = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611"
|
||||||
@ -1013,6 +1064,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
"checksum serde_derive 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)" = "f630a6370fd8e457873b4bd2ffdae75408bc291ba72be773772a4c2a065d9ae8"
|
"checksum serde_derive 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)" = "f630a6370fd8e457873b4bd2ffdae75408bc291ba72be773772a4c2a065d9ae8"
|
||||||
"checksum serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)" = "164eacbdb13512ec2745fb09d51fd5b22b0d65ed294a1dcf7285a360c80a675c"
|
"checksum serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)" = "164eacbdb13512ec2745fb09d51fd5b22b0d65ed294a1dcf7285a360c80a675c"
|
||||||
"checksum smallvec 1.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "fbee7696b84bbf3d89a1c2eccff0850e3047ed46bfcd2e92c29a2d074d57e252"
|
"checksum smallvec 1.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "fbee7696b84bbf3d89a1c2eccff0850e3047ed46bfcd2e92c29a2d074d57e252"
|
||||||
|
"checksum spm_precompiled 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f78be885c9efc899a7c0348f67c98b488cbeaf2cb608a48fb87ef1484ecab5c5"
|
||||||
|
"checksum static_assertions 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||||
"checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
|
"checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
|
||||||
"checksum strsim 0.9.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
|
"checksum strsim 0.9.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
|
||||||
"checksum syn 1.0.40 (registry+https://github.com/rust-lang/crates.io-index)" = "963f7d3cc59b59b9325165add223142bbf1df27655d07789f109896d353d8350"
|
"checksum syn 1.0.40 (registry+https://github.com/rust-lang/crates.io-index)" = "963f7d3cc59b59b9325165add223142bbf1df27655d07789f109896d353d8350"
|
||||||
@ -1029,6 +1082,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
"checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
"checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
||||||
"checksum unindent 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "af41d708427f8fd0e915dcebb2cae0f0e6acb2a939b2d399c265c39a38a18942"
|
"checksum unindent 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "af41d708427f8fd0e915dcebb2cae0f0e6acb2a939b2d399c265c39a38a18942"
|
||||||
"checksum vec_map 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191"
|
"checksum vec_map 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191"
|
||||||
|
"checksum version_check 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)" = "b5a972e5669d67ba988ce3dc826706fb0a8b01471c088cb0b6110b805cc36aed"
|
||||||
"checksum wasi 0.9.0+wasi-snapshot-preview1 (registry+https://github.com/rust-lang/crates.io-index)" = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
|
"checksum wasi 0.9.0+wasi-snapshot-preview1 (registry+https://github.com/rust-lang/crates.io-index)" = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
|
||||||
"checksum winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)" = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
"checksum winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)" = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||||
"checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
"checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||||
|
@ -1,6 +1,14 @@
|
|||||||
from tokenizers import Tokenizer, AddedToken, pre_tokenizers, decoders, trainers
|
from tokenizers import (
|
||||||
|
Tokenizer,
|
||||||
|
AddedToken,
|
||||||
|
pre_tokenizers,
|
||||||
|
decoders,
|
||||||
|
trainers,
|
||||||
|
normalizers,
|
||||||
|
)
|
||||||
|
import os
|
||||||
from tokenizers.models import Unigram
|
from tokenizers.models import Unigram
|
||||||
from tokenizers.normalizers import NFKC
|
import json
|
||||||
from .base_tokenizer import BaseTokenizer
|
from .base_tokenizer import BaseTokenizer
|
||||||
|
|
||||||
from typing import Optional, List, Union
|
from typing import Optional, List, Union
|
||||||
@ -16,11 +24,12 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
|
|||||||
self, vocab: Optional[str] = None, replacement: str = "▁", add_prefix_space: bool = True,
|
self, vocab: Optional[str] = None, replacement: str = "▁", add_prefix_space: bool = True,
|
||||||
):
|
):
|
||||||
if vocab is not None:
|
if vocab is not None:
|
||||||
|
# Let Unigram(..) fail if only one of them is None
|
||||||
tokenizer = Tokenizer(Unigram(vocab))
|
tokenizer = Tokenizer(Unigram(vocab))
|
||||||
else:
|
else:
|
||||||
tokenizer = Tokenizer(Unigram())
|
tokenizer = Tokenizer(Unigram())
|
||||||
|
|
||||||
tokenizer.normalizer = NFKC()
|
tokenizer.normalizer = normalizers.Sequence([normalizers.Nmt(), normalizers.NFKC(),])
|
||||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||||
[
|
[
|
||||||
pre_tokenizers.WhitespaceSplit(),
|
pre_tokenizers.WhitespaceSplit(),
|
||||||
@ -57,3 +66,63 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
|
|||||||
if isinstance(files, str):
|
if isinstance(files, str):
|
||||||
files = [files]
|
files = [files]
|
||||||
self._tokenizer.train(trainer, files)
|
self._tokenizer.train(trainer, files)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_spm(filename: str):
|
||||||
|
try:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(".")
|
||||||
|
|
||||||
|
import sentencepiece_model_pb2 as model
|
||||||
|
except Exception:
|
||||||
|
raise Exception(
|
||||||
|
"You don't seem to have the required protobuf file, in order to use this function you need to run `pip install protobuf` and `wget https://raw.githubusercontent.com/google/sentencepiece/master/python/sentencepiece_model_pb2.py` for us to be able to read the intrinsics of your spm_file. `pip install sentencepiece` is not required."
|
||||||
|
)
|
||||||
|
|
||||||
|
m = model.ModelProto()
|
||||||
|
m.ParseFromString(open(filename, "rb").read())
|
||||||
|
|
||||||
|
precompiled_charsmap = m.normalizer_spec.precompiled_charsmap
|
||||||
|
vocab = [(piece.piece, piece.score) for piece in m.pieces]
|
||||||
|
unk_id = m.trainer_spec.unk_id
|
||||||
|
model_type = m.trainer_spec.model_type
|
||||||
|
if model_type != 1:
|
||||||
|
raise Exception(
|
||||||
|
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {"unk_id": unk_id, "vocab": vocab}
|
||||||
|
|
||||||
|
replacement = "▁"
|
||||||
|
add_prefix_space = True
|
||||||
|
|
||||||
|
out_vocab_filename = f"{filename}.json"
|
||||||
|
try:
|
||||||
|
with open(out_vocab_filename, "w") as f:
|
||||||
|
json.dump(data, f, indent=4)
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(Unigram(out_vocab_filename))
|
||||||
|
finally:
|
||||||
|
os.remove(out_vocab_filename)
|
||||||
|
|
||||||
|
tokenizer.normalizer = normalizers.Precompiled(precompiled_charsmap)
|
||||||
|
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||||
|
[
|
||||||
|
pre_tokenizers.WhitespaceSplit(),
|
||||||
|
pre_tokenizers.Metaspace(
|
||||||
|
replacement=replacement, add_prefix_space=add_prefix_space
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
tokenizer.decoder = decoders.Metaspace(
|
||||||
|
replacement=replacement, add_prefix_space=add_prefix_space
|
||||||
|
)
|
||||||
|
|
||||||
|
parameters = {
|
||||||
|
"model": "SentencePieceUnigram",
|
||||||
|
}
|
||||||
|
|
||||||
|
obj = BaseTokenizer.__new__(SentencePieceUnigramTokenizer, tokenizer, parameters)
|
||||||
|
BaseTokenizer.__init__(obj, tokenizer, parameters)
|
||||||
|
return obj
|
||||||
|
@ -9,6 +9,8 @@ NFKC = normalizers.NFKC
|
|||||||
Sequence = normalizers.Sequence
|
Sequence = normalizers.Sequence
|
||||||
Lowercase = normalizers.Lowercase
|
Lowercase = normalizers.Lowercase
|
||||||
Strip = normalizers.Strip
|
Strip = normalizers.Strip
|
||||||
|
Nmt = normalizers.Nmt
|
||||||
|
Precompiled = normalizers.Precompiled
|
||||||
|
|
||||||
|
|
||||||
NORMALIZERS = {"nfc": NFC, "nfd": NFD, "nfkc": NFKC, "nfkd": NFKD}
|
NORMALIZERS = {"nfc": NFC, "nfd": NFD, "nfkc": NFKC, "nfkd": NFKD}
|
||||||
|
@ -99,6 +99,18 @@ class Strip(Normalizer):
|
|||||||
def __init__(self, left: bool = True, right: bool = True) -> Normalizer:
|
def __init__(self, left: bool = True, right: bool = True) -> Normalizer:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class Nmt(Normalizer):
|
||||||
|
""" Nmt normalizer """
|
||||||
|
|
||||||
|
def __init__(self) -> Normalizer:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Precompiled(Normalizer):
|
||||||
|
""" SpmNmtNfkc normalizer """
|
||||||
|
|
||||||
|
def __init__(self, precompiled_charsmap: bytes) -> Normalizer:
|
||||||
|
pass
|
||||||
|
|
||||||
def unicode_normalizer_from_str(normalizer: str) -> Normalizer:
|
def unicode_normalizer_from_str(normalizer: str) -> Normalizer:
|
||||||
"""
|
"""
|
||||||
Instanciate unicode normalizer from the normalizer name
|
Instanciate unicode normalizer from the normalizer name
|
||||||
|
@ -1,7 +1,17 @@
|
|||||||
import tokenizers
|
import tokenizers
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
from collections import Counter
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
try:
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
has_color = True
|
||||||
|
except Exception:
|
||||||
|
has_color = False
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -9,38 +19,62 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input-file", "-i", type=str, required=True, help="Which files do you want to train from",
|
"--input-file", "-i", type=str, required=True, help="Which files do you want to train from",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-file",
|
||||||
|
"-m",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default=None,
|
||||||
|
help="Use a pretrained token file",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model-prefix", type=str, default="spm_parity", help="Model prefix for spm_train",
|
"--model-prefix", type=str, default="spm_parity", help="Model prefix for spm_train",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train",
|
"--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose", action="store_true", help="Verbosity",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train",
|
"--train",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Instead of checking the encoder part, we check the trainer part",
|
help="Instead of checking the encoder part, we check the trainer part",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--from-spm",
|
||||||
|
action="store_true",
|
||||||
|
help="Directly load the spm file with it's own normalizer",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
spm.SentencePieceTrainer.Train(
|
trained = False
|
||||||
f"--input={args.input_file} --model_prefix={args.model_prefix}"
|
if args.model_file is None:
|
||||||
f" --character_coverage=1.0"
|
spm.SentencePieceTrainer.Train(
|
||||||
f" --max_sentence_length=40000"
|
f"--input={args.input_file} --model_prefix={args.model_prefix}"
|
||||||
f" --num_threads=1"
|
f" --character_coverage=1.0"
|
||||||
f" --vocab_size={args.vocab_size}"
|
f" --max_sentence_length=40000"
|
||||||
)
|
f" --num_threads=1"
|
||||||
|
f" --vocab_size={args.vocab_size}"
|
||||||
|
)
|
||||||
|
trained = True
|
||||||
|
args.model_file = f"{args.model_prefix}.model"
|
||||||
|
|
||||||
if args.train:
|
try:
|
||||||
check_train(args)
|
if args.train:
|
||||||
else:
|
check_train(args)
|
||||||
check_encode(args)
|
else:
|
||||||
|
check_encode(args)
|
||||||
|
finally:
|
||||||
|
if trained:
|
||||||
|
os.remove(f"{args.model_prefix}.model")
|
||||||
|
os.remove(f"{args.model_prefix}.vocab")
|
||||||
|
|
||||||
|
|
||||||
def check_train(args):
|
def check_train(args):
|
||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
model_filename = f"{args.model_prefix}.model"
|
sp.Load(args.model_file)
|
||||||
sp.Load(model_filename)
|
|
||||||
|
|
||||||
tokenizer = tokenizers.SentencePieceUnigramTokenizer()
|
tokenizer = tokenizers.SentencePieceUnigramTokenizer()
|
||||||
tokenizer.train(args.input_file, show_progress=False)
|
tokenizer.train(args.input_file, show_progress=False)
|
||||||
@ -77,38 +111,144 @@ def check_train(args):
|
|||||||
assert (
|
assert (
|
||||||
tokenizer_tokens < spm_tokens
|
tokenizer_tokens < spm_tokens
|
||||||
), "Our trainer should be at least more efficient than the SPM one"
|
), "Our trainer should be at least more efficient than the SPM one"
|
||||||
|
print("Ok our trainer is at least more efficient than the SPM one")
|
||||||
|
|
||||||
|
|
||||||
|
def check_diff(spm_diff, tok_diff, sp, tok):
|
||||||
|
if spm_diff == list(reversed(tok_diff)):
|
||||||
|
# AAA -> AA+A vs A+AA case.
|
||||||
|
return True
|
||||||
|
elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(tok_diff):
|
||||||
|
# Second order OK
|
||||||
|
# Barrich -> Barr + ich vs Bar + rich
|
||||||
|
return True
|
||||||
|
spm_reencoded = sp.encode(sp.decode(spm_diff))
|
||||||
|
tok_reencoded = tok.encode(tok.decode(spm_diff)).ids
|
||||||
|
if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded:
|
||||||
|
# Type 3 error.
|
||||||
|
# Snehagatha ->
|
||||||
|
# Sne, h, aga, th, a
|
||||||
|
# Sne, ha, gat, ha
|
||||||
|
# Encoding the wrong with sp does not even recover what spm gave us
|
||||||
|
# It fits tokenizer however...
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_details(line, spm_ids, tok_ids, tok, sp):
|
||||||
|
# Encoding can be the same with same result AAA -> A + AA vs AA + A
|
||||||
|
# We can check that we use at least exactly the same number of tokens.
|
||||||
|
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
|
||||||
|
if spm_id != tok_id:
|
||||||
|
break
|
||||||
|
first = i
|
||||||
|
for i, (spm_id, tok_id) in enumerate(zip(reversed(spm_ids), reversed(tok_ids))):
|
||||||
|
if spm_id != tok_id:
|
||||||
|
break
|
||||||
|
last = len(spm_ids) - i
|
||||||
|
|
||||||
|
spm_diff = spm_ids[first:last]
|
||||||
|
tok_diff = tok_ids[first:last]
|
||||||
|
|
||||||
|
if check_diff(spm_diff, tok_diff, sp, tok):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if last - first > 5:
|
||||||
|
# We might have twice a single problem, attempt to subdivide the disjointed tokens into smaller problems
|
||||||
|
spms = Counter(spm_ids[first:last])
|
||||||
|
toks = Counter(tok_ids[first:last])
|
||||||
|
|
||||||
|
removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
|
||||||
|
min_width = 3
|
||||||
|
for i in range(last - first - min_width):
|
||||||
|
if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
|
||||||
|
possible_matches = [
|
||||||
|
k
|
||||||
|
for k in range(last - first - min_width)
|
||||||
|
if tok_ids[first + k : first + k + min_width]
|
||||||
|
== spm_ids[first + i : first + i + min_width]
|
||||||
|
]
|
||||||
|
for j in possible_matches:
|
||||||
|
if check_diff(
|
||||||
|
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
|
||||||
|
) and check_diff(spm_ids[first + i : last], tok_ids[first + j : last], sp, tok):
|
||||||
|
return True
|
||||||
|
ok_start = tok.decode(spm_ids[:first])
|
||||||
|
ok_end = tok.decode(spm_ids[last:])
|
||||||
|
wrong = tok.decode(spm_ids[first:last])
|
||||||
|
print()
|
||||||
|
if has_color:
|
||||||
|
print(f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}")
|
||||||
|
else:
|
||||||
|
print(wrong)
|
||||||
|
|
||||||
|
print(f"Spm: {[tok.decode([spm_ids[i]]) for i in range(first, last)]}")
|
||||||
|
print(f"Tok: {[tok.decode([tok_ids[i]]) for i in range(first, last)]}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def check_encode(args):
|
def check_encode(args):
|
||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
model_filename = f"{args.model_prefix}.model"
|
sp.Load(args.model_file)
|
||||||
sp.Load(model_filename)
|
|
||||||
|
|
||||||
vocab_filename = f"{args.model_prefix}.json"
|
if args.from_spm:
|
||||||
|
tok = tokenizers.SentencePieceUnigramTokenizer.from_spm(args.model_file)
|
||||||
|
else:
|
||||||
|
vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())]
|
||||||
|
vocab_filename = f"{args.model_file}.json"
|
||||||
|
unk_id = sp.unk_id()
|
||||||
|
|
||||||
vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())]
|
data = {"unk_id": unk_id, "vocab": vocab}
|
||||||
|
try:
|
||||||
|
with open(vocab_filename, "w") as f:
|
||||||
|
json.dump(data, f, indent=4)
|
||||||
|
|
||||||
data = {"unk_id": sp.unk_id(), "vocab": vocab}
|
tok = tokenizers.SentencePieceUnigramTokenizer(vocab_filename)
|
||||||
|
finally:
|
||||||
|
os.remove(vocab_filename)
|
||||||
|
|
||||||
with open(vocab_filename, "w") as f:
|
perfect = 0
|
||||||
json.dump(data, f, indent=4)
|
imperfect = 0
|
||||||
|
wrong = 0
|
||||||
tok = tokenizers.SentencePieceUnigramTokenizer(vocab_filename)
|
now = datetime.datetime.now
|
||||||
with open(args.input_file, "r") as f:
|
spm_total_time = datetime.timedelta(seconds=0)
|
||||||
|
tok_total_time = datetime.timedelta(seconds=0)
|
||||||
|
with open(args.input_file, "r", encoding="utf-8-sig") as f:
|
||||||
for i, line in enumerate(f):
|
for i, line in enumerate(f):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
|
|
||||||
|
start = now()
|
||||||
ids = sp.EncodeAsIds(line)
|
ids = sp.EncodeAsIds(line)
|
||||||
|
spm_time = now()
|
||||||
|
|
||||||
encoded = tok.encode(line)
|
encoded = tok.encode(line)
|
||||||
|
tok_time = now()
|
||||||
|
|
||||||
|
spm_total_time += spm_time - start
|
||||||
|
tok_total_time += tok_time - spm_time
|
||||||
|
|
||||||
|
if args.verbose:
|
||||||
|
if i % 10000 == 0:
|
||||||
|
print(
|
||||||
|
f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})"
|
||||||
|
)
|
||||||
|
print(f"SPM: {spm_total_time} - TOK: {tok_total_time}")
|
||||||
|
|
||||||
if ids != encoded.ids:
|
if ids != encoded.ids:
|
||||||
# Encoding can be the same with same result AAA -> A + AA vs AA + A
|
if check_details(line, ids, encoded.ids, tok, sp):
|
||||||
# We can check that we use at least exactly the same number of tokens.
|
imperfect += 1
|
||||||
assert len(ids) == len(encoded.ids)
|
continue
|
||||||
continue
|
else:
|
||||||
|
wrong += 1
|
||||||
|
else:
|
||||||
|
perfect += 1
|
||||||
|
|
||||||
assert ids == encoded.ids, f"line {i}: {line} : {ids} != {encoded.ids}"
|
assert ids == encoded.ids, f"line {i}: {line} : {ids} != {encoded.ids}"
|
||||||
|
|
||||||
|
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
|
||||||
|
total = perfect + imperfect + wrong
|
||||||
|
print(f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -15,7 +15,7 @@ setup(
|
|||||||
author_email="anthony@huggingface.co",
|
author_email="anthony@huggingface.co",
|
||||||
url="https://github.com/huggingface/tokenizers",
|
url="https://github.com/huggingface/tokenizers",
|
||||||
license="Apache License 2.0",
|
license="Apache License 2.0",
|
||||||
rust_extensions=[RustExtension("tokenizers.tokenizers", binding=Binding.PyO3)],
|
rust_extensions=[RustExtension("tokenizers.tokenizers", binding=Binding.PyO3, debug=False)],
|
||||||
extras_require=extras,
|
extras_require=extras,
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 5 - Production/Stable",
|
"Development Status :: 5 - Production/Stable",
|
||||||
|
@ -108,6 +108,8 @@ fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
|||||||
m.add_class::<normalizers::PySequence>()?;
|
m.add_class::<normalizers::PySequence>()?;
|
||||||
m.add_class::<normalizers::PyLowercase>()?;
|
m.add_class::<normalizers::PyLowercase>()?;
|
||||||
m.add_class::<normalizers::PyStrip>()?;
|
m.add_class::<normalizers::PyStrip>()?;
|
||||||
|
m.add_class::<normalizers::PyNmt>()?;
|
||||||
|
m.add_class::<normalizers::PyPrecompiled>()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -263,21 +263,19 @@ pub struct PyUnigram {}
|
|||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyUnigram {
|
impl PyUnigram {
|
||||||
#[new]
|
#[new]
|
||||||
fn new(vocab: Option<&str>) -> PyResult<(Self, PyModel)> {
|
fn new(vocab: Option<String>) -> PyResult<(Self, PyModel)> {
|
||||||
if let Some(vocab) = vocab {
|
match vocab {
|
||||||
let path = Path::new(vocab);
|
Some(vocab) => match Unigram::load(&std::path::Path::new(&vocab)) {
|
||||||
match Unigram::load(path) {
|
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
println!("Errors: {:?}", e);
|
println!("Errors: {:?}", e);
|
||||||
Err(exceptions::Exception::py_err("Error while loading Unigram"))
|
Err(exceptions::Exception::py_err("Error while loading Unigram"))
|
||||||
}
|
}
|
||||||
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
||||||
}
|
},
|
||||||
} else {
|
None => Ok((
|
||||||
Ok((
|
|
||||||
PyUnigram {},
|
PyUnigram {},
|
||||||
PyModel::new(Arc::new(Unigram::default().into())),
|
PyModel::new(Arc::new(Unigram::default().into())),
|
||||||
))
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,9 @@ use pyo3::types::*;
|
|||||||
use crate::error::ToPyResult;
|
use crate::error::ToPyResult;
|
||||||
use serde::ser::SerializeStruct;
|
use serde::ser::SerializeStruct;
|
||||||
use serde::{Deserialize, Serialize, Serializer};
|
use serde::{Deserialize, Serialize, Serializer};
|
||||||
use tk::normalizers::{BertNormalizer, Lowercase, NormalizerWrapper, Strip, NFC, NFD, NFKC, NFKD};
|
use tk::normalizers::{
|
||||||
|
BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Precompiled, Strip, NFC, NFD, NFKC, NFKD,
|
||||||
|
};
|
||||||
use tk::{NormalizedString, Normalizer};
|
use tk::{NormalizedString, Normalizer};
|
||||||
use tokenizers as tk;
|
use tokenizers as tk;
|
||||||
|
|
||||||
@ -45,6 +47,10 @@ impl PyNormalizer {
|
|||||||
NormalizerWrapper::Lowercase(_) => {
|
NormalizerWrapper::Lowercase(_) => {
|
||||||
Py::new(py, (PyLowercase {}, base)).map(Into::into)
|
Py::new(py, (PyLowercase {}, base)).map(Into::into)
|
||||||
}
|
}
|
||||||
|
NormalizerWrapper::Precompiled(_) => {
|
||||||
|
Py::new(py, (PyPrecompiled {}, base)).map(Into::into)
|
||||||
|
}
|
||||||
|
NormalizerWrapper::Nmt(_) => Py::new(py, (PyNmt {}, base)).map(Into::into),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -273,6 +279,37 @@ impl Normalizer for PyNormalizerWrapper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Nmt)]
|
||||||
|
pub struct PyNmt {}
|
||||||
|
#[pymethods]
|
||||||
|
impl PyNmt {
|
||||||
|
#[new]
|
||||||
|
fn new() -> PyResult<(Self, PyNormalizer)> {
|
||||||
|
Ok((PyNmt {}, Nmt.into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Precompiled)]
|
||||||
|
pub struct PyPrecompiled {}
|
||||||
|
#[pymethods]
|
||||||
|
impl PyPrecompiled {
|
||||||
|
#[new]
|
||||||
|
fn new(py_precompiled_charsmap: &PyBytes) -> PyResult<(Self, PyNormalizer)> {
|
||||||
|
let precompiled_charsmap: &[u8] = FromPyObject::extract(py_precompiled_charsmap)?;
|
||||||
|
Ok((
|
||||||
|
PyPrecompiled {},
|
||||||
|
Precompiled::from(precompiled_charsmap)
|
||||||
|
.map_err(|e| {
|
||||||
|
exceptions::Exception::py_err(format!(
|
||||||
|
"Error while attempting to build Precompiled normalizer: {}",
|
||||||
|
e.to_string()
|
||||||
|
))
|
||||||
|
})?
|
||||||
|
.into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use pyo3::{AsPyRef, Python};
|
use pyo3::{AsPyRef, Python};
|
||||||
|
@ -52,6 +52,7 @@ itertools = "0.9"
|
|||||||
log = "0.4"
|
log = "0.4"
|
||||||
esaxx-rs = "0.1"
|
esaxx-rs = "0.1"
|
||||||
derive_builder = "0.9"
|
derive_builder = "0.9"
|
||||||
|
spm_precompiled = "0.1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = "0.3"
|
criterion = "0.3"
|
||||||
|
@ -4,7 +4,6 @@ mod model;
|
|||||||
mod serialization;
|
mod serialization;
|
||||||
mod trainer;
|
mod trainer;
|
||||||
mod trie;
|
mod trie;
|
||||||
mod unicode;
|
|
||||||
|
|
||||||
pub use lattice::*;
|
pub use lattice::*;
|
||||||
pub use model::*;
|
pub use model::*;
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::models::unigram::lattice::Lattice;
|
use crate::models::unigram::lattice::Lattice;
|
||||||
use crate::models::unigram::trie::{Trie, TrieBuilder};
|
use crate::models::unigram::trie::{Trie, TrieBuilder};
|
||||||
use crate::tokenizer::{Model, Result, Token};
|
use crate::tokenizer::{Model, Result, Token};
|
||||||
use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY};
|
use crate::utils::cache::Cache;
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::convert::TryInto;
|
use std::convert::TryInto;
|
||||||
@ -17,7 +17,7 @@ pub struct Unigram {
|
|||||||
token_to_ids: TokenMap,
|
token_to_ids: TokenMap,
|
||||||
pub(crate) vocab: Vocab,
|
pub(crate) vocab: Vocab,
|
||||||
cache: Cache<String, Vec<String>>,
|
cache: Cache<String, Vec<String>>,
|
||||||
trie: Trie<char>,
|
trie: Trie<u8>,
|
||||||
pub min_score: f64,
|
pub min_score: f64,
|
||||||
pub(super) unk_id: usize,
|
pub(super) unk_id: usize,
|
||||||
pub(super) bos_id: usize,
|
pub(super) bos_id: usize,
|
||||||
@ -38,13 +38,13 @@ impl Clone for Unigram {
|
|||||||
let fresh_cache = self.cache.fresh();
|
let fresh_cache = self.cache.fresh();
|
||||||
Self {
|
Self {
|
||||||
vocab: self.vocab.clone(),
|
vocab: self.vocab.clone(),
|
||||||
token_to_ids: self.token_to_ids.clone(),
|
|
||||||
cache: fresh_cache,
|
cache: fresh_cache,
|
||||||
unk_id: self.unk_id,
|
token_to_ids: self.token_to_ids.clone(),
|
||||||
|
trie: self.trie.clone(),
|
||||||
min_score: self.min_score,
|
min_score: self.min_score,
|
||||||
|
unk_id: self.unk_id,
|
||||||
bos_id: self.bos_id,
|
bos_id: self.bos_id,
|
||||||
eos_id: self.eos_id,
|
eos_id: self.eos_id,
|
||||||
trie: self.trie.clone(),
|
|
||||||
fuse_unk: self.fuse_unk,
|
fuse_unk: self.fuse_unk,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -54,6 +54,7 @@ impl std::fmt::Debug for Unigram {
|
|||||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
fmt.debug_struct("BPE")
|
fmt.debug_struct("BPE")
|
||||||
.field("vocab", &self.vocab.len())
|
.field("vocab", &self.vocab.len())
|
||||||
|
.field("unk_id", &self.unk_id)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -113,19 +114,17 @@ impl Unigram {
|
|||||||
let mut min_score = f64::INFINITY;
|
let mut min_score = f64::INFINITY;
|
||||||
for (id, (token, score)) in vocab.iter().enumerate() {
|
for (id, (token, score)) in vocab.iter().enumerate() {
|
||||||
token_to_ids.insert(token.to_string(), id as u32);
|
token_to_ids.insert(token.to_string(), id as u32);
|
||||||
let chars: Vec<char> = token.chars().collect();
|
let bytes: Vec<u8> = token.bytes().collect();
|
||||||
builder.push(&chars);
|
builder.push(&bytes);
|
||||||
if score < &min_score {
|
if score < &min_score {
|
||||||
min_score = *score;
|
min_score = *score;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let trie = builder.build();
|
let trie = builder.build();
|
||||||
let fuse_unk = true;
|
let fuse_unk = true;
|
||||||
let cache = Cache::new(DEFAULT_CACHE_CAPACITY);
|
|
||||||
|
|
||||||
Ok(Unigram {
|
Ok(Unigram {
|
||||||
vocab,
|
vocab,
|
||||||
cache,
|
|
||||||
token_to_ids,
|
token_to_ids,
|
||||||
trie,
|
trie,
|
||||||
min_score,
|
min_score,
|
||||||
@ -133,6 +132,7 @@ impl Unigram {
|
|||||||
eos_id,
|
eos_id,
|
||||||
unk_id,
|
unk_id,
|
||||||
fuse_unk,
|
fuse_unk,
|
||||||
|
cache: Cache::default(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -151,26 +151,29 @@ impl Unigram {
|
|||||||
|
|
||||||
let len = lattice.len();
|
let len = lattice.len();
|
||||||
|
|
||||||
for begin_pos in 0..len {
|
let mut begin_pos = 0;
|
||||||
let trie_results: Vec<String> = self
|
while begin_pos < len {
|
||||||
.trie
|
let mblen = lattice.sentence[begin_pos..]
|
||||||
.common_prefix_search(lattice.sentence.chars().skip(begin_pos))
|
.chars()
|
||||||
.iter()
|
.next()
|
||||||
.map(|chars| chars.iter().collect())
|
.unwrap()
|
||||||
.collect();
|
.len_utf8();
|
||||||
|
|
||||||
let mut has_single_node = false;
|
let mut has_single_node = false;
|
||||||
|
|
||||||
for tok in trie_results {
|
for bytes in self
|
||||||
let n = tok.chars().count();
|
.trie
|
||||||
|
.common_prefix_search(lattice.sentence.bytes().skip(begin_pos))
|
||||||
|
{
|
||||||
|
let n = bytes.len();
|
||||||
|
let tok = String::from_utf8(bytes).unwrap();
|
||||||
let id = *self.token_to_ids.get(&tok).unwrap();
|
let id = *self.token_to_ids.get(&tok).unwrap();
|
||||||
|
|
||||||
let item = &self.vocab[id as usize];
|
let item = &self.vocab[id as usize];
|
||||||
assert_eq!(item.0, tok);
|
assert_eq!(item.0, tok);
|
||||||
let score: f64 = item.1;
|
let score: f64 = item.1;
|
||||||
lattice.insert(begin_pos, n, score, id.try_into().unwrap());
|
lattice.insert(begin_pos, n, score, id.try_into().unwrap());
|
||||||
if !has_single_node && n == 1 {
|
if !has_single_node && n == mblen {
|
||||||
has_single_node = true;
|
has_single_node = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -178,6 +181,7 @@ impl Unigram {
|
|||||||
if !has_single_node {
|
if !has_single_node {
|
||||||
lattice.insert(begin_pos, 1, unk_score, self.unk_id);
|
lattice.insert(begin_pos, 1, unk_score, self.unk_id);
|
||||||
}
|
}
|
||||||
|
begin_pos += mblen
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -202,17 +206,129 @@ impl Unigram {
|
|||||||
/// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]);
|
/// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]);
|
||||||
/// ```
|
/// ```
|
||||||
pub fn encode(&self, sentence: &str) -> Vec<String> {
|
pub fn encode(&self, sentence: &str) -> Vec<String> {
|
||||||
|
if sentence.is_empty() {
|
||||||
|
return vec![];
|
||||||
|
}
|
||||||
if let Some(result) = self.cache.get(sentence) {
|
if let Some(result) = self.cache.get(sentence) {
|
||||||
result
|
result.to_vec()
|
||||||
} else {
|
} else {
|
||||||
let result = self.encode_no_cache(sentence);
|
let result = self.encode_optimized(sentence);
|
||||||
self.cache.set(sentence.to_owned(), result.clone());
|
self.cache.set(sentence.to_owned(), result.clone());
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fn encode_no_cache(&self, sentence: &str) -> Vec<String> {
|
|
||||||
// TODO optimized version
|
fn encode_optimized(&self, sentence: &str) -> Vec<String> {
|
||||||
// https://github.com/google/sentencepiece/blob/d48247191a6d50e469ed1a4a36e877befffd1851/src/unigram_model.cc#L600
|
// https://github.com/google/sentencepiece/blob/d48247191a6d50e469ed1a4a36e877befffd1851/src/unigram_model.cc#L600
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct BestPathNode {
|
||||||
|
/// The vocab id. (maybe UNK)
|
||||||
|
id: usize,
|
||||||
|
/// The total score of the best path ending at this node.
|
||||||
|
best_path_score: f64,
|
||||||
|
/// The starting position (in utf-8) of this node. The entire best
|
||||||
|
/// path can be constructed by backtracking along this link.
|
||||||
|
starts_at: Option<usize>,
|
||||||
|
};
|
||||||
|
impl Default for BestPathNode {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
id: 0,
|
||||||
|
best_path_score: 0.0,
|
||||||
|
starts_at: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let size = sentence.len();
|
||||||
|
let unk_score = self.min_score - K_UNK_PENALTY;
|
||||||
|
|
||||||
|
let mut best_path_ends_at = vec![BestPathNode::default(); size + 1];
|
||||||
|
let mut starts_at = 0;
|
||||||
|
while starts_at < size {
|
||||||
|
let best_path_score_till_here = best_path_ends_at[starts_at].best_path_score;
|
||||||
|
let mut has_single_node = false;
|
||||||
|
let mblen = sentence[starts_at..].chars().next().unwrap().len_utf8();
|
||||||
|
for tok_bytes in self
|
||||||
|
.trie
|
||||||
|
.common_prefix_search(sentence.bytes().skip(starts_at))
|
||||||
|
{
|
||||||
|
let key_pos = starts_at + tok_bytes.len();
|
||||||
|
let token: String = String::from_utf8(tok_bytes).unwrap();
|
||||||
|
let mut target_node = &mut best_path_ends_at[key_pos];
|
||||||
|
let length = key_pos - starts_at;
|
||||||
|
let id = self.token_to_ids.get(&token).unwrap();
|
||||||
|
let score = self.vocab.get(*id as usize).unwrap().1;
|
||||||
|
let candidate_best_path_score = score + best_path_score_till_here;
|
||||||
|
if target_node.starts_at.is_none()
|
||||||
|
|| candidate_best_path_score > target_node.best_path_score
|
||||||
|
{
|
||||||
|
target_node.best_path_score = candidate_best_path_score;
|
||||||
|
target_node.starts_at = Some(starts_at);
|
||||||
|
target_node.id = *id as usize;
|
||||||
|
}
|
||||||
|
if !has_single_node && length == mblen {
|
||||||
|
has_single_node = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !has_single_node {
|
||||||
|
let mut target_node = &mut best_path_ends_at[starts_at + mblen];
|
||||||
|
let candidate_best_path_score = unk_score + best_path_score_till_here;
|
||||||
|
if target_node.starts_at.is_none()
|
||||||
|
|| candidate_best_path_score > target_node.best_path_score
|
||||||
|
{
|
||||||
|
target_node.best_path_score = candidate_best_path_score;
|
||||||
|
target_node.starts_at = Some(starts_at);
|
||||||
|
target_node.id = self.unk_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
starts_at += mblen
|
||||||
|
}
|
||||||
|
let mut ends_at = size;
|
||||||
|
let mut results: Vec<String> = vec![];
|
||||||
|
let mut token = vec![];
|
||||||
|
while ends_at > 0 {
|
||||||
|
let node = &best_path_ends_at[ends_at];
|
||||||
|
let starts_at = node.starts_at.unwrap();
|
||||||
|
if self.fuse_unk && node.id == self.unk_id {
|
||||||
|
token.push(
|
||||||
|
String::from_utf8(
|
||||||
|
sentence
|
||||||
|
.bytes()
|
||||||
|
.skip(starts_at)
|
||||||
|
.take(ends_at - starts_at)
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
if !token.is_empty() {
|
||||||
|
token.reverse();
|
||||||
|
results.push(token.concat());
|
||||||
|
token = vec![];
|
||||||
|
}
|
||||||
|
results.push(
|
||||||
|
String::from_utf8(
|
||||||
|
sentence
|
||||||
|
.bytes()
|
||||||
|
.skip(starts_at)
|
||||||
|
.take(ends_at - starts_at)
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
ends_at = starts_at;
|
||||||
|
}
|
||||||
|
if !token.is_empty() {
|
||||||
|
token.reverse();
|
||||||
|
results.push(token.concat());
|
||||||
|
}
|
||||||
|
results.reverse();
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn encode_unoptimized(&self, sentence: &str) -> Vec<String> {
|
||||||
let mut lattice = Lattice::from(sentence, self.unk_id, self.bos_id, self.eos_id);
|
let mut lattice = Lattice::from(sentence, self.unk_id, self.bos_id, self.eos_id);
|
||||||
self.populate_nodes(&mut lattice);
|
self.populate_nodes(&mut lattice);
|
||||||
if self.fuse_unk {
|
if self.fuse_unk {
|
||||||
@ -296,11 +412,14 @@ impl Model for Unigram {
|
|||||||
Ok(tokens
|
Ok(tokens
|
||||||
.iter()
|
.iter()
|
||||||
.map(|string| {
|
.map(|string| {
|
||||||
let id = self.token_to_ids.get(string).unwrap_or(&0);
|
let id: u32 = match self.token_to_ids.get(string) {
|
||||||
|
Some(id) => *id,
|
||||||
|
None => self.unk_id as u32,
|
||||||
|
};
|
||||||
let len = string.len();
|
let len = string.len();
|
||||||
let offsets = (offset, offset + len);
|
let offsets = (offset, offset + len);
|
||||||
offset += len;
|
offset += len;
|
||||||
Token::new(*id, string.to_string(), offsets)
|
Token::new(id, string.to_string(), offsets)
|
||||||
})
|
})
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
|
@ -75,4 +75,15 @@ mod test {
|
|||||||
|
|
||||||
assert_eq!(model, reconstructed);
|
assert_eq!(model, reconstructed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_serialization_unk_id_not_zero() {
|
||||||
|
let vocab = vec![("a".to_string(), -0.5), ("<unk>".to_string(), 0.0)];
|
||||||
|
let model = Unigram::from(vocab, 1).unwrap();
|
||||||
|
|
||||||
|
let data = serde_json::to_string(&model).unwrap();
|
||||||
|
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(model, reconstructed);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,4 @@
|
|||||||
use crate::models::unigram::{
|
use crate::models::unigram::{lattice::Lattice, model::Unigram};
|
||||||
lattice::Lattice,
|
|
||||||
model::Unigram,
|
|
||||||
unicode::{get_script, Script},
|
|
||||||
};
|
|
||||||
use crate::tokenizer::{AddedToken, Result, Trainer};
|
use crate::tokenizer::{AddedToken, Result, Trainer};
|
||||||
use indicatif::{ProgressBar, ProgressStyle};
|
use indicatif::{ProgressBar, ProgressStyle};
|
||||||
use log::debug;
|
use log::debug;
|
||||||
@ -53,33 +49,9 @@ pub struct UnigramTrainer {
|
|||||||
#[builder(default = "vec![]")]
|
#[builder(default = "vec![]")]
|
||||||
special_tokens: Vec<AddedToken>,
|
special_tokens: Vec<AddedToken>,
|
||||||
|
|
||||||
#[builder(default = "' '")]
|
|
||||||
space_char: char,
|
|
||||||
|
|
||||||
#[builder(default = "String::from(\"<unk>\")")]
|
#[builder(default = "String::from(\"<unk>\")")]
|
||||||
unk_token: String,
|
unk_token: String,
|
||||||
|
|
||||||
#[builder(default = "false")]
|
|
||||||
treat_whitespace_as_suffix: bool,
|
|
||||||
|
|
||||||
#[builder(default = "true")]
|
|
||||||
split_by_unicode_script: bool,
|
|
||||||
|
|
||||||
#[builder(default = "true")]
|
|
||||||
split_by_number: bool,
|
|
||||||
|
|
||||||
#[builder(default = "false")]
|
|
||||||
split_by_digits: bool,
|
|
||||||
|
|
||||||
/// In spm this parameter defaults to true,
|
|
||||||
/// we set it to false here because it's supposed
|
|
||||||
/// to be a job taken care by the pretokenizer. We still
|
|
||||||
/// have it here to enable easier testing as it does make a difference
|
|
||||||
/// in `is_valid_sentencepiece` method were we discard seed_pieces if they
|
|
||||||
/// contain a whitespace. This job can/could be taken elsewhere.
|
|
||||||
#[builder(default = "false")]
|
|
||||||
split_by_whitespace: bool,
|
|
||||||
|
|
||||||
#[builder(default = "16")]
|
#[builder(default = "16")]
|
||||||
max_piece_length: usize,
|
max_piece_length: usize,
|
||||||
#[builder(default = "1_000_000")]
|
#[builder(default = "1_000_000")]
|
||||||
@ -106,72 +78,19 @@ impl UnigramTrainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn is_valid_sentencepiece(&self, char_string: &[char]) -> bool {
|
fn is_valid_sentencepiece(&self, char_string: &[char]) -> bool {
|
||||||
// TODO check more formally but should be ok.
|
// Checks string length
|
||||||
// Checks string length, space not in the substring, numbers, hiragana and more
|
// Space not in the substring, numbers, hiragana and more should be taken
|
||||||
|
// care of within pre_tokenizers.
|
||||||
// https://github.com/google/sentencepiece/blob/26be9516cd81d5315ee31c48d2438018e0eab879/src/trainer_interface.cc#L203
|
// https://github.com/google/sentencepiece/blob/26be9516cd81d5315ee31c48d2438018e0eab879/src/trainer_interface.cc#L203
|
||||||
let n = char_string.len();
|
let n = char_string.len();
|
||||||
if char_string.is_empty() || n > self.max_piece_length {
|
if char_string.is_empty() || n > self.max_piece_length {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
let mut last_script = Script::Any;
|
|
||||||
for (i, c) in char_string.iter().enumerate() {
|
|
||||||
if *c == '\0' {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if *c == self.space_char {
|
|
||||||
if self.treat_whitespace_as_suffix {
|
|
||||||
let is_not_suffix_no_whitespace = self.split_by_whitespace && i != n - 1;
|
|
||||||
let is_prefix = !self.split_by_whitespace && i == 0;
|
|
||||||
if is_not_suffix_no_whitespace || is_prefix {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let is_not_prefix_no_whitespace = self.split_by_whitespace && i != 0;
|
|
||||||
let is_suffix = !self.split_by_whitespace && i == n - 1;
|
|
||||||
if is_not_prefix_no_whitespace || is_suffix {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This function checks that unicode "scripts" are consistent, so we cannot have romaji and
|
|
||||||
// hiragana for instance. Seems pretty specific. Also Hiragana and katakana are mixed
|
|
||||||
let raw_script = get_script(c);
|
|
||||||
|
|
||||||
let script = if *c as u32 == 0x30FC {
|
|
||||||
Script::Han
|
|
||||||
} else if *c == self.space_char || !self.split_by_number && c.is_numeric() {
|
|
||||||
Script::Any
|
|
||||||
} else {
|
|
||||||
match raw_script {
|
|
||||||
Script::Hiragana => Script::Han,
|
|
||||||
Script::Katakana => Script::Han,
|
|
||||||
script => script,
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if self.split_by_digits && c.is_numeric() && n > 1 {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if self.split_by_unicode_script
|
|
||||||
&& script != Script::Any
|
|
||||||
&& last_script != Script::Any
|
|
||||||
&& script != last_script
|
|
||||||
{
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
last_script = script;
|
|
||||||
}
|
|
||||||
|
|
||||||
true
|
true
|
||||||
|
|
||||||
// true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn finalize(&self, model: Unigram, required_chars: HashSet<String>) -> Result<Unigram> {
|
fn finalize(&self, model: Unigram, required_chars: HashSet<String>) -> Result<Unigram> {
|
||||||
// let mut pieces: Vec<SentencePiece> =
|
|
||||||
// Vec::with_capacity(self.vocab_size.try_into().unwrap());
|
|
||||||
|
|
||||||
let mut min_score_penalty = 0.0;
|
let mut min_score_penalty = 0.0;
|
||||||
let min_score_penalty_delta = 0.0001;
|
let min_score_penalty_delta = 0.0001;
|
||||||
|
|
||||||
@ -204,7 +123,6 @@ impl UnigramTrainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
|
fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
|
||||||
// TODO more logic needed if this required chars > vocab_size
|
|
||||||
word_counts
|
word_counts
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(s, _count)| s.chars())
|
.map(|(s, _count)| s.chars())
|
||||||
@ -292,7 +210,6 @@ impl UnigramTrainer {
|
|||||||
pieces: &[SentencePiece],
|
pieces: &[SentencePiece],
|
||||||
sentences: &[Sentence],
|
sentences: &[Sentence],
|
||||||
) -> Vec<SentencePiece> {
|
) -> Vec<SentencePiece> {
|
||||||
// TODO
|
|
||||||
let mut always_keep = vec![true; pieces.len()];
|
let mut always_keep = vec![true; pieces.len()];
|
||||||
let mut alternatives: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()];
|
let mut alternatives: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()];
|
||||||
|
|
||||||
@ -494,11 +411,11 @@ impl UnigramTrainer {
|
|||||||
.collect();
|
.collect();
|
||||||
new_pieces
|
new_pieces
|
||||||
}
|
}
|
||||||
pub fn _train(&self, mut sentences: Vec<Sentence>) -> Result<(Unigram, Vec<AddedToken>)> {
|
pub fn _train(&self, sentences: Vec<Sentence>) -> Result<(Unigram, Vec<AddedToken>)> {
|
||||||
let progress = self.setup_progress();
|
let progress = self.setup_progress();
|
||||||
//
|
//
|
||||||
// 1. Compute frequent substrings
|
// 1. Compute frequent substrings
|
||||||
// TODO should be either i64 or i32
|
// TODO Should be able to upgrade to u64 when needed
|
||||||
self.update_progress(&progress, sentences.len(), "Suffix array seeds");
|
self.update_progress(&progress, sentences.len(), "Suffix array seeds");
|
||||||
let mut pieces: Vec<SentencePiece> =
|
let mut pieces: Vec<SentencePiece> =
|
||||||
Vec::with_capacity(self.vocab_size.try_into().unwrap());
|
Vec::with_capacity(self.vocab_size.try_into().unwrap());
|
||||||
@ -507,26 +424,6 @@ impl UnigramTrainer {
|
|||||||
pieces.extend(self.make_seed_sentence_pieces(&sentences, &progress)?);
|
pieces.extend(self.make_seed_sentence_pieces(&sentences, &progress)?);
|
||||||
self.finalize_progress(&progress, sentences.len());
|
self.finalize_progress(&progress, sentences.len());
|
||||||
|
|
||||||
if self.split_by_whitespace {
|
|
||||||
self.update_progress(&progress, sentences.len(), "Splitting by whitespace");
|
|
||||||
let mut words: HashMap<String, u32> = HashMap::new();
|
|
||||||
for (sentence, count) in &sentences {
|
|
||||||
for word in sentence.split(self.space_char) {
|
|
||||||
if word.is_empty() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
*words
|
|
||||||
.entry(format!("{}{}", self.space_char, word))
|
|
||||||
.or_insert(0) += count;
|
|
||||||
}
|
|
||||||
if let Some(p) = &progress {
|
|
||||||
p.inc(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.finalize_progress(&progress, sentences.len());
|
|
||||||
sentences = words.into_iter().collect();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Useful to check compatibility with spm.
|
// Useful to check compatibility with spm.
|
||||||
debug!(
|
debug!(
|
||||||
"Using {} pieces on {} sentences for EM training",
|
"Using {} pieces on {} sentences for EM training",
|
||||||
@ -625,8 +522,6 @@ mod tests {
|
|||||||
fn test_unigram_chars() {
|
fn test_unigram_chars() {
|
||||||
let trainer = UnigramTrainerBuilder::default()
|
let trainer = UnigramTrainerBuilder::default()
|
||||||
.show_progress(false)
|
.show_progress(false)
|
||||||
.split_by_whitespace(false)
|
|
||||||
.treat_whitespace_as_suffix(true)
|
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -30,23 +30,40 @@ impl<Label: Eq + Hash + Copy> Trie<Label> {
|
|||||||
node.is_leaf = true;
|
node.is_leaf = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn common_prefix_search(&self, iterator: impl Iterator<Item = Label>) -> Vec<Vec<Label>> {
|
pub fn common_prefix_search<T>(&self, iterator: T) -> TrieIterator<Label, T>
|
||||||
let mut node = &self.root;
|
where
|
||||||
let mut results = vec![];
|
T: Iterator<Item = Label>,
|
||||||
let mut prefix = vec![];
|
{
|
||||||
for label in iterator {
|
TrieIterator {
|
||||||
prefix.push(label);
|
node: &self.root,
|
||||||
let child_opt = node.children.get(&label);
|
prefix: vec![],
|
||||||
if let Some(child) = child_opt {
|
iterator,
|
||||||
node = child;
|
}
|
||||||
if node.is_leaf {
|
}
|
||||||
results.push(prefix.clone());
|
}
|
||||||
}
|
|
||||||
} else {
|
pub struct TrieIterator<'a, Label, T> {
|
||||||
return results;
|
node: &'a Node<Label>,
|
||||||
|
prefix: Vec<Label>,
|
||||||
|
iterator: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Label, T> Iterator for TrieIterator<'_, Label, T>
|
||||||
|
where
|
||||||
|
Label: Eq + Hash + Copy,
|
||||||
|
T: Iterator<Item = Label>,
|
||||||
|
{
|
||||||
|
type Item = Vec<Label>;
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
loop {
|
||||||
|
let label = self.iterator.next()?;
|
||||||
|
self.prefix.push(label);
|
||||||
|
let child = self.node.children.get(&label)?;
|
||||||
|
self.node = child;
|
||||||
|
if self.node.is_leaf {
|
||||||
|
return Some(self.prefix.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
results
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
pub mod bert;
|
pub mod bert;
|
||||||
|
pub mod precompiled;
|
||||||
pub mod strip;
|
pub mod strip;
|
||||||
pub mod unicode;
|
pub mod unicode;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
pub use crate::normalizers::bert::BertNormalizer;
|
pub use crate::normalizers::bert::BertNormalizer;
|
||||||
|
pub use crate::normalizers::precompiled::Precompiled;
|
||||||
pub use crate::normalizers::strip::Strip;
|
pub use crate::normalizers::strip::Strip;
|
||||||
pub use crate::normalizers::unicode::{NFC, NFD, NFKC, NFKD};
|
pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD};
|
||||||
pub use crate::normalizers::utils::{Lowercase, Sequence};
|
pub use crate::normalizers::utils::{Lowercase, Sequence};
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@ -24,6 +26,8 @@ pub enum NormalizerWrapper {
|
|||||||
NFKD(NFKD),
|
NFKD(NFKD),
|
||||||
Sequence(Sequence),
|
Sequence(Sequence),
|
||||||
Lowercase(Lowercase),
|
Lowercase(Lowercase),
|
||||||
|
Nmt(Nmt),
|
||||||
|
Precompiled(Precompiled),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Normalizer for NormalizerWrapper {
|
impl Normalizer for NormalizerWrapper {
|
||||||
@ -37,6 +41,8 @@ impl Normalizer for NormalizerWrapper {
|
|||||||
NormalizerWrapper::NFKD(nfkd) => nfkd.normalize(normalized),
|
NormalizerWrapper::NFKD(nfkd) => nfkd.normalize(normalized),
|
||||||
NormalizerWrapper::Sequence(sequence) => sequence.normalize(normalized),
|
NormalizerWrapper::Sequence(sequence) => sequence.normalize(normalized),
|
||||||
NormalizerWrapper::Lowercase(lc) => lc.normalize(normalized),
|
NormalizerWrapper::Lowercase(lc) => lc.normalize(normalized),
|
||||||
|
NormalizerWrapper::Nmt(lc) => lc.normalize(normalized),
|
||||||
|
NormalizerWrapper::Precompiled(lc) => lc.normalize(normalized),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -49,3 +55,5 @@ impl_enum_from!(NFD, NormalizerWrapper, NFD);
|
|||||||
impl_enum_from!(Strip, NormalizerWrapper, StripNormalizer);
|
impl_enum_from!(Strip, NormalizerWrapper, StripNormalizer);
|
||||||
impl_enum_from!(Sequence, NormalizerWrapper, Sequence);
|
impl_enum_from!(Sequence, NormalizerWrapper, Sequence);
|
||||||
impl_enum_from!(Lowercase, NormalizerWrapper, Lowercase);
|
impl_enum_from!(Lowercase, NormalizerWrapper, Lowercase);
|
||||||
|
impl_enum_from!(Nmt, NormalizerWrapper, Nmt);
|
||||||
|
impl_enum_from!(Precompiled, NormalizerWrapper, Precompiled);
|
||||||
|
53
tokenizers/src/normalizers/precompiled.rs
Normal file
53
tokenizers/src/normalizers/precompiled.rs
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
use crate::tokenizer::{NormalizedString, Normalizer, Result};
|
||||||
|
pub use spm_precompiled::Precompiled;
|
||||||
|
use unicode_segmentation::UnicodeSegmentation;
|
||||||
|
|
||||||
|
impl Normalizer for Precompiled {
|
||||||
|
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
|
||||||
|
let mut transformations = Vec::with_capacity(normalized.get().len());
|
||||||
|
// Future reader. From @Narsil.
|
||||||
|
// Yes, this is weird,
|
||||||
|
// Yes, this seems broken
|
||||||
|
// No, I don't know why Google did this.
|
||||||
|
// If you question this code, check this normalizer against
|
||||||
|
// XNLI database (all languages) with Unigram model against
|
||||||
|
// Mbart, XLMRoberta *AND* Marian. If you don't get 100% or
|
||||||
|
// break a single test.
|
||||||
|
// You don't pass.
|
||||||
|
normalized.get().graphemes(true).for_each(|grapheme| {
|
||||||
|
let old_count = grapheme.chars().count() as isize;
|
||||||
|
if grapheme.len() < 6 {
|
||||||
|
if let Some(norm) = self.transform(grapheme) {
|
||||||
|
let new_count = norm.chars().count() as isize;
|
||||||
|
for (i, c) in norm.chars().enumerate() {
|
||||||
|
let n = if i == 0 {
|
||||||
|
new_count - old_count
|
||||||
|
} else {
|
||||||
|
i as isize
|
||||||
|
};
|
||||||
|
transformations.push((c, n));
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (char_index, c) in grapheme.char_indices() {
|
||||||
|
let part = &grapheme[char_index..char_index + c.len_utf8()];
|
||||||
|
if let Some(norm) = self.transform(part) {
|
||||||
|
let new_count = norm.chars().count() as isize;
|
||||||
|
for (i, c) in norm.chars().enumerate() {
|
||||||
|
let n = if i == 0 {
|
||||||
|
new_count - old_count
|
||||||
|
} else {
|
||||||
|
i as isize
|
||||||
|
};
|
||||||
|
transformations.push((c, n));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
transformations.push((c, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
normalized.transform(transformations.into_iter(), 0);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
@ -36,7 +36,70 @@ impl Normalizer for NFKC {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn do_nmt(normalized: &mut NormalizedString) {
|
||||||
|
// Ascii Control characters
|
||||||
|
normalized
|
||||||
|
.filter(|c| match c as u32 {
|
||||||
|
0x0001..=0x0008 => false,
|
||||||
|
0x000B => false,
|
||||||
|
0x000E..=0x001F => false,
|
||||||
|
0x007F => false,
|
||||||
|
0x008F => false,
|
||||||
|
0x009F => false,
|
||||||
|
_ => true,
|
||||||
|
})
|
||||||
|
// Other code points considered as whitespace.
|
||||||
|
.map(|c| match c as u32 {
|
||||||
|
0x0009 => ' ',
|
||||||
|
0x000A => ' ',
|
||||||
|
0x000C => ' ',
|
||||||
|
0x000D => ' ',
|
||||||
|
0x1680 => ' ',
|
||||||
|
0x200B..=0x200F => ' ',
|
||||||
|
0x2028 => ' ',
|
||||||
|
0x2029 => ' ',
|
||||||
|
0x2581 => ' ',
|
||||||
|
0xFEFF => ' ',
|
||||||
|
0xFFFD => ' ',
|
||||||
|
_ => c,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Copy, Clone, Debug)]
|
||||||
|
pub struct Nmt;
|
||||||
|
impl Normalizer for Nmt {
|
||||||
|
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
|
||||||
|
do_nmt(normalized);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl_serde_unit_struct!(NFCVisitor, NFC);
|
impl_serde_unit_struct!(NFCVisitor, NFC);
|
||||||
impl_serde_unit_struct!(NFCKVisitor, NFKC);
|
impl_serde_unit_struct!(NFCKVisitor, NFKC);
|
||||||
impl_serde_unit_struct!(NFKDVisitor, NFKD);
|
impl_serde_unit_struct!(NFKDVisitor, NFKD);
|
||||||
impl_serde_unit_struct!(NFDVisitor, NFD);
|
impl_serde_unit_struct!(NFDVisitor, NFD);
|
||||||
|
impl_serde_unit_struct!(NMTVisitor, Nmt);
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_nfkc() {
|
||||||
|
let original = "\u{fb01}".to_string();
|
||||||
|
let normalized = "fi".to_string();
|
||||||
|
let mut n = NormalizedString::from(original.clone());
|
||||||
|
NFKC.normalize(&mut n).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
n,
|
||||||
|
NormalizedString::new(
|
||||||
|
original,
|
||||||
|
normalized,
|
||||||
|
vec![(0, 3), (0, 3)],
|
||||||
|
vec![(0, 2), (0, 2), (0, 2)],
|
||||||
|
0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
144
tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs
Normal file
144
tokenizers/src/pre_tokenizers/unicode_scripts/pre_tokenizer.rs
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script};
|
||||||
|
use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct UnicodeScripts;
|
||||||
|
impl_serde_unit_struct!(UnicodeScriptsVisitor, UnicodeScripts);
|
||||||
|
|
||||||
|
impl UnicodeScripts {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for UnicodeScripts {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This code exists in the Unigram default IsValidSentencePiece.
|
||||||
|
// It could be integrated directly within `get_script` but I
|
||||||
|
// think it's kind of tricky to see those modifications later
|
||||||
|
// I am guessing release mode will optimize this away anyway.
|
||||||
|
fn fixed_script(c: char) -> Script {
|
||||||
|
let raw_script = get_script(c);
|
||||||
|
if c as u32 == 0x30FC {
|
||||||
|
Script::Han
|
||||||
|
} else if c == ' ' {
|
||||||
|
Script::Any
|
||||||
|
} else {
|
||||||
|
match raw_script {
|
||||||
|
Script::Hiragana => Script::Han,
|
||||||
|
Script::Katakana => Script::Han,
|
||||||
|
script => script,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PreTokenizer for UnicodeScripts {
|
||||||
|
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
|
||||||
|
pretokenized.split(|_, normalized| {
|
||||||
|
let mut last_script = None;
|
||||||
|
let mut offset = 0;
|
||||||
|
let mut ranges: Vec<_> = normalized
|
||||||
|
.get()
|
||||||
|
.chars()
|
||||||
|
.filter_map(|c| {
|
||||||
|
let script = Some(fixed_script(c));
|
||||||
|
let result = if script != Some(Script::Any)
|
||||||
|
&& last_script != Some(Script::Any)
|
||||||
|
&& last_script != script
|
||||||
|
{
|
||||||
|
Some(offset)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
offset += c.len_utf8();
|
||||||
|
if script != Some(Script::Any) {
|
||||||
|
last_script = script;
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
ranges.push(normalized.get().len());
|
||||||
|
Ok(ranges
|
||||||
|
.windows(2)
|
||||||
|
.map(|item| {
|
||||||
|
normalized
|
||||||
|
.slice(Range::Normalized(item[0]..item[1]))
|
||||||
|
.expect("NormalizedString bad split")
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::OffsetReferential;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn basic() {
|
||||||
|
let pretok = UnicodeScripts::default();
|
||||||
|
let mut pretokenized = PreTokenizedString::from("どこで生れ。Yes");
|
||||||
|
pretok.pre_tokenize(&mut pretokenized).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
pretokenized
|
||||||
|
.get_splits(OffsetReferential::Normalized)
|
||||||
|
.into_iter()
|
||||||
|
.map(|(s, o, _)| (s, o))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
vec![("どこで生れ", (0, 15)), ("。", (15, 18)), ("Yes", (18, 21))]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
pretokenized
|
||||||
|
.get_splits(OffsetReferential::Original)
|
||||||
|
.into_iter()
|
||||||
|
.map(|(s, o, _)| (s, o))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
vec![("どこで生れ", (0, 15)), ("。", (15, 18)), ("Yes", (18, 21))]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn spaces_are_included_in_every_script() {
|
||||||
|
let pretok = UnicodeScripts::default();
|
||||||
|
let mut pretokenized = PreTokenizedString::from("Apples are りんご 林檎");
|
||||||
|
pretok.pre_tokenize(&mut pretokenized).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
pretokenized
|
||||||
|
.get_splits(OffsetReferential::Normalized)
|
||||||
|
.into_iter()
|
||||||
|
.map(|(s, o, _)| (s, o))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
vec![("Apples are ", (0, 11)), ("りんご 林檎", (11, 27))]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
pretokenized
|
||||||
|
.get_splits(OffsetReferential::Original)
|
||||||
|
.into_iter()
|
||||||
|
.map(|(s, o, _)| (s, o))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
vec![("Apples are ", (0, 11)), ("りんご 林檎", (11, 27))]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unicode_script() {
|
||||||
|
assert_eq!(Script::Han, fixed_script('京'));
|
||||||
|
assert_eq!(Script::Han, fixed_script('太'));
|
||||||
|
assert_eq!(Script::Han, fixed_script('い'));
|
||||||
|
assert_eq!(Script::Han, fixed_script('グ'));
|
||||||
|
assert_eq!(Script::Han, fixed_script('ー'));
|
||||||
|
assert_eq!(Script::Latin, fixed_script('a'));
|
||||||
|
assert_eq!(Script::Latin, fixed_script('A'));
|
||||||
|
assert_eq!(Script::Common, fixed_script('0'));
|
||||||
|
assert_eq!(Script::Common, fixed_script('$'));
|
||||||
|
assert_eq!(Script::Common, fixed_script('@'));
|
||||||
|
assert_eq!(Script::Common, fixed_script('-'));
|
||||||
|
assert_eq!(Script::Any, fixed_script(' '));
|
||||||
|
}
|
||||||
|
}
|
@ -144,8 +144,8 @@ pub enum Script {
|
|||||||
Yi,
|
Yi,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_script(c: &char) -> Script {
|
pub fn get_script(c: char) -> Script {
|
||||||
match *c as u32 {
|
match c as u32 {
|
||||||
0x0000..=0x001F => Script::Common,
|
0x0000..=0x001F => Script::Common,
|
||||||
0x0020 => Script::Common,
|
0x0020 => Script::Common,
|
||||||
0x0021..=0x0023 => Script::Common,
|
0x0021..=0x0023 => Script::Common,
|
||||||
@ -2078,16 +2078,18 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_unicode_script() {
|
fn test_unicode_script() {
|
||||||
assert_eq!(Script::Han, get_script(&'京'));
|
assert_eq!(Script::Han, get_script('京'));
|
||||||
assert_eq!(Script::Han, get_script(&'太'));
|
assert_eq!(Script::Han, get_script('太'));
|
||||||
assert_eq!(Script::Hiragana, get_script(&'い'));
|
assert_eq!(Script::Hiragana, get_script('い'));
|
||||||
assert_eq!(Script::Katakana, get_script(&'グ'));
|
assert_eq!(Script::Katakana, get_script('グ'));
|
||||||
assert_eq!(Script::Common, get_script(&'ー'));
|
assert_eq!(Script::Common, get_script('ー'));
|
||||||
assert_eq!(Script::Latin, get_script(&'a'));
|
assert_eq!(Script::Latin, get_script('a'));
|
||||||
assert_eq!(Script::Latin, get_script(&'A'));
|
assert_eq!(Script::Latin, get_script('A'));
|
||||||
assert_eq!(Script::Common, get_script(&'0'));
|
assert_eq!(Script::Common, get_script('0'));
|
||||||
assert_eq!(Script::Common, get_script(&'$'));
|
assert_eq!(Script::Common, get_script('$'));
|
||||||
assert_eq!(Script::Common, get_script(&'@'));
|
assert_eq!(Script::Common, get_script('@'));
|
||||||
assert_eq!(Script::Common, get_script(&'-'));
|
assert_eq!(Script::Common, get_script('-'));
|
||||||
|
assert_eq!(Script::Common, get_script(' '));
|
||||||
|
assert_eq!(Script::Common, get_script('<27>'));
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -124,6 +124,22 @@ pub struct NormalizedString {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl NormalizedString {
|
impl NormalizedString {
|
||||||
|
#[cfg(test)]
|
||||||
|
pub(crate) fn new(
|
||||||
|
original: String,
|
||||||
|
normalized: String,
|
||||||
|
alignments: Vec<(usize, usize)>,
|
||||||
|
alignments_original: Vec<(usize, usize)>,
|
||||||
|
original_shift: usize,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
original,
|
||||||
|
normalized,
|
||||||
|
alignments,
|
||||||
|
alignments_original,
|
||||||
|
original_shift,
|
||||||
|
}
|
||||||
|
}
|
||||||
/// Return the normalized string
|
/// Return the normalized string
|
||||||
pub fn get(&self) -> &str {
|
pub fn get(&self) -> &str {
|
||||||
&self.normalized
|
&self.normalized
|
||||||
|
Reference in New Issue
Block a user