Removing --release compat test.

- Leaving the one that checks that sampling follows the expected
distribution.
- Marking the python Unigram.train(..) test as slow
- The python Unigram.train(..) test now uses `big.txt` file.
This commit is contained in:
Nicolas Patry
2020-09-02 18:47:02 +02:00
committed by Anthony MOI
parent d0366529b7
commit 816632c9fa
4 changed files with 23 additions and 256 deletions

View File

@ -0,0 +1,19 @@
import pytest
def pytest_addoption(parser):
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
def pytest_configure(config):
config.addinivalue_line("markers", "slow: mark test as slow to run")
def pytest_collection_modifyitems(config, items):
if config.getoption("--runslow"):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)

View File

@ -7,10 +7,10 @@ from ..utils import data_dir, train_files
class TestUnigram:
@pytest.mark.slow
def test_train(self, train_files):
# TODO: This is super *slow* fix it before merging.
tokenizer = SentencePieceUnigramTokenizer()
tokenizer.train(train_files["simple"], show_progress=False)
tokenizer.train(train_files["big"], show_progress=False)
filename = "tests/data/unigram_trained.json"
tokenizer.save(filename)

View File

@ -62,12 +62,7 @@ def openai_files(data_dir):
@pytest.fixture(scope="session")
def train_files(data_dir):
return {
"wagahaiwa": download(
"https://storage.googleapis.com/tokenizers/unigram_wagahaiwa_nekodearu.txt"
),
"simple": download(
"https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt"
),
"big": download("https://norvig.com/big.txt"),
}

View File

