mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
ByteLevel handles prefix spaces
This commit is contained in:
@ -32,7 +32,7 @@ impl ByteLevel {
|
||||
#[staticmethod]
|
||||
fn new() -> PyResult<Decoder> {
|
||||
Ok(Decoder {
|
||||
decoder: Container::Owned(Box::new(tk::decoders::byte_level::ByteLevel)),
|
||||
decoder: Container::Owned(Box::new(tk::decoders::byte_level::ByteLevel::new(false))),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -31,9 +31,11 @@ pub struct ByteLevel {}
|
||||
#[pymethods]
|
||||
impl ByteLevel {
|
||||
#[staticmethod]
|
||||
fn new() -> PyResult<PreTokenizer> {
|
||||
fn new(add_prefix_space: bool) -> PyResult<PreTokenizer> {
|
||||
Ok(PreTokenizer {
|
||||
pretok: Container::Owned(Box::new(tk::pre_tokenizers::byte_level::ByteLevel)),
|
||||
pretok: Container::Owned(Box::new(tk::pre_tokenizers::byte_level::ByteLevel::new(
|
||||
add_prefix_space,
|
||||
))),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -18,8 +18,8 @@ fn shell(matches: &ArgMatches) -> Result<()> {
|
||||
|
||||
let bpe = BPE::from_files(vocab, merges)?;
|
||||
let mut tokenizer = Tokenizer::new(Box::new(bpe));
|
||||
tokenizer.with_pre_tokenizer(Box::new(ByteLevel));
|
||||
tokenizer.with_decoder(Box::new(ByteLevel));
|
||||
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::new(true)));
|
||||
tokenizer.with_decoder(Box::new(ByteLevel::new(false)));
|
||||
|
||||
tokenizer.add_tokens(&[
|
||||
AddedToken {
|
||||
|
@ -30,11 +30,25 @@ lazy_static! {
|
||||
bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
|
||||
}
|
||||
|
||||
pub struct ByteLevel;
|
||||
pub struct ByteLevel {
|
||||
add_prefix_space: bool,
|
||||
}
|
||||
impl ByteLevel {
|
||||
pub fn new(add_prefix_space: bool) -> Self {
|
||||
ByteLevel { add_prefix_space }
|
||||
}
|
||||
}
|
||||
|
||||
impl PreTokenizer for ByteLevel {
|
||||
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
|
||||
let s = if self.add_prefix_space && !s.starts_with(' ') {
|
||||
format!(" {}", s)
|
||||
} else {
|
||||
s.to_owned()
|
||||
};
|
||||
|
||||
Ok(RE
|
||||
.captures_iter(s)
|
||||
.captures_iter(&s)
|
||||
.map(|capture| {
|
||||
let capture = capture.get(0).unwrap();
|
||||
let start = capture.start();
|
||||
@ -103,7 +117,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn pre_tokenization() {
|
||||
let pre_tok = ByteLevel;
|
||||
let pre_tok = ByteLevel::new(false);
|
||||
assert_eq!(
|
||||
pre_tok
|
||||
.pre_tokenize("Hello my friend, how is your day going?")
|
||||
@ -116,7 +130,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn decoding() {
|
||||
let decoder = ByteLevel;
|
||||
let decoder = ByteLevel::new(false);
|
||||
assert_eq!(
|
||||
"Hello my friend, how is your day going?",
|
||||
decoder
|
||||
@ -133,6 +147,19 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_prefix_space() {
|
||||
let pre_tok = ByteLevel::new(true);
|
||||
assert_eq!(
|
||||
pre_tok
|
||||
.pre_tokenize("Hello my friend, how is your day going?")
|
||||
.unwrap(),
|
||||
vec![
|
||||
"ĠHello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?"
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_works_on_separated_tokens() {
|
||||
let samples = vec![
|
||||
@ -145,7 +172,7 @@ mod tests {
|
||||
),
|
||||
];
|
||||
|
||||
let bl = ByteLevel;
|
||||
let bl = ByteLevel::new(false);
|
||||
for sample in samples {
|
||||
let pre_tokenized = bl.pre_tokenize(&sample).unwrap();
|
||||
let separated_tokens = pre_tokenized
|
||||
|
Reference in New Issue
Block a user