mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-10 22:58:25 +00:00
Truncate Right (#841)
* feat(tokenizers): add truncate test case * !feat(tokenizer): truncate right * refacto(tokenizers): clippy * feat(bindings): update bindings for truncate() * fix(tokenizers): remove unsafe code * refacto(tokenizers): truncate direction * truncate direction enum * compute parts ranges beforehand * 2n space because encoding is dropped at the end of procedure * update bindings * add pip install in python bindings' make test * fix(node): clippy asks to use unwrap_or_else * fix(node): lint * refacto(tokenizers): replace Vec<Range<usize>> by Vec<(usize, usize)> * refacto(bindings): add match syntax * refacto(tokenizers): use mem::replace instead of mem::swap * refacto(tokenizers): assign value the normal way
This commit is contained in:
3
bindings/node/lib/bindings/raw-encoding.d.ts
vendored
3
bindings/node/lib/bindings/raw-encoding.d.ts
vendored
@@ -142,8 +142,9 @@ export interface RawEncoding {
|
||||
* @param length The maximum length to be kept
|
||||
* @param [stride=0] The length of the previous first sequence
|
||||
* to be included in the overflowing sequence
|
||||
* @param [direction='right'] Truncate direction
|
||||
*/
|
||||
truncate(length: number, stride?: number): void;
|
||||
truncate(length: number, stride?: number, direction?: string): void;
|
||||
}
|
||||
|
||||
interface PaddingOptions {
|
||||
|
||||
@@ -254,9 +254,10 @@ export class Encoding {
|
||||
* @param length The maximum length to be kept
|
||||
* @param [stride=0] The length of the previous first sequence
|
||||
* to be included in the overflowing sequence
|
||||
* @param [direction='right'] Truncate direction
|
||||
*/
|
||||
truncate(length: number, stride?: number): void {
|
||||
this._rawEncoding.truncate(length, stride);
|
||||
truncate(length: number, stride?: number, direction = "right"): void {
|
||||
this._rawEncoding.truncate(length, stride, direction);
|
||||
this.resetInternalProperties();
|
||||
}
|
||||
|
||||
|
||||
238
bindings/node/native/Cargo.lock
generated
238
bindings/node/native/Cargo.lock
generated
@@ -28,9 +28,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "ansi_term"
|
||||
version = "0.11.0"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b"
|
||||
checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
@@ -93,9 +93,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
|
||||
|
||||
[[package]]
|
||||
name = "bitvec"
|
||||
version = "0.19.5"
|
||||
version = "0.19.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8942c8d352ae1838c9dda0b0ca2ab657696ef2232a20147cf1b30ae1a9cb4321"
|
||||
checksum = "55f93d0ef3363c364d5976646a38f04cf67cfe1d4c8d160cdea02cab2c116b33"
|
||||
dependencies = [
|
||||
"funty",
|
||||
"radium",
|
||||
@@ -114,9 +114,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.7.0"
|
||||
version = "3.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c59e7af012c713f529e7a3ee57ce9b31ddd858d4b512923602f74608b009631"
|
||||
checksum = "8f1e260c3a9040a7c19a12468758f4c16f31a81a1fe087482be9570ec864bb6c"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
@@ -176,9 +176,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.70"
|
||||
version = "1.0.72"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d26a6ce4b6a484fa3edb70f7efa6fc430fd2b87285fe8b84304fd0936faa0dc0"
|
||||
checksum = "22a9137b95ea06864e018375b72adfb7db6e6f68cfc8df5a04d00288050485ee"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
@@ -194,9 +194,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "2.33.3"
|
||||
version = "2.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37e58ac78573c40708d45522f0d80fa2f01cc4f9b4e2bf749807255454312002"
|
||||
checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c"
|
||||
dependencies = [
|
||||
"ansi_term",
|
||||
"atty",
|
||||
@@ -209,13 +209,13 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "console"
|
||||
version = "0.14.1"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3993e6445baa160675931ec041a5e03ca84b9c6e32a056150d3aa2bdda0a1f45"
|
||||
checksum = "a28b32d32ca44b70c3e4acd7db1babf555fa026e385fb95f18028f88848b3c31"
|
||||
dependencies = [
|
||||
"encode_unicode",
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"regex",
|
||||
"terminal_size",
|
||||
"unicode-width",
|
||||
@@ -224,9 +224,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.1"
|
||||
version = "0.9.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0a89e2ae426ea83155dccf10c0fa6b1463ef6d5fcb44cee0b224a408fa640a62"
|
||||
checksum = "6888e10551bb93e424d8df1d07f1a8b4fceb0001a3a4b048bfc47554946f47b3"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
@@ -234,9 +234,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation-sys"
|
||||
version = "0.8.2"
|
||||
version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea221b5284a47e40033bf9b66f35f984ec0ea2931eb03505246cd27a963f981b"
|
||||
checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc"
|
||||
|
||||
[[package]]
|
||||
name = "cpufeatures"
|
||||
@@ -249,9 +249,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "crc32fast"
|
||||
version = "1.2.1"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "81156fece84ab6a9f2afdb109ce3ae577e42b1228441eded99bd77f627953b1a"
|
||||
checksum = "738c290dfaea84fc1ca15ad9c168d083b05a714e1efddd8edaab678dc28d2836"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
]
|
||||
@@ -409,9 +409,9 @@ checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"
|
||||
|
||||
[[package]]
|
||||
name = "encoding_rs"
|
||||
version = "0.8.28"
|
||||
version = "0.8.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "80df024fbc5ac80f87dfef0d9f5209a252f2a497f7f42944cff24d8253cac065"
|
||||
checksum = "a74ea89a0a1b98f6332de42c95baff457ada66d1cb4030f9ff151b2041a1c746"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
]
|
||||
@@ -449,9 +449,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "flate2"
|
||||
version = "1.0.21"
|
||||
version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "80edafed416a46fb378521624fab1cfa2eb514784fd8921adbe8a8d8321da811"
|
||||
checksum = "1e6988e897c1c9c485f43b47a529cef42fde0547f9d8d41a7062518f1d8fc53f"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"crc32fast",
|
||||
@@ -508,44 +508,43 @@ checksum = "fed34cd105917e91daa4da6b3728c47b068749d6a62c59811f06ed2ac71d9da7"
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.17"
|
||||
version = "0.3.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5da6ba8c3bb3c165d3c7319fc1cc8304facf1fb8db99c5de877183c08a273888"
|
||||
checksum = "7fc8cd39e3dbf865f7340dce6a2d401d24fd37c6fe6c4f0ee0de8bfca2252d27"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-core"
|
||||
version = "0.3.17"
|
||||
version = "0.3.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "88d1c26957f23603395cd326b0ffe64124b818f4449552f960d815cfba83a53d"
|
||||
checksum = "629316e42fe7c2a0b9a65b47d159ceaa5453ab14e8f0a3c5eedbb8cd55b4a445"
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.17"
|
||||
version = "0.3.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "522de2a0fe3e380f1bc577ba0474108faf3f6b18321dbf60b3b9c39a75073377"
|
||||
checksum = "e481354db6b5c353246ccf6a728b0c5511d752c08da7260546fc0933869daa11"
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.17"
|
||||
version = "0.3.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "36ea153c13024fe480590b3e3d4cad89a0cfacecc24577b68f86c6ced9c2bc11"
|
||||
checksum = "996c6442437b62d21a32cd9906f9c41e7dc1e19a9579843fad948696769305af"
|
||||
|
||||
[[package]]
|
||||
name = "futures-task"
|
||||
version = "0.3.17"
|
||||
version = "0.3.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d3d00f4eddb73e498a54394f228cd55853bdf059259e8e7bc6e69d408892e99"
|
||||
checksum = "dabf1872aaab32c886832f2276d2f5399887e2bd613698a02359e4ea83f8de12"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.17"
|
||||
version = "0.3.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "36568465210a3a6ee45e1f165136d68671471a501e632e9a98d96872222b5481"
|
||||
checksum = "41d22213122356472061ac0f1ab2cee28d2bac8491410fd68c2af53d1cedb83e"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-task",
|
||||
@@ -601,9 +600,9 @@ checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.3.4"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7f3675cfef6a30c8031cf9e6493ebdc3bb3272a3fea3923c4210d1830e6a472"
|
||||
checksum = "7fd819562fcebdac5afc5c113c3ec36f902840b70fd4fc458799c8ce4607ae55"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
@@ -635,9 +634,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "http"
|
||||
version = "0.2.4"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "527e8c9ac747e28542699a951517aa9a6945af506cd1f2e1b53a576c17b6cc11"
|
||||
checksum = "1323096b05d41827dadeaee54c9981958c0f94e670bc94ed80037d1a7b8b186b"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
@@ -646,9 +645,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "0.4.3"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "399c583b2979440c60be0821a6199eca73bc3c8dcd9d070d75ac726e2c6186e5"
|
||||
checksum = "1ff4f84919677303da5f147645dbea6b1881f368d03ac84e1dc09031ebd7b2c6"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http",
|
||||
@@ -663,15 +662,15 @@ checksum = "acd94fdbe1d4ff688b67b04eee2e17bd50995534a61539e45adfefb45e5e5503"
|
||||
|
||||
[[package]]
|
||||
name = "httpdate"
|
||||
version = "1.0.1"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6456b8a6c8f33fee7d958fcd1b60d55b11940a79e63ae87013e6d22e26034440"
|
||||
checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421"
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "0.14.12"
|
||||
version = "0.14.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "13f67199e765030fa08fe0bd581af683f0d5bc04ea09c2b1102012c5fb90e7fd"
|
||||
checksum = "436ec0091e4f20e655156a30a0df3770fe2900aa301e548e08446ec794b6953c"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
@@ -775,9 +774,9 @@ checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4"
|
||||
|
||||
[[package]]
|
||||
name = "js-sys"
|
||||
version = "0.3.53"
|
||||
version = "0.3.55"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e4bf49d50e2961077d9c99f4b7997d770a1114f087c3c2e0069b36c13fc2979d"
|
||||
checksum = "7cc9ffccd38c451a86bf13657df244e9c3f37493cce8e5e21e940963777acc84"
|
||||
dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
@@ -803,9 +802,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.101"
|
||||
version = "0.2.109"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3cb00336871be5ed2c8ed44b60ae9959dc5b9f08539422ed43f09e34ecaeba21"
|
||||
checksum = "f98a04dce437184842841303488f70d0188c5f51437d2a834dc097eafa909a01"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
@@ -830,9 +829,9 @@ checksum = "0ee1c47aaa256ecabcaea351eae4a9b01ef39ed810004e298d2511ed284b1525"
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.6.4"
|
||||
version = "0.6.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59accc507f1338036a0477ef61afdae33cde60840f4dfe481319ce3ad116ddf9"
|
||||
checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
@@ -855,9 +854,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "mio"
|
||||
version = "0.7.13"
|
||||
version = "0.7.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8c2bdb6314ec10835cd3293dd268473a835c02b7b352e788be788b3c6ca6bb16"
|
||||
checksum = "8067b404fe97c70829f082dec8bcf4f71225d7eaea1d8645349cb76fa06205cc"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
@@ -1077,9 +1076,9 @@ checksum = "692fcb63b64b1758029e0a96ee63e049ce8c5948587f2f7208df04625e5f6b56"
|
||||
|
||||
[[package]]
|
||||
name = "onig"
|
||||
version = "6.2.0"
|
||||
version = "6.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b16fd3c0e73b516af509c13c4ba76ec0c987ce20d78b38cff356b8d01fc6a6c0"
|
||||
checksum = "67ddfe2c93bb389eea6e6d713306880c7f6dcc99a75b659ce145d962c861b225"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"lazy_static",
|
||||
@@ -1089,9 +1088,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "onig_sys"
|
||||
version = "69.7.0"
|
||||
version = "69.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fd9442a09e4fbd08d196ddf419b2c79a43c3a46c800320cc841d45c2449a240"
|
||||
checksum = "5dd3eee045c84695b53b20255bb7317063df090b68e18bfac0abb6c39cf7f33e"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"pkg-config",
|
||||
@@ -1105,9 +1104,9 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.36"
|
||||
version = "0.10.38"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8d9facdb76fec0b73c406f125d44d86fdad818d66fef0531eec9233ca425ff4a"
|
||||
checksum = "0c7ae222234c30df141154f159066c5093ff73b63204dcda7121eb082fc56a95"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"cfg-if 1.0.0",
|
||||
@@ -1125,9 +1124,9 @@ checksum = "28988d872ab76095a6e6ac88d99b54fd267702734fd7ffe610ca27f533ddb95a"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-sys"
|
||||
version = "0.9.66"
|
||||
version = "0.9.71"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1996d2d305e561b70d1ee0c53f1542833f4e1ac6ce9a6708b6ff2738ca67dc82"
|
||||
checksum = "7df13d165e607909b363a4757a6f133f8a818a74e9d3a98d09c6128e15fa4c73"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"cc",
|
||||
@@ -1156,30 +1155,30 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||
|
||||
[[package]]
|
||||
name = "pkg-config"
|
||||
version = "0.3.19"
|
||||
version = "0.3.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3831453b3449ceb48b6d9c7ad7c96d5ea673e9b470a1dc578c2ce6521230884c"
|
||||
checksum = "12295df4f294471248581bc09bef3c38a5e46f1e36d6a37353621a0c6c357e1f"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.10"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857"
|
||||
checksum = "ed0cfbc8191465bed66e1718596ee0b0b35d5ee1f41c5df2189d0fe8bde535ba"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.29"
|
||||
version = "1.0.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9f5105d4fdaab20335ca9565e106a5d9b82b6219b5ba735731124ac6711d23d"
|
||||
checksum = "ba508cc11742c0dc5c1659771673afbab7a0efab23aa17e854cbab0837ed0b43"
|
||||
dependencies = [
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.9"
|
||||
version = "1.0.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7"
|
||||
checksum = "38bc8cc6a5f2e3655e0899c1b848643b2562f853f114bfec7be120678e3ace05"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
@@ -1354,9 +1353,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.11.4"
|
||||
version = "0.11.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "246e9f61b9bb77df069a947682be06e31ac43ea37862e244a69f177694ea6d22"
|
||||
checksum = "07bea77bc708afa10e59905c3d4af7c8fd43c9214251673095ff8b14345fcbc5"
|
||||
dependencies = [
|
||||
"base64 0.13.0",
|
||||
"bytes",
|
||||
@@ -1376,6 +1375,7 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
@@ -1394,9 +1394,9 @@ checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.5"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e"
|
||||
checksum = "3c9613b5a66ab9ba26415184cfc41156594925a9cf3a2057e57f31ff145f6568"
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
@@ -1474,9 +1474,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.67"
|
||||
version = "1.0.72"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a7f9e390c27c3c0ce8bc5d725f6e4d30a29d26659494aa4b17535f7522c5c950"
|
||||
checksum = "d0ffa0837f2dfa6fb90868c2b5468cad482e175f7dad97e7421951e663f2b527"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
@@ -1497,9 +1497,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.9.6"
|
||||
version = "0.9.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9204c41a1597a8c5af23c82d1c921cb01ec0a4c59e07a9c7306062829a3903f3"
|
||||
checksum = "b69f9a4c9740d74c5baa3fd2e547f9525fa8088a8a958e0ca2409a514e33f5fa"
|
||||
dependencies = [
|
||||
"block-buffer",
|
||||
"cfg-if 1.0.0",
|
||||
@@ -1510,21 +1510,21 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.4"
|
||||
version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c307a32c1c5c437f38c7fd45d753050587732ba8628319fbdf12a7e289ccc590"
|
||||
checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5"
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.6.1"
|
||||
version = "1.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e"
|
||||
checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309"
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
version = "0.4.1"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "765f090f0e423d2b55843402a07915add955e7d60657db13707a159727326cad"
|
||||
checksum = "5dc90fe6c7be1a323296982db1836d1ea9e47b6839496dde9a541bc496df3516"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"winapi",
|
||||
@@ -1562,9 +1562,9 @@ checksum = "6446ced80d6c486436db5c078dde11a9f73d42b57fb273121e160b84f63d894c"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.75"
|
||||
version = "1.0.82"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b7f58f7e8eaa0009c5fec437aabf511bd9933e4b2d7407bd05273c01a8906ea7"
|
||||
checksum = "8daf5dd0bb60cbd4137b1b587d2fc0ae729bc07cf01cd70b36a1ed5ade3b9d59"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -1623,18 +1623,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.28"
|
||||
version = "1.0.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "283d5230e63df9608ac7d9691adc1dfb6e701225436eb64d0b9a7f0a5a04f6ec"
|
||||
checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.28"
|
||||
version = "1.0.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fa3884228611f5cd3608e2d409bf7dce832e4eb3135e3f11addbd7e41bd68e71"
|
||||
checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -1653,9 +1653,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.3.1"
|
||||
version = "1.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "848a1e1181b9f6753b5e96a092749e29b11d19ede67dfbbd6c7dc7e0f49b5338"
|
||||
checksum = "2c1c1d5a42b6245520c249549ec267180beaffcc0615401ac8e31853d4b6d8d2"
|
||||
dependencies = [
|
||||
"tinyvec_macros",
|
||||
]
|
||||
@@ -1696,9 +1696,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.11.0"
|
||||
version = "1.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4efe6fc2395938c8155973d7be49fe8d03a843726e285e100a8a383cc0154ce"
|
||||
checksum = "70e992e41e0d2fb9f755b37446f20900f64446ef54874f40a60c78f021ac6144"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"bytes",
|
||||
@@ -1722,9 +1722,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.6.7"
|
||||
version = "0.6.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1caa0b0c8d94a049db56b5acf8cba99dc0623aab1b26d5b5f5e2d945846b3592"
|
||||
checksum = "9e99e1983e5d376cd8eb4b66604d2e99e79f5bd988c3055891dcd8c9e2604cc0"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
@@ -1742,9 +1742,9 @@ checksum = "360dfd1d6d30e05fda32ace2c8c70e9c0a9da713275777f5a4dbb8a1893930c6"
|
||||
|
||||
[[package]]
|
||||
name = "tracing"
|
||||
version = "0.1.26"
|
||||
version = "0.1.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09adeb8c97449311ccd28a427f96fb563e7fd31aabf994189879d9da2394b89d"
|
||||
checksum = "375a639232caf30edfc78e8d89b2d4c375515393e7af7e16f01cd96917fb2105"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"pin-project-lite",
|
||||
@@ -1753,9 +1753,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tracing-core"
|
||||
version = "0.1.19"
|
||||
version = "0.1.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2ca517f43f0fb96e0c3072ed5c275fe5eece87e8cb52f4a77b69226d3b1c9df8"
|
||||
checksum = "1f4ed65637b8390770814083d20756f87bfa2c21bf2f110babdc5438351746e4"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
]
|
||||
@@ -1774,9 +1774,9 @@ checksum = "b63708a265f51345575b27fe43f9500ad611579e764c79edbc2037b1121959ec"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-bidi"
|
||||
version = "0.3.6"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "246f4c42e67e7a4e3c6106ff716a5d067d4132a642840b242e357e468a2a0085"
|
||||
checksum = "1a01404663e3db436ed2746d9fefef640d868edae3cceb81c3b8d5732fda678f"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-normalization"
|
||||
@@ -1804,9 +1804,9 @@ checksum = "8895849a949e7845e06bd6dc1aa51731a103c42707010a5b591c0038fb73385b"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.1.8"
|
||||
version = "0.1.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9337591893a19b88d8d87f2cec1e73fad5cdfd10e5a6f349f498ad6ea2ffb1e3"
|
||||
checksum = "3ed742d4ea2bd1176e236172c8429aaf54486e7ac098db29ffe6529e0ce50973"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
@@ -1874,21 +1874,19 @@ checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen"
|
||||
version = "0.2.76"
|
||||
version = "0.2.78"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ce9b1b516211d33767048e5d47fa2a381ed8b76fc48d2ce4aa39877f9f183e0"
|
||||
checksum = "632f73e236b219150ea279196e54e610f5dbafa5d61786303d4da54f84e47fce"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"wasm-bindgen-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-backend"
|
||||
version = "0.2.76"
|
||||
version = "0.2.78"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cfe8dc78e2326ba5f845f4b5bf548401604fa20b1dd1d365fb73b6c1d6364041"
|
||||
checksum = "a317bf8f9fba2476b4b2c85ef4c4af8ff39c3c7f0cdfeed4f82c34a880aa837b"
|
||||
dependencies = [
|
||||
"bumpalo",
|
||||
"lazy_static",
|
||||
@@ -1901,9 +1899,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-futures"
|
||||
version = "0.4.26"
|
||||
version = "0.4.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "95fded345a6559c2cfee778d562300c581f7d4ff3edb9b0d230d69800d213972"
|
||||
checksum = "8e8d7523cb1f2a4c96c1317ca690031b714a51cc14e05f712446691f413f5d39"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"js-sys",
|
||||
@@ -1913,9 +1911,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro"
|
||||
version = "0.2.76"
|
||||
version = "0.2.78"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44468aa53335841d9d6b6c023eaab07c0cd4bddbcfdee3e2bb1e8d2cb8069fef"
|
||||
checksum = "d56146e7c495528bf6587663bea13a8eb588d39b36b679d83972e1a2dbbdacf9"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"wasm-bindgen-macro-support",
|
||||
@@ -1923,9 +1921,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro-support"
|
||||
version = "0.2.76"
|
||||
version = "0.2.78"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0195807922713af1e67dc66132c7328206ed9766af3858164fb583eedc25fbad"
|
||||
checksum = "7803e0eea25835f8abdc585cd3021b3deb11543c6fe226dcd30b228857c5c5ab"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -1936,15 +1934,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-shared"
|
||||
version = "0.2.76"
|
||||
version = "0.2.78"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "acdb075a845574a1fa5f09fd77e43f7747599301ea3417a9fbffdeedfc1f4a29"
|
||||
checksum = "0237232789cf037d5480773fe568aac745bfe2afbc11a863e97901780a6b47cc"
|
||||
|
||||
[[package]]
|
||||
name = "web-sys"
|
||||
version = "0.3.53"
|
||||
version = "0.3.55"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "224b2f6b67919060055ef1a67807367c2066ed520c3862cc013d26cf893a783c"
|
||||
checksum = "38eb105f1c59d9eaa6b5cdc92b859d85b926e82cb2e0945cd0c9259faa6fe9fb"
|
||||
dependencies = [
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
|
||||
@@ -4,6 +4,8 @@ use crate::extraction::*;
|
||||
use crate::tokenizer::PaddingParams;
|
||||
use neon::prelude::*;
|
||||
|
||||
use tk::utils::truncation::TruncateDirection;
|
||||
|
||||
/// Encoding
|
||||
pub struct Encoding {
|
||||
pub encoding: Option<tk::tokenizer::Encoding>,
|
||||
@@ -340,16 +342,23 @@ declare_types! {
|
||||
}
|
||||
|
||||
method truncate(mut cx) {
|
||||
// truncate(length: number, stride: number = 0)
|
||||
// truncate(length: number, stride: number = 0, direction: string = 'right')
|
||||
|
||||
let length = cx.extract::<usize>(0)?;
|
||||
let stride = cx.extract_opt::<usize>(1)?.unwrap_or(0);
|
||||
let direction = cx.extract_opt::<String>(2)?.unwrap_or_else(|| String::from("right"));
|
||||
|
||||
let tdir = match direction.as_str() {
|
||||
"left" => TruncateDirection::Left,
|
||||
"right" => TruncateDirection::Right,
|
||||
_ => panic!("Invalid truncation direction value : {}", direction),
|
||||
};
|
||||
|
||||
let mut this = cx.this();
|
||||
let guard = cx.lock();
|
||||
this.borrow_mut(&guard)
|
||||
.encoding.as_mut().expect("Uninitialized Encoding")
|
||||
.truncate(length, stride);
|
||||
.truncate(length, stride, tdir);
|
||||
|
||||
Ok(cx.undefined().upcast())
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ TESTS_RESOURCES = $(DATA_DIR)/small.txt $(DATA_DIR)/roberta.json
|
||||
|
||||
# Launch the test suite
|
||||
test: $(TESTS_RESOURCES)
|
||||
pip install pytest requests setuptools_rust numpy pyarrow datasets
|
||||
python -m pytest -s -v tests
|
||||
cargo test --no-default-features
|
||||
|
||||
|
||||
@@ -286,7 +286,7 @@ class Encoding:
|
||||
:obj:`List[str]`: The list of tokens
|
||||
"""
|
||||
pass
|
||||
def truncate(self, max_length, stride=0):
|
||||
def truncate(self, max_length, stride=0, direction="right"):
|
||||
"""
|
||||
Truncate the :class:`~tokenizers.Encoding` at the given length
|
||||
|
||||
@@ -299,6 +299,9 @@ class Encoding:
|
||||
|
||||
stride (:obj:`int`, defaults to :obj:`0`):
|
||||
The length of previous content to be included in each overflowing piece
|
||||
|
||||
direction (:obj:`str`, defaults to :obj:`right`)
|
||||
Truncate direction
|
||||
"""
|
||||
pass
|
||||
@property
|
||||
|
||||
@@ -3,6 +3,7 @@ use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use pyo3::{PyObjectProtocol, PySequenceProtocol};
|
||||
use tk::tokenizer::{Offsets, PaddingDirection};
|
||||
use tk::utils::truncation::TruncateDirection;
|
||||
use tokenizers as tk;
|
||||
|
||||
use crate::error::{deprecation_warning, PyError};
|
||||
@@ -439,9 +440,19 @@ impl PyEncoding {
|
||||
///
|
||||
/// stride (:obj:`int`, defaults to :obj:`0`):
|
||||
/// The length of previous content to be included in each overflowing piece
|
||||
///
|
||||
/// direction (:obj:`str`, defaults to :obj:`right`)
|
||||
/// Truncate direction
|
||||
#[args(stride = "0")]
|
||||
#[text_signature = "(self, max_length, stride=0)"]
|
||||
fn truncate(&mut self, max_length: usize, stride: usize) {
|
||||
self.encoding.truncate(max_length, stride);
|
||||
#[args(direction = "\"right\"")]
|
||||
#[text_signature = "(self, max_length, stride=0, direction='right')"]
|
||||
fn truncate(&mut self, max_length: usize, stride: usize, direction: &str) {
|
||||
let tdir = match direction {
|
||||
"left" => TruncateDirection::Left,
|
||||
"right" => TruncateDirection::Right,
|
||||
_ => panic!("Invalid truncation direction value : {}", direction),
|
||||
};
|
||||
|
||||
self.encoding.truncate(max_length, stride, tdir);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::parallelism::*;
|
||||
use crate::tokenizer::{Offsets, Token};
|
||||
use crate::utils::padding::PaddingDirection;
|
||||
use crate::utils::truncation::TruncateDirection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Range;
|
||||
@@ -294,9 +295,10 @@ impl Encoding {
|
||||
|
||||
/// Truncate the current `Encoding`.
|
||||
///
|
||||
/// Panic if `stride >= max_len`
|
||||
pub fn truncate(&mut self, max_len: usize, stride: usize) {
|
||||
if max_len >= self.ids.len() {
|
||||
/// Panics if `stride >= max_len`
|
||||
pub fn truncate(&mut self, max_len: usize, stride: usize, direction: TruncateDirection) {
|
||||
let encoding_len = self.ids.len();
|
||||
if max_len >= encoding_len {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -306,78 +308,75 @@ impl Encoding {
|
||||
return;
|
||||
}
|
||||
|
||||
// Get the main overflowing part
|
||||
let o_ids = self.ids.split_off(max_len);
|
||||
let o_type_ids = self.type_ids.split_off(max_len);
|
||||
let o_tokens = self.tokens.split_off(max_len);
|
||||
let o_words = self.words.split_off(max_len);
|
||||
let o_offsets = self.offsets.split_off(max_len);
|
||||
let o_spe_toks = self.special_tokens_mask.split_off(max_len);
|
||||
let o_attent = self.attention_mask.split_off(max_len);
|
||||
assert!(stride < max_len);
|
||||
|
||||
// When truncating, we loose the `sequence_ranges` information.
|
||||
// When truncating, we lose the `sequence_ranges` information.
|
||||
self.sequence_ranges.clear();
|
||||
|
||||
// Now we need to separate the overflowing part into as many Encoding as needed
|
||||
assert!(stride < max_len);
|
||||
let part_size = max_len - stride;
|
||||
let mut overflowing = vec![];
|
||||
let mut part_id = 0;
|
||||
let mut prev_encoding: &Encoding = self;
|
||||
|
||||
loop {
|
||||
if part_size * part_id >= o_ids.len() {
|
||||
break;
|
||||
let offset = max_len - stride;
|
||||
let mut end = false;
|
||||
let parts_ranges: Vec<(usize, usize)> = match direction {
|
||||
TruncateDirection::Right => (0..encoding_len)
|
||||
.step_by(offset)
|
||||
.filter_map(|start| {
|
||||
if !end {
|
||||
let stop = std::cmp::min(start + max_len, encoding_len);
|
||||
end = stop == encoding_len;
|
||||
Some((start, stop))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
TruncateDirection::Left => (0..encoding_len)
|
||||
.rev()
|
||||
.step_by(offset)
|
||||
.filter_map(|stop| {
|
||||
let stop = stop + 1;
|
||||
let start = if stop < max_len { 0 } else { stop - max_len };
|
||||
if start < stop && !end {
|
||||
end = start == 0;
|
||||
Some((start, stop))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let o = Encoding {
|
||||
ids: get_current_part(&prev_encoding.ids, &o_ids, part_size, part_id, stride),
|
||||
type_ids: get_current_part(
|
||||
&prev_encoding.type_ids,
|
||||
&o_type_ids,
|
||||
part_size,
|
||||
part_id,
|
||||
stride,
|
||||
),
|
||||
tokens: get_current_part(
|
||||
&prev_encoding.tokens,
|
||||
&o_tokens,
|
||||
part_size,
|
||||
part_id,
|
||||
stride,
|
||||
),
|
||||
words: get_current_part(&prev_encoding.words, &o_words, part_size, part_id, stride),
|
||||
offsets: get_current_part(
|
||||
&prev_encoding.offsets,
|
||||
&o_offsets,
|
||||
part_size,
|
||||
part_id,
|
||||
stride,
|
||||
),
|
||||
special_tokens_mask: get_current_part(
|
||||
&prev_encoding.special_tokens_mask,
|
||||
&o_spe_toks,
|
||||
part_size,
|
||||
part_id,
|
||||
stride,
|
||||
),
|
||||
attention_mask: get_current_part(
|
||||
&prev_encoding.attention_mask,
|
||||
&o_attent,
|
||||
part_size,
|
||||
part_id,
|
||||
stride,
|
||||
),
|
||||
let mut i = 0;
|
||||
let (start, stop) = parts_ranges[i];
|
||||
let mut new_encoding = Encoding {
|
||||
ids: self.ids[start..stop].to_vec(),
|
||||
type_ids: self.type_ids[start..stop].to_vec(),
|
||||
tokens: self.tokens[start..stop].to_vec(),
|
||||
words: self.words[start..stop].to_vec(),
|
||||
offsets: self.offsets[start..stop].to_vec(),
|
||||
special_tokens_mask: self.special_tokens_mask[start..stop].to_vec(),
|
||||
attention_mask: self.attention_mask[start..stop].to_vec(),
|
||||
overflowing: vec![],
|
||||
sequence_ranges: HashMap::new(),
|
||||
};
|
||||
|
||||
part_id += 1;
|
||||
overflowing.push(o);
|
||||
prev_encoding = overflowing.last().unwrap();
|
||||
loop {
|
||||
if i == parts_ranges.len() - 1 {
|
||||
break;
|
||||
}
|
||||
|
||||
self.overflowing = overflowing;
|
||||
i += 1;
|
||||
let (start, stop) = parts_ranges[i];
|
||||
new_encoding.overflowing.push(Encoding {
|
||||
ids: self.ids[start..stop].to_vec(),
|
||||
type_ids: self.type_ids[start..stop].to_vec(),
|
||||
tokens: self.tokens[start..stop].to_vec(),
|
||||
words: self.words[start..stop].to_vec(),
|
||||
offsets: self.offsets[start..stop].to_vec(),
|
||||
special_tokens_mask: self.special_tokens_mask[start..stop].to_vec(),
|
||||
attention_mask: self.attention_mask[start..stop].to_vec(),
|
||||
overflowing: vec![],
|
||||
sequence_ranges: HashMap::new(),
|
||||
});
|
||||
}
|
||||
*self = new_encoding;
|
||||
}
|
||||
|
||||
/// Merge all Encodings together
|
||||
@@ -543,23 +542,6 @@ impl std::iter::FromIterator<(u32, String, (usize, usize), Option<u32>, u32)> fo
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_current_part<T: Clone>(
|
||||
prev: &[T],
|
||||
current: &[T],
|
||||
size: usize,
|
||||
idx: usize,
|
||||
stride: usize,
|
||||
) -> Vec<T> {
|
||||
let curr_slice = if (idx + 1) * size > current.len() {
|
||||
¤t[idx * size..]
|
||||
} else {
|
||||
¤t[idx * size..(idx + 1) * size]
|
||||
};
|
||||
let prev_slice = &prev[prev.len() - stride..];
|
||||
[prev_slice, curr_slice].concat()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -620,7 +602,7 @@ mod tests {
|
||||
attention_mask: vec![1, 1, 1],
|
||||
..Default::default()
|
||||
};
|
||||
a.truncate(2, 0);
|
||||
a.truncate(2, 0, TruncateDirection::Right);
|
||||
|
||||
assert_eq!(
|
||||
a,
|
||||
@@ -663,7 +645,7 @@ mod tests {
|
||||
attention_mask: vec![1, 1, 1],
|
||||
..Default::default()
|
||||
};
|
||||
a.truncate(0, 0);
|
||||
a.truncate(0, 0, TruncateDirection::Right);
|
||||
|
||||
assert_eq!(
|
||||
a,
|
||||
@@ -688,6 +670,105 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_overflow_with_stride() {
|
||||
let mut enc = Encoding {
|
||||
ids: vec![1, 2, 3, 4, 5],
|
||||
type_ids: vec![0, 0, 0, 0, 0],
|
||||
tokens: vec![
|
||||
String::from("42"),
|
||||
String::from("is"),
|
||||
String::from("the"),
|
||||
String::from("answer"),
|
||||
String::from("!"),
|
||||
],
|
||||
words: vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
|
||||
offsets: vec![(0, 2), (2, 4), (4, 7), (7, 13), (13, 14)],
|
||||
special_tokens_mask: vec![0, 0, 0, 0, 0],
|
||||
attention_mask: vec![1, 1, 1, 1, 1],
|
||||
overflowing: vec![],
|
||||
..Default::default()
|
||||
};
|
||||
enc.truncate(4, 2, TruncateDirection::Right);
|
||||
|
||||
assert_eq!(
|
||||
enc,
|
||||
Encoding {
|
||||
ids: vec![1, 2, 3, 4],
|
||||
type_ids: vec![0, 0, 0, 0],
|
||||
tokens: vec![
|
||||
String::from("42"),
|
||||
String::from("is"),
|
||||
String::from("the"),
|
||||
String::from("answer"),
|
||||
],
|
||||
words: vec![Some(0), Some(1), Some(2), Some(3)],
|
||||
offsets: vec![(0, 2), (2, 4), (4, 7), (7, 13)],
|
||||
special_tokens_mask: vec![0, 0, 0, 0],
|
||||
attention_mask: vec![1, 1, 1, 1],
|
||||
overflowing: vec![Encoding {
|
||||
ids: vec![3, 4, 5],
|
||||
type_ids: vec![0, 0, 0],
|
||||
tokens: vec![
|
||||
String::from("the"),
|
||||
String::from("answer"),
|
||||
String::from("!"),
|
||||
],
|
||||
words: vec![Some(2), Some(3), Some(4)],
|
||||
offsets: vec![(4, 7), (7, 13), (13, 14)],
|
||||
special_tokens_mask: vec![0, 0, 0],
|
||||
attention_mask: vec![1, 1, 1],
|
||||
overflowing: vec![],
|
||||
..Default::default()
|
||||
}],
|
||||
..Default::default()
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_left() {
|
||||
let mut a = Encoding {
|
||||
ids: vec![1, 2, 3],
|
||||
type_ids: vec![0, 0, 0],
|
||||
tokens: vec![
|
||||
String::from("Hello"),
|
||||
String::from("World"),
|
||||
String::from("!"),
|
||||
],
|
||||
words: vec![Some(0), Some(1), Some(2)],
|
||||
offsets: vec![(0, 5), (6, 11), (11, 12)],
|
||||
special_tokens_mask: vec![0, 0, 0],
|
||||
attention_mask: vec![1, 1, 1],
|
||||
..Default::default()
|
||||
};
|
||||
a.truncate(2, 0, TruncateDirection::Left);
|
||||
|
||||
assert_eq!(
|
||||
a,
|
||||
Encoding {
|
||||
ids: vec![2, 3],
|
||||
type_ids: vec![0, 0],
|
||||
tokens: vec![String::from("World"), String::from("!")],
|
||||
words: vec![Some(1), Some(2)],
|
||||
offsets: vec![(6, 11), (11, 12)],
|
||||
special_tokens_mask: vec![0, 0],
|
||||
attention_mask: vec![1, 1],
|
||||
overflowing: vec![Encoding {
|
||||
ids: vec![1],
|
||||
type_ids: vec![0],
|
||||
tokens: vec![String::from("Hello")],
|
||||
words: vec![Some(0)],
|
||||
offsets: vec![(0, 5)],
|
||||
special_tokens_mask: vec![0],
|
||||
attention_mask: vec![1],
|
||||
..Default::default()
|
||||
}],
|
||||
..Default::default()
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mappings() {
|
||||
let encoding = Encoding {
|
||||
|
||||
@@ -3,6 +3,11 @@ use serde::{Deserialize, Serialize};
|
||||
use std::cmp;
|
||||
use std::mem;
|
||||
|
||||
pub enum TruncateDirection {
|
||||
Left,
|
||||
Right,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TruncationParams {
|
||||
pub max_length: usize,
|
||||
@@ -67,9 +72,9 @@ pub fn truncate_encodings(
|
||||
params: &TruncationParams,
|
||||
) -> Result<(Encoding, Option<Encoding>)> {
|
||||
if params.max_length == 0 {
|
||||
encoding.truncate(0, params.stride);
|
||||
encoding.truncate(0, params.stride, TruncateDirection::Right);
|
||||
if let Some(other_encoding) = pair_encoding.as_mut() {
|
||||
other_encoding.truncate(0, params.stride);
|
||||
other_encoding.truncate(0, params.stride, TruncateDirection::Right);
|
||||
}
|
||||
return Ok((encoding, pair_encoding));
|
||||
}
|
||||
@@ -129,10 +134,14 @@ pub fn truncate_encodings(
|
||||
if swap {
|
||||
mem::swap(&mut n1, &mut n2);
|
||||
}
|
||||
encoding.truncate(n1, params.stride);
|
||||
other_encoding.truncate(n2, params.stride);
|
||||
encoding.truncate(n1, params.stride, TruncateDirection::Right);
|
||||
other_encoding.truncate(n2, params.stride, TruncateDirection::Right);
|
||||
} else {
|
||||
encoding.truncate(total_length - to_remove, params.stride);
|
||||
encoding.truncate(
|
||||
total_length - to_remove,
|
||||
params.stride,
|
||||
TruncateDirection::Right,
|
||||
);
|
||||
}
|
||||
}
|
||||
TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
|
||||
@@ -146,7 +155,11 @@ pub fn truncate_encodings(
|
||||
|
||||
let target_len = target.get_ids().len();
|
||||
if target_len > to_remove {
|
||||
target.truncate(target_len - to_remove, params.stride);
|
||||
target.truncate(
|
||||
target_len - to_remove,
|
||||
params.stride,
|
||||
TruncateDirection::Right,
|
||||
);
|
||||
} else {
|
||||
return Err(Box::new(TruncationError::SequenceTooShort));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user