@ -6,13 +6,9 @@ use std::collections::HashMap;
use std::fs::read_to_string;
use std::path::Path;
#[cfg(not(debug_assertions))]
use std::process::Command;
use tokenizers::models::unigram::Lattice;
use tokenizers::models::unigram::Unigram;
#[cfg(not(debug_assertions))]
use tokenizers::models::unigram::{Lattice, UnigramTrainerBuilder};
use tokenizers::tokenizer::Model;
#[cfg(not(debug_assertions))]
use unicode_normalization_alignments::UnicodeNormalization;
#[test]
fn test_unigram_from_file() {
@ -86,246 +82,3 @@ fn test_sample() {
}
}
}
#[cfg(not(debug_assertions))]
#[test]
fn test_train_from_file() {
let trainer = UnigramTrainerBuilder::default()
.show_progress(false)
.split_by_whitespace(true)
.build()
.unwrap();
let mut word_counts: Vec<(String, u32)> = vec![];
let file = read_to_string("data/unigram_wagahaiwa_nekodearu.txt").unwrap();
let mut ignored = 0;
for line in file.split("\n") {
if line.len() > 4192 || line.len() == 0 {
ignored += 1;
continue;
}
word_counts.push((line.to_string(), 1));
}
println!("Kept {:?} sentences", word_counts.len());
println!("Ignored {:?} sentences", ignored);
// println!("Start train {:?}", word_counts);
let (model, _) = trainer._train(word_counts).unwrap();
// println!("Stop train {:?}", model.get_vocab());
// println!("Vocab {}", model.get_vocab().len());
//
model
.save(
std::path::Path::new("data"),
Some("unigram_wagahaiwa_nekodearu"),
)
.unwrap();
let input = "吾輩《わがはい》は猫である。名前はまだ無い。";
assert_eq!(
model
.tokenize(input)
.unwrap()
.iter()
.map(|tok| tok.value.clone())
.collect::<Vec<_>>(),
vec![
"吾輩",
"",
"わが",
"はい",
"",
"",
"",
"である",
"",
"名前",
"はまだ",
"無い",
""
]
);
}
#[cfg(not(debug_assertions))]
#[test]
fn test_spm_compat_train() {
println!("Starting train compat test");
let n_sentences = 100_000;
let train_file = "data/wikitext-103-raw/wiki.train.raw";
let test_file = "data/wikitext-103-raw/wiki.test.raw";
let spm_prefix = "data/wikitext-103-raw/spm_wiki_103";
let output = Command::new("spm_train")
.args(&[
"--input",
train_file,
"--model_type",
"unigram",
"--model_prefix",
spm_prefix,
"--input_sentence_size",
&n_sentences.to_string(),
"--num_threads",
"1",
"--shuffle_input_sentence",
"0",
"--character_coverage",
"1",
])
.output()
.expect("Command failed is `spm_train` installed ?");
if !output.status.success() {
let err_msg = std::str::from_utf8(&output.stderr).unwrap();
assert!(output.status.success(), "Command failed {}", err_msg)
}
println!("train: {}", std::str::from_utf8(&output.stderr).unwrap());
let output = Command::new("spm_encode")
.args(&[
"--model",
&format!("{}.model", spm_prefix),
"--input",
test_file,
])
.output()
.expect("Command failed is `spm_train` installed ?");
// println!("{}", std::str::from_utf8(output.stdout));
let trainer = UnigramTrainerBuilder::default()
.show_progress(true)
.split_by_whitespace(true)
.space_char('▁')
.build()
.unwrap();
let mut word_counts: Vec<(String, u32)> = vec![];
let file = read_to_string(train_file).unwrap();
let mut total = 0;
for line in file.lines() {
if total == n_sentences {
break;
}
total += 1;
match normalize(line) {
Ok(formatted_line) => {
word_counts.push((formatted_line, 1));
}
_ => (),
}
}
println!("Kept {:?} sentences", word_counts.len());
// println!("Start train {:?}", word_counts);
let (model, _) = trainer._train(word_counts).unwrap();
// println!("Stop train {:?}", model.get_vocab());
// println!("Vocab {}", model.get_vocab().len());
model.save(Path::new("data"), Some("trained.json")).unwrap();
let file = read_to_string(test_file).unwrap();
let encoded = std::str::from_utf8(&output.stdout).unwrap();
let mut correct = 0;
let mut total = 0;
let mut n_tokenizer_tokens = 0;
let mut n_spm_tokens = 0;
for (tokenizer_line, spm_line) in file.lines().zip(encoded.lines()) {
let tokenizer_tokens = model.encode(&tokenizer_line.replace(" ", ""));
let mut spm_tokens: Vec<String> = spm_line
.split(' ')
.map(|s| s.to_string().replace('▁', " "))
.collect();
// XXX : For some reason spm_encode mangles trailing spaces which exist in wiki103.
if spm_tokens.len() == 1 {
spm_tokens.pop();
}
spm_tokens.push(" ".to_string());
n_tokenizer_tokens += tokenizer_tokens.len();
n_spm_tokens += spm_tokens.len();
if tokenizer_tokens == spm_tokens {
correct += 1;
}
total += 1;
// assert_eq!(tokenizer_tokens, spm_tokens, "Failed on line {}", i + 1,);
// println!("{} vs {}", tokenizer_tokens.len(), spm_tokens.len());
// assert!(tokenizer_tokens.len() <= spm_tokens.len());
// if spm_tokens.len() < tokenizer_tokens.len() {
// println!("Tokenizer line {:?}", tokenizer_tokens.join(" "));
// println!("Spm line {:?}", spm_line);
// }
}
let acc = (correct as f64) / (total as f64) * 100.0;
println!("Total tokenizer tokens {}", n_tokenizer_tokens);
println!("Total spm tokens {}", n_spm_tokens);
println!("Total accuracy {}/{} ({:.2}%)", correct, total, acc);
assert!(n_tokenizer_tokens < n_spm_tokens);
}
#[cfg(not(debug_assertions))]
fn normalize(s: &str) -> Result<String, ()> {
let prefixed = format!(
" {}",
s.chars()
.filter(|c| !c.is_control())
.collect::<String>()
.nfkc()
.map(|(s, _)| s)
.collect::<String>()
);
let mut vecs = vec![""];
vecs.extend(prefixed.split_whitespace().collect::<Vec<_>>());
let normalized: String = vecs.join(" ").to_string();
let result = normalized.replace(' ', "");
if result.len() > 4192 || result.is_empty() {
return Err(());
}
Ok(result)
}
#[cfg(test)]
#[cfg(not(debug_assertions))]
mod tests {
use super::*;
#[test]
fn test_normalize() {
assert!(normalize("").is_err());
assert!(normalize(" ").is_err());
assert!(normalize(" ").is_err());
// Sentence with heading/tailing/redundant spaces.
assert_eq!("▁ABC", normalize("ABC").unwrap());
assert_eq!("▁ABC", normalize(" ABC ").unwrap());
assert_eq!("▁A▁B▁C", normalize(" A B C ").unwrap());
assert_eq!("▁ABC", normalize(" ABC ").unwrap());
assert_eq!("▁ABC", normalize(" ").unwrap());
assert_eq!("▁ABC", normalize("  ABC").unwrap());
assert_eq!("▁ABC", normalize("  ABC  ").unwrap());
// NFKC char to char normalization.
assert_eq!("▁123", normalize("①②③").unwrap());
// NFKC char to multi-char normalization.
assert_eq!("▁株式会社", normalize("").unwrap());
// Half width katakana, character composition happens.
assert_eq!("▁グーグル", normalize(" グーグル ").unwrap());
assert_eq!(
"▁I▁saw▁a▁girl",
normalize(" I saw a   girl  ").unwrap()
);
// Remove control chars.
assert!(normalize(&format!("{}", std::char::from_u32(0x7F).unwrap())).is_err());
assert!(normalize(&format!("{}", std::char::from_u32(0x8F).unwrap())).is_err());
assert!(normalize(&format!("{}", std::char::from_u32(0x9F).unwrap())).is_err());
assert!(normalize(&format!("{}", std::char::from_u32(0x0B).unwrap())).is_err());
for c in 0x10..=0x1F {
assert!(normalize(&format!("{}", std::char::from_u32(c).unwrap())).is_err());
}
}
}