mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 08:45:38 +00:00
ByteLevel is also a Decoder
This commit is contained in:
@ -0,0 +1,2 @@
|
|||||||
|
// Re-export this as a decoder
|
||||||
|
pub use super::pre_tokenizers::byte_level;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::tokenizer::PreTokenizer;
|
use crate::tokenizer::{Decoder, PreTokenizer};
|
||||||
use onig::Regex;
|
use onig::Regex;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -22,17 +22,18 @@ fn bytes_char() -> HashMap<u8, u32> {
|
|||||||
bs.into_iter().zip(cs).collect()
|
bs.into_iter().zip(cs).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lazy_static! {
|
||||||
|
static ref RE: Regex =
|
||||||
|
Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+")
|
||||||
|
.unwrap();
|
||||||
|
static ref BYTES_CHAR: HashMap<u8, u32> = bytes_char();
|
||||||
|
static ref CHAR_BYTES: HashMap<u32, u8> =
|
||||||
|
bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ByteLevel;
|
pub struct ByteLevel;
|
||||||
impl PreTokenizer for ByteLevel {
|
impl PreTokenizer for ByteLevel {
|
||||||
fn pre_tokenize(&self, s: &str) -> Vec<String> {
|
fn pre_tokenize(&self, s: &str) -> Vec<String> {
|
||||||
lazy_static! {
|
|
||||||
static ref RE: Regex = Regex::new(
|
|
||||||
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
static ref BYTES_CHAR: HashMap<u8, u32> = bytes_char();
|
|
||||||
}
|
|
||||||
|
|
||||||
RE.find_iter(s)
|
RE.find_iter(s)
|
||||||
.map(|(start, end)| s[start..end].to_owned())
|
.map(|(start, end)| s[start..end].to_owned())
|
||||||
.map(|s| {
|
.map(|s| {
|
||||||
@ -45,13 +46,29 @@ impl PreTokenizer for ByteLevel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Decoder for ByteLevel {
|
||||||
|
fn decode(&self, tokens: Vec<String>) -> String {
|
||||||
|
tokens
|
||||||
|
.into_iter()
|
||||||
|
.map(|token| {
|
||||||
|
let bytes = token
|
||||||
|
.chars()
|
||||||
|
.map(|c| CHAR_BYTES[&(c as u32)])
|
||||||
|
.collect::<Vec<u8>>();
|
||||||
|
String::from_utf8_lossy(&bytes).into_owned()
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::ByteLevel;
|
use super::ByteLevel;
|
||||||
use crate::tokenizer::PreTokenizer;
|
use crate::tokenizer::{Decoder, PreTokenizer};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn basic() {
|
fn pre_tokenization() {
|
||||||
let pre_tok = ByteLevel;
|
let pre_tok = ByteLevel;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
pre_tok.pre_tokenize("Hello my friend, how is your day going?"),
|
pre_tok.pre_tokenize("Hello my friend, how is your day going?"),
|
||||||
@ -60,4 +77,20 @@ mod tests {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decoding() {
|
||||||
|
let decoder = ByteLevel;
|
||||||
|
assert_eq!(
|
||||||
|
"Hello my friend, how is your day going?",
|
||||||
|
decoder.decode(
|
||||||
|
vec![
|
||||||
|
"Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?"
|
||||||
|
]
|
||||||
|
.into_iter()
|
||||||
|
.map(|s| s.into())
|
||||||
|
.collect::<Vec<String>>()
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user