diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 547bbff3..896c20d3 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -1,4 +1,4 @@ -use crate::tokenizer::{Decoder, PreTokenizer, Result}; +use crate::tokenizer::{Decoder, Offsets, PreTokenizer, Result}; use regex::Regex; use std::collections::HashMap; use unicode_categories::UnicodeCategories; @@ -41,7 +41,7 @@ impl ByteLevel { } impl PreTokenizer for ByteLevel { - fn pre_tokenize(&self, s: &str) -> Result> { + fn pre_tokenize(&self, s: &str) -> Result> { let s = if self.add_prefix_space && !s.starts_with(' ') { format!(" {}", s) } else { @@ -59,19 +59,11 @@ impl PreTokenizer for ByteLevel { // we don't want to return it let last = s[start..end].chars().last(); let next = s[end..].chars().nth(0); - if last.is_some() - && last.unwrap().is_separator_space() - && next.is_some() - && !next.unwrap().is_separator_space() - { - if let Some(newstr) = s[start..end] - .chars() - .collect::>() - .split_last() - .map(|(_, rest)| rest) - .map(|chars| chars.iter().collect::()) - { - return newstr; + if let (Some(last), Some(next)) = (last, next) { + if last.is_separator_space() && !next.is_separator_space() { + let bytes = s[start..end - 1].as_bytes().to_vec(); + let offsets = (start, end - 1); + return (bytes, offsets); } } // if our first char is not a whitespace but the previous one was, we return @@ -80,17 +72,22 @@ impl PreTokenizer for ByteLevel { let current = s[start..end].chars().nth(0).map(|c| c.is_whitespace()); if let (Some(prev), Some(current)) = (prev, current) { if prev.is_separator_space() && !current { - return format!("{}{}", prev, s[start..end].to_owned()); + let bytes = + [format!("{}", prev).as_bytes(), s[start..end].as_bytes()].concat(); + let offsets = (start - 1, end); + return (bytes, offsets); } } - s[start..end].to_owned() + (s[start..end].as_bytes().to_vec(), (start, end)) }) - .map(|s| { - s.into_bytes() - .iter() - .map(|b| std::char::from_u32(BYTES_CHAR[b]).unwrap()) - .collect() + .map(|(s, offsets)| { + ( + s.iter() + .map(|b| std::char::from_u32(BYTES_CHAR[b]).unwrap()) + .collect(), + offsets, + ) }) .collect()) } @@ -122,7 +119,16 @@ mod tests { .pre_tokenize("Hello my friend, how is your day going?") .unwrap(), vec![ - "Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?" + ("Hello".into(), (0, 5)), + ("Ġmy".into(), (5, 8)), + ("Ġfriend".into(), (8, 15)), + (",".into(), (15, 16)), + ("Ġhow".into(), (16, 20)), + ("Ġis".into(), (20, 23)), + ("Ġyour".into(), (23, 28)), + ("Ġday".into(), (28, 32)), + ("Ġgoing".into(), (32, 38)), + ("?".into(), (38, 39)) ] ); } @@ -154,7 +160,16 @@ mod tests { .pre_tokenize("Hello my friend, how is your day going?") .unwrap(), vec![ - "ĠHello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?" + ("ĠHello".into(), (0, 6)), + ("Ġmy".into(), (6, 9)), + ("Ġfriend".into(), (9, 16)), + (",".into(), (16, 17)), + ("Ġhow".into(), (17, 21)), + ("Ġis".into(), (21, 24)), + ("Ġyour".into(), (24, 29)), + ("Ġday".into(), (29, 33)), + ("Ġgoing".into(), (33, 39)), + ("?".into(), (39, 40)) ] ); } @@ -176,7 +191,7 @@ mod tests { let pre_tokenized = bl.pre_tokenize(&sample).unwrap(); let separated_tokens = pre_tokenized .into_iter() - .map(|token| token.split("").map(|t| t.into()).collect::>()) + .map(|(token, _)| token.split("").map(|t| t.into()).collect::>()) .flatten() .collect::>(); assert_eq!(sample, bl.decode(separated_tokens).unwrap()); @@ -192,11 +207,11 @@ mod tests { assert_eq!( p, vec![ - String::from("Hello"), - String::from("Ġthere"), - String::from("Ċ"), - String::from("Hello"), - String::from("Ġthere") + ("Hello".into(), (0, 5)), + ("Ġthere".into(), (5, 11)), + ("Ċ".into(), (11, 12)), + ("Hello".into(), (12, 17)), + ("Ġthere".into(), (17, 23)) ] ); } @@ -210,10 +225,10 @@ mod tests { assert_eq!( p, vec![ - String::from("Hello"), - String::from("Ġthere"), - String::from("ĠĠĠĠĠĠ"), - String::from("Ġdear") + ("Hello".into(), (0, 5)), + ("Ġthere".into(), (5, 11)), + ("ĠĠĠĠĠĠ".into(), (11, 17)), + ("Ġdear".into(), (17, 22)) ] ); }