Improvements on spm parity: (#401)

* Removing all pre_tokenizer logic from Unigram algorithm.

* Improving *a lot* the parity check.

- We can now detect a lot more errors
- Special cases have been added temporarily.

* Adding 2 new normalizers that mimick spm defaut's behavior.

* Adding `encoding_optimized` version of the `encode` algorithm.

- Removes Lattice allocation.
- Changes trie `common_prefix_search` to return an iterator to avoid
  allocation of the full results.

* Trie<char> -> Trie<u8> Another improvement on speed.

* [WIP] Attempt to create a Precompiled Normalizer from SPM to be 100%
compliant with arbitrary models.

* Adding a new `Precompiled` Normalizer that is replacing `SpmNmtNfkc`.

- It will be used for direct compatiblity with `Spm` and replace all
their custom rules by using directly the normalizer spec embedded
within spm files, removing all need for any rules for us.
- We need `nom` dependency to parse the binary format of `spm`.
- We need to add `sentencepiece_model_pb2.py` file to be able to read
  the proto file.
- We reimplemented their `Darts::DoubleArray` compact trie format.

* Fixing a bug with Precompiled normalizer.

* Fixing some edge cases (now in tests) with this weird precompiled
normalizer.

It seems a very handy crafted trie does not prevent from shooting
oneself in the foot. Sorry future reader.

* Keep API stable for this PR (change of the API should come later #409).

- Removed sentencepiece_model_pb2 from binding and add instructions to
make `from_spm` work.

* Adding model check in `from_spm`.

* Adressing @n1t0's comments.

* Adding a check to make sure alignments stay correct.

Also added a bit more documentation on how Precompiled works.

* Extracting `Precompiled` into it's own `spm_precompiled` crate.

* Using ranges in `do_nmt`.
This commit is contained in:
Nicolas Patry
2020-09-15 22:21:02 +02:00
committed by GitHub
parent 62c3d40f11
commit 330876ae02
22 changed files with 897 additions and 207 deletions

View File

@ -33,6 +33,12 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "arrayvec"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cff77d8686867eceff3105329d4698d96c2391c176d5d03adc90c7389162b5b8"
[[package]] [[package]]
name = "atty" name = "atty"
version = "0.2.14" version = "0.2.14"
@ -339,6 +345,19 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "lexical-core"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db65c6da02e61f55dae90a0ae427b2a5f6b3e8db09f58d10efab23af92592616"
dependencies = [
"arrayvec",
"bitflags",
"cfg-if",
"ryu",
"static_assertions",
]
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.74" version = "0.2.74"
@ -452,6 +471,17 @@ dependencies = [
"tokenizers", "tokenizers",
] ]
[[package]]
name = "nom"
version = "5.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffb4262d26ed83a1c0a33a38fe2bb15797329c85770da05e6b828ddb782627af"
dependencies = [
"lexical-core",
"memchr",
"version_check",
]
[[package]] [[package]]
name = "num" name = "num"
version = "0.2.1" version = "0.2.1"
@ -754,6 +784,23 @@ version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3757cb9d89161a2f24e1cf78efa0c1fcff485d18e3f55e0aa3480824ddaa0f3f" checksum = "3757cb9d89161a2f24e1cf78efa0c1fcff485d18e3f55e0aa3480824ddaa0f3f"
[[package]]
name = "spm_precompiled"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f78be885c9efc899a7c0348f67c98b488cbeaf2cb608a48fb87ef1484ecab5c5"
dependencies = [
"nom",
"serde",
"unicode-segmentation",
]
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]] [[package]]
name = "strsim" name = "strsim"
version = "0.8.0" version = "0.8.0"
@ -833,6 +880,7 @@ dependencies = [
"regex-syntax", "regex-syntax",
"serde", "serde",
"serde_json", "serde_json",
"spm_precompiled",
"unicode-normalization-alignments", "unicode-normalization-alignments",
"unicode-segmentation", "unicode-segmentation",
"unicode_categories", "unicode_categories",

View File

@ -16,6 +16,11 @@ dependencies = [
"winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)", "winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)",
] ]
[[package]]
name = "arrayvec"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]] [[package]]
name = "atty" name = "atty"
version = "0.2.14" version = "0.2.14"
@ -350,6 +355,18 @@ name = "lazy_static"
version = "1.4.0" version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "lexical-core"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"arrayvec 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
"bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
"ryu 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)",
"static_assertions 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.77" version = "0.2.77"
@ -409,6 +426,16 @@ dependencies = [
"rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", "rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
] ]
[[package]]
name = "nom"
version = "5.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"lexical-core 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)",
"memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
"version_check 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]] [[package]]
name = "num-complex" name = "num-complex"
version = "0.2.4" version = "0.2.4"
@ -741,6 +768,21 @@ name = "smallvec"
version = "1.4.2" version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "spm_precompiled"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"nom 5.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode-segmentation 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]] [[package]]
name = "strsim" name = "strsim"
version = "0.8.0" version = "0.8.0"
@ -834,6 +876,7 @@ dependencies = [
"regex-syntax 0.6.18 (registry+https://github.com/rust-lang/crates.io-index)", "regex-syntax 0.6.18 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)", "serde 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)", "serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)",
"spm_precompiled 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode-normalization-alignments 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)", "unicode-normalization-alignments 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode-segmentation 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)", "unicode-segmentation 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)",
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", "unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
@ -893,6 +936,11 @@ name = "vec_map"
version = "0.8.2" version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "version_check"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]] [[package]]
name = "wasi" name = "wasi"
version = "0.9.0+wasi-snapshot-preview1" version = "0.9.0+wasi-snapshot-preview1"
@ -928,6 +976,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[metadata] [metadata]
"checksum aho-corasick 0.7.13 (registry+https://github.com/rust-lang/crates.io-index)" = "043164d8ba5c4c3035fec9bbee8647c0261d788f3474306f93bb65901cae0e86" "checksum aho-corasick 0.7.13 (registry+https://github.com/rust-lang/crates.io-index)" = "043164d8ba5c4c3035fec9bbee8647c0261d788f3474306f93bb65901cae0e86"
"checksum ansi_term 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b" "checksum ansi_term 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b"
"checksum arrayvec 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cff77d8686867eceff3105329d4698d96c2391c176d5d03adc90c7389162b5b8"
"checksum atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)" = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" "checksum atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)" = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
"checksum autocfg 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" "checksum autocfg 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
"checksum bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" "checksum bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
@ -966,6 +1015,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum itertools 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b" "checksum itertools 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b"
"checksum itoa 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "dc6f3ad7b9d11a0c00842ff8de1b60ee58661048eb8049ed33c73594f359d7e6" "checksum itoa 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "dc6f3ad7b9d11a0c00842ff8de1b60ee58661048eb8049ed33c73594f359d7e6"
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" "checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
"checksum lexical-core 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)" = "db65c6da02e61f55dae90a0ae427b2a5f6b3e8db09f58d10efab23af92592616"
"checksum libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)" = "f2f96b10ec2560088a8e76961b00d47107b3a625fecb76dedb29ee7ccbf98235" "checksum libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)" = "f2f96b10ec2560088a8e76961b00d47107b3a625fecb76dedb29ee7ccbf98235"
"checksum lock_api 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)" = "28247cc5a5be2f05fbcd76dd0cf2c7d3b5400cb978a28042abcd4fa0b3f8261c" "checksum lock_api 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)" = "28247cc5a5be2f05fbcd76dd0cf2c7d3b5400cb978a28042abcd4fa0b3f8261c"
"checksum log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)" = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b" "checksum log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)" = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b"
@ -974,6 +1024,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400" "checksum memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400"
"checksum memoffset 0.5.5 (registry+https://github.com/rust-lang/crates.io-index)" = "c198b026e1bbf08a937e94c6c60f9ec4a2267f5b0d2eec9c1b21b061ce2be55f" "checksum memoffset 0.5.5 (registry+https://github.com/rust-lang/crates.io-index)" = "c198b026e1bbf08a937e94c6c60f9ec4a2267f5b0d2eec9c1b21b061ce2be55f"
"checksum ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)" = "ac06db03ec2f46ee0ecdca1a1c34a99c0d188a0d83439b84bf0cb4b386e4ab09" "checksum ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)" = "ac06db03ec2f46ee0ecdca1a1c34a99c0d188a0d83439b84bf0cb4b386e4ab09"
"checksum nom 5.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "ffb4262d26ed83a1c0a33a38fe2bb15797329c85770da05e6b828ddb782627af"
"checksum num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95" "checksum num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
"checksum num-integer 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)" = "8d59457e662d541ba17869cf51cf177c0b5f0cbf476c66bdc90bf1edac4f875b" "checksum num-integer 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)" = "8d59457e662d541ba17869cf51cf177c0b5f0cbf476c66bdc90bf1edac4f875b"
"checksum num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)" = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611" "checksum num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)" = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611"
@ -1013,6 +1064,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum serde_derive 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)" = "f630a6370fd8e457873b4bd2ffdae75408bc291ba72be773772a4c2a065d9ae8" "checksum serde_derive 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)" = "f630a6370fd8e457873b4bd2ffdae75408bc291ba72be773772a4c2a065d9ae8"
"checksum serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)" = "164eacbdb13512ec2745fb09d51fd5b22b0d65ed294a1dcf7285a360c80a675c" "checksum serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)" = "164eacbdb13512ec2745fb09d51fd5b22b0d65ed294a1dcf7285a360c80a675c"
"checksum smallvec 1.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "fbee7696b84bbf3d89a1c2eccff0850e3047ed46bfcd2e92c29a2d074d57e252" "checksum smallvec 1.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "fbee7696b84bbf3d89a1c2eccff0850e3047ed46bfcd2e92c29a2d074d57e252"
"checksum spm_precompiled 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f78be885c9efc899a7c0348f67c98b488cbeaf2cb608a48fb87ef1484ecab5c5"
"checksum static_assertions 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
"checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" "checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
"checksum strsim 0.9.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c" "checksum strsim 0.9.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
"checksum syn 1.0.40 (registry+https://github.com/rust-lang/crates.io-index)" = "963f7d3cc59b59b9325165add223142bbf1df27655d07789f109896d353d8350" "checksum syn 1.0.40 (registry+https://github.com/rust-lang/crates.io-index)" = "963f7d3cc59b59b9325165add223142bbf1df27655d07789f109896d353d8350"
@ -1029,6 +1082,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" "checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
"checksum unindent 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "af41d708427f8fd0e915dcebb2cae0f0e6acb2a939b2d399c265c39a38a18942" "checksum unindent 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "af41d708427f8fd0e915dcebb2cae0f0e6acb2a939b2d399c265c39a38a18942"
"checksum vec_map 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" "checksum vec_map 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191"
"checksum version_check 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)" = "b5a972e5669d67ba988ce3dc826706fb0a8b01471c088cb0b6110b805cc36aed"
"checksum wasi 0.9.0+wasi-snapshot-preview1 (registry+https://github.com/rust-lang/crates.io-index)" = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" "checksum wasi 0.9.0+wasi-snapshot-preview1 (registry+https://github.com/rust-lang/crates.io-index)" = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
"checksum winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)" = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" "checksum winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)" = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
"checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" "checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"

