diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 8b3755a9..547bbff3 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -1,6 +1,7 @@ use crate::tokenizer::{Decoder, PreTokenizer, Result}; use regex::Regex; use std::collections::HashMap; +use unicode_categories::UnicodeCategories; fn bytes_char() -> HashMap { let mut bs: Vec = vec![]; @@ -59,9 +60,9 @@ impl PreTokenizer for ByteLevel { let last = s[start..end].chars().last(); let next = s[end..].chars().nth(0); if last.is_some() - && last.unwrap().is_whitespace() + && last.unwrap().is_separator_space() && next.is_some() - && !next.unwrap().is_whitespace() + && !next.unwrap().is_separator_space() { if let Some(newstr) = s[start..end] .chars() @@ -77,12 +78,10 @@ impl PreTokenizer for ByteLevel { // a whitespace before our match let prev = s[0..start].chars().last(); let current = s[start..end].chars().nth(0).map(|c| c.is_whitespace()); - if prev.is_some() - && prev.unwrap().is_whitespace() - && current.is_some() - && !current.unwrap() - { - return format!(" {}", s[start..end].to_owned()); + if let (Some(prev), Some(current)) = (prev, current) { + if prev.is_separator_space() && !current { + return format!("{}{}", prev, s[start..end].to_owned()); + } } s[start..end].to_owned() @@ -183,4 +182,39 @@ mod tests { assert_eq!(sample, bl.decode(separated_tokens).unwrap()); } } + + #[test] + fn handling_of_newlines() { + let s = String::from("Hello there\nHello there"); + let pretok = ByteLevel::new(false); + let p = pretok.pre_tokenize(&s).unwrap(); + + assert_eq!( + p, + vec![ + String::from("Hello"), + String::from("Ġthere"), + String::from("Ċ"), + String::from("Hello"), + String::from("Ġthere") + ] + ); + } + + #[test] + fn handling_of_multiple_whitespaces() { + let s = String::from("Hello there dear"); + let pretok = ByteLevel::new(false); + let p = pretok.pre_tokenize(&s).unwrap(); + + assert_eq!( + p, + vec![ + String::from("Hello"), + String::from("Ġthere"), + String::from("ĠĠĠĠĠĠ"), + String::from("Ġdear") + ] + ); + } }