Ability to encode with added tokens

This commit is contained in:
Anthony MOI
2019-12-16 18:22:17 -05:00
parent 45c2d25a9f
commit 4c7f6e1f04

View File

@ -191,44 +191,80 @@ impl Tokenizer {
/// Encode the given sentence
pub fn encode(&self, input: EncodeInput) -> Result<Encoding> {
let generate_output = move |sentence: String, type_id: u32| -> Result<Encoding> {
// 1. Normalization
// TODO: Make sure we have the offsets update necessary to go from the original text to
// the normalized one
let original = sentence.clone();
let normalized = self.normalize(sentence)?;
// First we need to split into as many sequences as needed to avoid splitting
// on our added tokens
let mut encodings = self
.split_on_added_tokens(&sentence)
.into_iter()
.map(|(sentence, id)| -> Result<Encoding> {
// If this is one of our added tokens, lets return an encoding directly
if let Some(id) = id {
return Ok(Encoding::new(
sentence.clone(),
sentence.clone(),
vec![id],
vec![type_id],
vec![sentence.to_owned()],
vec![(0, sentence.len())],
vec![0],
vec![1],
None,
));
}
// 2. Pre tokenization
let pre_tokenized = self.pre_tokenize(&normalized)?;
// 1. Normalization
// TODO: Make sure we have the offsets update necessary to go from the original text to
// the normalized one
let original = sentence.clone();
let normalized = self.normalize(sentence)?;
// 3. Model
let output = self.model.tokenize(pre_tokenized)?;
let length = output.len();
// 2. Pre tokenization
let pre_tokenized = self.pre_tokenize(&normalized)?;
let (ids, tokens, offsets) = output.into_iter().fold(
(
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
),
|(mut ids, mut tokens, mut offsets), t| {
ids.push(t.id);
tokens.push(t.value);
offsets.push(t.offsets);
(ids, tokens, offsets)
},
);
// 3. Model
let output = self.model.tokenize(pre_tokenized)?;
let length = output.len();
Ok(Encoding::new(
original,
normalized,
ids,
vec![type_id; length],
tokens,
offsets,
vec![0; length],
vec![1; length],
None,
))
let (ids, tokens, offsets) = output.into_iter().fold(
(
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
),
|(mut ids, mut tokens, mut offsets), t| {
ids.push(t.id);
tokens.push(t.value);
offsets.push(t.offsets);
(ids, tokens, offsets)
},
);
Ok(Encoding::new(
original,
normalized,
ids,
vec![type_id; length],
tokens,
offsets,
vec![0; length],
vec![1; length],
None,
))
})
.collect::<Result<Vec<Encoding>>>()?;
if encodings.is_empty() {
return Ok(Encoding::default());
}
let others = encodings.split_off(1);
let mut first: Encoding = encodings.into_iter().nth(0).unwrap();
for encoding in others {
first.merge_with(encoding);
}
Ok(first)
};
let (sentence, pair) = match input {
@ -412,4 +448,54 @@ impl Tokenizer {
// Return the number of added tokens
tokens.len() - ignored
}
/// Split the given sentence on multiple parts, finding the added tokens and their id in the process
fn split_on_added_tokens(&self, sentence: &str) -> Vec<(String, Option<u32>)> {
if let Some(split_re) = &self.split_re {
let splits = split_re
.find_iter(&sentence)
.map(|m| (m.start(), m.end()))
.collect::<Vec<_>>();
// We also insert the splits that are inbetween the added tokens, to split the entire string
let mut start_offset = 0;
let mut splits = splits
.into_iter()
.map(|(start, end)| {
let mut splits = vec![];
if start_offset < start {
splits.push((start_offset, start));
}
splits.push((start, end));
start_offset = end;
splits
})
.flatten()
.collect::<Vec<_>>();
if let Some((_, end)) = splits.iter().last().copied() {
if end < sentence.len() {
splits.push((end, sentence.len()));
}
}
if splits.is_empty() {
vec![(sentence.to_owned(), None)]
} else {
splits
.into_iter()
.map(|(start, end)| unsafe {
let s = sentence.get_unchecked(start..end).to_owned();
let id = self.added_tokens.get(&AddedToken {
content: s.clone(),
..Default::default()
});
(s, id.copied())
})
.collect()
}
} else {
vec![(sentence.to_owned(), None)]
}
}
}