Python - Black auto formatting

This commit is contained in:
Anthony MOI
2020-02-18 10:45:36 -05:00
parent 4706151c32
commit 81be207819
16 changed files with 179 additions and 211 deletions

View File

@ -7,28 +7,33 @@ parser.add_argument("--vocab", default=None, type=str, required=True, help="The
parser.add_argument("--merges", default=None, type=str, required=True, help="The merges.txt file")
args = parser.parse_args()
class GoodCustom:
"""GoodCustom
This class represents a good custom PreTokenizer that will be called
by `tokenizers` when needed
"""
def pre_tokenize(self, sentence):
return sentence.split(" ")
def decode(self, tokens):
return ", ".join(tokens)
class BadCustom:
"""Bad Pretok
This class represents a bad custom PreTokenizer that will trigger an exception
when called by `tokenizers`
"""
def pre_tokenize(self, sentence):
return None
def decode(self, tokens):
return None
def tokenize(sentence):
output = tokenizer.encode(sentence).tokens
print(f"`{sentence}` tokenized to {output}")
@ -66,4 +71,3 @@ try:
encoding = tokenizer.encode("Hey friend!")
except:
print("Bad tokenizer didn't work")

View File

@ -3,8 +3,9 @@ import argparse
from tqdm import tqdm
import logging
logging.getLogger('transformers').disabled = True
logging.getLogger('transformers.tokenization_utils').disabled = True
logging.getLogger("transformers").disabled = True
logging.getLogger("transformers.tokenization_utils").disabled = True
from tokenizers import Tokenizer, pre_tokenizers, decoders
from tokenizers.models import BPE, WordPiece
@ -18,7 +19,7 @@ parser.add_argument("--type", default="gpt2", type=str, help="The type of tokeni
parser.add_argument("--file", default=None, type=str, help="The file to encode")
parser.add_argument("--vocab", default=None, type=str, required=True, help="The vocab file")
parser.add_argument("--merges", default=None, type=str, help="The merges.txt file")
parser.add_argument("--debug", action='store_true', help="Verbose output")
parser.add_argument("--debug", action="store_true", help="Verbose output")
args = parser.parse_args()
if args.type == "gpt2" and args.merges is None:
@ -26,7 +27,7 @@ if args.type == "gpt2" and args.merges is None:
if args.file is not None:
with open(args.file, "r") as fp:
text = [ line.strip() for line in fp ]
text = [line.strip() for line in fp]
else:
text = """
The Zen of Python, by Tim Peters
@ -49,11 +50,13 @@ Although never is often better than *right* now.
If the implementation is hard to explain, it's a bad idea.
If the implementation is easy to explain, it may be a good idea.
Namespaces are one honking great idea -- let's do more of those!
""".split("\n")
""".split(
"\n"
)
if args.type == "gpt2":
print("Running GPT-2 tokenizer")
tok_p = GPT2Tokenizer.from_pretrained('gpt2')
tok_p = GPT2Tokenizer.from_pretrained("gpt2")
# Create a Tokenizer using BPE
tok_r = Tokenizer(BPE.from_files(args.vocab, args.merges))
@ -65,33 +68,30 @@ elif args.type == "bert":
print("Running Bert tokenizer")
tok_p = BertTokenizer.from_pretrained(args.vocab)
tok_r = Tokenizer(WordPiece.from_files(
args.vocab,
unk_token="[UNK]",
max_input_chars_per_word=100)
tok_r = Tokenizer(
WordPiece.from_files(args.vocab, unk_token="[UNK]", max_input_chars_per_word=100)
)
tok_r.normalizer = BertNormalizer(
clean_text=True,
handle_chinese_chars=True,
strip_accents=True,
lowercase=True,
clean_text=True, handle_chinese_chars=True, strip_accents=True, lowercase=True,
)
# tok_r.pre_tokenizer = pre_tokenizers.Whitespace()
tok_r.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
tok_r.decoder = decoders.WordPiece()
tok_r.post_processor = BertProcessing(
("[SEP]", tok_r.token_to_id("[SEP]")),
("[CLS]", tok_r.token_to_id("[CLS]")),
("[SEP]", tok_r.token_to_id("[SEP]")), ("[CLS]", tok_r.token_to_id("[CLS]")),
)
else:
raise Exception(f"Unknown type {args.type}")
def tokenize_r():
return tok_r.encode_batch(text);
return tok_r.encode_batch(text)
def tokenize_p():
return [tok_p.encode(sentence, add_special_tokens=True) for sentence in tqdm(text)]
print(f"Tokenizing {len(text)} lines")
# Rust version
@ -110,7 +110,7 @@ print(f"Transformer tokenizer took: {time_p} sec")
print(f"SpeedUp Ratio: {time_p / time_r}")
ids_r = [ sentence.ids for sentence in encoded_r ]
ids_r = [sentence.ids for sentence in encoded_r]
diff_ids = 0
for i in range(0, len(encoded_r)):
if encoded_r[i].ids != encoded_p[i]:
@ -124,8 +124,8 @@ for i in range(0, len(encoded_r)):
print("")
print(f"Ids differences: {diff_ids}")
decoded_r = tok_r.decode_batch([ sentence.ids for sentence in encoded_r ], False)
decoded_p = [ tok_p.decode(en) for en in encoded_p ]
decoded_r = tok_r.decode_batch([sentence.ids for sentence in encoded_r], False)
decoded_p = [tok_p.decode(en) for en in encoded_p]
diff_decoded = 0
for i in range(0, len(text)):
if decoded_r[i] != decoded_p[i]:

View File

@ -4,21 +4,24 @@ import glob
from tokenizers import BertWordPieceTokenizer
parser = argparse.ArgumentParser()
parser.add_argument("--files",
default=None,
metavar="path",
type=str,
required=True,
help="The files to use as training; accept '**/*.txt' type of patterns \
if enclosed in quotes")
parser.add_argument("--out",
default="./",
type=str,
help="Path to the output directory, where the files will be saved")
parser.add_argument("--name",
default="bert-wordpiece",
type=str,
help="The name of the output vocab files")
parser.add_argument(
"--files",
default=None,
metavar="path",
type=str,
required=True,
help="The files to use as training; accept '**/*.txt' type of patterns \
if enclosed in quotes",
)
parser.add_argument(
"--out",
default="./",
type=str,
help="Path to the output directory, where the files will be saved",
)
parser.add_argument(
"--name", default="bert-wordpiece", type=str, help="The name of the output vocab files"
)
args = parser.parse_args()
files = glob.glob(args.files)
@ -29,11 +32,7 @@ if not files:
# Initialize an empty tokenizer
tokenizer = BertWordPieceTokenizer(
clean_text=True,
handle_chinese_chars=True,
strip_accents=True,
lowercase=True,
clean_text=True, handle_chinese_chars=True, strip_accents=True, lowercase=True,
)
# And then train
@ -44,7 +43,7 @@ trainer = tokenizer.train(
show_progress=True,
special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"],
limit_alphabet=1000,
wordpieces_prefix="##"
wordpieces_prefix="##",
)
# Save the files

View File

@ -5,21 +5,24 @@ from os.path import join
from tokenizers import ByteLevelBPETokenizer
parser = argparse.ArgumentParser()
parser.add_argument("--files",
default=None,
metavar="path",
type=str,
required=True,
help="The files to use as training; accept '**/*.txt' type of patterns \
if enclosed in quotes")
parser.add_argument("--out",
default="./",
type=str,
help="Path to the output directory, where the files will be saved")
parser.add_argument("--name",
default="bpe-bytelevel",
type=str,
help="The name of the output vocab files")
parser.add_argument(
"--files",
default=None,
metavar="path",
type=str,
required=True,
help="The files to use as training; accept '**/*.txt' type of patterns \
if enclosed in quotes",
)
parser.add_argument(
"--out",
default="./",
type=str,
help="Path to the output directory, where the files will be saved",
)
parser.add_argument(
"--name", default="bpe-bytelevel", type=str, help="The name of the output vocab files"
)
args = parser.parse_args()
files = glob.glob(args.files)
@ -47,7 +50,7 @@ tokenizer.save(args.out, args.name)
tokenizer = ByteLevelBPETokenizer(
join(args.out, "{}-vocab.json".format(args.name)),
join(args.out, "{}-merges.txt".format(args.name)),
add_prefix_space=True
add_prefix_space=True,
)
# Test encoding

View File

@ -11,5 +11,5 @@ from .implementations import (
ByteLevelBPETokenizer,
CharBPETokenizer,
SentencePieceBPETokenizer,
BertWordPieceTokenizer
BertWordPieceTokenizer,
)

View File

@ -9,7 +9,7 @@ from .implementations import (
ByteLevelBPETokenizer as ByteLevelBPETokenizer,
BPETokenizer as BPETokenizer,
SentencePieceBPETokenizer as SentencePieceBPETokenizer,
BertWordPieceTokenizer as BertWordPieceTokenizer
BertWordPieceTokenizer as BertWordPieceTokenizer,
)
from typing import Optional, Union, List, Tuple
@ -38,27 +38,22 @@ class Encoding:
def normalized_str(self) -> IndexableString:
""" The normalized string """
pass
@property
def original_str(self) -> IndexableString:
""" The original string """
pass
@property
def ids(self) -> List[int]:
""" The tokenized ids """
pass
@property
def tokens(self) -> List[str]:
""" The tokenized strings """
pass
@property
def type_ids(self) -> List[int]:
""" The type ids """
pass
@property
def offsets(self) -> List[Offsets]:
""" The offsets.
@ -67,28 +62,26 @@ class Encoding:
method on the `original_str`.
"""
pass
@property
def special_tokens_mask(self) -> List[int]:
""" The special tokens mask """
pass
@property
def attention_mask(self) -> List[int]:
""" The attention mask """
pass
@property
def overflowing(self) -> Optional[Encoding]:
""" The overflowing encoding, after truncation """
pass
def pad(self,
length: int,
pad_id: Optional[int] = 0,
pad_type_id: Optional[int] = 0,
pad_token: Optional[str] = "[PAD]",
direction: Optional[str] = "right"):
def pad(
self,
length: int,
pad_id: Optional[int] = 0,
pad_type_id: Optional[int] = 0,
pad_token: Optional[str] = "[PAD]",
direction: Optional[str] = "right",
):
""" Pad the current Encoding at the given length
Args:
@ -108,7 +101,6 @@ class Encoding:
The pad token to be used when padding
"""
pass
def truncate(self, max_length: int, stride: Optional[int] = 0):
""" Truncate the current Encoding at the given max_length
@ -122,7 +114,6 @@ class Encoding:
"""
pass
class Tokenizer:
""" Tokenizer
@ -151,56 +142,44 @@ class Tokenizer:
Tokenizer
"""
pass
@property
def model(self) -> Model:
""" Get the model in use with this Tokenizer """
pass
@model.setter
def model(self, model: models.Model):
""" Change the model to use with this Tokenizer """
pass
@property
def pre_tokenizer(self) -> Optional[PreTokenizer]:
""" Get the pre-tokenizer in use with this model """
pass
@pre_tokenizer.setter
def pre_tokenizer(self, pre_tokenizer: pre_tokenizers.PreTokenizer):
""" Change the pre tokenizer to use with this Tokenizer """
pass
@property
def decoder(self) -> Optional[Decoder]:
""" Get the decoder in use with this model """
pass
@decoder.setter
def decoder(self, decoder: decoders.Decoder):
""" Change the decoder to use with this Tokenizer """
pass
@property
def post_processor(self) -> Optional[PostProcessor]:
""" Get the post-processor in use with this Tokenizer """
pass
@post_processor.setter
def post_processor(self, processor: processors.PostProcessor):
""" Change the post processor to use with this Tokenizer """
@property
def normalizer(self) -> Optional[Normalizer]:
""" Get the normalizer in use with this Tokenizer """
pass
@normalizer.setter
def normalizer(self, normalizer: normalizers.Normalizer):
""" Change the normalizer to use with this Tokenizer """
def num_special_tokens_to_add(self, is_pair: bool) -> int:
"""
Return the number of special tokens that would be added for single/pair sentences.
@ -208,8 +187,6 @@ class Tokenizer:
:return:
"""
pass
def get_vocab_size(self, with_added_tokens: Optional[bool]) -> int:
""" Returns the size of the vocabulary
@ -218,11 +195,7 @@ class Tokenizer:
Whether to include the added tokens in the vocabulary's size
"""
pass
def enable_truncation(self,
max_length: int,
stride: Optional[int],
strategy: Optional[str]):
def enable_truncation(self, max_length: int, stride: Optional[int], strategy: Optional[str]):
""" Enable the truncation
Args:
@ -237,17 +210,17 @@ class Tokenizer:
Can be one of `longest_first`, `only_first` or `only_second`
"""
pass
def no_truncation(self):
""" Disable truncation """
pass
def enable_padding(self,
direction: Optional[str] = "right",
pad_id: Optional[int] = 0,
pad_type_id: Optional[int] = 0,
pad_token: Optional[str] = "[PAD]",
max_length: Optional[int] = None):
def enable_padding(
self,
direction: Optional[str] = "right",
pad_id: Optional[int] = 0,
pad_type_id: Optional[int] = 0,
pad_token: Optional[str] = "[PAD]",
max_length: Optional[int] = None,
):
""" Enable the padding
Args:
@ -268,11 +241,9 @@ class Tokenizer:
we pad using the size of the longest sequence in a batch
"""
pass
def no_padding(self):
""" Disable padding """
pass
def encode(self, sequence: str, pair: Optional[str] = None) -> Encoding:
""" Encode the given sequence
@ -287,7 +258,6 @@ class Tokenizer:
An Encoding
"""
pass
def encode_batch(self, sequences: List[Union[str, Tuple[str, str]]]) -> List[Encoding]:
""" Encode the given sequences or pair of sequences
@ -300,7 +270,6 @@ class Tokenizer:
A list of Encoding
"""
pass
def decode(self, ids: List[int], skip_special_tokens: Optional[bool] = True) -> str:
""" Decode the given list of ids to a string sequence
@ -315,10 +284,9 @@ class Tokenizer:
The decoded string
"""
pass
def decode_batch(self,
sequences: List[List[int]],
skip_special_tokens: Optional[bool] = True) -> str:
def decode_batch(
self, sequences: List[List[int]], skip_special_tokens: Optional[bool] = True
) -> str:
""" Decode the list of sequences to a list of string sequences
Args:
@ -332,7 +300,6 @@ class Tokenizer:
A list of decoded strings
"""
pass
def token_to_id(self, token: str) -> Optional[int]:
""" Convert the given token to its corresponding id
@ -344,7 +311,6 @@ class Tokenizer:
The corresponding id if it exists, None otherwise
"""
pass
def id_to_token(self, id: int) -> Optional[str]:
""" Convert the given token id to its corresponding string
@ -356,7 +322,6 @@ class Tokenizer:
The corresponding string if it exists, None otherwise
"""
pass
def add_tokens(self, tokens: List[Union[str, Tuple[str, bool]]]) -> int:
""" Add the given tokens to the vocabulary
@ -371,7 +336,6 @@ class Tokenizer:
The number of tokens that were added to the vocabulary
"""
pass
def add_special_tokens(self, tokens: List[str]) -> int:
""" Add the given special tokens to the vocabulary, and treat them as special tokens.

View File

@ -2,8 +2,8 @@ from .. import Tokenizer, Encoding
from typing import List, Union, Tuple, Optional
class BaseTokenizer:
class BaseTokenizer:
def __init__(self, tokenizer: Tokenizer, parameters=None):
self._tokenizer = tokenizer
self._parameters = parameters if parameters is not None else {}
@ -11,7 +11,8 @@ class BaseTokenizer:
def __repr__(self):
return "Tokenizer(vocabulary_size={}, {})".format(
self._tokenizer.get_vocab_size(),
', '.join(k + '=' + str(v) for k, v in self._parameters.items()))
", ".join(k + "=" + str(v) for k, v in self._parameters.items()),
)
def num_special_tokens_to_add(self, is_pair: bool) -> int:
"""
@ -33,12 +34,14 @@ class BaseTokenizer:
"""
return self._tokenizer.get_vocab_size(with_added_tokens=with_added_tokens)
def enable_padding(self,
direction: Optional[str] = "right",
pad_id: Optional[int] = 0,
pad_type_id: Optional[int] = 0,
pad_token: Optional[str] = "[PAD]",
max_length: Optional[int] = None):
def enable_padding(
self,
direction: Optional[str] = "right",
pad_id: Optional[int] = 0,
pad_type_id: Optional[int] = 0,
pad_token: Optional[str] = "[PAD]",
max_length: Optional[int] = None,
):
""" Change the padding strategy
Args:
@ -58,20 +61,21 @@ class BaseTokenizer:
If specified, the length at which to pad. If not specified
we pad using the size of the longest sequence in a batch
"""
return self._tokenizer.enable_padding(direction=direction,
pad_id=pad_id,
pad_type_id=pad_type_id,
pad_token=pad_token,
max_length=max_length)
return self._tokenizer.enable_padding(
direction=direction,
pad_id=pad_id,
pad_type_id=pad_type_id,
pad_token=pad_token,
max_length=max_length,
)
def no_padding(self):
""" Disable padding """
return self._tokenizer.no_padding()
def enable_truncation(self,
max_length: int,
stride: Optional[int]=0,
strategy: Optional[str]='longest_first'):
def enable_truncation(
self, max_length: int, stride: Optional[int] = 0, strategy: Optional[str] = "longest_first"
):
""" Change the truncation options
Args:
@ -85,9 +89,7 @@ class BaseTokenizer:
strategy: (`optional) str:
Can be one of `longest_first`, `only_first` or `only_second`
"""
return self._tokenizer.enable_truncation(max_length,
stride=stride,
strategy=strategy)
return self._tokenizer.enable_truncation(max_length, stride=stride, strategy=strategy)
def no_truncation(self):
""" Disable truncation """
@ -166,9 +168,9 @@ class BaseTokenizer:
"""
return self._tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)
def decode_batch(self,
sequences: List[List[int]],
skip_special_tokens: Optional[bool] = True) -> str:
def decode_batch(
self, sequences: List[List[int]], skip_special_tokens: Optional[bool] = True
) -> str:
""" Decode the list of sequences to a list of string sequences
Args:

View File

@ -5,29 +5,34 @@ from .base_tokenizer import BaseTokenizer
from typing import Optional, List, Union
class ByteLevelBPETokenizer(BaseTokenizer):
""" ByteLevelBPETokenizer
Represents a Byte-level BPE as introduced by OpenAI with their GPT-2 model
"""
def __init__(self,
vocab_file: Optional[str]=None,
merges_file: Optional[str]=None,
add_prefix_space: bool=False,
lowercase: bool=False,
dropout: Optional[float]=None,
unicode_normalizer: Optional[str]=None,
continuing_subword_prefix: Optional[str]=None,
end_of_word_suffix: Optional[str]=None
):
def __init__(
self,
vocab_file: Optional[str] = None,
merges_file: Optional[str] = None,
add_prefix_space: bool = False,
lowercase: bool = False,
dropout: Optional[float] = None,
unicode_normalizer: Optional[str] = None,
continuing_subword_prefix: Optional[str] = None,
end_of_word_suffix: Optional[str] = None,
):
if vocab_file is not None and merges_file is not None:
tokenizer = Tokenizer(BPE.from_files(
vocab_file, merges_file,
dropout=dropout,
continuing_subword_prefix=continuing_subword_prefix or "",
end_of_word_suffix=end_of_word_suffix or "",
))
tokenizer = Tokenizer(
BPE.from_files(
vocab_file,
merges_file,
dropout=dropout,
continuing_subword_prefix=continuing_subword_prefix or "",
end_of_word_suffix=end_of_word_suffix or "",
)
)
else:
tokenizer = Tokenizer(BPE.empty())
@ -47,9 +52,7 @@ class ByteLevelBPETokenizer(BaseTokenizer):
else:
tokenizer.normalizer = normalizers[0]
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(
add_prefix_space=add_prefix_space
)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
parameters = {

View File

@ -12,14 +12,16 @@ class CharBPETokenizer(BaseTokenizer):
Represents the BPE algorithm, as introduced by Rico Sennrich (https://arxiv.org/abs/1508.07909)
"""
def __init__(self,
vocab_file: Optional[str]=None,
merges_file: Optional[str]=None,
unk_token: Optional[str]="<unk>",
suffix: Optional[str]="</w>",
dropout: Optional[float]=None,
lowercase: bool = False,
unicode_normalizer: Optional[str] = None):
def __init__(
self,
vocab_file: Optional[str] = None,
merges_file: Optional[str] = None,
unk_token: Optional[str] = "<unk>",
suffix: Optional[str] = "</w>",
dropout: Optional[float] = None,
lowercase: bool = False,
unicode_normalizer: Optional[str] = None,
):
if vocab_file is not None and merges_file is not None:
tokenizer = Tokenizer(
BPE.from_files(
@ -27,13 +29,13 @@ class CharBPETokenizer(BaseTokenizer):
merges_file,
dropout=dropout,
unk_token=unk_token,
end_of_word_suffix=suffix
end_of_word_suffix=suffix,
)
)
else:
tokenizer = Tokenizer(BPE.empty())
tokenizer.add_special_tokens([ unk_token ])
tokenizer.add_special_tokens([unk_token])
# Check for Unicode normalization first (before everything else)
normalizers = []

View File

@ -5,29 +5,30 @@ from .base_tokenizer import BaseTokenizer
from typing import Optional, List, Union
class SentencePieceBPETokenizer(BaseTokenizer):
""" SentencePiece BPE Tokenizer
Represents the BPE algorithm, with the pretokenization used by SentencePiece
"""
def __init__(self,
vocab_file: Optional[str]=None,
merges_file: Optional[str]=None,
unk_token: str="<unk>",
replacement: str="",
add_prefix_space: bool=True,
dropout: Optional[float]=None):
def __init__(
self,
vocab_file: Optional[str] = None,
merges_file: Optional[str] = None,
unk_token: str = "<unk>",
replacement: str = "",
add_prefix_space: bool = True,
dropout: Optional[float] = None,
):
if vocab_file is not None and merges_file is not None:
tokenizer = Tokenizer(BPE.from_files(vocab_file,
merges_file,
dropout=dropout,
unk_token=unk_token))
tokenizer = Tokenizer(
BPE.from_files(vocab_file, merges_file, dropout=dropout, unk_token=unk_token)
)
else:
tokenizer = Tokenizer(BPE.empty())
tokenizer.add_special_tokens([ unk_token ])
tokenizer.add_special_tokens([unk_token])
tokenizer.normalizer = NFKC()
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(

View File

@ -3,4 +3,4 @@ from .. import models
Model = models.Model
BPE = models.BPE
WordPiece = models.WordPiece
WordLevel = models.WordLevel
WordLevel = models.WordLevel

View File

@ -16,18 +16,19 @@ class Model:
"""
pass
class BPE(Model):
""" BytePairEncoding model class """
@staticmethod
def from_files(vocab: str,
merges: str,
cache_capacity: Optional[int],
dropout: Optional[float],
unk_token: Optional[str],
continuing_subword_prefix: Optional[str],
end_of_word_suffix: Optional[str]) -> Model:
def from_files(
vocab: str,
merges: str,
cache_capacity: Optional[int],
dropout: Optional[float],
unk_token: Optional[str],
continuing_subword_prefix: Optional[str],
end_of_word_suffix: Optional[str],
) -> Model:
""" Instantiate a BPE Model from the given vocab and merges files.
Args:
@ -55,20 +56,18 @@ class BPE(Model):
The suffix to attach to subword units that represent an end of word.
"""
pass
@staticmethod
def empty() -> Model:
""" Instantiate an empty BPE Model. """
pass
class WordPiece(Model):
""" WordPiece model class """
@staticmethod
def from_files(vocab: str,
unk_token: Optional[str],
max_input_chars_per_word: Optional[int]) -> Model:
def from_files(
vocab: str, unk_token: Optional[str], max_input_chars_per_word: Optional[int]
) -> Model:
""" Instantiate a WordPiece Model from the given vocab file.
Args:
@ -82,13 +81,11 @@ class WordPiece(Model):
The maximum number of characters to authorize in a single word.
"""
pass
@staticmethod
def empty() -> Model:
""" Instantiate an empty WordPiece Model. """
pass
class WordLevel(Model):
"""
Most simple tokenizer model based on mapping token from a vocab file to their corresponding id.
@ -105,4 +102,4 @@ class WordLevel(Model):
unk_token: str:
The unknown token to be used by the model.
"""
pass
pass

View File

@ -11,19 +11,15 @@ Lowercase = normalizers.Lowercase
Strip = normalizers.Strip
NORMALIZERS = {
"nfc": NFC,
"nfd": NFD,
"nfkc": NFKC,
"nfkd": NFKD
}
NORMALIZERS = {"nfc": NFC, "nfd": NFD, "nfkc": NFKC, "nfkd": NFKD}
def unicode_normalizer_from_str(normalizer: str) -> Normalizer:
if normalizer not in NORMALIZERS:
raise ValueError(
"{} is not a known unicode normalizer. Available are {}"
.format(normalizer, NORMALIZERS.keys())
"{} is not a known unicode normalizer. Available are {}".format(
normalizer, NORMALIZERS.keys()
)
)
return NORMALIZERS[normalizer]()
return NORMALIZERS[normalizer]()

View File

@ -98,7 +98,6 @@ class Strip(Normalizer):
def __init__(self, left: bool = True, right: bool = True) -> Normalizer:
pass
def unicode_normalizer_from_str(normalizer: str) -> Normalizer:
"""
Instanciate unicode normalizer from the normalizer name

View File

@ -33,7 +33,6 @@ class ByteLevel(PreTokenizer):
PreTokenizer
"""
pass
@staticmethod
def alphabet() -> List[str]:
""" Returns the alphabet used by this PreTokenizer.
@ -96,7 +95,6 @@ class Metaspace(PreTokenizer):
"""
pass
class CharDelimiterSplit(PreTokenizer):
""" CharDelimiterSplit PreTokenizer

View File

@ -2,4 +2,4 @@ from .. import processors
PostProcessor = processors.PostProcessor
BertProcessing = processors.BertProcessing
RobertaProcessing = processors.RobertaProcessing
RobertaProcessing = processors.RobertaProcessing