mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 20:28:22 +00:00
Python - Improve and Test EncodeInput extraction
This commit is contained in:
@@ -177,7 +177,7 @@ impl<'s> FromPyObject<'s> for TextInputSequence<'s> {
|
||||
fn extract(ob: &'s PyAny) -> PyResult<Self> {
|
||||
let err = exceptions::ValueError::py_err("TextInputSequence must be str");
|
||||
if let Ok(s) = ob.downcast::<PyString>() {
|
||||
Ok(Self(s.to_string()?.into()))
|
||||
Ok(Self(s.to_string().map_err(|_| err)?.into()))
|
||||
} else {
|
||||
Err(err)
|
||||
}
|
||||
@@ -276,23 +276,25 @@ impl From<PyArrayStr> for tk::InputSequence<'_> {
|
||||
struct PreTokenizedInputSequence<'s>(tk::InputSequence<'s>);
|
||||
impl<'s> FromPyObject<'s> for PreTokenizedInputSequence<'s> {
|
||||
fn extract(ob: &'s PyAny) -> PyResult<Self> {
|
||||
let err = exceptions::ValueError::py_err(
|
||||
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
|
||||
);
|
||||
|
||||
if let Ok(seq) = ob.extract::<PyArrayUnicode>() {
|
||||
Ok(Self(seq.into()))
|
||||
} else if let Ok(seq) = ob.extract::<PyArrayStr>() {
|
||||
Ok(Self(seq.into()))
|
||||
} else if let Ok(s) = ob.downcast::<PyList>() {
|
||||
let seq = s.extract::<Vec<&str>>().map_err(|_| err)?;
|
||||
Ok(Self(seq.into()))
|
||||
} else if let Ok(s) = ob.downcast::<PyTuple>() {
|
||||
let seq = s.extract::<Vec<&str>>().map_err(|_| err)?;
|
||||
Ok(Self(seq.into()))
|
||||
} else {
|
||||
Err(err)
|
||||
return Ok(Self(seq.into()));
|
||||
}
|
||||
if let Ok(seq) = ob.extract::<PyArrayStr>() {
|
||||
return Ok(Self(seq.into()));
|
||||
}
|
||||
if let Ok(s) = ob.downcast::<PyList>() {
|
||||
if let Ok(seq) = s.extract::<Vec<&str>>() {
|
||||
return Ok(Self(seq.into()));
|
||||
}
|
||||
}
|
||||
if let Ok(s) = ob.downcast::<PyTuple>() {
|
||||
if let Ok(seq) = s.extract::<Vec<&str>>() {
|
||||
return Ok(Self(seq.into()));
|
||||
}
|
||||
}
|
||||
Err(exceptions::ValueError::py_err(
|
||||
"PreTokenizedInputSequence must be Union[List[str], Tuple[str]]",
|
||||
))
|
||||
}
|
||||
}
|
||||
impl<'s> From<PreTokenizedInputSequence<'s>> for tk::InputSequence<'s> {
|
||||
@@ -304,17 +306,22 @@ impl<'s> From<PreTokenizedInputSequence<'s>> for tk::InputSequence<'s> {
|
||||
struct TextEncodeInput<'s>(tk::EncodeInput<'s>);
|
||||
impl<'s> FromPyObject<'s> for TextEncodeInput<'s> {
|
||||
fn extract(ob: &'s PyAny) -> PyResult<Self> {
|
||||
let err = exceptions::ValueError::py_err(
|
||||
"TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]",
|
||||
);
|
||||
|
||||
if let Ok(i) = ob.extract::<TextInputSequence>() {
|
||||
Ok(Self(i.into()))
|
||||
} else if let Ok((i1, i2)) = ob.extract::<(TextInputSequence, TextInputSequence)>() {
|
||||
Ok(Self((i1, i2).into()))
|
||||
} else {
|
||||
Err(err)
|
||||
return Ok(Self(i.into()));
|
||||
}
|
||||
if let Ok((i1, i2)) = ob.extract::<(TextInputSequence, TextInputSequence)>() {
|
||||
return Ok(Self((i1, i2).into()));
|
||||
}
|
||||
if let Ok(arr) = ob.extract::<Vec<&PyAny>>() {
|
||||
if arr.len() == 2 {
|
||||
let first = arr[0].extract::<TextInputSequence>()?;
|
||||
let second = arr[1].extract::<TextInputSequence>()?;
|
||||
return Ok(Self((first, second).into()));
|
||||
}
|
||||
}
|
||||
Err(exceptions::ValueError::py_err(
|
||||
"TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]",
|
||||
))
|
||||
}
|
||||
}
|
||||
impl<'s> From<TextEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {
|
||||
@@ -325,20 +332,24 @@ impl<'s> From<TextEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {
|
||||
struct PreTokenizedEncodeInput<'s>(tk::EncodeInput<'s>);
|
||||
impl<'s> FromPyObject<'s> for PreTokenizedEncodeInput<'s> {
|
||||
fn extract(ob: &'s PyAny) -> PyResult<Self> {
|
||||
let err = exceptions::ValueError::py_err(
|
||||
if let Ok(i) = ob.extract::<PreTokenizedInputSequence>() {
|
||||
return Ok(Self(i.into()));
|
||||
}
|
||||
if let Ok((i1, i2)) = ob.extract::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>()
|
||||
{
|
||||
return Ok(Self((i1, i2).into()));
|
||||
}
|
||||
if let Ok(arr) = ob.extract::<Vec<&PyAny>>() {
|
||||
if arr.len() == 2 {
|
||||
let first = arr[0].extract::<PreTokenizedInputSequence>()?;
|
||||
let second = arr[1].extract::<PreTokenizedInputSequence>()?;
|
||||
return Ok(Self((first, second).into()));
|
||||
}
|
||||
}
|
||||
Err(exceptions::ValueError::py_err(
|
||||
"PreTokenizedEncodeInput must be Union[PreTokenizedInputSequence, \
|
||||
Tuple[PreTokenizedInputSequence, PreTokenizedInputSequence]]",
|
||||
);
|
||||
|
||||
if let Ok(i) = ob.extract::<PreTokenizedInputSequence>() {
|
||||
Ok(Self(i.into()))
|
||||
} else if let Ok((i1, i2)) =
|
||||
ob.extract::<(PreTokenizedInputSequence, PreTokenizedInputSequence)>()
|
||||
{
|
||||
Ok(Self((i1, i2).into()))
|
||||
} else {
|
||||
Err(err)
|
||||
}
|
||||
))
|
||||
}
|
||||
}
|
||||
impl<'s> From<PreTokenizedEncodeInput<'s>> for tk::tokenizer::EncodeInput<'s> {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
import pickle
|
||||
import pytest
|
||||
from ..utils import data_dir, roberta_files, bert_files, multiprocessing_with_parallelism
|
||||
@@ -140,7 +141,7 @@ class TestTokenizer:
|
||||
def test_encode_formats(self, bert_files):
|
||||
tokenizer = BertWordPieceTokenizer(bert_files["vocab"])
|
||||
|
||||
# Well formed
|
||||
# Encode
|
||||
output = tokenizer.encode("my name is john")
|
||||
assert output.tokens == ["[CLS]", "my", "name", "is", "john", "[SEP]"]
|
||||
output = tokenizer.encode("my name is john", "pair")
|
||||
@@ -150,14 +151,91 @@ class TestTokenizer:
|
||||
output = tokenizer.encode(["my", "name", "is", "john"], ["pair"], is_pretokenized=True)
|
||||
assert output.tokens == ["[CLS]", "my", "name", "is", "john", "[SEP]", "pair", "[SEP]"]
|
||||
|
||||
output = tokenizer.encode_batch(["My name is John", "My name is Georges"])
|
||||
assert output[0].tokens == ["[CLS]", "my", "name", "is", "john", "[SEP]"]
|
||||
assert output[1].tokens == ["[CLS]", "my", "name", "is", "georges", "[SEP]"]
|
||||
output = tokenizer.encode_batch([("my name is john", "pair"), ("my name is john", "pair")])
|
||||
assert output[0].tokens == ["[CLS]", "my", "name", "is", "john", "[SEP]", "pair", "[SEP]"]
|
||||
assert output[1].tokens == ["[CLS]", "my", "name", "is", "john", "[SEP]", "pair", "[SEP]"]
|
||||
output = tokenizer.encode_batch([["my", "name", "is", "john"]], is_pretokenized=True)
|
||||
assert output[0].tokens == ["[CLS]", "my", "name", "is", "john", "[SEP]"]
|
||||
# Encode batch
|
||||
result_single = [
|
||||
["[CLS]", "my", "name", "is", "john", "[SEP]"],
|
||||
["[CLS]", "my", "name", "is", "georges", "[SEP]"],
|
||||
]
|
||||
result_pair = [
|
||||
["[CLS]", "my", "name", "is", "john", "[SEP]", "pair", "[SEP]"],
|
||||
["[CLS]", "my", "name", "is", "georges", "[SEP]", "pair", "[SEP]"],
|
||||
]
|
||||
|
||||
def format(encodings):
|
||||
return [e.tokens for e in encodings]
|
||||
|
||||
def test_single(input, is_pretokenized=False):
|
||||
output = tokenizer.encode_batch(input, is_pretokenized=is_pretokenized)
|
||||
assert format(output) == result_single
|
||||
|
||||
def test_pair(input, is_pretokenized=False):
|
||||
output = tokenizer.encode_batch(input, is_pretokenized=is_pretokenized)
|
||||
assert format(output) == result_pair
|
||||
|
||||
# Classic inputs
|
||||
|
||||
# Lists
|
||||
test_single(["My name is John", "My name is Georges"])
|
||||
test_pair([("my name is john", "pair"), ("my name is georges", "pair")])
|
||||
test_pair([["my name is john", "pair"], ["my name is georges", "pair"]])
|
||||
|
||||
# Tuples
|
||||
test_single(("My name is John", "My name is Georges"))
|
||||
test_pair((("My name is John", "pair"), ("My name is Georges", "pair")))
|
||||
|
||||
# Numpy
|
||||
test_single(np.array(["My name is John", "My name is Georges"]))
|
||||
test_pair(np.array([("My name is John", "pair"), ("My name is Georges", "pair")]))
|
||||
test_pair(np.array([["My name is John", "pair"], ["My name is Georges", "pair"]]))
|
||||
|
||||
# PreTokenized inputs
|
||||
|
||||
# Lists
|
||||
test_single([["My", "name", "is", "John"], ["My", "name", "is", "Georges"]], True)
|
||||
test_pair(
|
||||
[(["My", "name", "is", "John"], ["pair"]), (["My", "name", "is", "Georges"], ["pair"])],
|
||||
True,
|
||||
)
|
||||
test_pair(
|
||||
[[["My", "name", "is", "John"], ["pair"]], [["My", "name", "is", "Georges"], ["pair"]]],
|
||||
True,
|
||||
)
|
||||
|
||||
# Tuples
|
||||
test_single((("My", "name", "is", "John"), ("My", "name", "is", "Georges")), True)
|
||||
test_pair(
|
||||
(
|
||||
(("My", "name", "is", "John"), ("pair",)),
|
||||
(("My", "name", "is", "Georges"), ("pair",)),
|
||||
),
|
||||
True,
|
||||
)
|
||||
test_pair(
|
||||
((["My", "name", "is", "John"], ["pair"]), (["My", "name", "is", "Georges"], ["pair"])),
|
||||
True,
|
||||
)
|
||||
|
||||
# Numpy
|
||||
test_single(np.array([["My", "name", "is", "John"], ["My", "name", "is", "Georges"]]), True)
|
||||
test_single(np.array((("My", "name", "is", "John"), ("My", "name", "is", "Georges"))), True)
|
||||
test_pair(
|
||||
np.array(
|
||||
[
|
||||
[["My", "name", "is", "John"], ["pair"]],
|
||||
[["My", "name", "is", "Georges"], ["pair"]],
|
||||
]
|
||||
),
|
||||
True,
|
||||
)
|
||||
test_pair(
|
||||
np.array(
|
||||
(
|
||||
(("My", "name", "is", "John"), ("pair",)),
|
||||
(("My", "name", "is", "Georges"), ("pair",)),
|
||||
)
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
# Mal formed
|
||||
with pytest.raises(ValueError, match="InputSequence must be str"):
|
||||
|
||||
Reference in New Issue
Block a user