mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
Basic python bindings
This commit is contained in:
145
bindings/python/Cargo.lock
generated
145
bindings/python/Cargo.lock
generated
@ -124,6 +124,55 @@ dependencies = [
|
||||
"vec_map 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-queue"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"crossbeam-utils 0.6.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.6.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ctor"
|
||||
version = "0.1.12"
|
||||
@ -133,6 +182,11 @@ dependencies = [
|
||||
"syn 1.0.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.6.2"
|
||||
@ -176,6 +230,14 @@ name = "glob"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "humantime"
|
||||
version = "1.3.0"
|
||||
@ -262,6 +324,14 @@ name = "memchr"
|
||||
version = "2.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "4.2.3"
|
||||
@ -279,6 +349,15 @@ dependencies = [
|
||||
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "onig"
|
||||
version = "5.0.0"
|
||||
@ -418,6 +497,28 @@ dependencies = [
|
||||
"proc-macro2 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon-core 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crossbeam-queue 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crossbeam-utils 0.6.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.3.1"
|
||||
@ -439,11 +540,37 @@ name = "rustc-demangle"
|
||||
version = "0.1.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver-parser"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.102"
|
||||
@ -533,8 +660,11 @@ dependencies = [
|
||||
name = "tokenizers-lib"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"onig 5.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -632,12 +762,19 @@ dependencies = [
|
||||
"checksum cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)" = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
||||
"checksum clang-sys 0.28.1 (registry+https://github.com/rust-lang/crates.io-index)" = "81de550971c976f176130da4b2978d3b524eaa0fd9ac31f3ceb5ae1231fb4853"
|
||||
"checksum clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5067f5bb2d80ef5d68b4c87db81601f0b75bca627bc2ef76b141d7b846a3c6d9"
|
||||
"checksum crossbeam-deque 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c3aa945d63861bfe624b55d153a39684da1e8c0bc8fba932f7ee3a3c16cea3ca"
|
||||
"checksum crossbeam-epoch 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5064ebdbf05ce3cb95e45c8b086f72263f4166b29b97f6baff7ef7fe047b55ac"
|
||||
"checksum crossbeam-queue 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7c979cd6cfe72335896575c6b5688da489e420d36a27a0b9eb0c73db574b4a4b"
|
||||
"checksum crossbeam-utils 0.6.6 (registry+https://github.com/rust-lang/crates.io-index)" = "04973fa96e96579258a5091af6003abde64af786b860f18622b82e026cca60e6"
|
||||
"checksum crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ce446db02cdc3165b94ae73111e570793400d0794e46125cc4056c81cbb039f4"
|
||||
"checksum ctor 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)" = "cd8ce37ad4184ab2ce004c33bf6379185d3b1c95801cab51026bd271bf68eedc"
|
||||
"checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
|
||||
"checksum env_logger 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)" = "aafcde04e90a5226a6443b7aabdb016ba2f8307c847d524724bd9b346dd1a2d3"
|
||||
"checksum failure 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "f8273f13c977665c5db7eb2b99ae520952fe5ac831ae4cd09d80c4c7042b5ed9"
|
||||
"checksum fxhash 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
|
||||
"checksum ghost 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "2a36606a68532b5640dc86bb1f33c64b45c4682aad4c50f3937b317ea387f3d6"
|
||||
"checksum glob 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
|
||||
"checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120"
|
||||
"checksum humantime 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f"
|
||||
"checksum indoc 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "3f9553c1e16c114b8b77ebeb329e5f2876eed62a8d51178c8bc6bff0d65f98f8"
|
||||
"checksum indoc-impl 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b714fc08d0961716390977cdff1536234415ac37b509e34e5a983def8340fb75"
|
||||
@ -649,8 +786,10 @@ dependencies = [
|
||||
"checksum libloading 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f2b111a074963af1d37a139918ac6d49ad1d0d5e47f72fd55388619691a7d753"
|
||||
"checksum log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)" = "14b6052be84e6b71ab17edffc2eeabf5c2c3ae1fdb464aae35ac50c67a44e1f7"
|
||||
"checksum memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "88579771288728879b57485cc7d6b07d648c9f0141eb955f8ab7f9d45394468e"
|
||||
"checksum memoffset 0.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "75189eb85871ea5c2e2c15abbdd541185f63b408415e5051f5cac122d8c774b9"
|
||||
"checksum nom 4.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2ad2a91a8e869eeb30b9cb3119ae87773a8f4ae617f41b1eb9c154b2905f7bd6"
|
||||
"checksum num-traits 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)" = "6ba9a427cfca2be13aa6f6403b0b7e7368fe982bfa16fccc450ce74c46cd9b32"
|
||||
"checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72"
|
||||
"checksum onig 5.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e4e723fc996fff1aeab8f62205f3e8528bf498bdd5eadb2784d2d31f30077947"
|
||||
"checksum onig_sys 69.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "0a8d4efbf5f59cece01f539305191485b651acb3785b9d5eef05749f0496514e"
|
||||
"checksum paste 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "423a519e1c6e828f1e73b720f9d9ed2fa643dce8a7737fb43235ce0b41eeaa49"
|
||||
@ -666,10 +805,16 @@ dependencies = [
|
||||
"checksum quick-error 1.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "9274b940887ce9addde99c4eee6b5c44cc494b182b97e73dc8ffdcb3397fd3f0"
|
||||
"checksum quote 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)" = "6ce23b6b870e8f94f81fb0a363d65d86675884b34a09043c81e5562f11c1f8e1"
|
||||
"checksum quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "053a8c8bcc71fcce321828dc897a98ab9760bef03a4fc36693c231e5b3216cfe"
|
||||
"checksum rayon 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "83a27732a533a1be0a0035a111fe76db89ad312f6f0347004c220c57f209a123"
|
||||
"checksum rayon-core 1.6.0 (registry+https://github.com/rust-lang/crates.io-index)" = "98dcf634205083b17d0861252431eb2acbfb698ab7478a2d20de07954f47ec7b"
|
||||
"checksum regex 1.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "dc220bd33bdce8f093101afe22a037b8eb0e5af33592e6a9caafff0d4cb81cbd"
|
||||
"checksum regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)" = "11a7e20d1cce64ef2fed88b66d347f88bd9babb82845b2b858f3edbf59a4f716"
|
||||
"checksum rustc-demangle 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783"
|
||||
"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
|
||||
"checksum ryu 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "bfa8506c1de11c9c4e4c38863ccbe02a305c8188e85a05a784c9e11e1c3910c8"
|
||||
"checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d"
|
||||
"checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403"
|
||||
"checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3"
|
||||
"checksum serde 1.0.102 (registry+https://github.com/rust-lang/crates.io-index)" = "0c4b39bd9b0b087684013a792c59e3e07a46a01d2322518d8a1104641a0b1be0"
|
||||
"checksum serde_derive 1.0.102 (registry+https://github.com/rust-lang/crates.io-index)" = "ca13fc1a832f793322228923fbb3aba9f3f44444898f835d31ad1b74fa0a2bf8"
|
||||
"checksum serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)" = "2f72eb2a68a7dc3f9a691bfda9305a1c017a6215e5a4545c258500d2099a37c2"
|
||||
|
@ -1,3 +1,71 @@
|
||||
from tokenizers import WhitespaceTokenizer
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
|
||||
print(WhitespaceTokenizer.tokenize("Hey man!"))
|
||||
from tokenizers import Tokenizer
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--file", default=None, type=str, help="The file to encode")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.file is not None:
|
||||
current_dir = os.path.abspath(os.path.dirname(__file__))
|
||||
path = os.path.join(current_dir, args.file)
|
||||
|
||||
with open(path, "r") as fp:
|
||||
text = [ line.strip() for line in fp ]
|
||||
else:
|
||||
text = """
|
||||
The Zen of Python, by Tim Peters
|
||||
Beautiful is better than ugly.
|
||||
Explicit is better than implicit.
|
||||
Simple is better than complex.
|
||||
Complex is better than complicated.
|
||||
Flat is better than nested.
|
||||
Sparse is better than dense.
|
||||
Readability counts.
|
||||
Special cases aren't special enough to break the rules.
|
||||
Although practicality beats purity.
|
||||
Errors should never pass silently.
|
||||
Unless explicitly silenced.
|
||||
In the face of ambiguity, refuse the temptation to guess.
|
||||
There should be one-- and preferably only one --obvious way to do it.
|
||||
Although that way may not be obvious at first unless you're Dutch.
|
||||
Now is better than never.
|
||||
Although never is often better than *right* now.
|
||||
If the implementation is hard to explain, it's a bad idea.
|
||||
If the implementation is easy to explain, it may be a good idea.
|
||||
Namespaces are one honking great idea -- let's do more of those!
|
||||
""".split("\n")
|
||||
|
||||
|
||||
tok_p = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tok_r = Tokenizer.bpe_from_files(
|
||||
"../../data/gpt2-vocab.json",
|
||||
"../../data/gpt2-merges.txt",
|
||||
pre_tokenizer="ByteLevel",
|
||||
)
|
||||
|
||||
def tokenize_r():
|
||||
# return [ tok_r.encode(sentence) for sentence in text]
|
||||
return tok_r.encode_batch(text);
|
||||
|
||||
def tokenize_p():
|
||||
return [tok_p.encode(sentence) for sentence in text]
|
||||
|
||||
print(f"Tokenizing {len(text)} lines")
|
||||
|
||||
# Rust version
|
||||
start = time.time()
|
||||
encoded_r = tokenize_r()
|
||||
end = time.time()
|
||||
print(f"Rust tokenizer took: {end - start} sec")
|
||||
|
||||
# Python version
|
||||
start = time.time()
|
||||
encoded_p = tokenize_p()
|
||||
end = time.time()
|
||||
print(f"Transformer tokenizer took: {end - start} sec")
|
||||
|
||||
assert([ [ token.id for token in sentence] for sentence in encoded_r ] == encoded_p)
|
||||
|
@ -1,20 +1,162 @@
|
||||
extern crate tokenizers as tk;
|
||||
use tk::models::bpe::Error as BpeError;
|
||||
|
||||
use pyo3::exceptions;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyDict, PyList};
|
||||
|
||||
#[pyclass]
|
||||
struct WhitespaceTokenizer {}
|
||||
#[repr(transparent)]
|
||||
struct Token {
|
||||
tok: tk::tokenizer::Token,
|
||||
}
|
||||
impl Token {
|
||||
pub fn new(tok: tk::tokenizer::Token) -> Self {
|
||||
Token { tok }
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl WhitespaceTokenizer {
|
||||
impl Token {
|
||||
#[getter]
|
||||
fn get_id(&self) -> PyResult<u32> {
|
||||
Ok(self.tok.id)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_value(&self) -> PyResult<&str> {
|
||||
Ok(&self.tok.value)
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_offsets(&self) -> PyResult<(usize, usize)> {
|
||||
Ok(self.tok.offsets)
|
||||
}
|
||||
|
||||
fn as_tuple(&self) -> PyResult<(u32, &str, (usize, usize))> {
|
||||
Ok((self.tok.id, &self.tok.value, self.tok.offsets))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pre_tokenizer(name: &str) -> Option<Box<dyn tk::tokenizer::PreTokenizer + Sync>> {
|
||||
match name {
|
||||
"ByteLevel" => Some(Box::new(tk::pre_tokenizers::byte_level::ByteLevel)),
|
||||
"Whitespace" => Some(Box::new(tk::pre_tokenizers::whitespace::Whitespace)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_normalizer(_name: &str) -> Option<Box<dyn tk::tokenizer::Normalizer + Sync>> {
|
||||
None
|
||||
}
|
||||
|
||||
fn get_post_processor(_name: &str) -> Option<Box<dyn tk::tokenizer::PostProcessor + Sync>> {
|
||||
None
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct Tokenizer {
|
||||
tokenizer: tk::tokenizer::Tokenizer,
|
||||
}
|
||||
#[pymethods]
|
||||
impl Tokenizer {
|
||||
#[staticmethod]
|
||||
fn tokenize(s: String) -> PyResult<Vec<String>> {
|
||||
Ok(tk::WhitespaceTokenizer::tokenize(&s))
|
||||
#[args(kwargs = "**")]
|
||||
fn bpe_from_files(vocab: &str, merges: &str, kwargs: Option<&PyDict>) -> PyResult<Self> {
|
||||
let model = match tk::models::bpe::BPE::from_files(vocab, merges) {
|
||||
Ok(bpe) => Ok(Box::new(bpe)),
|
||||
Err(e) => match e {
|
||||
BpeError::BadVocabulary => {
|
||||
Err(exceptions::Exception::py_err("Bad vocab.json format"))
|
||||
}
|
||||
BpeError::Io(io) => Err(PyErr::from(io)),
|
||||
BpeError::JsonError(_) => Err(exceptions::Exception::py_err(
|
||||
"Error while parsing vocab json file",
|
||||
)),
|
||||
BpeError::MergeTokenOutOfVocabulary(token) => Err(exceptions::Exception::py_err(
|
||||
format!("Merge token out of vocabulary: {}", token),
|
||||
)),
|
||||
},
|
||||
}?;
|
||||
|
||||
let mut tokenizer = tk::tokenizer::Tokenizer::new(model);
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (option, value) in kwargs {
|
||||
match option.to_string().as_ref() {
|
||||
"pre_tokenizer" => {
|
||||
let value = value.to_string();
|
||||
if let Some(pre_tokenizer) = get_pre_tokenizer(&value) {
|
||||
tokenizer.with_pre_tokenizer(pre_tokenizer);
|
||||
} else {
|
||||
return Err(exceptions::Exception::py_err(format!(
|
||||
"PreTokenizer `{}` not found",
|
||||
value
|
||||
)));
|
||||
}
|
||||
}
|
||||
"normalizers" => {
|
||||
let mut normalizers = vec![];
|
||||
let values = value.cast_as::<PyList>()?;
|
||||
for value in values {
|
||||
let value = value.to_string();
|
||||
if let Some(normalizer) = get_normalizer(&value) {
|
||||
normalizers.push(normalizer);
|
||||
} else {
|
||||
return Err(exceptions::Exception::py_err(format!(
|
||||
"Normalizer `{}` not found",
|
||||
value
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
"post_processors" => {
|
||||
let mut processors = vec![];
|
||||
let values = value.cast_as::<PyList>()?;
|
||||
for value in values {
|
||||
let value = value.to_string();
|
||||
if let Some(processor) = get_post_processor(&value) {
|
||||
processors.push(processor);
|
||||
} else {
|
||||
return Err(exceptions::Exception::py_err(format!(
|
||||
"PostProcessor `{}` not found",
|
||||
value
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => println!("Ignored unknown kwarg `{}`", option),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Tokenizer { tokenizer })
|
||||
}
|
||||
|
||||
fn encode(&self, sentence: &str) -> Vec<Token> {
|
||||
self.tokenizer
|
||||
.encode(sentence)
|
||||
.into_iter()
|
||||
.map(|token| Token::new(token))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn encode_batch(&self, sentences: Vec<&str>) -> Vec<Vec<Token>> {
|
||||
self.tokenizer
|
||||
.encode_batch(sentences)
|
||||
.into_iter()
|
||||
.map(|sentence| {
|
||||
sentence
|
||||
.into_iter()
|
||||
.map(|token| Token::new(token))
|
||||
.collect()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn tokenizers(py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<WhitespaceTokenizer>()?;
|
||||
fn tokenizers(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<Tokenizer>()?;
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user