Encoding uses NormalizedString

This commit is contained in:
Anthony MOI
2019-12-28 15:25:50 -05:00
parent 162829b7a9
commit 8c40c89836

View File

@ -1,3 +1,5 @@
use crate::tokenizer::NormalizedString;
/// The various possible padding directions
#[derive(Debug, Clone)]
pub enum PaddingDirection {
@ -8,8 +10,7 @@ pub enum PaddingDirection {
/// The Encoding struct represents the output of the Tokenizer
#[derive(Default, PartialEq, Debug, Clone)]
pub struct Encoding {
original: String,
normalized: String,
normalized: NormalizedString,
ids: Vec<u32>,
type_ids: Vec<u32>,
tokens: Vec<String>,
@ -21,8 +22,7 @@ pub struct Encoding {
impl Encoding {
#[allow(clippy::too_many_arguments)]
pub fn new(
original: String,
normalized: String,
normalized: NormalizedString,
ids: Vec<u32>,
type_ids: Vec<u32>,
tokens: Vec<String>,
@ -32,7 +32,6 @@ impl Encoding {
overflowing: Option<Box<Encoding>>,
) -> Self {
Encoding {
original,
normalized,
ids,
type_ids,
@ -44,11 +43,7 @@ impl Encoding {
}
}
pub fn get_original(&self) -> &str {
&self.original
}
pub fn get_normalized(&self) -> &str {
pub fn get_normalized(&self) -> &NormalizedString {
&self.normalized
}
@ -96,15 +91,7 @@ impl Encoding {
let mut o_spe_toks = self.special_tokens_mask.split_off(max_len);
let mut o_attent = self.attention_mask.split_off(max_len);
// Figure out offsets for original and normalized
// TODO: We will be able to retrive the right part of original
// only when we will have the alignment difference between both
// For now we will use the normalized offset...
let max = self
.offsets
.iter()
.fold(0, |max, (_, end)| if *end > max { *end } else { max });
let trunc_original = self.original.split_off(max);
let max = self.offsets.last().map(|(_, end)| *end).unwrap_or(0);
let trunc_normalized = self.normalized.split_off(max);
if stride > 0 {
@ -117,7 +104,6 @@ impl Encoding {
}
self.overflowing = Some(Box::new(Encoding {
original: trunc_original,
normalized: trunc_normalized,
ids: o_ids,
type_ids: o_type_ids,
@ -130,8 +116,7 @@ impl Encoding {
}
pub fn merge_with(&mut self, pair: Encoding) {
self.original.push_str(&pair.original);
self.normalized.push_str(&pair.normalized);
self.normalized.merge_with(&pair.normalized);
self.ids.extend(pair.ids);
self.type_ids.extend(pair.type_ids);
self.tokens.extend(pair.tokens);
@ -224,8 +209,7 @@ mod tests {
#[test]
fn merge_encodings() {
let mut a = Encoding {
original: String::from("Hello "),
normalized: String::from("Hello "),
normalized: NormalizedString::from("Hello "),
ids: vec![1],
type_ids: vec![0],
tokens: vec![String::from("Hello ")],
@ -235,8 +219,7 @@ mod tests {
overflowing: None,
};
let b = Encoding {
original: String::from("World!"),
normalized: String::from("World!"),
normalized: NormalizedString::from("World!"),
ids: vec![2],
type_ids: vec![1],
tokens: vec![String::from("World!")],
@ -250,8 +233,7 @@ mod tests {
assert_eq!(
a,
Encoding {
original: String::from("Hello World!"),
normalized: String::from("Hello World!"),
normalized: NormalizedString::from("Hello World!"),
ids: vec![1, 2],
type_ids: vec![0, 1],
tokens: vec![String::from("Hello "), String::from("World!")],
@ -266,8 +248,7 @@ mod tests {
#[test]
fn truncate() {
let mut a = Encoding {
original: String::from("Hello World!"),
normalized: String::from("Hello World!"),
normalized: NormalizedString::from("Hello World!"),
ids: vec![1, 2, 3],
type_ids: vec![0, 0, 0],
tokens: vec![
@ -285,8 +266,7 @@ mod tests {
assert_eq!(
a,
Encoding {
original: String::from("Hello World"),
normalized: String::from("Hello World"),
normalized: NormalizedString::from("Hello World"),
ids: vec![1, 2],
type_ids: vec![0, 0],
tokens: vec![String::from("Hello"), String::from("World")],
@ -294,8 +274,7 @@ mod tests {
special_tokens_mask: vec![0, 0],
attention_mask: vec![1, 1],
overflowing: Some(Box::new(Encoding {
original: String::from("!"),
normalized: String::from("!"),
normalized: NormalizedString::from("!"),
ids: vec![3],
type_ids: vec![0],
tokens: vec![String::from("!")],