ByteLevel handles prefix spaces

This commit is contained in:
Anthony MOI
2019-12-17 18:35:40 -05:00
parent 6766585965
commit 4d14b08afe
4 changed files with 39 additions and 10 deletions

View File

@ -32,7 +32,7 @@ impl ByteLevel {
#[staticmethod]
fn new() -> PyResult<Decoder> {
Ok(Decoder {
decoder: Container::Owned(Box::new(tk::decoders::byte_level::ByteLevel)),
decoder: Container::Owned(Box::new(tk::decoders::byte_level::ByteLevel::new(false))),
})
}
}

View File

@ -31,9 +31,11 @@ pub struct ByteLevel {}
#[pymethods]
impl ByteLevel {
#[staticmethod]
fn new() -> PyResult<PreTokenizer> {
fn new(add_prefix_space: bool) -> PyResult<PreTokenizer> {
Ok(PreTokenizer {
pretok: Container::Owned(Box::new(tk::pre_tokenizers::byte_level::ByteLevel)),
pretok: Container::Owned(Box::new(tk::pre_tokenizers::byte_level::ByteLevel::new(
add_prefix_space,
))),
})
}
}

View File

@ -18,8 +18,8 @@ fn shell(matches: &ArgMatches) -> Result<()> {
let bpe = BPE::from_files(vocab, merges)?;
let mut tokenizer = Tokenizer::new(Box::new(bpe));
tokenizer.with_pre_tokenizer(Box::new(ByteLevel));
tokenizer.with_decoder(Box::new(ByteLevel));
tokenizer.with_pre_tokenizer(Box::new(ByteLevel::new(true)));
tokenizer.with_decoder(Box::new(ByteLevel::new(false)));
tokenizer.add_tokens(&[
AddedToken {

View File

@ -30,11 +30,25 @@ lazy_static! {
bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
}
pub struct ByteLevel;
pub struct ByteLevel {
add_prefix_space: bool,
}
impl ByteLevel {
pub fn new(add_prefix_space: bool) -> Self {
ByteLevel { add_prefix_space }
}
}
impl PreTokenizer for ByteLevel {
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
let s = if self.add_prefix_space && !s.starts_with(' ') {
format!(" {}", s)
} else {
s.to_owned()
};
Ok(RE
.captures_iter(s)
.captures_iter(&s)
.map(|capture| {
let capture = capture.get(0).unwrap();
let start = capture.start();
@ -103,7 +117,7 @@ mod tests {
#[test]
fn pre_tokenization() {
let pre_tok = ByteLevel;
let pre_tok = ByteLevel::new(false);
assert_eq!(
pre_tok
.pre_tokenize("Hello my friend, how is your day going?")
@ -116,7 +130,7 @@ mod tests {
#[test]
fn decoding() {
let decoder = ByteLevel;
let decoder = ByteLevel::new(false);
assert_eq!(
"Hello my friend, how is your day going?",
decoder
@ -133,6 +147,19 @@ mod tests {
);
}
#[test]
fn add_prefix_space() {
let pre_tok = ByteLevel::new(true);
assert_eq!(
pre_tok
.pre_tokenize("Hello my friend, how is your day going?")
.unwrap(),
vec![
"ĠHello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?"
]
);
}
#[test]
fn decode_works_on_separated_tokens() {
let samples = vec![
@ -145,7 +172,7 @@ mod tests {
),
];
let bl = ByteLevel;
let bl = ByteLevel::new(false);
for sample in samples {
let pre_tokenized = bl.pre_tokenize(&sample).unwrap();
let separated_tokens = pre_tokenized