mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 20:28:22 +00:00
Update python bindings
This commit is contained in:
@@ -54,6 +54,13 @@ fn get_post_processor(_name: &str) -> Option<Box<dyn tk::tokenizer::PostProcesso
|
||||
None
|
||||
}
|
||||
|
||||
fn get_decoder(name: &str) -> Option<Box<dyn tk::tokenizer::Decoder + Sync>> {
|
||||
match name {
|
||||
"ByteLevel" => Some(Box::new(tk::decoders::byte_level::ByteLevel)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct Tokenizer {
|
||||
tokenizer: tk::tokenizer::Tokenizer,
|
||||
@@ -109,6 +116,7 @@ impl Tokenizer {
|
||||
)));
|
||||
}
|
||||
}
|
||||
tokenizer.with_normalizers(normalizers);
|
||||
}
|
||||
"post_processors" => {
|
||||
let mut processors = vec![];
|
||||
@@ -124,6 +132,18 @@ impl Tokenizer {
|
||||
)));
|
||||
}
|
||||
}
|
||||
tokenizer.with_post_processors(processors);
|
||||
}
|
||||
"decoder" => {
|
||||
let value = value.to_string();
|
||||
if let Some(decoder) = get_decoder(&value) {
|
||||
tokenizer.with_decoder(decoder);
|
||||
} else {
|
||||
return Err(exceptions::Exception::py_err(format!(
|
||||
"Decoder `{}` not found",
|
||||
value
|
||||
)));
|
||||
}
|
||||
}
|
||||
_ => println!("Ignored unknown kwarg `{}`", option),
|
||||
}
|
||||
@@ -154,6 +174,14 @@ impl Tokenizer {
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn decode(&self, ids: Vec<u32>) -> String {
|
||||
self.tokenizer.decode(ids)
|
||||
}
|
||||
|
||||
fn decode_batch(&self, sentences: Vec<Vec<u32>>) -> Vec<String> {
|
||||
self.tokenizer.decode_batch(sentences)
|
||||
}
|
||||
|
||||
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||
self.tokenizer.token_to_id(token)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user