Adding truncation_side within TruncationParams. (#860)

* Add truncation to enable_truncation

* Fix typo

* Adding truncation_side within `TruncationParams`.

* Node serialization of this direction param.

* Update the test.

* Fixing warnings/lint.

* Adding stuff (can't local debug :( )

* Slow loop... ;(

* Stub.py.

Co-authored-by: Niels Rogge <niels.rogge1@gmail.com>
This commit is contained in:
Nicolas Patry
2021-12-28 12:37:06 +01:00
committed by GitHub
parent c4c9de23a5
commit 152880ab3e
16 changed files with 478 additions and 26 deletions

View File

@ -4,6 +4,11 @@ export enum TruncationStrategy {
OnlySecond = "only_second",
}
export enum TruncationDirection {
Left = "left",
Right = "right",
}
export enum PaddingDirection {
Left = "left",
Right = "right",

View File

@ -1,5 +1,5 @@
import { Decoder } from "./decoders";
import { PaddingDirection, TruncationStrategy } from "./enums";
import { PaddingDirection, TruncationDirection, TruncationStrategy } from "./enums";
import { Model } from "./models";
import { Normalizer } from "./normalizers";
import { PostProcessor } from "./post-processors";
@ -35,6 +35,12 @@ export interface TruncationOptions {
* @default TruncationStrategy.LongestFirst
*/
strategy?: TruncationStrategy;
/**
* Which side to truncate
* @default TruncationDirection.Left
*/
direction?: TruncationDirection;
}
export interface TruncationConfiguration extends Required<TruncationOptions> {

View File

@ -3,7 +3,7 @@
import { promisify } from "util";
import { PaddingDirection, TruncationStrategy } from "./enums";
import { PaddingDirection, TruncationDirection, TruncationStrategy } from "./enums";
import { BPE } from "./models";
import { RawEncoding } from "./raw-encoding";
import {
@ -376,6 +376,7 @@ describe("Tokenizer", () => {
maxLength: 2,
strategy: TruncationStrategy.LongestFirst,
stride: 0,
direction: TruncationDirection.Right,
};
expect(truncation).toEqual(expectedConfig);
});

View File

@ -36,7 +36,7 @@ export class Encoding {
return this._rawEncoding.getNSequences();
}
setSequenceId(seqId: number) {
setSequenceId(seqId: number): void {
return this._rawEncoding.setSequenceId(seqId);
}

View File

@ -1,4 +1,8 @@
import { PaddingDirection, TruncationStrategy } from "../../bindings/enums";
import {
PaddingDirection,
TruncationDirection,
TruncationStrategy,
} from "../../bindings/enums";
import { BPE } from "../../bindings/models";
import {
PaddingConfiguration,
@ -29,6 +33,7 @@ describe("BaseTokenizer", () => {
const expectedConfig: TruncationConfiguration = {
maxLength: 2,
strategy: TruncationStrategy.LongestFirst,
direction: TruncationDirection.Right,
stride: 0,
};
expect(tokenizer.truncation).toEqual(expectedConfig);

View File

@ -259,6 +259,13 @@ pub enum TruncationStrategyDef {
OnlySecond,
}
#[derive(Serialize, Deserialize)]
#[serde(remote = "tk::TruncationDirection", rename_all = "camelCase")]
pub enum TruncationDirectionDef {
Left,
Right,
}
#[derive(Serialize, Deserialize)]
#[serde(
remote = "tk::TruncationParams",
@ -269,6 +276,8 @@ pub struct TruncationParamsDef {
max_length: usize,
#[serde(with = "TruncationStrategyDef")]
strategy: tk::TruncationStrategy,
#[serde(with = "TruncationDirectionDef")]
direction: tk::TruncationDirection,
stride: usize,
}

View File

@ -300,7 +300,7 @@ class Encoding:
stride (:obj:`int`, defaults to :obj:`0`):
The length of previous content to be included in each overflowing piece
direction (:obj:`str`, defaults to :obj:`right`)
direction (:obj:`str`, defaults to :obj:`right`):
Truncate direction
"""
pass
@ -743,7 +743,7 @@ class Tokenizer:
the longest sequence in a batch.
"""
pass
def enable_truncation(self, max_length, stride=0, strategy="longest_first"):
def enable_truncation(self, max_length, stride=0, strategy="longest_first", direction="right"):
"""
Enable truncation
@ -758,6 +758,9 @@ class Tokenizer:
strategy (:obj:`str`, `optional`, defaults to :obj:`longest_first`):
The strategy used to truncation. Can be one of ``longest_first``, ``only_first`` or
``only_second``.
direction (:obj:`str`, defaults to :obj:`right`):
Truncate direction
"""
pass
def encode(self, sequence, pair=None, is_pretokenized=False, add_special_tokens=True):

View File

@ -441,7 +441,7 @@ impl PyEncoding {
/// stride (:obj:`int`, defaults to :obj:`0`):
/// The length of previous content to be included in each overflowing piece
///
/// direction (:obj:`str`, defaults to :obj:`right`)
/// direction (:obj:`str`, defaults to :obj:`right`):
/// Truncate direction
#[args(stride = "0")]
#[args(direction = "\"right\"")]

View File

@ -10,7 +10,7 @@ use pyo3::PyObjectProtocol;
use tk::models::bpe::BPE;
use tk::tokenizer::{
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
TruncationParams, TruncationStrategy,
TruncationDirection, TruncationParams, TruncationStrategy,
};
use tk::utils::iter::ResultShunt;
use tokenizers as tk;
@ -660,8 +660,11 @@ impl PyTokenizer {
/// strategy (:obj:`str`, `optional`, defaults to :obj:`longest_first`):
/// The strategy used to truncation. Can be one of ``longest_first``, ``only_first`` or
/// ``only_second``.
///
/// direction (:obj:`str`, defaults to :obj:`right`):
/// Truncate direction
#[args(kwargs = "**")]
#[text_signature = "(self, max_length, stride=0, strategy='longest_first')"]
#[text_signature = "(self, max_length, stride=0, strategy='longest_first', direction='right')"]
fn enable_truncation(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
let mut params = TruncationParams {
max_length,
@ -687,6 +690,19 @@ impl PyTokenizer {
.into_pyerr::<exceptions::PyValueError>()),
}?
}
"direction" => {
let value: &str = value.extract()?;
params.direction = match value {
"left" => Ok(TruncationDirection::Left),
"right" => Ok(TruncationDirection::Right),
_ => Err(PyError(format!(
"Unknown `direction`: `{}`. Use \
one of `left` or `right`.",
value
))
.into_pyerr::<exceptions::PyValueError>()),
}?
}
_ => println!("Ignored unknown kwarg option {}", key),
}
}
@ -718,6 +734,7 @@ impl PyTokenizer {
dict.set_item("max_length", params.max_length)?;
dict.set_item("stride", params.stride)?;
dict.set_item("strategy", params.strategy.as_ref())?;
dict.set_item("direction", params.direction.as_ref())?;
Ok(Some(dict))
})

12
bindings/python/test.py Normal file
View File

@ -0,0 +1,12 @@
from tokenizers import ByteLevelBPETokenizer
from tokenizers import pre_tokenizers, models, Tokenizer, trainers
tokenizer = Tokenizer(models.Unigram())
tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()
trainer = trainers.UnigramTrainer(
vocab_size=400000000,
show_progress=True,
special_tokens=["<s>", "<pad>", "</s>", "<unk>", "mask"]
)
tokenizer.train(["data/big.txt"], trainer)

36
bindings/python/test.txt Normal file
View File

@ -0,0 +1,36 @@
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>

View File

@ -154,7 +154,7 @@ class TestTemplateProcessing:
with pytest.raises(Exception, match="Cannot build Piece"):
processor = TemplateProcessing(single="[CLS] $A: [SEP]")
# Special tokens must be provided when used in template:
with pytest.raises(Exception, match="Missing SpecialToken\(s\) with id\(s\)"):
with pytest.raises(Exception, match="Missing SpecialToken\\(s\\) with id\\(s\\)"):
processor = TemplateProcessing(single=["[CLS]"])
def test_bert_parity(self):

View File

@ -125,7 +125,9 @@ class TestTokenizer:
assert type(output.ids) == list
assert type(output.type_ids) == list
assert type(output.offsets) == list
assert type(output.words) == list
with pytest.warns(DeprecationWarning):
assert type(output.words) == list
assert type(output.word_ids) == list
assert type(output.special_tokens_mask) == list
assert type(output.attention_mask) == list
assert type(output.overflowing) == list
@ -311,6 +313,14 @@ class TestTokenizer:
trunc = tokenizer.truncation
tokenizer.enable_truncation(**trunc)
# Left truncation direction
tokenizer.enable_truncation(2, direction="left")
output = tokenizer.encode("my name is john")
assert output.tokens == ["is", "john"]
output = tokenizer.encode("my name is john", "pair")
assert output.tokens == ["john", "pair"]
def test_padding(self):
tokenizer = Tokenizer(BPE())
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])

View File

@ -0,0 +1,340 @@
{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {
"type": "WhitespaceSplit"
},
"post_processor": {
"type": "ByteLevel",
"add_prefix_space": true,
"trim_offsets": false
},
"decoder": {
"type": "ByteLevel",
"add_prefix_space": true,
"trim_offsets": true
},
"model": {
"type": "BPE",
"dropout": null,
"unk_token": null,
"continuing_subword_prefix": null,
"end_of_word_suffix": null,
"fuse_unk": false,
"vocab": {
"!": 0,
"\"": 1,
"#": 2,
"$": 3,
"%": 4,
"&": 5,
"'": 6,
"(": 7,
")": 8,
"*": 9,
"+": 10,
",": 11,
"-": 12,
".": 13,
"/": 14,
"0": 15,
"1": 16,
"2": 17,
"3": 18,
"4": 19,
"5": 20,
"6": 21,
"7": 22,
"8": 23,
"9": 24,
":": 25,
";": 26,
"<": 27,
"=": 28,
">": 29,
"?": 30,
"@": 31,
"A": 32,
"B": 33,
"C": 34,
"D": 35,
"E": 36,
"F": 37,
"G": 38,
"H": 39,
"I": 40,
"J": 41,
"K": 42,
"L": 43,
"M": 44,
"N": 45,
"O": 46,
"P": 47,
"Q": 48,
"R": 49,
"S": 50,
"T": 51,
"U": 52,
"V": 53,
"W": 54,
"X": 55,
"Y": 56,
"Z": 57,
"[": 58,
"\\": 59,
"]": 60,
"^": 61,
"_": 62,
"`": 63,
"a": 64,
"b": 65,
"c": 66,
"d": 67,
"e": 68,
"f": 69,
"g": 70,
"h": 71,
"i": 72,
"j": 73,
"k": 74,
"l": 75,
"m": 76,
"n": 77,
"o": 78,
"p": 79,
"q": 80,
"r": 81,
"s": 82,
"t": 83,
"u": 84,
"v": 85,
"w": 86,
"x": 87,
"y": 88,
"z": 89,
"{": 90,
"|": 91,
"}": 92,
"~": 93,
"¡": 94,
"¢": 95,
"£": 96,
"¤": 97,
"¥": 98,
"¦": 99,
"§": 100,
"¨": 101,
"©": 102,
"ª": 103,
"«": 104,
"¬": 105,
"®": 106,
"¯": 107,
"°": 108,
"±": 109,
"²": 110,
"³": 111,
"´": 112,
"µ": 113,
"¶": 114,
"·": 115,
"¸": 116,
"¹": 117,
"º": 118,
"»": 119,
"¼": 120,
"½": 121,
"¾": 122,
"¿": 123,
"À": 124,
"Á": 125,
"Â": 126,
"Ã": 127,
"Ä": 128,
"Å": 129,
"Æ": 130,
"Ç": 131,
"È": 132,
"É": 133,
"Ê": 134,
"Ë": 135,
"Ì": 136,
"Í": 137,
"Î": 138,
"Ï": 139,
"Ð": 140,
"Ñ": 141,
"Ò": 142,
"Ó": 143,
"Ô": 144,
"Õ": 145,
"Ö": 146,
"×": 147,
"Ø": 148,
"Ù": 149,
"Ú": 150,
"Û": 151,
"Ü": 152,
"Ý": 153,
"Þ": 154,
"ß": 155,
"à": 156,
"á": 157,
"â": 158,
"ã": 159,
"ä": 160,
"å": 161,
"æ": 162,
"ç": 163,
"è": 164,
"é": 165,
"ê": 166,
"ë": 167,
"ì": 168,
"í": 169,
"î": 170,
"ï": 171,
"ð": 172,
"ñ": 173,
"ò": 174,
"ó": 175,
"ô": 176,
"õ": 177,
"ö": 178,
"÷": 179,
"ø": 180,
"ù": 181,
"ú": 182,
"û": 183,
"ü": 184,
"ý": 185,
"þ": 186,
"ÿ": 187,
"Ā": 188,
"ā": 189,
"Ă": 190,
"ă": 191,
"Ą": 192,
"ą": 193,
"Ć": 194,
"ć": 195,
"Ĉ": 196,
"ĉ": 197,
"Ċ": 198,
"ċ": 199,
"Č": 200,
"č": 201,
"Ď": 202,
"ď": 203,
"Đ": 204,
"đ": 205,
"Ē": 206,
"ē": 207,
"Ĕ": 208,
"ĕ": 209,
"Ė": 210,
"ė": 211,
"Ę": 212,
"ę": 213,
"Ě": 214,
"ě": 215,
"Ĝ": 216,
"ĝ": 217,
"Ğ": 218,
"ğ": 219,
"Ġ": 220,
"ġ": 221,
"Ģ": 222,
"ģ": 223,
"Ĥ": 224,
"ĥ": 225,
"Ħ": 226,
"ħ": 227,
"Ĩ": 228,
"ĩ": 229,
"Ī": 230,
"ī": 231,
"Ĭ": 232,
"ĭ": 233,
"Į": 234,
"į": 235,
"İ": 236,
"ı": 237,
"IJ": 238,
"ij": 239,
"Ĵ": 240,
"ĵ": 241,
"Ķ": 242,
"ķ": 243,
"ĸ": 244,
"Ĺ": 245,
"ĺ": 246,
"Ļ": 247,
"ļ": 248,
"Ľ": 249,
"ľ": 250,
"Ŀ": 251,
"ŀ": 252,
"Ł": 253,
"ł": 254,
"Ń": 255,
"CU": 256,
"DO": 257,
"EN": 258,
"MEN": 259,
"T>": 260,
"es": 261,
"is": 262,
"tes": 263,
"CUMEN": 264,
"DOCUMEN": 265,
"test": 266,
"DOCUMENT>": 267,
"/DOCUMENT>": 268,
"<DOCUMENT>": 269,
"</DOCUMENT>": 270,
"\\test": 271,
"a}": 272,
"atest": 273,
"bl": 274,
"his": 275,
"this": 276,
"{bl": 277,
"isatest": 278,
"\\test{bl": 279,
"thisisatest": 280,
"\\test{bla}": 281
},
"merges": [
"C U",
"D O",
"E N",
"M EN",
"T >",
"e s",
"i s",
"t es",
"CU MEN",
"DO CUMEN",
"tes t",
"DOCUMEN T>",
"/ DOCUMENT>",
"< DOCUMENT>",
"< /DOCUMENT>",
"\\ test",
"a }",
"a test",
"b l",
"h is",
"t his",
"{ bl",
"is atest",
"\\test {bl",
"this isatest",
"\\test{bl a}"
]
}
}