From 8c40c89836606491c0c68852e817f36c1cd69296 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Sat, 28 Dec 2019 15:25:50 -0500 Subject: [PATCH] Encoding uses NormalizedString --- tokenizers/src/tokenizer/encoding.rs | 47 ++++++++-------------------- 1 file changed, 13 insertions(+), 34 deletions(-) diff --git a/tokenizers/src/tokenizer/encoding.rs b/tokenizers/src/tokenizer/encoding.rs index 9f0cc1ed..55613460 100644 --- a/tokenizers/src/tokenizer/encoding.rs +++ b/tokenizers/src/tokenizer/encoding.rs @@ -1,3 +1,5 @@ +use crate::tokenizer::NormalizedString; + /// The various possible padding directions #[derive(Debug, Clone)] pub enum PaddingDirection { @@ -8,8 +10,7 @@ pub enum PaddingDirection { /// The Encoding struct represents the output of the Tokenizer #[derive(Default, PartialEq, Debug, Clone)] pub struct Encoding { - original: String, - normalized: String, + normalized: NormalizedString, ids: Vec, type_ids: Vec, tokens: Vec, @@ -21,8 +22,7 @@ pub struct Encoding { impl Encoding { #[allow(clippy::too_many_arguments)] pub fn new( - original: String, - normalized: String, + normalized: NormalizedString, ids: Vec, type_ids: Vec, tokens: Vec, @@ -32,7 +32,6 @@ impl Encoding { overflowing: Option>, ) -> Self { Encoding { - original, normalized, ids, type_ids, @@ -44,11 +43,7 @@ impl Encoding { } } - pub fn get_original(&self) -> &str { - &self.original - } - - pub fn get_normalized(&self) -> &str { + pub fn get_normalized(&self) -> &NormalizedString { &self.normalized } @@ -96,15 +91,7 @@ impl Encoding { let mut o_spe_toks = self.special_tokens_mask.split_off(max_len); let mut o_attent = self.attention_mask.split_off(max_len); - // Figure out offsets for original and normalized - // TODO: We will be able to retrive the right part of original - // only when we will have the alignment difference between both - // For now we will use the normalized offset... - let max = self - .offsets - .iter() - .fold(0, |max, (_, end)| if *end > max { *end } else { max }); - let trunc_original = self.original.split_off(max); + let max = self.offsets.last().map(|(_, end)| *end).unwrap_or(0); let trunc_normalized = self.normalized.split_off(max); if stride > 0 { @@ -117,7 +104,6 @@ impl Encoding { } self.overflowing = Some(Box::new(Encoding { - original: trunc_original, normalized: trunc_normalized, ids: o_ids, type_ids: o_type_ids, @@ -130,8 +116,7 @@ impl Encoding { } pub fn merge_with(&mut self, pair: Encoding) { - self.original.push_str(&pair.original); - self.normalized.push_str(&pair.normalized); + self.normalized.merge_with(&pair.normalized); self.ids.extend(pair.ids); self.type_ids.extend(pair.type_ids); self.tokens.extend(pair.tokens); @@ -224,8 +209,7 @@ mod tests { #[test] fn merge_encodings() { let mut a = Encoding { - original: String::from("Hello "), - normalized: String::from("Hello "), + normalized: NormalizedString::from("Hello "), ids: vec![1], type_ids: vec![0], tokens: vec![String::from("Hello ")], @@ -235,8 +219,7 @@ mod tests { overflowing: None, }; let b = Encoding { - original: String::from("World!"), - normalized: String::from("World!"), + normalized: NormalizedString::from("World!"), ids: vec![2], type_ids: vec![1], tokens: vec![String::from("World!")], @@ -250,8 +233,7 @@ mod tests { assert_eq!( a, Encoding { - original: String::from("Hello World!"), - normalized: String::from("Hello World!"), + normalized: NormalizedString::from("Hello World!"), ids: vec![1, 2], type_ids: vec![0, 1], tokens: vec![String::from("Hello "), String::from("World!")], @@ -266,8 +248,7 @@ mod tests { #[test] fn truncate() { let mut a = Encoding { - original: String::from("Hello World!"), - normalized: String::from("Hello World!"), + normalized: NormalizedString::from("Hello World!"), ids: vec![1, 2, 3], type_ids: vec![0, 0, 0], tokens: vec![ @@ -285,8 +266,7 @@ mod tests { assert_eq!( a, Encoding { - original: String::from("Hello World"), - normalized: String::from("Hello World"), + normalized: NormalizedString::from("Hello World"), ids: vec![1, 2], type_ids: vec![0, 0], tokens: vec![String::from("Hello"), String::from("World")], @@ -294,8 +274,7 @@ mod tests { special_tokens_mask: vec![0, 0], attention_mask: vec![1, 1], overflowing: Some(Box::new(Encoding { - original: String::from("!"), - normalized: String::from("!"), + normalized: NormalizedString::from("!"), ids: vec![3], type_ids: vec![0], tokens: vec![String::from("!")],