View File

@ -1,6 +1,14 @@
from tokenizers import Tokenizer, AddedToken, pre_tokenizers, decoders, trainers from tokenizers import (
Tokenizer,
AddedToken,
pre_tokenizers,
decoders,
trainers,
normalizers,
)
import os
from tokenizers.models import Unigram from tokenizers.models import Unigram
from tokenizers.normalizers import NFKC import json
from .base_tokenizer import BaseTokenizer from .base_tokenizer import BaseTokenizer
from typing import Optional, List, Union from typing import Optional, List, Union
@ -16,11 +24,12 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
self, vocab: Optional[str] = None, replacement: str = "", add_prefix_space: bool = True, self, vocab: Optional[str] = None, replacement: str = "", add_prefix_space: bool = True,
): ):
if vocab is not None: if vocab is not None:
# Let Unigram(..) fail if only one of them is None
tokenizer = Tokenizer(Unigram(vocab)) tokenizer = Tokenizer(Unigram(vocab))
else: else:
tokenizer = Tokenizer(Unigram()) tokenizer = Tokenizer(Unigram())
tokenizer.normalizer = NFKC() tokenizer.normalizer = normalizers.Sequence([normalizers.Nmt(), normalizers.NFKC(),])
tokenizer.pre_tokenizer = pre_tokenizers.Sequence( tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[ [
pre_tokenizers.WhitespaceSplit(), pre_tokenizers.WhitespaceSplit(),
@ -57,3 +66,63 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
if isinstance(files, str): if isinstance(files, str):
files = [files] files = [files]
self._tokenizer.train(trainer, files) self._tokenizer.train(trainer, files)
@staticmethod
def from_spm(filename: str):
try:
import sys
sys.path.append(".")
import sentencepiece_model_pb2 as model
except Exception:
raise Exception(
"You don't seem to have the required protobuf file, in order to use this function you need to run `pip install protobuf` and `wget https://raw.githubusercontent.com/google/sentencepiece/master/python/sentencepiece_model_pb2.py` for us to be able to read the intrinsics of your spm_file. `pip install sentencepiece` is not required."
)
m = model.ModelProto()
m.ParseFromString(open(filename, "rb").read())
precompiled_charsmap = m.normalizer_spec.precompiled_charsmap
vocab = [(piece.piece, piece.score) for piece in m.pieces]
unk_id = m.trainer_spec.unk_id
model_type = m.trainer_spec.model_type
if model_type != 1:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)
data = {"unk_id": unk_id, "vocab": vocab}
replacement = ""
add_prefix_space = True
out_vocab_filename = f"{filename}.json"
try:
with open(out_vocab_filename, "w") as f:
json.dump(data, f, indent=4)
tokenizer = Tokenizer(Unigram(out_vocab_filename))
finally:
os.remove(out_vocab_filename)
tokenizer.normalizer = normalizers.Precompiled(precompiled_charsmap)
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.WhitespaceSplit(),
pre_tokenizers.Metaspace(
replacement=replacement, add_prefix_space=add_prefix_space
),
]
)
tokenizer.decoder = decoders.Metaspace(
replacement=replacement, add_prefix_space=add_prefix_space
)
parameters = {
"model": "SentencePieceUnigram",
}
obj = BaseTokenizer.__new__(SentencePieceUnigramTokenizer, tokenizer, parameters)
BaseTokenizer.__init__(obj, tokenizer, parameters)
return obj

View File

@ -9,6 +9,8 @@ NFKC = normalizers.NFKC
Sequence = normalizers.Sequence Sequence = normalizers.Sequence
Lowercase = normalizers.Lowercase Lowercase = normalizers.Lowercase
Strip = normalizers.Strip Strip = normalizers.Strip
Nmt = normalizers.Nmt
Precompiled = normalizers.Precompiled
NORMALIZERS = {"nfc": NFC, "nfd": NFD, "nfkc": NFKC, "nfkd": NFKD} NORMALIZERS = {"nfc": NFC, "nfd": NFD, "nfkc": NFKC, "nfkd": NFKD}

