mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Ability to encode with added tokens
This commit is contained in:
@ -191,44 +191,80 @@ impl Tokenizer {
|
||||
/// Encode the given sentence
|
||||
pub fn encode(&self, input: EncodeInput) -> Result<Encoding> {
|
||||
let generate_output = move |sentence: String, type_id: u32| -> Result<Encoding> {
|
||||
// 1. Normalization
|
||||
// TODO: Make sure we have the offsets update necessary to go from the original text to
|
||||
// the normalized one
|
||||
let original = sentence.clone();
|
||||
let normalized = self.normalize(sentence)?;
|
||||
// First we need to split into as many sequences as needed to avoid splitting
|
||||
// on our added tokens
|
||||
let mut encodings = self
|
||||
.split_on_added_tokens(&sentence)
|
||||
.into_iter()
|
||||
.map(|(sentence, id)| -> Result<Encoding> {
|
||||
// If this is one of our added tokens, lets return an encoding directly
|
||||
if let Some(id) = id {
|
||||
return Ok(Encoding::new(
|
||||
sentence.clone(),
|
||||
sentence.clone(),
|
||||
vec![id],
|
||||
vec![type_id],
|
||||
vec![sentence.to_owned()],
|
||||
vec![(0, sentence.len())],
|
||||
vec![0],
|
||||
vec![1],
|
||||
None,
|
||||
));
|
||||
}
|
||||
|
||||
// 2. Pre tokenization
|
||||
let pre_tokenized = self.pre_tokenize(&normalized)?;
|
||||
// 1. Normalization
|
||||
// TODO: Make sure we have the offsets update necessary to go from the original text to
|
||||
// the normalized one
|
||||
let original = sentence.clone();
|
||||
let normalized = self.normalize(sentence)?;
|
||||
|
||||
// 3. Model
|
||||
let output = self.model.tokenize(pre_tokenized)?;
|
||||
let length = output.len();
|
||||
// 2. Pre tokenization
|
||||
let pre_tokenized = self.pre_tokenize(&normalized)?;
|
||||
|
||||
let (ids, tokens, offsets) = output.into_iter().fold(
|
||||
(
|
||||
Vec::with_capacity(length),
|
||||
Vec::with_capacity(length),
|
||||
Vec::with_capacity(length),
|
||||
),
|
||||
|(mut ids, mut tokens, mut offsets), t| {
|
||||
ids.push(t.id);
|
||||
tokens.push(t.value);
|
||||
offsets.push(t.offsets);
|
||||
(ids, tokens, offsets)
|
||||
},
|
||||
);
|
||||
// 3. Model
|
||||
let output = self.model.tokenize(pre_tokenized)?;
|
||||
let length = output.len();
|
||||
|
||||
Ok(Encoding::new(
|
||||
original,
|
||||
normalized,
|
||||
ids,
|
||||
vec![type_id; length],
|
||||
tokens,
|
||||
offsets,
|
||||
vec![0; length],
|
||||
vec![1; length],
|
||||
None,
|
||||
))
|
||||
let (ids, tokens, offsets) = output.into_iter().fold(
|
||||
(
|
||||
Vec::with_capacity(length),
|
||||
Vec::with_capacity(length),
|
||||
Vec::with_capacity(length),
|
||||
),
|
||||
|(mut ids, mut tokens, mut offsets), t| {
|
||||
ids.push(t.id);
|
||||
tokens.push(t.value);
|
||||
offsets.push(t.offsets);
|
||||
(ids, tokens, offsets)
|
||||
},
|
||||
);
|
||||
|
||||
Ok(Encoding::new(
|
||||
original,
|
||||
normalized,
|
||||
ids,
|
||||
vec![type_id; length],
|
||||
tokens,
|
||||
offsets,
|
||||
vec![0; length],
|
||||
vec![1; length],
|
||||
None,
|
||||
))
|
||||
})
|
||||
.collect::<Result<Vec<Encoding>>>()?;
|
||||
|
||||
if encodings.is_empty() {
|
||||
return Ok(Encoding::default());
|
||||
}
|
||||
|
||||
let others = encodings.split_off(1);
|
||||
let mut first: Encoding = encodings.into_iter().nth(0).unwrap();
|
||||
|
||||
for encoding in others {
|
||||
first.merge_with(encoding);
|
||||
}
|
||||
|
||||
Ok(first)
|
||||
};
|
||||
|
||||
let (sentence, pair) = match input {
|
||||
@ -412,4 +448,54 @@ impl Tokenizer {
|
||||
// Return the number of added tokens
|
||||
tokens.len() - ignored
|
||||
}
|
||||
|
||||
/// Split the given sentence on multiple parts, finding the added tokens and their id in the process
|
||||
fn split_on_added_tokens(&self, sentence: &str) -> Vec<(String, Option<u32>)> {
|
||||
if let Some(split_re) = &self.split_re {
|
||||
let splits = split_re
|
||||
.find_iter(&sentence)
|
||||
.map(|m| (m.start(), m.end()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// We also insert the splits that are inbetween the added tokens, to split the entire string
|
||||
let mut start_offset = 0;
|
||||
let mut splits = splits
|
||||
.into_iter()
|
||||
.map(|(start, end)| {
|
||||
let mut splits = vec![];
|
||||
if start_offset < start {
|
||||
splits.push((start_offset, start));
|
||||
}
|
||||
splits.push((start, end));
|
||||
start_offset = end;
|
||||
|
||||
splits
|
||||
})
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
if let Some((_, end)) = splits.iter().last().copied() {
|
||||
if end < sentence.len() {
|
||||
splits.push((end, sentence.len()));
|
||||
}
|
||||
}
|
||||
|
||||
if splits.is_empty() {
|
||||
vec![(sentence.to_owned(), None)]
|
||||
} else {
|
||||
splits
|
||||
.into_iter()
|
||||
.map(|(start, end)| unsafe {
|
||||
let s = sentence.get_unchecked(start..end).to_owned();
|
||||
let id = self.added_tokens.get(&AddedToken {
|
||||
content: s.clone(),
|
||||
..Default::default()
|
||||
});
|
||||
(s, id.copied())
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
} else {
|
||||
vec![(sentence.to_owned(), None)]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user