mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Encoding uses NormalizedString
This commit is contained in:
@ -1,3 +1,5 @@
|
|||||||
|
use crate::tokenizer::NormalizedString;
|
||||||
|
|
||||||
/// The various possible padding directions
|
/// The various possible padding directions
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum PaddingDirection {
|
pub enum PaddingDirection {
|
||||||
@ -8,8 +10,7 @@ pub enum PaddingDirection {
|
|||||||
/// The Encoding struct represents the output of the Tokenizer
|
/// The Encoding struct represents the output of the Tokenizer
|
||||||
#[derive(Default, PartialEq, Debug, Clone)]
|
#[derive(Default, PartialEq, Debug, Clone)]
|
||||||
pub struct Encoding {
|
pub struct Encoding {
|
||||||
original: String,
|
normalized: NormalizedString,
|
||||||
normalized: String,
|
|
||||||
ids: Vec<u32>,
|
ids: Vec<u32>,
|
||||||
type_ids: Vec<u32>,
|
type_ids: Vec<u32>,
|
||||||
tokens: Vec<String>,
|
tokens: Vec<String>,
|
||||||
@ -21,8 +22,7 @@ pub struct Encoding {
|
|||||||
impl Encoding {
|
impl Encoding {
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
original: String,
|
normalized: NormalizedString,
|
||||||
normalized: String,
|
|
||||||
ids: Vec<u32>,
|
ids: Vec<u32>,
|
||||||
type_ids: Vec<u32>,
|
type_ids: Vec<u32>,
|
||||||
tokens: Vec<String>,
|
tokens: Vec<String>,
|
||||||
@ -32,7 +32,6 @@ impl Encoding {
|
|||||||
overflowing: Option<Box<Encoding>>,
|
overflowing: Option<Box<Encoding>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Encoding {
|
Encoding {
|
||||||
original,
|
|
||||||
normalized,
|
normalized,
|
||||||
ids,
|
ids,
|
||||||
type_ids,
|
type_ids,
|
||||||
@ -44,11 +43,7 @@ impl Encoding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_original(&self) -> &str {
|
pub fn get_normalized(&self) -> &NormalizedString {
|
||||||
&self.original
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_normalized(&self) -> &str {
|
|
||||||
&self.normalized
|
&self.normalized
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,15 +91,7 @@ impl Encoding {
|
|||||||
let mut o_spe_toks = self.special_tokens_mask.split_off(max_len);
|
let mut o_spe_toks = self.special_tokens_mask.split_off(max_len);
|
||||||
let mut o_attent = self.attention_mask.split_off(max_len);
|
let mut o_attent = self.attention_mask.split_off(max_len);
|
||||||
|
|
||||||
// Figure out offsets for original and normalized
|
let max = self.offsets.last().map(|(_, end)| *end).unwrap_or(0);
|
||||||
// TODO: We will be able to retrive the right part of original
|
|
||||||
// only when we will have the alignment difference between both
|
|
||||||
// For now we will use the normalized offset...
|
|
||||||
let max = self
|
|
||||||
.offsets
|
|
||||||
.iter()
|
|
||||||
.fold(0, |max, (_, end)| if *end > max { *end } else { max });
|
|
||||||
let trunc_original = self.original.split_off(max);
|
|
||||||
let trunc_normalized = self.normalized.split_off(max);
|
let trunc_normalized = self.normalized.split_off(max);
|
||||||
|
|
||||||
if stride > 0 {
|
if stride > 0 {
|
||||||
@ -117,7 +104,6 @@ impl Encoding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.overflowing = Some(Box::new(Encoding {
|
self.overflowing = Some(Box::new(Encoding {
|
||||||
original: trunc_original,
|
|
||||||
normalized: trunc_normalized,
|
normalized: trunc_normalized,
|
||||||
ids: o_ids,
|
ids: o_ids,
|
||||||
type_ids: o_type_ids,
|
type_ids: o_type_ids,
|
||||||
@ -130,8 +116,7 @@ impl Encoding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn merge_with(&mut self, pair: Encoding) {
|
pub fn merge_with(&mut self, pair: Encoding) {
|
||||||
self.original.push_str(&pair.original);
|
self.normalized.merge_with(&pair.normalized);
|
||||||
self.normalized.push_str(&pair.normalized);
|
|
||||||
self.ids.extend(pair.ids);
|
self.ids.extend(pair.ids);
|
||||||
self.type_ids.extend(pair.type_ids);
|
self.type_ids.extend(pair.type_ids);
|
||||||
self.tokens.extend(pair.tokens);
|
self.tokens.extend(pair.tokens);
|
||||||
@ -224,8 +209,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn merge_encodings() {
|
fn merge_encodings() {
|
||||||
let mut a = Encoding {
|
let mut a = Encoding {
|
||||||
original: String::from("Hello "),
|
normalized: NormalizedString::from("Hello "),
|
||||||
normalized: String::from("Hello "),
|
|
||||||
ids: vec![1],
|
ids: vec![1],
|
||||||
type_ids: vec![0],
|
type_ids: vec![0],
|
||||||
tokens: vec![String::from("Hello ")],
|
tokens: vec![String::from("Hello ")],
|
||||||
@ -235,8 +219,7 @@ mod tests {
|
|||||||
overflowing: None,
|
overflowing: None,
|
||||||
};
|
};
|
||||||
let b = Encoding {
|
let b = Encoding {
|
||||||
original: String::from("World!"),
|
normalized: NormalizedString::from("World!"),
|
||||||
normalized: String::from("World!"),
|
|
||||||
ids: vec![2],
|
ids: vec![2],
|
||||||
type_ids: vec![1],
|
type_ids: vec![1],
|
||||||
tokens: vec![String::from("World!")],
|
tokens: vec![String::from("World!")],
|
||||||
@ -250,8 +233,7 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
a,
|
a,
|
||||||
Encoding {
|
Encoding {
|
||||||
original: String::from("Hello World!"),
|
normalized: NormalizedString::from("Hello World!"),
|
||||||
normalized: String::from("Hello World!"),
|
|
||||||
ids: vec![1, 2],
|
ids: vec![1, 2],
|
||||||
type_ids: vec![0, 1],
|
type_ids: vec![0, 1],
|
||||||
tokens: vec![String::from("Hello "), String::from("World!")],
|
tokens: vec![String::from("Hello "), String::from("World!")],
|
||||||
@ -266,8 +248,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn truncate() {
|
fn truncate() {
|
||||||
let mut a = Encoding {
|
let mut a = Encoding {
|
||||||
original: String::from("Hello World!"),
|
normalized: NormalizedString::from("Hello World!"),
|
||||||
normalized: String::from("Hello World!"),
|
|
||||||
ids: vec![1, 2, 3],
|
ids: vec![1, 2, 3],
|
||||||
type_ids: vec![0, 0, 0],
|
type_ids: vec![0, 0, 0],
|
||||||
tokens: vec![
|
tokens: vec![
|
||||||
@ -285,8 +266,7 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
a,
|
a,
|
||||||
Encoding {
|
Encoding {
|
||||||
original: String::from("Hello World"),
|
normalized: NormalizedString::from("Hello World"),
|
||||||
normalized: String::from("Hello World"),
|
|
||||||
ids: vec![1, 2],
|
ids: vec![1, 2],
|
||||||
type_ids: vec![0, 0],
|
type_ids: vec![0, 0],
|
||||||
tokens: vec![String::from("Hello"), String::from("World")],
|
tokens: vec![String::from("Hello"), String::from("World")],
|
||||||
@ -294,8 +274,7 @@ mod tests {
|
|||||||
special_tokens_mask: vec![0, 0],
|
special_tokens_mask: vec![0, 0],
|
||||||
attention_mask: vec![1, 1],
|
attention_mask: vec![1, 1],
|
||||||
overflowing: Some(Box::new(Encoding {
|
overflowing: Some(Box::new(Encoding {
|
||||||
original: String::from("!"),
|
normalized: NormalizedString::from("!"),
|
||||||
normalized: String::from("!"),
|
|
||||||
ids: vec![3],
|
ids: vec![3],
|
||||||
type_ids: vec![0],
|
type_ids: vec![0],
|
||||||
tokens: vec![String::from("!")],
|
tokens: vec![String::from("!")],
|
||||||
|
Reference in New Issue
Block a user