From 708a63514a6cf1c60bc598f86fba62856c87d722 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Sun, 29 Dec 2019 01:22:16 -0500 Subject: [PATCH] Add ability to retrieve ranges or NormalizedString --- tokenizers/src/tokenizer/normalizer.rs | 31 ++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index 0defc276..e67dadbf 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -42,6 +42,27 @@ impl NormalizedString { &self.original } + /// Return a range of the normalized string + pub fn get_range(&self, range: std::ops::Range) -> Option<&str> { + self.normalized.get(range) + } + + /// Return a range of the original string, using a range from the normalized string + pub fn get_range_original(&self, range: std::ops::Range) -> Option<&str> { + self.alignments + .get(range) + .map(|alignments| { + if alignments.is_empty() { + None + } else { + let start = alignments[0].0; + let end = alignments[alignments.len() - 1].1; + self.original.get(start..end) + } + }) + .flatten() + } + /// Applies transformations to the current normalized version, updating the current /// alignments with the new ones. /// This method expect an Iterator yielding each char of the new normalized string @@ -299,4 +320,14 @@ mod tests { &[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (6, 7)] ); } + + #[test] + fn original_range() { + let mut n = NormalizedString::from("Hello_______ World!"); + n.filter(|c| *c != '_').lowercase(); + let world_n = n.get_range(6..11).unwrap(); + let world_o = n.get_range_original(6..11).unwrap(); + assert_eq!(world_n, "world"); + assert_eq!(world_o, "World"); + } }