Feature: Handle invalid truncate direction (#858)

* refacto: TruncateDirection -> TruncationDirection

* feat(node): invalid direction will throw

* feat(python): invalid direction will throw

* Update bindings/node/lib/bindings/raw-encoding.test.ts

* Update bindings/python/tests/bindings/test_encoding.py

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Luc Georges
2021-12-27 14:31:57 +01:00
committed by GitHub
parent 38a85b2112
commit c4c9de23a5
6 changed files with 47 additions and 26 deletions

View File

@ -104,6 +104,10 @@ describe("RawEncoding", () => {
it("accepts `undefined` as second parameter", () => {
expect(encoding.truncate(10, undefined)).toBeUndefined();
});
it("should throw an Error on invalid direction", () => {
const t = () => encoding.truncate(10, 3, "not_valid");
expect(t).toThrow(`Invalid truncation direction value : not_valid`);
});
});
describe("getWordIds", () => {

View File

@ -4,7 +4,7 @@ use crate::extraction::*;
use crate::tokenizer::PaddingParams;
use neon::prelude::*;
use tk::utils::truncation::TruncateDirection;
use tk::utils::truncation::TruncationDirection;
/// Encoding
pub struct Encoding {
@ -349,10 +349,10 @@ declare_types! {
let direction = cx.extract_opt::<String>(2)?.unwrap_or_else(|| String::from("right"));
let tdir = match direction.as_str() {
"left" => TruncateDirection::Left,
"right" => TruncateDirection::Right,
_ => panic!("Invalid truncation direction value : {}", direction),
};
"left" => Ok(TruncationDirection::Left),
"right" => Ok(TruncationDirection::Right),
_ => cx.throw_error(format!("Invalid truncation direction value : {}", direction)),
}?;
let mut this = cx.this();
let guard = cx.lock();

View File

@ -3,7 +3,7 @@ use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::{PyObjectProtocol, PySequenceProtocol};
use tk::tokenizer::{Offsets, PaddingDirection};
use tk::utils::truncation::TruncateDirection;
use tk::utils::truncation::TruncationDirection;
use tokenizers as tk;
use crate::error::{deprecation_warning, PyError};
@ -446,13 +446,18 @@ impl PyEncoding {
#[args(stride = "0")]
#[args(direction = "\"right\"")]
#[text_signature = "(self, max_length, stride=0, direction='right')"]
fn truncate(&mut self, max_length: usize, stride: usize, direction: &str) {
fn truncate(&mut self, max_length: usize, stride: usize, direction: &str) -> PyResult<()> {
let tdir = match direction {
"left" => TruncateDirection::Left,
"right" => TruncateDirection::Right,
_ => panic!("Invalid truncation direction value : {}", direction),
};
"left" => Ok(TruncationDirection::Left),
"right" => Ok(TruncationDirection::Right),
_ => Err(PyError(format!(
"Invalid truncation direction value : {}",
direction
))
.into_pyerr::<exceptions::PyValueError>()),
}?;
self.encoding.truncate(max_length, stride, tdir);
Ok(())
}
}

View File

@ -106,3 +106,15 @@ class TestEncoding:
assert pair.char_to_word(2, 0) == 1
assert pair.char_to_word(2, 1) == None
assert pair.char_to_word(3, 1) == 1
def test_truncation(self, encodings):
single, _ = encodings
single.truncate(2, 1, "right")
assert single.tokens == ["[CLS]", "i"]
assert single.overflowing[0].tokens == ["i", "love"]
def test_invalid_truncate_direction(self, encodings):
single, _ = encodings
with pytest.raises(ValueError) as excinfo:
single.truncate(2, 1, "not_a_direction")
assert "Invalid truncation direction value : not_a_direction" == str(excinfo.value)

View File

@ -1,7 +1,7 @@
use crate::parallelism::*;
use crate::tokenizer::{Offsets, Token};
use crate::utils::padding::PaddingDirection;
use crate::utils::truncation::TruncateDirection;
use crate::utils::truncation::TruncationDirection;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::ops::Range;
@ -296,7 +296,7 @@ impl Encoding {
/// Truncate the current `Encoding`.
///
/// Panics if `stride >= max_len`
pub fn truncate(&mut self, max_len: usize, stride: usize, direction: TruncateDirection) {
pub fn truncate(&mut self, max_len: usize, stride: usize, direction: TruncationDirection) {
let encoding_len = self.ids.len();
if max_len >= encoding_len {
return;
@ -316,7 +316,7 @@ impl Encoding {
let offset = max_len - stride;
let mut end = false;
let parts_ranges: Vec<(usize, usize)> = match direction {
TruncateDirection::Right => (0..encoding_len)
TruncationDirection::Right => (0..encoding_len)
.step_by(offset)
.filter_map(|start| {
if !end {
@ -328,7 +328,7 @@ impl Encoding {
}
})
.collect(),
TruncateDirection::Left => (0..encoding_len)
TruncationDirection::Left => (0..encoding_len)
.rev()
.step_by(offset)
.filter_map(|stop| {
@ -602,7 +602,7 @@ mod tests {
attention_mask: vec![1, 1, 1],
..Default::default()
};
a.truncate(2, 0, TruncateDirection::Right);
a.truncate(2, 0, TruncationDirection::Right);
assert_eq!(
a,
@ -645,7 +645,7 @@ mod tests {
attention_mask: vec![1, 1, 1],
..Default::default()
};
a.truncate(0, 0, TruncateDirection::Right);
a.truncate(0, 0, TruncationDirection::Right);
assert_eq!(
a,
@ -689,7 +689,7 @@ mod tests {
overflowing: vec![],
..Default::default()
};
enc.truncate(4, 2, TruncateDirection::Right);
enc.truncate(4, 2, TruncationDirection::Right);
assert_eq!(
enc,
@ -742,7 +742,7 @@ mod tests {
attention_mask: vec![1, 1, 1],
..Default::default()
};
a.truncate(2, 0, TruncateDirection::Left);
a.truncate(2, 0, TruncationDirection::Left);
assert_eq!(
a,

View File

@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
use std::cmp;
use std::mem;
pub enum TruncateDirection {
pub enum TruncationDirection {
Left,
Right,
}
@ -72,9 +72,9 @@ pub fn truncate_encodings(
params: &TruncationParams,
) -> Result<(Encoding, Option<Encoding>)> {
if params.max_length == 0 {
encoding.truncate(0, params.stride, TruncateDirection::Right);
encoding.truncate(0, params.stride, TruncationDirection::Right);
if let Some(other_encoding) = pair_encoding.as_mut() {
other_encoding.truncate(0, params.stride, TruncateDirection::Right);
other_encoding.truncate(0, params.stride, TruncationDirection::Right);
}
return Ok((encoding, pair_encoding));
}
@ -134,13 +134,13 @@ pub fn truncate_encodings(
if swap {
mem::swap(&mut n1, &mut n2);
}
encoding.truncate(n1, params.stride, TruncateDirection::Right);
other_encoding.truncate(n2, params.stride, TruncateDirection::Right);
encoding.truncate(n1, params.stride, TruncationDirection::Right);
other_encoding.truncate(n2, params.stride, TruncationDirection::Right);
} else {
encoding.truncate(
total_length - to_remove,
params.stride,
TruncateDirection::Right,
TruncationDirection::Right,
);
}
}
@ -158,7 +158,7 @@ pub fn truncate_encodings(
target.truncate(
target_len - to_remove,
params.stride,
TruncateDirection::Right,
TruncationDirection::Right,
);
} else {
return Err(Box::new(TruncationError::SequenceTooShort));