View File

@ -99,6 +99,18 @@ class Strip(Normalizer):
def __init__(self, left: bool = True, right: bool = True) -> Normalizer: def __init__(self, left: bool = True, right: bool = True) -> Normalizer:
pass pass
class Nmt(Normalizer):
""" Nmt normalizer """
def __init__(self) -> Normalizer:
pass
class Precompiled(Normalizer):
""" SpmNmtNfkc normalizer """
def __init__(self, precompiled_charsmap: bytes) -> Normalizer:
pass
def unicode_normalizer_from_str(normalizer: str) -> Normalizer: def unicode_normalizer_from_str(normalizer: str) -> Normalizer:
""" """
Instanciate unicode normalizer from the normalizer name Instanciate unicode normalizer from the normalizer name

View File

@ -1,7 +1,17 @@
import tokenizers import tokenizers
from argparse import ArgumentParser from argparse import ArgumentParser
import sentencepiece as spm import sentencepiece as spm
from collections import Counter
import json import json
import os
import datetime
try:
from termcolor import colored
has_color = True
except Exception:
has_color = False
def main(): def main():
@ -9,38 +19,62 @@ def main():
parser.add_argument( parser.add_argument(
"--input-file", "-i", type=str, required=True, help="Which files do you want to train from", "--input-file", "-i", type=str, required=True, help="Which files do you want to train from",
) )
parser.add_argument(
"--model-file",
"-m",
type=str,
required=False,
default=None,
help="Use a pretrained token file",
)
parser.add_argument( parser.add_argument(
"--model-prefix", type=str, default="spm_parity", help="Model prefix for spm_train", "--model-prefix", type=str, default="spm_parity", help="Model prefix for spm_train",
) )
parser.add_argument( parser.add_argument(
"--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train", "--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train",
) )
parser.add_argument(
"--verbose", action="store_true", help="Verbosity",
)
parser.add_argument( parser.add_argument(
"--train", "--train",
action="store_true", action="store_true",
help="Instead of checking the encoder part, we check the trainer part", help="Instead of checking the encoder part, we check the trainer part",
) )
parser.add_argument(
"--from-spm",
action="store_true",
help="Directly load the spm file with it's own normalizer",
)
args = parser.parse_args() args = parser.parse_args()
spm.SentencePieceTrainer.Train( trained = False
f"--input={args.input_file} --model_prefix={args.model_prefix}" if args.model_file is None:
f" --character_coverage=1.0" spm.SentencePieceTrainer.Train(
f" --max_sentence_length=40000" f"--input={args.input_file} --model_prefix={args.model_prefix}"
f" --num_threads=1" f" --character_coverage=1.0"
f" --vocab_size={args.vocab_size}" f" --max_sentence_length=40000"
) f" --num_threads=1"
f" --vocab_size={args.vocab_size}"
)
trained = True
args.model_file = f"{args.model_prefix}.model"
if args.train: try:
check_train(args) if args.train:
else: check_train(args)
check_encode(args) else:
check_encode(args)
finally:
if trained:
os.remove(f"{args.model_prefix}.model")
os.remove(f"{args.model_prefix}.vocab")
def check_train(args): def check_train(args):
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
model_filename = f"{args.model_prefix}.model" sp.Load(args.model_file)
sp.Load(model_filename)
tokenizer = tokenizers.SentencePieceUnigramTokenizer() tokenizer = tokenizers.SentencePieceUnigramTokenizer()
tokenizer.train(args.input_file, show_progress=False) tokenizer.train(args.input_file, show_progress=False)
@ -77,38 +111,144 @@ def check_train(args):
assert ( assert (
tokenizer_tokens < spm_tokens tokenizer_tokens < spm_tokens
), "Our trainer should be at least more efficient than the SPM one" ), "Our trainer should be at least more efficient than the SPM one"
print("Ok our trainer is at least more efficient than the SPM one")
def check_diff(spm_diff, tok_diff, sp, tok):
if spm_diff == list(reversed(tok_diff)):
# AAA -> AA+A vs A+AA case.
return True
elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(tok_diff):
# Second order OK
# Barrich -> Barr + ich vs Bar + rich
return True
spm_reencoded = sp.encode(sp.decode(spm_diff))
tok_reencoded = tok.encode(tok.decode(spm_diff)).ids
if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded:
# Type 3 error.
# Snehagatha ->
# Sne, h, aga, th, a
# Sne, ha, gat, ha
# Encoding the wrong with sp does not even recover what spm gave us
# It fits tokenizer however...
return True
return False
def check_details(line, spm_ids, tok_ids, tok, sp):
# Encoding can be the same with same result AAA -> A + AA vs AA + A
# We can check that we use at least exactly the same number of tokens.
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
if spm_id != tok_id:
break
first = i
for i, (spm_id, tok_id) in enumerate(zip(reversed(spm_ids), reversed(tok_ids))):
if spm_id != tok_id:
break
last = len(spm_ids) - i
spm_diff = spm_ids[first:last]
tok_diff = tok_ids[first:last]
if check_diff(spm_diff, tok_diff, sp, tok):
return True
if last - first > 5:
# We might have twice a single problem, attempt to subdivide the disjointed tokens into smaller problems
spms = Counter(spm_ids[first:last])
toks = Counter(tok_ids[first:last])
removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
min_width = 3
for i in range(last - first - min_width):
if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
possible_matches = [
k
for k in range(last - first - min_width)
if tok_ids[first + k : first + k + min_width]
== spm_ids[first + i : first + i + min_width]
]
for j in possible_matches:
if check_diff(
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
) and check_diff(spm_ids[first + i : last], tok_ids[first + j : last], sp, tok):
return True
ok_start = tok.decode(spm_ids[:first])
ok_end = tok.decode(spm_ids[last:])
wrong = tok.decode(spm_ids[first:last])
print()
if has_color:
print(f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}")
else:
print(wrong)
print(f"Spm: {[tok.decode([spm_ids[i]]) for i in range(first, last)]}")
print(f"Tok: {[tok.decode([tok_ids[i]]) for i in range(first, last)]}")
return False
def check_encode(args): def check_encode(args):
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
model_filename = f"{args.model_prefix}.model" sp.Load(args.model_file)
sp.Load(model_filename)
vocab_filename = f"{args.model_prefix}.json" if args.from_spm:
tok = tokenizers.SentencePieceUnigramTokenizer.from_spm(args.model_file)
else:
vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())]
vocab_filename = f"{args.model_file}.json"
unk_id = sp.unk_id()
vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())] data = {"unk_id": unk_id, "vocab": vocab}
try:
with open(vocab_filename, "w") as f:
json.dump(data, f, indent=4)
data = {"unk_id": sp.unk_id(), "vocab": vocab} tok = tokenizers.SentencePieceUnigramTokenizer(vocab_filename)
finally:
os.remove(vocab_filename)
with open(vocab_filename, "w") as f: perfect = 0
json.dump(data, f, indent=4) imperfect = 0
wrong = 0
tok = tokenizers.SentencePieceUnigramTokenizer(vocab_filename) now = datetime.datetime.now
with open(args.input_file, "r") as f: spm_total_time = datetime.timedelta(seconds=0)
tok_total_time = datetime.timedelta(seconds=0)
with open(args.input_file, "r", encoding="utf-8-sig") as f:
for i, line in enumerate(f): for i, line in enumerate(f):
line = line.strip() line = line.strip()
start = now()
ids = sp.EncodeAsIds(line) ids = sp.EncodeAsIds(line)
spm_time = now()
encoded = tok.encode(line) encoded = tok.encode(line)
tok_time = now()
spm_total_time += spm_time - start
tok_total_time += tok_time - spm_time
if args.verbose:
if i % 10000 == 0:
print(
f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})"
)
print(f"SPM: {spm_total_time} - TOK: {tok_total_time}")
if ids != encoded.ids: if ids != encoded.ids:
# Encoding can be the same with same result AAA -> A + AA vs AA + A if check_details(line, ids, encoded.ids, tok, sp):
# We can check that we use at least exactly the same number of tokens. imperfect += 1
assert len(ids) == len(encoded.ids) continue
continue else:
wrong += 1
else:
perfect += 1
assert ids == encoded.ids, f"line {i}: {line} : {ids} != {encoded.ids}" assert ids == encoded.ids, f"line {i}: {line} : {ids} != {encoded.ids}"
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
total = perfect + imperfect + wrong
print(f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -15,7 +15,7 @@ setup(
author_email="anthony@huggingface.co", author_email="anthony@huggingface.co",
url="https://github.com/huggingface/tokenizers", url="https://github.com/huggingface/tokenizers",
license="Apache License 2.0", license="Apache License 2.0",
rust_extensions=[RustExtension("tokenizers.tokenizers", binding=Binding.PyO3)], rust_extensions=[RustExtension("tokenizers.tokenizers", binding=Binding.PyO3, debug=False)],
extras_require=extras, extras_require=extras,
classifiers=[ classifiers=[
"Development Status :: 5 - Production/Stable", "Development Status :: 5 - Production/Stable",

View File

@ -108,6 +108,8 @@ fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<normalizers::PySequence>()?; m.add_class::<normalizers::PySequence>()?;
m.add_class::<normalizers::PyLowercase>()?; m.add_class::<normalizers::PyLowercase>()?;
m.add_class::<normalizers::PyStrip>()?; m.add_class::<normalizers::PyStrip>()?;
m.add_class::<normalizers::PyNmt>()?;
m.add_class::<normalizers::PyPrecompiled>()?;
Ok(()) Ok(())
} }

View File

@ -263,21 +263,19 @@ pub struct PyUnigram {}
#[pymethods] #[pymethods]
impl PyUnigram { impl PyUnigram {
#[new] #[new]
fn new(vocab: Option<&str>) -> PyResult<(Self, PyModel)> { fn new(vocab: Option<String>) -> PyResult<(Self, PyModel)> {
if let Some(vocab) = vocab { match vocab {
let path = Path::new(vocab); Some(vocab) => match Unigram::load(&std::path::Path::new(&vocab)) {
match Unigram::load(path) {
Err(e) => { Err(e) => {
println!("Errors: {:?}", e); println!("Errors: {:?}", e);
Err(exceptions::Exception::py_err("Error while loading Unigram")) Err(exceptions::Exception::py_err("Error while loading Unigram"))
} }
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))), Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
} },
} else { None => Ok((
Ok((
PyUnigram {}, PyUnigram {},
PyModel::new(Arc::new(Unigram::default().into())), PyModel::new(Arc::new(Unigram::default().into())),
)) )),
} }
} }
} }

