diff --git a/tokenizers/src/tokenizer/encoding.rs b/tokenizers/src/tokenizer/encoding.rs index 678702b7..4ca661cc 100644 --- a/tokenizers/src/tokenizer/encoding.rs +++ b/tokenizers/src/tokenizer/encoding.rs @@ -1,21 +1,6 @@ +use crate::utils::padding::PaddingDirection; use rayon::prelude::*; -/// The various possible padding directions. -#[derive(Debug, Clone, Copy)] -pub enum PaddingDirection { - Left, - Right, -} - -impl std::convert::AsRef for PaddingDirection { - fn as_ref(&self) -> &str { - match self { - PaddingDirection::Left => "left", - PaddingDirection::Right => "right", - } - } -} - /// Represents the output of a `Tokenizer`. #[derive(Default, PartialEq, Debug, Clone)] pub struct Encoding { diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 107db142..bd406401 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -9,10 +9,9 @@ //! - [`PostProcessor`](trait.PostProcessor.html): Takes care of the processing after tokenization (like truncating, padding, //! ...). -pub use crate::utils::{ - pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams, - TruncationStrategy, -}; +use crate::utils::iter::ResultShunt; +pub use crate::utils::padding::{pad_encodings, PaddingParams, PaddingStrategy}; +pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy}; use indicatif::{ProgressBar, ProgressStyle}; use rayon::prelude::*; use std::{ @@ -302,10 +301,8 @@ impl Tokenizer { move |sentence: String, type_id: u32| -> Result<(Encoding, NormalizedString)> { // First we need to split into as many sequences as needed to avoid splitting // on our added tokens - let (mut encodings, mut normalized) = self - .split_on_added_tokens(&sentence) - .into_iter() - .map(|(sentence, id)| -> Result<(Encoding, NormalizedString)> { + let results = self.split_on_added_tokens(&sentence).into_iter().map( + |(sentence, id)| -> Result<(Encoding, NormalizedString)> { // If this is one of our added tokens, lets return an encoding directly if let Some(id) = id { return Ok(( @@ -358,11 +355,11 @@ impl Tokenizer { ), normalized, )) - }) - // TODO: Improve this and avoid collecting first... - .collect::>>()? - .into_iter() - .unzip::<_, _, Vec<_>, Vec<_>>(); + }, + ); + + let (mut encodings, mut normalized) = + ResultShunt::process(results, |iter| iter.unzip::<_, _, Vec<_>, Vec<_>>())?; if encodings.is_empty() { return Ok((Encoding::default(), NormalizedString::from(""))); diff --git a/tokenizers/src/utils/iter.rs b/tokenizers/src/utils/iter.rs new file mode 100644 index 00000000..45878af2 --- /dev/null +++ b/tokenizers/src/utils/iter.rs @@ -0,0 +1,58 @@ +//! This comes from the Rust libcore and is duplicated here because it is not exported +//! (cf https://github.com/rust-lang/rust/blob/25091ed9b7739e12466fb2490baa1e8a2815121c/src/libcore/iter/adapters/mod.rs#L2664) +//! We are now using the version from https://stackoverflow.com/questions/44544323/how-to-unzip-a-sequence-of-resulta-b-e-to-a-veca-vecb-and-stop-on-f +//! because the one from the libcore seems to cause overflowing stacks in some cases + +pub struct ResultShunt { + iter: I, + error: Option, +} + +impl ResultShunt +where + I: Iterator>, +{ + /// Process the given iterator as if it yielded a `T` instead of a + /// `Result`. Any errors will stop the inner iterator and + /// the overall result will be an error. + pub fn process(iter: I, mut f: F) -> Result + where + F: FnMut(&mut Self) -> U, + { + let mut shunt = ResultShunt::new(iter); + let value = f(shunt.by_ref()); + shunt.reconstruct(value) + } + + fn new(iter: I) -> Self { + ResultShunt { iter, error: None } + } + + /// Consume the adapter and rebuild a `Result` value. This should + /// *always* be called, otherwise any potential error would be + /// lost. + fn reconstruct(self, val: U) -> Result { + match self.error { + None => Ok(val), + Some(e) => Err(e), + } + } +} + +impl Iterator for ResultShunt +where + I: Iterator>, +{ + type Item = T; + + fn next(&mut self) -> Option { + match self.iter.next() { + Some(Ok(v)) => Some(v), + Some(Err(e)) => { + self.error = Some(e); + None + } + None => None, + } + } +} diff --git a/tokenizers/src/utils/mod.rs b/tokenizers/src/utils/mod.rs new file mode 100644 index 00000000..f961fd7d --- /dev/null +++ b/tokenizers/src/utils/mod.rs @@ -0,0 +1,3 @@ +pub mod iter; +pub mod padding; +pub mod truncation; diff --git a/tokenizers/src/utils/padding.rs b/tokenizers/src/utils/padding.rs new file mode 100644 index 00000000..a37b4ddf --- /dev/null +++ b/tokenizers/src/utils/padding.rs @@ -0,0 +1,63 @@ +use crate::tokenizer::{Encoding, Result}; +use rayon::prelude::*; + +/// The various possible padding directions. +#[derive(Debug, Clone, Copy)] +pub enum PaddingDirection { + Left, + Right, +} + +impl std::convert::AsRef for PaddingDirection { + fn as_ref(&self) -> &str { + match self { + PaddingDirection::Left => "left", + PaddingDirection::Right => "right", + } + } +} + +#[derive(Debug, Clone)] +pub struct PaddingParams { + pub strategy: PaddingStrategy, + pub direction: PaddingDirection, + pub pad_id: u32, + pub pad_type_id: u32, + pub pad_token: String, +} + +#[derive(Debug, Clone)] +pub enum PaddingStrategy { + BatchLongest, + Fixed(usize), +} + +pub fn pad_encodings( + mut encodings: Vec, + params: &PaddingParams, +) -> Result> { + if encodings.is_empty() { + return Ok(encodings); + } + + let pad_length = match params.strategy { + PaddingStrategy::Fixed(size) => size, + PaddingStrategy::BatchLongest => encodings + .par_iter() + .map(|e| e.get_ids().len()) + .max() + .unwrap(), + }; + + encodings.par_iter_mut().for_each(|encoding| { + encoding.pad( + pad_length, + params.pad_id, + params.pad_type_id, + ¶ms.pad_token, + params.direction, + ) + }); + + Ok(encodings) +} diff --git a/tokenizers/src/utils.rs b/tokenizers/src/utils/truncation.rs similarity index 74% rename from tokenizers/src/utils.rs rename to tokenizers/src/utils/truncation.rs index efc3ca5e..981276c0 100644 --- a/tokenizers/src/utils.rs +++ b/tokenizers/src/utils/truncation.rs @@ -1,5 +1,4 @@ -use crate::tokenizer::{Encoding, PaddingDirection, Result}; -use rayon::prelude::*; +use crate::tokenizer::{Encoding, Result}; #[derive(Debug, Clone)] pub struct TruncationParams { @@ -8,21 +7,6 @@ pub struct TruncationParams { pub stride: usize, } -#[derive(Debug, Clone)] -pub struct PaddingParams { - pub strategy: PaddingStrategy, - pub direction: PaddingDirection, - pub pad_id: u32, - pub pad_type_id: u32, - pub pad_token: String, -} - -#[derive(Debug, Clone)] -pub enum PaddingStrategy { - BatchLongest, - Fixed(usize), -} - #[derive(Debug)] pub enum Error { SecondSequenceNotProvided, @@ -118,33 +102,3 @@ pub fn truncate_encodings( Ok((encoding, pair_encoding)) } - -pub fn pad_encodings( - mut encodings: Vec, - params: &PaddingParams, -) -> Result> { - if encodings.is_empty() { - return Ok(encodings); - } - - let pad_length = match params.strategy { - PaddingStrategy::Fixed(size) => size, - PaddingStrategy::BatchLongest => encodings - .par_iter() - .map(|e| e.get_ids().len()) - .max() - .unwrap(), - }; - - encodings.par_iter_mut().for_each(|encoding| { - encoding.pad( - pad_length, - params.pad_id, - params.pad_type_id, - ¶ms.pad_token, - params.direction, - ) - }); - - Ok(encodings) -}