mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Feature: Handle invalid truncate direction (#858)
* refacto: TruncateDirection -> TruncationDirection * feat(node): invalid direction will throw * feat(python): invalid direction will throw * Update bindings/node/lib/bindings/raw-encoding.test.ts * Update bindings/python/tests/bindings/test_encoding.py Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@ -104,6 +104,10 @@ describe("RawEncoding", () => {
|
||||
it("accepts `undefined` as second parameter", () => {
|
||||
expect(encoding.truncate(10, undefined)).toBeUndefined();
|
||||
});
|
||||
it("should throw an Error on invalid direction", () => {
|
||||
const t = () => encoding.truncate(10, 3, "not_valid");
|
||||
expect(t).toThrow(`Invalid truncation direction value : not_valid`);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getWordIds", () => {
|
||||
|
@ -4,7 +4,7 @@ use crate::extraction::*;
|
||||
use crate::tokenizer::PaddingParams;
|
||||
use neon::prelude::*;
|
||||
|
||||
use tk::utils::truncation::TruncateDirection;
|
||||
use tk::utils::truncation::TruncationDirection;
|
||||
|
||||
/// Encoding
|
||||
pub struct Encoding {
|
||||
@ -349,10 +349,10 @@ declare_types! {
|
||||
let direction = cx.extract_opt::<String>(2)?.unwrap_or_else(|| String::from("right"));
|
||||
|
||||
let tdir = match direction.as_str() {
|
||||
"left" => TruncateDirection::Left,
|
||||
"right" => TruncateDirection::Right,
|
||||
_ => panic!("Invalid truncation direction value : {}", direction),
|
||||
};
|
||||
"left" => Ok(TruncationDirection::Left),
|
||||
"right" => Ok(TruncationDirection::Right),
|
||||
_ => cx.throw_error(format!("Invalid truncation direction value : {}", direction)),
|
||||
}?;
|
||||
|
||||
let mut this = cx.this();
|
||||
let guard = cx.lock();
|
||||
|
@ -3,7 +3,7 @@ use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use pyo3::{PyObjectProtocol, PySequenceProtocol};
|
||||
use tk::tokenizer::{Offsets, PaddingDirection};
|
||||
use tk::utils::truncation::TruncateDirection;
|
||||
use tk::utils::truncation::TruncationDirection;
|
||||
use tokenizers as tk;
|
||||
|
||||
use crate::error::{deprecation_warning, PyError};
|
||||
@ -446,13 +446,18 @@ impl PyEncoding {
|
||||
#[args(stride = "0")]
|
||||
#[args(direction = "\"right\"")]
|
||||
#[text_signature = "(self, max_length, stride=0, direction='right')"]
|
||||
fn truncate(&mut self, max_length: usize, stride: usize, direction: &str) {
|
||||
fn truncate(&mut self, max_length: usize, stride: usize, direction: &str) -> PyResult<()> {
|
||||
let tdir = match direction {
|
||||
"left" => TruncateDirection::Left,
|
||||
"right" => TruncateDirection::Right,
|
||||
_ => panic!("Invalid truncation direction value : {}", direction),
|
||||
};
|
||||
"left" => Ok(TruncationDirection::Left),
|
||||
"right" => Ok(TruncationDirection::Right),
|
||||
_ => Err(PyError(format!(
|
||||
"Invalid truncation direction value : {}",
|
||||
direction
|
||||
))
|
||||
.into_pyerr::<exceptions::PyValueError>()),
|
||||
}?;
|
||||
|
||||
self.encoding.truncate(max_length, stride, tdir);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -106,3 +106,15 @@ class TestEncoding:
|
||||
assert pair.char_to_word(2, 0) == 1
|
||||
assert pair.char_to_word(2, 1) == None
|
||||
assert pair.char_to_word(3, 1) == 1
|
||||
|
||||
def test_truncation(self, encodings):
|
||||
single, _ = encodings
|
||||
single.truncate(2, 1, "right")
|
||||
assert single.tokens == ["[CLS]", "i"]
|
||||
assert single.overflowing[0].tokens == ["i", "love"]
|
||||
|
||||
def test_invalid_truncate_direction(self, encodings):
|
||||
single, _ = encodings
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
single.truncate(2, 1, "not_a_direction")
|
||||
assert "Invalid truncation direction value : not_a_direction" == str(excinfo.value)
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::parallelism::*;
|
||||
use crate::tokenizer::{Offsets, Token};
|
||||
use crate::utils::padding::PaddingDirection;
|
||||
use crate::utils::truncation::TruncateDirection;
|
||||
use crate::utils::truncation::TruncationDirection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::ops::Range;
|
||||
@ -296,7 +296,7 @@ impl Encoding {
|
||||
/// Truncate the current `Encoding`.
|
||||
///
|
||||
/// Panics if `stride >= max_len`
|
||||
pub fn truncate(&mut self, max_len: usize, stride: usize, direction: TruncateDirection) {
|
||||
pub fn truncate(&mut self, max_len: usize, stride: usize, direction: TruncationDirection) {
|
||||
let encoding_len = self.ids.len();
|
||||
if max_len >= encoding_len {
|
||||
return;
|
||||
@ -316,7 +316,7 @@ impl Encoding {
|
||||
let offset = max_len - stride;
|
||||
let mut end = false;
|
||||
let parts_ranges: Vec<(usize, usize)> = match direction {
|
||||
TruncateDirection::Right => (0..encoding_len)
|
||||
TruncationDirection::Right => (0..encoding_len)
|
||||
.step_by(offset)
|
||||
.filter_map(|start| {
|
||||
if !end {
|
||||
@ -328,7 +328,7 @@ impl Encoding {
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
TruncateDirection::Left => (0..encoding_len)
|
||||
TruncationDirection::Left => (0..encoding_len)
|
||||
.rev()
|
||||
.step_by(offset)
|
||||
.filter_map(|stop| {
|
||||
@ -602,7 +602,7 @@ mod tests {
|
||||
attention_mask: vec![1, 1, 1],
|
||||
..Default::default()
|
||||
};
|
||||
a.truncate(2, 0, TruncateDirection::Right);
|
||||
a.truncate(2, 0, TruncationDirection::Right);
|
||||
|
||||
assert_eq!(
|
||||
a,
|
||||
@ -645,7 +645,7 @@ mod tests {
|
||||
attention_mask: vec![1, 1, 1],
|
||||
..Default::default()
|
||||
};
|
||||
a.truncate(0, 0, TruncateDirection::Right);
|
||||
a.truncate(0, 0, TruncationDirection::Right);
|
||||
|
||||
assert_eq!(
|
||||
a,
|
||||
@ -689,7 +689,7 @@ mod tests {
|
||||
overflowing: vec![],
|
||||
..Default::default()
|
||||
};
|
||||
enc.truncate(4, 2, TruncateDirection::Right);
|
||||
enc.truncate(4, 2, TruncationDirection::Right);
|
||||
|
||||
assert_eq!(
|
||||
enc,
|
||||
@ -742,7 +742,7 @@ mod tests {
|
||||
attention_mask: vec![1, 1, 1],
|
||||
..Default::default()
|
||||
};
|
||||
a.truncate(2, 0, TruncateDirection::Left);
|
||||
a.truncate(2, 0, TruncationDirection::Left);
|
||||
|
||||
assert_eq!(
|
||||
a,
|
||||
|
@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
|
||||
use std::cmp;
|
||||
use std::mem;
|
||||
|
||||
pub enum TruncateDirection {
|
||||
pub enum TruncationDirection {
|
||||
Left,
|
||||
Right,
|
||||
}
|
||||
@ -72,9 +72,9 @@ pub fn truncate_encodings(
|
||||
params: &TruncationParams,
|
||||
) -> Result<(Encoding, Option<Encoding>)> {
|
||||
if params.max_length == 0 {
|
||||
encoding.truncate(0, params.stride, TruncateDirection::Right);
|
||||
encoding.truncate(0, params.stride, TruncationDirection::Right);
|
||||
if let Some(other_encoding) = pair_encoding.as_mut() {
|
||||
other_encoding.truncate(0, params.stride, TruncateDirection::Right);
|
||||
other_encoding.truncate(0, params.stride, TruncationDirection::Right);
|
||||
}
|
||||
return Ok((encoding, pair_encoding));
|
||||
}
|
||||
@ -134,13 +134,13 @@ pub fn truncate_encodings(
|
||||
if swap {
|
||||
mem::swap(&mut n1, &mut n2);
|
||||
}
|
||||
encoding.truncate(n1, params.stride, TruncateDirection::Right);
|
||||
other_encoding.truncate(n2, params.stride, TruncateDirection::Right);
|
||||
encoding.truncate(n1, params.stride, TruncationDirection::Right);
|
||||
other_encoding.truncate(n2, params.stride, TruncationDirection::Right);
|
||||
} else {
|
||||
encoding.truncate(
|
||||
total_length - to_remove,
|
||||
params.stride,
|
||||
TruncateDirection::Right,
|
||||
TruncationDirection::Right,
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -158,7 +158,7 @@ pub fn truncate_encodings(
|
||||
target.truncate(
|
||||
target_len - to_remove,
|
||||
params.stride,
|
||||
TruncateDirection::Right,
|
||||
TruncationDirection::Right,
|
||||
);
|
||||
} else {
|
||||
return Err(Box::new(TruncationError::SequenceTooShort));
|
||||
|
Reference in New Issue
Block a user