Python - Tweak BPE constructor + add some tests

This commit is contained in:
Anthony MOI
2020-04-08 14:00:17 -04:00
parent be7b345bcd
commit 4cb77ca64c
6 changed files with 43 additions and 10 deletions

View File

@@ -154,16 +154,15 @@ impl BPE {
merges: Option<&str>, merges: Option<&str>,
kwargs: Option<&PyDict>, kwargs: Option<&PyDict>,
) -> PyResult<(Self, Model)> { ) -> PyResult<(Self, Model)> {
if (vocab.is_some() && merges.is_none()) || (vocab.is_none() && merges.is_some()) {
return Err(exceptions::ValueError::py_err(
"`vocab` and `merges` must be both specified",
));
}
let mut builder = tk::models::bpe::BPE::builder(); let mut builder = tk::models::bpe::BPE::builder();
if let Some(vocab) = vocab { if let (Some(vocab), Some(merges)) = (vocab, merges) {
if let Some(merges) = merges { builder = builder.files(vocab.to_owned(), merges.to_owned());
builder = builder.files(vocab.to_owned(), merges.to_owned());
} else {
return Err(exceptions::Exception::py_err(format!(
"Got vocab file ({}), but missing merges",
vocab
)));
}
} }
if let Some(kwargs) = kwargs { if let Some(kwargs) = kwargs {
for (key, value) in kwargs { for (key, value) in kwargs {

View File

@@ -7,6 +7,7 @@ class TestByteLevel:
def test_instantiate(self): def test_instantiate(self):
assert ByteLevel() is not None assert ByteLevel() is not None
assert isinstance(ByteLevel(), Decoder) assert isinstance(ByteLevel(), Decoder)
assert isinstance(ByteLevel(), ByteLevel)
def test_decoding(self): def test_decoding(self):
decoder = ByteLevel() decoder = ByteLevel()
@@ -19,6 +20,7 @@ class TestWordPiece:
assert WordPiece(prefix="__") is not None assert WordPiece(prefix="__") is not None
assert WordPiece(cleanup=True) is not None assert WordPiece(cleanup=True) is not None
assert isinstance(WordPiece(), Decoder) assert isinstance(WordPiece(), Decoder)
assert isinstance(WordPiece(), WordPiece)
def test_decoding(self): def test_decoding(self):
decoder = WordPiece() decoder = WordPiece()
@@ -37,6 +39,7 @@ class TestMetaspace:
Metaspace(replacement="") Metaspace(replacement="")
assert Metaspace(add_prefix_space=True) is not None assert Metaspace(add_prefix_space=True) is not None
assert isinstance(Metaspace(), Decoder) assert isinstance(Metaspace(), Decoder)
assert isinstance(Metaspace(), Metaspace)
def test_decoding(self): def test_decoding(self):
decoder = Metaspace() decoder = Metaspace()
@@ -50,6 +53,7 @@ class TestBPEDecoder:
assert BPEDecoder() is not None assert BPEDecoder() is not None
assert BPEDecoder(suffix="_") is not None assert BPEDecoder(suffix="_") is not None
assert isinstance(BPEDecoder(), Decoder) assert isinstance(BPEDecoder(), Decoder)
assert isinstance(BPEDecoder(), BPEDecoder)
def test_decoding(self): def test_decoding(self):
decoder = BPEDecoder() decoder = BPEDecoder()

View File

@@ -1,3 +1,5 @@
import pytest
from ..utils import data_dir, roberta_files, bert_files from ..utils import data_dir, roberta_files, bert_files
from tokenizers.models import Model, BPE, WordPiece, WordLevel from tokenizers.models import Model, BPE, WordPiece, WordLevel
@@ -8,6 +10,9 @@ class TestBPE:
assert isinstance(BPE(), Model) assert isinstance(BPE(), Model)
assert isinstance(BPE(), BPE) assert isinstance(BPE(), BPE)
assert isinstance(BPE(roberta_files["vocab"], roberta_files["merges"]), Model) assert isinstance(BPE(roberta_files["vocab"], roberta_files["merges"]), Model)
with pytest.raises(ValueError, match="`vocab` and `merges` must be both specified"):
BPE(vocab=roberta_files["vocab"])
BPE(merges=roberta_files["merges"])
class TestWordPiece: class TestWordPiece:

View File

@@ -1,9 +1,13 @@
from tokenizers import Tokenizer from tokenizers import Tokenizer
from tokenizers.models import BPE from tokenizers.models import BPE
from tokenizers.normalizers import BertNormalizer, Sequence, Lowercase, Strip from tokenizers.normalizers import Normalizer, BertNormalizer, Sequence, Lowercase, Strip
class TestBertNormalizer: class TestBertNormalizer:
def test_instantiate(self):
assert isinstance(BertNormalizer(), Normalizer)
assert isinstance(BertNormalizer(), BertNormalizer)
def test_strip_accents(self): def test_strip_accents(self):
tokenizer = Tokenizer(BPE()) tokenizer = Tokenizer(BPE())
tokenizer.normalizer = BertNormalizer( tokenizer.normalizer = BertNormalizer(
@@ -42,6 +46,10 @@ class TestBertNormalizer:
class TestSequence: class TestSequence:
def test_instantiate(self):
assert isinstance(Sequence([]), Normalizer)
assert isinstance(Sequence([]), Sequence)
def test_can_make_sequences(self): def test_can_make_sequences(self):
tokenizer = Tokenizer(BPE()) tokenizer = Tokenizer(BPE())
tokenizer.normalizer = Sequence([Lowercase(), Strip()]) tokenizer.normalizer = Sequence([Lowercase(), Strip()])
@@ -51,6 +59,10 @@ class TestSequence:
class TestLowercase: class TestLowercase:
def test_instantiate(self):
assert isinstance(Lowercase(), Normalizer)
assert isinstance(Lowercase(), Lowercase)
def test_lowercase(self): def test_lowercase(self):
tokenizer = Tokenizer(BPE()) tokenizer = Tokenizer(BPE())
tokenizer.normalizer = Lowercase() tokenizer.normalizer = Lowercase()
@@ -60,6 +72,10 @@ class TestLowercase:
class TestStrip: class TestStrip:
def test_instantiate(self):
assert isinstance(Strip(), Normalizer)
assert isinstance(Strip(), Strip)
def test_left_strip(self): def test_left_strip(self):
tokenizer = Tokenizer(BPE()) tokenizer = Tokenizer(BPE())
tokenizer.normalizer = Strip(left=True, right=False) tokenizer.normalizer = Strip(left=True, right=False)

View File

@@ -17,6 +17,7 @@ class TestByteLevel:
assert ByteLevel(add_prefix_space=True) is not None assert ByteLevel(add_prefix_space=True) is not None
assert ByteLevel(add_prefix_space=False) is not None assert ByteLevel(add_prefix_space=False) is not None
assert isinstance(ByteLevel(), PreTokenizer) assert isinstance(ByteLevel(), PreTokenizer)
assert isinstance(ByteLevel(), ByteLevel)
def test_has_alphabet(self): def test_has_alphabet(self):
assert isinstance(ByteLevel.alphabet(), list) assert isinstance(ByteLevel.alphabet(), list)
@@ -27,18 +28,21 @@ class TestWhitespace:
def test_instantiate(self): def test_instantiate(self):
assert Whitespace() is not None assert Whitespace() is not None
assert isinstance(Whitespace(), PreTokenizer) assert isinstance(Whitespace(), PreTokenizer)
assert isinstance(Whitespace(), Whitespace)
class TestWhitespaceSplit: class TestWhitespaceSplit:
def test_instantiate(self): def test_instantiate(self):
assert WhitespaceSplit() is not None assert WhitespaceSplit() is not None
assert isinstance(WhitespaceSplit(), PreTokenizer) assert isinstance(WhitespaceSplit(), PreTokenizer)
assert isinstance(WhitespaceSplit(), WhitespaceSplit)
class TestBertPreTokenizer: class TestBertPreTokenizer:
def test_instantiate(self): def test_instantiate(self):
assert BertPreTokenizer() is not None assert BertPreTokenizer() is not None
assert isinstance(BertPreTokenizer(), PreTokenizer) assert isinstance(BertPreTokenizer(), PreTokenizer)
assert isinstance(BertPreTokenizer(), BertPreTokenizer)
class TestMetaspace: class TestMetaspace:
@@ -49,6 +53,7 @@ class TestMetaspace:
Metaspace(replacement="") Metaspace(replacement="")
assert Metaspace(add_prefix_space=True) is not None assert Metaspace(add_prefix_space=True) is not None
assert isinstance(Metaspace(), PreTokenizer) assert isinstance(Metaspace(), PreTokenizer)
assert isinstance(Metaspace(), Metaspace)
class TestCharDelimiterSplit: class TestCharDelimiterSplit:
@@ -57,3 +62,4 @@ class TestCharDelimiterSplit:
with pytest.raises(Exception, match="delimiter must be a single character"): with pytest.raises(Exception, match="delimiter must be a single character"):
CharDelimiterSplit("") CharDelimiterSplit("")
assert isinstance(CharDelimiterSplit(" "), PreTokenizer) assert isinstance(CharDelimiterSplit(" "), PreTokenizer)
assert isinstance(CharDelimiterSplit(" "), CharDelimiterSplit)

View File

@@ -11,6 +11,7 @@ class TestBertProcessing:
processor = BertProcessing(("[SEP]", 0), ("[CLS]", 1)) processor = BertProcessing(("[SEP]", 0), ("[CLS]", 1))
assert processor is not None assert processor is not None
assert isinstance(processor, PostProcessor) assert isinstance(processor, PostProcessor)
assert isinstance(processor, BertProcessing)
def test_processing(self): def test_processing(self):
tokenizer = Tokenizer(BPE()) tokenizer = Tokenizer(BPE())
@@ -28,6 +29,7 @@ class TestRobertaProcessing:
processor = RobertaProcessing(("</s>", 1), ("<s>", 0)) processor = RobertaProcessing(("</s>", 1), ("<s>", 0))
assert processor is not None assert processor is not None
assert isinstance(processor, PostProcessor) assert isinstance(processor, PostProcessor)
assert isinstance(processor, RobertaProcessing)
def test_processing(self): def test_processing(self):
tokenizer = Tokenizer(BPE()) tokenizer = Tokenizer(BPE())
@@ -45,6 +47,7 @@ class TestByteLevelProcessing:
assert ByteLevel() is not None assert ByteLevel() is not None
assert ByteLevel(trim_offsets=True) is not None assert ByteLevel(trim_offsets=True) is not None
assert isinstance(ByteLevel(), PostProcessor) assert isinstance(ByteLevel(), PostProcessor)
assert isinstance(ByteLevel(), ByteLevel)
def test_processing(self, roberta_files): def test_processing(self, roberta_files):
tokenizer = Tokenizer(BPE(roberta_files["vocab"], roberta_files["merges"])) tokenizer = Tokenizer(BPE(roberta_files["vocab"], roberta_files["merges"]))