mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Adding truncation_side within TruncationParams
. (#860)
* Add truncation to enable_truncation * Fix typo * Adding truncation_side within `TruncationParams`. * Node serialization of this direction param. * Update the test. * Fixing warnings/lint. * Adding stuff (can't local debug :( ) * Slow loop... ;( * Stub.py. Co-authored-by: Niels Rogge <niels.rogge1@gmail.com>
This commit is contained in:
@ -4,6 +4,11 @@ export enum TruncationStrategy {
|
|||||||
OnlySecond = "only_second",
|
OnlySecond = "only_second",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export enum TruncationDirection {
|
||||||
|
Left = "left",
|
||||||
|
Right = "right",
|
||||||
|
}
|
||||||
|
|
||||||
export enum PaddingDirection {
|
export enum PaddingDirection {
|
||||||
Left = "left",
|
Left = "left",
|
||||||
Right = "right",
|
Right = "right",
|
||||||
|
8
bindings/node/lib/bindings/tokenizer.d.ts
vendored
8
bindings/node/lib/bindings/tokenizer.d.ts
vendored
@ -1,5 +1,5 @@
|
|||||||
import { Decoder } from "./decoders";
|
import { Decoder } from "./decoders";
|
||||||
import { PaddingDirection, TruncationStrategy } from "./enums";
|
import { PaddingDirection, TruncationDirection, TruncationStrategy } from "./enums";
|
||||||
import { Model } from "./models";
|
import { Model } from "./models";
|
||||||
import { Normalizer } from "./normalizers";
|
import { Normalizer } from "./normalizers";
|
||||||
import { PostProcessor } from "./post-processors";
|
import { PostProcessor } from "./post-processors";
|
||||||
@ -35,6 +35,12 @@ export interface TruncationOptions {
|
|||||||
* @default TruncationStrategy.LongestFirst
|
* @default TruncationStrategy.LongestFirst
|
||||||
*/
|
*/
|
||||||
strategy?: TruncationStrategy;
|
strategy?: TruncationStrategy;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Which side to truncate
|
||||||
|
* @default TruncationDirection.Left
|
||||||
|
*/
|
||||||
|
direction?: TruncationDirection;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface TruncationConfiguration extends Required<TruncationOptions> {
|
export interface TruncationConfiguration extends Required<TruncationOptions> {
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import { promisify } from "util";
|
import { promisify } from "util";
|
||||||
|
|
||||||
import { PaddingDirection, TruncationStrategy } from "./enums";
|
import { PaddingDirection, TruncationDirection, TruncationStrategy } from "./enums";
|
||||||
import { BPE } from "./models";
|
import { BPE } from "./models";
|
||||||
import { RawEncoding } from "./raw-encoding";
|
import { RawEncoding } from "./raw-encoding";
|
||||||
import {
|
import {
|
||||||
@ -376,6 +376,7 @@ describe("Tokenizer", () => {
|
|||||||
maxLength: 2,
|
maxLength: 2,
|
||||||
strategy: TruncationStrategy.LongestFirst,
|
strategy: TruncationStrategy.LongestFirst,
|
||||||
stride: 0,
|
stride: 0,
|
||||||
|
direction: TruncationDirection.Right,
|
||||||
};
|
};
|
||||||
expect(truncation).toEqual(expectedConfig);
|
expect(truncation).toEqual(expectedConfig);
|
||||||
});
|
});
|
||||||
|
@ -36,7 +36,7 @@ export class Encoding {
|
|||||||
return this._rawEncoding.getNSequences();
|
return this._rawEncoding.getNSequences();
|
||||||
}
|
}
|
||||||
|
|
||||||
setSequenceId(seqId: number) {
|
setSequenceId(seqId: number): void {
|
||||||
return this._rawEncoding.setSequenceId(seqId);
|
return this._rawEncoding.setSequenceId(seqId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
import { PaddingDirection, TruncationStrategy } from "../../bindings/enums";
|
import {
|
||||||
|
PaddingDirection,
|
||||||
|
TruncationDirection,
|
||||||
|
TruncationStrategy,
|
||||||
|
} from "../../bindings/enums";
|
||||||
import { BPE } from "../../bindings/models";
|
import { BPE } from "../../bindings/models";
|
||||||
import {
|
import {
|
||||||
PaddingConfiguration,
|
PaddingConfiguration,
|
||||||
@ -29,6 +33,7 @@ describe("BaseTokenizer", () => {
|
|||||||
const expectedConfig: TruncationConfiguration = {
|
const expectedConfig: TruncationConfiguration = {
|
||||||
maxLength: 2,
|
maxLength: 2,
|
||||||
strategy: TruncationStrategy.LongestFirst,
|
strategy: TruncationStrategy.LongestFirst,
|
||||||
|
direction: TruncationDirection.Right,
|
||||||
stride: 0,
|
stride: 0,
|
||||||
};
|
};
|
||||||
expect(tokenizer.truncation).toEqual(expectedConfig);
|
expect(tokenizer.truncation).toEqual(expectedConfig);
|
||||||
|
@ -259,6 +259,13 @@ pub enum TruncationStrategyDef {
|
|||||||
OnlySecond,
|
OnlySecond,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
#[serde(remote = "tk::TruncationDirection", rename_all = "camelCase")]
|
||||||
|
pub enum TruncationDirectionDef {
|
||||||
|
Left,
|
||||||
|
Right,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
#[serde(
|
#[serde(
|
||||||
remote = "tk::TruncationParams",
|
remote = "tk::TruncationParams",
|
||||||
@ -269,6 +276,8 @@ pub struct TruncationParamsDef {
|
|||||||
max_length: usize,
|
max_length: usize,
|
||||||
#[serde(with = "TruncationStrategyDef")]
|
#[serde(with = "TruncationStrategyDef")]
|
||||||
strategy: tk::TruncationStrategy,
|
strategy: tk::TruncationStrategy,
|
||||||
|
#[serde(with = "TruncationDirectionDef")]
|
||||||
|
direction: tk::TruncationDirection,
|
||||||
stride: usize,
|
stride: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -300,7 +300,7 @@ class Encoding:
|
|||||||
stride (:obj:`int`, defaults to :obj:`0`):
|
stride (:obj:`int`, defaults to :obj:`0`):
|
||||||
The length of previous content to be included in each overflowing piece
|
The length of previous content to be included in each overflowing piece
|
||||||
|
|
||||||
direction (:obj:`str`, defaults to :obj:`right`)
|
direction (:obj:`str`, defaults to :obj:`right`):
|
||||||
Truncate direction
|
Truncate direction
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
@ -743,7 +743,7 @@ class Tokenizer:
|
|||||||
the longest sequence in a batch.
|
the longest sequence in a batch.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
def enable_truncation(self, max_length, stride=0, strategy="longest_first"):
|
def enable_truncation(self, max_length, stride=0, strategy="longest_first", direction="right"):
|
||||||
"""
|
"""
|
||||||
Enable truncation
|
Enable truncation
|
||||||
|
|
||||||
@ -758,6 +758,9 @@ class Tokenizer:
|
|||||||
strategy (:obj:`str`, `optional`, defaults to :obj:`longest_first`):
|
strategy (:obj:`str`, `optional`, defaults to :obj:`longest_first`):
|
||||||
The strategy used to truncation. Can be one of ``longest_first``, ``only_first`` or
|
The strategy used to truncation. Can be one of ``longest_first``, ``only_first`` or
|
||||||
``only_second``.
|
``only_second``.
|
||||||
|
|
||||||
|
direction (:obj:`str`, defaults to :obj:`right`):
|
||||||
|
Truncate direction
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
def encode(self, sequence, pair=None, is_pretokenized=False, add_special_tokens=True):
|
def encode(self, sequence, pair=None, is_pretokenized=False, add_special_tokens=True):
|
||||||
|
@ -441,7 +441,7 @@ impl PyEncoding {
|
|||||||
/// stride (:obj:`int`, defaults to :obj:`0`):
|
/// stride (:obj:`int`, defaults to :obj:`0`):
|
||||||
/// The length of previous content to be included in each overflowing piece
|
/// The length of previous content to be included in each overflowing piece
|
||||||
///
|
///
|
||||||
/// direction (:obj:`str`, defaults to :obj:`right`)
|
/// direction (:obj:`str`, defaults to :obj:`right`):
|
||||||
/// Truncate direction
|
/// Truncate direction
|
||||||
#[args(stride = "0")]
|
#[args(stride = "0")]
|
||||||
#[args(direction = "\"right\"")]
|
#[args(direction = "\"right\"")]
|
||||||
|
@ -10,7 +10,7 @@ use pyo3::PyObjectProtocol;
|
|||||||
use tk::models::bpe::BPE;
|
use tk::models::bpe::BPE;
|
||||||
use tk::tokenizer::{
|
use tk::tokenizer::{
|
||||||
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
|
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
|
||||||
TruncationParams, TruncationStrategy,
|
TruncationDirection, TruncationParams, TruncationStrategy,
|
||||||
};
|
};
|
||||||
use tk::utils::iter::ResultShunt;
|
use tk::utils::iter::ResultShunt;
|
||||||
use tokenizers as tk;
|
use tokenizers as tk;
|
||||||
@ -660,8 +660,11 @@ impl PyTokenizer {
|
|||||||
/// strategy (:obj:`str`, `optional`, defaults to :obj:`longest_first`):
|
/// strategy (:obj:`str`, `optional`, defaults to :obj:`longest_first`):
|
||||||
/// The strategy used to truncation. Can be one of ``longest_first``, ``only_first`` or
|
/// The strategy used to truncation. Can be one of ``longest_first``, ``only_first`` or
|
||||||
/// ``only_second``.
|
/// ``only_second``.
|
||||||
|
///
|
||||||
|
/// direction (:obj:`str`, defaults to :obj:`right`):
|
||||||
|
/// Truncate direction
|
||||||
#[args(kwargs = "**")]
|
#[args(kwargs = "**")]
|
||||||
#[text_signature = "(self, max_length, stride=0, strategy='longest_first')"]
|
#[text_signature = "(self, max_length, stride=0, strategy='longest_first', direction='right')"]
|
||||||
fn enable_truncation(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
|
fn enable_truncation(&mut self, max_length: usize, kwargs: Option<&PyDict>) -> PyResult<()> {
|
||||||
let mut params = TruncationParams {
|
let mut params = TruncationParams {
|
||||||
max_length,
|
max_length,
|
||||||
@ -687,6 +690,19 @@ impl PyTokenizer {
|
|||||||
.into_pyerr::<exceptions::PyValueError>()),
|
.into_pyerr::<exceptions::PyValueError>()),
|
||||||
}?
|
}?
|
||||||
}
|
}
|
||||||
|
"direction" => {
|
||||||
|
let value: &str = value.extract()?;
|
||||||
|
params.direction = match value {
|
||||||
|
"left" => Ok(TruncationDirection::Left),
|
||||||
|
"right" => Ok(TruncationDirection::Right),
|
||||||
|
_ => Err(PyError(format!(
|
||||||
|
"Unknown `direction`: `{}`. Use \
|
||||||
|
one of `left` or `right`.",
|
||||||
|
value
|
||||||
|
))
|
||||||
|
.into_pyerr::<exceptions::PyValueError>()),
|
||||||
|
}?
|
||||||
|
}
|
||||||
_ => println!("Ignored unknown kwarg option {}", key),
|
_ => println!("Ignored unknown kwarg option {}", key),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -718,6 +734,7 @@ impl PyTokenizer {
|
|||||||
dict.set_item("max_length", params.max_length)?;
|
dict.set_item("max_length", params.max_length)?;
|
||||||
dict.set_item("stride", params.stride)?;
|
dict.set_item("stride", params.stride)?;
|
||||||
dict.set_item("strategy", params.strategy.as_ref())?;
|
dict.set_item("strategy", params.strategy.as_ref())?;
|
||||||
|
dict.set_item("direction", params.direction.as_ref())?;
|
||||||
|
|
||||||
Ok(Some(dict))
|
Ok(Some(dict))
|
||||||
})
|
})
|
||||||
|
12
bindings/python/test.py
Normal file
12
bindings/python/test.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from tokenizers import ByteLevelBPETokenizer
|
||||||
|
from tokenizers import pre_tokenizers, models, Tokenizer, trainers
|
||||||
|
|
||||||
|
tokenizer = Tokenizer(models.Unigram())
|
||||||
|
tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()
|
||||||
|
trainer = trainers.UnigramTrainer(
|
||||||
|
vocab_size=400000000,
|
||||||
|
show_progress=True,
|
||||||
|
special_tokens=["<s>", "<pad>", "</s>", "<unk>", "mask"]
|
||||||
|
)
|
||||||
|
tokenizer.train(["data/big.txt"], trainer)
|
||||||
|
|
36
bindings/python/test.txt
Normal file
36
bindings/python/test.txt
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
||||||
|
<DOCUMENT> \test{bla} thisisatest </DOCUMENT>
|
@ -154,7 +154,7 @@ class TestTemplateProcessing:
|
|||||||
with pytest.raises(Exception, match="Cannot build Piece"):
|
with pytest.raises(Exception, match="Cannot build Piece"):
|
||||||
processor = TemplateProcessing(single="[CLS] $A: [SEP]")
|
processor = TemplateProcessing(single="[CLS] $A: [SEP]")
|
||||||
# Special tokens must be provided when used in template:
|
# Special tokens must be provided when used in template:
|
||||||
with pytest.raises(Exception, match="Missing SpecialToken\(s\) with id\(s\)"):
|
with pytest.raises(Exception, match="Missing SpecialToken\\(s\\) with id\\(s\\)"):
|
||||||
processor = TemplateProcessing(single=["[CLS]"])
|
processor = TemplateProcessing(single=["[CLS]"])
|
||||||
|
|
||||||
def test_bert_parity(self):
|
def test_bert_parity(self):
|
||||||
|
@ -125,7 +125,9 @@ class TestTokenizer:
|
|||||||
assert type(output.ids) == list
|
assert type(output.ids) == list
|
||||||
assert type(output.type_ids) == list
|
assert type(output.type_ids) == list
|
||||||
assert type(output.offsets) == list
|
assert type(output.offsets) == list
|
||||||
assert type(output.words) == list
|
with pytest.warns(DeprecationWarning):
|
||||||
|
assert type(output.words) == list
|
||||||
|
assert type(output.word_ids) == list
|
||||||
assert type(output.special_tokens_mask) == list
|
assert type(output.special_tokens_mask) == list
|
||||||
assert type(output.attention_mask) == list
|
assert type(output.attention_mask) == list
|
||||||
assert type(output.overflowing) == list
|
assert type(output.overflowing) == list
|
||||||
@ -311,6 +313,14 @@ class TestTokenizer:
|
|||||||
trunc = tokenizer.truncation
|
trunc = tokenizer.truncation
|
||||||
tokenizer.enable_truncation(**trunc)
|
tokenizer.enable_truncation(**trunc)
|
||||||
|
|
||||||
|
# Left truncation direction
|
||||||
|
tokenizer.enable_truncation(2, direction="left")
|
||||||
|
output = tokenizer.encode("my name is john")
|
||||||
|
assert output.tokens == ["is", "john"]
|
||||||
|
|
||||||
|
output = tokenizer.encode("my name is john", "pair")
|
||||||
|
assert output.tokens == ["john", "pair"]
|
||||||
|
|
||||||
def test_padding(self):
|
def test_padding(self):
|
||||||
tokenizer = Tokenizer(BPE())
|
tokenizer = Tokenizer(BPE())
|
||||||
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
|
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
|
||||||
|
340
bindings/python/tokenizer.json
Normal file
340
bindings/python/tokenizer.json
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
{
|
||||||
|
"version": "1.0",
|
||||||
|
"truncation": null,
|
||||||
|
"padding": null,
|
||||||
|
"added_tokens": [],
|
||||||
|
"normalizer": null,
|
||||||
|
"pre_tokenizer": {
|
||||||
|
"type": "WhitespaceSplit"
|
||||||
|
},
|
||||||
|
"post_processor": {
|
||||||
|
"type": "ByteLevel",
|
||||||
|
"add_prefix_space": true,
|
||||||
|
"trim_offsets": false
|
||||||
|
},
|
||||||
|
"decoder": {
|
||||||
|
"type": "ByteLevel",
|
||||||
|
"add_prefix_space": true,
|
||||||
|
"trim_offsets": true
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "BPE",
|
||||||
|
"dropout": null,
|
||||||
|
"unk_token": null,
|
||||||
|
"continuing_subword_prefix": null,
|
||||||
|
"end_of_word_suffix": null,
|
||||||
|
"fuse_unk": false,
|
||||||
|
"vocab": {
|
||||||
|
"!": 0,
|
||||||
|
"\"": 1,
|
||||||
|
"#": 2,
|
||||||
|
"$": 3,
|
||||||
|
"%": 4,
|
||||||
|
"&": 5,
|
||||||
|
"'": 6,
|
||||||
|
"(": 7,
|
||||||
|
")": 8,
|
||||||
|
"*": 9,
|
||||||
|
"+": 10,
|
||||||
|
",": 11,
|
||||||
|
"-": 12,
|
||||||
|
".": 13,
|
||||||
|
"/": 14,
|
||||||
|
"0": 15,
|
||||||
|
"1": 16,
|
||||||
|
"2": 17,
|
||||||
|
"3": 18,
|
||||||
|
"4": 19,
|
||||||
|
"5": 20,
|
||||||
|
"6": 21,
|
||||||
|
"7": 22,
|
||||||
|
"8": 23,
|
||||||
|
"9": 24,
|
||||||
|
":": 25,
|
||||||
|
";": 26,
|
||||||
|
"<": 27,
|
||||||
|
"=": 28,
|
||||||
|
">": 29,
|
||||||
|
"?": 30,
|
||||||
|
"@": 31,
|
||||||
|
"A": 32,
|
||||||
|
"B": 33,
|
||||||
|
"C": 34,
|
||||||
|
"D": 35,
|
||||||
|
"E": 36,
|
||||||
|
"F": 37,
|
||||||
|
"G": 38,
|
||||||
|
"H": 39,
|
||||||
|
"I": 40,
|
||||||
|
"J": 41,
|
||||||
|
"K": 42,
|
||||||
|
"L": 43,
|
||||||
|
"M": 44,
|
||||||
|
"N": 45,
|
||||||
|
"O": 46,
|
||||||
|
"P": 47,
|
||||||
|
"Q": 48,
|
||||||
|
"R": 49,
|
||||||
|
"S": 50,
|
||||||
|
"T": 51,
|
||||||
|
"U": 52,
|
||||||
|
"V": 53,
|
||||||
|
"W": 54,
|
||||||
|
"X": 55,
|
||||||
|
"Y": 56,
|
||||||
|
"Z": 57,
|
||||||
|
"[": 58,
|
||||||
|
"\\": 59,
|
||||||
|
"]": 60,
|
||||||
|
"^": 61,
|
||||||
|
"_": 62,
|
||||||
|
"`": 63,
|
||||||
|
"a": 64,
|
||||||
|
"b": 65,
|
||||||
|
"c": 66,
|
||||||
|
"d": 67,
|
||||||
|
"e": 68,
|
||||||
|
"f": 69,
|
||||||
|
"g": 70,
|
||||||
|
"h": 71,
|
||||||
|
"i": 72,
|
||||||
|
"j": 73,
|
||||||
|
"k": 74,
|
||||||
|
"l": 75,
|
||||||
|
"m": 76,
|
||||||
|
"n": 77,
|
||||||
|
"o": 78,
|
||||||
|
"p": 79,
|
||||||
|
"q": 80,
|
||||||
|
"r": 81,
|
||||||
|
"s": 82,
|
||||||
|
"t": 83,
|
||||||
|
"u": 84,
|
||||||
|
"v": 85,
|
||||||
|
"w": 86,
|
||||||
|
"x": 87,
|
||||||
|
"y": 88,
|
||||||
|
"z": 89,
|
||||||
|
"{": 90,
|
||||||
|
"|": 91,
|
||||||
|
"}": 92,
|
||||||
|
"~": 93,
|
||||||
|
"¡": 94,
|
||||||
|
"¢": 95,
|
||||||
|
"£": 96,
|
||||||
|
"¤": 97,
|
||||||
|
"¥": 98,
|
||||||
|
"¦": 99,
|
||||||
|
"§": 100,
|
||||||
|
"¨": 101,
|
||||||
|
"©": 102,
|
||||||
|
"ª": 103,
|
||||||
|
"«": 104,
|
||||||
|
"¬": 105,
|
||||||
|
"®": 106,
|
||||||
|
"¯": 107,
|
||||||
|
"°": 108,
|
||||||
|
"±": 109,
|
||||||
|
"²": 110,
|
||||||
|
"³": 111,
|
||||||
|
"´": 112,
|
||||||
|
"µ": 113,
|
||||||
|
"¶": 114,
|
||||||
|
"·": 115,
|
||||||
|
"¸": 116,
|
||||||
|
"¹": 117,
|
||||||
|
"º": 118,
|
||||||
|
"»": 119,
|
||||||
|
"¼": 120,
|
||||||
|
"½": 121,
|
||||||
|
"¾": 122,
|
||||||
|
"¿": 123,
|
||||||
|
"À": 124,
|
||||||
|
"Á": 125,
|
||||||
|
"Â": 126,
|
||||||
|
"Ã": 127,
|
||||||
|
"Ä": 128,
|
||||||
|
"Å": 129,
|
||||||
|
"Æ": 130,
|
||||||
|
"Ç": 131,
|
||||||
|
"È": 132,
|
||||||
|
"É": 133,
|
||||||
|
"Ê": 134,
|
||||||
|
"Ë": 135,
|
||||||
|
"Ì": 136,
|
||||||
|
"Í": 137,
|
||||||
|
"Î": 138,
|
||||||
|
"Ï": 139,
|
||||||
|
"Ð": 140,
|
||||||
|
"Ñ": 141,
|
||||||
|
"Ò": 142,
|
||||||
|
"Ó": 143,
|
||||||
|
"Ô": 144,
|
||||||
|
"Õ": 145,
|
||||||
|
"Ö": 146,
|
||||||
|
"×": 147,
|
||||||
|
"Ø": 148,
|
||||||
|
"Ù": 149,
|
||||||
|
"Ú": 150,
|
||||||
|
"Û": 151,
|
||||||
|
"Ü": 152,
|
||||||
|
"Ý": 153,
|
||||||
|
"Þ": 154,
|
||||||
|
"ß": 155,
|
||||||
|
"à": 156,
|
||||||
|
"á": 157,
|
||||||
|
"â": 158,
|
||||||
|
"ã": 159,
|
||||||
|
"ä": 160,
|
||||||
|
"å": 161,
|
||||||
|
"æ": 162,
|
||||||
|
"ç": 163,
|
||||||
|
"è": 164,
|
||||||
|
"é": 165,
|
||||||
|
"ê": 166,
|
||||||
|
"ë": 167,
|
||||||
|
"ì": 168,
|
||||||
|
"í": 169,
|
||||||
|
"î": 170,
|
||||||
|
"ï": 171,
|
||||||
|
"ð": 172,
|
||||||
|
"ñ": 173,
|
||||||
|
"ò": 174,
|
||||||
|
"ó": 175,
|
||||||
|
"ô": 176,
|
||||||
|
"õ": 177,
|
||||||
|
"ö": 178,
|
||||||
|
"÷": 179,
|
||||||
|
"ø": 180,
|
||||||
|
"ù": 181,
|
||||||
|
"ú": 182,
|
||||||
|
"û": 183,
|
||||||
|
"ü": 184,
|
||||||
|
"ý": 185,
|
||||||
|
"þ": 186,
|
||||||
|
"ÿ": 187,
|
||||||
|
"Ā": 188,
|
||||||
|
"ā": 189,
|
||||||
|
"Ă": 190,
|
||||||
|
"ă": 191,
|
||||||
|
"Ą": 192,
|
||||||
|
"ą": 193,
|
||||||
|
"Ć": 194,
|
||||||
|
"ć": 195,
|
||||||
|
"Ĉ": 196,
|
||||||
|
"ĉ": 197,
|
||||||
|
"Ċ": 198,
|
||||||
|
"ċ": 199,
|
||||||
|
"Č": 200,
|
||||||
|
"č": 201,
|
||||||
|
"Ď": 202,
|
||||||
|
"ď": 203,
|
||||||
|
"Đ": 204,
|
||||||
|
"đ": 205,
|
||||||
|
"Ē": 206,
|
||||||
|
"ē": 207,
|
||||||
|
"Ĕ": 208,
|
||||||
|
"ĕ": 209,
|
||||||
|
"Ė": 210,
|
||||||
|
"ė": 211,
|
||||||
|
"Ę": 212,
|
||||||
|
"ę": 213,
|
||||||
|
"Ě": 214,
|
||||||
|
"ě": 215,
|
||||||
|
"Ĝ": 216,
|
||||||
|
"ĝ": 217,
|
||||||
|
"Ğ": 218,
|
||||||
|
"ğ": 219,
|
||||||
|
"Ġ": 220,
|
||||||
|
"ġ": 221,
|
||||||
|
"Ģ": 222,
|
||||||
|
"ģ": 223,
|
||||||
|
"Ĥ": 224,
|
||||||
|
"ĥ": 225,
|
||||||
|
"Ħ": 226,
|
||||||
|
"ħ": 227,
|
||||||
|
"Ĩ": 228,
|
||||||
|
"ĩ": 229,
|
||||||
|
"Ī": 230,
|
||||||
|
"ī": 231,
|
||||||
|
"Ĭ": 232,
|
||||||
|
"ĭ": 233,
|
||||||
|
"Į": 234,
|
||||||
|
"į": 235,
|
||||||
|
"İ": 236,
|
||||||
|
"ı": 237,
|
||||||
|
"IJ": 238,
|
||||||
|
"ij": 239,
|
||||||
|
"Ĵ": 240,
|
||||||
|
"ĵ": 241,
|
||||||
|
"Ķ": 242,
|
||||||
|
"ķ": 243,
|
||||||
|
"ĸ": 244,
|
||||||
|
"Ĺ": 245,
|
||||||
|
"ĺ": 246,
|
||||||
|
"Ļ": 247,
|
||||||
|
"ļ": 248,
|
||||||
|
"Ľ": 249,
|
||||||
|
"ľ": 250,
|
||||||
|
"Ŀ": 251,
|
||||||
|
"ŀ": 252,
|
||||||
|
"Ł": 253,
|
||||||
|
"ł": 254,
|
||||||
|
"Ń": 255,
|
||||||
|
"CU": 256,
|
||||||
|
"DO": 257,
|
||||||
|
"EN": 258,
|
||||||
|
"MEN": 259,
|
||||||
|
"T>": 260,
|
||||||
|
"es": 261,
|
||||||
|
"is": 262,
|
||||||
|
"tes": 263,
|
||||||
|
"CUMEN": 264,
|
||||||
|
"DOCUMEN": 265,
|
||||||
|
"test": 266,
|
||||||
|
"DOCUMENT>": 267,
|
||||||
|
"/DOCUMENT>": 268,
|
||||||
|
"<DOCUMENT>": 269,
|
||||||
|
"</DOCUMENT>": 270,
|
||||||
|
"\\test": 271,
|
||||||
|
"a}": 272,
|
||||||
|
"atest": 273,
|
||||||
|
"bl": 274,
|
||||||
|
"his": 275,
|
||||||
|
"this": 276,
|
||||||
|
"{bl": 277,
|
||||||
|
"isatest": 278,
|
||||||
|
"\\test{bl": 279,
|
||||||
|
"thisisatest": 280,
|
||||||
|
"\\test{bla}": 281
|
||||||
|
},
|
||||||
|
"merges": [
|
||||||
|
"C U",
|
||||||
|
"D O",
|
||||||
|
"E N",
|
||||||
|
"M EN",
|
||||||
|
"T >",
|
||||||
|
"e s",
|
||||||
|
"i s",
|
||||||
|
"t es",
|
||||||
|
"CU MEN",
|
||||||
|
"DO CUMEN",
|
||||||
|
"tes t",
|
||||||
|
"DOCUMEN T>",
|
||||||
|
"/ DOCUMENT>",
|
||||||
|
"< DOCUMENT>",
|
||||||
|
"< /DOCUMENT>",
|
||||||
|
"\\ test",
|
||||||
|
"a }",
|
||||||
|
"a test",
|
||||||
|
"b l",
|
||||||
|
"h is",
|
||||||
|
"t his",
|
||||||
|
"{ bl",
|
||||||
|
"is atest",
|
||||||
|
"\\test {bl",
|
||||||
|
"this isatest",
|
||||||
|
"\\test{bl a}"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
@ -42,7 +42,9 @@ pub use crate::processors::PostProcessorWrapper;
|
|||||||
// And some other types
|
// And some other types
|
||||||
pub use crate::utils::iter::LinesWithEnding;
|
pub use crate::utils::iter::LinesWithEnding;
|
||||||
pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams, PaddingStrategy};
|
pub use crate::utils::padding::{pad_encodings, PaddingDirection, PaddingParams, PaddingStrategy};
|
||||||
pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy};
|
pub use crate::utils::truncation::{
|
||||||
|
truncate_encodings, TruncationDirection, TruncationParams, TruncationStrategy,
|
||||||
|
};
|
||||||
pub use added_vocabulary::*;
|
pub use added_vocabulary::*;
|
||||||
pub use encoding::*;
|
pub use encoding::*;
|
||||||
pub use normalizer::{NormalizedString, OffsetReferential, SplitDelimiterBehavior};
|
pub use normalizer::{NormalizedString, OffsetReferential, SplitDelimiterBehavior};
|
||||||
|
@ -3,13 +3,24 @@ use serde::{Deserialize, Serialize};
|
|||||||
use std::cmp;
|
use std::cmp;
|
||||||
use std::mem;
|
use std::mem;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||||
pub enum TruncationDirection {
|
pub enum TruncationDirection {
|
||||||
Left,
|
Left,
|
||||||
Right,
|
Right,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::convert::AsRef<str> for TruncationDirection {
|
||||||
|
fn as_ref(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
TruncationDirection::Left => "left",
|
||||||
|
TruncationDirection::Right => "right",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct TruncationParams {
|
pub struct TruncationParams {
|
||||||
|
pub direction: TruncationDirection,
|
||||||
pub max_length: usize,
|
pub max_length: usize,
|
||||||
pub strategy: TruncationStrategy,
|
pub strategy: TruncationStrategy,
|
||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
@ -21,6 +32,7 @@ impl Default for TruncationParams {
|
|||||||
max_length: 512,
|
max_length: 512,
|
||||||
strategy: TruncationStrategy::LongestFirst,
|
strategy: TruncationStrategy::LongestFirst,
|
||||||
stride: 0,
|
stride: 0,
|
||||||
|
direction: TruncationDirection::Right,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -72,9 +84,9 @@ pub fn truncate_encodings(
|
|||||||
params: &TruncationParams,
|
params: &TruncationParams,
|
||||||
) -> Result<(Encoding, Option<Encoding>)> {
|
) -> Result<(Encoding, Option<Encoding>)> {
|
||||||
if params.max_length == 0 {
|
if params.max_length == 0 {
|
||||||
encoding.truncate(0, params.stride, TruncationDirection::Right);
|
encoding.truncate(0, params.stride, params.direction);
|
||||||
if let Some(other_encoding) = pair_encoding.as_mut() {
|
if let Some(other_encoding) = pair_encoding.as_mut() {
|
||||||
other_encoding.truncate(0, params.stride, TruncationDirection::Right);
|
other_encoding.truncate(0, params.stride, params.direction);
|
||||||
}
|
}
|
||||||
return Ok((encoding, pair_encoding));
|
return Ok((encoding, pair_encoding));
|
||||||
}
|
}
|
||||||
@ -134,14 +146,10 @@ pub fn truncate_encodings(
|
|||||||
if swap {
|
if swap {
|
||||||
mem::swap(&mut n1, &mut n2);
|
mem::swap(&mut n1, &mut n2);
|
||||||
}
|
}
|
||||||
encoding.truncate(n1, params.stride, TruncationDirection::Right);
|
encoding.truncate(n1, params.stride, params.direction);
|
||||||
other_encoding.truncate(n2, params.stride, TruncationDirection::Right);
|
other_encoding.truncate(n2, params.stride, params.direction);
|
||||||
} else {
|
} else {
|
||||||
encoding.truncate(
|
encoding.truncate(total_length - to_remove, params.stride, params.direction);
|
||||||
total_length - to_remove,
|
|
||||||
params.stride,
|
|
||||||
TruncationDirection::Right,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
|
TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
|
||||||
@ -155,11 +163,7 @@ pub fn truncate_encodings(
|
|||||||
|
|
||||||
let target_len = target.get_ids().len();
|
let target_len = target.get_ids().len();
|
||||||
if target_len > to_remove {
|
if target_len > to_remove {
|
||||||
target.truncate(
|
target.truncate(target_len - to_remove, params.stride, params.direction);
|
||||||
target_len - to_remove,
|
|
||||||
params.stride,
|
|
||||||
TruncationDirection::Right,
|
|
||||||
);
|
|
||||||
} else {
|
} else {
|
||||||
return Err(Box::new(TruncationError::SequenceTooShort));
|
return Err(Box::new(TruncationError::SequenceTooShort));
|
||||||
}
|
}
|
||||||
@ -284,6 +288,7 @@ mod tests {
|
|||||||
max_length: 7,
|
max_length: 7,
|
||||||
strategy: TruncationStrategy::LongestFirst,
|
strategy: TruncationStrategy::LongestFirst,
|
||||||
stride: 0,
|
stride: 0,
|
||||||
|
direction: TruncationDirection::Right,
|
||||||
};
|
};
|
||||||
|
|
||||||
truncate_and_assert(get_empty(), get_empty(), ¶ms, 0, 0);
|
truncate_and_assert(get_empty(), get_empty(), ¶ms, 0, 0);
|
||||||
@ -313,6 +318,7 @@ mod tests {
|
|||||||
max_length: 0,
|
max_length: 0,
|
||||||
strategy: TruncationStrategy::LongestFirst,
|
strategy: TruncationStrategy::LongestFirst,
|
||||||
stride: 0,
|
stride: 0,
|
||||||
|
direction: TruncationDirection::Right,
|
||||||
};
|
};
|
||||||
|
|
||||||
truncate_and_assert(get_empty(), get_short(), ¶ms, 0, 0);
|
truncate_and_assert(get_empty(), get_short(), ¶ms, 0, 0);
|
||||||
|
Reference in New Issue
Block a user