Encoding uses NormalizedString

This commit is contained in:
Anthony MOI
2019-12-28 15:25:50 -05:00
parent 162829b7a9
commit 8c40c89836

View File

@ -1,3 +1,5 @@
use crate::tokenizer::NormalizedString;
/// The various possible padding directions /// The various possible padding directions
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum PaddingDirection { pub enum PaddingDirection {
@ -8,8 +10,7 @@ pub enum PaddingDirection {
/// The Encoding struct represents the output of the Tokenizer /// The Encoding struct represents the output of the Tokenizer
#[derive(Default, PartialEq, Debug, Clone)] #[derive(Default, PartialEq, Debug, Clone)]
pub struct Encoding { pub struct Encoding {
original: String, normalized: NormalizedString,
normalized: String,
ids: Vec<u32>, ids: Vec<u32>,
type_ids: Vec<u32>, type_ids: Vec<u32>,
tokens: Vec<String>, tokens: Vec<String>,
@ -21,8 +22,7 @@ pub struct Encoding {
impl Encoding { impl Encoding {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
original: String, normalized: NormalizedString,
normalized: String,
ids: Vec<u32>, ids: Vec<u32>,
type_ids: Vec<u32>, type_ids: Vec<u32>,
tokens: Vec<String>, tokens: Vec<String>,
@ -32,7 +32,6 @@ impl Encoding {
overflowing: Option<Box<Encoding>>, overflowing: Option<Box<Encoding>>,
) -> Self { ) -> Self {
Encoding { Encoding {
original,
normalized, normalized,
ids, ids,
type_ids, type_ids,
@ -44,11 +43,7 @@ impl Encoding {
} }
} }
pub fn get_original(&self) -> &str { pub fn get_normalized(&self) -> &NormalizedString {
&self.original
}
pub fn get_normalized(&self) -> &str {
&self.normalized &self.normalized
} }
@ -96,15 +91,7 @@ impl Encoding {
let mut o_spe_toks = self.special_tokens_mask.split_off(max_len); let mut o_spe_toks = self.special_tokens_mask.split_off(max_len);
let mut o_attent = self.attention_mask.split_off(max_len); let mut o_attent = self.attention_mask.split_off(max_len);
// Figure out offsets for original and normalized let max = self.offsets.last().map(|(_, end)| *end).unwrap_or(0);
// TODO: We will be able to retrive the right part of original
// only when we will have the alignment difference between both
// For now we will use the normalized offset...
let max = self
.offsets
.iter()
.fold(0, |max, (_, end)| if *end > max { *end } else { max });
let trunc_original = self.original.split_off(max);
let trunc_normalized = self.normalized.split_off(max); let trunc_normalized = self.normalized.split_off(max);
if stride > 0 { if stride > 0 {
@ -117,7 +104,6 @@ impl Encoding {
} }
self.overflowing = Some(Box::new(Encoding { self.overflowing = Some(Box::new(Encoding {
original: trunc_original,
normalized: trunc_normalized, normalized: trunc_normalized,
ids: o_ids, ids: o_ids,
type_ids: o_type_ids, type_ids: o_type_ids,
@ -130,8 +116,7 @@ impl Encoding {
} }
pub fn merge_with(&mut self, pair: Encoding) { pub fn merge_with(&mut self, pair: Encoding) {
self.original.push_str(&pair.original); self.normalized.merge_with(&pair.normalized);
self.normalized.push_str(&pair.normalized);
self.ids.extend(pair.ids); self.ids.extend(pair.ids);
self.type_ids.extend(pair.type_ids); self.type_ids.extend(pair.type_ids);
self.tokens.extend(pair.tokens); self.tokens.extend(pair.tokens);
@ -224,8 +209,7 @@ mod tests {
#[test] #[test]
fn merge_encodings() { fn merge_encodings() {
let mut a = Encoding { let mut a = Encoding {
original: String::from("Hello "), normalized: NormalizedString::from("Hello "),
normalized: String::from("Hello "),
ids: vec![1], ids: vec![1],
type_ids: vec![0], type_ids: vec![0],
tokens: vec![String::from("Hello ")], tokens: vec![String::from("Hello ")],
@ -235,8 +219,7 @@ mod tests {
overflowing: None, overflowing: None,
}; };
let b = Encoding { let b = Encoding {
original: String::from("World!"), normalized: NormalizedString::from("World!"),
normalized: String::from("World!"),
ids: vec![2], ids: vec![2],
type_ids: vec![1], type_ids: vec![1],
tokens: vec![String::from("World!")], tokens: vec![String::from("World!")],
@ -250,8 +233,7 @@ mod tests {
assert_eq!( assert_eq!(
a, a,
Encoding { Encoding {
original: String::from("Hello World!"), normalized: NormalizedString::from("Hello World!"),
normalized: String::from("Hello World!"),
ids: vec![1, 2], ids: vec![1, 2],
type_ids: vec![0, 1], type_ids: vec![0, 1],
tokens: vec![String::from("Hello "), String::from("World!")], tokens: vec![String::from("Hello "), String::from("World!")],
@ -266,8 +248,7 @@ mod tests {
#[test] #[test]
fn truncate() { fn truncate() {
let mut a = Encoding { let mut a = Encoding {
original: String::from("Hello World!"), normalized: NormalizedString::from("Hello World!"),
normalized: String::from("Hello World!"),
ids: vec![1, 2, 3], ids: vec![1, 2, 3],
type_ids: vec![0, 0, 0], type_ids: vec![0, 0, 0],
tokens: vec![ tokens: vec![
@ -285,8 +266,7 @@ mod tests {
assert_eq!( assert_eq!(
a, a,
Encoding { Encoding {
original: String::from("Hello World"), normalized: NormalizedString::from("Hello World"),
normalized: String::from("Hello World"),
ids: vec![1, 2], ids: vec![1, 2],
type_ids: vec![0, 0], type_ids: vec![0, 0],
tokens: vec![String::from("Hello"), String::from("World")], tokens: vec![String::from("Hello"), String::from("World")],
@ -294,8 +274,7 @@ mod tests {
special_tokens_mask: vec![0, 0], special_tokens_mask: vec![0, 0],
attention_mask: vec![1, 1], attention_mask: vec![1, 1],
overflowing: Some(Box::new(Encoding { overflowing: Some(Box::new(Encoding {
original: String::from("!"), normalized: NormalizedString::from("!"),
normalized: String::from("!"),
ids: vec![3], ids: vec![3],
type_ids: vec![0], type_ids: vec![0],
tokens: vec![String::from("!")], tokens: vec![String::from("!")],