Add some tests for Encoding

This commit is contained in:
Anthony MOI
2019-12-12 19:03:42 -05:00
parent da45a1d6d0
commit 7711946882

View File

@ -1,5 +1,5 @@
/// The Encoding struct represents the output of the Tokenizer
#[derive(Default)]
#[derive(Default, PartialEq, Debug)]
pub struct Encoding {
original: String,
normalized: String,
@ -140,6 +140,8 @@ impl Encoding {
}
}
/// Prepend the `stride` last elements of the `previous` Vec to the current Vec
// A new Vec is instantiated though.
fn prepend_stride<T: Clone>(previous: &Vec<T>, current: Vec<T>, stride: usize) -> Vec<T> {
let prev = previous
.iter()
@ -151,3 +153,103 @@ fn prepend_stride<T: Clone>(previous: &Vec<T>, current: Vec<T>, stride: usize) -
[&prev[..], &current[..]].concat()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prepend_stride() {
let prev = vec![1, 2, 3, 4, 5, 6, 7, 8];
let curr = vec![9, 10, 11, 12];
assert_eq!(prepend_stride(&prev, curr, 3), vec![6, 7, 8, 9, 10, 11, 12]);
}
#[test]
fn merge_encodings() {
let mut a = Encoding {
original: String::from("Hello "),
normalized: String::from("Hello "),
ids: vec![1],
type_ids: vec![0],
tokens: vec![String::from("Hello ")],
offsets: vec![(0, 6)],
special_tokens_mask: vec![0],
attention_mask: vec![1],
overflowing: None,
};
let b = Encoding {
original: String::from("World!"),
normalized: String::from("World!"),
ids: vec![2],
type_ids: vec![1],
tokens: vec![String::from("World!")],
offsets: vec![(0, 6)],
special_tokens_mask: vec![0],
attention_mask: vec![1],
overflowing: None,
};
a.merge_with(b);
assert_eq!(
a,
Encoding {
original: String::from("Hello World!"),
normalized: String::from("Hello World!"),
ids: vec![1, 2],
type_ids: vec![0, 1],
tokens: vec![String::from("Hello "), String::from("World!")],
offsets: vec![(0, 6), (6, 12)],
special_tokens_mask: vec![0, 0],
attention_mask: vec![1, 1],
overflowing: None,
}
);
}
#[test]
fn truncate() {
let mut a = Encoding {
original: String::from("Hello World!"),
normalized: String::from("Hello World!"),
ids: vec![1, 2, 3],
type_ids: vec![0, 0, 0],
tokens: vec![
String::from("Hello"),
String::from("World"),
String::from("!"),
],
offsets: vec![(0, 5), (6, 11), (11, 12)],
special_tokens_mask: vec![0, 0, 0],
attention_mask: vec![1, 1, 1],
overflowing: None,
};
a.truncate(2, 0);
assert_eq!(
a,
Encoding {
original: String::from("Hello World"),
normalized: String::from("Hello World"),
ids: vec![1, 2],
type_ids: vec![0, 0],
tokens: vec![String::from("Hello"), String::from("World")],
offsets: vec![(0, 5), (6, 11)],
special_tokens_mask: vec![0, 0],
attention_mask: vec![1, 1],
overflowing: Some(Box::new(Encoding {
original: String::from("!"),
normalized: String::from("!"),
ids: vec![3],
type_ids: vec![0],
tokens: vec![String::from("!")],
offsets: vec![(11, 12)],
special_tokens_mask: vec![0],
attention_mask: vec![1],
overflowing: None,
}))
}
);
}
}