Fix SentencePiece tokenizers conversion

This commit is contained in:
Anthony MOI
2021-02-03 09:57:41 -05:00
committed by Anthony MOI
parent fc0a50a272
commit 96b9972842
4 changed files with 33 additions and 45 deletions

View File

@@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Fixed
- [#616]: Fix SentencePiece tokenizers conversion
## [0.10.0] ## [0.10.0]
### Added ### Added
@@ -22,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`tokenizer.model.dropout = 0.1`) `tokenizer.model.dropout = 0.1`)
- [#538]: The API Reference has been improved and is now up-to-date. - [#538]: The API Reference has been improved and is now up-to-date.
## Fixed ### Fixed
- [#519]: During training, the `Model` is now trained in-place. This fixes several bugs that were - [#519]: During training, the `Model` is now trained in-place. This fixes several bugs that were
forcing to reload the `Model` after a training. forcing to reload the `Model` after a training.
- [#539]: Fix `BaseTokenizer` enable_truncation docstring - [#539]: Fix `BaseTokenizer` enable_truncation docstring
@@ -293,6 +298,7 @@ delimiter (Works like `.split(delimiter)`)
- Fix a bug that was causing crashes in Python 3.5 - Fix a bug that was causing crashes in Python 3.5
[#616]: https://github.com/huggingface/tokenizers/pull/616
[#590]: https://github.com/huggingface/tokenizers/pull/590 [#590]: https://github.com/huggingface/tokenizers/pull/590
[#574]: https://github.com/huggingface/tokenizers/pull/574 [#574]: https://github.com/huggingface/tokenizers/pull/574
[#544]: https://github.com/huggingface/tokenizers/pull/544 [#544]: https://github.com/huggingface/tokenizers/pull/544

View File

@@ -1,11 +1,4 @@
from tokenizers import ( from tokenizers import Tokenizer, AddedToken, pre_tokenizers, decoders, trainers, normalizers, Regex
Tokenizer,
AddedToken,
pre_tokenizers,
decoders,
trainers,
normalizers,
)
import os import os
from tokenizers.models import Unigram from tokenizers.models import Unigram
import json import json
@@ -33,18 +26,10 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
tokenizer = Tokenizer(Unigram()) tokenizer = Tokenizer(Unigram())
tokenizer.normalizer = normalizers.Sequence( tokenizer.normalizer = normalizers.Sequence(
[ [normalizers.Nmt(), normalizers.NFKC(), normalizers.Replace(Regex(" {2,}"), " ")]
normalizers.Nmt(),
normalizers.NFKC(),
]
) )
tokenizer.pre_tokenizer = pre_tokenizers.Sequence( tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(
[
pre_tokenizers.WhitespaceSplit(),
pre_tokenizers.Metaspace(
replacement=replacement, add_prefix_space=add_prefix_space replacement=replacement, add_prefix_space=add_prefix_space
),
]
) )
tokenizer.decoder = decoders.Metaspace( tokenizer.decoder = decoders.Metaspace(
replacement=replacement, add_prefix_space=add_prefix_space replacement=replacement, add_prefix_space=add_prefix_space
@@ -124,15 +109,15 @@ class SentencePieceUnigramTokenizer(BaseTokenizer):
tokenizer = Tokenizer(Unigram(vocab, unk_id)) tokenizer = Tokenizer(Unigram(vocab, unk_id))
tokenizer.normalizer = normalizers.Precompiled(precompiled_charsmap) tokenizer.normalizer = normalizers.Sequence(
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[ [
pre_tokenizers.WhitespaceSplit(), normalizers.Precompiled(precompiled_charsmap),
pre_tokenizers.Metaspace( normalizers.Replace(Regex(" {2,}"), " "),
replacement=replacement, add_prefix_space=add_prefix_space
),
] ]
) )
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(
replacement=replacement, add_prefix_space=add_prefix_space
)
tokenizer.decoder = decoders.Metaspace( tokenizer.decoder = decoders.Metaspace(
replacement=replacement, add_prefix_space=add_prefix_space replacement=replacement, add_prefix_space=add_prefix_space
) )

View File

@@ -3,7 +3,7 @@ from tokenizers.implementations import SentencePieceUnigramTokenizer, BaseTokeni
from tokenizers.processors import TemplateProcessing from tokenizers.processors import TemplateProcessing
from tokenizers.models import Unigram, BPE from tokenizers.models import Unigram, BPE
from tokenizers import decoders from tokenizers import decoders
from tokenizers import Tokenizer from tokenizers import Tokenizer, Regex
from tokenizers.normalizers import ( from tokenizers.normalizers import (
StripAccents, StripAccents,
NFKD, NFKD,
@@ -81,7 +81,7 @@ class SpmConverter(Converter):
elif model_type == 2: elif model_type == 2:
vocab, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract() vocab, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
tokenizer = Tokenizer( tokenizer = Tokenizer(
BPE(vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True,) BPE(vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True)
) )
else: else:
raise Exception( raise Exception(
@@ -92,7 +92,7 @@ class SpmConverter(Converter):
def normalizer(self, proto): def normalizer(self, proto):
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
return Precompiled(precompiled_charsmap) return Sequence([Precompiled(precompiled_charsmap), Replace(Regex(" {2,}"), " ")])
def post_processor(self, tokenizer): def post_processor(self, tokenizer):
return None return None
@@ -105,11 +105,8 @@ class SpmConverter(Converter):
replacement = "" replacement = ""
add_prefix_space = True add_prefix_space = True
tokenizer.pre_tokenizer = PSequence( tokenizer.pre_tokenizer = Metaspace(
[ replacement=replacement, add_prefix_space=add_prefix_space
WhitespaceSplit(),
Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
]
) )
tokenizer.decoder = decoders.Metaspace( tokenizer.decoder = decoders.Metaspace(
replacement=replacement, add_prefix_space=add_prefix_space replacement=replacement, add_prefix_space=add_prefix_space
@@ -134,7 +131,7 @@ class AlbertConverter(SpmConverter):
] ]
def normalizer(self, proto): def normalizer(self, proto):
normalizers = [Replace("``", '"'), Replace("''", '"')] normalizers = [Replace("``", '"'), Replace("''", '"'), Replace(Regex(" {2,}"), " ")]
if not self.original_tokenizer.keep_accents: if not self.original_tokenizer.keep_accents:
normalizers.append(NFKD()) normalizers.append(NFKD())
normalizers.append(StripAccents()) normalizers.append(StripAccents())
@@ -270,7 +267,7 @@ class XLNetConverter(SpmConverter):
] ]
def normalizer(self, proto): def normalizer(self, proto):
normalizers = [Replace("``", '"'), Replace("''", '"')] normalizers = [Replace("``", '"'), Replace("''", '"'), Replace(Regex(" {2,}"), " ")]
if not self.original_tokenizer.keep_accents: if not self.original_tokenizer.keep_accents:
normalizers.append(NFKD()) normalizers.append(NFKD())
normalizers.append(StripAccents()) normalizers.append(StripAccents())
@@ -316,7 +313,7 @@ class PegasusConverter(SpmConverter):
return TemplateProcessing( return TemplateProcessing(
seq_a=["$0", eos], seq_a=["$0", eos],
seq_b=["$1", eos], seq_b=["$1", eos],
special_tokens=[(eos, tokenizer.get_vocab()[eos]),], special_tokens=[(eos, tokenizer.get_vocab()[eos])],
) )
@@ -325,7 +322,7 @@ class T5Converter(SpmConverter):
return TemplateProcessing( return TemplateProcessing(
seq_a=["$0", "</s>"], seq_a=["$0", "</s>"],
seq_b=["$1", "</s>"], seq_b=["$1", "</s>"],
special_tokens=[("</s>", tokenizer.get_vocab()["</s>"]),], special_tokens=[("</s>", tokenizer.get_vocab()["</s>"])],
) )

View File

@@ -131,12 +131,12 @@ def check_diff(spm_diff, tok_diff, sp, tok):
if spm_diff == list(reversed(tok_diff)): if spm_diff == list(reversed(tok_diff)):
# AAA -> AA+A vs A+AA case. # AAA -> AA+A vs A+AA case.
return True return True
# elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode( elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(
# tok_diff tok_diff
# ): ):
# # Second order OK # Second order OK
# # Barrich -> Barr + ich vs Bar + rich # Barrich -> Barr + ich vs Bar + rich
# return True return True
spm_reencoded = sp.encode(sp.decode(spm_diff)) spm_reencoded = sp.encode(sp.decode(spm_diff))
tok_reencoded = tok.encode(tok.decode(spm_diff)).ids tok_reencoded = tok.encode(tok.decode(spm_diff)).ids
if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded: if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded:
@@ -265,7 +265,7 @@ def check_encode(args):
else: else:
perfect += 1 perfect += 1
assert ids == encoded.ids, f"line {i}: {line} : {ids} != {encoded.ids}" assert ids == encoded.ids, f"line {i}: {line} : \n\n{ids}\n{encoded.ids}\n{list(zip(encoded.ids, encoded.tokens))}"
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})") print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
total = perfect + imperfect + wrong total = perfect + imperfect + wrong