View File

@ -7,7 +7,9 @@ use pyo3::types::*;
use crate::error::ToPyResult; use crate::error::ToPyResult;
use serde::ser::SerializeStruct; use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize, Serializer}; use serde::{Deserialize, Serialize, Serializer};
use tk::normalizers::{BertNormalizer, Lowercase, NormalizerWrapper, Strip, NFC, NFD, NFKC, NFKD}; use tk::normalizers::{
BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Precompiled, Strip, NFC, NFD, NFKC, NFKD,
};
use tk::{NormalizedString, Normalizer}; use tk::{NormalizedString, Normalizer};
use tokenizers as tk; use tokenizers as tk;
@ -45,6 +47,10 @@ impl PyNormalizer {
NormalizerWrapper::Lowercase(_) => { NormalizerWrapper::Lowercase(_) => {
Py::new(py, (PyLowercase {}, base)).map(Into::into) Py::new(py, (PyLowercase {}, base)).map(Into::into)
} }
NormalizerWrapper::Precompiled(_) => {
Py::new(py, (PyPrecompiled {}, base)).map(Into::into)
}
NormalizerWrapper::Nmt(_) => Py::new(py, (PyNmt {}, base)).map(Into::into),
}, },
} }
} }
@ -273,6 +279,37 @@ impl Normalizer for PyNormalizerWrapper {
} }
} }
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Nmt)]
pub struct PyNmt {}
#[pymethods]
impl PyNmt {
#[new]
fn new() -> PyResult<(Self, PyNormalizer)> {
Ok((PyNmt {}, Nmt.into()))
}
}
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Precompiled)]
pub struct PyPrecompiled {}
#[pymethods]
impl PyPrecompiled {
#[new]
fn new(py_precompiled_charsmap: &PyBytes) -> PyResult<(Self, PyNormalizer)> {
let precompiled_charsmap: &[u8] = FromPyObject::extract(py_precompiled_charsmap)?;
Ok((
PyPrecompiled {},
Precompiled::from(precompiled_charsmap)
.map_err(|e| {
exceptions::Exception::py_err(format!(
"Error while attempting to build Precompiled normalizer: {}",
e.to_string()
))
})?
.into(),
))
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use pyo3::{AsPyRef, Python}; use pyo3::{AsPyRef, Python};

View File

@ -52,6 +52,7 @@ itertools = "0.9"
log = "0.4" log = "0.4"
esaxx-rs = "0.1" esaxx-rs = "0.1"
derive_builder = "0.9" derive_builder = "0.9"
spm_precompiled = "0.1"
[dev-dependencies] [dev-dependencies]
criterion = "0.3" criterion = "0.3"

View File

@ -4,7 +4,6 @@ mod model;
mod serialization; mod serialization;
mod trainer; mod trainer;
mod trie; mod trie;
mod unicode;
pub use lattice::*; pub use lattice::*;
pub use model::*; pub use model::*;

View File

@ -1,7 +1,7 @@
use crate::models::unigram::lattice::Lattice; use crate::models::unigram::lattice::Lattice;
use crate::models::unigram::trie::{Trie, TrieBuilder}; use crate::models::unigram::trie::{Trie, TrieBuilder};
use crate::tokenizer::{Model, Result, Token}; use crate::tokenizer::{Model, Result, Token};
use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY}; use crate::utils::cache::Cache;
use std::collections::HashMap; use std::collections::HashMap;
use std::convert::TryInto; use std::convert::TryInto;
@ -17,7 +17,7 @@ pub struct Unigram {
token_to_ids: TokenMap, token_to_ids: TokenMap,
pub(crate) vocab: Vocab, pub(crate) vocab: Vocab,
cache: Cache<String, Vec<String>>, cache: Cache<String, Vec<String>>,
trie: Trie<char>, trie: Trie<u8>,
pub min_score: f64, pub min_score: f64,
pub(super) unk_id: usize, pub(super) unk_id: usize,
pub(super) bos_id: usize, pub(super) bos_id: usize,
@ -38,13 +38,13 @@ impl Clone for Unigram {
let fresh_cache = self.cache.fresh(); let fresh_cache = self.cache.fresh();
Self { Self {
vocab: self.vocab.clone(), vocab: self.vocab.clone(),
token_to_ids: self.token_to_ids.clone(),
cache: fresh_cache, cache: fresh_cache,
unk_id: self.unk_id, token_to_ids: self.token_to_ids.clone(),
trie: self.trie.clone(),
min_score: self.min_score, min_score: self.min_score,
unk_id: self.unk_id,
bos_id: self.bos_id, bos_id: self.bos_id,
eos_id: self.eos_id, eos_id: self.eos_id,
trie: self.trie.clone(),
fuse_unk: self.fuse_unk, fuse_unk: self.fuse_unk,
} }
} }
@ -54,6 +54,7 @@ impl std::fmt::Debug for Unigram {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("BPE") fmt.debug_struct("BPE")
.field("vocab", &self.vocab.len()) .field("vocab", &self.vocab.len())
.field("unk_id", &self.unk_id)
.finish() .finish()
} }
} }
@ -113,19 +114,17 @@ impl Unigram {
let mut min_score = f64::INFINITY; let mut min_score = f64::INFINITY;
for (id, (token, score)) in vocab.iter().enumerate() { for (id, (token, score)) in vocab.iter().enumerate() {
token_to_ids.insert(token.to_string(), id as u32); token_to_ids.insert(token.to_string(), id as u32);
let chars: Vec<char> = token.chars().collect(); let bytes: Vec<u8> = token.bytes().collect();
builder.push(&chars); builder.push(&bytes);
if score < &min_score { if score < &min_score {
min_score = *score; min_score = *score;
} }
} }
let trie = builder.build(); let trie = builder.build();
let fuse_unk = true; let fuse_unk = true;
let cache = Cache::new(DEFAULT_CACHE_CAPACITY);
Ok(Unigram { Ok(Unigram {
vocab, vocab,
cache,
token_to_ids, token_to_ids,
trie, trie,
min_score, min_score,
@ -133,6 +132,7 @@ impl Unigram {
eos_id, eos_id,
unk_id, unk_id,
fuse_unk, fuse_unk,
cache: Cache::default(),
}) })
} }
@ -151,26 +151,29 @@ impl Unigram {
let len = lattice.len(); let len = lattice.len();
for begin_pos in 0..len { let mut begin_pos = 0;
let trie_results: Vec<String> = self while begin_pos < len {
.trie let mblen = lattice.sentence[begin_pos..]
.common_prefix_search(lattice.sentence.chars().skip(begin_pos)) .chars()
.iter() .next()
.map(|chars| chars.iter().collect()) .unwrap()
.collect(); .len_utf8();
let mut has_single_node = false; let mut has_single_node = false;
for tok in trie_results { for bytes in self
let n = tok.chars().count(); .trie
.common_prefix_search(lattice.sentence.bytes().skip(begin_pos))
{
let n = bytes.len();
let tok = String::from_utf8(bytes).unwrap();
let id = *self.token_to_ids.get(&tok).unwrap(); let id = *self.token_to_ids.get(&tok).unwrap();
let item = &self.vocab[id as usize]; let item = &self.vocab[id as usize];
assert_eq!(item.0, tok); assert_eq!(item.0, tok);
let score: f64 = item.1; let score: f64 = item.1;
lattice.insert(begin_pos, n, score, id.try_into().unwrap()); lattice.insert(begin_pos, n, score, id.try_into().unwrap());
if !has_single_node && n == 1 { if !has_single_node && n == mblen {
has_single_node = true; has_single_node = true;
} }
} }
@ -178,6 +181,7 @@ impl Unigram {
if !has_single_node { if !has_single_node {
lattice.insert(begin_pos, 1, unk_score, self.unk_id); lattice.insert(begin_pos, 1, unk_score, self.unk_id);
} }
begin_pos += mblen
} }
} }
@ -202,17 +206,129 @@ impl Unigram {
/// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]); /// assert_eq!(result, vec!["abcd", "a", "cd", "xx"]);
/// ``` /// ```
pub fn encode(&self, sentence: &str) -> Vec<String> { pub fn encode(&self, sentence: &str) -> Vec<String> {
if sentence.is_empty() {
return vec![];
}
if let Some(result) = self.cache.get(sentence) { if let Some(result) = self.cache.get(sentence) {
result result.to_vec()
} else { } else {
let result = self.encode_no_cache(sentence); let result = self.encode_optimized(sentence);
self.cache.set(sentence.to_owned(), result.clone()); self.cache.set(sentence.to_owned(), result.clone());
result result
} }
} }
fn encode_no_cache(&self, sentence: &str) -> Vec<String> {
// TODO optimized version fn encode_optimized(&self, sentence: &str) -> Vec<String> {
// https://github.com/google/sentencepiece/blob/d48247191a6d50e469ed1a4a36e877befffd1851/src/unigram_model.cc#L600 // https://github.com/google/sentencepiece/blob/d48247191a6d50e469ed1a4a36e877befffd1851/src/unigram_model.cc#L600
#[derive(Debug, Clone)]
struct BestPathNode {
/// The vocab id. (maybe UNK)
id: usize,
/// The total score of the best path ending at this node.
best_path_score: f64,
/// The starting position (in utf-8) of this node. The entire best
/// path can be constructed by backtracking along this link.
starts_at: Option<usize>,
};
impl Default for BestPathNode {
fn default() -> Self {
Self {
id: 0,
best_path_score: 0.0,
starts_at: None,
}
}
}
let size = sentence.len();
let unk_score = self.min_score - K_UNK_PENALTY;
let mut best_path_ends_at = vec![BestPathNode::default(); size + 1];
let mut starts_at = 0;
while starts_at < size {
let best_path_score_till_here = best_path_ends_at[starts_at].best_path_score;
let mut has_single_node = false;
let mblen = sentence[starts_at..].chars().next().unwrap().len_utf8();
for tok_bytes in self
.trie
.common_prefix_search(sentence.bytes().skip(starts_at))
{
let key_pos = starts_at + tok_bytes.len();
let token: String = String::from_utf8(tok_bytes).unwrap();
let mut target_node = &mut best_path_ends_at[key_pos];
let length = key_pos - starts_at;
let id = self.token_to_ids.get(&token).unwrap();
let score = self.vocab.get(*id as usize).unwrap().1;
let candidate_best_path_score = score + best_path_score_till_here;
if target_node.starts_at.is_none()
|| candidate_best_path_score > target_node.best_path_score
{
target_node.best_path_score = candidate_best_path_score;
target_node.starts_at = Some(starts_at);
target_node.id = *id as usize;
}
if !has_single_node && length == mblen {
has_single_node = true;
}
}
if !has_single_node {
let mut target_node = &mut best_path_ends_at[starts_at + mblen];
let candidate_best_path_score = unk_score + best_path_score_till_here;
if target_node.starts_at.is_none()
|| candidate_best_path_score > target_node.best_path_score
{
target_node.best_path_score = candidate_best_path_score;
target_node.starts_at = Some(starts_at);
target_node.id = self.unk_id;
}
}
starts_at += mblen
}
let mut ends_at = size;
let mut results: Vec<String> = vec![];
let mut token = vec![];
while ends_at > 0 {
let node = &best_path_ends_at[ends_at];
let starts_at = node.starts_at.unwrap();
if self.fuse_unk && node.id == self.unk_id {
token.push(
String::from_utf8(
sentence
.bytes()
.skip(starts_at)
.take(ends_at - starts_at)
.collect(),
)
.unwrap(),
);
} else {
if !token.is_empty() {
token.reverse();
results.push(token.concat());
token = vec![];
}
results.push(
String::from_utf8(
sentence
.bytes()
.skip(starts_at)
.take(ends_at - starts_at)
.collect(),
)
.unwrap(),
);
}
ends_at = starts_at;
}
if !token.is_empty() {
token.reverse();
results.push(token.concat());
}
results.reverse();
results
}
#[allow(dead_code)]
fn encode_unoptimized(&self, sentence: &str) -> Vec<String> {
let mut lattice = Lattice::from(sentence, self.unk_id, self.bos_id, self.eos_id); let mut lattice = Lattice::from(sentence, self.unk_id, self.bos_id, self.eos_id);
self.populate_nodes(&mut lattice); self.populate_nodes(&mut lattice);
if self.fuse_unk { if self.fuse_unk {
@ -296,11 +412,14 @@ impl Model for Unigram {
Ok(tokens Ok(tokens
.iter() .iter()
.map(|string| { .map(|string| {
let id = self.token_to_ids.get(string).unwrap_or(&0); let id: u32 = match self.token_to_ids.get(string) {
Some(id) => *id,
None => self.unk_id as u32,
};
let len = string.len(); let len = string.len();
let offsets = (offset, offset + len); let offsets = (offset, offset + len);
offset += len; offset += len;
Token::new(*id, string.to_string(), offsets) Token::new(id, string.to_string(), offsets)
}) })
.collect()) .collect())
} }

View File

@ -75,4 +75,15 @@ mod test {
assert_eq!(model, reconstructed); assert_eq!(model, reconstructed);
} }
#[test]
fn test_serialization_unk_id_not_zero() {
let vocab = vec![("a".to_string(), -0.5), ("<unk>".to_string(), 0.0)];
let model = Unigram::from(vocab, 1).unwrap();
let data = serde_json::to_string(&model).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap();
assert_eq!(model, reconstructed);
}
} }

