mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Encoding uses NormalizedString
This commit is contained in:
@ -1,3 +1,5 @@
|
||||
use crate::tokenizer::NormalizedString;
|
||||
|
||||
/// The various possible padding directions
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PaddingDirection {
|
||||
@ -8,8 +10,7 @@ pub enum PaddingDirection {
|
||||
/// The Encoding struct represents the output of the Tokenizer
|
||||
#[derive(Default, PartialEq, Debug, Clone)]
|
||||
pub struct Encoding {
|
||||
original: String,
|
||||
normalized: String,
|
||||
normalized: NormalizedString,
|
||||
ids: Vec<u32>,
|
||||
type_ids: Vec<u32>,
|
||||
tokens: Vec<String>,
|
||||
@ -21,8 +22,7 @@ pub struct Encoding {
|
||||
impl Encoding {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
original: String,
|
||||
normalized: String,
|
||||
normalized: NormalizedString,
|
||||
ids: Vec<u32>,
|
||||
type_ids: Vec<u32>,
|
||||
tokens: Vec<String>,
|
||||
@ -32,7 +32,6 @@ impl Encoding {
|
||||
overflowing: Option<Box<Encoding>>,
|
||||
) -> Self {
|
||||
Encoding {
|
||||
original,
|
||||
normalized,
|
||||
ids,
|
||||
type_ids,
|
||||
@ -44,11 +43,7 @@ impl Encoding {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_original(&self) -> &str {
|
||||
&self.original
|
||||
}
|
||||
|
||||
pub fn get_normalized(&self) -> &str {
|
||||
pub fn get_normalized(&self) -> &NormalizedString {
|
||||
&self.normalized
|
||||
}
|
||||
|
||||
@ -96,15 +91,7 @@ impl Encoding {
|
||||
let mut o_spe_toks = self.special_tokens_mask.split_off(max_len);
|
||||
let mut o_attent = self.attention_mask.split_off(max_len);
|
||||
|
||||
// Figure out offsets for original and normalized
|
||||
// TODO: We will be able to retrive the right part of original
|
||||
// only when we will have the alignment difference between both
|
||||
// For now we will use the normalized offset...
|
||||
let max = self
|
||||
.offsets
|
||||
.iter()
|
||||
.fold(0, |max, (_, end)| if *end > max { *end } else { max });
|
||||
let trunc_original = self.original.split_off(max);
|
||||
let max = self.offsets.last().map(|(_, end)| *end).unwrap_or(0);
|
||||
let trunc_normalized = self.normalized.split_off(max);
|
||||
|
||||
if stride > 0 {
|
||||
@ -117,7 +104,6 @@ impl Encoding {
|
||||
}
|
||||
|
||||
self.overflowing = Some(Box::new(Encoding {
|
||||
original: trunc_original,
|
||||
normalized: trunc_normalized,
|
||||
ids: o_ids,
|
||||
type_ids: o_type_ids,
|
||||
@ -130,8 +116,7 @@ impl Encoding {
|
||||
}
|
||||
|
||||
pub fn merge_with(&mut self, pair: Encoding) {
|
||||
self.original.push_str(&pair.original);
|
||||
self.normalized.push_str(&pair.normalized);
|
||||
self.normalized.merge_with(&pair.normalized);
|
||||
self.ids.extend(pair.ids);
|
||||
self.type_ids.extend(pair.type_ids);
|
||||
self.tokens.extend(pair.tokens);
|
||||
@ -224,8 +209,7 @@ mod tests {
|
||||
#[test]
|
||||
fn merge_encodings() {
|
||||
let mut a = Encoding {
|
||||
original: String::from("Hello "),
|
||||
normalized: String::from("Hello "),
|
||||
normalized: NormalizedString::from("Hello "),
|
||||
ids: vec![1],
|
||||
type_ids: vec![0],
|
||||
tokens: vec![String::from("Hello ")],
|
||||
@ -235,8 +219,7 @@ mod tests {
|
||||
overflowing: None,
|
||||
};
|
||||
let b = Encoding {
|
||||
original: String::from("World!"),
|
||||
normalized: String::from("World!"),
|
||||
normalized: NormalizedString::from("World!"),
|
||||
ids: vec![2],
|
||||
type_ids: vec![1],
|
||||
tokens: vec![String::from("World!")],
|
||||
@ -250,8 +233,7 @@ mod tests {
|
||||
assert_eq!(
|
||||
a,
|
||||
Encoding {
|
||||
original: String::from("Hello World!"),
|
||||
normalized: String::from("Hello World!"),
|
||||
normalized: NormalizedString::from("Hello World!"),
|
||||
ids: vec![1, 2],
|
||||
type_ids: vec![0, 1],
|
||||
tokens: vec![String::from("Hello "), String::from("World!")],
|
||||
@ -266,8 +248,7 @@ mod tests {
|
||||
#[test]
|
||||
fn truncate() {
|
||||
let mut a = Encoding {
|
||||
original: String::from("Hello World!"),
|
||||
normalized: String::from("Hello World!"),
|
||||
normalized: NormalizedString::from("Hello World!"),
|
||||
ids: vec![1, 2, 3],
|
||||
type_ids: vec![0, 0, 0],
|
||||
tokens: vec![
|
||||
@ -285,8 +266,7 @@ mod tests {
|
||||
assert_eq!(
|
||||
a,
|
||||
Encoding {
|
||||
original: String::from("Hello World"),
|
||||
normalized: String::from("Hello World"),
|
||||
normalized: NormalizedString::from("Hello World"),
|
||||
ids: vec![1, 2],
|
||||
type_ids: vec![0, 0],
|
||||
tokens: vec![String::from("Hello"), String::from("World")],
|
||||
@ -294,8 +274,7 @@ mod tests {
|
||||
special_tokens_mask: vec![0, 0],
|
||||
attention_mask: vec![1, 1],
|
||||
overflowing: Some(Box::new(Encoding {
|
||||
original: String::from("!"),
|
||||
normalized: String::from("!"),
|
||||
normalized: NormalizedString::from("!"),
|
||||
ids: vec![3],
|
||||
type_ids: vec![0],
|
||||
tokens: vec![String::from("!")],
|
||||
|
Reference in New Issue
Block a user