Fix whitespace handling in ByteLevel

This commit is contained in:
Anthony MOI
2019-12-24 11:20:26 -05:00
parent 9f1421a04b
commit cf0e8917cd

View File

@ -1,6 +1,7 @@
use crate::tokenizer::{Decoder, PreTokenizer, Result}; use crate::tokenizer::{Decoder, PreTokenizer, Result};
use regex::Regex; use regex::Regex;
use std::collections::HashMap; use std::collections::HashMap;
use unicode_categories::UnicodeCategories;
fn bytes_char() -> HashMap<u8, u32> { fn bytes_char() -> HashMap<u8, u32> {
let mut bs: Vec<u8> = vec![]; let mut bs: Vec<u8> = vec![];
@ -59,9 +60,9 @@ impl PreTokenizer for ByteLevel {
let last = s[start..end].chars().last(); let last = s[start..end].chars().last();
let next = s[end..].chars().nth(0); let next = s[end..].chars().nth(0);
if last.is_some() if last.is_some()
&& last.unwrap().is_whitespace() && last.unwrap().is_separator_space()
&& next.is_some() && next.is_some()
&& !next.unwrap().is_whitespace() && !next.unwrap().is_separator_space()
{ {
if let Some(newstr) = s[start..end] if let Some(newstr) = s[start..end]
.chars() .chars()
@ -77,12 +78,10 @@ impl PreTokenizer for ByteLevel {
// a whitespace before our match // a whitespace before our match
let prev = s[0..start].chars().last(); let prev = s[0..start].chars().last();
let current = s[start..end].chars().nth(0).map(|c| c.is_whitespace()); let current = s[start..end].chars().nth(0).map(|c| c.is_whitespace());
if prev.is_some() if let (Some(prev), Some(current)) = (prev, current) {
&& prev.unwrap().is_whitespace() if prev.is_separator_space() && !current {
&& current.is_some() return format!("{}{}", prev, s[start..end].to_owned());
&& !current.unwrap() }
{
return format!(" {}", s[start..end].to_owned());
} }
s[start..end].to_owned() s[start..end].to_owned()
@ -183,4 +182,39 @@ mod tests {
assert_eq!(sample, bl.decode(separated_tokens).unwrap()); assert_eq!(sample, bl.decode(separated_tokens).unwrap());
} }
} }
#[test]
fn handling_of_newlines() {
let s = String::from("Hello there\nHello there");
let pretok = ByteLevel::new(false);
let p = pretok.pre_tokenize(&s).unwrap();
assert_eq!(
p,
vec![
String::from("Hello"),
String::from("Ġthere"),
String::from("Ċ"),
String::from("Hello"),
String::from("Ġthere")
]
);
}
#[test]
fn handling_of_multiple_whitespaces() {
let s = String::from("Hello there dear");
let pretok = ByteLevel::new(false);
let p = pretok.pre_tokenize(&s).unwrap();
assert_eq!(
p,
vec![
String::from("Hello"),
String::from("Ġthere"),
String::from("ĠĠĠĠĠĠ"),
String::from("Ġdear")
]
);
}
} }