From 2419c14e42844e0b51a80ca0be4c346e1128b32c Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Thu, 21 Nov 2019 11:52:55 -0500 Subject: [PATCH] ByteLevel is also a Decoder --- tokenizers/src/decoders/mod.rs | 2 + tokenizers/src/pre_tokenizers/byte_level.rs | 55 ++++++++++++++++----- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/tokenizers/src/decoders/mod.rs b/tokenizers/src/decoders/mod.rs index e69de29b..e014c665 100644 --- a/tokenizers/src/decoders/mod.rs +++ b/tokenizers/src/decoders/mod.rs @@ -0,0 +1,2 @@ +// Re-export this as a decoder +pub use super::pre_tokenizers::byte_level; diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 577beaea..864409af 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -1,4 +1,4 @@ -use crate::tokenizer::PreTokenizer; +use crate::tokenizer::{Decoder, PreTokenizer}; use onig::Regex; use std::collections::HashMap; @@ -22,17 +22,18 @@ fn bytes_char() -> HashMap { bs.into_iter().zip(cs).collect() } +lazy_static! { + static ref RE: Regex = + Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+") + .unwrap(); + static ref BYTES_CHAR: HashMap = bytes_char(); + static ref CHAR_BYTES: HashMap = + bytes_char().into_iter().map(|(c, b)| (b, c)).collect(); +} + pub struct ByteLevel; impl PreTokenizer for ByteLevel { fn pre_tokenize(&self, s: &str) -> Vec { - lazy_static! { - static ref RE: Regex = Regex::new( - r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+" - ) - .unwrap(); - static ref BYTES_CHAR: HashMap = bytes_char(); - } - RE.find_iter(s) .map(|(start, end)| s[start..end].to_owned()) .map(|s| { @@ -45,13 +46,29 @@ impl PreTokenizer for ByteLevel { } } +impl Decoder for ByteLevel { + fn decode(&self, tokens: Vec) -> String { + tokens + .into_iter() + .map(|token| { + let bytes = token + .chars() + .map(|c| CHAR_BYTES[&(c as u32)]) + .collect::>(); + String::from_utf8_lossy(&bytes).into_owned() + }) + .collect::>() + .join("") + } +} + #[cfg(test)] mod tests { use super::ByteLevel; - use crate::tokenizer::PreTokenizer; + use crate::tokenizer::{Decoder, PreTokenizer}; #[test] - fn basic() { + fn pre_tokenization() { let pre_tok = ByteLevel; assert_eq!( pre_tok.pre_tokenize("Hello my friend, how is your day going?"), @@ -60,4 +77,20 @@ mod tests { ] ); } + + #[test] + fn decoding() { + let decoder = ByteLevel; + assert_eq!( + "Hello my friend, how is your day going?", + decoder.decode( + vec![ + "Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?" + ] + .into_iter() + .map(|s| s.into()) + .collect::>() + ) + ); + } }