mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-07 21:28:19 +00:00
Failing test for compatibility for SentencePieceUnigramTokenizer.
- We are failing on ambiguous tokenization (AAA -> A + AA vs AA + A). Could be linked to float precision and hard or impossible to fix (should not hinder model performance) - We are now fusing_unk by default as it's the case with spm_train - We are still failing on at least space deduplication. Probably should be handlded by a pre-tokenizer.
This commit is contained in:
96
bindings/python/Cargo.lock
generated
96
bindings/python/Cargo.lock
generated
@@ -137,6 +137,61 @@ dependencies = [
|
||||
"syn 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"darling_core 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"darling_macro 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_core"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"fnv 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"ident_case 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"proc-macro2 1.0.19 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"strsim 0.9.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"syn 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_macro"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"darling_core 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"syn 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"darling 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"derive_builder_core 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"proc-macro2 1.0.19 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"syn 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder_core"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"darling 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"proc-macro2 1.0.19 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quote 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"syn 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.5.3"
|
||||
@@ -159,6 +214,19 @@ dependencies = [
|
||||
"termcolor 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "esaxx-rs"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"cc 1.0.58 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fnv"
|
||||
version = "1.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.1.14"
|
||||
@@ -195,6 +263,11 @@ dependencies = [
|
||||
"quick-error 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ident_case"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "indicatif"
|
||||
version = "0.14.0"
|
||||
@@ -674,6 +747,11 @@ name = "strsim"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.37"
|
||||
@@ -743,6 +821,8 @@ name = "tokenizers"
|
||||
version = "0.10.1"
|
||||
dependencies = [
|
||||
"clap 2.33.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"derive_builder 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"esaxx-rs 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"itertools 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
@@ -756,6 +836,7 @@ dependencies = [
|
||||
"serde 1.0.114 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode-normalization-alignments 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode-segmentation 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
@@ -783,6 +864,11 @@ dependencies = [
|
||||
"smallvec 1.4.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-segmentation"
|
||||
version = "1.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.1.8"
|
||||
@@ -856,13 +942,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum crossbeam-queue 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "774ba60a54c213d409d5353bda12d49cd68d14e45036a285234c8d6f91f92570"
|
||||
"checksum crossbeam-utils 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3c7c73a2d1e9fc0886a08b93e98eb643461230d5f1925e4036204d5f2e261a8"
|
||||
"checksum ctor 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)" = "39858aa5bac06462d4dd4b9164848eb81ffc4aa5c479746393598fd193afa227"
|
||||
"checksum darling 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)" = "0d706e75d87e35569db781a9b5e2416cff1236a47ed380831f959382ccd5f858"
|
||||
"checksum darling_core 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f0c960ae2da4de88a91b2d920c2a7233b400bc33cb28453a2987822d8392519b"
|
||||
"checksum darling_macro 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)" = "d9b5a2f4ac4969822c62224815d069952656cadc7084fdca9751e6d959189b72"
|
||||
"checksum derive_builder 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a2658621297f2cf68762a6f7dc0bb7e1ff2cfd6583daef8ee0fed6f7ec468ec0"
|
||||
"checksum derive_builder_core 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2791ea3e372c8495c0bc2033991d76b512cd799d07491fbd6890124db9458bef"
|
||||
"checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
|
||||
"checksum encode_unicode 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"
|
||||
"checksum env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "44533bbbb3bb3c1fa17d9f2e4e38bbbaf8396ba82193c4cb1b6445d711445d36"
|
||||
"checksum esaxx-rs 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "a3f0bf221d15f92461d05eea094c77aec5a00e3574740159e178beab2c58ea64"
|
||||
"checksum fnv 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)" = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||
"checksum getrandom 0.1.14 (registry+https://github.com/rust-lang/crates.io-index)" = "7abc8dd8451921606d809ba32e95b6111925cd2906060d2dcc29c070220503eb"
|
||||
"checksum ghost 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "1a5bcf1bbeab73aa4cf2fde60a846858dc036163c7c33bec309f8d17de785479"
|
||||
"checksum hermit-abi 0.1.15 (registry+https://github.com/rust-lang/crates.io-index)" = "3deed196b6e7f9e44a2ae8d94225d80302d81208b1bb673fd21fe634645c85a9"
|
||||
"checksum humantime 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f"
|
||||
"checksum ident_case 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
||||
"checksum indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)" = "49a68371cf417889c9d7f98235b7102ea7c54fc59bcbd22f3dea785be9d27e40"
|
||||
"checksum indoc 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "47741a8bc60fb26eb8d6e0238bbb26d8575ff623fdc97b1a2c00c050b9684ed8"
|
||||
"checksum indoc-impl 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "ce046d161f000fffde5f432a0d034d0341dc152643b2598ed5bfce44c4f3a8f0"
|
||||
@@ -921,6 +1015,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum serde_json 1.0.57 (registry+https://github.com/rust-lang/crates.io-index)" = "164eacbdb13512ec2745fb09d51fd5b22b0d65ed294a1dcf7285a360c80a675c"
|
||||
"checksum smallvec 1.4.1 (registry+https://github.com/rust-lang/crates.io-index)" = "3757cb9d89161a2f24e1cf78efa0c1fcff485d18e3f55e0aa3480824ddaa0f3f"
|
||||
"checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
|
||||
"checksum strsim 0.9.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
|
||||
"checksum syn 1.0.37 (registry+https://github.com/rust-lang/crates.io-index)" = "239f255b9e3429350f188c27b807fc9920a15eb9145230ff1a7d054c08fec319"
|
||||
"checksum tempfile 3.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7a6e24d9338a0a5be79593e2fa15a648add6138caa803e2d5bc782c371732ca9"
|
||||
"checksum termcolor 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "bb6bfa289a4d7c5766392812c0a1f4c1ba45afa1ad47803c11e1f407d846d75f"
|
||||
@@ -929,6 +1024,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
"checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
|
||||
"checksum thread_local 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "d40c6d1b69745a6ec6fb1ca717914848da4b44ae29d9b3080cbee91d72a69b14"
|
||||
"checksum unicode-normalization-alignments 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)" = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de"
|
||||
"checksum unicode-segmentation 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e83e153d1053cbb5a118eeff7fd5be06ed99153f00dbcd8ae310c5fb2b22edc0"
|
||||
"checksum unicode-width 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)" = "9337591893a19b88d8d87f2cec1e73fad5cdfd10e5a6f349f498ad6ea2ffb1e3"
|
||||
"checksum unicode-xid 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564"
|
||||
"checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
||||
|
||||
@@ -8,7 +8,8 @@ TextInputSequence = str
|
||||
PreTokenizedInputSequence = Union[List[str], Tuple[str]]
|
||||
TextEncodeInput = Union[TextInputSequence, Tuple[TextInputSequence, TextInputSequence]]
|
||||
PreTokenizedEncodeInput = Union[
|
||||
PreTokenizedInputSequence, Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]
|
||||
PreTokenizedInputSequence,
|
||||
Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence],
|
||||
]
|
||||
|
||||
InputSequence = Union[TextInputSequence, PreTokenizedInputSequence]
|
||||
@@ -25,5 +26,6 @@ from .implementations import (
|
||||
ByteLevelBPETokenizer,
|
||||
CharBPETokenizer,
|
||||
SentencePieceBPETokenizer,
|
||||
SentencePieceUnigramTokenizer,
|
||||
BertWordPieceTokenizer,
|
||||
)
|
||||
|
||||
@@ -2,4 +2,5 @@ from .base_tokenizer import BaseTokenizer
|
||||
from .byte_level_bpe import ByteLevelBPETokenizer
|
||||
from .char_level_bpe import CharBPETokenizer
|
||||
from .sentencepiece_bpe import SentencePieceBPETokenizer
|
||||
from .sentencepiece_unigram import SentencePieceUnigramTokenizer
|
||||
from .bert_wordpiece import BertWordPieceTokenizer
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
from tokenizers import Tokenizer, AddedToken, pre_tokenizers, decoders, trainers
|
||||
from tokenizers.models import Unigram
|
||||
from tokenizers.normalizers import NFKC
|
||||
from .base_tokenizer import BaseTokenizer
|
||||
|
||||
from typing import Optional, List, Union
|
||||
|
||||
|
||||
class SentencePieceUnigramTokenizer(BaseTokenizer):
|
||||
"""SentencePiece Unigram Tokenizer
|
||||
|
||||
Represents the Unigram algorithm, with the pretokenization used by SentencePiece
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab: Optional[str] = None,
|
||||
replacement: str = "▁",
|
||||
add_prefix_space: bool = True,
|
||||
):
|
||||
if vocab is not None:
|
||||
tokenizer = Tokenizer(Unigram(vocab))
|
||||
else:
|
||||
tokenizer = Tokenizer(Unigram())
|
||||
|
||||
tokenizer.normalizer = NFKC()
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(
|
||||
replacement=replacement, add_prefix_space=add_prefix_space
|
||||
)
|
||||
tokenizer.decoder = decoders.Metaspace(
|
||||
replacement=replacement, add_prefix_space=add_prefix_space
|
||||
)
|
||||
|
||||
parameters = {
|
||||
"model": "SentencePieceUnigram",
|
||||
"replacement": replacement,
|
||||
"add_prefix_space": add_prefix_space,
|
||||
}
|
||||
|
||||
super().__init__(tokenizer, parameters)
|
||||
@@ -6,3 +6,4 @@ Model = models.Model
|
||||
BPE = models.BPE
|
||||
WordPiece = models.WordPiece
|
||||
WordLevel = models.WordLevel
|
||||
Unigram = models.Unigram
|
||||
|
||||
@@ -100,3 +100,25 @@ class WordLevel(Model):
|
||||
|
||||
def __init__(self, vocab: Optional[str], unk_token: Optional[str]):
|
||||
pass
|
||||
|
||||
class Unigram(Model):
|
||||
"""UnigramEncoding model class
|
||||
|
||||
Instantiate a Unigram Model from the given model file.
|
||||
|
||||
Args:
|
||||
vocab: ('`optional`) string:
|
||||
Path to a vocabulary JSON file.
|
||||
|
||||
is_spm_file: ('`optional`) bool:
|
||||
If the file came out of sentencepiece, we need to load it differently
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __init__(
|
||||
self,
|
||||
vocab: Optional[str],
|
||||
is_spm_file: Optional[bool],
|
||||
):
|
||||
pass
|
||||
|
||||
80
bindings/python/scripts/spm_parity_check.py
Normal file
80
bindings/python/scripts/spm_parity_check.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import tokenizers
|
||||
from argparse import ArgumentParser
|
||||
import sentencepiece as spm
|
||||
import json
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser("SentencePiece parity checker")
|
||||
parser.add_argument(
|
||||
"--input-file",
|
||||
"-i",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Which files do you want to train from",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-prefix",
|
||||
type=str,
|
||||
default="spm_parity",
|
||||
help="Model prefix for spm_train",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
spm.SentencePieceTrainer.Train(
|
||||
f"--input={args.input_file} --model_prefix={args.model_prefix}"
|
||||
f" --vocab_size={args.vocab_size}"
|
||||
)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
model_filename = f"{args.model_prefix}.model"
|
||||
sp.Load(model_filename)
|
||||
|
||||
vocab_filename = f"{args.model_prefix}.json"
|
||||
|
||||
vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())]
|
||||
|
||||
data = {"unk_id": sp.unk_id(), "vocab": vocab}
|
||||
|
||||
with open(vocab_filename, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
tok = tokenizers.SentencePieceUnigramTokenizer(vocab_filename)
|
||||
with open(args.input_file, "r") as f:
|
||||
for i, line in enumerate(f):
|
||||
line = line.strip()
|
||||
ids = sp.EncodeAsIds(line)
|
||||
|
||||
encoded = tok.encode(line)
|
||||
|
||||
if ids != encoded.ids:
|
||||
# Encoding can be the same with same result AAA -> A + AA vs AA + A
|
||||
# We just check this does not cover unk tokens
|
||||
if len(ids) != len(encoded.ids):
|
||||
N = len(ids)
|
||||
M = len(encoded.ids)
|
||||
first_index_error = [
|
||||
i for i in range(min(N, M)) if ids[i] != encoded.ids[i]
|
||||
][0]
|
||||
last_index_error = [
|
||||
min(N, M) - i
|
||||
for i in range(min(N, M))
|
||||
if ids[-i - 1] != encoded.ids[-i - 1]
|
||||
][0]
|
||||
print(ids[first_index_error : last_index_error + 1])
|
||||
print(encoded.ids[first_index_error : last_index_error + 1])
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
assert len(ids) == len(encoded.ids)
|
||||
continue
|
||||
|
||||
assert ids == encoded.ids, f"line {i}: {line} : {ids} != {encoded.ids}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -51,6 +51,7 @@ fn models(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<models::PyBPE>()?;
|
||||
m.add_class::<models::PyWordPiece>()?;
|
||||
m.add_class::<models::PyWordLevel>()?;
|
||||
m.add_class::<models::PyUnigram>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tk::models::bpe::BPE;
|
||||
use tk::models::unigram::Unigram;
|
||||
use tk::models::wordlevel::WordLevel;
|
||||
use tk::models::wordpiece::WordPiece;
|
||||
use tk::models::ModelWrapper;
|
||||
@@ -37,6 +38,7 @@ impl PyModel {
|
||||
ModelWrapper::BPE(_) => Py::new(py, (PyBPE {}, base)).map(Into::into),
|
||||
ModelWrapper::WordPiece(_) => Py::new(py, (PyWordPiece {}, base)).map(Into::into),
|
||||
ModelWrapper::WordLevel(_) => Py::new(py, (PyWordLevel {}, base)).map(Into::into),
|
||||
ModelWrapper::Unigram(_) => Py::new(py, (PyUnigram {}, base)).map(Into::into),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -255,6 +257,57 @@ impl PyWordLevel {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(extends=PyModel, module = "tokenizers.models", name=Unigram)]
|
||||
pub struct PyUnigram {}
|
||||
|
||||
#[pymethods]
|
||||
impl PyUnigram {
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
fn new(vocab: Option<&str>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
|
||||
let mut is_spm_file = false;
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, val) in kwargs {
|
||||
let key: &str = key.extract()?;
|
||||
match key {
|
||||
"is_spm_file" => is_spm_file = val.extract()?,
|
||||
_ => println!("Ignored unknown kwargs option {}", key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(vocab) = vocab {
|
||||
let path = Path::new(vocab);
|
||||
if is_spm_file {
|
||||
match Unigram::load_spm(path) {
|
||||
Err(e) => {
|
||||
println!("Errors: {:?}", e);
|
||||
Err(exceptions::Exception::py_err(
|
||||
"Error while initializing Unigram from spm file",
|
||||
))
|
||||
}
|
||||
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
||||
}
|
||||
} else {
|
||||
match Unigram::load(path) {
|
||||
Err(e) => {
|
||||
println!("Errors: {:?}", e);
|
||||
Err(exceptions::Exception::py_err(
|
||||
"Error while initializing Unigram",
|
||||
))
|
||||
}
|
||||
Ok(model) => Ok((PyUnigram {}, PyModel::new(Arc::new(model.into())))),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok((
|
||||
PyUnigram {},
|
||||
PyModel::new(Arc::new(Unigram::default().into())),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::models::PyModel;
|
||||
|
||||
Reference in New Issue
Block a user