Rust - Improve TruncationError

This commit is contained in:
Anthony MOI
2020-04-24 11:53:30 -04:00
parent 7d2b59b0aa
commit 02cc97756f
3 changed files with 30 additions and 8 deletions

View File

@@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Changed
- Improved errors generated during truncation: When the provided max length is too low are
now handled properly.
## [0.7.0]
### Changed

View File

@@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Changed
- Improved `TruncationError` to handle cases where provided max length is too low.
### Fixed
- [#236]: Fix a bug with offsets being shifted when there are sub-sequences (Usually with
special tokens and/or added tokens in the sequence).

View File

@@ -8,25 +8,34 @@ pub struct TruncationParams {
}
#[derive(Debug)]
pub enum Error {
pub enum TruncationError {
/// We are supposed to truncate the pair sequence, but it has not been provided.
SecondSequenceNotProvided,
/// We cannot truncate the target sequence enough to respect the provided max length.
SequenceTooShort,
/// We cannot truncate with the given constraints.
MaxLengthTooLow,
}
impl std::fmt::Display for Error {
impl std::fmt::Display for TruncationError {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
use TruncationError::*;
match self {
Error::SecondSequenceNotProvided => {
SecondSequenceNotProvided => {
write!(fmt, "Truncation error: Second sequence not provided")
}
Error::SequenceTooShort => write!(
SequenceTooShort => write!(
fmt,
"Truncation error: Sequence to truncate too short to respect the provided max_length"
),
MaxLengthTooLow => write!(
fmt,
"Truncation error: Specified max length is too low \
to respect the various constraints"),
}
}
}
impl std::error::Error for Error {}
impl std::error::Error for TruncationError {}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TruncationStrategy {
@@ -51,7 +60,7 @@ pub fn truncate_encodings(
params: &TruncationParams,
) -> Result<(Encoding, Option<Encoding>)> {
if params.max_length == 0 {
return Ok((encoding, pair_encoding));
return Err(Box::new(TruncationError::MaxLengthTooLow));
}
let total_length = encoding.get_ids().len()
@@ -77,6 +86,10 @@ pub fn truncate_encodings(
}
}
if n_first == 0 || (pair_encoding.is_some() && n_second == 0) {
return Err(Box::new(TruncationError::MaxLengthTooLow));
}
encoding.truncate(n_first, params.stride);
if let Some(encoding) = pair_encoding.as_mut() {
encoding.truncate(n_second, params.stride);
@@ -88,14 +101,14 @@ pub fn truncate_encodings(
} else if let Some(encoding) = pair_encoding.as_mut() {
Ok(encoding)
} else {
Err(Box::new(Error::SecondSequenceNotProvided))
Err(Box::new(TruncationError::SecondSequenceNotProvided))
}?;
let target_len = target.get_ids().len();
if target_len > to_remove {
target.truncate(target_len - to_remove, params.stride);
} else {
return Err(Box::new(Error::SequenceTooShort));
return Err(Box::new(TruncationError::SequenceTooShort));
}
}
}