Changed rust api for merges, that is now Vec<(String, String)>

This commit is contained in:
Nicolas Patry
2020-09-23 19:09:46 +02:00
parent 36832bfa12
commit 95cc8c47ad
13 changed files with 94 additions and 84 deletions

View File

@@ -36,11 +36,7 @@ class BPE(Model):
A dictionnary of string keys and their ids {"am": 0,...}
merges: (`optional`) string:
A dictionnary of pairs of ids as keys and their merge correspondace:
{(id_left, id_right): (importance, id_merged), .... }
with vocab : {"a": 0, "b": 1", ... "ab": 4} the merge
{(0, 1): (0, 4) ,...}
corresponds to the "ab" merge, that is the most likely merge (0)
A list of pairs of tokens [("a", "b"),...]
cache_capacity: (`optional`) int:
The number of words that the BPE cache can contain. The cache allows
@@ -66,7 +62,7 @@ class BPE(Model):
def __init__(
self,
vocab: Optional[Union[str, Dict[str, int]]],
merges: Optional[Union[str, Dict[Tuple[int, int], Tuple[int, int]]]],
merges: Optional[Union[str, List[Tuple[str, str]]]],
cache_capacity: Optional[int],
dropout: Optional[float],
unk_token: Optional[str],

View File

@@ -12,7 +12,7 @@ class TestBPE:
assert isinstance(BPE(), BPE)
vocab = {"a": 0, "b": 1, "ab": 2}
merges = {(0, 1): (0, 2)}
merges = [("a", "b")]
assert isinstance(BPE(vocab, merges), Model)
assert isinstance(BPE.from_file(roberta_files["vocab"], roberta_files["merges"]), BPE)
with pytest.raises(ValueError, match="`vocab` and `merges` must be both specified"):