View File

@ -1,8 +1,4 @@
use crate::models::unigram::{ use crate::models::unigram::{lattice::Lattice, model::Unigram};
lattice::Lattice,
model::Unigram,
unicode::{get_script, Script},
};
use crate::tokenizer::{AddedToken, Result, Trainer}; use crate::tokenizer::{AddedToken, Result, Trainer};
use indicatif::{ProgressBar, ProgressStyle}; use indicatif::{ProgressBar, ProgressStyle};
use log::debug; use log::debug;
@ -53,33 +49,9 @@ pub struct UnigramTrainer {
#[builder(default = "vec![]")] #[builder(default = "vec![]")]
special_tokens: Vec<AddedToken>, special_tokens: Vec<AddedToken>,
#[builder(default = "' '")]
space_char: char,
#[builder(default = "String::from(\"<unk>\")")] #[builder(default = "String::from(\"<unk>\")")]
unk_token: String, unk_token: String,
#[builder(default = "false")]
treat_whitespace_as_suffix: bool,
#[builder(default = "true")]
split_by_unicode_script: bool,
#[builder(default = "true")]
split_by_number: bool,
#[builder(default = "false")]
split_by_digits: bool,
/// In spm this parameter defaults to true,
/// we set it to false here because it's supposed
/// to be a job taken care by the pretokenizer. We still
/// have it here to enable easier testing as it does make a difference
/// in `is_valid_sentencepiece` method were we discard seed_pieces if they
/// contain a whitespace. This job can/could be taken elsewhere.
#[builder(default = "false")]
split_by_whitespace: bool,
#[builder(default = "16")] #[builder(default = "16")]
max_piece_length: usize, max_piece_length: usize,
#[builder(default = "1_000_000")] #[builder(default = "1_000_000")]
@ -106,72 +78,19 @@ impl UnigramTrainer {
} }
fn is_valid_sentencepiece(&self, char_string: &[char]) -> bool { fn is_valid_sentencepiece(&self, char_string: &[char]) -> bool {
// TODO check more formally but should be ok. // Checks string length
// Checks string length, space not in the substring, numbers, hiragana and more // Space not in the substring, numbers, hiragana and more should be taken
// care of within pre_tokenizers.
// https://github.com/google/sentencepiece/blob/26be9516cd81d5315ee31c48d2438018e0eab879/src/trainer_interface.cc#L203 // https://github.com/google/sentencepiece/blob/26be9516cd81d5315ee31c48d2438018e0eab879/src/trainer_interface.cc#L203
let n = char_string.len(); let n = char_string.len();
if char_string.is_empty() || n > self.max_piece_length { if char_string.is_empty() || n > self.max_piece_length {
return false; return false;
} }
let mut last_script = Script::Any;
for (i, c) in char_string.iter().enumerate() {
if *c == '\0' {
return false;
}
if *c == self.space_char {
if self.treat_whitespace_as_suffix {
let is_not_suffix_no_whitespace = self.split_by_whitespace && i != n - 1;
let is_prefix = !self.split_by_whitespace && i == 0;
if is_not_suffix_no_whitespace || is_prefix {
return false;
}
} else {
let is_not_prefix_no_whitespace = self.split_by_whitespace && i != 0;
let is_suffix = !self.split_by_whitespace && i == n - 1;
if is_not_prefix_no_whitespace || is_suffix {
return false;
}
}
}
// This function checks that unicode "scripts" are consistent, so we cannot have romaji and
// hiragana for instance. Seems pretty specific. Also Hiragana and katakana are mixed
let raw_script = get_script(c);
let script = if *c as u32 == 0x30FC {
Script::Han
} else if *c == self.space_char || !self.split_by_number && c.is_numeric() {
Script::Any
} else {
match raw_script {
Script::Hiragana => Script::Han,
Script::Katakana => Script::Han,
script => script,
}
};
if self.split_by_digits && c.is_numeric() && n > 1 {
return false;
}
if self.split_by_unicode_script
&& script != Script::Any
&& last_script != Script::Any
&& script != last_script
{
return false;
}
last_script = script;
}
true true
// true
} }
fn finalize(&self, model: Unigram, required_chars: HashSet<String>) -> Result<Unigram> { fn finalize(&self, model: Unigram, required_chars: HashSet<String>) -> Result<Unigram> {
// let mut pieces: Vec<SentencePiece> =
// Vec::with_capacity(self.vocab_size.try_into().unwrap());
let mut min_score_penalty = 0.0; let mut min_score_penalty = 0.0;
let min_score_penalty_delta = 0.0001; let min_score_penalty_delta = 0.0001;
@ -204,7 +123,6 @@ impl UnigramTrainer {
} }
fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> { fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
// TODO more logic needed if this required chars > vocab_size
word_counts word_counts
.iter() .iter()
.map(|(s, _count)| s.chars()) .map(|(s, _count)| s.chars())
@ -292,7 +210,6 @@ impl UnigramTrainer {
pieces: &[SentencePiece], pieces: &[SentencePiece],
sentences: &[Sentence], sentences: &[Sentence],
) -> Vec<SentencePiece> { ) -> Vec<SentencePiece> {
// TODO
let mut always_keep = vec![true; pieces.len()]; let mut always_keep = vec![true; pieces.len()];
let mut alternatives: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()]; let mut alternatives: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()];
@ -494,11 +411,11 @@ impl UnigramTrainer {
.collect(); .collect();
new_pieces new_pieces
} }
pub fn _train(&self, mut sentences: Vec<Sentence>) -> Result<(Unigram, Vec<AddedToken>)> { pub fn _train(&self, sentences: Vec<Sentence>) -> Result<(Unigram, Vec<AddedToken>)> {
let progress = self.setup_progress(); let progress = self.setup_progress();
// //
// 1. Compute frequent substrings // 1. Compute frequent substrings
// TODO should be either i64 or i32 // TODO Should be able to upgrade to u64 when needed
self.update_progress(&progress, sentences.len(), "Suffix array seeds"); self.update_progress(&progress, sentences.len(), "Suffix array seeds");
let mut pieces: Vec<SentencePiece> = let mut pieces: Vec<SentencePiece> =
Vec::with_capacity(self.vocab_size.try_into().unwrap()); Vec::with_capacity(self.vocab_size.try_into().unwrap());
@ -507,26 +424,6 @@ impl UnigramTrainer {
pieces.extend(self.make_seed_sentence_pieces(&sentences, &progress)?); pieces.extend(self.make_seed_sentence_pieces(&sentences, &progress)?);
self.finalize_progress(&progress, sentences.len()); self.finalize_progress(&progress, sentences.len());
if self.split_by_whitespace {
self.update_progress(&progress, sentences.len(), "Splitting by whitespace");
let mut words: HashMap<String, u32> = HashMap::new();
for (sentence, count) in &sentences {
for word in sentence.split(self.space_char) {
if word.is_empty() {
continue;
}
*words
.entry(format!("{}{}", self.space_char, word))
.or_insert(0) += count;
}
if let Some(p) = &progress {
p.inc(1);
}
}
self.finalize_progress(&progress, sentences.len());
sentences = words.into_iter().collect();
}
// Useful to check compatibility with spm. // Useful to check compatibility with spm.
debug!( debug!(
"Using {} pieces on {} sentences for EM training", "Using {} pieces on {} sentences for EM training",
@ -625,8 +522,6 @@ mod tests {
fn test_unigram_chars() { fn test_unigram_chars() {
let trainer = UnigramTrainerBuilder::default() let trainer = UnigramTrainerBuilder::default()
.show_progress(false) .show_progress(false)
.split_by_whitespace(false)
.treat_whitespace_as_suffix(true)
.build() .build()
.unwrap(); .unwrap();

View File

@ -30,23 +30,40 @@ impl<Label: Eq + Hash + Copy> Trie<Label> {
node.is_leaf = true; node.is_leaf = true;
} }
pub fn common_prefix_search(&self, iterator: impl Iterator<Item = Label>) -> Vec<Vec<Label>> { pub fn common_prefix_search<T>(&self, iterator: T) -> TrieIterator<Label, T>
let mut node = &self.root; where
let mut results = vec![]; T: Iterator<Item = Label>,
let mut prefix = vec![]; {
for label in iterator { TrieIterator {
prefix.push(label); node: &self.root,
let child_opt = node.children.get(&label); prefix: vec![],
if let Some(child) = child_opt { iterator,
node = child; }
if node.is_leaf { }
results.push(prefix.clone()); }
}
} else { pub struct TrieIterator<'a, Label, T> {
return results; node: &'a Node<Label>,
prefix: Vec<Label>,
iterator: T,
}
impl<Label, T> Iterator for TrieIterator<'_, Label, T>
where
Label: Eq + Hash + Copy,
T: Iterator<Item = Label>,
{
type Item = Vec<Label>;
fn next(&mut self) -> Option<Self::Item> {
loop {
let label = self.iterator.next()?;
self.prefix.push(label);
let child = self.node.children.get(&label)?;
self.node = child;
if self.node.is_leaf {
return Some(self.prefix.clone());
} }
} }
results
} }
} }

View File

@ -1,11 +1,13 @@
pub mod bert; pub mod bert;
pub mod precompiled;
pub mod strip; pub mod strip;
pub mod unicode; pub mod unicode;
pub mod utils; pub mod utils;
pub use crate::normalizers::bert::BertNormalizer; pub use crate::normalizers::bert::BertNormalizer;
pub use crate::normalizers::precompiled::Precompiled;
pub use crate::normalizers::strip::Strip; pub use crate::normalizers::strip::Strip;
pub use crate::normalizers::unicode::{NFC, NFD, NFKC, NFKD}; pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD};
pub use crate::normalizers::utils::{Lowercase, Sequence}; pub use crate::normalizers::utils::{Lowercase, Sequence};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -24,6 +26,8 @@ pub enum NormalizerWrapper {
NFKD(NFKD), NFKD(NFKD),
Sequence(Sequence), Sequence(Sequence),
Lowercase(Lowercase), Lowercase(Lowercase),
Nmt(Nmt),
Precompiled(Precompiled),
} }
impl Normalizer for NormalizerWrapper { impl Normalizer for NormalizerWrapper {
@ -37,6 +41,8 @@ impl Normalizer for NormalizerWrapper {
NormalizerWrapper::NFKD(nfkd) => nfkd.normalize(normalized), NormalizerWrapper::NFKD(nfkd) => nfkd.normalize(normalized),
NormalizerWrapper::Sequence(sequence) => sequence.normalize(normalized), NormalizerWrapper::Sequence(sequence) => sequence.normalize(normalized),
NormalizerWrapper::Lowercase(lc) => lc.normalize(normalized), NormalizerWrapper::Lowercase(lc) => lc.normalize(normalized),
NormalizerWrapper::Nmt(lc) => lc.normalize(normalized),
NormalizerWrapper::Precompiled(lc) => lc.normalize(normalized),
} }
} }
} }
@ -49,3 +55,5 @@ impl_enum_from!(NFD, NormalizerWrapper, NFD);
impl_enum_from!(Strip, NormalizerWrapper, StripNormalizer); impl_enum_from!(Strip, NormalizerWrapper, StripNormalizer);
impl_enum_from!(Sequence, NormalizerWrapper, Sequence); impl_enum_from!(Sequence, NormalizerWrapper, Sequence);
impl_enum_from!(Lowercase, NormalizerWrapper, Lowercase); impl_enum_from!(Lowercase, NormalizerWrapper, Lowercase);
impl_enum_from!(Nmt, NormalizerWrapper, Nmt);
impl_enum_from!(Precompiled, NormalizerWrapper, Precompiled);

