Ability to encode with added tokens

This commit is contained in:
Anthony MOI
2019-12-16 18:22:17 -05:00
parent 45c2d25a9f
commit 4c7f6e1f04

View File

@ -191,6 +191,27 @@ impl Tokenizer {
/// Encode the given sentence
pub fn encode(&self, input: EncodeInput) -> Result<Encoding> {
let generate_output = move |sentence: String, type_id: u32| -> Result<Encoding> {
// First we need to split into as many sequences as needed to avoid splitting
// on our added tokens
let mut encodings = self
.split_on_added_tokens(&sentence)
.into_iter()
.map(|(sentence, id)| -> Result<Encoding> {
// If this is one of our added tokens, lets return an encoding directly
if let Some(id) = id {
return Ok(Encoding::new(
sentence.clone(),
sentence.clone(),
vec![id],
vec![type_id],
vec![sentence.to_owned()],
vec![(0, sentence.len())],
vec![0],
vec![1],
None,
));
}
// 1. Normalization
// TODO: Make sure we have the offsets update necessary to go from the original text to
// the normalized one
@ -229,6 +250,21 @@ impl Tokenizer {
vec![1; length],
None,
))
})
.collect::<Result<Vec<Encoding>>>()?;
if encodings.is_empty() {
return Ok(Encoding::default());
}
let others = encodings.split_off(1);
let mut first: Encoding = encodings.into_iter().nth(0).unwrap();
for encoding in others {
first.merge_with(encoding);
}
Ok(first)
};
let (sentence, pair) = match input {
@ -412,4 +448,54 @@ impl Tokenizer {
// Return the number of added tokens
tokens.len() - ignored
}
/// Split the given sentence on multiple parts, finding the added tokens and their id in the process
fn split_on_added_tokens(&self, sentence: &str) -> Vec<(String, Option<u32>)> {
if let Some(split_re) = &self.split_re {
let splits = split_re
.find_iter(&sentence)
.map(|m| (m.start(), m.end()))
.collect::<Vec<_>>();
// We also insert the splits that are inbetween the added tokens, to split the entire string
let mut start_offset = 0;
let mut splits = splits
.into_iter()
.map(|(start, end)| {
let mut splits = vec![];
if start_offset < start {
splits.push((start_offset, start));
}
splits.push((start, end));
start_offset = end;
splits
})
.flatten()
.collect::<Vec<_>>();
if let Some((_, end)) = splits.iter().last().copied() {
if end < sentence.len() {
splits.push((end, sentence.len()));
}
}
if splits.is_empty() {
vec![(sentence.to_owned(), None)]
} else {
splits
.into_iter()
.map(|(start, end)| unsafe {
let s = sentence.get_unchecked(start..end).to_owned();
let id = self.added_tokens.get(&AddedToken {
content: s.clone(),
..Default::default()
});
(s, id.copied())
})
.collect()
}
} else {
vec![(sentence.to_owned(), None)]
}
}
}