Python - PyWordPiece can get/set its attributes

This commit is contained in:
Anthony MOI
2020-11-16 12:34:55 -05:00
committed by Anthony MOI
parent c22cfc31f9
commit 760537aad3
3 changed files with 82 additions and 3 deletions

View File

@ -549,6 +549,66 @@ impl PyWordPiece {
#[pymethods]
impl PyWordPiece {
#[getter]
fn get_unk_token(self_: PyRef<Self>) -> String {
let super_ = self_.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::WordPiece(ref wp) = *model {
wp.unk_token.clone()
} else {
unreachable!()
}
}
#[setter]
fn set_unk_token(self_: PyRef<Self>, unk_token: String) {
let super_ = self_.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::WordPiece(ref mut wp) = *model {
wp.unk_token = unk_token;
}
}
#[getter]
fn get_continuing_subword_prefix(self_: PyRef<Self>) -> String {
let super_ = self_.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::WordPiece(ref wp) = *model {
wp.continuing_subword_prefix.clone()
} else {
unreachable!()
}
}
#[setter]
fn set_continuing_subword_prefix(self_: PyRef<Self>, continuing_subword_prefix: String) {
let super_ = self_.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::WordPiece(ref mut wp) = *model {
wp.continuing_subword_prefix = continuing_subword_prefix;
}
}
#[getter]
fn get_max_input_chars_per_word(self_: PyRef<Self>) -> usize {
let super_ = self_.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::WordPiece(ref wp) = *model {
wp.max_input_chars_per_word
} else {
unreachable!()
}
}
#[setter]
fn set_max_input_chars_per_word(self_: PyRef<Self>, max: usize) {
let super_ = self_.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::WordPiece(ref mut wp) = *model {
wp.max_input_chars_per_word = max;
}
}
#[new]
#[args(kwargs = "**")]
fn new(vocab: Option<PyVocab>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {

View File

@ -84,6 +84,25 @@ class TestWordPiece:
with pytest.deprecated_call():
assert isinstance(pickle.loads(pickle.dumps(WordPiece(bert_files["vocab"]))), WordPiece)
def test_can_modify(self):
model = WordPiece(
unk_token="<oov>",
continuing_subword_prefix="__prefix__",
max_input_chars_per_word=200,
)
assert model.unk_token == "<oov>"
assert model.continuing_subword_prefix == "__prefix__"
assert model.max_input_chars_per_word == 200
# Modify these
model.unk_token = "<unk>"
assert model.unk_token == "<unk>"
model.continuing_subword_prefix = "$$$"
assert model.continuing_subword_prefix == "$$$"
model.max_input_chars_per_word = 10
assert model.max_input_chars_per_word == 10
class TestWordLevel:
def test_instantiate(self, roberta_files):