mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 20:28:22 +00:00
Rust - Improve TruncationError
This commit is contained in:
@@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Changed
|
||||
- Improved errors generated during truncation: When the provided max length is too low are
|
||||
now handled properly.
|
||||
|
||||
## [0.7.0]
|
||||
|
||||
### Changed
|
||||
|
||||
@@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Changed
|
||||
- Improved `TruncationError` to handle cases where provided max length is too low.
|
||||
|
||||
### Fixed
|
||||
- [#236]: Fix a bug with offsets being shifted when there are sub-sequences (Usually with
|
||||
special tokens and/or added tokens in the sequence).
|
||||
|
||||
@@ -8,25 +8,34 @@ pub struct TruncationParams {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
pub enum TruncationError {
|
||||
/// We are supposed to truncate the pair sequence, but it has not been provided.
|
||||
SecondSequenceNotProvided,
|
||||
/// We cannot truncate the target sequence enough to respect the provided max length.
|
||||
SequenceTooShort,
|
||||
/// We cannot truncate with the given constraints.
|
||||
MaxLengthTooLow,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Error {
|
||||
impl std::fmt::Display for TruncationError {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
use TruncationError::*;
|
||||
match self {
|
||||
Error::SecondSequenceNotProvided => {
|
||||
SecondSequenceNotProvided => {
|
||||
write!(fmt, "Truncation error: Second sequence not provided")
|
||||
}
|
||||
Error::SequenceTooShort => write!(
|
||||
SequenceTooShort => write!(
|
||||
fmt,
|
||||
"Truncation error: Sequence to truncate too short to respect the provided max_length"
|
||||
),
|
||||
MaxLengthTooLow => write!(
|
||||
fmt,
|
||||
"Truncation error: Specified max length is too low \
|
||||
to respect the various constraints"),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl std::error::Error for Error {}
|
||||
impl std::error::Error for TruncationError {}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum TruncationStrategy {
|
||||
@@ -51,7 +60,7 @@ pub fn truncate_encodings(
|
||||
params: &TruncationParams,
|
||||
) -> Result<(Encoding, Option<Encoding>)> {
|
||||
if params.max_length == 0 {
|
||||
return Ok((encoding, pair_encoding));
|
||||
return Err(Box::new(TruncationError::MaxLengthTooLow));
|
||||
}
|
||||
|
||||
let total_length = encoding.get_ids().len()
|
||||
@@ -77,6 +86,10 @@ pub fn truncate_encodings(
|
||||
}
|
||||
}
|
||||
|
||||
if n_first == 0 || (pair_encoding.is_some() && n_second == 0) {
|
||||
return Err(Box::new(TruncationError::MaxLengthTooLow));
|
||||
}
|
||||
|
||||
encoding.truncate(n_first, params.stride);
|
||||
if let Some(encoding) = pair_encoding.as_mut() {
|
||||
encoding.truncate(n_second, params.stride);
|
||||
@@ -88,14 +101,14 @@ pub fn truncate_encodings(
|
||||
} else if let Some(encoding) = pair_encoding.as_mut() {
|
||||
Ok(encoding)
|
||||
} else {
|
||||
Err(Box::new(Error::SecondSequenceNotProvided))
|
||||
Err(Box::new(TruncationError::SecondSequenceNotProvided))
|
||||
}?;
|
||||
|
||||
let target_len = target.get_ids().len();
|
||||
if target_len > to_remove {
|
||||
target.truncate(target_len - to_remove, params.stride);
|
||||
} else {
|
||||
return Err(Box::new(Error::SequenceTooShort));
|
||||
return Err(Box::new(TruncationError::SequenceTooShort));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user