Rust - Improve utils and unzipping encode results

This commit is contained in:
Anthony MOI
2020-03-15 13:41:59 -04:00
parent 45f3449096
commit e0779c50b5
6 changed files with 136 additions and 76 deletions

View File

@ -1,21 +1,6 @@
use crate::utils::padding::PaddingDirection;
use rayon::prelude::*;
/// The various possible padding directions.
#[derive(Debug, Clone, Copy)]
pub enum PaddingDirection {
Left,
Right,
}
impl std::convert::AsRef<str> for PaddingDirection {
fn as_ref(&self) -> &str {
match self {
PaddingDirection::Left => "left",
PaddingDirection::Right => "right",
}
}
}
/// Represents the output of a `Tokenizer`.
#[derive(Default, PartialEq, Debug, Clone)]
pub struct Encoding {

View File

@ -9,10 +9,9 @@
//! - [`PostProcessor`](trait.PostProcessor.html): Takes care of the processing after tokenization (like truncating, padding,
//! ...).
pub use crate::utils::{
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
TruncationStrategy,
};
use crate::utils::iter::ResultShunt;
pub use crate::utils::padding::{pad_encodings, PaddingParams, PaddingStrategy};
pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy};
use indicatif::{ProgressBar, ProgressStyle};
use rayon::prelude::*;
use std::{
@ -302,10 +301,8 @@ impl Tokenizer {
move |sentence: String, type_id: u32| -> Result<(Encoding, NormalizedString)> {
// First we need to split into as many sequences as needed to avoid splitting
// on our added tokens
let (mut encodings, mut normalized) = self
.split_on_added_tokens(&sentence)
.into_iter()
.map(|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
let results = self.split_on_added_tokens(&sentence).into_iter().map(
|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
// If this is one of our added tokens, lets return an encoding directly
if let Some(id) = id {
return Ok((
@ -358,11 +355,11 @@ impl Tokenizer {
),
normalized,
))
})
// TODO: Improve this and avoid collecting first...
.collect::<Result<Vec<(Encoding, NormalizedString)>>>()?
.into_iter()
.unzip::<_, _, Vec<_>, Vec<_>>();
},
);
let (mut encodings, mut normalized) =
ResultShunt::process(results, |iter| iter.unzip::<_, _, Vec<_>, Vec<_>>())?;
if encodings.is_empty() {
return Ok((Encoding::default(), NormalizedString::from("")));

View File

@ -0,0 +1,58 @@
//! This comes from the Rust libcore and is duplicated here because it is not exported
//! (cf https://github.com/rust-lang/rust/blob/25091ed9b7739e12466fb2490baa1e8a2815121c/src/libcore/iter/adapters/mod.rs#L2664)
//! We are now using the version from https://stackoverflow.com/questions/44544323/how-to-unzip-a-sequence-of-resulta-b-e-to-a-veca-vecb-and-stop-on-f
//! because the one from the libcore seems to cause overflowing stacks in some cases
pub struct ResultShunt<I, E> {
iter: I,
error: Option<E>,
}
impl<I, T, E> ResultShunt<I, E>
where
I: Iterator<Item = Result<T, E>>,
{
/// Process the given iterator as if it yielded a `T` instead of a
/// `Result<T, _>`. Any errors will stop the inner iterator and
/// the overall result will be an error.
pub fn process<F, U>(iter: I, mut f: F) -> Result<U, E>
where
F: FnMut(&mut Self) -> U,
{
let mut shunt = ResultShunt::new(iter);
let value = f(shunt.by_ref());
shunt.reconstruct(value)
}
fn new(iter: I) -> Self {
ResultShunt { iter, error: None }
}
/// Consume the adapter and rebuild a `Result` value. This should
/// *always* be called, otherwise any potential error would be
/// lost.
fn reconstruct<U>(self, val: U) -> Result<U, E> {
match self.error {
None => Ok(val),
Some(e) => Err(e),
}
}
}
impl<I, T, E> Iterator for ResultShunt<I, E>
where
I: Iterator<Item = Result<T, E>>,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self.iter.next() {
Some(Ok(v)) => Some(v),
Some(Err(e)) => {
self.error = Some(e);
None
}
None => None,
}
}
}

View File

@ -0,0 +1,3 @@
pub mod iter;
pub mod padding;
pub mod truncation;

View File

@ -0,0 +1,63 @@
use crate::tokenizer::{Encoding, Result};
use rayon::prelude::*;
/// The various possible padding directions.
#[derive(Debug, Clone, Copy)]
pub enum PaddingDirection {
Left,
Right,
}
impl std::convert::AsRef<str> for PaddingDirection {
fn as_ref(&self) -> &str {
match self {
PaddingDirection::Left => "left",
PaddingDirection::Right => "right",
}
}
}
#[derive(Debug, Clone)]
pub struct PaddingParams {
pub strategy: PaddingStrategy,
pub direction: PaddingDirection,
pub pad_id: u32,
pub pad_type_id: u32,
pub pad_token: String,
}
#[derive(Debug, Clone)]
pub enum PaddingStrategy {
BatchLongest,
Fixed(usize),
}
pub fn pad_encodings(
mut encodings: Vec<Encoding>,
params: &PaddingParams,
) -> Result<Vec<Encoding>> {
if encodings.is_empty() {
return Ok(encodings);
}
let pad_length = match params.strategy {
PaddingStrategy::Fixed(size) => size,
PaddingStrategy::BatchLongest => encodings
.par_iter()
.map(|e| e.get_ids().len())
.max()
.unwrap(),
};
encodings.par_iter_mut().for_each(|encoding| {
encoding.pad(
pad_length,
params.pad_id,
params.pad_type_id,
&params.pad_token,
params.direction,
)
});
Ok(encodings)
}

View File

@ -1,5 +1,4 @@
use crate::tokenizer::{Encoding, PaddingDirection, Result};
use rayon::prelude::*;
use crate::tokenizer::{Encoding, Result};
#[derive(Debug, Clone)]
pub struct TruncationParams {
@ -8,21 +7,6 @@ pub struct TruncationParams {
pub stride: usize,
}
#[derive(Debug, Clone)]
pub struct PaddingParams {
pub strategy: PaddingStrategy,
pub direction: PaddingDirection,
pub pad_id: u32,
pub pad_type_id: u32,
pub pad_token: String,
}
#[derive(Debug, Clone)]
pub enum PaddingStrategy {
BatchLongest,
Fixed(usize),
}
#[derive(Debug)]
pub enum Error {
SecondSequenceNotProvided,
@ -118,33 +102,3 @@ pub fn truncate_encodings(
Ok((encoding, pair_encoding))
}
pub fn pad_encodings(
mut encodings: Vec<Encoding>,
params: &PaddingParams,
) -> Result<Vec<Encoding>> {
if encodings.is_empty() {
return Ok(encodings);
}
let pad_length = match params.strategy {
PaddingStrategy::Fixed(size) => size,
PaddingStrategy::BatchLongest => encodings
.par_iter()
.map(|e| e.get_ids().len())
.max()
.unwrap(),
};
encodings.par_iter_mut().for_each(|encoding| {
encoding.pad(
pad_length,
params.pad_id,
params.pad_type_id,
&params.pad_token,
params.direction,
)
});
Ok(encodings)
}