mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-04 19:58:21 +00:00
Adressing first pass of comments.
This commit is contained in:
@@ -70,13 +70,17 @@ elif args.type == "bert":
|
|||||||
|
|
||||||
tok_r = Tokenizer(WordPiece(args.vocab, unk_token="[UNK]", max_input_chars_per_word=100))
|
tok_r = Tokenizer(WordPiece(args.vocab, unk_token="[UNK]", max_input_chars_per_word=100))
|
||||||
tok_r.normalizer = BertNormalizer(
|
tok_r.normalizer = BertNormalizer(
|
||||||
clean_text=True, handle_chinese_chars=True, strip_accents=True, lowercase=True,
|
clean_text=True,
|
||||||
|
handle_chinese_chars=True,
|
||||||
|
strip_accents=True,
|
||||||
|
lowercase=True,
|
||||||
)
|
)
|
||||||
# tok_r.pre_tokenizer = pre_tokenizers.Whitespace()
|
# tok_r.pre_tokenizer = pre_tokenizers.Whitespace()
|
||||||
tok_r.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
tok_r.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
|
||||||
tok_r.decoder = decoders.WordPiece()
|
tok_r.decoder = decoders.WordPiece()
|
||||||
tok_r.post_processor = BertProcessing(
|
tok_r.post_processor = BertProcessing(
|
||||||
("[SEP]", tok_r.token_to_id("[SEP]")), ("[CLS]", tok_r.token_to_id("[CLS]")),
|
("[SEP]", tok_r.token_to_id("[SEP]")),
|
||||||
|
("[CLS]", tok_r.token_to_id("[CLS]")),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown type {args.type}")
|
raise Exception(f"Unknown type {args.type}")
|
||||||
|
|||||||
@@ -32,7 +32,10 @@ if not files:
|
|||||||
|
|
||||||
# Initialize an empty tokenizer
|
# Initialize an empty tokenizer
|
||||||
tokenizer = BertWordPieceTokenizer(
|
tokenizer = BertWordPieceTokenizer(
|
||||||
clean_text=True, handle_chinese_chars=True, strip_accents=True, lowercase=True,
|
clean_text=True,
|
||||||
|
handle_chinese_chars=True,
|
||||||
|
strip_accents=True,
|
||||||
|
lowercase=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# And then train
|
# And then train
|
||||||
|
|||||||
@@ -76,43 +76,28 @@ class SpmConverter(Converter):
|
|||||||
model_type = proto.trainer_spec.model_type
|
model_type = proto.trainer_spec.model_type
|
||||||
vocab = self.vocab(proto)
|
vocab = self.vocab(proto)
|
||||||
unk_id = self.unk_id(proto)
|
unk_id = self.unk_id(proto)
|
||||||
filename = self.original_tokenizer.vocab_file
|
|
||||||
|
|
||||||
if model_type == 1:
|
if model_type == 1:
|
||||||
data = {"unk_id": unk_id, "vocab": vocab}
|
tokenizer = Tokenizer(Unigram(vocab, unk_id))
|
||||||
|
|
||||||
out_vocab_filename = f"{filename}.json"
|
|
||||||
try:
|
|
||||||
with open(out_vocab_filename, "w") as f:
|
|
||||||
json.dump(data, f, indent=4)
|
|
||||||
|
|
||||||
tokenizer = Tokenizer(Unigram(out_vocab_filename))
|
|
||||||
finally:
|
|
||||||
os.remove(out_vocab_filename)
|
|
||||||
elif model_type == 2:
|
elif model_type == 2:
|
||||||
vocab, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract()
|
vocab, merges = SentencePieceExtractor(
|
||||||
|
self.original_tokenizer.vocab_file
|
||||||
|
).extract()
|
||||||
# Open output files and let's extract model information
|
# Open output files and let's extract model information
|
||||||
out_vocab_filename = f"{filename}.vocab"
|
actual_merges = {}
|
||||||
out_merge_filename = f"{filename}.merge"
|
for id_merge, (a, b) in enumerate(merges):
|
||||||
try:
|
id_a = vocab[a]
|
||||||
with open(out_vocab_filename, "w") as vocab_f:
|
id_b = vocab[b]
|
||||||
json.dump(vocab, vocab_f)
|
id_ab = vocab[a + b]
|
||||||
try:
|
id_ab = vocab[a + b]
|
||||||
with open(out_merge_filename, "w") as merges_f:
|
actual_merges[(id_a, id_b)] = (id_merge, id_ab)
|
||||||
# Save content
|
tokenizer = Tokenizer(
|
||||||
merges_f.writelines(map(lambda x: f"{x[0]} {x[1]}{os.linesep}", merges))
|
BPE(
|
||||||
tokenizer = Tokenizer(
|
vocab,
|
||||||
BPE(
|
actual_merges,
|
||||||
out_vocab_filename,
|
unk_token=proto.trainer_spec.unk_piece,
|
||||||
out_merge_filename,
|
fuse_unk=True,
|
||||||
unk_token=proto.trainer_spec.unk_piece,
|
)
|
||||||
fuse_unk=True,
|
)
|
||||||
)
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
os.remove(out_merge_filename)
|
|
||||||
finally:
|
|
||||||
os.remove(out_vocab_filename)
|
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
"You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
|
||||||
@@ -346,7 +331,9 @@ class PegasusConverter(SpmConverter):
|
|||||||
return TemplateProcessing(
|
return TemplateProcessing(
|
||||||
seq_a=["$0", eos],
|
seq_a=["$0", eos],
|
||||||
seq_b=["$1", eos],
|
seq_b=["$1", eos],
|
||||||
special_tokens=[(eos, tokenizer.get_vocab()[eos]),],
|
special_tokens=[
|
||||||
|
(eos, tokenizer.get_vocab()[eos]),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -355,7 +342,9 @@ class T5Converter(SpmConverter):
|
|||||||
return TemplateProcessing(
|
return TemplateProcessing(
|
||||||
seq_a=["$0", "</s>"],
|
seq_a=["$0", "</s>"],
|
||||||
seq_b=["$1", "</s>"],
|
seq_b=["$1", "</s>"],
|
||||||
special_tokens=[("</s>", tokenizer.get_vocab()["</s>"]),],
|
special_tokens=[
|
||||||
|
("</s>", tokenizer.get_vocab()["</s>"]),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -447,7 +436,9 @@ def main():
|
|||||||
model_len = 50
|
model_len = 50
|
||||||
status_len = 6
|
status_len = 6
|
||||||
speedup_len = 8
|
speedup_len = 8
|
||||||
print(f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|")
|
print(
|
||||||
|
f"|{'Model':^{model_len}}|{'Status':^{status_len}}|{'Speedup':^{speedup_len}}|"
|
||||||
|
)
|
||||||
print(f"|{'-'*model_len}|{'-'*status_len}|{'-'*speedup_len}|")
|
print(f"|{'-'*model_len}|{'-'*status_len}|{'-'*speedup_len}|")
|
||||||
for pretrained in args.models:
|
for pretrained in args.models:
|
||||||
status, speedup = check(pretrained, args.filename)
|
status, speedup = check(pretrained, args.filename)
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ except Exception:
|
|||||||
def main():
|
def main():
|
||||||
parser = ArgumentParser("SentencePiece parity checker")
|
parser = ArgumentParser("SentencePiece parity checker")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--input-file", "-i", type=str, required=True, help="Which files do you want to train from",
|
"--input-file",
|
||||||
|
"-i",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Which files do you want to train from",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model-file",
|
"--model-file",
|
||||||
@@ -28,13 +32,22 @@ def main():
|
|||||||
help="Use a pretrained token file",
|
help="Use a pretrained token file",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model-prefix", type=str, default="spm_parity", help="Model prefix for spm_train",
|
"--model-prefix",
|
||||||
|
type=str,
|
||||||
|
default="spm_parity",
|
||||||
|
help="Model prefix for spm_train",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocab-size", "-v", type=int, default=8000, help="Vocab size for spm_train",
|
"--vocab-size",
|
||||||
|
"-v",
|
||||||
|
type=int,
|
||||||
|
default=8000,
|
||||||
|
help="Vocab size for spm_train",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--verbose", action="store_true", help="Verbosity",
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
help="Verbosity",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train",
|
"--train",
|
||||||
@@ -160,10 +173,14 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
|
|||||||
spms = Counter(spm_ids[first:last])
|
spms = Counter(spm_ids[first:last])
|
||||||
toks = Counter(tok_ids[first:last])
|
toks = Counter(tok_ids[first:last])
|
||||||
|
|
||||||
removable_tokens = {spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si}
|
removable_tokens = {
|
||||||
|
spm_ for (spm_, si) in spms.items() if toks.get(spm_, 0) == si
|
||||||
|
}
|
||||||
min_width = 3
|
min_width = 3
|
||||||
for i in range(last - first - min_width):
|
for i in range(last - first - min_width):
|
||||||
if all(spm_ids[first + i + j] in removable_tokens for j in range(min_width)):
|
if all(
|
||||||
|
spm_ids[first + i + j] in removable_tokens for j in range(min_width)
|
||||||
|
):
|
||||||
possible_matches = [
|
possible_matches = [
|
||||||
k
|
k
|
||||||
for k in range(last - first - min_width)
|
for k in range(last - first - min_width)
|
||||||
@@ -174,7 +191,11 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
|
|||||||
if check_diff(
|
if check_diff(
|
||||||
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
|
spm_ids[first : first + i], tok_ids[first : first + j], sp, tok
|
||||||
) and check_details(
|
) and check_details(
|
||||||
line, spm_ids[first + i : last], tok_ids[first + j : last], sp, tok,
|
line,
|
||||||
|
spm_ids[first + i : last],
|
||||||
|
tok_ids[first + j : last],
|
||||||
|
sp,
|
||||||
|
tok,
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -189,7 +210,9 @@ def check_details(line, spm_ids, tok_ids, sp, tok):
|
|||||||
wrong = tok.decode(spm_ids[first:last])
|
wrong = tok.decode(spm_ids[first:last])
|
||||||
print()
|
print()
|
||||||
if has_color:
|
if has_color:
|
||||||
print(f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}")
|
print(
|
||||||
|
f"{colored(ok_start, 'grey')}{colored(wrong, 'red')}{colored(ok_end, 'grey')}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(wrong)
|
print(wrong)
|
||||||
return False
|
return False
|
||||||
@@ -203,17 +226,8 @@ def check_encode(args):
|
|||||||
tok = tokenizers.SentencePieceUnigramTokenizer.from_spm(args.model_file)
|
tok = tokenizers.SentencePieceUnigramTokenizer.from_spm(args.model_file)
|
||||||
else:
|
else:
|
||||||
vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())]
|
vocab = [(sp.id_to_piece(i), sp.get_score(i)) for i in range(sp.piece_size())]
|
||||||
vocab_filename = f"{args.model_file}.json"
|
|
||||||
unk_id = sp.unk_id()
|
unk_id = sp.unk_id()
|
||||||
|
tok = tokenizers.SentencePieceUnigramTokenizer(vocab, unk_id)
|
||||||
data = {"unk_id": unk_id, "vocab": vocab}
|
|
||||||
try:
|
|
||||||
with open(vocab_filename, "w") as f:
|
|
||||||
json.dump(data, f, indent=4)
|
|
||||||
|
|
||||||
tok = tokenizers.SentencePieceUnigramTokenizer(vocab_filename)
|
|
||||||
finally:
|
|
||||||
os.remove(vocab_filename)
|
|
||||||
|
|
||||||
perfect = 0
|
perfect = 0
|
||||||
imperfect = 0
|
imperfect = 0
|
||||||
@@ -255,7 +269,9 @@ def check_encode(args):
|
|||||||
|
|
||||||
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
|
print(f"({perfect} / {imperfect} / {wrong} ----- {perfect + imperfect + wrong})")
|
||||||
total = perfect + imperfect + wrong
|
total = perfect + imperfect + wrong
|
||||||
print(f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}")
|
print(
|
||||||
|
f"Accuracy {perfect * 100 / total:.2f} Slowdown : {tok_total_time/ spm_total_time:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -34,3 +34,11 @@ impl<T> ToPyResult<T> {
|
|||||||
self.into()
|
self.into()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn deprecation_warning(version: &str, message: &str) -> PyResult<()> {
|
||||||
|
let gil = pyo3::Python::acquire_gil();
|
||||||
|
let python = gil.python();
|
||||||
|
let deprecation_warning = python.import("builtins")?.get("DeprecationWarning")?;
|
||||||
|
let full_message = format!("Deprecated in {}: {}", version, message);
|
||||||
|
pyo3::PyErr::warn(python, deprecation_warning, &full_message, 0)
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,16 +15,7 @@ use tk::models::ModelWrapper;
|
|||||||
use tk::{Model, Token};
|
use tk::{Model, Token};
|
||||||
use tokenizers as tk;
|
use tokenizers as tk;
|
||||||
|
|
||||||
use super::error::ToPyResult;
|
use super::error::{deprecation_warning, ToPyResult};
|
||||||
|
|
||||||
fn deprecation_warning(version: &str, message: &str) -> PyResult<()> {
|
|
||||||
let gil = pyo3::Python::acquire_gil();
|
|
||||||
let python = gil.python();
|
|
||||||
let deprecation_warning = python.import("builtins")?.get("DeprecationWarning")?;
|
|
||||||
let full_message = format!("Deprecated in {}: {}", version, message);
|
|
||||||
pyo3::PyErr::warn(python, deprecation_warning, &full_message, 0)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A Model represents some tokenization algorithm like BPE or Word
|
/// A Model represents some tokenization algorithm like BPE or Word
|
||||||
/// This class cannot be constructed directly. Please use one of the concrete models.
|
/// This class cannot be constructed directly. Please use one of the concrete models.
|
||||||
@@ -183,13 +174,24 @@ impl PyBPE {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(FromPyObject)]
|
||||||
|
enum PyVocab<'a> {
|
||||||
|
Vocab(Vocab),
|
||||||
|
Filename(&'a str),
|
||||||
|
}
|
||||||
|
#[derive(FromPyObject)]
|
||||||
|
enum PyMerges<'a> {
|
||||||
|
Merges(Merges),
|
||||||
|
Filename(&'a str),
|
||||||
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
impl PyBPE {
|
impl PyBPE {
|
||||||
#[new]
|
#[new]
|
||||||
#[args(kwargs = "**")]
|
#[args(kwargs = "**")]
|
||||||
fn new(
|
fn new(
|
||||||
vocab: Option<&PyAny>,
|
vocab: Option<PyVocab>,
|
||||||
merges: Option<&PyAny>,
|
merges: Option<PyMerges>,
|
||||||
kwargs: Option<&PyDict>,
|
kwargs: Option<&PyDict>,
|
||||||
) -> PyResult<(Self, PyModel)> {
|
) -> PyResult<(Self, PyModel)> {
|
||||||
if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) {
|
if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) {
|
||||||
@@ -199,17 +201,24 @@ impl PyBPE {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut builder = BPE::builder();
|
let mut builder = BPE::builder();
|
||||||
if let (Some(vocab_any), Some(merges_any)) = (vocab, merges) {
|
if let (Some(vocab), Some(merges)) = (vocab, merges) {
|
||||||
if let (Ok(vocab), Ok(merges)) = (vocab_any.extract(), merges_any.extract()) {
|
match (vocab, merges) {
|
||||||
builder = builder.vocab_and_merges(vocab, merges);
|
(PyVocab::Vocab(vocab), PyMerges::Merges(merges)) => {
|
||||||
} else {
|
builder = builder.vocab_and_merges(vocab, merges);
|
||||||
let vocab_filename: String = vocab_any.extract()?;
|
}
|
||||||
let merges_filename: String = merges_any.extract()?;
|
(PyVocab::Filename(vocab_filename), PyMerges::Filename(merges_filename)) => {
|
||||||
deprecation_warning(
|
deprecation_warning(
|
||||||
"0.9.0",
|
"0.9.0",
|
||||||
"BPE.__init__ will not create from files anymore, try `BPE.from_files` instead",
|
"BPE.__init__ will not create from files anymore, try `BPE.from_files` instead",
|
||||||
)?;
|
)?;
|
||||||
builder = builder.files(vocab_filename, merges_filename);
|
builder =
|
||||||
|
builder.files(vocab_filename.to_string(), merges_filename.to_string());
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(exceptions::PyValueError::new_err(
|
||||||
|
"`vocab` and `merges` must be both be from memory or both filenames",
|
||||||
|
));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -268,20 +277,21 @@ impl PyWordPiece {
|
|||||||
impl PyWordPiece {
|
impl PyWordPiece {
|
||||||
#[new]
|
#[new]
|
||||||
#[args(kwargs = "**")]
|
#[args(kwargs = "**")]
|
||||||
fn new(vocab: Option<&PyAny>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
|
fn new(vocab: Option<PyVocab>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
|
||||||
let mut builder = WordPiece::builder();
|
let mut builder = WordPiece::builder();
|
||||||
|
|
||||||
if let Some(vocab_any) = vocab {
|
if let Some(vocab) = vocab {
|
||||||
#[allow(deprecated)]
|
match vocab {
|
||||||
if let Ok(vocab) = vocab_any.extract() {
|
PyVocab::Vocab(vocab) => {
|
||||||
builder = builder.vocab(vocab);
|
builder = builder.vocab(vocab);
|
||||||
} else {
|
}
|
||||||
deprecation_warning(
|
PyVocab::Filename(vocab_filename) => {
|
||||||
"0.9.0",
|
deprecation_warning(
|
||||||
"WordPiece.__init__ will not create from files anymore, try `WordPiece.from_file` instead",
|
"0.9.0",
|
||||||
)?;
|
"WordPiece.__init__ will not create from files anymore, try `WordPiece.from_file` instead",
|
||||||
let vocab_filename: String = vocab_any.extract()?;
|
)?;
|
||||||
builder = builder.files(vocab_filename);
|
builder = builder.files(vocab_filename.to_string());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -320,27 +330,27 @@ impl PyWordLevel {
|
|||||||
impl PyWordLevel {
|
impl PyWordLevel {
|
||||||
#[new]
|
#[new]
|
||||||
#[args(kwargs = "**")]
|
#[args(kwargs = "**")]
|
||||||
fn new(vocab: Option<&PyAny>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
|
fn new(vocab: Option<PyVocab>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
|
||||||
let unk_token = PyWordLevel::get_unk(kwargs)?;
|
let unk_token = PyWordLevel::get_unk(kwargs)?;
|
||||||
|
|
||||||
if let Some(vocab_object) = vocab {
|
if let Some(vocab) = vocab {
|
||||||
let model = if let Ok(vocab) = vocab_object.extract() {
|
let model = match vocab {
|
||||||
WordLevel::builder()
|
PyVocab::Vocab(vocab) => WordLevel::builder()
|
||||||
.vocab(vocab)
|
.vocab(vocab)
|
||||||
.unk_token(unk_token)
|
.unk_token(unk_token)
|
||||||
.build()
|
.build(),
|
||||||
} else {
|
PyVocab::Filename(vocab_filename) => {
|
||||||
let filename: &str = vocab_object.extract()?;
|
deprecation_warning(
|
||||||
deprecation_warning(
|
|
||||||
"0.9.0",
|
"0.9.0",
|
||||||
"WordLevel.__init__ will not create from files anymore, try `WordLevel.from_file` instead",
|
"WordLevel.__init__ will not create from files anymore, try `WordLevel.from_file` instead",
|
||||||
)?;
|
)?;
|
||||||
WordLevel::from_files(filename, unk_token).map_err(|e| {
|
WordLevel::from_files(vocab_filename, unk_token).map_err(|e| {
|
||||||
exceptions::PyException::new_err(format!(
|
exceptions::PyException::new_err(format!(
|
||||||
"Error while loading WordLevel: {}",
|
"Error while loading WordLevel: {}",
|
||||||
e
|
e
|
||||||
))
|
))
|
||||||
})?
|
})?
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((PyWordLevel {}, PyModel::new(Arc::new(model.into()))))
|
Ok((PyWordLevel {}, PyModel::new(Arc::new(model.into()))))
|
||||||
|
|||||||
@@ -14,22 +14,33 @@ class TestBPE:
|
|||||||
vocab = {"a": 0, "b": 1, "ab": 2}
|
vocab = {"a": 0, "b": 1, "ab": 2}
|
||||||
merges = {(0, 1): (0, 2)}
|
merges = {(0, 1): (0, 2)}
|
||||||
assert isinstance(BPE(vocab, merges), Model)
|
assert isinstance(BPE(vocab, merges), Model)
|
||||||
with pytest.raises(ValueError, match="`vocab` and `merges` must be both specified"):
|
with pytest.raises(
|
||||||
|
ValueError, match="`vocab` and `merges` must be both specified"
|
||||||
|
):
|
||||||
BPE(vocab=vocab)
|
BPE(vocab=vocab)
|
||||||
BPE(merges=merges)
|
BPE(merges=merges)
|
||||||
|
|
||||||
assert isinstance(pickle.loads(pickle.dumps(BPE(vocab, merges))), BPE,)
|
assert isinstance(
|
||||||
|
pickle.loads(pickle.dumps(BPE(vocab, merges))),
|
||||||
|
BPE,
|
||||||
|
)
|
||||||
|
|
||||||
# Deprecated calls in 0.9
|
# Deprecated calls in 0.9
|
||||||
with pytest.deprecated_call():
|
with pytest.deprecated_call():
|
||||||
assert isinstance(BPE(roberta_files["vocab"], roberta_files["merges"]), Model)
|
assert isinstance(
|
||||||
|
BPE(roberta_files["vocab"], roberta_files["merges"]), Model
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="`vocab` and `merges` must be both specified"):
|
with pytest.raises(
|
||||||
|
ValueError, match="`vocab` and `merges` must be both specified"
|
||||||
|
):
|
||||||
BPE(vocab=roberta_files["vocab"])
|
BPE(vocab=roberta_files["vocab"])
|
||||||
BPE(merges=roberta_files["merges"])
|
BPE(merges=roberta_files["merges"])
|
||||||
with pytest.deprecated_call():
|
with pytest.deprecated_call():
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
pickle.loads(pickle.dumps(BPE(roberta_files["vocab"], roberta_files["merges"]))),
|
pickle.loads(
|
||||||
|
pickle.dumps(BPE(roberta_files["vocab"], roberta_files["merges"]))
|
||||||
|
),
|
||||||
BPE,
|
BPE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -48,7 +59,9 @@ class TestWordPiece:
|
|||||||
with pytest.deprecated_call():
|
with pytest.deprecated_call():
|
||||||
assert isinstance(WordPiece(bert_files["vocab"]), Model)
|
assert isinstance(WordPiece(bert_files["vocab"]), Model)
|
||||||
with pytest.deprecated_call():
|
with pytest.deprecated_call():
|
||||||
assert isinstance(pickle.loads(pickle.dumps(WordPiece(bert_files["vocab"]))), WordPiece)
|
assert isinstance(
|
||||||
|
pickle.loads(pickle.dumps(WordPiece(bert_files["vocab"]))), WordPiece
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestWordLevel:
|
class TestWordLevel:
|
||||||
|
|||||||
@@ -144,7 +144,9 @@ class TestTokenizer:
|
|||||||
assert output.tokens == ["my", "name", "is", "john"]
|
assert output.tokens == ["my", "name", "is", "john"]
|
||||||
|
|
||||||
# Can encode a batch with both a single sequence and a pair of sequences
|
# Can encode a batch with both a single sequence and a pair of sequences
|
||||||
output = tokenizer.encode_batch(["my name is john", ("my name is john", "pair")])
|
output = tokenizer.encode_batch(
|
||||||
|
["my name is john", ("my name is john", "pair")]
|
||||||
|
)
|
||||||
assert len(output) == 2
|
assert len(output) == 2
|
||||||
|
|
||||||
def test_encode_formats(self, bert_files):
|
def test_encode_formats(self, bert_files):
|
||||||
@@ -167,7 +169,9 @@ class TestTokenizer:
|
|||||||
]
|
]
|
||||||
output = tokenizer.encode(["my", "name", "is", "john"], is_pretokenized=True)
|
output = tokenizer.encode(["my", "name", "is", "john"], is_pretokenized=True)
|
||||||
assert output.tokens == ["[CLS]", "my", "name", "is", "john", "[SEP]"]
|
assert output.tokens == ["[CLS]", "my", "name", "is", "john", "[SEP]"]
|
||||||
output = tokenizer.encode(["my", "name", "is", "john"], ["pair"], is_pretokenized=True)
|
output = tokenizer.encode(
|
||||||
|
["my", "name", "is", "john"], ["pair"], is_pretokenized=True
|
||||||
|
)
|
||||||
assert output.tokens == [
|
assert output.tokens == [
|
||||||
"[CLS]",
|
"[CLS]",
|
||||||
"my",
|
"my",
|
||||||
@@ -213,13 +217,19 @@ class TestTokenizer:
|
|||||||
|
|
||||||
# Numpy
|
# Numpy
|
||||||
test_single(np.array(["My name is John", "My name is Georges"]))
|
test_single(np.array(["My name is John", "My name is Georges"]))
|
||||||
test_pair(np.array([("My name is John", "pair"), ("My name is Georges", "pair")]))
|
test_pair(
|
||||||
test_pair(np.array([["My name is John", "pair"], ["My name is Georges", "pair"]]))
|
np.array([("My name is John", "pair"), ("My name is Georges", "pair")])
|
||||||
|
)
|
||||||
|
test_pair(
|
||||||
|
np.array([["My name is John", "pair"], ["My name is Georges", "pair"]])
|
||||||
|
)
|
||||||
|
|
||||||
# PreTokenized inputs
|
# PreTokenized inputs
|
||||||
|
|
||||||
# Lists
|
# Lists
|
||||||
test_single([["My", "name", "is", "John"], ["My", "name", "is", "Georges"]], True)
|
test_single(
|
||||||
|
[["My", "name", "is", "John"], ["My", "name", "is", "Georges"]], True
|
||||||
|
)
|
||||||
test_pair(
|
test_pair(
|
||||||
[
|
[
|
||||||
(["My", "name", "is", "John"], ["pair"]),
|
(["My", "name", "is", "John"], ["pair"]),
|
||||||
@@ -236,7 +246,9 @@ class TestTokenizer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Tuples
|
# Tuples
|
||||||
test_single((("My", "name", "is", "John"), ("My", "name", "is", "Georges")), True)
|
test_single(
|
||||||
|
(("My", "name", "is", "John"), ("My", "name", "is", "Georges")), True
|
||||||
|
)
|
||||||
test_pair(
|
test_pair(
|
||||||
(
|
(
|
||||||
(("My", "name", "is", "John"), ("pair",)),
|
(("My", "name", "is", "John"), ("pair",)),
|
||||||
@@ -254,10 +266,12 @@ class TestTokenizer:
|
|||||||
|
|
||||||
# Numpy
|
# Numpy
|
||||||
test_single(
|
test_single(
|
||||||
np.array([["My", "name", "is", "John"], ["My", "name", "is", "Georges"]]), True,
|
np.array([["My", "name", "is", "John"], ["My", "name", "is", "Georges"]]),
|
||||||
|
True,
|
||||||
)
|
)
|
||||||
test_single(
|
test_single(
|
||||||
np.array((("My", "name", "is", "John"), ("My", "name", "is", "Georges"))), True,
|
np.array((("My", "name", "is", "John"), ("My", "name", "is", "Georges"))),
|
||||||
|
True,
|
||||||
)
|
)
|
||||||
test_pair(
|
test_pair(
|
||||||
np.array(
|
np.array(
|
||||||
@@ -298,11 +312,14 @@ class TestTokenizer:
|
|||||||
|
|
||||||
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
|
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
|
||||||
tokenizer.post_processor = RobertaProcessing(
|
tokenizer.post_processor = RobertaProcessing(
|
||||||
("</s>", tokenizer.token_to_id("</s>")), ("<s>", tokenizer.token_to_id("<s>")),
|
("</s>", tokenizer.token_to_id("</s>")),
|
||||||
|
("<s>", tokenizer.token_to_id("<s>")),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Can encode with special tokens
|
# Can encode with special tokens
|
||||||
output_with_specials = tokenizer.encode("My name is John", add_special_tokens=True)
|
output_with_specials = tokenizer.encode(
|
||||||
|
"My name is John", add_special_tokens=True
|
||||||
|
)
|
||||||
assert output_with_specials.tokens == [
|
assert output_with_specials.tokens == [
|
||||||
"<s>",
|
"<s>",
|
||||||
"ĠMy",
|
"ĠMy",
|
||||||
@@ -313,7 +330,9 @@ class TestTokenizer:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Can encode without special tokens
|
# Can encode without special tokens
|
||||||
output_without_specials = tokenizer.encode("My name is John", add_special_tokens=False)
|
output_without_specials = tokenizer.encode(
|
||||||
|
"My name is John", add_special_tokens=False
|
||||||
|
)
|
||||||
assert output_without_specials.tokens == ["ĠMy", "Ġname", "Ġis", "ĠJohn"]
|
assert output_without_specials.tokens == ["ĠMy", "Ġname", "Ġis", "ĠJohn"]
|
||||||
|
|
||||||
def test_truncation(self):
|
def test_truncation(self):
|
||||||
|
|||||||
@@ -67,7 +67,10 @@ class TestByteLevelBPE:
|
|||||||
|
|
||||||
def test_lowerspace(self, roberta_files):
|
def test_lowerspace(self, roberta_files):
|
||||||
tokenizer = ByteLevelBPETokenizer.from_files(
|
tokenizer = ByteLevelBPETokenizer.from_files(
|
||||||
roberta_files["vocab"], roberta_files["merges"], add_prefix_space=True, lowercase=True,
|
roberta_files["vocab"],
|
||||||
|
roberta_files["merges"],
|
||||||
|
add_prefix_space=True,
|
||||||
|
lowercase=True,
|
||||||
)
|
)
|
||||||
output = tokenizer.encode("The Quick Brown Fox Jumps Over The Lazy Dog")
|
output = tokenizer.encode("The Quick Brown Fox Jumps Over The Lazy Dog")
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ from tokenizers import CharBPETokenizer
|
|||||||
|
|
||||||
class TestBertWordPieceBPE:
|
class TestBertWordPieceBPE:
|
||||||
def test_basic_encode(self, openai_files):
|
def test_basic_encode(self, openai_files):
|
||||||
tokenizer = CharBPETokenizer.from_files(openai_files["vocab"], openai_files["merges"])
|
tokenizer = CharBPETokenizer.from_files(
|
||||||
|
openai_files["vocab"], openai_files["merges"]
|
||||||
|
)
|
||||||
|
|
||||||
output = tokenizer.encode("My name is John", "pair")
|
output = tokenizer.encode("My name is John", "pair")
|
||||||
assert output.ids == [0, 253, 1362, 544, 0, 7, 12662, 2688]
|
assert output.ids == [0, 253, 1362, 544, 0, 7, 12662, 2688]
|
||||||
@@ -50,6 +52,8 @@ class TestBertWordPieceBPE:
|
|||||||
assert decoded == "my name is john"
|
assert decoded == "my name is john"
|
||||||
|
|
||||||
def test_multiprocessing_with_parallelism(self, openai_files):
|
def test_multiprocessing_with_parallelism(self, openai_files):
|
||||||
tokenizer = CharBPETokenizer.from_files(openai_files["vocab"], openai_files["merges"])
|
tokenizer = CharBPETokenizer.from_files(
|
||||||
|
openai_files["vocab"], openai_files["merges"]
|
||||||
|
)
|
||||||
multiprocessing_with_parallelism(tokenizer, False)
|
multiprocessing_with_parallelism(tokenizer, False)
|
||||||
multiprocessing_with_parallelism(tokenizer, True)
|
multiprocessing_with_parallelism(tokenizer, True)
|
||||||
|
|||||||
Reference in New Issue
Block a user