mirror of
https://github.com/mii443/tokenizers.git
synced 2025-09-01 14:59:20 +00:00
Python - PyWordPiece can get/set its attributes
This commit is contained in:
@ -549,6 +549,66 @@ impl PyWordPiece {
|
||||
|
||||
#[pymethods]
|
||||
impl PyWordPiece {
|
||||
#[getter]
|
||||
fn get_unk_token(self_: PyRef<Self>) -> String {
|
||||
let super_ = self_.as_ref();
|
||||
let model = super_.model.read().unwrap();
|
||||
if let ModelWrapper::WordPiece(ref wp) = *model {
|
||||
wp.unk_token.clone()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_unk_token(self_: PyRef<Self>, unk_token: String) {
|
||||
let super_ = self_.as_ref();
|
||||
let mut model = super_.model.write().unwrap();
|
||||
if let ModelWrapper::WordPiece(ref mut wp) = *model {
|
||||
wp.unk_token = unk_token;
|
||||
}
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_continuing_subword_prefix(self_: PyRef<Self>) -> String {
|
||||
let super_ = self_.as_ref();
|
||||
let model = super_.model.read().unwrap();
|
||||
if let ModelWrapper::WordPiece(ref wp) = *model {
|
||||
wp.continuing_subword_prefix.clone()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_continuing_subword_prefix(self_: PyRef<Self>, continuing_subword_prefix: String) {
|
||||
let super_ = self_.as_ref();
|
||||
let mut model = super_.model.write().unwrap();
|
||||
if let ModelWrapper::WordPiece(ref mut wp) = *model {
|
||||
wp.continuing_subword_prefix = continuing_subword_prefix;
|
||||
}
|
||||
}
|
||||
|
||||
#[getter]
|
||||
fn get_max_input_chars_per_word(self_: PyRef<Self>) -> usize {
|
||||
let super_ = self_.as_ref();
|
||||
let model = super_.model.read().unwrap();
|
||||
if let ModelWrapper::WordPiece(ref wp) = *model {
|
||||
wp.max_input_chars_per_word
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
#[setter]
|
||||
fn set_max_input_chars_per_word(self_: PyRef<Self>, max: usize) {
|
||||
let super_ = self_.as_ref();
|
||||
let mut model = super_.model.write().unwrap();
|
||||
if let ModelWrapper::WordPiece(ref mut wp) = *model {
|
||||
wp.max_input_chars_per_word = max;
|
||||
}
|
||||
}
|
||||
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
fn new(vocab: Option<PyVocab>, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> {
|
||||
|
@ -84,6 +84,25 @@ class TestWordPiece:
|
||||
with pytest.deprecated_call():
|
||||
assert isinstance(pickle.loads(pickle.dumps(WordPiece(bert_files["vocab"]))), WordPiece)
|
||||
|
||||
def test_can_modify(self):
|
||||
model = WordPiece(
|
||||
unk_token="<oov>",
|
||||
continuing_subword_prefix="__prefix__",
|
||||
max_input_chars_per_word=200,
|
||||
)
|
||||
|
||||
assert model.unk_token == "<oov>"
|
||||
assert model.continuing_subword_prefix == "__prefix__"
|
||||
assert model.max_input_chars_per_word == 200
|
||||
|
||||
# Modify these
|
||||
model.unk_token = "<unk>"
|
||||
assert model.unk_token == "<unk>"
|
||||
model.continuing_subword_prefix = "$$$"
|
||||
assert model.continuing_subword_prefix == "$$$"
|
||||
model.max_input_chars_per_word = 10
|
||||
assert model.max_input_chars_per_word == 10
|
||||
|
||||
|
||||
class TestWordLevel:
|
||||
def test_instantiate(self, roberta_files):
|
||||
|
Reference in New Issue
Block a user