From dad70e8e8505476bd8e14a1a897948c1af17394e Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Fri, 31 Jul 2020 16:49:00 -0400 Subject: [PATCH] Implement suggestions by @sebpuetz MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sebastian Pütz --- bindings/python/src/pre_tokenizers.rs | 2 +- tokenizers/src/pre_tokenizers/bert.rs | 3 +- tokenizers/src/pre_tokenizers/byte_level.rs | 16 ++-- tokenizers/src/pre_tokenizers/metaspace.rs | 24 ++--- tokenizers/src/pre_tokenizers/whitespace.rs | 12 ++- tokenizers/src/tokenizer/added_vocabulary.rs | 8 +- tokenizers/src/tokenizer/encoding.rs | 34 +++++++ tokenizers/src/tokenizer/mod.rs | 69 +++++++------- tokenizers/src/tokenizer/normalizer.rs | 18 ++-- tokenizers/src/tokenizer/pattern.rs | 36 ++------ tokenizers/src/tokenizer/pre_tokenizer.rs | 96 ++++++++------------ 11 files changed, 161 insertions(+), 157 deletions(-) diff --git a/bindings/python/src/pre_tokenizers.rs b/bindings/python/src/pre_tokenizers.rs index 70b87473..7685b557 100644 --- a/bindings/python/src/pre_tokenizers.rs +++ b/bindings/python/src/pre_tokenizers.rs @@ -61,7 +61,7 @@ impl PreTokenizer { .into_py()?; Ok(pretokenized - .get_normalized(true) + .get_normalized(tk::OffsetReferential::Original) .into_iter() .map(|(s, o)| (s.to_owned(), o)) .collect()) diff --git a/tokenizers/src/pre_tokenizers/bert.rs b/tokenizers/src/pre_tokenizers/bert.rs index e44e29dc..285c87a5 100644 --- a/tokenizers/src/pre_tokenizers/bert.rs +++ b/tokenizers/src/pre_tokenizers/bert.rs @@ -32,6 +32,7 @@ impl PreTokenizer for BertPreTokenizer { #[cfg(test)] mod tests { use super::*; + use crate::OffsetReferential; #[test] fn basic() { @@ -39,7 +40,7 @@ mod tests { let mut pretokenized: PreTokenizedString = "Hey friend! How are you?!?".into(); pretok.pre_tokenize(&mut pretokenized).unwrap(); assert_eq!( - pretokenized.get_normalized(true), + pretokenized.get_normalized(OffsetReferential::Original), vec![ ("Hey", (0, 3)), ("", (3, 4)), diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 75d3023f..e2e48c01 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -218,8 +218,8 @@ pub fn process_offsets(encoding: &mut Encoding, add_prefix_space: bool) { mod tests { use super::ByteLevel; use crate::tokenizer::{ - normalizer::Range, Decoder, Encoding, NormalizedString, PostProcessor, PreTokenizedString, - PreTokenizer, + normalizer::Range, Decoder, Encoding, NormalizedString, OffsetReferential, PostProcessor, + PreTokenizedString, PreTokenizer, }; #[test] @@ -228,7 +228,7 @@ mod tests { let mut pretokenized: PreTokenizedString = "Hello my friend, how is your day going?".into(); bytelevel.pre_tokenize(&mut pretokenized).unwrap(); assert_eq!( - pretokenized.get_normalized(true), + pretokenized.get_normalized(OffsetReferential::Original), vec![ ("Hello", (0, 5)), ("Ġmy", (5, 8)), @@ -273,7 +273,7 @@ mod tests { let mut pretokenized = PreTokenizedString::from(*s); bytelevel.pre_tokenize(&mut pretokenized).unwrap(); assert_eq!( - pretokenized.get_normalized(false), + pretokenized.get_normalized(OffsetReferential::Normalized), vec![ ("ĠHello", (0, 6)), ("Ġmy", (6, 9)), @@ -317,7 +317,7 @@ mod tests { bytelevel.pre_tokenize(&mut pretokenized).unwrap(); assert_eq!( - pretokenized.get_normalized(true), + pretokenized.get_normalized(OffsetReferential::Original), vec![ ("Hello", (0, 5)), ("Ġthere", (5, 11)), @@ -335,7 +335,7 @@ mod tests { bytelevel.pre_tokenize(&mut pretokenized).unwrap(); assert_eq!( - pretokenized.get_normalized(true), + pretokenized.get_normalized(OffsetReferential::Original), vec![ ("Hello", (0, 5)), ("Ġthere", (5, 11)), @@ -352,11 +352,11 @@ mod tests { bytelevel.pre_tokenize(&mut pretokenized).unwrap(); assert_eq!( - pretokenized.get_normalized(true), + pretokenized.get_normalized(OffsetReferential::Original), vec![("i", (0, 1)), ("âŃ¢", (1, 2)), ("j", (2, 3))] ); assert_eq!( - pretokenized.get_normalized(false), + pretokenized.get_normalized(OffsetReferential::Normalized), vec![("i", (0, 1)), ("âŃ¢", (1, 4)), ("j", (4, 5))] ); assert_eq!( diff --git a/tokenizers/src/pre_tokenizers/metaspace.rs b/tokenizers/src/pre_tokenizers/metaspace.rs index 14791de4..78222e9d 100644 --- a/tokenizers/src/pre_tokenizers/metaspace.rs +++ b/tokenizers/src/pre_tokenizers/metaspace.rs @@ -6,25 +6,18 @@ use serde::{Deserialize, Serialize}; /// splits on this character pub struct Metaspace { replacement: char, - str_bytes: [u8; 4], + str_rep: String, add_prefix_space: bool, } impl Metaspace { pub fn new(replacement: char, add_prefix_space: bool) -> Self { - let mut str_bytes = [0; 4]; - replacement.encode_utf8(&mut str_bytes); Self { replacement, - str_bytes, + str_rep: replacement.to_string(), add_prefix_space, } } - - #[inline] - fn replacement(&self) -> &str { - unsafe { std::str::from_utf8_unchecked(&self.str_bytes[..self.replacement.len_utf8()]) } - } } impl Default for Metaspace { @@ -38,14 +31,14 @@ impl PreTokenizer for Metaspace { fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { pretokenized.split(|_, mut normalized| { if self.add_prefix_space { - normalized.prepend(&self.replacement()); + normalized.prepend(&self.str_rep); } Ok(normalized .split(' ', SplitDelimiterBehavior::MergedWithNext)? .into_iter() .map(|mut normalized| { - normalized.replace(' ', self.replacement())?; + normalized.replace(' ', &self.str_rep)?; Ok(normalized) }) .collect::>>()?) @@ -78,6 +71,7 @@ impl Decoder for Metaspace { #[cfg(test)] mod tests { use super::*; + use crate::OffsetReferential; #[test] fn basic() { @@ -85,11 +79,11 @@ mod tests { let mut pretokenized = PreTokenizedString::from("Hey friend!"); pretok.pre_tokenize(&mut pretokenized).unwrap(); assert_eq!( - pretokenized.get_normalized(false), + pretokenized.get_normalized(OffsetReferential::Normalized), vec![("▁Hey", (0, 4)), ("▁friend!", (4, 12))] ); assert_eq!( - pretokenized.get_normalized(true), + pretokenized.get_normalized(OffsetReferential::Original), vec![("▁Hey", (0, 3)), ("▁friend!", (3, 11))] ); } @@ -100,7 +94,7 @@ mod tests { let mut pretokenized = PreTokenizedString::from("Hey friend!"); pretok.pre_tokenize(&mut pretokenized).unwrap(); assert_eq!( - pretokenized.get_normalized(false), + pretokenized.get_normalized(OffsetReferential::Normalized), vec![ ("▁Hey", (0, 4)), ("▁", (4, 5)), @@ -109,7 +103,7 @@ mod tests { ] ); assert_eq!( - pretokenized.get_normalized(true), + pretokenized.get_normalized(OffsetReferential::Original), vec![ ("▁Hey", (0, 3)), ("▁", (3, 4)), diff --git a/tokenizers/src/pre_tokenizers/whitespace.rs b/tokenizers/src/pre_tokenizers/whitespace.rs index 24d186a3..c872fd67 100644 --- a/tokenizers/src/pre_tokenizers/whitespace.rs +++ b/tokenizers/src/pre_tokenizers/whitespace.rs @@ -45,7 +45,7 @@ impl PreTokenizer for WhitespaceSplit { #[cfg(test)] mod tests { use super::*; - use crate::tokenizer::PreTokenizer; + use crate::{OffsetReferential, PreTokenizer}; #[test] fn basic() { @@ -77,7 +77,10 @@ mod tests { for (s, res) in tests { let mut pretokenized = PreTokenizedString::from(s); pretok.pre_tokenize(&mut pretokenized).unwrap(); - assert_eq!(pretokenized.get_normalized(true), res); + assert_eq!( + pretokenized.get_normalized(OffsetReferential::Original), + res + ); } } @@ -103,7 +106,10 @@ mod tests { for (s, res) in tests { let mut pretokenized = PreTokenizedString::from(s); pretok.pre_tokenize(&mut pretokenized).unwrap(); - assert_eq!(pretokenized.get_normalized(true), res); + assert_eq!( + pretokenized.get_normalized(OffsetReferential::Original), + res + ); } } } diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 05ac958d..0bd5bc9a 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -454,24 +454,22 @@ impl AddedVocabulary { pretokenized .split(|i, mut sequence| { if let Some(id) = indices[i] { - multi_indices.push(vec![Some(id)]); + multi_indices.push(Some(id)); Ok(itertools::Either::Left(std::iter::once(sequence))) } else { normalizer.map(|n| n.normalize(&mut sequence)); let (idcs, split) = self.split_with_indices(sequence, &self.split_normalized_re); - multi_indices.push(idcs); + multi_indices.extend(idcs); Ok(itertools::Either::Right(split)) } }) .expect("AddedVocabulary bad split"); - let indices = multi_indices.into_iter().flatten().collect::>(); - pretokenized .into_iter() - .zip(indices) + .zip(multi_indices) .map(|(substring, id)| { ( substring.normalized, diff --git a/tokenizers/src/tokenizer/encoding.rs b/tokenizers/src/tokenizer/encoding.rs index 092b2707..6e178179 100644 --- a/tokenizers/src/tokenizer/encoding.rs +++ b/tokenizers/src/tokenizer/encoding.rs @@ -47,6 +47,19 @@ impl Encoding { } } + pub fn with_capacity(len: usize) -> Self { + Encoding { + ids: Vec::with_capacity(len), + type_ids: Vec::with_capacity(len), + tokens: Vec::with_capacity(len), + words: Vec::with_capacity(len), + offsets: Vec::with_capacity(len), + special_tokens_mask: Vec::with_capacity(len), + attention_mask: Vec::with_capacity(len), + overflowing: vec![], + } + } + pub fn from_tokens(tokens: Vec, type_id: u32) -> Self { let length = tokens.len(); let (ids, tokens, offsets) = tokens.into_iter().fold( @@ -404,6 +417,27 @@ impl std::iter::FromIterator for Encoding { } } +impl std::iter::FromIterator<(u32, String, (usize, usize), Option, u32)> for Encoding { + fn from_iter, u32)>>( + iter: I, + ) -> Self { + let items = iter.into_iter(); + let (lower, upper) = items.size_hint(); + let length = upper.unwrap_or(lower); + let mut encoding = Self::with_capacity(length); + + for (id, token, offsets, word, type_id) in items { + encoding.ids.push(id); + encoding.tokens.push(token); + encoding.offsets.push(offsets); + encoding.type_ids.push(type_id); + encoding.words.push(word); + } + + encoding + } +} + #[inline] fn get_current_part( prev: &[T], diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index fad16f8f..d0497c11 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -31,7 +31,7 @@ mod serialization; pub use added_vocabulary::*; pub use encoding::*; -pub use normalizer::{NormalizedString, SplitDelimiterBehavior}; +pub use normalizer::{NormalizedString, OffsetReferential, SplitDelimiterBehavior}; pub use pre_tokenizer::*; pub type Error = Box; @@ -423,6 +423,7 @@ impl Tokenizer { .added_vocabulary .extract_and_normalize(self.normalizer.as_deref(), &subseq) .map(|(normalized, original_offsets, id)| match id { + // This is an added token, no need to tokenize, we have the ID Some(id) => { let mut encoding = Encoding::from_tokens( vec![Token::new( @@ -435,6 +436,7 @@ impl Tokenizer { encoding.get_words_mut()[0] = Some(0); Ok(encoding) } + // Let's tokenize None => self.do_tokenize( self.do_pre_tokenize(normalized)?, original_offsets, @@ -675,45 +677,40 @@ impl Tokenizer { ) -> Result { let pretokenized: PreTokenizedString = pretokenized.into(); - let mut empty_words = 0; pretokenized .into_iter() + .filter(|substr| !substr.normalized.is_empty()) .enumerate() - .map(|(word_idx, substr)| { - if substr.normalized.is_empty() { - empty_words += 1; - return Ok(Encoding::default()); + .flat_map(|(word_idx, substr)| { + match self.model.tokenize(substr.normalized.get()) { + Ok(tokens) => { + itertools::Either::Left(tokens.into_iter().map(move |token| { + // We convert the normalized offsets back to the original + let converted_offsets = substr + .normalized + .convert_offsets(Range::Normalized( + token.offsets.0..token.offsets.1, + )) + .map_or(token.offsets, |range| { + ( + original_offsets.0 + + substr.original_offsets.0 + + range.start, + original_offsets.0 + substr.original_offsets.0 + range.end, + ) + }); + + Ok(( + token.id, + token.value, + converted_offsets, + Some(word_idx as u32), + type_id, + )) + })) + } + Err(e) => itertools::Either::Right(std::iter::once(Err(e))), } - - let mut tokens = self.model.tokenize(substr.normalized.get())?; - - // Update the offsets to match the original input - tokens.iter_mut().for_each(|token| { - // We convert the normalized offsets back to the original - let converted_offsets = substr - .normalized - .convert_offsets(Range::Normalized(token.offsets.0..token.offsets.1)) - .map_or(token.offsets, |range| { - ( - original_offsets.0 + substr.original_offsets.0 + range.start, - original_offsets.0 + substr.original_offsets.0 + range.end, - ) - }); - - // And we update the token to these original offsets, applying the original offset - // of the sequence we just tokenized. - token.offsets = converted_offsets; - }); - - // Then build the encoding from these tokens, setting the `words` as relevant - let mut encoding = Encoding::from_tokens(tokens, type_id); - encoding.get_words_mut().iter_mut().for_each(|word| { - // empty words are generally spaces, and other things - // that were normalized out, so we dont want to count them in. - *word = Some(word_idx as u32 - empty_words); - }); - - Ok(encoding) }) .collect() } diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index 75622f14..8bca7e5d 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -5,6 +5,12 @@ use crate::{Offsets, Result}; use std::ops::{Bound, RangeBounds}; use unicode_normalization_alignments::UnicodeNormalization; +/// The possible offsets referential +pub enum OffsetReferential { + Original, + Normalized, +} + /// Represents a Range usable by the NormalizedString to index its content. /// A Range can use indices relative to either the `Original` or the `Normalized` string #[derive(Debug, Clone)] @@ -300,11 +306,11 @@ impl NormalizedString { Some(Self { original: get_range_of(&self.original, r_original) - .unwrap_or("") - .to_owned(), + .unwrap_or_default() + .into(), normalized: get_range_of(&self.normalized, r_normalized.clone()) - .unwrap_or("") - .to_owned(), + .unwrap_or_default() + .into(), alignments: self .alignments .get(r_normalized)? @@ -462,7 +468,7 @@ impl NormalizedString { pub fn replace(&mut self, pattern: P, content: &str) -> Result<()> { let matches = pattern.find_matches(&self.normalized)?; - let (normalized, alignments): (Vec, Vec) = matches + let (normalized, alignments): (String, Vec) = matches .into_iter() .flat_map(|((start, end), is_match)| { let len = end - start; @@ -490,7 +496,7 @@ impl NormalizedString { }) .unzip(); - self.normalized = normalized.into_iter().collect(); + self.normalized = normalized; self.alignments = alignments; Ok(()) diff --git a/tokenizers/src/tokenizer/pattern.rs b/tokenizers/src/tokenizer/pattern.rs index d977dea8..44a2fbfd 100644 --- a/tokenizers/src/tokenizer/pattern.rs +++ b/tokenizers/src/tokenizer/pattern.rs @@ -31,34 +31,18 @@ impl Pattern for &Regex { return Ok(vec![((0, 0), false)]); } - // Find initial matches - let matches = self - .find_iter(inside) - .map(|m| ((m.start(), m.end()), true)) - .collect::>(); - - // Then add missing splits inbetween - let mut start_offset = 0; - let mut splits = matches - .into_iter() - .flat_map(|((start, end), flag)| { - let mut splits = vec![]; - if start_offset < start { - splits.push(((start_offset, start), false)); - } - splits.push(((start, end), flag)); - start_offset = end; - - splits - }) - .collect::>(); - - if let Some(((_, end), _)) = splits.iter().last().copied() { - if end < inside.len() { - splits.push(((end, inside.len()), false)); + let mut prev = 0; + let mut splits = Vec::with_capacity(inside.len()); + for m in self.find_iter(inside) { + if prev != m.start() { + splits.push(((prev, m.start()), false)); } + splits.push(((m.start(), m.end()), true)); + prev = m.end(); + } + if prev != inside.len() { + splits.push(((prev, inside.len()), false)) } - Ok(splits) } } diff --git a/tokenizers/src/tokenizer/pre_tokenizer.rs b/tokenizers/src/tokenizer/pre_tokenizer.rs index 8499da59..04522661 100644 --- a/tokenizers/src/tokenizer/pre_tokenizer.rs +++ b/tokenizers/src/tokenizer/pre_tokenizer.rs @@ -1,4 +1,4 @@ -use crate::{NormalizedString, Offsets, Result}; +use crate::{NormalizedString, OffsetReferential, Offsets, Result}; /// Wrapper for a subpart of a `NormalizedString`. /// @@ -16,7 +16,15 @@ pub struct SubString { pub original_offsets: Offsets, } -/// A `PreTokenizedString` takes care of splitting the input string in multiple +impl SubString { + pub fn new(normalized: NormalizedString, original_offsets: Offsets) -> Self { + Self { + normalized, + original_offsets, + } + } +} + /// sub strings, while ensuring that they form a coherend group. This let us keep /// track of the offsets during the whole normalization and pre-tokenization steps. #[derive(Debug)] @@ -42,52 +50,28 @@ impl PreTokenizedString { F: FnMut(usize, NormalizedString) -> Result, U: IntoIterator, { - self.parts = self - .parts - .drain(..) - .enumerate() - .flat_map(|(i, sub)| { - let original_len = sub.normalized.len_original(); - let original_offsets = sub.original_offsets; + // new_parts is at least as big as self.parts + let mut new_parts = Vec::with_capacity(self.parts.len()); + for (i, sub) in self.parts.drain(..).enumerate() { + let original_len = sub.normalized.len_original(); + let original_offsets = sub.original_offsets; - let mut new_len = 0; - let res = split_fn(i, sub.normalized); - if let Err(e) = res { - return itertools::Either::Left(std::iter::once(Err(e))); - } - - let parts = res - .unwrap() - .into_iter() - .map(|normalized| { - let len = normalized.len_original(); - let new_s = SubString { - normalized, - original_offsets: ( - original_offsets.0 + new_len, - original_offsets.0 + new_len + len, - ), - }; - new_len += len; - new_s - }) - .collect::>(); - - if new_len != original_len { - println!( - "Original offsets: {:?}\nNew: {:?}", - (0, original_len), - (0, new_len) - ); - itertools::Either::Left(std::iter::once(Err( - "Split pre-tokenized string must represent the entire original string" - .into(), - ))) - } else { - itertools::Either::Right(parts.into_iter().map(Ok)) - } - }) - .collect::>>()?; + let mut new_len = 0; + new_parts.extend(split_fn(i, sub.normalized)?.into_iter().map(|normalized| { + let len = normalized.len_original(); + let start = original_offsets.0 + new_len; + let end = original_offsets.0 + new_len + len; + let new_s = SubString::new(normalized, (start, end)); + new_len += len; + new_s + })); + if original_len != new_len { + return Err( + "Split pre-tokenized string must represent the entire original string".into(), + ); + } + } + self.parts = new_parts; Ok(()) } @@ -98,19 +82,20 @@ impl PreTokenizedString { /// Returns a list of normalized string and the associated offsets, /// either in original or normalized referential - pub fn get_normalized(&self, original: bool) -> Vec<(&str, Offsets)> { + pub fn get_normalized(&self, offset_type: OffsetReferential) -> Vec<(&str, Offsets)> { let mut offset = 0; self.iter() .map(|sub| { - let offsets = if original { - ( + let offsets = match offset_type { + OffsetReferential::Original => ( sub.original_offsets.0, sub.original_offsets.0 + sub.normalized.len_original(), - ) - } else { - let len = sub.normalized.len(); - offset += len; - (offset - len, offset) + ), + OffsetReferential::Normalized => { + let len = sub.normalized.len(); + offset += len; + (offset - len, offset) + } }; (sub.normalized.get(), offsets) @@ -176,4 +161,3 @@ impl<'a> IntoIterator for &'a PreTokenizedString { self.parts.iter() } } -