View File

@ -0,0 +1,53 @@
use crate::tokenizer::{NormalizedString, Normalizer, Result};
pub use spm_precompiled::Precompiled;
use unicode_segmentation::UnicodeSegmentation;
impl Normalizer for Precompiled {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
let mut transformations = Vec::with_capacity(normalized.get().len());
// Future reader. From @Narsil.
// Yes, this is weird,
// Yes, this seems broken
// No, I don't know why Google did this.
// If you question this code, check this normalizer against
// XNLI database (all languages) with Unigram model against
// Mbart, XLMRoberta *AND* Marian. If you don't get 100% or
// break a single test.
// You don't pass.
normalized.get().graphemes(true).for_each(|grapheme| {
let old_count = grapheme.chars().count() as isize;
if grapheme.len() < 6 {
if let Some(norm) = self.transform(grapheme) {
let new_count = norm.chars().count() as isize;
for (i, c) in norm.chars().enumerate() {
let n = if i == 0 {
new_count - old_count
} else {
i as isize
};
transformations.push((c, n));
}
return;
}
}
for (char_index, c) in grapheme.char_indices() {
let part = &grapheme[char_index..char_index + c.len_utf8()];
if let Some(norm) = self.transform(part) {
let new_count = norm.chars().count() as isize;
for (i, c) in norm.chars().enumerate() {
let n = if i == 0 {
new_count - old_count
} else {
i as isize
};
transformations.push((c, n));
}
} else {
transformations.push((c, 0));
}
}
});
normalized.transform(transformations.into_iter(), 0);
Ok(())
}
}

