Rust - Improve TruncationError

This commit is contained in:
Anthony MOI
2020-04-24 11:53:30 -04:00
parent 7d2b59b0aa
commit 02cc97756f
3 changed files with 30 additions and 8 deletions

View File

@@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Changed
- Improved errors generated during truncation: When the provided max length is too low are
now handled properly.
## [0.7.0] ## [0.7.0]
### Changed ### Changed

View File

@@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
### Changed
- Improved `TruncationError` to handle cases where provided max length is too low.
### Fixed ### Fixed
- [#236]: Fix a bug with offsets being shifted when there are sub-sequences (Usually with - [#236]: Fix a bug with offsets being shifted when there are sub-sequences (Usually with
special tokens and/or added tokens in the sequence). special tokens and/or added tokens in the sequence).

View File

@@ -8,25 +8,34 @@ pub struct TruncationParams {
} }
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum TruncationError {
/// We are supposed to truncate the pair sequence, but it has not been provided.
SecondSequenceNotProvided, SecondSequenceNotProvided,
/// We cannot truncate the target sequence enough to respect the provided max length.
SequenceTooShort, SequenceTooShort,
/// We cannot truncate with the given constraints.
MaxLengthTooLow,
} }
impl std::fmt::Display for Error { impl std::fmt::Display for TruncationError {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
use TruncationError::*;
match self { match self {
Error::SecondSequenceNotProvided => { SecondSequenceNotProvided => {
write!(fmt, "Truncation error: Second sequence not provided") write!(fmt, "Truncation error: Second sequence not provided")
} }
Error::SequenceTooShort => write!( SequenceTooShort => write!(
fmt, fmt,
"Truncation error: Sequence to truncate too short to respect the provided max_length" "Truncation error: Sequence to truncate too short to respect the provided max_length"
), ),
MaxLengthTooLow => write!(
fmt,
"Truncation error: Specified max length is too low \
to respect the various constraints"),
} }
} }
} }
impl std::error::Error for Error {} impl std::error::Error for TruncationError {}
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
pub enum TruncationStrategy { pub enum TruncationStrategy {
@@ -51,7 +60,7 @@ pub fn truncate_encodings(
params: &TruncationParams, params: &TruncationParams,
) -> Result<(Encoding, Option<Encoding>)> { ) -> Result<(Encoding, Option<Encoding>)> {
if params.max_length == 0 { if params.max_length == 0 {
return Ok((encoding, pair_encoding)); return Err(Box::new(TruncationError::MaxLengthTooLow));
} }
let total_length = encoding.get_ids().len() let total_length = encoding.get_ids().len()
@@ -77,6 +86,10 @@ pub fn truncate_encodings(
} }
} }
if n_first == 0 || (pair_encoding.is_some() && n_second == 0) {
return Err(Box::new(TruncationError::MaxLengthTooLow));
}
encoding.truncate(n_first, params.stride); encoding.truncate(n_first, params.stride);
if let Some(encoding) = pair_encoding.as_mut() { if let Some(encoding) = pair_encoding.as_mut() {
encoding.truncate(n_second, params.stride); encoding.truncate(n_second, params.stride);
@@ -88,14 +101,14 @@ pub fn truncate_encodings(
} else if let Some(encoding) = pair_encoding.as_mut() { } else if let Some(encoding) = pair_encoding.as_mut() {
Ok(encoding) Ok(encoding)
} else { } else {
Err(Box::new(Error::SecondSequenceNotProvided)) Err(Box::new(TruncationError::SecondSequenceNotProvided))
}?; }?;
let target_len = target.get_ids().len(); let target_len = target.get_ids().len();
if target_len > to_remove { if target_len > to_remove {
target.truncate(target_len - to_remove, params.stride); target.truncate(target_len - to_remove, params.stride);
} else { } else {
return Err(Box::new(Error::SequenceTooShort)); return Err(Box::new(TruncationError::SequenceTooShort));
} }
} }
} }