Feature: Handle invalid truncate direction (#858)

* refacto: TruncateDirection -> TruncationDirection

* feat(node): invalid direction will throw

* feat(python): invalid direction will throw

* Update bindings/node/lib/bindings/raw-encoding.test.ts

* Update bindings/python/tests/bindings/test_encoding.py

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
Luc Georges
2021-12-27 14:31:57 +01:00
committed by GitHub
parent 38a85b2112
commit c4c9de23a5
6 changed files with 47 additions and 26 deletions

View File

@ -104,6 +104,10 @@ describe("RawEncoding", () => {
it("accepts `undefined` as second parameter", () => {
expect(encoding.truncate(10, undefined)).toBeUndefined();
});
it("should throw an Error on invalid direction", () => {
const t = () => encoding.truncate(10, 3, "not_valid");
expect(t).toThrow(`Invalid truncation direction value : not_valid`);
});
});
describe("getWordIds", () => {

View File

@ -4,7 +4,7 @@ use crate::extraction::*;
use crate::tokenizer::PaddingParams;
use neon::prelude::*;
use tk::utils::truncation::TruncateDirection;
use tk::utils::truncation::TruncationDirection;
/// Encoding
pub struct Encoding {
@ -349,10 +349,10 @@ declare_types! {
let direction = cx.extract_opt::<String>(2)?.unwrap_or_else(|| String::from("right"));
let tdir = match direction.as_str() {
"left" => TruncateDirection::Left,
"right" => TruncateDirection::Right,
_ => panic!("Invalid truncation direction value : {}", direction),
};
"left" => Ok(TruncationDirection::Left),
"right" => Ok(TruncationDirection::Right),
_ => cx.throw_error(format!("Invalid truncation direction value : {}", direction)),
}?;
let mut this = cx.this();
let guard = cx.lock();

View File

@ -3,7 +3,7 @@ use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::{PyObjectProtocol, PySequenceProtocol};
use tk::tokenizer::{Offsets, PaddingDirection};
use tk::utils::truncation::TruncateDirection;
use tk::utils::truncation::TruncationDirection;
use tokenizers as tk;
use crate::error::{deprecation_warning, PyError};
@ -446,13 +446,18 @@ impl PyEncoding {
#[args(stride = "0")]
#[args(direction = "\"right\"")]
#[text_signature = "(self, max_length, stride=0, direction='right')"]
fn truncate(&mut self, max_length: usize, stride: usize, direction: &str) {
fn truncate(&mut self, max_length: usize, stride: usize, direction: &str) -> PyResult<()> {
let tdir = match direction {
"left" => TruncateDirection::Left,
"right" => TruncateDirection::Right,
_ => panic!("Invalid truncation direction value : {}", direction),
};
"left" => Ok(TruncationDirection::Left),
"right" => Ok(TruncationDirection::Right),
_ => Err(PyError(format!(
"Invalid truncation direction value : {}",
direction
))
.into_pyerr::<exceptions::PyValueError>()),
}?;
self.encoding.truncate(max_length, stride, tdir);
Ok(())
}
}

View File

@ -106,3 +106,15 @@ class TestEncoding:
assert pair.char_to_word(2, 0) == 1
assert pair.char_to_word(2, 1) == None
assert pair.char_to_word(3, 1) == 1
def test_truncation(self, encodings):
single, _ = encodings
single.truncate(2, 1, "right")
assert single.tokens == ["[CLS]", "i"]
assert single.overflowing[0].tokens == ["i", "love"]
def test_invalid_truncate_direction(self, encodings):
single, _ = encodings
with pytest.raises(ValueError) as excinfo:
single.truncate(2, 1, "not_a_direction")
assert "Invalid truncation direction value : not_a_direction" == str(excinfo.value)