Python - Handle kwargs for bert modules

This commit is contained in:
Anthony MOI
2019-12-13 15:28:29 -05:00
parent 3355be89cd
commit e93cc62a71
3 changed files with 35 additions and 8 deletions

View File

@ -61,8 +61,8 @@ elif args.type == "bert":
print("Running Bert tokenizer")
tok_p = BertTokenizer.from_pretrained('bert-base-uncased')
tok_r = Tokenizer(models.WordPiece.from_files(args.vocab))
tok_r.with_pre_tokenizer(pre_tokenizers.BasicPreTokenizer.new())
tok_r = Tokenizer(models.WordPiece.from_files(args.vocab, unk_token="[UNK]", max_input_chars_per_word=100))
tok_r.with_pre_tokenizer(pre_tokenizers.BasicPreTokenizer.new(do_lower_case=True, tokenize_chinese_chars=True, never_split=[]))
tok_r.with_decoder(decoders.WordPiece.new())
else:
raise Exception(f"Unknown type {args.type}")

View File

@ -4,6 +4,7 @@ use super::utils::Container;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
/// A Model represents some tokenization algorithm like BPE or Word
/// This class cannot be constructed directly. Please use one of the concrete models.
@ -71,10 +72,21 @@ impl WordPiece {
///
/// Instantiate a new WordPiece model using the provided vocabulary file
#[staticmethod]
fn from_files(vocab: &str) -> PyResult<Model> {
// TODO: Parse kwargs for these
let unk_token = String::from("[UNK]");
let max_input_chars_per_word = Some(100);
#[args(kwargs = "**")]
fn from_files(vocab: &str, kwargs: Option<&PyDict>) -> PyResult<Model> {
let mut unk_token = String::from("[UNK]");
let mut max_input_chars_per_word = Some(100);
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
let key: &str = key.extract()?;
match key {
"unk_token" => unk_token = val.extract()?,
"max_input_chars_per_word" => max_input_chars_per_word = Some(val.extract()?),
_ => println!("Ignored unknown kwargs option {}", key),
}
}
}
match tk::models::wordpiece::WordPiece::from_files(
vocab,

View File

@ -43,12 +43,27 @@ pub struct BasicPreTokenizer {}
#[pymethods]
impl BasicPreTokenizer {
#[staticmethod]
fn new() -> PyResult<PreTokenizer> {
// TODO: Parse kwargs for these
#[args(kwargs = "**")]
fn new(kwargs: Option<&PyDict>) -> PyResult<PreTokenizer> {
let mut do_lower_case = true;
let mut never_split = HashSet::new();
let mut tokenize_chinese_chars = true;
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
let key: &str = key.extract()?;
match key {
"do_lower_case" => do_lower_case = val.extract()?,
"tokenize_chinese_chars" => tokenize_chinese_chars = val.extract()?,
"never_split" => {
let values: Vec<String> = val.extract()?;
never_split = values.into_iter().collect();
}
_ => println!("Ignored unknown kwargs option {}", key),
}
}
}
Ok(PreTokenizer {
pretok: Container::Owned(Box::new(tk::pre_tokenizers::basic::BasicPreTokenizer::new(
do_lower_case,