Python - Use macro for getter/setter in models

This commit is contained in:
Anthony MOI
2020-11-20 20:51:17 -05:00
committed by Anthony MOI
parent 2feccdbbfa
commit 091287dcf5

View File

@ -274,6 +274,28 @@ impl PyBPE {
}
}
macro_rules! getter {
($self: ident, $variant: ident, $($name: tt)+) => {{
let super_ = $self.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::$variant(ref mo) = *model {
mo.$($name)+
} else {
unreachable!()
}
}};
}
macro_rules! setter {
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
let super_ = $self.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::$variant(ref mut mo) = *model {
mo.$name = $value;
}
}};
}
#[derive(FromPyObject)]
enum PyVocab<'a> {
Vocab(Vocab),
@ -288,54 +310,28 @@ enum PyMerges<'a> {
#[pymethods]
impl PyBPE {
#[getter]
fn get_dropout(self_: PyRef<Self>) -> Option<f64> {
let super_ = self_.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::BPE(ref bpe) = *model {
bpe.dropout.map(|d| d as f64)
} else {
unreachable!()
}
fn get_dropout(self_: PyRef<Self>) -> Option<f32> {
getter!(self_, BPE, dropout)
}
#[setter]
fn set_dropout(self_: PyRef<Self>, dropout: Option<f32>) {
let super_ = self_.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::BPE(ref mut bpe) = *model {
bpe.dropout = dropout;
}
setter!(self_, BPE, dropout, dropout);
}
#[getter]
fn get_unk_token(self_: PyRef<Self>) -> Option<String> {
let super_ = self_.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::BPE(ref bpe) = *model {
bpe.unk_token.clone()
} else {
unreachable!()
}
getter!(self_, BPE, unk_token.clone())
}
#[setter]
fn set_unk_token(self_: PyRef<Self>, unk_token: Option<String>) {
let super_ = self_.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::BPE(ref mut bpe) = *model {
bpe.unk_token = unk_token;
}
setter!(self_, BPE, unk_token, unk_token);
}
#[getter]
fn get_continuing_subword_prefix(self_: PyRef<Self>) -> Option<String> {
let super_ = self_.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::BPE(ref bpe) = *model {
bpe.continuing_subword_prefix.clone()
} else {
unreachable!()
}
getter!(self_, BPE, continuing_subword_prefix.clone())
}
#[setter]
@ -343,51 +339,32 @@ impl PyBPE {
self_: PyRef<Self>,
continuing_subword_prefix: Option<String>,
) {
let super_ = self_.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::BPE(ref mut bpe) = *model {
bpe.continuing_subword_prefix = continuing_subword_prefix;
}
setter!(
self_,
BPE,
continuing_subword_prefix,
continuing_subword_prefix
);
}
#[getter]
fn get_end_of_word_suffix(self_: PyRef<Self>) -> Option<String> {
let super_ = self_.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::BPE(ref bpe) = *model {
bpe.end_of_word_suffix.clone()
} else {
unreachable!()
}
getter!(self_, BPE, end_of_word_suffix.clone())
}
#[setter]
fn set_end_of_word_suffix(self_: PyRef<Self>, end_of_word_suffix: Option<String>) {
let super_ = self_.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::BPE(ref mut bpe) = *model {
bpe.end_of_word_suffix = end_of_word_suffix;
}
setter!(self_, BPE, end_of_word_suffix, end_of_word_suffix);
}
#[getter]
fn get_fuse_unk(self_: PyRef<Self>) -> bool {
let super_ = self_.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::BPE(ref bpe) = *model {
bpe.fuse_unk
} else {
unreachable!()
}
getter!(self_, BPE, fuse_unk)
}
#[setter]
fn set_fuse_unk(self_: PyRef<Self>, fuse_unk: bool) {
let super_ = self_.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::BPE(ref mut bpe) = *model {
bpe.fuse_unk = fuse_unk;
}
setter!(self_, BPE, fuse_unk, fuse_unk);
}
#[new]
@ -551,62 +528,37 @@ impl PyWordPiece {
impl PyWordPiece {
#[getter]
fn get_unk_token(self_: PyRef<Self>) -> String {
let super_ = self_.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::WordPiece(ref wp) = *model {
wp.unk_token.clone()
} else {
unreachable!()
}
getter!(self_, WordPiece, unk_token.clone())
}
#[setter]
fn set_unk_token(self_: PyRef<Self>, unk_token: String) {
let super_ = self_.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::WordPiece(ref mut wp) = *model {
wp.unk_token = unk_token;
}
setter!(self_, WordPiece, unk_token, unk_token);
}
#[getter]
fn get_continuing_subword_prefix(self_: PyRef<Self>) -> String {
let super_ = self_.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::WordPiece(ref wp) = *model {
wp.continuing_subword_prefix.clone()
} else {
unreachable!()
}
getter!(self_, WordPiece, continuing_subword_prefix.clone())
}
#[setter]
fn set_continuing_subword_prefix(self_: PyRef<Self>, continuing_subword_prefix: String) {
let super_ = self_.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::WordPiece(ref mut wp) = *model {
wp.continuing_subword_prefix = continuing_subword_prefix;
}
setter!(
self_,
WordPiece,
continuing_subword_prefix,
continuing_subword_prefix
);
}
#[getter]
fn get_max_input_chars_per_word(self_: PyRef<Self>) -> usize {
let super_ = self_.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::WordPiece(ref wp) = *model {
wp.max_input_chars_per_word
} else {
unreachable!()
}
getter!(self_, WordPiece, max_input_chars_per_word)
}
#[setter]
fn set_max_input_chars_per_word(self_: PyRef<Self>, max: usize) {
let super_ = self_.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::WordPiece(ref mut wp) = *model {
wp.max_input_chars_per_word = max;
}
setter!(self_, WordPiece, max_input_chars_per_word, max);
}
#[new]
@ -704,22 +656,12 @@ pub struct PyWordLevel {}
impl PyWordLevel {
#[getter]
fn get_unk_token(self_: PyRef<Self>) -> String {
let super_ = self_.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::WordLevel(ref wl) = *model {
wl.unk_token.clone()
} else {
unreachable!()
}
getter!(self_, WordLevel, unk_token.clone())
}
#[setter]
fn set_unk_token(self_: PyRef<Self>, unk_token: String) {
let super_ = self_.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::WordLevel(ref mut wl) = *model {
wl.unk_token = unk_token;
}
setter!(self_, WordLevel, unk_token, unk_token);
}
#[new]