Adressing first pass of comments.

This commit is contained in:
Nicolas Patry
2020-09-23 11:29:17 +02:00
parent 1cd4824273
commit 8f8156fd2c
10 changed files with 196 additions and 125 deletions

View File

@@ -70,13 +70,17 @@ elif args.type == "bert":
tok_r = Tokenizer(WordPiece(args.vocab, unk_token="[UNK]", max_input_chars_per_word=100)) tok_r = Tokenizer(WordPiece(args.vocab, unk_token="[UNK]", max_input_chars_per_word=100))
tok_r.normalizer = BertNormalizer( tok_r.normalizer = BertNormalizer(
clean_text=True, handle_chinese_chars=True, strip_accents=True, lowercase=True, clean_text=True,
handle_chinese_chars=True,
strip_accents=True,
lowercase=True,
) )
# tok_r.pre_tokenizer = pre_tokenizers.Whitespace() # tok_r.pre_tokenizer = pre_tokenizers.Whitespace()
tok_r.pre_tokenizer = pre_tokenizers.BertPreTokenizer() tok_r.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
tok_r.decoder = decoders.WordPiece() tok_r.decoder = decoders.WordPiece()
tok_r.post_processor = BertProcessing( tok_r.post_processor = BertProcessing(
("[SEP]", tok_r.token_to_id("[SEP]")), ("[CLS]", tok_r.token_to_id("[CLS]")), ("[SEP]", tok_r.token_to_id("[SEP]")),
("[CLS]", tok_r.token_to_id("[CLS]")),
) )
else: else:
raise Exception(f"Unknown type {args.type}") raise Exception(f"Unknown type {args.type}")

View File

@@ -32,7 +32,10 @@ if not files:
# Initialize an empty tokenizer # Initialize an empty tokenizer
tokenizer = BertWordPieceTokenizer( tokenizer = BertWordPieceTokenizer(
clean_text=True, handle_chinese_chars=True, strip_accents=True, lowercase=True, clean_text=True,
handle_chinese_chars=True,
strip_accents=True,
lowercase=True,
) )
# And then train # And then train

View File

