Python - Fix cases where str expected instead of AddedToken

This commit is contained in:
Anthony MOI
2020-03-25 19:22:37 -04:00
parent d25eb075c8
commit f8d54edcdd
4 changed files with 33 additions and 19 deletions

View File

@@ -3,6 +3,7 @@ extern crate tokenizers as tk;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::PyObjectProtocol;
use super::decoders::Decoder;
use super::encoding::Encoding;
@@ -45,6 +46,19 @@ impl AddedToken {
Ok(())
}
}
#[pyproto]
impl PyObjectProtocol for AddedToken {
fn __str__(&'p self) -> PyResult<&'p str> {
Ok(&self.token.content)
}
fn __repr__(&self) -> PyResult<String> {
Ok(format!(
"AddedToken(\"{}\", rstrip={}, lstrip={}, single_word={})",
self.token.content, self.token.rstrip, self.token.lstrip, self.token.single_word
))
}
}
#[pyclass(dict)]
pub struct Tokenizer {

View File

@@ -28,21 +28,21 @@ class BertWordPieceTokenizer(BaseTokenizer):
):
if vocab_file is not None:
tokenizer = Tokenizer(WordPiece.from_files(vocab_file, unk_token=unk_token))
tokenizer = Tokenizer(WordPiece.from_files(vocab_file, unk_token=str(unk_token)))
else:
tokenizer = Tokenizer(WordPiece.empty())
# Let the tokenizer know about special tokens if they are part of the vocab
if tokenizer.token_to_id(unk_token) is not None:
tokenizer.add_special_tokens([unk_token])
if tokenizer.token_to_id(sep_token) is not None:
tokenizer.add_special_tokens([sep_token])
if tokenizer.token_to_id(cls_token) is not None:
tokenizer.add_special_tokens([cls_token])
if tokenizer.token_to_id(pad_token) is not None:
tokenizer.add_special_tokens([pad_token])
if tokenizer.token_to_id(mask_token) is not None:
tokenizer.add_special_tokens([mask_token])
if tokenizer.token_to_id(str(unk_token)) is not None:
tokenizer.add_special_tokens([str(unk_token)])
if tokenizer.token_to_id(str(sep_token)) is not None:
tokenizer.add_special_tokens([str(sep_token)])
if tokenizer.token_to_id(str(cls_token)) is not None:
tokenizer.add_special_tokens([str(cls_token)])
if tokenizer.token_to_id(str(pad_token)) is not None:
tokenizer.add_special_tokens([str(pad_token)])
if tokenizer.token_to_id(str(mask_token)) is not None:
tokenizer.add_special_tokens([str(mask_token)])
tokenizer.normalizer = BertNormalizer(
clean_text=clean_text,
@@ -53,15 +53,15 @@ class BertWordPieceTokenizer(BaseTokenizer):
tokenizer.pre_tokenizer = BertPreTokenizer()
if add_special_tokens and vocab_file is not None:
sep_token_id = tokenizer.token_to_id(sep_token)
sep_token_id = tokenizer.token_to_id(str(sep_token))
if sep_token_id is None:
raise TypeError("sep_token not found in the vocabulary")
cls_token_id = tokenizer.token_to_id(cls_token)
cls_token_id = tokenizer.token_to_id(str(cls_token))
if cls_token_id is None:
raise TypeError("cls_token not found in the vocabulary")
tokenizer.post_processor = BertProcessing(
(sep_token, sep_token_id), (cls_token, cls_token_id)
(str(sep_token), sep_token_id), (str(cls_token), cls_token_id)
)
tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)

View File

@@ -28,15 +28,15 @@ class CharBPETokenizer(BaseTokenizer):
vocab_file,
merges_file,
dropout=dropout,
unk_token=unk_token,
unk_token=str(unk_token),
end_of_word_suffix=suffix,
)
)
else:
tokenizer = Tokenizer(BPE.empty())
if tokenizer.token_to_id(unk_token) is not None:
tokenizer.add_special_tokens([unk_token])
if tokenizer.token_to_id(str(unk_token)) is not None:
tokenizer.add_special_tokens([str(unk_token)])
# Check for Unicode normalization first (before everything else)
normalizers = []

View File

@@ -28,8 +28,8 @@ class SentencePieceBPETokenizer(BaseTokenizer):
else:
tokenizer = Tokenizer(BPE.empty())
if tokenizer.token_to_id(unk_token) is not None:
tokenizer.add_special_tokens([unk_token])
if tokenizer.token_to_id(str(unk_token)) is not None:
tokenizer.add_special_tokens([str(unk_token)])
tokenizer.normalizer = NFKC()
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(