From e2e2ca28c95a6447fb6d2b82de2e74b8cb7d080e Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Fri, 1 May 2020 15:51:21 -0400 Subject: [PATCH] Node - Update bindings for new `encode` --- bindings/node/native/Cargo.lock | 178 +++++++- bindings/node/native/Cargo.toml | 7 +- bindings/node/native/src/lib.rs | 5 +- bindings/node/native/src/tasks/tokenizer.rs | 98 +---- bindings/node/native/src/tokenizer.rs | 457 +++++++++++--------- 5 files changed, 438 insertions(+), 307 deletions(-) diff --git a/bindings/node/native/Cargo.lock b/bindings/node/native/Cargo.lock index b951aa0c..cc6aa024 100644 --- a/bindings/node/native/Cargo.lock +++ b/bindings/node/native/Cargo.lock @@ -31,6 +31,31 @@ name = "autocfg" version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "autocfg" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "backtrace" +version = "0.3.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "backtrace-sys 0.1.37 (registry+https://github.com/rust-lang/crates.io-index)", + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)", + "rustc-demangle 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "backtrace-sys" +version = "0.1.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "cc 1.0.50 (registry+https://github.com/rust-lang/crates.io-index)", + "libc 0.2.66 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "bitflags" version = "1.2.1" @@ -150,6 +175,15 @@ name = "encode_unicode" version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "error-chain" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "backtrace 0.3.46 (registry+https://github.com/rust-lang/crates.io-index)", + "version_check 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "getrandom" version = "0.1.14" @@ -236,6 +270,18 @@ dependencies = [ "neon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "neon-serde" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "error-chain 0.12.2 (registry+https://github.com/rust-lang/crates.io-index)", + "neon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", + "neon-runtime 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", + "num 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)", + "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "neon-sys" version = "0.3.3" @@ -251,10 +297,71 @@ version = "0.1.0" dependencies = [ "neon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", "neon-build 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", - "rayon 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "neon-runtime 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)", + "neon-serde 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)", + "serde_derive 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)", "tokenizers 0.10.1", ] +[[package]] +name = "num" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)", + "num-integer 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)", + "num-iter 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", + "num-rational 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "num-complex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "num-integer" +version = "0.1.42" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "num-iter" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "num-integer 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "num-rational" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "num-integer 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)", + "num-traits 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "num-traits" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "num_cpus" version = "1.12.0" @@ -274,6 +381,22 @@ name = "ppv-lite86" version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "proc-macro2" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "quote" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 1.0.10 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "rand" version = "0.7.3" @@ -349,6 +472,11 @@ name = "regex-syntax" version = "0.6.13" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "rustc-demangle" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "rustc_version" version = "0.2.3" @@ -385,6 +513,16 @@ name = "serde" version = "1.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "serde_derive" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 1.0.10 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)", + "syn 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "serde_json" version = "1.0.45" @@ -405,6 +543,16 @@ name = "strsim" version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "syn" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "proc-macro2 1.0.10 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)", + "unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "termios" version = "0.3.1" @@ -459,6 +607,11 @@ name = "unicode-width" version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "unicode-xid" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "unicode_categories" version = "0.1.1" @@ -469,6 +622,11 @@ name = "vec_map" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "version_check" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "wasi" version = "0.9.0+wasi-snapshot-preview1" @@ -498,6 +656,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum ansi_term 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b" "checksum atty 0.2.14 (registry+https://github.com/rust-lang/crates.io-index)" = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" "checksum autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2" +"checksum autocfg 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "f8aac770f1885fd7e387acedd76065302551364496e46b3dd00860b2f8359b9d" +"checksum backtrace 0.3.46 (registry+https://github.com/rust-lang/crates.io-index)" = "b1e692897359247cc6bb902933361652380af0f1b7651ae5c5013407f30e109e" +"checksum backtrace-sys 0.1.37 (registry+https://github.com/rust-lang/crates.io-index)" = "18fbebbe1c9d1f383a9cc7e8ccdb471b91c8d024ee9c2ca5b5346121fe8b4399" "checksum bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" "checksum c2-chacha 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "214238caa1bf3a496ec3392968969cab8549f96ff30652c9e56885329315f6bb" "checksum cc 1.0.50 (registry+https://github.com/rust-lang/crates.io-index)" = "95e28fa049fda1c330bcf9d723be7663a899c4679724b34c81e9f5a326aab8cd" @@ -512,6 +673,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum cslice 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "697c714f50560202b1f4e2e09cd50a421881c83e9025db75d15f276616f04f40" "checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3" "checksum encode_unicode 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +"checksum error-chain 0.12.2 (registry+https://github.com/rust-lang/crates.io-index)" = "d371106cc88ffdfb1eabd7111e432da544f16f3e2d7bf1dfe8bf575f1df045cd" "checksum getrandom 0.1.14 (registry+https://github.com/rust-lang/crates.io-index)" = "7abc8dd8451921606d809ba32e95b6111925cd2906060d2dcc29c070220503eb" "checksum hermit-abi 0.1.6 (registry+https://github.com/rust-lang/crates.io-index)" = "eff2656d88f158ce120947499e971d743c05dbcbed62e5bd2f38f1698bbc3772" "checksum indicatif 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)" = "49a68371cf417889c9d7f98235b7102ea7c54fc59bcbd22f3dea785be9d27e40" @@ -523,10 +685,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum neon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "53b85accbbd250627f899a6fc1f220bbb4c8c2ff6dc71830dc6b752b39c2eb97" "checksum neon-build 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ae406bf1065c4399e69d328a3bd8d4f088f2a205dc3881bf68c0ac775bfef337" "checksum neon-runtime 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "d8465ac4ed3f340dead85e053b75a5f639f48ac6343b3523eff90a751758eead" +"checksum neon-serde 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "4b45847cc4cec46db1ff2e921cd9948b38afa0c2348c99c06f9ae70406a30b60" "checksum neon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "8ae4cf3871ca5a395077e68144c1754e94e9e1e3329e7f8399d999ca573ed89a" +"checksum num 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36" +"checksum num-complex 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95" +"checksum num-integer 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "3f6ea62e9d81a77cd3ee9a2a5b9b609447857f3d358704331e4ef39eb247fcba" +"checksum num-iter 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)" = "dfb0800a0291891dd9f4fe7bd9c19384f98f7fbe0cd0f39a2c6b88b9868bbc00" +"checksum num-rational 0.2.4 (registry+https://github.com/rust-lang/crates.io-index)" = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef" +"checksum num-traits 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)" = "c62be47e61d1842b9170f0fdeec8eba98e60e90e5446449a0545e5152acd7096" "checksum num_cpus 1.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "46203554f085ff89c235cd12f7075f3233af9b11ed7c9e16dfe2560d03313ce6" "checksum number_prefix 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a" "checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b" +"checksum proc-macro2 1.0.10 (registry+https://github.com/rust-lang/crates.io-index)" = "df246d292ff63439fea9bc8c0a270bed0e390d5ebd4db4ba15aba81111b5abe3" +"checksum quote 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "4c1f4b0efa5fc5e8ceb705136bfee52cfdb6a4e3509f770b478cd6ed434232a7" "checksum rand 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" "checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853" "checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" @@ -535,22 +706,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum rayon-core 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "08a89b46efaf957e52b18062fb2f4660f8b8a4dde1807ca002690868ef2c85a9" "checksum regex 1.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "b5508c1941e4e7cb19965abef075d35a9a8b5cdf0846f30b4050e9b55dc55e87" "checksum regex-syntax 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)" = "e734e891f5b408a29efbf8309e656876276f49ab6a6ac208600b4419bd893d90" +"checksum rustc-demangle 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783" "checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a" "checksum ryu 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "bfa8506c1de11c9c4e4c38863ccbe02a305c8188e85a05a784c9e11e1c3910c8" "checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d" "checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403" "checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" "checksum serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)" = "414115f25f818d7dfccec8ee535d76949ae78584fc4f79a6f45a904bf8ab4449" +"checksum serde_derive 1.0.106 (registry+https://github.com/rust-lang/crates.io-index)" = "9e549e3abf4fb8621bd1609f11dfc9f5e50320802273b12f3811a67e6716ea6c" "checksum serde_json 1.0.45 (registry+https://github.com/rust-lang/crates.io-index)" = "eab8f15f15d6c41a154c1b128a22f2dfabe350ef53c40953d84e36155c91192b" "checksum smallvec 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5c2fb2ec9bcd216a5b0d0ccf31ab17b5ed1d627960edff65bbe95d3ce221cefc" "checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" +"checksum syn 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)" = "410a7488c0a728c7ceb4ad59b9567eb4053d02e8cc7f5c0e0eeeb39518369213" "checksum termios 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "72b620c5ea021d75a735c943269bb07d30c9b77d6ac6b236bc8b5c496ef05625" "checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" "checksum thread_local 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "d40c6d1b69745a6ec6fb1ca717914848da4b44ae29d9b3080cbee91d72a69b14" "checksum unicode-normalization-alignments 0.1.12 (registry+https://github.com/rust-lang/crates.io-index)" = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" "checksum unicode-width 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "caaa9d531767d1ff2150b9332433f32a24622147e5ebb1f26409d5da67afd479" +"checksum unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "826e7639553986605ec5979c7dd957c7895e93eabed50ab2ffa7f6128a75097c" "checksum unicode_categories 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" "checksum vec_map 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "05c78687fb1a80548ae3250346c3db86a80a7cdd77bda190189f2d0a0987c81a" +"checksum version_check 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)" = "078775d0255232fb988e6fccf26ddc9d1ac274299aaedcedce21c6f72cc533ce" "checksum wasi 0.9.0+wasi-snapshot-preview1 (registry+https://github.com/rust-lang/crates.io-index)" = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" "checksum winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)" = "8093091eeb260906a183e6ae1abdba2ef5ef2257a21801128899c3fc699229c6" "checksum winapi-i686-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" diff --git a/bindings/node/native/Cargo.toml b/bindings/node/native/Cargo.toml index 2e648d6d..58540874 100644 --- a/bindings/node/native/Cargo.toml +++ b/bindings/node/native/Cargo.toml @@ -14,6 +14,9 @@ crate-type = ["cdylib"] neon-build = "0.3.3" [dependencies] -neon = "0.3.3" -rayon = "1.2.0" +neon = "0.3" +neon-runtime = "0.3" +neon-serde = "0.3" +serde = "1.0" +serde_derive = "1.0" tokenizers = { path = "../../../tokenizers" } diff --git a/bindings/node/native/src/lib.rs b/bindings/node/native/src/lib.rs index 3bfb3084..f9c04c3a 100644 --- a/bindings/node/native/src/lib.rs +++ b/bindings/node/native/src/lib.rs @@ -1,7 +1,10 @@ #![warn(clippy::all)] extern crate neon; -extern crate rayon; +extern crate neon_serde; +extern crate serde; +#[macro_use] +extern crate serde_derive; extern crate tokenizers as tk; mod container; diff --git a/bindings/node/native/src/tasks/tokenizer.rs b/bindings/node/native/src/tasks/tokenizer.rs index ad07a2a3..2db9023a 100644 --- a/bindings/node/native/src/tasks/tokenizer.rs +++ b/bindings/node/native/src/tasks/tokenizer.rs @@ -2,7 +2,6 @@ extern crate tokenizers as tk; use crate::encoding::*; use neon::prelude::*; -use rayon::prelude::*; use tk::tokenizer::{EncodeInput, Encoding, Tokenizer}; pub struct WorkingTokenizer { @@ -40,7 +39,8 @@ impl Task for EncodeTask { fn perform(&self) -> Result { match self { EncodeTask::Single(worker, input, add_special_tokens) => { - let mut input = unsafe { std::ptr::replace(input as *const _ as *mut _, None) }; + let mut input: Option = + unsafe { std::ptr::replace(input as *const _ as *mut _, None) }; let tokenizer: &Tokenizer = unsafe { &*worker.ptr }; tokenizer .encode( @@ -51,7 +51,8 @@ impl Task for EncodeTask { .map(|encoding| EncodeOutput::Single(encoding)) } EncodeTask::Batch(worker, input, add_special_tokens) => { - let mut input = unsafe { std::ptr::replace(input as *const _ as *mut _, None) }; + let mut input: Option> = + unsafe { std::ptr::replace(input as *const _ as *mut _, None) }; let tokenizer: &Tokenizer = unsafe { &*worker.ptr }; tokenizer .encode_batch( @@ -101,97 +102,6 @@ impl Task for EncodeTask { } } -pub enum EncodeTokenizedTask { - Single(WorkingTokenizer, Option>, u32), - Batch( - WorkingTokenizer, - Option>>, - u32, - ), -} - -pub enum EncodeTokenizedOutput { - Single(Encoding), - Batch(Vec), -} - -impl Task for EncodeTokenizedTask { - type Output = EncodeTokenizedOutput; - type Error = String; - type JsEvent = JsValue; - - fn perform(&self) -> Result { - match self { - EncodeTokenizedTask::Single(worker, input, type_id) => { - let input = unsafe { std::ptr::replace(input as *const _ as *mut _, None) }; - let tokenizer: &Tokenizer = unsafe { &*worker.ptr }; - - tokenizer - .get_model() - .tokenize(input.unwrap()) - .map_err(|e| format!("{}", e)) - .map(|tokens| { - EncodeTokenizedOutput::Single(Encoding::from_tokens(tokens, *type_id)) - }) - } - EncodeTokenizedTask::Batch(worker, input, type_id) => { - let input: Option> = - unsafe { std::ptr::replace(input as *const _ as *mut _, None) }; - let tokenizer: &Tokenizer = unsafe { &*worker.ptr }; - - input - .unwrap() - .into_par_iter() - .map(|input| { - tokenizer - .get_model() - .tokenize(input) - .map_err(|e| format!("{}", e)) - .map(|tokens| Encoding::from_tokens(tokens, *type_id)) - }) - .collect::>() - .map(EncodeTokenizedOutput::Batch) - } - } - } - - fn complete( - self, - mut cx: TaskContext, - result: Result, - ) -> JsResult { - match result.map_err(|e| cx.throw_error::<_, ()>(e).unwrap_err())? { - EncodeTokenizedOutput::Single(encoding) => { - let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?; - // Set the actual encoding - let guard = cx.lock(); - js_encoding - .borrow_mut(&guard) - .encoding - .to_owned(Box::new(encoding)); - - Ok(js_encoding.upcast()) - } - EncodeTokenizedOutput::Batch(encodings) => { - let result = JsArray::new(&mut cx, encodings.len() as u32); - for (i, encoding) in encodings.into_iter().enumerate() { - let mut js_encoding = JsEncoding::new::<_, JsEncoding, _>(&mut cx, vec![])?; - - // Set the actual encoding - let guard = cx.lock(); - js_encoding - .borrow_mut(&guard) - .encoding - .to_owned(Box::new(encoding)); - - result.set(&mut cx, i as u32, js_encoding)?; - } - Ok(result.upcast()) - } - } - } -} - pub enum DecodeTask { Single(WorkingTokenizer, Vec, bool), Batch(WorkingTokenizer, Vec>, bool), diff --git a/bindings/node/native/src/tokenizer.rs b/bindings/node/native/src/tokenizer.rs index b5e56764..8032eb98 100644 --- a/bindings/node/native/src/tokenizer.rs +++ b/bindings/node/native/src/tokenizer.rs @@ -7,9 +7,10 @@ use crate::models::JsModel; use crate::normalizers::JsNormalizer; use crate::pre_tokenizers::JsPreTokenizer; use crate::processors::JsPostProcessor; -use crate::tasks::tokenizer::{DecodeTask, EncodeTask, EncodeTokenizedTask, WorkingTokenizer}; +use crate::tasks::tokenizer::{DecodeTask, EncodeTask, WorkingTokenizer}; use crate::trainers::JsTrainer; use neon::prelude::*; +use serde::de::DeserializeOwned; use tk::tokenizer::{ PaddingDirection, PaddingParams, PaddingStrategy, TruncationParams, TruncationStrategy, @@ -69,6 +70,190 @@ declare_types! { } } +pub struct Error(String); +impl From for Error +where + T: std::fmt::Display, +{ + fn from(e: T) -> Self { + Self(format!("{}", e)) + } +} +impl From for neon::result::Throw { + fn from(err: Error) -> Self { + let msg = err.0; + unsafe { + neon_runtime::error::throw_error_from_utf8(msg.as_ptr(), msg.len() as i32); + neon::result::Throw + } + } +} + +pub type LibResult = std::result::Result; + +trait FromJsValue: Sized { + fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult; +} + +impl FromJsValue for T +where + T: DeserializeOwned, +{ + fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult { + let val: T = neon_serde::from_value(cx, from)?; + Ok(val) + } +} + +trait Extract { + fn extract(&mut self, pos: i32) -> LibResult; + fn extract_opt(&mut self, pos: i32) -> LibResult>; + fn extract_vec(&mut self, pos: i32) -> LibResult>; + fn extract_vec_opt(&mut self, pos: i32) -> LibResult>>; +} + +impl<'c, T: neon::object::This> Extract for CallContext<'c, T> { + fn extract(&mut self, pos: i32) -> LibResult { + let val = self + .argument_opt(pos) + .ok_or_else(|| Error(format!("Argument {} is missing", pos)))?; + let ext = E::from_value(val, self)?; + Ok(ext) + } + + fn extract_opt(&mut self, pos: i32) -> LibResult> { + let val = self.argument_opt(pos); + match val { + None => Ok(None), + Some(v) => { + // For any optional value, we accept both `undefined` and `null` + if v.downcast::().is_ok() || v.downcast::().is_ok() { + Ok(None) + } else { + Ok(Some(E::from_value(v, self)?)) + } + } + } + } + + fn extract_vec(&mut self, pos: i32) -> LibResult> { + let vec = self + .argument_opt(pos) + .ok_or_else(|| Error(format!("Argument {} is missing", pos)))? + .downcast::()? + .to_vec(self)?; + + vec.into_iter().map(|v| E::from_value(v, self)).collect() + } + + fn extract_vec_opt(&mut self, pos: i32) -> LibResult>> { + self.argument_opt(pos) + .map(|v| { + let vec = v.downcast::()?.to_vec(self)?; + Ok(vec + .into_iter() + .map(|v| E::from_value(v, self)) + .collect::>>()?) + }) + .map_or(Ok(None), |v| v.map(Some)) + } +} + +struct TextInputSequence(tk::InputSequence); +struct PreTokenizedInputSequence(tk::InputSequence); +impl FromJsValue for PreTokenizedInputSequence { + fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult { + let sequence = from + .downcast::()? + .to_vec(cx)? + .into_iter() + .map(|v| Ok(v.downcast::()?.value())) + .collect::>>()?; + Ok(Self(sequence.into())) + } +} +impl From for tk::InputSequence { + fn from(v: PreTokenizedInputSequence) -> Self { + v.0 + } +} +impl FromJsValue for TextInputSequence { + fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, _cx: &mut C) -> LibResult { + Ok(Self(from.downcast::()?.value().into())) + } +} +impl From for tk::InputSequence { + fn from(v: TextInputSequence) -> Self { + v.0 + } +} + +struct TextEncodeInput(tk::EncodeInput); +struct PreTokenizedEncodeInput(tk::EncodeInput); +impl FromJsValue for PreTokenizedEncodeInput { + fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult { + // If array is of size 2, and the first element is also an array, we'll parse a pair + let array = from.downcast::()?; + let is_pair = array.len() == 2 + && array + .get(cx, 0) + .map_or(false, |a| a.downcast::().is_ok()); + + if is_pair { + let first_seq: tk::InputSequence = + PreTokenizedInputSequence::from_value(array.get(cx, 0)?, cx)?.into(); + let pair_seq: tk::InputSequence = + PreTokenizedInputSequence::from_value(array.get(cx, 1)?, cx)?.into(); + Ok(Self((first_seq, pair_seq).into())) + } else { + Ok(Self( + PreTokenizedInputSequence::from_value(from, cx)?.into(), + )) + } + } +} +impl From for tk::EncodeInput { + fn from(v: PreTokenizedEncodeInput) -> Self { + v.0 + } +} +impl FromJsValue for TextEncodeInput { + fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult { + // If we get an array, it's a pair of sequences + if let Ok(array) = from.downcast::() { + let first_seq: tk::InputSequence = + TextInputSequence::from_value(array.get(cx, 0)?, cx)?.into(); + let pair_seq: tk::InputSequence = + TextInputSequence::from_value(array.get(cx, 1)?, cx)?.into(); + Ok(Self((first_seq, pair_seq).into())) + } else { + Ok(Self(TextInputSequence::from_value(from, cx)?.into())) + } + } +} +impl From for tk::EncodeInput { + fn from(v: TextEncodeInput) -> Self { + v.0 + } +} + +#[allow(non_snake_case)] +#[derive(Debug, Serialize, Deserialize)] +struct EncodeOptions { + #[serde(default)] + isPretokenized: bool, + #[serde(default)] + addSpecialTokens: bool, +} +impl Default for EncodeOptions { + fn default() -> Self { + Self { + isPretokenized: false, + addSpecialTokens: true, + } + } +} + /// Tokenizer pub struct Tokenizer { tokenizer: tk::tokenizer::Tokenizer, @@ -184,28 +369,47 @@ declare_types! { } method encode(mut cx) { + // type InputSequence = string | string[]; // encode( - // sentence: String, - // pair: String | null, - // add_special_tokens: boolean, + // sentence: InputSequence, + // pair?: InputSequence, + // options?: { + // addSpecialTokens?: boolean, + // isPretokenized?: boolean, + // } | (err, encoding) -> void, // __callback: (err, encoding) -> void // ) - let sentence = cx.argument::(0)?.value(); - let mut pair: Option = None; - if let Some(args) = cx.argument_opt(1) { - if let Ok(p) = args.downcast::() { - pair = Some(p.value()); - } else if args.downcast::().is_err() { - return cx.throw_error("Second arg must be of type `String | null`"); - } - } - let add_special_tokens = cx.argument::(2)?.value(); - let callback = cx.argument::(3)?; - let input = if let Some(pair) = pair { - tk::tokenizer::EncodeInput::Dual(sentence, pair) + // Start by extracting options and callback + let (options, callback) = match cx.extract_opt::(2) { + // Options were there, and extracted + Ok(Some(options)) => { + (options, cx.argument::(3)?) + }, + // Options were undefined or null + Ok(None) => { + (EncodeOptions::default(), cx.argument::(3)?) + } + // Options not specified, callback instead + Err(_) => { + (EncodeOptions::default(), cx.argument::(2)?) + } + }; + + // Then we extract our input sequences + let sentence: tk::InputSequence = if options.isPretokenized { + cx.extract::(0)?.into() } else { - tk::tokenizer::EncodeInput::Single(sentence) + cx.extract::(0)?.into() + }; + let pair: Option = if options.isPretokenized { + cx.extract_opt::(1)?.map(|v| v.into()) + } else { + cx.extract_opt::(1)?.map(|v| v.into()) + }; + let input: tk::EncodeInput = match pair { + Some(pair) => (sentence, pair).into(), + None => sentence.into() }; let worker = { @@ -215,212 +419,47 @@ declare_types! { worker }; - let task = EncodeTask::Single(worker, Some(input), add_special_tokens); + let task = EncodeTask::Single(worker, Some(input), options.addSpecialTokens); task.schedule(callback); Ok(cx.undefined().upcast()) } method encodeBatch(mut cx) { - // type EncodeInput = (String | [String, String])[] + // type InputSequence = string | string[]; + // type EncodeInput = (InputSequence | [InputSequence, InputSequence])[] // encode_batch( - // sentences: EncodeInput[], - // add_special_tokens: boolean, + // inputs: EncodeInput[], + // options?: { + // addSpecialTokens?: boolean, + // isPretokenized?: boolean, + // } | (err, encodings) -> void, // __callback: (err, encodings) -> void // ) - let inputs = cx.argument::(0)?.to_vec(&mut cx)?; - let inputs = inputs.into_iter().map(|value| { - if let Ok(s) = value.downcast::() { - Ok(tk::tokenizer::EncodeInput::Single(s.value())) - } else if let Ok(arr) = value.downcast::() { - if arr.len() != 2 { - cx.throw_error("Input must be an array of `String | [String, String]`") - } else { - Ok(tk::tokenizer::EncodeInput::Dual( - arr.get(&mut cx, 0)? - .downcast::() - .or_throw(&mut cx)? - .value(), - arr.get(&mut cx, 1)? - .downcast::() - .or_throw(&mut cx)? - .value()) - ) - } - } else { - cx.throw_error("Input must be an array of `String | [String, String]`") - } - }).collect::>>()?; - let add_special_tokens = cx.argument::(1)?.value(); - let callback = cx.argument::(2)?; - let worker = { - let this = cx.this(); - let guard = cx.lock(); - let worker = this.borrow(&guard).prepare_for_task(); - worker - }; - - let task = EncodeTask::Batch(worker, Some(inputs), add_special_tokens); - task.schedule(callback); - Ok(cx.undefined().upcast()) - } - - method encodeTokenized(mut cx) { - /// encodeTokenized( - /// sequence: (String | [String, [number, number]])[], - /// typeId?: number = 0, - /// callback: (err, Encoding) - /// ) - - let sequence = cx.argument::(0)?.to_vec(&mut cx)?; - - let type_arg = cx.argument::(1)?; - let type_id = if type_arg.downcast::().is_err() { - type_arg.downcast_or_throw::(&mut cx)?.value() as u32 - } else { - 0 - }; - - enum Mode { - NoOffsets, - Offsets, - }; - let mode = sequence.iter().next().map(|item| { - if item.downcast::().is_ok() { - Ok(Mode::NoOffsets) - } else if item.downcast::().is_ok() { - Ok(Mode::Offsets) - } else { - Err("Input must be (String | [String, [number, number]])[]") - } - }) - .unwrap() - .map_err(|e| cx.throw_error::<_, ()>(e.to_string()).unwrap_err())?; - - let mut total_len = 0; - let sequence = sequence.iter().map(|item| match mode { - Mode::NoOffsets => { - let s = item.downcast::().or_throw(&mut cx)?.value(); - let len = s.chars().count(); - total_len += len; - Ok((s, (total_len - len, total_len))) + // Start by extracting options and callback + let (options, callback) = match cx.extract_opt::(1) { + // Options were there, and extracted + Ok(Some(options)) => { + (options, cx.argument::(2)?) }, - Mode::Offsets => { - let tuple = item.downcast::().or_throw(&mut cx)?; - let s = tuple.get(&mut cx, 0)? - .downcast::() - .or_throw(&mut cx)? - .value(); - let offsets = tuple.get(&mut cx, 1)? - .downcast::() - .or_throw(&mut cx)?; - let (start, end) = ( - offsets.get(&mut cx, 0)? - .downcast::() - .or_throw(&mut cx)? - .value() as usize, - offsets.get(&mut cx, 1)? - .downcast::(). - or_throw(&mut cx)? - .value() as usize, - ); - Ok((s, (start, end))) + // Options were undefined or null + Ok(None) => { + (EncodeOptions::default(), cx.argument::(2)?) + } + // Options not specified, callback instead + Err(_) => { + (EncodeOptions::default(), cx.argument::(1)?) } - }).collect::, _>>()?; - let callback = cx.argument::(2)?; - - let worker = { - let this = cx.this(); - let guard = cx.lock(); - let worker = this.borrow(&guard).prepare_for_task(); - worker }; - let task = EncodeTokenizedTask::Single(worker, Some(sequence), type_id); - task.schedule(callback); - Ok(cx.undefined().upcast()) - } - - method encodeTokenizedBatch(mut cx) { - /// encodeTokenizedBatch( - /// sequences: (String | [String, [number, number]])[][], - /// typeId?: number = 0, - /// callback: (err, Encoding) - /// ) - - let sequences = cx.argument::(0)?.to_vec(&mut cx)?; - - let type_arg = cx.argument::(1)?; - let type_id = if type_arg.downcast::().is_err() { - type_arg.downcast_or_throw::(&mut cx)?.value() as u32 + let inputs: Vec = if options.isPretokenized { + cx.extract_vec::(0)? + .into_iter().map(|v| v.into()).collect() } else { - 0 + cx.extract_vec::(0)? + .into_iter().map(|v| v.into()).collect() }; - enum Mode { - NoOffsets, - Offsets, - }; - let mode = sequences.iter().next().map(|sequence| { - if let Ok(sequence) = sequence.downcast::().or_throw(&mut cx) { - sequence.to_vec(&mut cx).ok().map(|s| s.iter().next().map(|item| { - if item.downcast::().is_ok() { - Some(Mode::NoOffsets) - } else if item.downcast::().is_ok() { - Some(Mode::Offsets) - } else { - None - } - }).flatten()).flatten() - } else { - None - } - }) - .flatten() - .ok_or_else(|| - cx.throw_error::<_, ()>( - "Input must be (String | [String, [number, number]])[]" - ).unwrap_err() - )?; - - let sequences = sequences.into_iter().map(|sequence| { - let mut total_len = 0; - sequence.downcast::().or_throw(&mut cx)? - .to_vec(&mut cx)? - .into_iter() - .map(|item| match mode { - Mode::NoOffsets => { - let s = item.downcast::().or_throw(&mut cx)?.value(); - let len = s.chars().count(); - total_len += len; - Ok((s, (total_len - len, total_len))) - }, - Mode::Offsets => { - let tuple = item.downcast::().or_throw(&mut cx)?; - let s = tuple.get(&mut cx, 0)? - .downcast::() - .or_throw(&mut cx)? - .value(); - let offsets = tuple.get(&mut cx, 1)? - .downcast::() - .or_throw(&mut cx)?; - let (start, end) = ( - offsets.get(&mut cx, 0)? - .downcast::() - .or_throw(&mut cx)? - .value() as usize, - offsets.get(&mut cx, 1)? - .downcast::(). - or_throw(&mut cx)? - .value() as usize, - ); - Ok((s, (start, end))) - } - }).collect::, _>>() - }) - .collect::, _>>()?; - let callback = cx.argument::(2)?; - let worker = { let this = cx.this(); let guard = cx.lock(); @@ -428,7 +467,7 @@ declare_types! { worker }; - let task = EncodeTokenizedTask::Batch(worker, Some(sequences), type_id); + let task = EncodeTask::Batch(worker, Some(inputs), options.addSpecialTokens); task.schedule(callback); Ok(cx.undefined().upcast()) }