@@ -76,43 +76,28 @@ class SpmConverter(Converter):
model_type = proto.trainer_spec.model_type model_type = proto.trainer_spec.model_type
vocab = self.vocab(proto) vocab = self.vocab(proto)
unk_id = self.unk_id(proto) unk_id = self.unk_id(proto)
filename = self.original_tokenizer.vocab_file
if model_type == 1: if model_type == 1:
data = {"unk_id": unk_id, "vocab": vocab} tokenizer = Tokenizer(Unigram(vocab, unk_id))
out_vocab_filename = f"{filename}.json"
try:
with open(out_vocab_filename, "w") as f:
json.dump(data, f, indent=4)
tokenizer = Tokenizer(Unigram(out_vocab_filename))
finally:
os.remove(out_vocab_filename)
elif model_type == 2: elif model_type == 2:
vocab, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract() vocab, merges = SentencePieceExtractor(
self.original_tokenizer.vocab_file
).extract()
# Open output files and let's extract model information # Open output files and let's extract model information
out_vocab_filename = f"{filename}.vocab" actual_merges = {}
out_merge_filename = f"{filename}.merge" for id_merge, (a, b) in enumerate(merges):
try: id_a = vocab[a]
with open(out_vocab_filename, "w") as vocab_f: id_b = vocab[b]
json.dump(vocab, vocab_f) id_ab = vocab[a + b]
try: id_ab = vocab[a + b]
with open(out_merge_filename, "w") as merges_f: actual_merges[(id_a, id_b)] = (id_merge, id_ab)
# Save content tokenizer = Tokenizer(
merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{os.linesep}", merges)) BPE(
tokenizer = Tokenizer( vocab,
BPE( actual_merges,
out_vocab_filename, unk_token=proto.trainer_spec.unk_piece,
out_merge_filename, fuse_unk=True,
unk_token=proto.trainer_spec.unk_piece, )
fuse_unk=True, )
)
)
finally:
os.remove(out_merge_filename)
finally:
os.remove(out_vocab_filename)
else: else:
raise Exception( raise Exception(
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm" "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
@@ -346,7 +331,9 @@ class PegasusConverter(SpmConverter):
return TemplateProcessing( return TemplateProcessing(
seq_a=["$0", eos], seq_a=["$0", eos],
seq_b=["$1", eos], seq_b=["$1", eos],
special_tokens=[(eos, tokenizer.get_vocab()[eos]),], special_tokens=[
(eos, tokenizer.get_vocab()[eos]),
],
) )
@@ -355,7 +342,9 @@ class T5Converter(SpmConverter):
return TemplateProcessing( return TemplateProcessing(
seq_a=["$0", "</s>"], seq_a=["$0", "</s>"],
seq_b=["$1", "</s>"], seq_b=["$1", "</s>"],
special_tokens=[("</s>", tokenizer.get_vocab()["</s>"]),], special_tokens=[
("</s>", tokenizer.get_vocab()["</s>"]),
],
) )
@@ -447,7 +436,9 @@ def main():
model_len = 50 model_len = 50
status_len = 6 status_len = 6
speedup_len = 8 speedup_len = 8
print(f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|") print(
f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|"
)
print(f"|{'-'*model_len}|{'-'*status_len}|{'-'*speedup_len}|") print(f"|{'-'*model_len}|{'-'*status_len}|{'-'*speedup_len}|")
for pretrained in args.models: for pretrained in args.models:
status, speedup = check(pretrained, args.filename) status, speedup = check(pretrained, args.filename)

View File

@@ -17,7 +17,11 @@ except Exception:
def main(): def main():
parser = ArgumentParser("SentencePiece parity checker") parser = ArgumentParser("SentencePiece parity checker")
parser.add_argument( parser.add_argument(
"--input-file", "-i", type=str, required=True, help="Which files do you want to train from", "--input-file",
"-i",
type=str,
required=True,
help="Which files do you want to train from",
) )
parser.add_argument( parser.add_argument(
"--model-file", "--model-file",
@@ -28,13 +32,22 @@ def main():
help="Use a pretrained token file", help="Use a pretrained token file",
) )
parser.add_argument( parser.add_argument(
"--model-prefix", type=str, default="spm_parity", help="Model prefix for spm_train", "--model-prefix",
type=str,
default="spm_parity",
help="Model prefix for spm_train",
) )
parser.add_argument( parser.add_argument(
"--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train", "--vocab-size",
"-v",
type=int,
default=8000,
help="Vocab size for spm_train",
) )
parser.add_argument( parser.add_argument(
"--verbose", action="store_true", help="Verbosity", "--verbose",
action="store_true",
help="Verbosity",
) )
parser.add_argument( parser.add_argument(
"--train", "--train",
@@ -160,10 +173,14 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
spms = Counter(spm_ids[first:last]) spms = Counter(spm_ids[first:last])
toks = Counter(tok_ids[first:last]) toks = Counter(tok_ids[first:last])
removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si} removable_tokens = {
spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si
}
min_width = 3 min_width = 3
for i in range(last - first - min_width): for i in range(last - first - min_width):
if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)): if all(
spm_ids[first + i + j] in removable_tokens for j in range(min_width)
):
possible_matches = [ possible_matches = [
k k
for k in range(last - first - min_width) for k in range(last - first - min_width)
@@ -174,7 +191,11 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
if check_diff( if check_diff(
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
) and check_details( ) and check_details(
line, spm_ids[first + i : last], tok_ids[first + j : last], sp, tok, line,
spm_ids[first + i : last],
tok_ids[first + j : last],
sp,
tok,
): ):
return True return True
@@ -189,7 +210,9 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
wrong = tok.decode(spm_ids[first:last]) wrong = tok.decode(spm_ids[first:last])
print() print()
if has_color: if has_color:
print(f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}") print(
f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}"
)
else: else:
print(wrong) print(wrong)
return False return False
@@ -203,17 +226,8 @@ def check_encode(args):
tok = tokenizers.SentencePieceUnigramTokenizer.from_spm(args.model_file) tok = tokenizers.SentencePieceUnigramTokenizer.from_spm(args.model_file)
else: else:
vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())] vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())]
vocab_filename = f"{args.model_file}.json"
unk_id = sp.unk_id() unk_id = sp.unk_id()
tok = tokenizers.SentencePieceUnigramTokenizer(vocab, unk_id)
data = {"unk_id": unk_id, "vocab": vocab}
try:
with open(vocab_filename, "w") as f:
json.dump(data, f, indent=4)
tok = tokenizers.SentencePieceUnigramTokenizer(vocab_filename)
finally:
os.remove(vocab_filename)
perfect = 0 perfect = 0
imperfect = 0 imperfect = 0
@@ -255,7 +269,9 @@ def check_encode(args):
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})") print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
total = perfect + imperfect + wrong total = perfect + imperfect + wrong
print(f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}") print(
f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}"
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -34,3 +34,11 @@ impl<T> ToPyResult<T> {
self.into() self.into()
} }
} }
pub(crate) fn deprecation_warning(version: &str, message: &str) -> PyResult<()> {
let gil = pyo3::Python::acquire_gil();
let python = gil.python();
let deprecation_warning = python.import("builtins")?.get("DeprecationWarning")?;
let full_message = format!("Deprecated in {}: {}", version, message);
pyo3::PyErr::warn(python, deprecation_warning, &full_message, 0)
}

