Adding a new convert script, that will convert all python Tokenizer code

into a proper Rust Tokenizer format and check it on a file.

- Also fuse_unks by default in `tokenizers`'s BPE.
This commit is contained in:
Nicolas Patry
2020-09-17 10:12:00 +02:00
parent c84f1d05c0
commit 2fd1d9cf06
4 changed files with 521 additions and 15 deletions

View File

@ -0,0 +1,442 @@
import transformers
from tokenizers.implementations import SentencePieceUnigramTokenizer, BaseTokenizer
from tokenizers.processors import TemplateProcessing
from tokenizers.models import Unigram, BPE
from tokenizers import decoders
from tokenizers import Tokenizer
from tokenizers.normalizers import (
NFKD,
Lowercase,
Sequence,
BertNormalizer,
Precompiled,
)
from tokenizers.pre_tokenizers import (
Digits,
WhitespaceSplit,
Metaspace,
Sequence as PSequence,
)
import json
import unicodedata
import sys
import os
import datetime
sys.path.append(".")
from spm_parity_check import check_details
from sentencepiece_extractor import SentencePieceExtractor
def check_number_comma(piece: str) -> bool:
return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
def get_proto(filename: str):
try:
import sys
sys.path.append(".")
import sentencepiece_model_pb2 as model
except Exception:
raise Exception(
"You don't seem to have the required protobuf file, in order to use this function you need to run `pip install protobuf` and `wget https://raw.githubusercontent.com/google/sentencepiece/master/python/sentencepiece_model_pb2.py` for us to be able to read the intrinsics of your spm_file. `pip install sentencepiece` is not required."
)
m = model.ModelProto()
m.ParseFromString(open(filename, "rb").read())
return m
class Converter:
def __init__(self, original_tokenizer):
self.original_tokenizer = original_tokenizer
def converted(self) -> Tokenizer:
raise NotImplementedError()
class SpmConverter(Converter):
def __init__(self, *args):
super().__init__(*args)
self.proto = get_proto(self.original_tokenizer.vocab_file)
def vocab(self, proto):
return [(piece.piece, piece.score) for piece in proto.pieces]
def unk_id(self, proto):
return proto.trainer_spec.unk_id
def tokenizer(self, proto):
model_type = proto.trainer_spec.model_type
vocab = self.vocab(proto)
unk_id = self.unk_id(proto)
filename = self.original_tokenizer.vocab_file
if model_type == 1:
data = {"unk_id": unk_id, "vocab": vocab}
out_vocab_filename = f"{filename}.json"
try:
with open(out_vocab_filename, "w") as f:
json.dump(data, f, indent=4)
tokenizer = Tokenizer(Unigram(out_vocab_filename))
finally:
os.remove(out_vocab_filename)
elif model_type == 2:
vocab, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
# Open output files and let's extract model information
out_vocab_filename = f"{filename}.vocab"
out_merge_filename = f"{filename}.merge"
try:
with open(out_vocab_filename, "w") as vocab_f:
json.dump(vocab, vocab_f)
try:
with open(out_merge_filename, "w") as merges_f:
# Save content
merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{os.linesep}", merges))
tokenizer = Tokenizer(
BPE(
out_vocab_filename,
out_merge_filename,
unk_token=proto.trainer_spec.unk_piece,
)
)
finally:
os.remove(out_merge_filename)
finally:
os.remove(out_vocab_filename)
else:
raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
)
return tokenizer
def normalizer(self, proto):
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
return Precompiled(precompiled_charsmap)
def post_processor(self, tokenizer):
return None
def converted(self):
tokenizer = self.tokenizer(self.proto)
# Tokenizer assemble
tokenizer.normalizer = self.normalizer(self.proto)
replacement = ""
add_prefix_space = True
tokenizer.pre_tokenizer = PSequence(
[
WhitespaceSplit(),
Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
]
)
tokenizer.decoder = decoders.Metaspace(
replacement=replacement, add_prefix_space=add_prefix_space
)
post_processor = self.post_processor(tokenizer)
if post_processor:
tokenizer.post_processor = post_processor
# TODO what parameters should we give ?
parameters = {}
return BaseTokenizer(tokenizer, parameters)
class AlbertConverter(SpmConverter):
def vocab(self, proto):
return [
(piece.piece, piece.score)
if check_number_comma(piece.piece)
else (piece.piece, piece.score - 100)
for piece in proto.pieces
]
def normalizer(self, proto):
normalizers = []
# TODO Missing Replace quotes
if not self.original_tokenizer.keep_accents:
normalizers.append(NFKD())
# TODO Missing strip accents
if self.original_tokenizer.do_lower_case:
normalizers.append(Lowercase())
precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
normalizers.append(Precompiled(precompiled_charsmap))
return Sequence(normalizers)
def post_processor(self, tokenizer):
return TemplateProcessing(
seq_a=["[CLS]", "$0", "[SEP]"],
seq_b=["$1", "[SEP]"],
special_tokens=[
("[CLS]", tokenizer.get_vocab()["[CLS]"]),
("[SEP]", tokenizer.get_vocab()["[SEP]"]),
],
)
class CamembertConverter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>NOTUSED", 0.0),
("<pad>", 0.0),
("</s>NOTUSED", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces]
return vocab
def unk_id(self, proto):
# See vocab unk position
return 3
def post_processor(self, tokenizer):
return TemplateProcessing(
seq_a=["<s>", "$0", "</s>"],
seq_b=["$1", "</s>"],
special_tokens=[
("<s>", tokenizer.get_vocab()["<s>"]),
("</s>", tokenizer.get_vocab()["</s>"]),
],
)
class MBartConverter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>", 0.0),
("<pad>", 0.0),
("</s>", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
vocab += [
("ar_AR", 0.0),
("cs_CZ", 0.0),
("de_DE", 0.0),
("en_XX", 0.0),
("es_XX", 0.0),
("et_EE", 0.0),
("fi_FI", 0.0),
("fr_XX", 0.0),
("gu_IN", 0.0),
("hi_IN", 0.0),
("it_IT", 0.0),
("ja_XX", 0.0),
("kk_KZ", 0.0),
("ko_KR", 0.0),
("lt_LT", 0.0),
("lv_LV", 0.0),
("my_MM", 0.0),
("ne_NP", 0.0),
("nl_XX", 0.0),
("ro_RO", 0.0),
("ru_RU", 0.0),
("si_LK", 0.0),
("tr_TR", 0.0),
("vi_VN", 0.0),
("zh_CN", 0.0),
]
return vocab
def unk_id(self, proto):
return 3
def post_processor(self, tokenizer):
return TemplateProcessing(
seq_a=["$0", "</s>", "en_XX"],
seq_b=["$1", "</s>"],
special_tokens=[
("en_XX", tokenizer.get_vocab()["en_XX"]),
("</s>", tokenizer.get_vocab()["</s>"]),
],
)
class XLMRobertaConverter(SpmConverter):
def vocab(self, proto):
vocab = [
("<s>", 0.0),
("<pad>", 0.0),
("</s>", 0.0),
("<unk>", 0.0),
]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
return vocab
def unk_id(self, proto):
unk_id = 3
return unk_id
def post_processor(self, tokenizer):
return TemplateProcessing(
seq_a=["<s>", "$0", "</s>"],
seq_b=["$1", "</s>"],
special_tokens=[
("<s>", tokenizer.get_vocab()["<s>"]),
("</s>", tokenizer.get_vocab()["</s>"]),
],
)
class XLNetConverter(SpmConverter):
def vocab(self, proto):
return [
(piece.piece, piece.score)
if check_number_comma(piece.piece)
else (piece.piece, piece.score - 100)
for piece in proto.pieces
]
def normalizer(self, proto):
# TODO Missing Replace quotes
# TODO Missing strip accents
return super().normalizer(proto)
def post_processor(self, tokenizer):
return TemplateProcessing(
seq_a=["$0", "<sep>", "<cls>"],
seq_b=["$1", "<sep>"],
special_tokens=[
("<sep>", tokenizer.get_vocab()["<sep>"]),
("<cls>", tokenizer.get_vocab()["<cls>"]),
],
)
class ReformerConverter(SpmConverter):
pass
class PegasusConverter(SpmConverter):
offset = 103
def vocab(self, proto):
vocab = [
(self.original_tokenizer.pad_token, 0),
(self.original_tokenizer.eos_token, 0),
]
vocab += [(f"unk_{i}", -100) for i in range(2, 2 + self.offset)]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
return vocab
def unk_id(self, proto):
return proto.trainer_spec.unk_id + self.offset
def post_processor(self, tokenizer):
eos = self.original_tokenizer.eos_token
return TemplateProcessing(
seq_a=["$0", eos],
seq_b=["$1", eos],
special_tokens=[(eos, tokenizer.get_vocab()[eos]),],
)
class T5Converter(SpmConverter):
def post_processor(self, tokenizer):
return TemplateProcessing(
seq_a=["$0", "</s>"],
seq_b=["$1", "</s>"],
special_tokens=[("</s>", tokenizer.get_vocab()["</s>"]),],
)
CONVERTERS = {
"AlbertTokenizer": AlbertConverter,
"CamembertTokenizer": CamembertConverter,
"XLMRobertaTokenizer": XLMRobertaConverter,
"MBartTokenizer": MBartConverter,
"XLNetTokenizer": XLNetConverter,
"ReformerTokenizer": ReformerConverter,
"PegasusTokenizer": PegasusConverter,
"T5Tokenizer": T5Converter,
}
def check(pretrained):
transformer_tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
converter_class = CONVERTERS[transformer_tokenizer.__class__.__name__]
tokenizer = converter_class(transformer_tokenizer).converted()
tokenizer.save(f"{pretrained}.json")
now = datetime.datetime.now
trans_total_time = datetime.timedelta(seconds=0)
tok_total_time = datetime.timedelta(seconds=0)
with open("/home/nicolas/data/xnli/xnli.txt", "r") as f:
for i, line in enumerate(f):
line = line.strip()
# TODO in normalizer
line = unicodedata.normalize("NFKD", line)
line = "".join([c for c in line if not unicodedata.combining(c)])
# TODO in normalizer
line = line.replace("``", '"').replace("''", '"')
start = now()
ids = transformer_tokenizer.encode(line)
trans = now()
tok_ids = tokenizer.encode(line).ids
tok = now()
trans_total_time += trans - start
tok_total_time += tok - trans
if ids != tok_ids:
if check_details(line, ids, tok_ids, tokenizer, transformer_tokenizer):
continue
assert ids == tok_ids, f"Error in line {i}: {line} {ids} != {tok_ids}"
return ("OK", trans_total_time / tok_total_time)
def main():
pretraineds = [
"albert-base-v1",
"albert-large-v1",
"albert-xlarge-v1",
"albert-xxlarge-v1",
"albert-base-v2",
"albert-large-v2",
"albert-xlarge-v2",
"albert-xxlarge-v2",
"camembert-base",
"xlm-roberta-base",
"xlm-roberta-large",
"xlm-roberta-large-finetuned-conll02-dutch",
"xlm-roberta-large-finetuned-conll02-spanish",
"xlm-roberta-large-finetuned-conll03-english",
"xlm-roberta-large-finetuned-conll03-german",
"facebook/mbart-large-en-ro",
"facebook/mbart-large-cc25",
"xlnet-base-cased",
"xlnet-large-cased",
"google/reformer-crime-and-punishment",
"t5-small",
"google/pegasus-large",
]
pretraineds = ["google/pegasus-large"]
model_len = 50
status_len = 6
speedup_len = 8
print(f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|")
print(f"|{'-'*model_len}|{'-'*status_len}|{'-'*speedup_len}|")
for pretrained in pretraineds:
status, speedup = check(pretrained)
print(
f"|{pretrained:<{model_len}}|{status:^{status_len}}|{speedup:^{speedup_len - 1}.2f}x|"
)
if __name__ == "__main__":
main()

View File

@ -118,10 +118,12 @@ def check_diff(spm_diff, tok_diff, sp, tok):
if spm_diff == list(reversed(tok_diff)):
# AAA -> AA+A vs A+AA case.
return True
elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(tok_diff):
# Second order OK
# Barrich -> Barr + ich vs Bar + rich
return True
# elif len(spm_diff) == len(tok_diff) and tok.decode(spm_diff) == tok.decode(
# tok_diff
# ):
# # Second order OK
# # Barrich -> Barr + ich vs Bar + rich
# return True
spm_reencoded = sp.encode(sp.decode(spm_diff))
tok_reencoded = tok.encode(tok.decode(spm_diff)).ids
if spm_reencoded != spm_diff and spm_reencoded == tok_reencoded:
@ -173,6 +175,13 @@ def check_details(line, spm_ids, tok_ids, tok, sp):
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
) and check_diff(spm_ids[first + i : last], tok_ids[first + j : last], sp, tok):
return True
print(f"Spm: {[tok.decode([spm_ids[i]]) for i in range(first, last)]}")
try:
print(f"Tok: {[tok.decode([tok_ids[i]]) for i in range(first, last)]}")
except Exception:
pass
ok_start = tok.decode(spm_ids[:first])
ok_end = tok.decode(spm_ids[last:])
wrong = tok.decode(spm_ids[first:last])
@ -181,9 +190,6 @@ def check_details(line, spm_ids, tok_ids, tok, sp):
print(f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}")
else:
print(wrong)
print(f"Spm: {[tok.decode([spm_ids[i]]) for i in range(first, last)]}")
print(f"Tok: {[tok.decode([tok_ids[i]]) for i in range(first, last)]}")
return False

View File

@ -303,6 +303,7 @@ impl BPE {
fn merge_word(&self, w: &str) -> Result<Word> {
let mut indices = w.char_indices().map(|(idx, _)| idx).peekable();
let mut word = Word::with_capacity(w.len());
let mut unk: Option<(u32, usize)> = None;
while let Some(i) = indices.next() {
let (s, byte_len) = if let Some(&end) = indices.peek() {
match (i, self.continuing_subword_prefix.as_ref()) {
@ -323,16 +324,29 @@ impl BPE {
};
if let Some(id) = self.vocab.get(s.as_ref()) {
if let Some((unk_id, unk_len)) = unk {
word.add(unk_id, unk_len);
unk = None;
}
word.add(*id, byte_len);
} else if let Some(unk) = &self.unk_token {
let unk_id = self
.vocab
.get(unk)
.ok_or_else(|| Error::UnkTokenOutOfVocabulary(unk.to_owned()))?;
// Handle UNK token
word.add(*unk_id, byte_len);
} else if let Some(unk_token) = &self.unk_token {
unk = if let Some((unk_id, unk_len)) = unk {
// Fuse unk
Some((unk_id, unk_len + byte_len))
} else {
Some((
*self
.vocab
.get(unk_token)
.ok_or_else(|| Error::UnkTokenOutOfVocabulary(unk_token.to_owned()))?,
byte_len,
))
};
}
}
if let Some((unk_id, unk_len)) = unk {
word.add(unk_id, unk_len);
}
word.merge_all(&self.merges, self.dropout);
@ -455,6 +469,34 @@ mod tests {
assert_eq!(serialized, "{\"a\":0,\"b\":1,\"c\":2,\"ab\":3}");
}
#[test]
fn test_unk_get_fused() {
let vocab: Vocab = [("<unk>".into(), 0), ("a".into(), 1), ("b".into(), 2)]
.iter()
.cloned()
.collect();
let bpe = BpeBuilder::default()
.vocab_and_merges(vocab, HashMap::new())
.unk_token("<unk>".to_string())
.build()
.unwrap();
let tokens = bpe.tokenize("c").unwrap();
assert_eq!(tokens, vec![Token::new(0u32, "<unk>".into(), (0, 1)),]);
let tokens = bpe.tokenize("cc").unwrap();
assert_eq!(tokens, vec![Token::new(0u32, "<unk>".into(), (0, 2)),]);
let tokens = bpe.tokenize("accb").unwrap();
assert_eq!(
tokens,
vec![
Token::new(1u32, "a".into(), (0, 1)),
Token::new(0u32, "<unk>".into(), (1, 3)),
Token::new(2u32, "b".into(), (3, 4)),
]
);
}
#[test]
// Test tokenization. With dropout set to 0 tokenization is deterministic,
// so we know exactly what the result should be.

View File

@ -15,9 +15,16 @@ impl Normalizer for Precompiled {
// break a single test.
// You don't pass.
normalized.get().graphemes(true).for_each(|grapheme| {
let old_count = grapheme.chars().count() as isize;
if grapheme.len() < 6 {
if let Some(norm) = self.transform(grapheme) {
// debug!(
// "Replacing {:?}({:?}) by {:?}({:?})",
// grapheme,
// grapheme.chars().count(),
// norm,
// norm.chars().count()
// );
let old_count = grapheme.chars().count() as isize;
let new_count = norm.chars().count() as isize;
for (i, c) in norm.chars().enumerate() {
let n = if i == 0 {
@ -33,7 +40,15 @@ impl Normalizer for Precompiled {
for (char_index, c) in grapheme.char_indices() {
let part = &grapheme[char_index..char_index + c.len_utf8()];
if let Some(norm) = self.transform(part) {
let old_count = part.chars().count() as isize;
let new_count = norm.chars().count() as isize;
// debug!(
// "Replacing {:?}({:?}) by {:?}({:?})",
// part,
// part.chars().count(),
// norm,
// norm.chars().count()
// );
for (i, c) in norm.chars().enumerate() {
let n = if i == 0 {
new_count - old_count
@ -47,6 +62,7 @@ impl Normalizer for Precompiled {
}
}
});
// debug!("Normalized {:?}", normalized);
normalized.transform(transformations.into_iter(), 0);
Ok(())
}