Implement suggestions by @sebpuetz

Co-authored-by: Sebastian Pütz <sebastian.puetz@uni-tuebingen.de>
This commit is contained in:
Anthony MOI
2020-07-31 16:49:00 -04:00
committed by Anthony MOI
parent efb29582c6
commit dad70e8e85
11 changed files with 161 additions and 157 deletions

View File

@@ -61,7 +61,7 @@ impl PreTokenizer {
.into_py()?;
Ok(pretokenized
.get_normalized(true)
.get_normalized(tk::OffsetReferential::Original)
.into_iter()
.map(|(s, o)| (s.to_owned(), o))
.collect())

View File

@@ -32,6 +32,7 @@ impl PreTokenizer for BertPreTokenizer {
#[cfg(test)]
mod tests {
use super::*;
use crate::OffsetReferential;
#[test]
fn basic() {
@@ -39,7 +40,7 @@ mod tests {
let mut pretokenized: PreTokenizedString = "Hey friend! How are you?!?".into();
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized.get_normalized(true),
pretokenized.get_normalized(OffsetReferential::Original),
vec![
("Hey", (0, 3)),
("", (3, 4)),

View File

@@ -218,8 +218,8 @@ pub fn process_offsets(encoding: &mut Encoding, add_prefix_space: bool) {
mod tests {
use super::ByteLevel;
use crate::tokenizer::{
normalizer::Range, Decoder, Encoding, NormalizedString, PostProcessor, PreTokenizedString,
PreTokenizer,
normalizer::Range, Decoder, Encoding, NormalizedString, OffsetReferential, PostProcessor,
PreTokenizedString, PreTokenizer,
};
#[test]
@@ -228,7 +228,7 @@ mod tests {
let mut pretokenized: PreTokenizedString = "Hello my friend, how is your day going?".into();
bytelevel.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized.get_normalized(true),
pretokenized.get_normalized(OffsetReferential::Original),
vec![
("Hello", (0, 5)),
("Ġmy", (5, 8)),
@@ -273,7 +273,7 @@ mod tests {
let mut pretokenized = PreTokenizedString::from(*s);
bytelevel.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized.get_normalized(false),
pretokenized.get_normalized(OffsetReferential::Normalized),
vec![
("ĠHello", (0, 6)),
("Ġmy", (6, 9)),
@@ -317,7 +317,7 @@ mod tests {
bytelevel.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized.get_normalized(true),
pretokenized.get_normalized(OffsetReferential::Original),
vec![
("Hello", (0, 5)),
("Ġthere", (5, 11)),
@@ -335,7 +335,7 @@ mod tests {
bytelevel.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized.get_normalized(true),
pretokenized.get_normalized(OffsetReferential::Original),
vec![
("Hello", (0, 5)),
("Ġthere", (5, 11)),
@@ -352,11 +352,11 @@ mod tests {
bytelevel.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized.get_normalized(true),
pretokenized.get_normalized(OffsetReferential::Original),
vec![("i", (0, 1)), ("âŃ¢", (1, 2)), ("j", (2, 3))]
);
assert_eq!(
pretokenized.get_normalized(false),
pretokenized.get_normalized(OffsetReferential::Normalized),
vec![("i", (0, 1)), ("âŃ¢", (1, 4)), ("j", (4, 5))]
);
assert_eq!(

View File

@@ -6,25 +6,18 @@ use serde::{Deserialize, Serialize};
/// splits on this character
pub struct Metaspace {
replacement: char,
str_bytes: [u8; 4],
str_rep: String,
add_prefix_space: bool,
}
impl Metaspace {
pub fn new(replacement: char, add_prefix_space: bool) -> Self {
let mut str_bytes = [0; 4];
replacement.encode_utf8(&mut str_bytes);
Self {
replacement,
str_bytes,
str_rep: replacement.to_string(),
add_prefix_space,
}
}
#[inline]
fn replacement(&self) -> &str {
unsafe { std::str::from_utf8_unchecked(&self.str_bytes[..self.replacement.len_utf8()]) }
}
}
impl Default for Metaspace {
@@ -38,14 +31,14 @@ impl PreTokenizer for Metaspace {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
pretokenized.split(|_, mut normalized| {
if self.add_prefix_space {
normalized.prepend(&self.replacement());
normalized.prepend(&self.str_rep);
}
Ok(normalized
.split(' ', SplitDelimiterBehavior::MergedWithNext)?
.into_iter()
.map(|mut normalized| {
normalized.replace(' ', self.replacement())?;
normalized.replace(' ', &self.str_rep)?;
Ok(normalized)
})
.collect::<Result<Vec<_>>>()?)
@@ -78,6 +71,7 @@ impl Decoder for Metaspace {
#[cfg(test)]
mod tests {
use super::*;
use crate::OffsetReferential;
#[test]
fn basic() {
@@ -85,11 +79,11 @@ mod tests {
let mut pretokenized = PreTokenizedString::from("Hey friend!");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized.get_normalized(false),
pretokenized.get_normalized(OffsetReferential::Normalized),
vec![("▁Hey", (0, 4)), ("▁friend!", (4, 12))]
);
assert_eq!(
pretokenized.get_normalized(true),
pretokenized.get_normalized(OffsetReferential::Original),
vec![("▁Hey", (0, 3)), ("▁friend!", (3, 11))]
);
}
@@ -100,7 +94,7 @@ mod tests {
let mut pretokenized = PreTokenizedString::from("Hey friend!");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized.get_normalized(false),
pretokenized.get_normalized(OffsetReferential::Normalized),
vec![
("▁Hey", (0, 4)),
("", (4, 5)),
@@ -109,7 +103,7 @@ mod tests {
]
);
assert_eq!(
pretokenized.get_normalized(true),
pretokenized.get_normalized(OffsetReferential::Original),
vec![
("▁Hey", (0, 3)),
("", (3, 4)),

View File

@@ -45,7 +45,7 @@ impl PreTokenizer for WhitespaceSplit {
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::PreTokenizer;
use crate::{OffsetReferential, PreTokenizer};
#[test]
fn basic() {
@@ -77,7 +77,10 @@ mod tests {
for (s, res) in tests {
let mut pretokenized = PreTokenizedString::from(s);
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(pretokenized.get_normalized(true), res);
assert_eq!(
pretokenized.get_normalized(OffsetReferential::Original),
res
);
}
}
@@ -103,7 +106,10 @@ mod tests {
for (s, res) in tests {
let mut pretokenized = PreTokenizedString::from(s);
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(pretokenized.get_normalized(true), res);
assert_eq!(
pretokenized.get_normalized(OffsetReferential::Original),
res
);
}
}
}

View File

@@ -454,24 +454,22 @@ impl AddedVocabulary {
pretokenized
.split(|i, mut sequence| {
if let Some(id) = indices[i] {
multi_indices.push(vec![Some(id)]);
multi_indices.push(Some(id));
Ok(itertools::Either::Left(std::iter::once(sequence)))
} else {
normalizer.map(|n| n.normalize(&mut sequence));
let (idcs, split) =
self.split_with_indices(sequence, &self.split_normalized_re);
multi_indices.push(idcs);
multi_indices.extend(idcs);
Ok(itertools::Either::Right(split))
}
})
.expect("AddedVocabulary bad split");
let indices = multi_indices.into_iter().flatten().collect::<Vec<_>>();
pretokenized
.into_iter()
.zip(indices)
.zip(multi_indices)
.map(|(substring, id)| {
(
substring.normalized,

View File

@@ -47,6 +47,19 @@ impl Encoding {
}
}
pub fn with_capacity(len: usize) -> Self {
Encoding {
ids: Vec::with_capacity(len),
type_ids: Vec::with_capacity(len),
tokens: Vec::with_capacity(len),
words: Vec::with_capacity(len),
offsets: Vec::with_capacity(len),
special_tokens_mask: Vec::with_capacity(len),
attention_mask: Vec::with_capacity(len),
overflowing: vec![],
}
}
pub fn from_tokens(tokens: Vec<Token>, type_id: u32) -> Self {
let length = tokens.len();
let (ids, tokens, offsets) = tokens.into_iter().fold(
@@ -404,6 +417,27 @@ impl std::iter::FromIterator<Encoding> for Encoding {
}
}
impl std::iter::FromIterator<(u32, String, (usize, usize), Option<u32>, u32)> for Encoding {
fn from_iter<I: IntoIterator<Item = (u32, String, (usize, usize), Option<u32>, u32)>>(
iter: I,
) -> Self {
let items = iter.into_iter();
let (lower, upper) = items.size_hint();
let length = upper.unwrap_or(lower);
let mut encoding = Self::with_capacity(length);
for (id, token, offsets, word, type_id) in items {
encoding.ids.push(id);
encoding.tokens.push(token);
encoding.offsets.push(offsets);
encoding.type_ids.push(type_id);
encoding.words.push(word);
}
encoding
}
}
#[inline]
fn get_current_part<T: Clone>(
prev: &[T],

View File

@@ -31,7 +31,7 @@ mod serialization;
pub use added_vocabulary::*;
pub use encoding::*;
pub use normalizer::{NormalizedString, SplitDelimiterBehavior};
pub use normalizer::{NormalizedString, OffsetReferential, SplitDelimiterBehavior};
pub use pre_tokenizer::*;
pub type Error = Box<dyn std::error::Error + Send + Sync>;
@@ -423,6 +423,7 @@ impl Tokenizer {
.added_vocabulary
.extract_and_normalize(self.normalizer.as_deref(), &subseq)
.map(|(normalized, original_offsets, id)| match id {
// This is an added token, no need to tokenize, we have the ID
Some(id) => {
let mut encoding = Encoding::from_tokens(
vec![Token::new(
@@ -435,6 +436,7 @@ impl Tokenizer {
encoding.get_words_mut()[0] = Some(0);
Ok(encoding)
}
// Let's tokenize
None => self.do_tokenize(
self.do_pre_tokenize(normalized)?,
original_offsets,
@@ -675,45 +677,40 @@ impl Tokenizer {
) -> Result<Encoding> {
let pretokenized: PreTokenizedString = pretokenized.into();
let mut empty_words = 0;
pretokenized
.into_iter()
.filter(|substr| !substr.normalized.is_empty())
.enumerate()
.map(|(word_idx, substr)| {
if substr.normalized.is_empty() {
empty_words += 1;
return Ok(Encoding::default());
}
let mut tokens = self.model.tokenize(substr.normalized.get())?;
// Update the offsets to match the original input
tokens.iter_mut().for_each(|token| {
.flat_map(|(word_idx, substr)| {
match self.model.tokenize(substr.normalized.get()) {
Ok(tokens) => {
itertools::Either::Left(tokens.into_iter().map(move |token| {
// We convert the normalized offsets back to the original
let converted_offsets = substr
.normalized
.convert_offsets(Range::Normalized(token.offsets.0..token.offsets.1))
.convert_offsets(Range::Normalized(
token.offsets.0..token.offsets.1,
))
.map_or(token.offsets, |range| {
(
original_offsets.0 + substr.original_offsets.0 + range.start,
original_offsets.0
+ substr.original_offsets.0
+ range.start,
original_offsets.0 + substr.original_offsets.0 + range.end,
)
});
// And we update the token to these original offsets, applying the original offset
// of the sequence we just tokenized.
token.offsets = converted_offsets;
});
// Then build the encoding from these tokens, setting the `words` as relevant
let mut encoding = Encoding::from_tokens(tokens, type_id);
encoding.get_words_mut().iter_mut().for_each(|word| {
// empty words are generally spaces, and other things
// that were normalized out, so we dont want to count them in.
*word = Some(word_idx as u32 - empty_words);
});
Ok(encoding)
Ok((
token.id,
token.value,
converted_offsets,
Some(word_idx as u32),
type_id,
))
}))
}
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
}
})
.collect()
}

View File

@@ -5,6 +5,12 @@ use crate::{Offsets, Result};
use std::ops::{Bound, RangeBounds};
use unicode_normalization_alignments::UnicodeNormalization;
/// The possible offsets referential
pub enum OffsetReferential {
Original,
Normalized,
}
/// Represents a Range usable by the NormalizedString to index its content.
/// A Range can use indices relative to either the `Original` or the `Normalized` string
#[derive(Debug, Clone)]
@@ -300,11 +306,11 @@ impl NormalizedString {
Some(Self {
original: get_range_of(&self.original, r_original)
.unwrap_or("")
.to_owned(),
.unwrap_or_default()
.into(),
normalized: get_range_of(&self.normalized, r_normalized.clone())
.unwrap_or("")
.to_owned(),
.unwrap_or_default()
.into(),
alignments: self
.alignments
.get(r_normalized)?
@@ -462,7 +468,7 @@ impl NormalizedString {
pub fn replace<P: Pattern>(&mut self, pattern: P, content: &str) -> Result<()> {
let matches = pattern.find_matches(&self.normalized)?;
let (normalized, alignments): (Vec<char>, Vec<Offsets>) = matches
let (normalized, alignments): (String, Vec<Offsets>) = matches
.into_iter()
.flat_map(|((start, end), is_match)| {
let len = end - start;
@@ -490,7 +496,7 @@ impl NormalizedString {
})
.unzip();
self.normalized = normalized.into_iter().collect();
self.normalized = normalized;
self.alignments = alignments;
Ok(())

View File

@@ -31,34 +31,18 @@ impl Pattern for &Regex {
return Ok(vec![((0, 0), false)]);
}
// Find initial matches
let matches = self
.find_iter(inside)
.map(|m| ((m.start(), m.end()), true))
.collect::<Vec<_>>();
// Then add missing splits inbetween
let mut start_offset = 0;
let mut splits = matches
.into_iter()
.flat_map(|((start, end), flag)| {
let mut splits = vec![];
if start_offset < start {
splits.push(((start_offset, start), false));
let mut prev = 0;
let mut splits = Vec::with_capacity(inside.len());
for m in self.find_iter(inside) {
if prev != m.start() {
splits.push(((prev, m.start()), false));
}
splits.push(((start, end), flag));
start_offset = end;
splits
})
.collect::<Vec<_>>();
if let Some(((_, end), _)) = splits.iter().last().copied() {
if end < inside.len() {
splits.push(((end, inside.len()), false));
splits.push(((m.start(), m.end()), true));
prev = m.end();
}
if prev != inside.len() {
splits.push(((prev, inside.len()), false))
}
Ok(splits)
}
}

View File

@@ -1,4 +1,4 @@
use crate::{NormalizedString, Offsets, Result};
use crate::{NormalizedString, OffsetReferential, Offsets, Result};
/// Wrapper for a subpart of a `NormalizedString`.
///
@@ -16,7 +16,15 @@ pub struct SubString {
pub original_offsets: Offsets,
}
/// A `PreTokenizedString` takes care of splitting the input string in multiple
impl SubString {
pub fn new(normalized: NormalizedString, original_offsets: Offsets) -> Self {
Self {
normalized,
original_offsets,
}
}
}
/// sub strings, while ensuring that they form a coherend group. This let us keep
/// track of the offsets during the whole normalization and pre-tokenization steps.
#[derive(Debug)]
@@ -42,52 +50,28 @@ impl PreTokenizedString {
F: FnMut(usize, NormalizedString) -> Result<U>,
U: IntoIterator<Item = NormalizedString>,
{
self.parts = self
.parts
.drain(..)
.enumerate()
.flat_map(|(i, sub)| {
// new_parts is at least as big as self.parts
let mut new_parts = Vec::with_capacity(self.parts.len());
for (i, sub) in self.parts.drain(..).enumerate() {
let original_len = sub.normalized.len_original();
let original_offsets = sub.original_offsets;
let mut new_len = 0;
let res = split_fn(i, sub.normalized);
if let Err(e) = res {
return itertools::Either::Left(std::iter::once(Err(e)));
}
let parts = res
.unwrap()
.into_iter()
.map(|normalized| {
new_parts.extend(split_fn(i, sub.normalized)?.into_iter().map(|normalized| {
let len = normalized.len_original();
let new_s = SubString {
normalized,
original_offsets: (
original_offsets.0 + new_len,
original_offsets.0 + new_len + len,
),
};
let start = original_offsets.0 + new_len;
let end = original_offsets.0 + new_len + len;
let new_s = SubString::new(normalized, (start, end));
new_len += len;
new_s
})
.collect::<Vec<_>>();
if new_len != original_len {
println!(
"Original offsets: {:?}\nNew: {:?}",
(0, original_len),
(0, new_len)
}));
if original_len != new_len {
return Err(
"Split pre-tokenized string must represent the entire original string".into(),
);
itertools::Either::Left(std::iter::once(Err(
"Split pre-tokenized string must represent the entire original string"
.into(),
)))
} else {
itertools::Either::Right(parts.into_iter().map(Ok))
}
})
.collect::<Result<Vec<_>>>()?;
}
self.parts = new_parts;
Ok(())
}
@@ -98,19 +82,20 @@ impl PreTokenizedString {
/// Returns a list of normalized string and the associated offsets,
/// either in original or normalized referential
pub fn get_normalized(&self, original: bool) -> Vec<(&str, Offsets)> {
pub fn get_normalized(&self, offset_type: OffsetReferential) -> Vec<(&str, Offsets)> {
let mut offset = 0;
self.iter()
.map(|sub| {
let offsets = if original {
(
let offsets = match offset_type {
OffsetReferential::Original => (
sub.original_offsets.0,
sub.original_offsets.0 + sub.normalized.len_original(),
)
} else {
),
OffsetReferential::Normalized => {
let len = sub.normalized.len();
offset += len;
(offset - len, offset)
}
};
(sub.normalized.get(), offsets)
@@ -176,4 +161,3 @@ impl<'a> IntoIterator for &'a PreTokenizedString {
self.parts.iter()
}
}