BasicPreTokenizer handles do_basic_tokenize for Bert

This commit is contained in:
Anthony MOI
2019-12-17 17:35:26 -05:00
parent 3f95248d6d
commit 1b66d87fd3

View File

@ -58,6 +58,8 @@ fn is_chinese_char(c: char) -> bool {
}
pub struct BasicPreTokenizer {
/// Whether to do the basic tokenization
do_basic_tokenize: bool,
/// Whether to lower case the input.
do_lower_case: bool,
/// A list of token not to split.
@ -68,11 +70,13 @@ pub struct BasicPreTokenizer {
impl BasicPreTokenizer {
pub fn new(
do_basic_tokenize: bool,
do_lower_case: bool,
never_split: HashSet<String>,
tokenize_chinese_chars: bool,
) -> Self {
BasicPreTokenizer {
do_basic_tokenize,
do_lower_case,
never_split,
tokenize_chinese_chars,
@ -145,27 +149,34 @@ impl BasicPreTokenizer {
impl PreTokenizer for BasicPreTokenizer {
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
let mut text = self.clean_text(s);
if !self.do_basic_tokenize {
Ok(whitespace_tokenize(&s)
.into_iter()
.map(|s| s.to_owned())
.collect())
} else {
let mut text = self.clean_text(s);
// This was added on November 1st, 2018 for the multilingual and Chinese
// models. This is also applied to the English models now, but it doesn't
// matter since the English models were not trained on any Chinese data
// and generally don't have any Chinese data in them (there are Chinese
// characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.).
if self.tokenize_chinese_chars {
text = self.tokenize_chinese_chars(&text);
}
let orig_tokens = whitespace_tokenize(&text);
let mut split_tokens = vec![];
for token in orig_tokens {
let mut tk = token.to_owned();
if self.do_lower_case && !self.never_split.contains(token) {
tk = self.run_strip_accents(&token.to_lowercase())
// This was added on November 1st, 2018 for the multilingual and Chinese
// models. This is also applied to the English models now, but it doesn't
// matter since the English models were not trained on any Chinese data
// and generally don't have any Chinese data in them (there are Chinese
// characters in the vocabulary because Wikipedia does have some Chinese
// words in the English Wikipedia.).
if self.tokenize_chinese_chars {
text = self.tokenize_chinese_chars(&text);
}
let orig_tokens = whitespace_tokenize(&text);
let mut split_tokens = vec![];
for token in orig_tokens {
let mut tk = token.to_owned();
if self.do_lower_case && !self.never_split.contains(token) {
tk = self.run_strip_accents(&token.to_lowercase())
}
split_tokens.extend(self.run_split_on_punc(&tk));
}
split_tokens.extend(self.run_split_on_punc(&tk));
}
Ok(split_tokens)
Ok(split_tokens)
}
}
}