Python - Use macro for getter/setter in models

This commit is contained in:
Anthony MOI
2020-11-20 20:51:17 -05:00
committed by Anthony MOI
parent 2feccdbbfa
commit 091287dcf5

View File

@ -274,6 +274,28 @@ impl PyBPE {
} }
} }
macro_rules! getter {
($self: ident, $variant: ident, $($name: tt)+) => {{
let super_ = $self.as_ref();
let model = super_.model.read().unwrap();
if let ModelWrapper::$variant(ref mo) = *model {
mo.$($name)+
} else {
unreachable!()
}
}};
}
macro_rules! setter {
($self: ident, $variant: ident, $name: ident, $value: expr) => {{
let super_ = $self.as_ref();
let mut model = super_.model.write().unwrap();
if let ModelWrapper::$variant(ref mut mo) = *model {
mo.$name = $value;
}
}};
}
#[derive(FromPyObject)] #[derive(FromPyObject)]
enum PyVocab<'a> { enum PyVocab<'a> {
Vocab(Vocab), Vocab(Vocab),
@ -288,54 +310,28 @@ enum PyMerges<'a> {
#[pymethods] #[pymethods]
impl PyBPE { impl PyBPE {
#[getter] #[getter]
fn get_dropout(self_: PyRef<Self>) -> Option<f64> { fn get_dropout(self_: PyRef<Self>) -> Option<f32> {
let super_ = self_.as_ref(); getter!(self_, BPE, dropout)
let model = super_.model.read().unwrap();
if let ModelWrapper::BPE(ref bpe) = *model {
bpe.dropout.map(|d| d as f64)
} else {
unreachable!()
}
} }
#[setter] #[setter]
fn set_dropout(self_: PyRef<Self>, dropout: Option<f32>) { fn set_dropout(self_: PyRef<Self>, dropout: Option<f32>) {
let super_ = self_.as_ref(); setter!(self_, BPE, dropout, dropout);
let mut model = super_.model.write().unwrap();
if let ModelWrapper::BPE(ref mut bpe) = *model {
bpe.dropout = dropout;
}
} }
#[getter] #[getter]
fn get_unk_token(self_: PyRef<Self>) -> Option<String> { fn get_unk_token(self_: PyRef<Self>) -> Option<String> {
let super_ = self_.as_ref(); getter!(self_, BPE, unk_token.clone())
let model = super_.model.read().unwrap();
if let ModelWrapper::BPE(ref bpe) = *model {
bpe.unk_token.clone()
} else {
unreachable!()
}
} }
#[setter] #[setter]
fn set_unk_token(self_: PyRef<Self>, unk_token: Option<String>) { fn set_unk_token(self_: PyRef<Self>, unk_token: Option<String>) {
let super_ = self_.as_ref(); setter!(self_, BPE, unk_token, unk_token);
let mut model = super_.model.write().unwrap();
if let ModelWrapper::BPE(ref mut bpe) = *model {
bpe.unk_token = unk_token;
}
} }
#[getter] #[getter]
fn get_continuing_subword_prefix(self_: PyRef<Self>) -> Option<String> { fn get_continuing_subword_prefix(self_: PyRef<Self>) -> Option<String> {
let super_ = self_.as_ref(); getter!(self_, BPE, continuing_subword_prefix.clone())
let model = super_.model.read().unwrap();
if let ModelWrapper::BPE(ref bpe) = *model {
bpe.continuing_subword_prefix.clone()
} else {
unreachable!()
}
} }
#[setter] #[setter]
@ -343,51 +339,32 @@ impl PyBPE {
self_: PyRef<Self>, self_: PyRef<Self>,
continuing_subword_prefix: Option<String>, continuing_subword_prefix: Option<String>,
) { ) {
let super_ = self_.as_ref(); setter!(
let mut model = super_.model.write().unwrap(); self_,
if let ModelWrapper::BPE(ref mut bpe) = *model { BPE,
bpe.continuing_subword_prefix = continuing_subword_prefix; continuing_subword_prefix,
} continuing_subword_prefix
);
} }
#[getter] #[getter]
fn get_end_of_word_suffix(self_: PyRef<Self>) -> Option<String> { fn get_end_of_word_suffix(self_: PyRef<Self>) -> Option<String> {
let super_ = self_.as_ref(); getter!(self_, BPE, end_of_word_suffix.clone())
let model = super_.model.read().unwrap();
if let ModelWrapper::BPE(ref bpe) = *model {
bpe.end_of_word_suffix.clone()
} else {
unreachable!()
}
} }
#[setter] #[setter]
fn set_end_of_word_suffix(self_: PyRef<Self>, end_of_word_suffix: Option<String>) { fn set_end_of_word_suffix(self_: PyRef<Self>, end_of_word_suffix: Option<String>) {
let super_ = self_.as_ref(); setter!(self_, BPE, end_of_word_suffix, end_of_word_suffix);
let mut model = super_.model.write().unwrap();
if let ModelWrapper::BPE(ref mut bpe) = *model {
bpe.end_of_word_suffix = end_of_word_suffix;
}
} }
#[getter] #[getter]
fn get_fuse_unk(self_: PyRef<Self>) -> bool { fn get_fuse_unk(self_: PyRef<Self>) -> bool {
let super_ = self_.as_ref(); getter!(self_, BPE, fuse_unk)
let model = super_.model.read().unwrap();
if let ModelWrapper::BPE(ref bpe) = *model {
bpe.fuse_unk
} else {
unreachable!()
}
} }
#[setter] #[setter]
fn set_fuse_unk(self_: PyRef<Self>, fuse_unk: bool) { fn set_fuse_unk(self_: PyRef<Self>, fuse_unk: bool) {
let super_ = self_.as_ref(); setter!(self_, BPE, fuse_unk, fuse_unk);
let mut model = super_.model.write().unwrap();
if let ModelWrapper::BPE(ref mut bpe) = *model {
bpe.fuse_unk = fuse_unk;
}
} }
#[new] #[new]
@ -551,62 +528,37 @@ impl PyWordPiece {
impl PyWordPiece { impl PyWordPiece {
#[getter] #[getter]
fn get_unk_token(self_: PyRef<Self>) -> String { fn get_unk_token(self_: PyRef<Self>) -> String {
let super_ = self_.as_ref(); getter!(self_, WordPiece, unk_token.clone())
let model = super_.model.read().unwrap();
if let ModelWrapper::WordPiece(ref wp) = *model {
wp.unk_token.clone()
} else {
unreachable!()
}
} }
#[setter] #[setter]
fn set_unk_token(self_: PyRef<Self>, unk_token: String) { fn set_unk_token(self_: PyRef<Self>, unk_token: String) {
let super_ = self_.as_ref(); setter!(self_, WordPiece, unk_token, unk_token);
let mut model = super_.model.write().unwrap();
if let ModelWrapper::WordPiece(ref mut wp) = *model {
wp.unk_token = unk_token;
}
} }
#[getter] #[getter]
fn get_continuing_subword_prefix(self_: PyRef<Self>) -> String { fn get_continuing_subword_prefix(self_: PyRef<Self>) -> String {
let super_ = self_.as_ref(); getter!(self_, WordPiece, continuing_subword_prefix.clone())
let model = super_.model.read().unwrap();
if let ModelWrapper::WordPiece(ref wp) = *model {
wp.continuing_subword_prefix.clone()
} else {
unreachable!()
}
} }
#[setter] #[setter]
fn set_continuing_subword_prefix(self_: PyRef<Self>, continuing_subword_prefix: String) { fn set_continuing_subword_prefix(self_: PyRef<Self>, continuing_subword_prefix: String) {
let super_ = self_.as_ref(); setter!(
let mut model = super_.model.write().unwrap(); self_,
if let ModelWrapper::WordPiece(ref mut wp) = *model { WordPiece,
wp.continuing_subword_prefix = continuing_subword_prefix; continuing_subword_prefix,
} continuing_subword_prefix
);
} }
#[getter] #[getter]
fn get_max_input_chars_per_word(self_: PyRef<Self>) -> usize { fn get_max_input_chars_per_word(self_: PyRef<Self>) -> usize {
let super_ = self_.as_ref(); getter!(self_, WordPiece, max_input_chars_per_word)
let model = super_.model.read().unwrap();
if let ModelWrapper::WordPiece(ref wp) = *model {
wp.max_input_chars_per_word
} else {
unreachable!()
}
} }
#[setter] #[setter]
fn set_max_input_chars_per_word(self_: PyRef<Self>, max: usize) { fn set_max_input_chars_per_word(self_: PyRef<Self>, max: usize) {
let super_ = self_.as_ref(); setter!(self_, WordPiece, max_input_chars_per_word, max);
let mut model = super_.model.write().unwrap();
if let ModelWrapper::WordPiece(ref mut wp) = *model {
wp.max_input_chars_per_word = max;
}
} }
#[new] #[new]
@ -704,22 +656,12 @@ pub struct PyWordLevel {}
impl PyWordLevel { impl PyWordLevel {
#[getter] #[getter]
fn get_unk_token(self_: PyRef<Self>) -> String { fn get_unk_token(self_: PyRef<Self>) -> String {
let super_ = self_.as_ref(); getter!(self_, WordLevel, unk_token.clone())
let model = super_.model.read().unwrap();
if let ModelWrapper::WordLevel(ref wl) = *model {
wl.unk_token.clone()
} else {
unreachable!()
}
} }
#[setter] #[setter]
fn set_unk_token(self_: PyRef<Self>, unk_token: String) { fn set_unk_token(self_: PyRef<Self>, unk_token: String) {
let super_ = self_.as_ref(); setter!(self_, WordLevel, unk_token, unk_token);
let mut model = super_.model.write().unwrap();
if let ModelWrapper::WordLevel(ref mut wl) = *model {
wl.unk_token = unk_token;
}
} }
#[new] #[new]