mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Feature: Handle invalid truncate direction (#858)
* refacto: TruncateDirection -> TruncationDirection * feat(node): invalid direction will throw * feat(python): invalid direction will throw * Update bindings/node/lib/bindings/raw-encoding.test.ts * Update bindings/python/tests/bindings/test_encoding.py Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
@ -104,6 +104,10 @@ describe("RawEncoding", () => {
|
||||
it("accepts `undefined` as second parameter", () => {
|
||||
expect(encoding.truncate(10, undefined)).toBeUndefined();
|
||||
});
|
||||
it("should throw an Error on invalid direction", () => {
|
||||
const t = () => encoding.truncate(10, 3, "not_valid");
|
||||
expect(t).toThrow(`Invalid truncation direction value : not_valid`);
|
||||
});
|
||||
});
|
||||
|
||||
describe("getWordIds", () => {
|
||||
|
@ -4,7 +4,7 @@ use crate::extraction::*;
|
||||
use crate::tokenizer::PaddingParams;
|
||||
use neon::prelude::*;
|
||||
|
||||
use tk::utils::truncation::TruncateDirection;
|
||||
use tk::utils::truncation::TruncationDirection;
|
||||
|
||||
/// Encoding
|
||||
pub struct Encoding {
|
||||
@ -349,10 +349,10 @@ declare_types! {
|
||||
let direction = cx.extract_opt::<String>(2)?.unwrap_or_else(|| String::from("right"));
|
||||
|
||||
let tdir = match direction.as_str() {
|
||||
"left" => TruncateDirection::Left,
|
||||
"right" => TruncateDirection::Right,
|
||||
_ => panic!("Invalid truncation direction value : {}", direction),
|
||||
};
|
||||
"left" => Ok(TruncationDirection::Left),
|
||||
"right" => Ok(TruncationDirection::Right),
|
||||
_ => cx.throw_error(format!("Invalid truncation direction value : {}", direction)),
|
||||
}?;
|
||||
|
||||
let mut this = cx.this();
|
||||
let guard = cx.lock();
|
||||
|
@ -3,7 +3,7 @@ use pyo3::prelude::*;
|
||||
use pyo3::types::*;
|
||||
use pyo3::{PyObjectProtocol, PySequenceProtocol};
|
||||
use tk::tokenizer::{Offsets, PaddingDirection};
|
||||
use tk::utils::truncation::TruncateDirection;
|
||||
use tk::utils::truncation::TruncationDirection;
|
||||
use tokenizers as tk;
|
||||
|
||||
use crate::error::{deprecation_warning, PyError};
|
||||
@ -446,13 +446,18 @@ impl PyEncoding {
|
||||
#[args(stride = "0")]
|
||||
#[args(direction = "\"right\"")]
|
||||
#[text_signature = "(self, max_length, stride=0, direction='right')"]
|
||||
fn truncate(&mut self, max_length: usize, stride: usize, direction: &str) {
|
||||
fn truncate(&mut self, max_length: usize, stride: usize, direction: &str) -> PyResult<()> {
|
||||
let tdir = match direction {
|
||||
"left" => TruncateDirection::Left,
|
||||
"right" => TruncateDirection::Right,
|
||||
_ => panic!("Invalid truncation direction value : {}", direction),
|
||||
};
|
||||
"left" => Ok(TruncationDirection::Left),
|
||||
"right" => Ok(TruncationDirection::Right),
|
||||
_ => Err(PyError(format!(
|
||||
"Invalid truncation direction value : {}",
|
||||
direction
|
||||
))
|
||||
.into_pyerr::<exceptions::PyValueError>()),
|
||||
}?;
|
||||
|
||||
self.encoding.truncate(max_length, stride, tdir);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -106,3 +106,15 @@ class TestEncoding:
|
||||
assert pair.char_to_word(2, 0) == 1
|
||||
assert pair.char_to_word(2, 1) == None
|
||||
assert pair.char_to_word(3, 1) == 1
|
||||
|
||||
def test_truncation(self, encodings):
|
||||
single, _ = encodings
|
||||
single.truncate(2, 1, "right")
|
||||
assert single.tokens == ["[CLS]", "i"]
|
||||
assert single.overflowing[0].tokens == ["i", "love"]
|
||||
|
||||
def test_invalid_truncate_direction(self, encodings):
|
||||
single, _ = encodings
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
single.truncate(2, 1, "not_a_direction")
|
||||
assert "Invalid truncation direction value : not_a_direction" == str(excinfo.value)
|
||||
|
Reference in New Issue
Block a user