From 02cc97756ffb9193b5d6d8dfcdeb7bf08adf2516 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Fri, 24 Apr 2020 11:53:30 -0400 Subject: [PATCH] Rust - Improve TruncationError --- bindings/python/CHANGELOG.md | 6 ++++++ tokenizers/CHANGELOG.md | 3 +++ tokenizers/src/utils/truncation.rs | 29 +++++++++++++++++++++-------- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/bindings/python/CHANGELOG.md b/bindings/python/CHANGELOG.md index f5b49a08..2e3027b2 100644 --- a/bindings/python/CHANGELOG.md +++ b/bindings/python/CHANGELOG.md @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Changed +- Improved errors generated during truncation: When the provided max length is too low are +now handled properly. + ## [0.7.0] ### Changed diff --git a/tokenizers/CHANGELOG.md b/tokenizers/CHANGELOG.md index fd26a7e2..06a732d7 100644 --- a/tokenizers/CHANGELOG.md +++ b/tokenizers/CHANGELOG.md @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed +- Improved `TruncationError` to handle cases where provided max length is too low. + ### Fixed - [#236]: Fix a bug with offsets being shifted when there are sub-sequences (Usually with special tokens and/or added tokens in the sequence). diff --git a/tokenizers/src/utils/truncation.rs b/tokenizers/src/utils/truncation.rs index 981276c0..3a21af56 100644 --- a/tokenizers/src/utils/truncation.rs +++ b/tokenizers/src/utils/truncation.rs @@ -8,25 +8,34 @@ pub struct TruncationParams { } #[derive(Debug)] -pub enum Error { +pub enum TruncationError { + /// We are supposed to truncate the pair sequence, but it has not been provided. SecondSequenceNotProvided, + /// We cannot truncate the target sequence enough to respect the provided max length. SequenceTooShort, + /// We cannot truncate with the given constraints. + MaxLengthTooLow, } -impl std::fmt::Display for Error { +impl std::fmt::Display for TruncationError { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + use TruncationError::*; match self { - Error::SecondSequenceNotProvided => { + SecondSequenceNotProvided => { write!(fmt, "Truncation error: Second sequence not provided") } - Error::SequenceTooShort => write!( + SequenceTooShort => write!( fmt, "Truncation error: Sequence to truncate too short to respect the provided max_length" ), + MaxLengthTooLow => write!( + fmt, + "Truncation error: Specified max length is too low \ + to respect the various constraints"), } } } -impl std::error::Error for Error {} +impl std::error::Error for TruncationError {} #[derive(Debug, Clone, Copy, PartialEq)] pub enum TruncationStrategy { @@ -51,7 +60,7 @@ pub fn truncate_encodings( params: &TruncationParams, ) -> Result<(Encoding, Option)> { if params.max_length == 0 { - return Ok((encoding, pair_encoding)); + return Err(Box::new(TruncationError::MaxLengthTooLow)); } let total_length = encoding.get_ids().len() @@ -77,6 +86,10 @@ pub fn truncate_encodings( } } + if n_first == 0 || (pair_encoding.is_some() && n_second == 0) { + return Err(Box::new(TruncationError::MaxLengthTooLow)); + } + encoding.truncate(n_first, params.stride); if let Some(encoding) = pair_encoding.as_mut() { encoding.truncate(n_second, params.stride); @@ -88,14 +101,14 @@ pub fn truncate_encodings( } else if let Some(encoding) = pair_encoding.as_mut() { Ok(encoding) } else { - Err(Box::new(Error::SecondSequenceNotProvided)) + Err(Box::new(TruncationError::SecondSequenceNotProvided)) }?; let target_len = target.get_ids().len(); if target_len > to_remove { target.truncate(target_len - to_remove, params.stride); } else { - return Err(Box::new(Error::SequenceTooShort)); + return Err(Box::new(TruncationError::SequenceTooShort)); } } }