Merge pull request #44 from huggingface/fixes

Fix bug + proposal for renaming + better repr for Tokenizers
This commit is contained in:
MOI Anthony
2020-01-08 09:45:14 -05:00
committed by GitHub
5 changed files with 53 additions and 13 deletions

View File

@ -5,8 +5,14 @@ from typing import List, Union, Tuple, Optional
class BaseTokenizer:
_tokenizer: Tokenizer
def __init__(self, tokenizer: Tokenizer):
def __init__(self, tokenizer: Tokenizer, parameters=None):
self._tokenizer = tokenizer
self._parameters = parameters if parameters is not None else {}
def __repr__(self):
return "Tokenizer(vocabulary_size={}, {})".format(
self._tokenizer.get_vocab_size(),
', '.join(k + ': ' + str(v) for k, v in self._parameters.items()))
def with_padding(self,
direction: Optional[str] = "right",
@ -68,7 +74,6 @@ class BaseTokenizer:
""" Disable truncation """
return self._tokenizer.without_truncation()
def add_tokens(self, tokens: List[Union[str, Tuple[str, bool]]]) -> int:
""" Add the given tokens to the vocabulary
@ -97,7 +102,7 @@ class BaseTokenizer:
Returns:
The number of tokens that were added to the vocabulary
"""
return self._tokenizer.add_special_tokens(tokens)
return self._tokenizer.add_special_tokens(special_tokens)
def encode(self, sequence: str, pair: Optional[str] = None) -> Encoding:
""" Encode the given sequence

View File

@ -12,6 +12,7 @@ class BertWordPieceTokenizer(BaseTokenizer):
def __init__(self,
vocab_file: Optional[str]=None,
add_special_tokens: bool=True,
unk_token: str="[UNK]",
sep_token: str="[SEP]",
cls_token: str="[CLS]",
@ -19,7 +20,8 @@ class BertWordPieceTokenizer(BaseTokenizer):
handle_chinese_chars: bool=True,
strip_accents: bool=True,
lowercase: bool=True,
prefix: str="##"):
wordpieces_prefix: str="##"):
if vocab_file is not None:
tokenizer = Tokenizer(WordPiece.from_files(vocab_file, unk_token=unk_token))
else:
@ -38,11 +40,24 @@ class BertWordPieceTokenizer(BaseTokenizer):
if cls_token_id is None:
raise TypeError("cls_token not found in the vocabulary")
tokenizer.post_processor = BertProcessing.new(
(sep_token, sep_token_id),
(cls_token, cls_token_id)
)
tokenizer.decoders = decoders.WordPiece.new(prefix=prefix)
if add_special_tokens:
tokenizer.post_processor = BertProcessing.new(
(sep_token, sep_token_id),
(cls_token, cls_token_id)
)
tokenizer.decoders = decoders.WordPiece.new(prefix=wordpieces_prefix)
super().__init__(tokenizer)
parameters = {
"model": "BertWordPiece",
"add_special_tokens": add_special_tokens,
"unk_token": unk_token,
"sep_token": sep_token,
"cls_token": cls_token,
"clean_text": clean_text,
"handle_chinese_chars": handle_chinese_chars,
"strip_accents": strip_accents,
"lowercase": lowercase,
"wordpieces_prefix": wordpieces_prefix,
}
super().__init__(tokenizer, parameters)

View File

@ -33,4 +33,11 @@ class BPETokenizer(BaseTokenizer):
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace.new()
tokenizer.decoder = decoders.BPEDecoder.new(suffix=suffix)
super().__init__(tokenizer)
parameters = {
"model": "BPE",
"unk_token": unk_token,
"suffix": suffix,
"dropout": dropout,
}
super().__init__(tokenizer, parameters)

View File

@ -24,4 +24,9 @@ class ByteLevelBPETokenizer(BaseTokenizer):
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel.new(add_prefix_space=add_prefix_space)
tokenizer.decoder = decoders.ByteLevel.new()
super().__init__(tokenizer)
parameters = {
"model": "ByteLevelBPE",
"add_prefix_space": add_prefix_space,
}
super().__init__(tokenizer, parameters)

View File

@ -32,4 +32,12 @@ class SentencePieceBPETokenizer(BaseTokenizer):
tokenizer.decoder = decoders.Metaspace.new(replacement=replacement,
add_prefix_space=add_prefix_space)
super().__init__(tokenizer)
parameters = {
"model": "SentencePieceBPE",
"unk_token": unk_token,
"replacement": replacement,
"add_prefix_space": add_prefix_space,
"dropout": dropout,
}
super().__init__(tokenizer, parameters)