From 162829b7a99083900dab1f1efa87a81622bb6ef9 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Sat, 28 Dec 2019 15:24:09 -0500 Subject: [PATCH] Introduce NormalizedString --- tokenizers/src/tokenizer/mod.rs | 8 +- tokenizers/src/tokenizer/normalizer.rs | 285 +++++++++++++++++++++++++ 2 files changed, 288 insertions(+), 5 deletions(-) create mode 100644 tokenizers/src/tokenizer/normalizer.rs diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index f147659e..0f220611 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -24,14 +24,12 @@ use std::{ }; mod encoding; +mod normalizer; + pub use encoding::*; +pub use normalizer::*; pub type Result = std::result::Result>; - -/// A Normalizer takes care of pre-processing strings -pub trait Normalizer { - fn normalize(&self, s: String) -> Result; -} pub type Offsets = (usize, usize); /// A PreTokenizer takes care of pre-tokenizing strings before this goes to the model diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs new file mode 100644 index 00000000..c30337d3 --- /dev/null +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -0,0 +1,285 @@ +use super::Result; +use std::cmp::Ordering; +use unicode_normalization::UnicodeNormalization; + +/// A Normalizer takes care of pre-processing strings +pub trait Normalizer { + fn normalize(&self, normalized: &mut NormalizedString) -> Result<()>; +} + +/// A normalized string takes care of keeping both versions of a String, and +/// provides necessary alignments to retrieve ranges of both strings +#[derive(Default, Debug, Clone)] +pub struct NormalizedString { + original: String, + normalized: String, + /// Mapping from normalized string to original one + /// (pos, changes) where pos is the position in the modified string, and changes an isize + /// representing the number of insertions or deletions + alignments: Vec<(usize, usize)>, +} + +impl std::cmp::PartialEq for NormalizedString { + fn eq(&self, other: &NormalizedString) -> bool { + self.normalized == other.normalized + } +} + +impl NormalizedString { + pub fn from(s: &str) -> Self { + NormalizedString { + original: s.to_owned(), + normalized: s.to_owned(), + alignments: (0..s.chars().count()).map(|v| (v, v + 1)).collect(), + } + } + + pub fn get(&self) -> &str { + &self.normalized + } + + pub fn get_original(&self) -> &str { + &self.original + } + + /// Applies transformations to the current normalized version, updating the current + /// alignments with the new ones. + /// This method expect an Iterator yielding each char of the new normalized string + /// with a `change` isize equals to: + /// - `1` if this is a new char + /// - `-N` if the char is right before N removed chars + /// - `0` if this char represents the old one (even if changed) + /// + /// `change` should never be more than `1`. If multiple chars are added, each of + /// them has a `change` of `1`, but more doesn't make any sense. + /// We treat any value above `1` as `1`. + pub fn transform>(&mut self, dest: I) { + let mut offset: isize = 0; + let (ch, alignments): (Vec<_>, Vec<_>) = dest + .enumerate() + .map(|(index, (c, changes))| { + let uof = if offset < 0 { + -offset as usize + } else { + offset as usize + }; + // A positive offset means we added characters. So we need to remove this offset + // from the current index to find out the previous id + let idx = if offset < 0 { index + uof } else { index - uof }; + let align = match changes.cmp(&0) { + // This is a newly inserted character, so we use the alignment from the + // previous one + Ordering::Greater => { + if idx < 1 { + Some((0, 0)) + } else { + offset += 1; + self.alignments.get(idx - 1).copied() + } + } + // No changes required here + Ordering::Equal => self.alignments.get(idx).copied(), + // Some characters where removed, so we merge our range with the one from the + // removed characters as the new alignment + Ordering::Less => { + let uch = -changes as usize; + offset += changes; + self.alignments.get(idx..idx + uch).map(|alignments| { + let min = alignments + .iter() + .map(|(start, end)| usize::min(*start, *end)) + .min() + .unwrap(); + let max = alignments + .iter() + .map(|(start, end)| usize::max(*start, *end)) + .max() + .unwrap(); + (min, max) + }) + } + }; + + // Then we keep only the char for string reconstruction + ( + c, + align.expect("Bad alignement in NormalizedString::transform"), + ) + }) + .unzip(); + self.alignments = alignments; + self.normalized = ch.iter().collect::(); + } + + /// Applies NFD normalization + pub fn nfd(&mut self) -> &mut Self { + self.transform(self.get().to_owned().nfd()); + self + } + + /// Applies NFKD normalization + pub fn nfkd(&mut self) -> &mut Self { + self.transform(self.get().to_owned().nfkd()); + self + } + + /// Applies NFC normalization + pub fn nfc(&mut self) -> &mut Self { + self.transform(self.get().to_owned().nfc()); + self + } + + /// Applies NFKC normalization + pub fn nfkc(&mut self) -> &mut Self { + self.transform(self.get().to_owned().nfkc()); + self + } + + /// Applies filtering over our characters + pub fn filter bool>(&mut self, filter: F) -> &mut Self { + let mut removed = 0; + let mut filtered = self + .normalized + .chars() + // We need to collect here to be able to reverse the iterator because Char is not ended + .collect::>() + .into_iter() + .rev() + .map(|c| { + let keep = filter(&c); + if keep { + if removed > 0 { + let res = (c, -removed); + removed = 0; + Some(res) + } else { + Some((c, 0)) + } + } else { + removed += 1; + None + } + }) + .collect::>(); + // For some reason, if we use rev, and unwrap directly, some parts of the tuples we return + // above get mixed up... So we collect first, then reverse in place + filtered.reverse(); + self.transform(filtered.iter().filter(|o| o.is_some()).map(|o| o.unwrap())); + self + } + + /// Map our characters + pub fn map char>(&mut self, map: F) -> &mut Self { + self.normalized = self.normalized.chars().map(map).collect::(); + self + } + + /// Calls the given function for each characters + pub fn for_each(&mut self, foreach: F) -> &mut Self { + self.normalized.chars().for_each(foreach); + self + } + + /// Lowercase + pub fn lowercase(&mut self) -> &mut Self { + self.normalized.to_lowercase(); + self + } + + /// Uppercase + pub fn uppercase(&mut self) -> &mut Self { + self.normalized.to_uppercase(); + self + } + + /// Split off ourselves, returning a new Self that contains the range [at, len). + /// self will then contain the range [0, at). + /// + /// Panic if at > len + pub fn split_off(&mut self, at: usize) -> Self { + let normalized = self.normalized.split_off(at); + let alignments = self.alignments.split_off(at); + let original_at = self.alignments.last().map(|(_, end)| *end).unwrap_or(0); + let original = self.original.split_off(original_at); + + NormalizedString { + original, + normalized, + alignments, + } + } + + /// Merge with the given NormalizedString by appending it to self + pub fn merge_with(&mut self, other: &NormalizedString) { + self.original.push_str(&other.original); + let len = self.len(); + self.alignments.extend( + other + .alignments + .iter() + .map(|(start, end)| (start + len, end + len)), + ); + self.normalized.push_str(&other.normalized); + } + + /// Returns the length + pub fn len(&self) -> usize { + self.normalized.len() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use unicode_categories::UnicodeCategories; + + #[test] + fn new_chars() { + let mut n = NormalizedString::from("élégant"); + n.nfd(); + assert_eq!( + &n.alignments, + &[ + (0, 1), + (0, 1), + (1, 2), + (2, 3), + (2, 3), + (3, 4), + (4, 5), + (5, 6), + (6, 7) + ] + ); + } + + #[test] + fn unchanged() { + let mut n = NormalizedString::from("élégant"); + n.nfd().filter(|c| !c.is_mark_nonspacing()); + assert_eq!( + &n.alignments, + &[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)] + ); + } + + #[test] + fn removed_chars() { + let mut n = NormalizedString::from("élégant"); + n.filter(|c| *c != 'n'); + assert_eq!( + &n.alignments, + &[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (6, 7)] + ); + } + + #[test] + fn mixed_addition_and_removal() { + let mut n = NormalizedString::from("élégant"); + n.nfd().filter(|c| !c.is_mark_nonspacing() && *c != 'n'); + assert_eq!( + &n.alignments, + &[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (6, 7)] + ); + } +}