mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Fix SentencePiece tokenizers conversion
This commit is contained in:
@ -3,7 +3,7 @@ from tokenizers.implementations import SentencePieceUnigramTokenizer, BaseTokeni
|
||||
from tokenizers.processors import TemplateProcessing
|
||||
from tokenizers.models import Unigram, BPE
|
||||
from tokenizers import decoders
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers import Tokenizer, Regex
|
||||
from tokenizers.normalizers import (
|
||||
StripAccents,
|
||||
NFKD,
|
||||
@ -81,7 +81,7 @@ class SpmConverter(Converter):
|
||||
elif model_type == 2:
|
||||
vocab, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
|
||||
tokenizer = Tokenizer(
|
||||
BPE(vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True,)
|
||||
BPE(vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True)
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
@ -92,7 +92,7 @@ class SpmConverter(Converter):
|
||||
|
||||
def normalizer(self, proto):
|
||||
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
|
||||
return Precompiled(precompiled_charsmap)
|
||||
return Sequence([Precompiled(precompiled_charsmap), Replace(Regex(" {2,}"), " ")])
|
||||
|
||||
def post_processor(self, tokenizer):
|
||||
return None
|
||||
@ -105,11 +105,8 @@ class SpmConverter(Converter):
|
||||
|
||||
replacement = "▁"
|
||||
add_prefix_space = True
|
||||
tokenizer.pre_tokenizer = PSequence(
|
||||
[
|
||||
WhitespaceSplit(),
|
||||
Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
|
||||
]
|
||||
tokenizer.pre_tokenizer = Metaspace(
|
||||
replacement=replacement, add_prefix_space=add_prefix_space
|
||||
)
|
||||
tokenizer.decoder = decoders.Metaspace(
|
||||
replacement=replacement, add_prefix_space=add_prefix_space
|
||||
@ -134,7 +131,7 @@ class AlbertConverter(SpmConverter):
|
||||
]
|
||||
|
||||
def normalizer(self, proto):
|
||||
normalizers = [Replace("``", '"'), Replace("''", '"')]
|
||||
normalizers = [Replace("``", '"'), Replace("''", '"'), Replace(Regex(" {2,}"), " ")]
|
||||
if not self.original_tokenizer.keep_accents:
|
||||
normalizers.append(NFKD())
|
||||
normalizers.append(StripAccents())
|
||||
@ -270,7 +267,7 @@ class XLNetConverter(SpmConverter):
|
||||
]
|
||||
|
||||
def normalizer(self, proto):
|
||||
normalizers = [Replace("``", '"'), Replace("''", '"')]
|
||||
normalizers = [Replace("``", '"'), Replace("''", '"'), Replace(Regex(" {2,}"), " ")]
|
||||
if not self.original_tokenizer.keep_accents:
|
||||
normalizers.append(NFKD())
|
||||
normalizers.append(StripAccents())
|
||||
@ -316,7 +313,7 @@ class PegasusConverter(SpmConverter):
|
||||
return TemplateProcessing(
|
||||
seq_a=["$0", eos],
|
||||
seq_b=["$1", eos],
|
||||
special_tokens=[(eos, tokenizer.get_vocab()[eos]),],
|
||||
special_tokens=[(eos, tokenizer.get_vocab()[eos])],
|
||||
)
|
||||
|
||||
|
||||
@ -325,7 +322,7 @@ class T5Converter(SpmConverter):
|
||||
return TemplateProcessing(
|
||||
seq_a=["$0", "</s>"],
|
||||
seq_b=["$1", "</s>"],
|
||||
special_tokens=[("</s>", tokenizer.get_vocab()["</s>"]),],
|
||||
special_tokens=[("</s>", tokenizer.get_vocab()["</s>"])],
|
||||
)
|
||||
|
||||
|
||||
|
@ -131,12 +131,12 @@ def check_diff(spm_diff, tok_diff, sp, tok):
|
||||
if spm_diff == list(reversed(tok_diff)):
|
||||
# AAA -> AA+A vs A+AA case.
|
||||
return True
|
||||
# elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(
|
||||
# tok_diff
|
||||
# ):
|
||||
# # Second order OK
|
||||
# # Barrich -> Barr + ich vs Bar + rich
|
||||
# return True
|
||||
elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(
|
||||
tok_diff
|
||||
):
|
||||
# Second order OK
|
||||
# Barrich -> Barr + ich vs Bar + rich
|
||||
return True
|
||||
spm_reencoded = sp.encode(sp.decode(spm_diff))
|
||||
tok_reencoded = tok.encode(tok.decode(spm_diff)).ids
|
||||
if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded:
|
||||
@ -265,7 +265,7 @@ def check_encode(args):
|
||||
else:
|
||||
perfect += 1
|
||||
|
||||
assert ids == encoded.ids, f"line {i}: {line} : {ids} != {encoded.ids}"
|
||||
assert ids == encoded.ids, f"line {i}: {line} : \n\n{ids}\n{encoded.ids}\n{list(zip(encoded.ids, encoded.tokens))}"
|
||||
|
||||
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
|
||||
total = perfect + imperfect + wrong
|
||||
|
Reference in New Issue
Block a user