Update python bindings

This commit is contained in:
Anthony MOI
2019-11-21 11:55:07 -05:00
parent 6853e6c904
commit c28a83cdc4
2 changed files with 33 additions and 2 deletions

View File

@ -43,7 +43,7 @@ Namespaces are one honking great idea -- let's do more of those!
tok_p = GPT2Tokenizer.from_pretrained('gpt2')
tok_r = Tokenizer.bpe_from_files(args.vocab, args.merges, pre_tokenizer="ByteLevel")
tok_r = Tokenizer.bpe_from_files(args.vocab, args.merges, pre_tokenizer="ByteLevel", decoder="ByteLevel")
def tokenize_r():
# return [ tok_r.encode(sentence) for sentence in text]
@ -66,4 +66,7 @@ encoded_p = tokenize_p()
end = time.time()
print(f"Transformer tokenizer took: {end - start} sec")
assert([ [ token.id for token in sentence] for sentence in encoded_r ] == encoded_p)
ids_r = [ [ token.id for token in sentence ] for sentence in encoded_r ]
assert(ids_r == encoded_p)
print(f"Decoded sentences: {tok_r.decode_batch(ids_r)}")

View File

@ -54,6 +54,13 @@ fn get_post_processor(_name: &str) -> Option<Box<dyn tk::tokenizer::PostProcesso
None
}
fn get_decoder(name: &str) -> Option<Box<dyn tk::tokenizer::Decoder + Sync>> {
match name {
"ByteLevel" => Some(Box::new(tk::decoders::byte_level::ByteLevel)),
_ => None,
}
}
#[pyclass]
struct Tokenizer {
tokenizer: tk::tokenizer::Tokenizer,
@ -109,6 +116,7 @@ impl Tokenizer {
)));
}
}
tokenizer.with_normalizers(normalizers);
}
"post_processors" => {
let mut processors = vec![];
@ -124,6 +132,18 @@ impl Tokenizer {
)));
}
}
tokenizer.with_post_processors(processors);
}
"decoder" => {
let value = value.to_string();
if let Some(decoder) = get_decoder(&value) {
tokenizer.with_decoder(decoder);
} else {
return Err(exceptions::Exception::py_err(format!(
"Decoder `{}` not found",
value
)));
}
}
_ => println!("Ignored unknown kwarg `{}`", option),
}
@ -154,6 +174,14 @@ impl Tokenizer {
.collect()
}
fn decode(&self, ids: Vec<u32>) -> String {
self.tokenizer.decode(ids)
}
fn decode_batch(&self, sentences: Vec<Vec<u32>>) -> Vec<String> {
self.tokenizer.decode_batch(sentences)
}
fn token_to_id(&self, token: &str) -> Option<u32> {
self.tokenizer.token_to_id(token)
}