ByteLevel is also a Decoder

This commit is contained in:
Anthony MOI
2019-11-21 11:52:55 -05:00
parent 56e37475c3
commit 2419c14e42
2 changed files with 46 additions and 11 deletions

View File

@ -0,0 +1,2 @@
// Re-export this as a decoder
pub use super::pre_tokenizers::byte_level;

View File

@ -1,4 +1,4 @@
use crate::tokenizer::PreTokenizer;
use crate::tokenizer::{Decoder, PreTokenizer};
use onig::Regex;
use std::collections::HashMap;
@ -22,17 +22,18 @@ fn bytes_char() -> HashMap<u8, u32> {
bs.into_iter().zip(cs).collect()
}
lazy_static! {
static ref RE: Regex =
Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+")
.unwrap();
static ref BYTES_CHAR: HashMap<u8, u32> = bytes_char();
static ref CHAR_BYTES: HashMap<u32, u8> =
bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
}
pub struct ByteLevel;
impl PreTokenizer for ByteLevel {
fn pre_tokenize(&self, s: &str) -> Vec<String> {
lazy_static! {
static ref RE: Regex = Regex::new(
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
)
.unwrap();
static ref BYTES_CHAR: HashMap<u8, u32> = bytes_char();
}
RE.find_iter(s)
.map(|(start, end)| s[start..end].to_owned())
.map(|s| {
@ -45,13 +46,29 @@ impl PreTokenizer for ByteLevel {
}
}
impl Decoder for ByteLevel {
fn decode(&self, tokens: Vec<String>) -> String {
tokens
.into_iter()
.map(|token| {
let bytes = token
.chars()
.map(|c| CHAR_BYTES[&(c as u32)])
.collect::<Vec<u8>>();
String::from_utf8_lossy(&bytes).into_owned()
})
.collect::<Vec<_>>()
.join("")
}
}
#[cfg(test)]
mod tests {
use super::ByteLevel;
use crate::tokenizer::PreTokenizer;
use crate::tokenizer::{Decoder, PreTokenizer};
#[test]
fn basic() {
fn pre_tokenization() {
let pre_tok = ByteLevel;
assert_eq!(
pre_tok.pre_tokenize("Hello my friend, how is your day going?"),
@ -60,4 +77,20 @@ mod tests {
]
);
}
#[test]
fn decoding() {
let decoder = ByteLevel;
assert_eq!(
"Hello my friend, how is your day going?",
decoder.decode(
vec![
"Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?"
]
.into_iter()
.map(|s| s.into())
.collect::<Vec<String>>()
)
);
}
}