View File

@ -36,7 +36,70 @@ impl Normalizer for NFKC {
} }
} }
fn do_nmt(normalized: &mut NormalizedString) {
// Ascii Control characters
normalized
.filter(|c| match c as u32 {
0x0001..=0x0008 => false,
0x000B => false,
0x000E..=0x001F => false,
0x007F => false,
0x008F => false,
0x009F => false,
_ => true,
})
// Other code points considered as whitespace.
.map(|c| match c as u32 {
0x0009 => ' ',
0x000A => ' ',
0x000C => ' ',
0x000D => ' ',
0x1680 => ' ',
0x200B..=0x200F => ' ',
0x2028 => ' ',
0x2029 => ' ',
0x2581 => ' ',
0xFEFF => ' ',
0xFFFD => ' ',
_ => c,
});
}
#[derive(Default, Copy, Clone, Debug)]
pub struct Nmt;
impl Normalizer for Nmt {
fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
do_nmt(normalized);
Ok(())
}
}
impl_serde_unit_struct!(NFCVisitor, NFC); impl_serde_unit_struct!(NFCVisitor, NFC);
impl_serde_unit_struct!(NFCKVisitor, NFKC); impl_serde_unit_struct!(NFCKVisitor, NFKC);
impl_serde_unit_struct!(NFKDVisitor, NFKD); impl_serde_unit_struct!(NFKDVisitor, NFKD);
impl_serde_unit_struct!(NFDVisitor, NFD); impl_serde_unit_struct!(NFDVisitor, NFD);
impl_serde_unit_struct!(NMTVisitor, Nmt);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nfkc() {
let original = "\u{fb01}".to_string();
let normalized = "fi".to_string();
let mut n = NormalizedString::from(original.clone());
NFKC.normalize(&mut n).unwrap();
assert_eq!(
n,
NormalizedString::new(
original,
normalized,
vec![(0, 3), (0, 3)],
vec![(0, 2), (0, 2), (0, 2)],
0
)
)
}
}

