mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 04:08:22 +00:00
Improvements on spm parity: (#401)
* Removing all pre_tokenizer logic from Unigram algorithm. * Improving *a lot* the parity check. - We can now detect a lot more errors - Special cases have been added temporarily. * Adding 2 new normalizers that mimick spm defaut's behavior. * Adding `encoding_optimized` version of the `encode` algorithm. - Removes Lattice allocation. - Changes trie `common_prefix_search` to return an iterator to avoid allocation of the full results. * Trie<char> -> Trie<u8> Another improvement on speed. * [WIP] Attempt to create a Precompiled Normalizer from SPM to be 100% compliant with arbitrary models. * Adding a new `Precompiled` Normalizer that is replacing `SpmNmtNfkc`. - It will be used for direct compatiblity with `Spm` and replace all their custom rules by using directly the normalizer spec embedded within spm files, removing all need for any rules for us. - We need `nom` dependency to parse the binary format of `spm`. - We need to add `sentencepiece_model_pb2.py` file to be able to read the proto file. - We reimplemented their `Darts::DoubleArray` compact trie format. * Fixing a bug with Precompiled normalizer. * Fixing some edge cases (now in tests) with this weird precompiled normalizer. It seems a very handy crafted trie does not prevent from shooting oneself in the foot. Sorry future reader. * Keep API stable for this PR (change of the API should come later #409). - Removed sentencepiece_model_pb2 from binding and add instructions to make `from_spm` work. * Adding model check in `from_spm`. * Adressing @n1t0's comments. * Adding a check to make sure alignments stay correct. Also added a bit more documentation on how Precompiled works. * Extracting `Precompiled` into it's own `spm_precompiled` crate. * Using ranges in `do_nmt`.
This commit is contained in:
54
bindings/python/Cargo.lock
generated
54
bindings/python/Cargo.lock
generated
@@ -16,6 +16,11 @@ dependencies = [
|
||||
"winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "atty"
|
||||
version = "0.2.14"
|
||||
@@ -350,6 +355,18 @@ name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "lexical-core"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"arrayvec 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"ryu 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"static_assertions 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.77"
|
||||
@@ -409,6 +426,16 @@ dependencies = [
|
||||
"rawpointer 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "5.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"lexical-core 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"version_check 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.2.4"
|
||||
@@ -741,6 +768,21 @@ name = "smallvec"
|
||||
version = "1.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "spm_precompiled"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"nom 5.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode-segmentation 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "static_assertions"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.8.0"
|
||||
@@ -834,6 +876,7 @@ dependencies = [
|
||||
"regex-syntax 0.6.18 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"spm_precompiled 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode-normalization-alignments 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode-segmentation 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
@@ -893,6 +936,11 @@ name = "vec_map"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.9.0+wasi-snapshot-preview1"
|
||||
@@ -928,6 +976,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
[metadata]
|
||||
"checksum aho-corasick 0.7.13 (registry+https://github.com/rust-lang/crates.io-index)" = "043164d8ba5c4c3035fec9bbee8647c0261d788f3474306f93bb65901cae0e86"
|
||||
"checksum ansi_term 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b"
|
||||
"checksum arrayvec 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cff77d8686867eceff3105329d4698d96c2391c176d5d03adc90c7389162b5b8"
|
||||
"checksum atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)" = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
||||
"checksum autocfg 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
|
||||
"checksum bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
|
||||
@@ -966,6 +1015,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum itertools 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b"
|
||||
"checksum itoa 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)" = "dc6f3ad7b9d11a0c00842ff8de1b60ee58661048eb8049ed33c73594f359d7e6"
|
||||
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
"checksum lexical-core 0.7.4 (registry+https://github.com/rust-lang/crates.io-index)" = "db65c6da02e61f55dae90a0ae427b2a5f6b3e8db09f58d10efab23af92592616"
|
||||
"checksum libc 0.2.77 (registry+https://github.com/rust-lang/crates.io-index)" = "f2f96b10ec2560088a8e76961b00d47107b3a625fecb76dedb29ee7ccbf98235"
|
||||
"checksum lock_api 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)" = "28247cc5a5be2f05fbcd76dd0cf2c7d3b5400cb978a28042abcd4fa0b3f8261c"
|
||||
"checksum log 0.4.11 (registry+https://github.com/rust-lang/crates.io-index)" = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b"
|
||||
@@ -974,6 +1024,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum memchr 2.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3728d817d99e5ac407411fa471ff9800a778d88a24685968b36824eaf4bee400"
|
||||
"checksum memoffset 0.5.5 (registry+https://github.com/rust-lang/crates.io-index)" = "c198b026e1bbf08a937e94c6c60f9ec4a2267f5b0d2eec9c1b21b061ce2be55f"
|
||||
"checksum ndarray 0.13.1 (registry+https://github.com/rust-lang/crates.io-index)" = "ac06db03ec2f46ee0ecdca1a1c34a99c0d188a0d83439b84bf0cb4b386e4ab09"
|
||||
"checksum nom 5.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "ffb4262d26ed83a1c0a33a38fe2bb15797329c85770da05e6b828ddb782627af"
|
||||
"checksum num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
|
||||
"checksum num-integer 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)" = "8d59457e662d541ba17869cf51cf177c0b5f0cbf476c66bdc90bf1edac4f875b"
|
||||
"checksum num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)" = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611"
|
||||
@@ -1013,6 +1064,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum serde_derive 1.0.116 (registry+https://github.com/rust-lang/crates.io-index)" = "f630a6370fd8e457873b4bd2ffdae75408bc291ba72be773772a4c2a065d9ae8"
|
||||
"checksum serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)" = "164eacbdb13512ec2745fb09d51fd5b22b0d65ed294a1dcf7285a360c80a675c"
|
||||
"checksum smallvec 1.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "fbee7696b84bbf3d89a1c2eccff0850e3047ed46bfcd2e92c29a2d074d57e252"
|
||||
"checksum spm_precompiled 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f78be885c9efc899a7c0348f67c98b488cbeaf2cb608a48fb87ef1484ecab5c5"
|
||||
"checksum static_assertions 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
"checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
|
||||
"checksum strsim 0.9.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
|
||||
"checksum syn 1.0.40 (registry+https://github.com/rust-lang/crates.io-index)" = "963f7d3cc59b59b9325165add223142bbf1df27655d07789f109896d353d8350"
|
||||
@@ -1029,6 +1082,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
||||
"checksum unindent 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "af41d708427f8fd0e915dcebb2cae0f0e6acb2a939b2d399c265c39a38a18942"
|
||||
"checksum vec_map 0.8.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191"
|
||||
"checksum version_check 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)" = "b5a972e5669d67ba988ce3dc826706fb0a8b01471c088cb0b6110b805cc36aed"
|
||||
"checksum wasi 0.9.0+wasi-snapshot-preview1 (registry+https://github.com/rust-lang/crates.io-index)" = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
|
||||
"checksum winapi 0.3.9 (registry+https://github.com/rust-lang/crates.io-index)" = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||
"checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
from tokenizers import Tokenizer, AddedToken, pre_tokenizers, decoders, trainers
|
||||
from tokenizers import (
|
||||
Tokenizer,
|
||||
AddedToken,
|
||||
pre_tokenizers,
|
||||
decoders,
|
||||
trainers,
|
||||
normalizers,
|
||||
)
|
||||
import os
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.normalizers import NFKC
|
||||
import json
|
||||
from .base_tokenizer import BaseTokenizer
|
||||
|
||||
from typing import Optional, List, Union
|
||||
@@ -16,11 +24,12 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
|
||||
self, vocab: Optional[str] = None, replacement: str = "▁", add_prefix_space: bool = True,
|
||||
):
|
||||
if vocab is not None:
|
||||
# Let Unigram(..) fail if only one of them is None
|
||||
tokenizer = Tokenizer(Unigram(vocab))
|
||||
else:
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
tokenizer.normalizer = NFKC()
|
||||
tokenizer.normalizer = normalizers.Sequence([normalizers.Nmt(), normalizers.NFKC(),])
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||
[
|
||||
pre_tokenizers.WhitespaceSplit(),
|
||||
@@ -57,3 +66,63 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
|
||||
if isinstance(files, str):
|
||||
files = [files]
|
||||
self._tokenizer.train(trainer, files)
|
||||
|
||||
@staticmethod
|
||||
def from_spm(filename: str):
|
||||
try:
|
||||
import sys
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
import sentencepiece_model_pb2 as model
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"You don't seem to have the required protobuf file, in order to use this function you need to run `pip install protobuf` and `wget https://raw.githubusercontent.com/google/sentencepiece/master/python/sentencepiece_model_pb2.py` for us to be able to read the intrinsics of your spm_file. `pip install sentencepiece` is not required."
|
||||
)
|
||||
|
||||
m = model.ModelProto()
|
||||
m.ParseFromString(open(filename, "rb").read())
|
||||
|
||||
precompiled_charsmap = m.normalizer_spec.precompiled_charsmap
|
||||
vocab = [(piece.piece, piece.score) for piece in m.pieces]
|
||||
unk_id = m.trainer_spec.unk_id
|
||||
model_type = m.trainer_spec.model_type
|
||||
if model_type != 1:
|
||||
raise Exception(
|
||||
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
||||
)
|
||||
|
||||
data = {"unk_id": unk_id, "vocab": vocab}
|
||||
|
||||
replacement = "▁"
|
||||
add_prefix_space = True
|
||||
|
||||
out_vocab_filename = f"{filename}.json"
|
||||
try:
|
||||
with open(out_vocab_filename, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
tokenizer = Tokenizer(Unigram(out_vocab_filename))
|
||||
finally:
|
||||
os.remove(out_vocab_filename)
|
||||
|
||||
tokenizer.normalizer = normalizers.Precompiled(precompiled_charsmap)
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
|
||||
[
|
||||
pre_tokenizers.WhitespaceSplit(),
|
||||
pre_tokenizers.Metaspace(
|
||||
replacement=replacement, add_prefix_space=add_prefix_space
|
||||
),
|
||||
]
|
||||
)
|
||||
tokenizer.decoder = decoders.Metaspace(
|
||||
replacement=replacement, add_prefix_space=add_prefix_space
|
||||
)
|
||||
|
||||
parameters = {
|
||||
"model": "SentencePieceUnigram",
|
||||
}
|
||||
|
||||
obj = BaseTokenizer.__new__(SentencePieceUnigramTokenizer, tokenizer, parameters)
|
||||
BaseTokenizer.__init__(obj, tokenizer, parameters)
|
||||
return obj
|
||||
|
||||
@@ -9,6 +9,8 @@ NFKC = normalizers.NFKC
|
||||
Sequence = normalizers.Sequence
|
||||
Lowercase = normalizers.Lowercase
|
||||
Strip = normalizers.Strip
|
||||
Nmt = normalizers.Nmt
|
||||
Precompiled = normalizers.Precompiled
|
||||
|
||||
|
||||
NORMALIZERS = {"nfc": NFC, "nfd": NFD, "nfkc": NFKC, "nfkd": NFKD}
|
||||
|
||||
@@ -99,6 +99,18 @@ class Strip(Normalizer):
|
||||
def __init__(self, left: bool = True, right: bool = True) -> Normalizer:
|
||||
pass
|
||||
|
||||
class Nmt(Normalizer):
|
||||
""" Nmt normalizer """
|
||||
|
||||
def __init__(self) -> Normalizer:
|
||||
pass
|
||||
|
||||
class Precompiled(Normalizer):
|
||||
""" SpmNmtNfkc normalizer """
|
||||
|
||||
def __init__(self, precompiled_charsmap: bytes) -> Normalizer:
|
||||
pass
|
||||
|
||||
def unicode_normalizer_from_str(normalizer: str) -> Normalizer:
|
||||
"""
|
||||
Instanciate unicode normalizer from the normalizer name
|
||||
|
||||
@@ -1,7 +1,17 @@
|
||||
import tokenizers
|
||||
from argparse import ArgumentParser
|
||||
import sentencepiece as spm
|
||||
from collections import Counter
|
||||
import json
|
||||
import os
|
||||
import datetime
|
||||
|
||||
try:
|
||||
from termcolor import colored
|
||||
|
||||
has_color = True
|
||||
except Exception:
|
||||
has_color = False
|
||||
|
||||
|
||||
def main():
|
||||
@@ -9,38 +19,62 @@ def main():
|
||||
parser.add_argument(
|
||||
"--input-file", "-i", type=str, required=True, help="Which files do you want to train from",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-file",
|
||||
"-m",
|
||||
type=str,
|
||||
required=False,
|
||||
default=None,
|
||||
help="Use a pretrained token file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-prefix", type=str, default="spm_parity", help="Model prefix for spm_train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Verbosity",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train",
|
||||
action="store_true",
|
||||
help="Instead of checking the encoder part, we check the trainer part",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from-spm",
|
||||
action="store_true",
|
||||
help="Directly load the spm file with it's own normalizer",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
spm.SentencePieceTrainer.Train(
|
||||
f"--input={args.input_file} --model_prefix={args.model_prefix}"
|
||||
f" --character_coverage=1.0"
|
||||
f" --max_sentence_length=40000"
|
||||
f" --num_threads=1"
|
||||
f" --vocab_size={args.vocab_size}"
|
||||
)
|
||||
trained = False
|
||||
if args.model_file is None:
|
||||
spm.SentencePieceTrainer.Train(
|
||||
f"--input={args.input_file} --model_prefix={args.model_prefix}"
|
||||
f" --character_coverage=1.0"
|
||||
f" --max_sentence_length=40000"
|
||||
f" --num_threads=1"
|
||||
f" --vocab_size={args.vocab_size}"
|
||||
)
|
||||
trained = True
|
||||
args.model_file = f"{args.model_prefix}.model"
|
||||
|
||||
if args.train:
|
||||
check_train(args)
|
||||
else:
|
||||
check_encode(args)
|
||||
try:
|
||||
if args.train:
|
||||
check_train(args)
|
||||
else:
|
||||
check_encode(args)
|
||||
finally:
|
||||
if trained:
|
||||
os.remove(f"{args.model_prefix}.model")
|
||||
os.remove(f"{args.model_prefix}.vocab")
|
||||
|
||||
|
||||
def check_train(args):
|
||||
sp = spm.SentencePieceProcessor()
|
||||
model_filename = f"{args.model_prefix}.model"
|
||||
sp.Load(model_filename)
|
||||
sp.Load(args.model_file)
|
||||
|
||||
tokenizer = tokenizers.SentencePieceUnigramTokenizer()
|
||||
tokenizer.train(args.input_file, show_progress=False)
|
||||
@@ -77,38 +111,144 @@ def check_train(args):
|
||||
assert (
|
||||
tokenizer_tokens < spm_tokens
|
||||
), "Our trainer should be at least more efficient than the SPM one"
|
||||
print("Ok our trainer is at least more efficient than the SPM one")
|
||||
|
||||
|
||||
def check_diff(spm_diff, tok_diff, sp, tok):
|
||||
if spm_diff == list(reversed(tok_diff)):
|
||||
# AAA -> AA+A vs A+AA case.
|
||||
return True
|
||||
elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(tok_diff):
|
||||
# Second order OK
|
||||
# Barrich -> Barr + ich vs Bar + rich
|
||||
return True
|
||||
spm_reencoded = sp.encode(sp.decode(spm_diff))
|
||||
tok_reencoded = tok.encode(tok.decode(spm_diff)).ids
|
||||
if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded:
|
||||
# Type 3 error.
|
||||
# Snehagatha ->
|
||||
# Sne, h, aga, th, a
|
||||
# Sne, ha, gat, ha
|
||||
# Encoding the wrong with sp does not even recover what spm gave us
|
||||
# It fits tokenizer however...
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_details(line, spm_ids, tok_ids, tok, sp):
|
||||
# Encoding can be the same with same result AAA -> A + AA vs AA + A
|
||||
# We can check that we use at least exactly the same number of tokens.
|
||||
for i, (spm_id, tok_id) in enumerate(zip(spm_ids, tok_ids)):
|
||||
if spm_id != tok_id:
|
||||
break
|
||||
first = i
|
||||
for i, (spm_id, tok_id) in enumerate(zip(reversed(spm_ids), reversed(tok_ids))):
|
||||
if spm_id != tok_id:
|
||||
break
|
||||
last = len(spm_ids) - i
|
||||
|
||||
spm_diff = spm_ids[first:last]
|
||||
tok_diff = tok_ids[first:last]
|
||||
|
||||
if check_diff(spm_diff, tok_diff, sp, tok):
|
||||
return True
|
||||
|
||||
if last - first > 5:
|
||||
# We might have twice a single problem, attempt to subdivide the disjointed tokens into smaller problems
|
||||
spms = Counter(spm_ids[first:last])
|
||||
toks = Counter(tok_ids[first:last])
|
||||
|
||||
removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
|
||||
min_width = 3
|
||||
for i in range(last - first - min_width):
|
||||
if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
|
||||
possible_matches = [
|
||||
k
|
||||
for k in range(last - first - min_width)
|
||||
if tok_ids[first + k : first + k + min_width]
|
||||
== spm_ids[first + i : first + i + min_width]
|
||||
]
|
||||
for j in possible_matches:
|
||||
if check_diff(
|
||||
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
|
||||
) and check_diff(spm_ids[first + i : last], tok_ids[first + j : last], sp, tok):
|
||||
return True
|
||||
ok_start = tok.decode(spm_ids[:first])
|
||||
ok_end = tok.decode(spm_ids[last:])
|
||||
wrong = tok.decode(spm_ids[first:last])
|
||||
print()
|
||||
if has_color:
|
||||
print(f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}")
|
||||
else:
|
||||
print(wrong)
|
||||
|
||||
print(f"Spm: {[tok.decode([spm_ids[i]]) for i in range(first, last)]}")
|
||||
print(f"Tok: {[tok.decode([tok_ids[i]]) for i in range(first, last)]}")
|
||||
return False
|
||||
|
||||
|
||||
def check_encode(args):
|
||||
sp = spm.SentencePieceProcessor()
|
||||
model_filename = f"{args.model_prefix}.model"
|
||||
sp.Load(model_filename)
|
||||
sp.Load(args.model_file)
|
||||
|
||||
vocab_filename = f"{args.model_prefix}.json"
|
||||
if args.from_spm:
|
||||
tok = tokenizers.SentencePieceUnigramTokenizer.from_spm(args.model_file)
|
||||
else:
|
||||
vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())]
|
||||
vocab_filename = f"{args.model_file}.json"
|
||||
unk_id = sp.unk_id()
|
||||
|
||||
vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())]
|
||||
data = {"unk_id": unk_id, "vocab": vocab}
|
||||
try:
|
||||
with open(vocab_filename, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
data = {"unk_id": sp.unk_id(), "vocab": vocab}
|
||||
tok = tokenizers.SentencePieceUnigramTokenizer(vocab_filename)
|
||||
finally:
|
||||
os.remove(vocab_filename)
|
||||
|
||||
with open(vocab_filename, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
tok = tokenizers.SentencePieceUnigramTokenizer(vocab_filename)
|
||||
with open(args.input_file, "r") as f:
|
||||
perfect = 0
|
||||
imperfect = 0
|
||||
wrong = 0
|
||||
now = datetime.datetime.now
|
||||
spm_total_time = datetime.timedelta(seconds=0)
|
||||
tok_total_time = datetime.timedelta(seconds=0)
|
||||
with open(args.input_file, "r", encoding="utf-8-sig") as f:
|
||||
for i, line in enumerate(f):
|
||||
line = line.strip()
|
||||
|
||||
start = now()
|
||||
ids = sp.EncodeAsIds(line)
|
||||
spm_time = now()
|
||||
|
||||
encoded = tok.encode(line)
|
||||
tok_time = now()
|
||||
|
||||
spm_total_time += spm_time - start
|
||||
tok_total_time += tok_time - spm_time
|
||||
|
||||
if args.verbose:
|
||||
if i % 10000 == 0:
|
||||
print(
|
||||
f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})"
|
||||
)
|
||||
print(f"SPM: {spm_total_time} - TOK: {tok_total_time}")
|
||||
|
||||
if ids != encoded.ids:
|
||||
# Encoding can be the same with same result AAA -> A + AA vs AA + A
|
||||
# We can check that we use at least exactly the same number of tokens.
|
||||
assert len(ids) == len(encoded.ids)
|
||||
continue
|
||||
if check_details(line, ids, encoded.ids, tok, sp):
|
||||
imperfect += 1
|
||||
continue
|
||||
else:
|
||||
wrong += 1
|
||||
else:
|
||||
perfect += 1
|
||||
|
||||
assert ids == encoded.ids, f"line {i}: {line} : {ids} != {encoded.ids}"
|
||||
|
||||
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
|
||||
total = perfect + imperfect + wrong
|
||||
print(f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -15,7 +15,7 @@ setup(
|
||||
author_email="anthony@huggingface.co",
|
||||
url="https://github.com/huggingface/tokenizers",
|
||||
license="Apache License 2.0",
|
||||
rust_extensions=[RustExtension("tokenizers.tokenizers", binding=Binding.PyO3)],
|
||||
rust_extensions=[RustExtension("tokenizers.tokenizers", binding=Binding.PyO3, debug=False)],
|
||||
extras_require=extras,
|
||||
classifiers=[
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
|
||||
@@ -108,6 +108,8 @@ fn normalizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<normalizers::PySequence>()?;
|
||||
m.add_class::<normalizers::PyLowercase>()?;
|
||||
m.add_class::<normalizers::PyStrip>()?;
|
||||
m.add_class::<normalizers::PyNmt>()?;
|
||||
m.add_class::<normalizers::PyPrecompiled>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -263,21 +263,19 @@ pub struct PyUnigram {}
|
||||
#[pymethods]
|
||||
impl PyUnigram {
|
||||
#[new]
|
||||
fn new(vocab: Option<&str>) -> PyResult<(Self, PyModel)> {
|
||||
if let Some(vocab) = vocab {
|
||||
let path = Path::new(vocab);
|
||||
match Unigram::load(path) {
|
||||
fn new(vocab: Option<String>) -> PyResult<(Self, PyModel)> {
|
||||
match vocab {
|
||||
Some(vocab) => match Unigram::load(&std::path::Path::new(&vocab)) {
|
||||
Err(e) => {
|
||||
println!("Errors: {:?}", e);
|
||||
Err(exceptions::Exception::py_err("Error while loading Unigram"))
|
||||
}
|
||||
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
||||
}
|
||||
} else {
|
||||
Ok((
|
||||
},
|
||||
None => Ok((
|
||||
PyUnigram {},
|
||||
PyModel::new(Arc::new(Unigram::default().into())),
|
||||
))
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,9 @@ use pyo3::types::*;
|
||||
use crate::error::ToPyResult;
|
||||
use serde::ser::SerializeStruct;
|
||||
use serde::{Deserialize, Serialize, Serializer};
|
||||
use tk::normalizers::{BertNormalizer, Lowercase, NormalizerWrapper, Strip, NFC, NFD, NFKC, NFKD};
|
||||
use tk::normalizers::{
|
||||
BertNormalizer, Lowercase, Nmt, NormalizerWrapper, Precompiled, Strip, NFC, NFD, NFKC, NFKD,
|
||||
};
|
||||
use tk::{NormalizedString, Normalizer};
|
||||
use tokenizers as tk;
|
||||
|
||||
@@ -45,6 +47,10 @@ impl PyNormalizer {
|
||||
NormalizerWrapper::Lowercase(_) => {
|
||||
Py::new(py, (PyLowercase {}, base)).map(Into::into)
|
||||
}
|
||||
NormalizerWrapper::Precompiled(_) => {
|
||||
Py::new(py, (PyPrecompiled {}, base)).map(Into::into)
|
||||
}
|
||||
NormalizerWrapper::Nmt(_) => Py::new(py, (PyNmt {}, base)).map(Into::into),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -273,6 +279,37 @@ impl Normalizer for PyNormalizerWrapper {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Nmt)]
|
||||
pub struct PyNmt {}
|
||||
#[pymethods]
|
||||
impl PyNmt {
|
||||
#[new]
|
||||
fn new() -> PyResult<(Self, PyNormalizer)> {
|
||||
Ok((PyNmt {}, Nmt.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(extends=PyNormalizer, module = "tokenizers.normalizers", name=Precompiled)]
|
||||
pub struct PyPrecompiled {}
|
||||
#[pymethods]
|
||||
impl PyPrecompiled {
|
||||
#[new]
|
||||
fn new(py_precompiled_charsmap: &PyBytes) -> PyResult<(Self, PyNormalizer)> {
|
||||
let precompiled_charsmap: &[u8] = FromPyObject::extract(py_precompiled_charsmap)?;
|
||||
Ok((
|
||||
PyPrecompiled {},
|
||||
Precompiled::from(precompiled_charsmap)
|
||||
.map_err(|e| {
|
||||
exceptions::Exception::py_err(format!(
|
||||
"Error while attempting to build Precompiled normalizer: {}",
|
||||
e.to_string()
|
||||
))
|
||||
})?
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use pyo3::{AsPyRef, Python};
|
||||
|
||||
Reference in New Issue
Block a user