mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Rust - Improve utils and unzipping encode results
This commit is contained in:
@ -1,21 +1,6 @@
|
||||
use crate::utils::padding::PaddingDirection;
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// The various possible padding directions.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum PaddingDirection {
|
||||
Left,
|
||||
Right,
|
||||
}
|
||||
|
||||
impl std::convert::AsRef<str> for PaddingDirection {
|
||||
fn as_ref(&self) -> &str {
|
||||
match self {
|
||||
PaddingDirection::Left => "left",
|
||||
PaddingDirection::Right => "right",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the output of a `Tokenizer`.
|
||||
#[derive(Default, PartialEq, Debug, Clone)]
|
||||
pub struct Encoding {
|
||||
|
@ -9,10 +9,9 @@
|
||||
//! - [`PostProcessor`](trait.PostProcessor.html): Takes care of the processing after tokenization (like truncating, padding,
|
||||
//! ...).
|
||||
|
||||
pub use crate::utils::{
|
||||
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
|
||||
TruncationStrategy,
|
||||
};
|
||||
use crate::utils::iter::ResultShunt;
|
||||
pub use crate::utils::padding::{pad_encodings, PaddingParams, PaddingStrategy};
|
||||
pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use rayon::prelude::*;
|
||||
use std::{
|
||||
@ -302,10 +301,8 @@ impl Tokenizer {
|
||||
move |sentence: String, type_id: u32| -> Result<(Encoding, NormalizedString)> {
|
||||
// First we need to split into as many sequences as needed to avoid splitting
|
||||
// on our added tokens
|
||||
let (mut encodings, mut normalized) = self
|
||||
.split_on_added_tokens(&sentence)
|
||||
.into_iter()
|
||||
.map(|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
|
||||
let results = self.split_on_added_tokens(&sentence).into_iter().map(
|
||||
|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
|
||||
// If this is one of our added tokens, lets return an encoding directly
|
||||
if let Some(id) = id {
|
||||
return Ok((
|
||||
@ -358,11 +355,11 @@ impl Tokenizer {
|
||||
),
|
||||
normalized,
|
||||
))
|
||||
})
|
||||
// TODO: Improve this and avoid collecting first...
|
||||
.collect::<Result<Vec<(Encoding, NormalizedString)>>>()?
|
||||
.into_iter()
|
||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
||||
},
|
||||
);
|
||||
|
||||
let (mut encodings, mut normalized) =
|
||||
ResultShunt::process(results, |iter| iter.unzip::<_, _, Vec<_>, Vec<_>>())?;
|
||||
|
||||
if encodings.is_empty() {
|
||||
return Ok((Encoding::default(), NormalizedString::from("")));
|
||||
|
58
tokenizers/src/utils/iter.rs
Normal file
58
tokenizers/src/utils/iter.rs
Normal file
@ -0,0 +1,58 @@
|
||||
//! This comes from the Rust libcore and is duplicated here because it is not exported
|
||||
//! (cf https://github.com/rust-lang/rust/blob/25091ed9b7739e12466fb2490baa1e8a2815121c/src/libcore/iter/adapters/mod.rs#L2664)
|
||||
//! We are now using the version from https://stackoverflow.com/questions/44544323/how-to-unzip-a-sequence-of-resulta-b-e-to-a-veca-vecb-and-stop-on-f
|
||||
//! because the one from the libcore seems to cause overflowing stacks in some cases
|
||||
|
||||
pub struct ResultShunt<I, E> {
|
||||
iter: I,
|
||||
error: Option<E>,
|
||||
}
|
||||
|
||||
impl<I, T, E> ResultShunt<I, E>
|
||||
where
|
||||
I: Iterator<Item = Result<T, E>>,
|
||||
{
|
||||
/// Process the given iterator as if it yielded a `T` instead of a
|
||||
/// `Result<T, _>`. Any errors will stop the inner iterator and
|
||||
/// the overall result will be an error.
|
||||
pub fn process<F, U>(iter: I, mut f: F) -> Result<U, E>
|
||||
where
|
||||
F: FnMut(&mut Self) -> U,
|
||||
{
|
||||
let mut shunt = ResultShunt::new(iter);
|
||||
let value = f(shunt.by_ref());
|
||||
shunt.reconstruct(value)
|
||||
}
|
||||
|
||||
fn new(iter: I) -> Self {
|
||||
ResultShunt { iter, error: None }
|
||||
}
|
||||
|
||||
/// Consume the adapter and rebuild a `Result` value. This should
|
||||
/// *always* be called, otherwise any potential error would be
|
||||
/// lost.
|
||||
fn reconstruct<U>(self, val: U) -> Result<U, E> {
|
||||
match self.error {
|
||||
None => Ok(val),
|
||||
Some(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, T, E> Iterator for ResultShunt<I, E>
|
||||
where
|
||||
I: Iterator<Item = Result<T, E>>,
|
||||
{
|
||||
type Item = T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
match self.iter.next() {
|
||||
Some(Ok(v)) => Some(v),
|
||||
Some(Err(e)) => {
|
||||
self.error = Some(e);
|
||||
None
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
}
|
3
tokenizers/src/utils/mod.rs
Normal file
3
tokenizers/src/utils/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod iter;
|
||||
pub mod padding;
|
||||
pub mod truncation;
|
63
tokenizers/src/utils/padding.rs
Normal file
63
tokenizers/src/utils/padding.rs
Normal file
@ -0,0 +1,63 @@
|
||||
use crate::tokenizer::{Encoding, Result};
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// The various possible padding directions.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum PaddingDirection {
|
||||
Left,
|
||||
Right,
|
||||
}
|
||||
|
||||
impl std::convert::AsRef<str> for PaddingDirection {
|
||||
fn as_ref(&self) -> &str {
|
||||
match self {
|
||||
PaddingDirection::Left => "left",
|
||||
PaddingDirection::Right => "right",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PaddingParams {
|
||||
pub strategy: PaddingStrategy,
|
||||
pub direction: PaddingDirection,
|
||||
pub pad_id: u32,
|
||||
pub pad_type_id: u32,
|
||||
pub pad_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PaddingStrategy {
|
||||
BatchLongest,
|
||||
Fixed(usize),
|
||||
}
|
||||
|
||||
pub fn pad_encodings(
|
||||
mut encodings: Vec<Encoding>,
|
||||
params: &PaddingParams,
|
||||
) -> Result<Vec<Encoding>> {
|
||||
if encodings.is_empty() {
|
||||
return Ok(encodings);
|
||||
}
|
||||
|
||||
let pad_length = match params.strategy {
|
||||
PaddingStrategy::Fixed(size) => size,
|
||||
PaddingStrategy::BatchLongest => encodings
|
||||
.par_iter()
|
||||
.map(|e| e.get_ids().len())
|
||||
.max()
|
||||
.unwrap(),
|
||||
};
|
||||
|
||||
encodings.par_iter_mut().for_each(|encoding| {
|
||||
encoding.pad(
|
||||
pad_length,
|
||||
params.pad_id,
|
||||
params.pad_type_id,
|
||||
¶ms.pad_token,
|
||||
params.direction,
|
||||
)
|
||||
});
|
||||
|
||||
Ok(encodings)
|
||||
}
|
@ -1,5 +1,4 @@
|
||||
use crate::tokenizer::{Encoding, PaddingDirection, Result};
|
||||
use rayon::prelude::*;
|
||||
use crate::tokenizer::{Encoding, Result};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TruncationParams {
|
||||
@ -8,21 +7,6 @@ pub struct TruncationParams {
|
||||
pub stride: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PaddingParams {
|
||||
pub strategy: PaddingStrategy,
|
||||
pub direction: PaddingDirection,
|
||||
pub pad_id: u32,
|
||||
pub pad_type_id: u32,
|
||||
pub pad_token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PaddingStrategy {
|
||||
BatchLongest,
|
||||
Fixed(usize),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
SecondSequenceNotProvided,
|
||||
@ -118,33 +102,3 @@ pub fn truncate_encodings(
|
||||
|
||||
Ok((encoding, pair_encoding))
|
||||
}
|
||||
|
||||
pub fn pad_encodings(
|
||||
mut encodings: Vec<Encoding>,
|
||||
params: &PaddingParams,
|
||||
) -> Result<Vec<Encoding>> {
|
||||
if encodings.is_empty() {
|
||||
return Ok(encodings);
|
||||
}
|
||||
|
||||
let pad_length = match params.strategy {
|
||||
PaddingStrategy::Fixed(size) => size,
|
||||
PaddingStrategy::BatchLongest => encodings
|
||||
.par_iter()
|
||||
.map(|e| e.get_ids().len())
|
||||
.max()
|
||||
.unwrap(),
|
||||
};
|
||||
|
||||
encodings.par_iter_mut().for_each(|encoding| {
|
||||
encoding.pad(
|
||||
pad_length,
|
||||
params.pad_id,
|
||||
params.pad_type_id,
|
||||
¶ms.pad_token,
|
||||
params.direction,
|
||||
)
|
||||
});
|
||||
|
||||
Ok(encodings)
|
||||
}
|
Reference in New Issue
Block a user