diff --git a/bindings/python/example.py b/bindings/python/example.py index eb341608..cdc319cc 100644 --- a/bindings/python/example.py +++ b/bindings/python/example.py @@ -43,7 +43,7 @@ Namespaces are one honking great idea -- let's do more of those! tok_p = GPT2Tokenizer.from_pretrained('gpt2') -tok_r = Tokenizer.bpe_from_files(args.vocab, args.merges, pre_tokenizer="ByteLevel") +tok_r = Tokenizer.bpe_from_files(args.vocab, args.merges, pre_tokenizer="ByteLevel", decoder="ByteLevel") def tokenize_r(): # return [ tok_r.encode(sentence) for sentence in text] @@ -66,4 +66,7 @@ encoded_p = tokenize_p() end = time.time() print(f"Transformer tokenizer took: {end - start} sec") -assert([ [ token.id for token in sentence] for sentence in encoded_r ] == encoded_p) +ids_r = [ [ token.id for token in sentence ] for sentence in encoded_r ] +assert(ids_r == encoded_p) + +print(f"Decoded sentences: {tok_r.decode_batch(ids_r)}") diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 20bc44ac..34e2bcce 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -54,6 +54,13 @@ fn get_post_processor(_name: &str) -> Option Option> { + match name { + "ByteLevel" => Some(Box::new(tk::decoders::byte_level::ByteLevel)), + _ => None, + } +} + #[pyclass] struct Tokenizer { tokenizer: tk::tokenizer::Tokenizer, @@ -109,6 +116,7 @@ impl Tokenizer { ))); } } + tokenizer.with_normalizers(normalizers); } "post_processors" => { let mut processors = vec![]; @@ -124,6 +132,18 @@ impl Tokenizer { ))); } } + tokenizer.with_post_processors(processors); + } + "decoder" => { + let value = value.to_string(); + if let Some(decoder) = get_decoder(&value) { + tokenizer.with_decoder(decoder); + } else { + return Err(exceptions::Exception::py_err(format!( + "Decoder `{}` not found", + value + ))); + } } _ => println!("Ignored unknown kwarg `{}`", option), } @@ -154,6 +174,14 @@ impl Tokenizer { .collect() } + fn decode(&self, ids: Vec) -> String { + self.tokenizer.decode(ids) + } + + fn decode_batch(&self, sentences: Vec>) -> Vec { + self.tokenizer.decode_batch(sentences) + } + fn token_to_id(&self, token: &str) -> Option { self.tokenizer.token_to_id(token) }