mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 08:45:38 +00:00
Rust - Improve utils and unzipping encode results
This commit is contained in:
@ -1,21 +1,6 @@
|
|||||||
|
use crate::utils::padding::PaddingDirection;
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
/// The various possible padding directions.
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub enum PaddingDirection {
|
|
||||||
Left,
|
|
||||||
Right,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::convert::AsRef<str> for PaddingDirection {
|
|
||||||
fn as_ref(&self) -> &str {
|
|
||||||
match self {
|
|
||||||
PaddingDirection::Left => "left",
|
|
||||||
PaddingDirection::Right => "right",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Represents the output of a `Tokenizer`.
|
/// Represents the output of a `Tokenizer`.
|
||||||
#[derive(Default, PartialEq, Debug, Clone)]
|
#[derive(Default, PartialEq, Debug, Clone)]
|
||||||
pub struct Encoding {
|
pub struct Encoding {
|
||||||
|
@ -9,10 +9,9 @@
|
|||||||
//! - [`PostProcessor`](trait.PostProcessor.html): Takes care of the processing after tokenization (like truncating, padding,
|
//! - [`PostProcessor`](trait.PostProcessor.html): Takes care of the processing after tokenization (like truncating, padding,
|
||||||
//! ...).
|
//! ...).
|
||||||
|
|
||||||
pub use crate::utils::{
|
use crate::utils::iter::ResultShunt;
|
||||||
pad_encodings, truncate_encodings, PaddingParams, PaddingStrategy, TruncationParams,
|
pub use crate::utils::padding::{pad_encodings, PaddingParams, PaddingStrategy};
|
||||||
TruncationStrategy,
|
pub use crate::utils::truncation::{truncate_encodings, TruncationParams, TruncationStrategy};
|
||||||
};
|
|
||||||
use indicatif::{ProgressBar, ProgressStyle};
|
use indicatif::{ProgressBar, ProgressStyle};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use std::{
|
use std::{
|
||||||
@ -302,10 +301,8 @@ impl Tokenizer {
|
|||||||
move |sentence: String, type_id: u32| -> Result<(Encoding, NormalizedString)> {
|
move |sentence: String, type_id: u32| -> Result<(Encoding, NormalizedString)> {
|
||||||
// First we need to split into as many sequences as needed to avoid splitting
|
// First we need to split into as many sequences as needed to avoid splitting
|
||||||
// on our added tokens
|
// on our added tokens
|
||||||
let (mut encodings, mut normalized) = self
|
let results = self.split_on_added_tokens(&sentence).into_iter().map(
|
||||||
.split_on_added_tokens(&sentence)
|
|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
|
||||||
.into_iter()
|
|
||||||
.map(|(sentence, id)| -> Result<(Encoding, NormalizedString)> {
|
|
||||||
// If this is one of our added tokens, lets return an encoding directly
|
// If this is one of our added tokens, lets return an encoding directly
|
||||||
if let Some(id) = id {
|
if let Some(id) = id {
|
||||||
return Ok((
|
return Ok((
|
||||||
@ -358,11 +355,11 @@ impl Tokenizer {
|
|||||||
),
|
),
|
||||||
normalized,
|
normalized,
|
||||||
))
|
))
|
||||||
})
|
},
|
||||||
// TODO: Improve this and avoid collecting first...
|
);
|
||||||
.collect::<Result<Vec<(Encoding, NormalizedString)>>>()?
|
|
||||||
.into_iter()
|
let (mut encodings, mut normalized) =
|
||||||
.unzip::<_, _, Vec<_>, Vec<_>>();
|
ResultShunt::process(results, |iter| iter.unzip::<_, _, Vec<_>, Vec<_>>())?;
|
||||||
|
|
||||||
if encodings.is_empty() {
|
if encodings.is_empty() {
|
||||||
return Ok((Encoding::default(), NormalizedString::from("")));
|
return Ok((Encoding::default(), NormalizedString::from("")));
|
||||||
|
58
tokenizers/src/utils/iter.rs
Normal file
58
tokenizers/src/utils/iter.rs
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
//! This comes from the Rust libcore and is duplicated here because it is not exported
|
||||||
|
//! (cf https://github.com/rust-lang/rust/blob/25091ed9b7739e12466fb2490baa1e8a2815121c/src/libcore/iter/adapters/mod.rs#L2664)
|
||||||
|
//! We are now using the version from https://stackoverflow.com/questions/44544323/how-to-unzip-a-sequence-of-resulta-b-e-to-a-veca-vecb-and-stop-on-f
|
||||||
|
//! because the one from the libcore seems to cause overflowing stacks in some cases
|
||||||
|
|
||||||
|
pub struct ResultShunt<I, E> {
|
||||||
|
iter: I,
|
||||||
|
error: Option<E>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I, T, E> ResultShunt<I, E>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = Result<T, E>>,
|
||||||
|
{
|
||||||
|
/// Process the given iterator as if it yielded a `T` instead of a
|
||||||
|
/// `Result<T, _>`. Any errors will stop the inner iterator and
|
||||||
|
/// the overall result will be an error.
|
||||||
|
pub fn process<F, U>(iter: I, mut f: F) -> Result<U, E>
|
||||||
|
where
|
||||||
|
F: FnMut(&mut Self) -> U,
|
||||||
|
{
|
||||||
|
let mut shunt = ResultShunt::new(iter);
|
||||||
|
let value = f(shunt.by_ref());
|
||||||
|
shunt.reconstruct(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(iter: I) -> Self {
|
||||||
|
ResultShunt { iter, error: None }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Consume the adapter and rebuild a `Result` value. This should
|
||||||
|
/// *always* be called, otherwise any potential error would be
|
||||||
|
/// lost.
|
||||||
|
fn reconstruct<U>(self, val: U) -> Result<U, E> {
|
||||||
|
match self.error {
|
||||||
|
None => Ok(val),
|
||||||
|
Some(e) => Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I, T, E> Iterator for ResultShunt<I, E>
|
||||||
|
where
|
||||||
|
I: Iterator<Item = Result<T, E>>,
|
||||||
|
{
|
||||||
|
type Item = T;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
match self.iter.next() {
|
||||||
|
Some(Ok(v)) => Some(v),
|
||||||
|
Some(Err(e)) => {
|
||||||
|
self.error = Some(e);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
3
tokenizers/src/utils/mod.rs
Normal file
3
tokenizers/src/utils/mod.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
pub mod iter;
|
||||||
|
pub mod padding;
|
||||||
|
pub mod truncation;
|
63
tokenizers/src/utils/padding.rs
Normal file
63
tokenizers/src/utils/padding.rs
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
use crate::tokenizer::{Encoding, Result};
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
/// The various possible padding directions.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum PaddingDirection {
|
||||||
|
Left,
|
||||||
|
Right,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::convert::AsRef<str> for PaddingDirection {
|
||||||
|
fn as_ref(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
PaddingDirection::Left => "left",
|
||||||
|
PaddingDirection::Right => "right",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PaddingParams {
|
||||||
|
pub strategy: PaddingStrategy,
|
||||||
|
pub direction: PaddingDirection,
|
||||||
|
pub pad_id: u32,
|
||||||
|
pub pad_type_id: u32,
|
||||||
|
pub pad_token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum PaddingStrategy {
|
||||||
|
BatchLongest,
|
||||||
|
Fixed(usize),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pad_encodings(
|
||||||
|
mut encodings: Vec<Encoding>,
|
||||||
|
params: &PaddingParams,
|
||||||
|
) -> Result<Vec<Encoding>> {
|
||||||
|
if encodings.is_empty() {
|
||||||
|
return Ok(encodings);
|
||||||
|
}
|
||||||
|
|
||||||
|
let pad_length = match params.strategy {
|
||||||
|
PaddingStrategy::Fixed(size) => size,
|
||||||
|
PaddingStrategy::BatchLongest => encodings
|
||||||
|
.par_iter()
|
||||||
|
.map(|e| e.get_ids().len())
|
||||||
|
.max()
|
||||||
|
.unwrap(),
|
||||||
|
};
|
||||||
|
|
||||||
|
encodings.par_iter_mut().for_each(|encoding| {
|
||||||
|
encoding.pad(
|
||||||
|
pad_length,
|
||||||
|
params.pad_id,
|
||||||
|
params.pad_type_id,
|
||||||
|
¶ms.pad_token,
|
||||||
|
params.direction,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(encodings)
|
||||||
|
}
|
@ -1,5 +1,4 @@
|
|||||||
use crate::tokenizer::{Encoding, PaddingDirection, Result};
|
use crate::tokenizer::{Encoding, Result};
|
||||||
use rayon::prelude::*;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct TruncationParams {
|
pub struct TruncationParams {
|
||||||
@ -8,21 +7,6 @@ pub struct TruncationParams {
|
|||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct PaddingParams {
|
|
||||||
pub strategy: PaddingStrategy,
|
|
||||||
pub direction: PaddingDirection,
|
|
||||||
pub pad_id: u32,
|
|
||||||
pub pad_type_id: u32,
|
|
||||||
pub pad_token: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum PaddingStrategy {
|
|
||||||
BatchLongest,
|
|
||||||
Fixed(usize),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
SecondSequenceNotProvided,
|
SecondSequenceNotProvided,
|
||||||
@ -118,33 +102,3 @@ pub fn truncate_encodings(
|
|||||||
|
|
||||||
Ok((encoding, pair_encoding))
|
Ok((encoding, pair_encoding))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn pad_encodings(
|
|
||||||
mut encodings: Vec<Encoding>,
|
|
||||||
params: &PaddingParams,
|
|
||||||
) -> Result<Vec<Encoding>> {
|
|
||||||
if encodings.is_empty() {
|
|
||||||
return Ok(encodings);
|
|
||||||
}
|
|
||||||
|
|
||||||
let pad_length = match params.strategy {
|
|
||||||
PaddingStrategy::Fixed(size) => size,
|
|
||||||
PaddingStrategy::BatchLongest => encodings
|
|
||||||
.par_iter()
|
|
||||||
.map(|e| e.get_ids().len())
|
|
||||||
.max()
|
|
||||||
.unwrap(),
|
|
||||||
};
|
|
||||||
|
|
||||||
encodings.par_iter_mut().for_each(|encoding| {
|
|
||||||
encoding.pad(
|
|
||||||
pad_length,
|
|
||||||
params.pad_id,
|
|
||||||
params.pad_type_id,
|
|
||||||
¶ms.pad_token,
|
|
||||||
params.direction,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(encodings)
|
|
||||||
}
|
|
Reference in New Issue
Block a user