mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Adding truncation_side within TruncationParams
. (#860)
* Add truncation to enable_truncation * Fix typo * Adding truncation_side within `TruncationParams`. * Node serialization of this direction param. * Update the test. * Fixing warnings/lint. * Adding stuff (can't local debug :( ) * Slow loop... ;( * Stub.py. Co-authored-by: Niels Rogge <niels.rogge1@gmail.com>
This commit is contained in:
@ -4,6 +4,11 @@ export enum TruncationStrategy {
|
||||
OnlySecond = "only_second",
|
||||
}
|
||||
|
||||
export enum TruncationDirection {
|
||||
Left = "left",
|
||||
Right = "right",
|
||||
}
|
||||
|
||||
export enum PaddingDirection {
|
||||
Left = "left",
|
||||
Right = "right",
|
||||
|
8
bindings/node/lib/bindings/tokenizer.d.ts
vendored
8
bindings/node/lib/bindings/tokenizer.d.ts
vendored
@ -1,5 +1,5 @@
|
||||
import { Decoder } from "./decoders";
|
||||
import { PaddingDirection, TruncationStrategy } from "./enums";
|
||||
import { PaddingDirection, TruncationDirection, TruncationStrategy } from "./enums";
|
||||
import { Model } from "./models";
|
||||
import { Normalizer } from "./normalizers";
|
||||
import { PostProcessor } from "./post-processors";
|
||||
@ -35,6 +35,12 @@ export interface TruncationOptions {
|
||||
* @default TruncationStrategy.LongestFirst
|
||||
*/
|
||||
strategy?: TruncationStrategy;
|
||||
|
||||
/**
|
||||
* Which side to truncate
|
||||
* @default TruncationDirection.Left
|
||||
*/
|
||||
direction?: TruncationDirection;
|
||||
}
|
||||
|
||||
export interface TruncationConfiguration extends Required<TruncationOptions> {
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
import { promisify } from "util";
|
||||
|
||||
import { PaddingDirection, TruncationStrategy } from "./enums";
|
||||
import { PaddingDirection, TruncationDirection, TruncationStrategy } from "./enums";
|
||||
import { BPE } from "./models";
|
||||
import { RawEncoding } from "./raw-encoding";
|
||||
import {
|
||||
@ -376,6 +376,7 @@ describe("Tokenizer", () => {
|
||||
maxLength: 2,
|
||||
strategy: TruncationStrategy.LongestFirst,
|
||||
stride: 0,
|
||||
direction: TruncationDirection.Right,
|
||||
};
|
||||
expect(truncation).toEqual(expectedConfig);
|
||||
});
|
||||
|
@ -36,7 +36,7 @@ export class Encoding {
|
||||
return this._rawEncoding.getNSequences();
|
||||
}
|
||||
|
||||
setSequenceId(seqId: number) {
|
||||
setSequenceId(seqId: number): void {
|
||||
return this._rawEncoding.setSequenceId(seqId);
|
||||
}
|
||||
|
||||
|
@ -1,4 +1,8 @@
|
||||
import { PaddingDirection, TruncationStrategy } from "../../bindings/enums";
|
||||
import {
|
||||
PaddingDirection,
|
||||
TruncationDirection,
|
||||
TruncationStrategy,
|
||||
} from "../../bindings/enums";
|
||||
import { BPE } from "../../bindings/models";
|
||||
import {
|
||||
PaddingConfiguration,
|
||||
@ -29,6 +33,7 @@ describe("BaseTokenizer", () => {
|
||||
const expectedConfig: TruncationConfiguration = {
|
||||
maxLength: 2,
|
||||
strategy: TruncationStrategy.LongestFirst,
|
||||
direction: TruncationDirection.Right,
|
||||
stride: 0,
|
||||
};
|
||||
expect(tokenizer.truncation).toEqual(expectedConfig);
|
||||
|
@ -259,6 +259,13 @@ pub enum TruncationStrategyDef {
|
||||
OnlySecond,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(remote = "tk::TruncationDirection", rename_all = "camelCase")]
|
||||
pub enum TruncationDirectionDef {
|
||||
Left,
|
||||
Right,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(
|
||||
remote = "tk::TruncationParams",
|
||||
@ -269,6 +276,8 @@ pub struct TruncationParamsDef {
|
||||
max_length: usize,
|
||||
#[serde(with = "TruncationStrategyDef")]
|
||||
strategy: tk::TruncationStrategy,
|
||||
#[serde(with = "TruncationDirectionDef")]
|
||||
direction: tk::TruncationDirection,
|
||||
stride: usize,
|
||||
}
|
||||
|
||||
|
@ -300,7 +300,7 @@ class Encoding:
|
||||
stride (:obj:`int`, defaults to :obj:`0`):
|
||||
The length of previous content to be included in each overflowing piece
|
||||
|
||||
direction (:obj:`str`, defaults to :obj:`right`)
|
||||
direction (:obj:`str`, defaults to :obj:`right`):
|
||||
Truncate direction
|
||||
"""
|
||||
pass
|
||||
@ -743,7 +743,7 @@ class Tokenizer:
|
||||
the longest sequence in a batch.
|
||||
"""
|
||||
pass
|
||||
def enable_truncation(self, max_length, stride=0, strategy="longest_first"):
|
||||
def enable_truncation(self, max_length, stride=0, strategy="longest_first", direction="right"):
|
||||
"""
|
||||
Enable truncation
|
||||
|
||||
@ -758,6 +758,9 @@ class Tokenizer:
|
||||
strategy (:obj:`str`, `optional`, defaults to :obj:`longest_first`):
|
||||
The strategy used to truncation. Can be one of ``longest_first``, ``only_first`` or
|
||||
``only_second``.
|
||||
|
||||
direction (:obj:`str`, defaults to :obj:`right`):
|
||||
Truncate direction
|
||||
"""
|
||||
pass
|
||||
def encode(self, sequence, pair=None, is_pretokenized=False, add_special_tokens=True):
|
||||
|
@ -441,7 +441,7 @@ impl PyEncoding {
|
||||
/// stride (:obj:`int`, defaults to :obj:`0`):
|
||||
/// The length of previous content to be included in each overflowing piece
|
||||
///
|
||||
/// direction (:obj:`str`, defaults to :obj:`right`)
|
||||
/// direction (:obj:`str`, defaults to :obj:`right`):
|
||||
/// Truncate direction
|
||||
#[args(stride = "0")]
|
||||
#[args(direction = "\"right\"")]
|
||||
|
@ -10,7 +10,7 @@ use pyo3::PyObjectProtocol;
|
||||
use tk::models::bpe::BPE;
|
||||
use tk::tokenizer::{
|
||||
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
|
||||
TruncationParams, TruncationStrategy,
|
||||
TruncationDirection, TruncationParams, TruncationStrategy,
|
||||
};
|
||||
use tk::utils::iter::ResultShunt;
|
||||
use tokenizers as tk;
|
||||
@ -660,8 +660,11 @@ impl PyTokenizer {
|
||||
/// strategy (:obj:`str`, `optional`, defaults to :obj:`longest_first`):
|
||||
/// The strategy used to truncation. Can be one of ``longest_first``, ``only_first`` or
|
||||
/// ``only_second``.
|
||||
///
|
||||
/// direction (:obj:`str`, defaults to :obj:`right`):
|
||||
/// Truncate direction
|
||||
#[args(kwargs = "**")]
|
||||
#[text_signature = "(self, max_length, stride=0, strategy='longest_first')"]
|
||||
#[text_signature = "(self, max_length, stride=0, strategy='longest_first', direction='right')"]
|
||||
fn enable_truncation(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
|
||||
let mut params = TruncationParams {
|
||||
max_length,
|
||||
@ -687,6 +690,19 @@ impl PyTokenizer {
|
||||
.into_pyerr::<exceptions::PyValueError>()),
|
||||
}?
|
||||
}
|
||||
"direction" => {
|
||||
let value: &str = value.extract()?;
|
||||
params.direction = match value {
|
||||
"left" => Ok(TruncationDirection::Left),
|
||||
"right" => Ok(TruncationDirection::Right),
|
||||
_ => Err(PyError(format!(
|
||||
"Unknown `direction`: `{}`. Use \
|
||||
one of `left` or `right`.",
|
||||
value
|
||||
))
|
||||
.into_pyerr::<exceptions::PyValueError>()),
|
||||
}?
|
||||
}
|
||||
_ => println!("Ignored unknown kwarg option {}", key),
|
||||
}
|
||||
}
|
||||
@ -718,6 +734,7 @@ impl PyTokenizer {
|
||||
dict.set_item("max_length", params.max_length)?;
|
||||
dict.set_item("stride", params.stride)?;
|
||||
dict.set_item("strategy", params.strategy.as_ref())?;
|
||||
dict.set_item("direction", params.direction.as_ref())?;
|
||||
|
||||
Ok(Some(dict))
|
||||
})
|
||||
|
12
bindings/python/test.py
Normal file
12
bindings/python/test.py
Normal file
@ -0,0 +1,12 @@
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
from tokenizers import pre_tokenizers, models, Tokenizer, trainers
|
||||
|
||||
tokenizer = Tokenizer(models.Unigram())
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()
|
||||
trainer = trainers.UnigramTrainer(
|
||||
vocab_size=400000000,
|
||||
show_progress=True,
|
||||
special_tokens=["<s>", "<pad>", "</s>", "<unk>", "mask"]
|
||||
)
|
||||
tokenizer.train(["data/big.txt"], trainer)
|
||||
|
36
bindings/python/test.txt
Normal file
36
bindings/python/test.txt
Normal file
@ -0,0 +1,36 @@
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
@ -154,7 +154,7 @@ class TestTemplateProcessing:
|
||||
with pytest.raises(Exception, match="Cannot build Piece"):
|
||||
processor = TemplateProcessing(single="[CLS] $A: [SEP]")
|
||||
# Special tokens must be provided when used in template:
|
||||
with pytest.raises(Exception, match="Missing SpecialToken\(s\) with id\(s\)"):
|
||||
with pytest.raises(Exception, match="Missing SpecialToken\\(s\\) with id\\(s\\)"):
|
||||
processor = TemplateProcessing(single=["[CLS]"])
|
||||
|
||||
def test_bert_parity(self):
|
||||
|
@ -125,7 +125,9 @@ class TestTokenizer:
|
||||
assert type(output.ids) == list
|
||||
assert type(output.type_ids) == list
|
||||
assert type(output.offsets) == list
|
||||
assert type(output.words) == list
|
||||
with pytest.warns(DeprecationWarning):
|
||||
assert type(output.words) == list
|
||||
assert type(output.word_ids) == list
|
||||
assert type(output.special_tokens_mask) == list
|
||||
assert type(output.attention_mask) == list
|
||||
assert type(output.overflowing) == list
|
||||
@ -311,6 +313,14 @@ class TestTokenizer:
|
||||
trunc = tokenizer.truncation
|
||||
tokenizer.enable_truncation(**trunc)
|
||||
|
||||
# Left truncation direction
|
||||
tokenizer.enable_truncation(2, direction="left")
|
||||
output = tokenizer.encode("my name is john")
|
||||
assert output.tokens == ["is", "john"]
|
||||
|
||||
output = tokenizer.encode("my name is john", "pair")
|
||||
assert output.tokens == ["john", "pair"]
|
||||
|
||||
def test_padding(self):
|
||||
tokenizer = Tokenizer(BPE())
|
||||
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
|
||||
|
340
bindings/python/tokenizer.json
Normal file
340
bindings/python/tokenizer.json
Normal file
@ -0,0 +1,340 @@
|
||||
{
|
||||
"version": "1.0",
|
||||
"truncation": null,
|
||||
"padding": null,
|
||||
"added_tokens": [],
|
||||
"normalizer": null,
|
||||
"pre_tokenizer": {
|
||||
"type": "WhitespaceSplit"
|
||||
},
|
||||
"post_processor": {
|
||||
"type": "ByteLevel",
|
||||
"add_prefix_space": true,
|
||||
"trim_offsets": false
|
||||
},
|
||||
"decoder": {
|
||||
"type": "ByteLevel",
|
||||
"add_prefix_space": true,
|
||||
"trim_offsets": true
|
||||
},
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"dropout": null,
|
||||
"unk_token": null,
|
||||
"continuing_subword_prefix": null,
|
||||
"end_of_word_suffix": null,
|
||||
"fuse_unk": false,
|
||||
"vocab": {
|
||||
"!": 0,
|
||||
"\"": 1,
|
||||
"#": 2,
|
||||
"$": 3,
|
||||
"%": 4,
|
||||
"&": 5,
|
||||
"'": 6,
|
||||
"(": 7,
|
||||
")": 8,
|
||||
"*": 9,
|
||||
"+": 10,
|
||||
",": 11,
|
||||
"-": 12,
|
||||
".": 13,
|
||||
"/": 14,
|
||||
"0": 15,
|
||||
"1": 16,
|
||||
"2": 17,
|
||||
"3": 18,
|
||||
"4": 19,
|
||||
"5": 20,
|
||||
"6": 21,
|
||||
"7": 22,
|
||||
"8": 23,
|
||||
"9": 24,
|
||||
":": 25,
|
||||
";": 26,
|
||||
"<": 27,
|
||||
"=": 28,
|
||||
">": 29,
|
||||
"?": 30,
|
||||
"@": 31,
|
||||
"A": 32,
|
||||
"B": 33,
|
||||
"C": 34,
|
||||
"D": 35,
|
||||
"E": 36,
|
||||
"F": 37,
|
||||
"G": 38,
|
||||
"H": 39,
|
||||
"I": 40,
|
||||
"J": 41,
|
||||
"K": 42,
|
||||
"L": 43,
|
||||
"M": 44,
|
||||
"N": 45,
|
||||
"O": 46,
|
||||
"P": 47,
|
||||
"Q": 48,
|
||||
"R": 49,
|
||||
"S": 50,
|
||||
"T": 51,
|
||||
"U": 52,
|
||||
"V": 53,
|
||||
"W": 54,
|
||||
"X": 55,
|
||||
"Y": 56,
|
||||
"Z": 57,
|
||||
"[": 58,
|
||||
"\\": 59,
|
||||
"]": 60,
|
||||
"^": 61,
|
||||
"_": 62,
|
||||
"`": 63,
|
||||
"a": 64,
|
||||
"b": 65,
|
||||
"c": 66,
|
||||
"d": 67,
|
||||
"e": 68,
|
||||
"f": 69,
|
||||
"g": 70,
|
||||
"h": 71,
|
||||
"i": 72,
|
||||
"j": 73,
|
||||
"k": 74,
|
||||
"l": 75,
|
||||
"m": 76,
|
||||
"n": 77,
|
||||
"o": 78,
|
||||
"p": 79,
|
||||
"q": 80,
|
||||
"r": 81,
|
||||
"s": 82,
|
||||
"t": 83,
|
||||
"u": 84,
|
||||
"v": 85,
|
||||
"w": 86,
|
||||
"x": 87,
|
||||
"y": 88,
|
||||
"z": 89,
|
||||
"{": 90,
|
||||
"|": 91,
|
||||
"}": 92,
|
||||
"~": 93,
|
||||
"¡": 94,
|
||||
"¢": 95,
|
||||
"£": 96,
|
||||
"¤": 97,
|
||||
"¥": 98,
|
||||
"¦": 99,
|
||||
"§": 100,
|
||||
"¨": 101,
|
||||
"©": 102,
|
||||
"ª": 103,
|
||||
"«": 104,
|
||||
"¬": 105,
|
||||
"®": 106,
|
||||
"¯": 107,
|
||||
"°": 108,
|
||||
"±": 109,
|
||||
"²": 110,
|
||||
"³": 111,
|
||||
"´": 112,
|
||||
"µ": 113,
|
||||
"¶": 114,
|
||||
"·": 115,
|
||||
"¸": 116,
|
||||
"¹": 117,
|
||||
"º": 118,
|
||||
"»": 119,
|
||||
"¼": 120,
|
||||
"½": 121,
|
||||
"¾": 122,
|
||||
"¿": 123,
|
||||
"À": 124,
|
||||
"Á": 125,
|
||||
"Â": 126,
|
||||
"Ã": 127,
|
||||
"Ä": 128,
|
||||
"Å": 129,
|
||||
"Æ": 130,
|
||||
"Ç": 131,
|
||||
"È": 132,
|
||||
"É": 133,
|
||||
"Ê": 134,
|
||||
"Ë": 135,
|
||||
"Ì": 136,
|
||||
"Í": 137,
|
||||
"Î": 138,
|
||||
"Ï": 139,
|
||||
"Ð": 140,
|
||||
"Ñ": 141,
|
||||
"Ò": 142,
|
||||
"Ó": 143,
|
||||
"Ô": 144,
|
||||
"Õ": 145,
|
||||
"Ö": 146,
|
||||
"×": 147,
|
||||
"Ø": 148,
|
||||
"Ù": 149,
|
||||
"Ú": 150,
|
||||
"Û": 151,
|
||||
"Ü": 152,
|
||||
"Ý": 153,
|
||||
"Þ": 154,
|
||||
"ß": 155,
|
||||
"à": 156,
|
||||
"á": 157,
|
||||
"â": 158,
|
||||
"ã": 159,
|
||||
"ä": 160,
|
||||
"å": 161,
|
||||
"æ": 162,
|
||||
"ç": 163,
|
||||
"è": 164,
|
||||
"é": 165,
|
||||
"ê": 166,
|
||||
"ë": 167,
|
||||
"ì": 168,
|
||||
"í": 169,
|
||||
"î": 170,
|
||||
"ï": 171,
|
||||
"ð": 172,
|
||||
"ñ": 173,
|
||||
"ò": 174,
|
||||
"ó": 175,
|
||||
"ô": 176,
|
||||
"õ": 177,
|
||||
"ö": 178,
|
||||
"÷": 179,
|
||||
"ø": 180,
|
||||
"ù": 181,
|
||||
"ú": 182,
|
||||
"û": 183,
|
||||
"ü": 184,
|
||||
"ý": 185,
|
||||
"þ": 186,
|
||||
"ÿ": 187,
|
||||
"Ā": 188,
|
||||
"ā": 189,
|
||||
"Ă": 190,
|
||||
"ă": 191,
|
||||
"Ą": 192,
|
||||
"ą": 193,
|
||||
"Ć": 194,
|
||||
"ć": 195,
|
||||
"Ĉ": 196,
|
||||
"ĉ": 197,
|
||||
"Ċ": 198,
|
||||
"ċ": 199,
|
||||
"Č": 200,
|
||||
"č": 201,
|
||||
"Ď": 202,
|
||||
"ď": 203,
|
||||
"Đ": 204,
|
||||
"đ": 205,
|
||||
"Ē": 206,
|
||||
"ē": 207,
|
||||
"Ĕ": 208,
|
||||
"ĕ": 209,
|
||||
"Ė": 210,
|
||||
"ė": 211,
|
||||
"Ę": 212,
|
||||
"ę": 213,
|
||||
"Ě": 214,
|
||||
"ě": 215,
|
||||
"Ĝ": 216,
|
||||
"ĝ": 217,
|
||||
"Ğ": 218,
|
||||
"ğ": 219,
|
||||
"Ġ": 220,
|
||||
"ġ": 221,
|
||||
"Ģ": 222,
|
||||
"ģ": 223,
|
||||
"Ĥ": 224,
|
||||
"ĥ": 225,
|
||||
"Ħ": 226,
|
||||
"ħ": 227,
|
||||
"Ĩ": 228,
|
||||
"ĩ": 229,
|
||||
"Ī": 230,
|
||||
"ī": 231,
|
||||
"Ĭ": 232,
|
||||
"ĭ": 233,
|
||||
"Į": 234,
|
||||
"į": 235,
|
||||
"İ": 236,
|
||||
"ı": 237,
|
||||
"IJ": 238,
|
||||
"ij": 239,
|
||||
"Ĵ": 240,
|
||||
"ĵ": 241,
|
||||
"Ķ": 242,
|
||||
"ķ": 243,
|
||||
"ĸ": 244,
|
||||
"Ĺ": 245,
|
||||
"ĺ": 246,
|
||||
"Ļ": 247,
|
||||
"ļ": 248,
|
||||
"Ľ": 249,
|
||||
"ľ": 250,
|
||||
"Ŀ": 251,
|
||||
"ŀ": 252,
|
||||
"Ł": 253,
|
||||
"ł": 254,
|
||||
"Ń": 255,
|
||||
"CU": 256,
|
||||
"DO": 257,
|
||||
"EN": 258,
|
||||
"MEN": 259,
|
||||
"T>": 260,
|
||||
"es": 261,
|
||||
"is": 262,
|
||||
"tes": 263,
|
||||
"CUMEN": 264,
|
||||
"DOCUMEN": 265,
|
||||
"test": 266,
|
||||
"DOCUMENT>": 267,
|
||||
"/DOCUMENT>": 268,
|
||||
"<DOCUMENT>": 269,
|
||||
"</DOCUMENT>": 270,
|
||||
"\\test": 271,
|
||||
"a}": 272,
|
||||
"atest": 273,
|
||||
"bl": 274,
|
||||
"his": 275,
|
||||
"this": 276,
|
||||
"{bl": 277,
|
||||
"isatest": 278,
|
||||
"\\test{bl": 279,
|
||||
"thisisatest": 280,
|
||||
"\\test{bla}": 281
|
||||
},
|
||||
"merges": [
|
||||
"C U",
|
||||
"D O",
|
||||
"E N",
|
||||
"M EN",
|
||||
"T >",
|
||||
"e s",
|
||||
"i s",
|
||||
"t es",
|
||||
"CU MEN",
|
||||
"DO CUMEN",
|
||||
"tes t",
|
||||
"DOCUMEN T>",
|
||||
"/ DOCUMENT>",
|
||||
"< DOCUMENT>",
|
||||
"< /DOCUMENT>",
|
||||
"\\ test",
|
||||
"a }",
|
||||
"a test",
|
||||
"b l",
|
||||
"h is",
|
||||
"t his",
|
||||
"{ bl",
|
||||
"is atest",
|
||||
"\\test {bl",
|
||||
"this isatest",
|
||||
"\\test{bl a}"
|
||||
]
|
||||
}
|
||||
}
|
@ -42,7 +42,9 @@ pub use crate::processors::PostProcessorWrapper;
|
||||
// And some other types
|
||||
pub use crate::utils::iter::LinesWithEnding;
|
||||
pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams, PaddingStrategy};
|
||||
pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy};
|
||||
pub use crate::utils::truncation::{
|
||||
truncate_encodings, TruncationDirection, TruncationParams, TruncationStrategy,
|
||||
};
|
||||
pub use added_vocabulary::*;
|
||||
pub use encoding::*;
|
||||
pub use normalizer::{NormalizedString, OffsetReferential, SplitDelimiterBehavior};
|
||||
|
@ -3,13 +3,24 @@ use serde::{Deserialize, Serialize};
|
||||
use std::cmp;
|
||||
use std::mem;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
pub enum TruncationDirection {
|
||||
Left,
|
||||
Right,
|
||||
}
|
||||
|
||||
impl std::convert::AsRef<str> for TruncationDirection {
|
||||
fn as_ref(&self) -> &str {
|
||||
match self {
|
||||
TruncationDirection::Left => "left",
|
||||
TruncationDirection::Right => "right",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TruncationParams {
|
||||
pub direction: TruncationDirection,
|
||||
pub max_length: usize,
|
||||
pub strategy: TruncationStrategy,
|
||||
pub stride: usize,
|
||||
@ -21,6 +32,7 @@ impl Default for TruncationParams {
|
||||
max_length: 512,
|
||||
strategy: TruncationStrategy::LongestFirst,
|
||||
stride: 0,
|
||||
direction: TruncationDirection::Right,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -72,9 +84,9 @@ pub fn truncate_encodings(
|
||||
params: &TruncationParams,
|
||||
) -> Result<(Encoding, Option<Encoding>)> {
|
||||
if params.max_length == 0 {
|
||||
encoding.truncate(0, params.stride, TruncationDirection::Right);
|
||||
encoding.truncate(0, params.stride, params.direction);
|
||||
if let Some(other_encoding) = pair_encoding.as_mut() {
|
||||
other_encoding.truncate(0, params.stride, TruncationDirection::Right);
|
||||
other_encoding.truncate(0, params.stride, params.direction);
|
||||
}
|
||||
return Ok((encoding, pair_encoding));
|
||||
}
|
||||
@ -134,14 +146,10 @@ pub fn truncate_encodings(
|
||||
if swap {
|
||||
mem::swap(&mut n1, &mut n2);
|
||||
}
|
||||
encoding.truncate(n1, params.stride, TruncationDirection::Right);
|
||||
other_encoding.truncate(n2, params.stride, TruncationDirection::Right);
|
||||
encoding.truncate(n1, params.stride, params.direction);
|
||||
other_encoding.truncate(n2, params.stride, params.direction);
|
||||
} else {
|
||||
encoding.truncate(
|
||||
total_length - to_remove,
|
||||
params.stride,
|
||||
TruncationDirection::Right,
|
||||
);
|
||||
encoding.truncate(total_length - to_remove, params.stride, params.direction);
|
||||
}
|
||||
}
|
||||
TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
|
||||
@ -155,11 +163,7 @@ pub fn truncate_encodings(
|
||||
|
||||
let target_len = target.get_ids().len();
|
||||
if target_len > to_remove {
|
||||
target.truncate(
|
||||
target_len - to_remove,
|
||||
params.stride,
|
||||
TruncationDirection::Right,
|
||||
);
|
||||
target.truncate(target_len - to_remove, params.stride, params.direction);
|
||||
} else {
|
||||
return Err(Box::new(TruncationError::SequenceTooShort));
|
||||
}
|
||||
@ -284,6 +288,7 @@ mod tests {
|
||||
max_length: 7,
|
||||
strategy: TruncationStrategy::LongestFirst,
|
||||
stride: 0,
|
||||
direction: TruncationDirection::Right,
|
||||
};
|
||||
|
||||
truncate_and_assert(get_empty(), get_empty(), ¶ms, 0, 0);
|
||||
@ -313,6 +318,7 @@ mod tests {
|
||||
max_length: 0,
|
||||
strategy: TruncationStrategy::LongestFirst,
|
||||
stride: 0,
|
||||
direction: TruncationDirection::Right,
|
||||
};
|
||||
|
||||
truncate_and_assert(get_empty(), get_short(), ¶ms, 0, 0);
|
||||
|
Reference in New Issue
Block a user