ByteLevel PreTokenizer handles offsets

This commit is contained in:
Anthony MOI
2019-12-29 00:08:42 -05:00
parent 35a8dfdd55
commit ad9cc52d83

View File

@ -1,4 +1,4 @@
use crate::tokenizer::{Decoder, PreTokenizer, Result};
use crate::tokenizer::{Decoder, Offsets, PreTokenizer, Result};
use regex::Regex;
use std::collections::HashMap;
use unicode_categories::UnicodeCategories;
@ -41,7 +41,7 @@ impl ByteLevel {
}
impl PreTokenizer for ByteLevel {
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
fn pre_tokenize(&self, s: &str) -> Result<Vec<(String, Offsets)>> {
let s = if self.add_prefix_space && !s.starts_with(' ') {
format!(" {}", s)
} else {
@ -59,19 +59,11 @@ impl PreTokenizer for ByteLevel {
// we don't want to return it
let last = s[start..end].chars().last();
let next = s[end..].chars().nth(0);
if last.is_some()
&& last.unwrap().is_separator_space()
&& next.is_some()
&& !next.unwrap().is_separator_space()
{
if let Some(newstr) = s[start..end]
.chars()
.collect::<Vec<_>>()
.split_last()
.map(|(_, rest)| rest)
.map(|chars| chars.iter().collect::<String>())
{
return newstr;
if let (Some(last), Some(next)) = (last, next) {
if last.is_separator_space() && !next.is_separator_space() {
let bytes = s[start..end - 1].as_bytes().to_vec();
let offsets = (start, end - 1);
return (bytes, offsets);
}
}
// if our first char is not a whitespace but the previous one was, we return
@ -80,17 +72,22 @@ impl PreTokenizer for ByteLevel {
let current = s[start..end].chars().nth(0).map(|c| c.is_whitespace());
if let (Some(prev), Some(current)) = (prev, current) {
if prev.is_separator_space() && !current {
return format!("{}{}", prev, s[start..end].to_owned());
let bytes =
[format!("{}", prev).as_bytes(), s[start..end].as_bytes()].concat();
let offsets = (start - 1, end);
return (bytes, offsets);
}
}
s[start..end].to_owned()
(s[start..end].as_bytes().to_vec(), (start, end))
})
.map(|s| {
s.into_bytes()
.iter()
.map(|b| std::char::from_u32(BYTES_CHAR[b]).unwrap())
.collect()
.map(|(s, offsets)| {
(
s.iter()
.map(|b| std::char::from_u32(BYTES_CHAR[b]).unwrap())
.collect(),
offsets,
)
})
.collect())
}
@ -122,7 +119,16 @@ mod tests {
.pre_tokenize("Hello my friend, how is your day going?")
.unwrap(),
vec![
"Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?"
("Hello".into(), (0, 5)),
("Ġmy".into(), (5, 8)),
("Ġfriend".into(), (8, 15)),
(",".into(), (15, 16)),
("Ġhow".into(), (16, 20)),
("Ġis".into(), (20, 23)),
("Ġyour".into(), (23, 28)),
("Ġday".into(), (28, 32)),
("Ġgoing".into(), (32, 38)),
("?".into(), (38, 39))
]
);
}
@ -154,7 +160,16 @@ mod tests {
.pre_tokenize("Hello my friend, how is your day going?")
.unwrap(),
vec![
"ĠHello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?"
("ĠHello".into(), (0, 6)),
("Ġmy".into(), (6, 9)),
("Ġfriend".into(), (9, 16)),
(",".into(), (16, 17)),
("Ġhow".into(), (17, 21)),
("Ġis".into(), (21, 24)),
("Ġyour".into(), (24, 29)),
("Ġday".into(), (29, 33)),
("Ġgoing".into(), (33, 39)),
("?".into(), (39, 40))
]
);
}
@ -176,7 +191,7 @@ mod tests {
let pre_tokenized = bl.pre_tokenize(&sample).unwrap();
let separated_tokens = pre_tokenized
.into_iter()
.map(|token| token.split("").map(|t| t.into()).collect::<Vec<_>>())
.map(|(token, _)| token.split("").map(|t| t.into()).collect::<Vec<_>>())
.flatten()
.collect::<Vec<_>>();
assert_eq!(sample, bl.decode(separated_tokens).unwrap());
@ -192,11 +207,11 @@ mod tests {
assert_eq!(
p,
vec![
String::from("Hello"),
String::from("Ġthere"),
String::from("Ċ"),
String::from("Hello"),
String::from("Ġthere")
("Hello".into(), (0, 5)),
("Ġthere".into(), (5, 11)),
("Ċ".into(), (11, 12)),
("Hello".into(), (12, 17)),
("Ġthere".into(), (17, 23))
]
);
}
@ -210,10 +225,10 @@ mod tests {
assert_eq!(
p,
vec![
String::from("Hello"),
String::from("Ġthere"),
String::from("ĠĠĠĠĠĠ"),
String::from("Ġdear")
("Hello".into(), (0, 5)),
("Ġthere".into(), (5, 11)),
("ĠĠĠĠĠĠ".into(), (11, 17)),
("Ġdear".into(), (17, 22))
]
);
}