View File

@ -0,0 +1,144 @@
use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script};
use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result};
#[derive(Clone, Debug)]
pub struct UnicodeScripts;
impl_serde_unit_struct!(UnicodeScriptsVisitor, UnicodeScripts);
impl UnicodeScripts {
pub fn new() -> Self {
Self {}
}
}
impl Default for UnicodeScripts {
fn default() -> Self {
Self::new()
}
}
// This code exists in the Unigram default IsValidSentencePiece.
// It could be integrated directly within `get_script` but I
// think it's kind of tricky to see those modifications later
// I am guessing release mode will optimize this away anyway.
fn fixed_script(c: char) -> Script {
let raw_script = get_script(c);
if c as u32 == 0x30FC {
Script::Han
} else if c == ' ' {
Script::Any
} else {
match raw_script {
Script::Hiragana => Script::Han,
Script::Katakana => Script::Han,
script => script,
}
}
}
impl PreTokenizer for UnicodeScripts {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
pretokenized.split(|_, normalized| {
let mut last_script = None;
let mut offset = 0;
let mut ranges: Vec<_> = normalized
.get()
.chars()
.filter_map(|c| {
let script = Some(fixed_script(c));
let result = if script != Some(Script::Any)
&& last_script != Some(Script::Any)
&& last_script != script
{
Some(offset)
} else {
None
};
offset += c.len_utf8();
if script != Some(Script::Any) {
last_script = script;
}
result
})
.collect();
ranges.push(normalized.get().len());
Ok(ranges
.windows(2)
.map(|item| {
normalized
.slice(Range::Normalized(item[0]..item[1]))
.expect("NormalizedString bad split")
})
.collect::<Vec<_>>())
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::OffsetReferential;
#[test]
fn basic() {
let pretok = UnicodeScripts::default();
let mut pretokenized = PreTokenizedString::from("どこで生れ。Yes");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![("どこで生れ", (0, 15)), ("", (15, 18)), ("Yes", (18, 21))]
);
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![("どこで生れ", (0, 15)), ("", (15, 18)), ("Yes", (18, 21))]
);
}
#[test]
fn spaces_are_included_in_every_script() {
let pretok = UnicodeScripts::default();
let mut pretokenized = PreTokenizedString::from("Apples are りんご 林檎");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![("Apples are ", (0, 11)), ("りんご 林檎", (11, 27))]
);
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![("Apples are ", (0, 11)), ("りんご 林檎", (11, 27))]
);
}
#[test]
fn test_unicode_script() {
assert_eq!(Script::Han, fixed_script('京'));
assert_eq!(Script::Han, fixed_script('太'));
assert_eq!(Script::Han, fixed_script('い'));
assert_eq!(Script::Han, fixed_script('グ'));
assert_eq!(Script::Han, fixed_script('ー'));
assert_eq!(Script::Latin, fixed_script('a'));
assert_eq!(Script::Latin, fixed_script('A'));
assert_eq!(Script::Common, fixed_script('0'));
assert_eq!(Script::Common, fixed_script('$'));
assert_eq!(Script::Common, fixed_script('@'));
assert_eq!(Script::Common, fixed_script('-'));
assert_eq!(Script::Any, fixed_script(' '));
}
}

View File

@ -144,8 +144,8 @@ pub enum Script {
Yi, Yi,
} }
pub fn get_script(c: &char) -> Script { pub fn get_script(c: char) -> Script {
match *c as u32 { match c as u32 {
0x0000..=0x001F => Script::Common, 0x0000..=0x001F => Script::Common,
0x0020 => Script::Common, 0x0020 => Script::Common,
0x0021..=0x0023 => Script::Common, 0x0021..=0x0023 => Script::Common,
@ -2078,16 +2078,18 @@ mod tests {
#[test] #[test]
fn test_unicode_script() { fn test_unicode_script() {
assert_eq!(Script::Han, get_script(&'京')); assert_eq!(Script::Han, get_script('京'));
assert_eq!(Script::Han, get_script(&'太')); assert_eq!(Script::Han, get_script('太'));
assert_eq!(Script::Hiragana, get_script(&'い')); assert_eq!(Script::Hiragana, get_script('い'));
assert_eq!(Script::Katakana, get_script(&'グ')); assert_eq!(Script::Katakana, get_script('グ'));
assert_eq!(Script::Common, get_script(&'ー')); assert_eq!(Script::Common, get_script('ー'));
assert_eq!(Script::Latin, get_script(&'a')); assert_eq!(Script::Latin, get_script('a'));
assert_eq!(Script::Latin, get_script(&'A')); assert_eq!(Script::Latin, get_script('A'));
assert_eq!(Script::Common, get_script(&'0')); assert_eq!(Script::Common, get_script('0'));
assert_eq!(Script::Common, get_script(&'$')); assert_eq!(Script::Common, get_script('$'));
assert_eq!(Script::Common, get_script(&'@')); assert_eq!(Script::Common, get_script('@'));
assert_eq!(Script::Common, get_script(&'-')); assert_eq!(Script::Common, get_script('-'));
assert_eq!(Script::Common, get_script(' '));
assert_eq!(Script::Common, get_script('<27>'));
} }
} }

View File

@ -124,6 +124,22 @@ pub struct NormalizedString {
} }
impl NormalizedString { impl NormalizedString {
#[cfg(test)]
pub(crate) fn new(
original: String,
normalized: String,
alignments: Vec<(usize, usize)>,
alignments_original: Vec<(usize, usize)>,
original_shift: usize,
) -> Self {
Self {
original,
normalized,
alignments,
alignments_original,
original_shift,
}
}
/// Return the normalized string /// Return the normalized string
pub fn get(&self) -> &str { pub fn get(&self) -> &str {
&self.normalized &self.normalized