mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-31 04:29:21 +00:00
Enabling training parity check for tokenizers.UnigramTrainer
This commit is contained in:
committed by
Anthony MOI
parent
558e76f18e
commit
ee3860c029
@ -15,14 +15,71 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train",
|
"--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train",
|
||||||
|
action="store_true",
|
||||||
|
help="Instead of checking the encoder part, we check the trainer part",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
spm.SentencePieceTrainer.Train(
|
spm.SentencePieceTrainer.Train(
|
||||||
f"--input={args.input_file} --model_prefix={args.model_prefix}"
|
f"--input={args.input_file} --model_prefix={args.model_prefix}"
|
||||||
|
f" --character_coverage=1.0"
|
||||||
|
f" --max_sentence_length=40000"
|
||||||
|
f" --num_threads=1"
|
||||||
f" --vocab_size={args.vocab_size}"
|
f" --vocab_size={args.vocab_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.train:
|
||||||
|
check_train(args)
|
||||||
|
else:
|
||||||
|
check_encode(args)
|
||||||
|
|
||||||
|
|
||||||
|
def check_train(args):
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
model_filename = f"{args.model_prefix}.model"
|
||||||
|
sp.Load(model_filename)
|
||||||
|
|
||||||
|
tokenizer = tokenizers.SentencePieceUnigramTokenizer()
|
||||||
|
tokenizer.train(args.input_file, show_progress=False)
|
||||||
|
|
||||||
|
spm_tokens = 0
|
||||||
|
tokenizer_tokens = 0
|
||||||
|
|
||||||
|
with open(args.input_file, "r") as f:
|
||||||
|
for i, line in enumerate(f):
|
||||||
|
line = line.strip()
|
||||||
|
ids = sp.EncodeAsIds(line)
|
||||||
|
|
||||||
|
encoded = tokenizer.encode(line)
|
||||||
|
|
||||||
|
spm_tokens += len(ids)
|
||||||
|
tokenizer_tokens += len(encoded.ids)
|
||||||
|
|
||||||
|
vocab = [0 for i in range(args.vocab_size)]
|
||||||
|
spm_vocab = [0 for i in range(args.vocab_size)]
|
||||||
|
|
||||||
|
for token, index in tokenizer.get_vocab().items():
|
||||||
|
vocab[index] = token
|
||||||
|
|
||||||
|
for i in range(args.vocab_size):
|
||||||
|
spm_vocab[i] = sp.id_to_piece(i)
|
||||||
|
|
||||||
|
# 0 is unk in tokenizers, 0, 1, 2 are unk bos, eos in spm by default.
|
||||||
|
for i, (token, spm_token) in enumerate(zip(vocab[1:], spm_vocab[3:])):
|
||||||
|
if token != spm_token:
|
||||||
|
print(f"First different token is token {i} ({token} != {spm_token})")
|
||||||
|
break
|
||||||
|
|
||||||
|
print(f"Tokenizer used {tokenizer_tokens}, where spm used {spm_tokens}")
|
||||||
|
assert (
|
||||||
|
tokenizer_tokens < spm_tokens
|
||||||
|
), "Our trainer should be at least more efficient than the SPM one"
|
||||||
|
|
||||||
|
|
||||||
|
def check_encode(args):
|
||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
model_filename = f"{args.model_prefix}.model"
|
model_filename = f"{args.model_prefix}.model"
|
||||||
sp.Load(model_filename)
|
sp.Load(model_filename)
|
||||||
|
@ -226,7 +226,7 @@ impl UnigramTrainer {
|
|||||||
let mut all_chars: HashMap<char, u32> = HashMap::new();
|
let mut all_chars: HashMap<char, u32> = HashMap::new();
|
||||||
let c_sentence_boundary = '\0';
|
let c_sentence_boundary = '\0';
|
||||||
let k_sentence_boundary = '\0'.to_string();
|
let k_sentence_boundary = '\0'.to_string();
|
||||||
for (string, _) in sentences {
|
for (string, n) in sentences {
|
||||||
flat_string.push_str(&string);
|
flat_string.push_str(&string);
|
||||||
// XXX
|
// XXX
|
||||||
// Comment suggests we add sentence boundary, but it seems to be missing from actual
|
// Comment suggests we add sentence boundary, but it seems to be missing from actual
|
||||||
@ -234,7 +234,7 @@ impl UnigramTrainer {
|
|||||||
flat_string.push_str(&k_sentence_boundary);
|
flat_string.push_str(&k_sentence_boundary);
|
||||||
for c in string.chars() {
|
for c in string.chars() {
|
||||||
if c != c_sentence_boundary {
|
if c != c_sentence_boundary {
|
||||||
*all_chars.entry(c).or_insert(0) += 1;
|
*all_chars.entry(c).or_insert(0) += n;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user