ByteLevel PreTokenizer handles offsets

This commit is contained in:
Anthony MOI
2019-12-29 00:08:42 -05:00
parent 35a8dfdd55
commit ad9cc52d83

View File

@ -1,4 +1,4 @@
use crate::tokenizer::{Decoder, PreTokenizer, Result}; use crate::tokenizer::{Decoder, Offsets, PreTokenizer, Result};
use regex::Regex; use regex::Regex;
use std::collections::HashMap; use std::collections::HashMap;
use unicode_categories::UnicodeCategories; use unicode_categories::UnicodeCategories;
@ -41,7 +41,7 @@ impl ByteLevel {
} }
impl PreTokenizer for ByteLevel { impl PreTokenizer for ByteLevel {
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> { fn pre_tokenize(&self, s: &str) -> Result<Vec<(String, Offsets)>> {
let s = if self.add_prefix_space && !s.starts_with(' ') { let s = if self.add_prefix_space && !s.starts_with(' ') {
format!(" {}", s) format!(" {}", s)
} else { } else {
@ -59,19 +59,11 @@ impl PreTokenizer for ByteLevel {
// we don't want to return it // we don't want to return it
let last = s[start..end].chars().last(); let last = s[start..end].chars().last();
let next = s[end..].chars().nth(0); let next = s[end..].chars().nth(0);
if last.is_some() if let (Some(last), Some(next)) = (last, next) {
&& last.unwrap().is_separator_space() if last.is_separator_space() && !next.is_separator_space() {
&& next.is_some() let bytes = s[start..end - 1].as_bytes().to_vec();
&& !next.unwrap().is_separator_space() let offsets = (start, end - 1);
{ return (bytes, offsets);
if let Some(newstr) = s[start..end]
.chars()
.collect::<Vec<_>>()
.split_last()
.map(|(_, rest)| rest)
.map(|chars| chars.iter().collect::<String>())
{
return newstr;
} }
} }
// if our first char is not a whitespace but the previous one was, we return // if our first char is not a whitespace but the previous one was, we return
@ -80,17 +72,22 @@ impl PreTokenizer for ByteLevel {
let current = s[start..end].chars().nth(0).map(|c| c.is_whitespace()); let current = s[start..end].chars().nth(0).map(|c| c.is_whitespace());
if let (Some(prev), Some(current)) = (prev, current) { if let (Some(prev), Some(current)) = (prev, current) {
if prev.is_separator_space() && !current { if prev.is_separator_space() && !current {
return format!("{}{}", prev, s[start..end].to_owned()); let bytes =
[format!("{}", prev).as_bytes(), s[start..end].as_bytes()].concat();
let offsets = (start - 1, end);
return (bytes, offsets);
} }
} }
s[start..end].to_owned() (s[start..end].as_bytes().to_vec(), (start, end))
}) })
.map(|s| { .map(|(s, offsets)| {
s.into_bytes() (
.iter() s.iter()
.map(|b| std::char::from_u32(BYTES_CHAR[b]).unwrap()) .map(|b| std::char::from_u32(BYTES_CHAR[b]).unwrap())
.collect() .collect(),
offsets,
)
}) })
.collect()) .collect())
} }
@ -122,7 +119,16 @@ mod tests {
.pre_tokenize("Hello my friend, how is your day going?") .pre_tokenize("Hello my friend, how is your day going?")
.unwrap(), .unwrap(),
vec![ vec![
"Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?" ("Hello".into(), (0, 5)),
("Ġmy".into(), (5, 8)),
("Ġfriend".into(), (8, 15)),
(",".into(), (15, 16)),
("Ġhow".into(), (16, 20)),
("Ġis".into(), (20, 23)),
("Ġyour".into(), (23, 28)),
("Ġday".into(), (28, 32)),
("Ġgoing".into(), (32, 38)),
("?".into(), (38, 39))
] ]
); );
} }
@ -154,7 +160,16 @@ mod tests {
.pre_tokenize("Hello my friend, how is your day going?") .pre_tokenize("Hello my friend, how is your day going?")
.unwrap(), .unwrap(),
vec![ vec![
"ĠHello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?" ("ĠHello".into(), (0, 6)),
("Ġmy".into(), (6, 9)),
("Ġfriend".into(), (9, 16)),
(",".into(), (16, 17)),
("Ġhow".into(), (17, 21)),
("Ġis".into(), (21, 24)),
("Ġyour".into(), (24, 29)),
("Ġday".into(), (29, 33)),
("Ġgoing".into(), (33, 39)),
("?".into(), (39, 40))
] ]
); );
} }
@ -176,7 +191,7 @@ mod tests {
let pre_tokenized = bl.pre_tokenize(&sample).unwrap(); let pre_tokenized = bl.pre_tokenize(&sample).unwrap();
let separated_tokens = pre_tokenized let separated_tokens = pre_tokenized
.into_iter() .into_iter()
.map(|token| token.split("").map(|t| t.into()).collect::<Vec<_>>()) .map(|(token, _)| token.split("").map(|t| t.into()).collect::<Vec<_>>())
.flatten() .flatten()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
assert_eq!(sample, bl.decode(separated_tokens).unwrap()); assert_eq!(sample, bl.decode(separated_tokens).unwrap());
@ -192,11 +207,11 @@ mod tests {
assert_eq!( assert_eq!(
p, p,
vec![ vec![
String::from("Hello"), ("Hello".into(), (0, 5)),
String::from("Ġthere"), ("Ġthere".into(), (5, 11)),
String::from("Ċ"), ("Ċ".into(), (11, 12)),
String::from("Hello"), ("Hello".into(), (12, 17)),
String::from("Ġthere") ("Ġthere".into(), (17, 23))
] ]
); );
} }
@ -210,10 +225,10 @@ mod tests {
assert_eq!( assert_eq!(
p, p,
vec![ vec![
String::from("Hello"), ("Hello".into(), (0, 5)),
String::from("Ġthere"), ("Ġthere".into(), (5, 11)),
String::from("ĠĠĠĠĠĠ"), ("ĠĠĠĠĠĠ".into(), (11, 17)),
String::from("Ġdear") ("Ġdear".into(), (17, 22))
] ]
); );
} }