mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
This allows testing versions not built in-place. Otherwise importing (or testing) in the package root fails without develop builds. Replace maturin with setuptools_rust since maturin fails with proper project structure.
114 lines
4.2 KiB
Python
114 lines
4.2 KiB
Python
from tokenizers import Tokenizer, AddedToken, decoders, trainers
|
|
from tokenizers.models import WordPiece
|
|
from tokenizers.normalizers import BertNormalizer
|
|
from tokenizers.pre_tokenizers import BertPreTokenizer
|
|
from tokenizers.processors import BertProcessing
|
|
from .base_tokenizer import BaseTokenizer
|
|
|
|
from typing import Optional, List, Union
|
|
|
|
|
|
class BertWordPieceTokenizer(BaseTokenizer):
|
|
""" Bert WordPiece Tokenizer """
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_file: Optional[str] = None,
|
|
unk_token: Union[str, AddedToken] = "[UNK]",
|
|
sep_token: Union[str, AddedToken] = "[SEP]",
|
|
cls_token: Union[str, AddedToken] = "[CLS]",
|
|
pad_token: Union[str, AddedToken] = "[PAD]",
|
|
mask_token: Union[str, AddedToken] = "[MASK]",
|
|
clean_text: bool = True,
|
|
handle_chinese_chars: bool = True,
|
|
strip_accents: Optional[bool] = None,
|
|
lowercase: bool = True,
|
|
wordpieces_prefix: str = "##",
|
|
):
|
|
|
|
if vocab_file is not None:
|
|
tokenizer = Tokenizer(WordPiece(vocab_file, unk_token=str(unk_token)))
|
|
else:
|
|
tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token)))
|
|
|
|
# Let the tokenizer know about special tokens if they are part of the vocab
|
|
if tokenizer.token_to_id(str(unk_token)) is not None:
|
|
tokenizer.add_special_tokens([str(unk_token)])
|
|
if tokenizer.token_to_id(str(sep_token)) is not None:
|
|
tokenizer.add_special_tokens([str(sep_token)])
|
|
if tokenizer.token_to_id(str(cls_token)) is not None:
|
|
tokenizer.add_special_tokens([str(cls_token)])
|
|
if tokenizer.token_to_id(str(pad_token)) is not None:
|
|
tokenizer.add_special_tokens([str(pad_token)])
|
|
if tokenizer.token_to_id(str(mask_token)) is not None:
|
|
tokenizer.add_special_tokens([str(mask_token)])
|
|
|
|
tokenizer.normalizer = BertNormalizer(
|
|
clean_text=clean_text,
|
|
handle_chinese_chars=handle_chinese_chars,
|
|
strip_accents=strip_accents,
|
|
lowercase=lowercase,
|
|
)
|
|
tokenizer.pre_tokenizer = BertPreTokenizer()
|
|
|
|
if vocab_file is not None:
|
|
sep_token_id = tokenizer.token_to_id(str(sep_token))
|
|
if sep_token_id is None:
|
|
raise TypeError("sep_token not found in the vocabulary")
|
|
cls_token_id = tokenizer.token_to_id(str(cls_token))
|
|
if cls_token_id is None:
|
|
raise TypeError("cls_token not found in the vocabulary")
|
|
|
|
tokenizer.post_processor = BertProcessing(
|
|
(str(sep_token), sep_token_id), (str(cls_token), cls_token_id)
|
|
)
|
|
tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)
|
|
|
|
parameters = {
|
|
"model": "BertWordPiece",
|
|
"unk_token": unk_token,
|
|
"sep_token": sep_token,
|
|
"cls_token": cls_token,
|
|
"pad_token": pad_token,
|
|
"mask_token": mask_token,
|
|
"clean_text": clean_text,
|
|
"handle_chinese_chars": handle_chinese_chars,
|
|
"strip_accents": strip_accents,
|
|
"lowercase": lowercase,
|
|
"wordpieces_prefix": wordpieces_prefix,
|
|
}
|
|
|
|
super().__init__(tokenizer, parameters)
|
|
|
|
def train(
|
|
self,
|
|
files: Union[str, List[str]],
|
|
vocab_size: int = 30000,
|
|
min_frequency: int = 2,
|
|
limit_alphabet: int = 1000,
|
|
initial_alphabet: List[str] = [],
|
|
special_tokens: List[Union[str, AddedToken]] = [
|
|
"[PAD]",
|
|
"[UNK]",
|
|
"[CLS]",
|
|
"[SEP]",
|
|
"[MASK]",
|
|
],
|
|
show_progress: bool = True,
|
|
wordpieces_prefix: str = "##",
|
|
):
|
|
""" Train the model using the given files """
|
|
|
|
trainer = trainers.WordPieceTrainer(
|
|
vocab_size=vocab_size,
|
|
min_frequency=min_frequency,
|
|
limit_alphabet=limit_alphabet,
|
|
initial_alphabet=initial_alphabet,
|
|
special_tokens=special_tokens,
|
|
show_progress=show_progress,
|
|
continuing_subword_prefix=wordpieces_prefix,
|
|
)
|
|
if isinstance(files, str):
|
|
files = [files]
|
|
self._tokenizer.train(trainer, files)
|