Fix whitespace handling in ByteLevel

This commit is contained in:
Anthony MOI
2019-12-24 11:20:26 -05:00
parent 9f1421a04b
commit cf0e8917cd

View File

@ -1,6 +1,7 @@
use crate::tokenizer::{Decoder, PreTokenizer, Result};
use regex::Regex;
use std::collections::HashMap;
use unicode_categories::UnicodeCategories;
fn bytes_char() -> HashMap<u8, u32> {
let mut bs: Vec<u8> = vec![];
@ -59,9 +60,9 @@ impl PreTokenizer for ByteLevel {
let last = s[start..end].chars().last();
let next = s[end..].chars().nth(0);
if last.is_some()
&& last.unwrap().is_whitespace()
&& last.unwrap().is_separator_space()
&& next.is_some()
&& !next.unwrap().is_whitespace()
&& !next.unwrap().is_separator_space()
{
if let Some(newstr) = s[start..end]
.chars()
@ -77,12 +78,10 @@ impl PreTokenizer for ByteLevel {
// a whitespace before our match
let prev = s[0..start].chars().last();
let current = s[start..end].chars().nth(0).map(|c| c.is_whitespace());
if prev.is_some()
&& prev.unwrap().is_whitespace()
&& current.is_some()
&& !current.unwrap()
{
return format!(" {}", s[start..end].to_owned());
if let (Some(prev), Some(current)) = (prev, current) {
if prev.is_separator_space() && !current {
return format!("{}{}", prev, s[start..end].to_owned());
}
}
s[start..end].to_owned()
@ -183,4 +182,39 @@ mod tests {
assert_eq!(sample, bl.decode(separated_tokens).unwrap());
}
}
#[test]
fn handling_of_newlines() {
let s = String::from("Hello there\nHello there");
let pretok = ByteLevel::new(false);
let p = pretok.pre_tokenize(&s).unwrap();
assert_eq!(
p,
vec![
String::from("Hello"),
String::from("Ġthere"),
String::from("Ċ"),
String::from("Hello"),
String::from("Ġthere")
]
);
}
#[test]
fn handling_of_multiple_whitespaces() {
let s = String::from("Hello there dear");
let pretok = ByteLevel::new(false);
let p = pretok.pre_tokenize(&s).unwrap();
assert_eq!(
p,
vec![
String::from("Hello"),
String::from("Ġthere"),
String::from("ĠĠĠĠĠĠ"),
String::from("Ġdear")
]
);
}
}