View File

@@ -15,16 +15,7 @@ use tk::models::ModelWrapper;
use tk::{Model, Token}; use tk::{Model, Token};
use tokenizers as tk; use tokenizers as tk;
use super::error::ToPyResult; use super::error::{deprecation_warning, ToPyResult};
fn deprecation_warning(version: &str, message: &str) -> PyResult<()> {
let gil = pyo3::Python::acquire_gil();
let python = gil.python();
let deprecation_warning = python.import("builtins")?.get("DeprecationWarning")?;
let full_message = format!("Deprecated in {}: {}", version, message);
pyo3::PyErr::warn(python, deprecation_warning, &full_message, 0)?;
Ok(())
}
/// A Model represents some tokenization algorithm like BPE or Word /// A Model represents some tokenization algorithm like BPE or Word
/// This class cannot be constructed directly. Please use one of the concrete models. /// This class cannot be constructed directly. Please use one of the concrete models.
@@ -183,13 +174,24 @@ impl PyBPE {
} }
} }
#[derive(FromPyObject)]
enum PyVocab<'a> {
Vocab(Vocab),
Filename(&'a str),
}
#[derive(FromPyObject)]
enum PyMerges<'a> {
Merges(Merges),
Filename(&'a str),
}
#[pymethods] #[pymethods]
impl PyBPE { impl PyBPE {
#[new] #[new]
#[args(kwargs = "**")] #[args(kwargs = "**")]
fn new( fn new(
vocab: Option<&PyAny>, vocab: Option<PyVocab>,
merges: Option<&PyAny>, merges: Option<PyMerges>,
kwargs: Option<&PyDict>, kwargs: Option<&PyDict>,
) -> PyResult<(Self, PyModel)> { ) -> PyResult<(Self, PyModel)> {
if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) { if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) {
@@ -199,17 +201,24 @@ impl PyBPE {
} }
let mut builder = BPE::builder(); let mut builder = BPE::builder();
if let (Some(vocab_any), Some(merges_any)) = (vocab, merges) { if let (Some(vocab), Some(merges)) = (vocab, merges) {
if let (Ok(vocab), Ok(merges)) = (vocab_any.extract(), merges_any.extract()) { match (vocab, merges) {
builder = builder.vocab_and_merges(vocab, merges); (PyVocab::Vocab(vocab), PyMerges::Merges(merges)) => {
} else { builder = builder.vocab_and_merges(vocab, merges);
let vocab_filename: String = vocab_any.extract()?; }
let merges_filename: String = merges_any.extract()?; (PyVocab::Filename(vocab_filename), PyMerges::Filename(merges_filename)) => {
deprecation_warning( deprecation_warning(
"0.9.0", "0.9.0",
"BPE.__init__ will not create from files anymore, try `BPE.from_files` instead", "BPE.__init__ will not create from files anymore, try `BPE.from_files` instead",
)?; )?;
builder = builder.files(vocab_filename, merges_filename); builder =
builder.files(vocab_filename.to_string(), merges_filename.to_string());
}
_ => {
return Err(exceptions::PyValueError::new_err(
"`vocab` and `merges` must be both be from memory or both filenames",
));
}
} }
} }
@@ -268,20 +277,21 @@ impl PyWordPiece {
impl PyWordPiece { impl PyWordPiece {
#[new] #[new]
#[args(kwargs = "**")] #[args(kwargs = "**")]
fn new(vocab: Option<&PyAny>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> { fn new(vocab: Option<PyVocab>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
let mut builder = WordPiece::builder(); let mut builder = WordPiece::builder();
if let Some(vocab_any) = vocab { if let Some(vocab) = vocab {
#[allow(deprecated)] match vocab {
if let Ok(vocab) = vocab_any.extract() { PyVocab::Vocab(vocab) => {
builder = builder.vocab(vocab); builder = builder.vocab(vocab);
} else { }
deprecation_warning( PyVocab::Filename(vocab_filename) => {
"0.9.0", deprecation_warning(
"WordPiece.__init__ will not create from files anymore, try `WordPiece.from_file` instead", "0.9.0",
)?; "WordPiece.__init__ will not create from files anymore, try `WordPiece.from_file` instead",
let vocab_filename: String = vocab_any.extract()?; )?;
builder = builder.files(vocab_filename); builder = builder.files(vocab_filename.to_string());
}
} }
} }
@@ -320,27 +330,27 @@ impl PyWordLevel {
impl PyWordLevel { impl PyWordLevel {
#[new] #[new]
#[args(kwargs = "**")] #[args(kwargs = "**")]
fn new(vocab: Option<&PyAny>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> { fn new(vocab: Option<PyVocab>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
let unk_token = PyWordLevel::get_unk(kwargs)?; let unk_token = PyWordLevel::get_unk(kwargs)?;
if let Some(vocab_object) = vocab { if let Some(vocab) = vocab {
let model = if let Ok(vocab) = vocab_object.extract() { let model = match vocab {
WordLevel::builder() PyVocab::Vocab(vocab) => WordLevel::builder()
.vocab(vocab) .vocab(vocab)
.unk_token(unk_token) .unk_token(unk_token)
.build() .build(),
} else { PyVocab::Filename(vocab_filename) => {
let filename: &str = vocab_object.extract()?; deprecation_warning(
deprecation_warning(
"0.9.0", "0.9.0",
"WordLevel.__init__ will not create from files anymore, try `WordLevel.from_file` instead", "WordLevel.__init__ will not create from files anymore, try `WordLevel.from_file` instead",
)?; )?;
WordLevel::from_files(filename, unk_token).map_err(|e| { WordLevel::from_files(vocab_filename, unk_token).map_err(|e| {
exceptions::PyException::new_err(format!( exceptions::PyException::new_err(format!(
"Error while loading WordLevel: {}", "Error while loading WordLevel: {}",
e e
)) ))
})? })?
}
}; };
Ok((PyWordLevel {}, PyModel::new(Arc::new(model.into())))) Ok((PyWordLevel {}, PyModel::new(Arc::new(model.into()))))

View File

@@ -14,22 +14,33 @@ class TestBPE:
vocab = {"a": 0, "b": 1, "ab": 2} vocab = {"a": 0, "b": 1, "ab": 2}
merges = {(0, 1): (0, 2)} merges = {(0, 1): (0, 2)}
assert isinstance(BPE(vocab, merges), Model) assert isinstance(BPE(vocab, merges), Model)
with pytest.raises(ValueError, match="`vocab` and `merges` must be both specified"): with pytest.raises(
ValueError, match="`vocab` and `merges` must be both specified"
):
BPE(vocab=vocab) BPE(vocab=vocab)
BPE(merges=merges) BPE(merges=merges)
assert isinstance(pickle.loads(pickle.dumps(BPE(vocab, merges))), BPE,) assert isinstance(
pickle.loads(pickle.dumps(BPE(vocab, merges))),
BPE,
)
# Deprecated calls in 0.9 # Deprecated calls in 0.9
with pytest.deprecated_call(): with pytest.deprecated_call():
assert isinstance(BPE(roberta_files["vocab"], roberta_files["merges"]), Model) assert isinstance(
BPE(roberta_files["vocab"], roberta_files["merges"]), Model
)
with pytest.raises(ValueError, match="`vocab` and `merges` must be both specified"): with pytest.raises(
ValueError, match="`vocab` and `merges` must be both specified"
):
BPE(vocab=roberta_files["vocab"]) BPE(vocab=roberta_files["vocab"])
BPE(merges=roberta_files["merges"]) BPE(merges=roberta_files["merges"])
with pytest.deprecated_call(): with pytest.deprecated_call():
assert isinstance( assert isinstance(
pickle.loads(pickle.dumps(BPE(roberta_files["vocab"], roberta_files["merges"]))), pickle.loads(
pickle.dumps(BPE(roberta_files["vocab"], roberta_files["merges"]))
),
BPE, BPE,
) )
@@ -48,7 +59,9 @@ class TestWordPiece:
with pytest.deprecated_call(): with pytest.deprecated_call():
assert isinstance(WordPiece(bert_files["vocab"]), Model) assert isinstance(WordPiece(bert_files["vocab"]), Model)
with pytest.deprecated_call(): with pytest.deprecated_call():
assert isinstance(pickle.loads(pickle.dumps(WordPiece(bert_files["vocab"]))), WordPiece) assert isinstance(
pickle.loads(pickle.dumps(WordPiece(bert_files["vocab"]))), WordPiece
)
class TestWordLevel: class TestWordLevel:

View File

@@ -144,7 +144,9 @@ class TestTokenizer:
assert output.tokens == ["my", "name", "is", "john"] assert output.tokens == ["my", "name", "is", "john"]
# Can encode a batch with both a single sequence and a pair of sequences # Can encode a batch with both a single sequence and a pair of sequences
output = tokenizer.encode_batch(["my name is john", ("my name is john", "pair")]) output = tokenizer.encode_batch(
["my name is john", ("my name is john", "pair")]
)
assert len(output) == 2 assert len(output) == 2
def test_encode_formats(self, bert_files): def test_encode_formats(self, bert_files):
@@ -167,7 +169,9 @@ class TestTokenizer:
] ]
output = tokenizer.encode(["my", "name", "is", "john"], is_pretokenized=True) output = tokenizer.encode(["my", "name", "is", "john"], is_pretokenized=True)
assert output.tokens == ["[CLS]", "my", "name", "is", "john", "[SEP]"] assert output.tokens == ["[CLS]", "my", "name", "is", "john", "[SEP]"]
output = tokenizer.encode(["my", "name", "is", "john"], ["pair"], is_pretokenized=True) output = tokenizer.encode(
["my", "name", "is", "john"], ["pair"], is_pretokenized=True
)
assert output.tokens == [ assert output.tokens == [
"[CLS]", "[CLS]",
"my", "my",
@@ -213,13 +217,19 @@ class TestTokenizer:
# Numpy # Numpy
test_single(np.array(["My name is John", "My name is Georges"])) test_single(np.array(["My name is John", "My name is Georges"]))
test_pair(np.array([("My name is John", "pair"), ("My name is Georges", "pair")])) test_pair(
test_pair(np.array([["My name is John", "pair"], ["My name is Georges", "pair"]])) np.array([("My name is John", "pair"), ("My name is Georges", "pair")])
)
test_pair(
np.array([["My name is John", "pair"], ["My name is Georges", "pair"]])
)
# PreTokenized inputs # PreTokenized inputs
# Lists # Lists
test_single([["My", "name", "is", "John"], ["My", "name", "is", "Georges"]], True) test_single(
[["My", "name", "is", "John"], ["My", "name", "is", "Georges"]], True
)
test_pair( test_pair(
[ [
(["My", "name", "is", "John"], ["pair"]), (["My", "name", "is", "John"], ["pair"]),
@@ -236,7 +246,9 @@ class TestTokenizer:
) )
# Tuples # Tuples
test_single((("My", "name", "is", "John"), ("My", "name", "is", "Georges")), True) test_single(
(("My", "name", "is", "John"), ("My", "name", "is", "Georges")), True
)
test_pair( test_pair(
( (
(("My", "name", "is", "John"), ("pair",)), (("My", "name", "is", "John"), ("pair",)),
@@ -254,10 +266,12 @@ class TestTokenizer:
# Numpy # Numpy
test_single( test_single(
np.array([["My", "name", "is", "John"], ["My", "name", "is", "Georges"]]), True, np.array([["My", "name", "is", "John"], ["My", "name", "is", "Georges"]]),
True,
) )
test_single( test_single(
np.array((("My", "name", "is", "John"), ("My", "name", "is", "Georges"))), True, np.array((("My", "name", "is", "John"), ("My", "name", "is", "Georges"))),
True,
) )
test_pair( test_pair(
np.array( np.array(
@@ -298,11 +312,14 @@ class TestTokenizer:
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True) tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
tokenizer.post_processor = RobertaProcessing( tokenizer.post_processor = RobertaProcessing(
("</s>", tokenizer.token_to_id("</s>")), ("<s>", tokenizer.token_to_id("<s>")), ("</s>", tokenizer.token_to_id("</s>")),
("<s>", tokenizer.token_to_id("<s>")),
) )
# Can encode with special tokens # Can encode with special tokens
output_with_specials = tokenizer.encode("My name is John", add_special_tokens=True) output_with_specials = tokenizer.encode(
"My name is John", add_special_tokens=True
)
assert output_with_specials.tokens == [ assert output_with_specials.tokens == [
"<s>", "<s>",
"ĠMy", "ĠMy",
@@ -313,7 +330,9 @@ class TestTokenizer:
] ]
# Can encode without special tokens # Can encode without special tokens
output_without_specials = tokenizer.encode("My name is John", add_special_tokens=False) output_without_specials = tokenizer.encode(
"My name is John", add_special_tokens=False
)
assert output_without_specials.tokens == ["ĠMy", "Ġname", "Ġis", "ĠJohn"] assert output_without_specials.tokens == ["ĠMy", "Ġname", "Ġis", "ĠJohn"]
def test_truncation(self): def test_truncation(self):

View File

@@ -67,7 +67,10 @@ class TestByteLevelBPE:
def test_lowerspace(self, roberta_files): def test_lowerspace(self, roberta_files):
tokenizer = ByteLevelBPETokenizer.from_files( tokenizer = ByteLevelBPETokenizer.from_files(
roberta_files["vocab"], roberta_files["merges"], add_prefix_space=True, lowercase=True, roberta_files["vocab"],
roberta_files["merges"],
add_prefix_space=True,
lowercase=True,
) )
output = tokenizer.encode("The Quick Brown Fox Jumps Over The Lazy Dog") output = tokenizer.encode("The Quick Brown Fox Jumps Over The Lazy Dog")

View File

@@ -6,7 +6,9 @@ from tokenizers import CharBPETokenizer
class TestBertWordPieceBPE: class TestBertWordPieceBPE:
def test_basic_encode(self, openai_files): def test_basic_encode(self, openai_files):
tokenizer = CharBPETokenizer.from_files(openai_files["vocab"], openai_files["merges"]) tokenizer = CharBPETokenizer.from_files(
openai_files["vocab"], openai_files["merges"]
)
output = tokenizer.encode("My name is John", "pair") output = tokenizer.encode("My name is John", "pair")
assert output.ids == [0, 253, 1362, 544, 0, 7, 12662, 2688] assert output.ids == [0, 253, 1362, 544, 0, 7, 12662, 2688]
@@ -50,6 +52,8 @@ class TestBertWordPieceBPE:
assert decoded == "my name is john" assert decoded == "my name is john"
def test_multiprocessing_with_parallelism(self, openai_files): def test_multiprocessing_with_parallelism(self, openai_files):
tokenizer = CharBPETokenizer.from_files(openai_files["vocab"], openai_files["merges"]) tokenizer = CharBPETokenizer.from_files(
openai_files["vocab"], openai_files["merges"]
)
multiprocessing_with_parallelism(tokenizer, False) multiprocessing_with_parallelism(tokenizer, False)
multiprocessing_with_parallelism(tokenizer, True) multiprocessing_with_parallelism(tokenizer, True)