mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Removing --release
compat test.
- Leaving the one that checks that sampling follows the expected distribution. - Marking the python Unigram.train(..) test as slow - The python Unigram.train(..) test now uses `big.txt` file.
This commit is contained in:
committed by
Anthony MOI
parent
d0366529b7
commit
816632c9fa
19
bindings/python/conftest.py
Normal file
19
bindings/python/conftest.py
Normal file
@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "slow: mark test as slow to run")
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
if config.getoption("--runslow"):
|
||||
# --runslow given in cli: do not skip slow tests
|
||||
return
|
||||
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
|
||||
for item in items:
|
||||
if "slow" in item.keywords:
|
||||
item.add_marker(skip_slow)
|
@ -7,10 +7,10 @@ from ..utils import data_dir, train_files
|
||||
|
||||
|
||||
class TestUnigram:
|
||||
@pytest.mark.slow
|
||||
def test_train(self, train_files):
|
||||
# TODO: This is super *slow* fix it before merging.
|
||||
tokenizer = SentencePieceUnigramTokenizer()
|
||||
tokenizer.train(train_files["simple"], show_progress=False)
|
||||
tokenizer.train(train_files["big"], show_progress=False)
|
||||
|
||||
filename = "tests/data/unigram_trained.json"
|
||||
tokenizer.save(filename)
|
||||
|
@ -62,12 +62,7 @@ def openai_files(data_dir):
|
||||
@pytest.fixture(scope="session")
|
||||
def train_files(data_dir):
|
||||
return {
|
||||
"wagahaiwa": download(
|
||||
"https://storage.googleapis.com/tokenizers/unigram_wagahaiwa_nekodearu.txt"
|
||||
),
|
||||
"simple": download(
|
||||
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt"
|
||||
),
|
||||
"big": download("https://norvig.com/big.txt"),
|
||||
}
|
||||
|
||||
|
||||
|
@ -6,13 +6,9 @@ use std::collections::HashMap;
|
||||
use std::fs::read_to_string;
|
||||
use std::path::Path;
|
||||
#[cfg(not(debug_assertions))]
|
||||
use std::process::Command;
|
||||
use tokenizers::models::unigram::Lattice;
|
||||
use tokenizers::models::unigram::Unigram;
|
||||
#[cfg(not(debug_assertions))]
|
||||
use tokenizers::models::unigram::{Lattice, UnigramTrainerBuilder};
|
||||
use tokenizers::tokenizer::Model;
|
||||
#[cfg(not(debug_assertions))]
|
||||
use unicode_normalization_alignments::UnicodeNormalization;
|
||||
|
||||
#[test]
|
||||
fn test_unigram_from_file() {
|
||||
@ -86,246 +82,3 @@ fn test_sample() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn test_train_from_file() {
|
||||
let trainer = UnigramTrainerBuilder::default()
|
||||
.show_progress(false)
|
||||
.split_by_whitespace(true)
|
||||
.build()
|
||||
.unwrap();
|
||||
let mut word_counts: Vec<(String, u32)> = vec![];
|
||||
let file = read_to_string("data/unigram_wagahaiwa_nekodearu.txt").unwrap();
|
||||
let mut ignored = 0;
|
||||
for line in file.split("\n") {
|
||||
if line.len() > 4192 || line.len() == 0 {
|
||||
ignored += 1;
|
||||
continue;
|
||||
}
|
||||
word_counts.push((line.to_string(), 1));
|
||||
}
|
||||
println!("Kept {:?} sentences", word_counts.len());
|
||||
println!("Ignored {:?} sentences", ignored);
|
||||
|
||||
// println!("Start train {:?}", word_counts);
|
||||
let (model, _) = trainer._train(word_counts).unwrap();
|
||||
// println!("Stop train {:?}", model.get_vocab());
|
||||
// println!("Vocab {}", model.get_vocab().len());
|
||||
//
|
||||
model
|
||||
.save(
|
||||
std::path::Path::new("data"),
|
||||
Some("unigram_wagahaiwa_nekodearu"),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let input = "吾輩《わがはい》は猫である。名前はまだ無い。";
|
||||
assert_eq!(
|
||||
model
|
||||
.tokenize(input)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|tok| tok.value.clone())
|
||||
.collect::<Vec<_>>(),
|
||||
vec![
|
||||
"吾輩",
|
||||
"《",
|
||||
"わが",
|
||||
"はい",
|
||||
"》",
|
||||
"は",
|
||||
"猫",
|
||||
"である",
|
||||
"。",
|
||||
"名前",
|
||||
"はまだ",
|
||||
"無い",
|
||||
"。"
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn test_spm_compat_train() {
|
||||
println!("Starting train compat test");
|
||||
let n_sentences = 100_000;
|
||||
let train_file = "data/wikitext-103-raw/wiki.train.raw";
|
||||
let test_file = "data/wikitext-103-raw/wiki.test.raw";
|
||||
let spm_prefix = "data/wikitext-103-raw/spm_wiki_103";
|
||||
let output = Command::new("spm_train")
|
||||
.args(&[
|
||||
"--input",
|
||||
train_file,
|
||||
"--model_type",
|
||||
"unigram",
|
||||
"--model_prefix",
|
||||
spm_prefix,
|
||||
"--input_sentence_size",
|
||||
&n_sentences.to_string(),
|
||||
"--num_threads",
|
||||
"1",
|
||||
"--shuffle_input_sentence",
|
||||
"0",
|
||||
"--character_coverage",
|
||||
"1",
|
||||
])
|
||||
.output()
|
||||
.expect("Command failed is `spm_train` installed ?");
|
||||
if !output.status.success() {
|
||||
let err_msg = std::str::from_utf8(&output.stderr).unwrap();
|
||||
assert!(output.status.success(), "Command failed {}", err_msg)
|
||||
}
|
||||
println!("train: {}", std::str::from_utf8(&output.stderr).unwrap());
|
||||
|
||||
let output = Command::new("spm_encode")
|
||||
.args(&[
|
||||
"--model",
|
||||
&format!("{}.model", spm_prefix),
|
||||
"--input",
|
||||
test_file,
|
||||
])
|
||||
.output()
|
||||
.expect("Command failed is `spm_train` installed ?");
|
||||
|
||||
// println!("{}", std::str::from_utf8(output.stdout));
|
||||
|
||||
let trainer = UnigramTrainerBuilder::default()
|
||||
.show_progress(true)
|
||||
.split_by_whitespace(true)
|
||||
.space_char('▁')
|
||||
.build()
|
||||
.unwrap();
|
||||
let mut word_counts: Vec<(String, u32)> = vec![];
|
||||
let file = read_to_string(train_file).unwrap();
|
||||
let mut total = 0;
|
||||
for line in file.lines() {
|
||||
if total == n_sentences {
|
||||
break;
|
||||
}
|
||||
|
||||
total += 1;
|
||||
match normalize(line) {
|
||||
Ok(formatted_line) => {
|
||||
word_counts.push((formatted_line, 1));
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
println!("Kept {:?} sentences", word_counts.len());
|
||||
|
||||
// println!("Start train {:?}", word_counts);
|
||||
let (model, _) = trainer._train(word_counts).unwrap();
|
||||
// println!("Stop train {:?}", model.get_vocab());
|
||||
// println!("Vocab {}", model.get_vocab().len());
|
||||
|
||||
model.save(Path::new("data"), Some("trained.json")).unwrap();
|
||||
|
||||
let file = read_to_string(test_file).unwrap();
|
||||
let encoded = std::str::from_utf8(&output.stdout).unwrap();
|
||||
|
||||
let mut correct = 0;
|
||||
let mut total = 0;
|
||||
let mut n_tokenizer_tokens = 0;
|
||||
let mut n_spm_tokens = 0;
|
||||
for (tokenizer_line, spm_line) in file.lines().zip(encoded.lines()) {
|
||||
let tokenizer_tokens = model.encode(&tokenizer_line.replace(" ", "▁"));
|
||||
let mut spm_tokens: Vec<String> = spm_line
|
||||
.split(' ')
|
||||
.map(|s| s.to_string().replace('▁', " "))
|
||||
.collect();
|
||||
// XXX : For some reason spm_encode mangles trailing spaces which exist in wiki103.
|
||||
if spm_tokens.len() == 1 {
|
||||
spm_tokens.pop();
|
||||
}
|
||||
spm_tokens.push(" ".to_string());
|
||||
|
||||
n_tokenizer_tokens += tokenizer_tokens.len();
|
||||
n_spm_tokens += spm_tokens.len();
|
||||
if tokenizer_tokens == spm_tokens {
|
||||
correct += 1;
|
||||
}
|
||||
total += 1;
|
||||
|
||||
// assert_eq!(tokenizer_tokens, spm_tokens, "Failed on line {}", i + 1,);
|
||||
// println!("{} vs {}", tokenizer_tokens.len(), spm_tokens.len());
|
||||
// assert!(tokenizer_tokens.len() <= spm_tokens.len());
|
||||
// if spm_tokens.len() < tokenizer_tokens.len() {
|
||||
// println!("Tokenizer line {:?}", tokenizer_tokens.join(" "));
|
||||
// println!("Spm line {:?}", spm_line);
|
||||
// }
|
||||
}
|
||||
let acc = (correct as f64) / (total as f64) * 100.0;
|
||||
println!("Total tokenizer tokens {}", n_tokenizer_tokens);
|
||||
println!("Total spm tokens {}", n_spm_tokens);
|
||||
println!("Total accuracy {}/{} ({:.2}%)", correct, total, acc);
|
||||
assert!(n_tokenizer_tokens < n_spm_tokens);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
fn normalize(s: &str) -> Result<String, ()> {
|
||||
let prefixed = format!(
|
||||
" {}",
|
||||
s.chars()
|
||||
.filter(|c| !c.is_control())
|
||||
.collect::<String>()
|
||||
.nfkc()
|
||||
.map(|(s, _)| s)
|
||||
.collect::<String>()
|
||||
);
|
||||
let mut vecs = vec![""];
|
||||
vecs.extend(prefixed.split_whitespace().collect::<Vec<_>>());
|
||||
|
||||
let normalized: String = vecs.join(" ").to_string();
|
||||
let result = normalized.replace(' ', "▁");
|
||||
if result.len() > 4192 || result.is_empty() {
|
||||
return Err(());
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(not(debug_assertions))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_normalize() {
|
||||
assert!(normalize("").is_err());
|
||||
assert!(normalize(" ").is_err());
|
||||
assert!(normalize(" ").is_err());
|
||||
|
||||
// Sentence with heading/tailing/redundant spaces.
|
||||
assert_eq!("▁ABC", normalize("ABC").unwrap());
|
||||
assert_eq!("▁ABC", normalize(" ABC ").unwrap());
|
||||
assert_eq!("▁A▁B▁C", normalize(" A B C ").unwrap());
|
||||
assert_eq!("▁ABC", normalize(" ABC ").unwrap());
|
||||
assert_eq!("▁ABC", normalize(" ABC ").unwrap());
|
||||
assert_eq!("▁ABC", normalize(" ABC").unwrap());
|
||||
assert_eq!("▁ABC", normalize(" ABC ").unwrap());
|
||||
|
||||
// NFKC char to char normalization.
|
||||
assert_eq!("▁123", normalize("①②③").unwrap());
|
||||
|
||||
// NFKC char to multi-char normalization.
|
||||
assert_eq!("▁株式会社", normalize("㍿").unwrap());
|
||||
|
||||
// Half width katakana, character composition happens.
|
||||
assert_eq!("▁グーグル", normalize(" グーグル ").unwrap());
|
||||
|
||||
assert_eq!(
|
||||
"▁I▁saw▁a▁girl",
|
||||
normalize(" I saw a girl ").unwrap()
|
||||
);
|
||||
|
||||
// Remove control chars.
|
||||
assert!(normalize(&format!("{}", std::char::from_u32(0x7F).unwrap())).is_err());
|
||||
assert!(normalize(&format!("{}", std::char::from_u32(0x8F).unwrap())).is_err());
|
||||
assert!(normalize(&format!("{}", std::char::from_u32(0x9F).unwrap())).is_err());
|
||||
assert!(normalize(&format!("{}", std::char::from_u32(0x0B).unwrap())).is_err());
|
||||
for c in 0x10..=0x1F {
|
||||
assert!(normalize(&format!("{}", std::char::from_u32(c).unwrap())).is_err());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user