Node - Improve handling of optionals

This commit is contained in:
Anthony MOI
2020-01-10 11:52:15 -05:00
parent 7e59ff8ee9
commit 0925c30997
4 changed files with 189 additions and 115 deletions

View File

@ -166,21 +166,29 @@ declare_types! {
if let Some(options) = options {
if let Ok(options) = options.downcast::<JsObject>() {
if let Ok(dir) = options.get(&mut cx, "direction") {
let dir = dir.downcast::<JsString>().or_throw(&mut cx)?.value();
match &dir[..] {
"right" => direction = PaddingDirection::Right,
"left" => direction = PaddingDirection::Left,
_ => return cx.throw_error("direction can be 'right' or 'left'"),
if let Err(_) = dir.downcast::<JsUndefined>() {
let dir = dir.downcast::<JsString>().or_throw(&mut cx)?.value();
match &dir[..] {
"right" => direction = PaddingDirection::Right,
"left" => direction = PaddingDirection::Left,
_ => return cx.throw_error("direction can be 'right' or 'left'"),
}
}
}
if let Ok(pid) = options.get(&mut cx, "padId") {
pad_id = pid.downcast::<JsNumber>().or_throw(&mut cx)?.value() as u32;
if let Err(_) = pid.downcast::<JsUndefined>() {
pad_id = pid.downcast::<JsNumber>().or_throw(&mut cx)?.value() as u32;
}
}
if let Ok(pid) = options.get(&mut cx, "padTypeId") {
pad_type_id = pid.downcast::<JsNumber>().or_throw(&mut cx)?.value() as u32;
if let Err(_) = pid.downcast::<JsUndefined>() {
pad_type_id = pid.downcast::<JsNumber>().or_throw(&mut cx)?.value() as u32;
}
}
if let Ok(token) = options.get(&mut cx, "padToken") {
pad_token = token.downcast::<JsString>().or_throw(&mut cx)?.value();
if let Err(_) = token.downcast::<JsUndefined>() {
pad_token = token.downcast::<JsString>().or_throw(&mut cx)?.value();
}
}
}
}

View File

@ -38,28 +38,38 @@ pub fn bpe_from_files(mut cx: FunctionContext) -> JsResult<JsModel> {
if let Some(options) = options {
if let Ok(options) = options.downcast::<JsObject>() {
if let Ok(cache_capacity) = options.get(&mut cx, "cache_capacity") {
let cache_capacity = cache_capacity
.downcast::<JsNumber>()
.or_throw(&mut cx)?
.value() as usize;
builder = builder.cache_capacity(cache_capacity);
if let Err(_) = cache_capacity.downcast::<JsUndefined>() {
let cache_capacity = cache_capacity
.downcast::<JsNumber>()
.or_throw(&mut cx)?
.value() as usize;
builder = builder.cache_capacity(cache_capacity);
}
}
if let Ok(dropout) = options.get(&mut cx, "dropout") {
let dropout = dropout.downcast::<JsNumber>().or_throw(&mut cx)?.value() as f32;
builder = builder.dropout(dropout);
if let Err(_) = dropout.downcast::<JsUndefined>() {
let dropout = dropout.downcast::<JsNumber>().or_throw(&mut cx)?.value() as f32;
builder = builder.dropout(dropout);
}
}
if let Ok(unk_token) = options.get(&mut cx, "unk_token") {
let unk_token =
unk_token.downcast::<JsString>().or_throw(&mut cx)?.value() as String;
builder = builder.unk_token(unk_token);
if let Err(_) = unk_token.downcast::<JsUndefined>() {
let unk_token =
unk_token.downcast::<JsString>().or_throw(&mut cx)?.value() as String;
builder = builder.unk_token(unk_token);
}
}
if let Ok(prefix) = options.get(&mut cx, "continuing_subword_prefix") {
let prefix = prefix.downcast::<JsString>().or_throw(&mut cx)?.value() as String;
builder = builder.continuing_subword_prefix(prefix);
if let Err(_) = prefix.downcast::<JsUndefined>() {
let prefix = prefix.downcast::<JsString>().or_throw(&mut cx)?.value() as String;
builder = builder.continuing_subword_prefix(prefix);
}
}
if let Ok(suffix) = options.get(&mut cx, "end_of_word_suffix") {
let suffix = suffix.downcast::<JsString>().or_throw(&mut cx)?.value() as String;
builder = builder.end_of_word_suffix(suffix);
if let Err(_) = suffix.downcast::<JsUndefined>() {
let suffix = suffix.downcast::<JsString>().or_throw(&mut cx)?.value() as String;
builder = builder.end_of_word_suffix(suffix);
}
}
}
}
@ -100,11 +110,15 @@ pub fn wordpiece_from_files(mut cx: FunctionContext) -> JsResult<JsModel> {
if let Some(options) = options {
if let Ok(options) = options.downcast::<JsObject>() {
if let Ok(unk) = options.get(&mut cx, "unkToken") {
unk_token = unk.downcast::<JsString>().or_throw(&mut cx)?.value() as String;
if let Err(_) = unk.downcast::<JsUndefined>() {
unk_token = unk.downcast::<JsString>().or_throw(&mut cx)?.value() as String;
}
}
if let Ok(max) = options.get(&mut cx, "maxInputCharsPerWord") {
max_input_chars_per_word =
Some(max.downcast::<JsNumber>().or_throw(&mut cx)?.value() as usize);
if let Err(_) = max.downcast::<JsUndefined>() {
max_input_chars_per_word =
Some(max.downcast::<JsNumber>().or_throw(&mut cx)?.value() as usize);
}
}
}
}

View File

@ -34,16 +34,24 @@ fn bert_normalizer(mut cx: FunctionContext) -> JsResult<JsNormalizer> {
if let Some(options) = cx.argument_opt(0) {
let options = options.downcast::<JsObject>().or_throw(&mut cx)?;
if let Ok(ct) = options.get(&mut cx, "cleanText") {
clean_text = ct.downcast::<JsBoolean>().or_throw(&mut cx)?.value();
if let Err(_) = ct.downcast::<JsUndefined>() {
clean_text = ct.downcast::<JsBoolean>().or_throw(&mut cx)?.value();
}
}
if let Ok(hcc) = options.get(&mut cx, "handleChineseChars") {
handle_chinese_chars = hcc.downcast::<JsBoolean>().or_throw(&mut cx)?.value();
if let Err(_) = hcc.downcast::<JsUndefined>() {
handle_chinese_chars = hcc.downcast::<JsBoolean>().or_throw(&mut cx)?.value();
}
}
if let Ok(sa) = options.get(&mut cx, "stripAccents") {
strip_accents = sa.downcast::<JsBoolean>().or_throw(&mut cx)?.value();
if let Err(_) = sa.downcast::<JsUndefined>() {
strip_accents = sa.downcast::<JsBoolean>().or_throw(&mut cx)?.value();
}
}
if let Ok(l) = options.get(&mut cx, "lowercase") {
lowercase = l.downcast::<JsBoolean>().or_throw(&mut cx)?.value();
if let Err(_) = l.downcast::<JsUndefined>() {
lowercase = l.downcast::<JsBoolean>().or_throw(&mut cx)?.value();
}
}
}

View File

@ -38,63 +38,85 @@ fn bpe_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {
if let Some(options) = options {
if let Ok(options) = options.downcast::<JsObject>() {
if let Ok(size) = options.get(&mut cx, "vocabSize") {
builder = builder
.vocab_size(size.downcast::<JsNumber>().or_throw(&mut cx)?.value() as usize);
if let Err(_) = size.downcast::<JsUndefined>() {
builder =
builder.vocab_size(
size.downcast::<JsNumber>().or_throw(&mut cx)?.value() as usize
);
}
}
if let Ok(freq) = options.get(&mut cx, "minFrequency") {
builder = builder
.min_frequency(freq.downcast::<JsNumber>().or_throw(&mut cx)?.value() as u32);
if let Err(_) = freq.downcast::<JsUndefined>() {
builder = builder.min_frequency(
freq.downcast::<JsNumber>().or_throw(&mut cx)?.value() as u32,
);
}
}
if let Ok(tokens) = options.get(&mut cx, "specialTokens") {
builder = builder.special_tokens(
tokens
.downcast::<JsArray>()
.or_throw(&mut cx)?
.to_vec(&mut cx)?
.into_iter()
.map(|token| Ok(token.downcast::<JsString>().or_throw(&mut cx)?.value()))
.collect::<NeonResult<Vec<_>>>()?,
);
if let Err(_) = tokens.downcast::<JsUndefined>() {
builder = builder.special_tokens(
tokens
.downcast::<JsArray>()
.or_throw(&mut cx)?
.to_vec(&mut cx)?
.into_iter()
.map(|token| {
Ok(token.downcast::<JsString>().or_throw(&mut cx)?.value())
})
.collect::<NeonResult<Vec<_>>>()?,
);
}
}
if let Ok(limit) = options.get(&mut cx, "limitAlphabet") {
builder = builder.limit_alphabet(
limit.downcast::<JsNumber>().or_throw(&mut cx)?.value() as usize,
);
if let Err(_) = limit.downcast::<JsUndefined>() {
builder = builder.limit_alphabet(
limit.downcast::<JsNumber>().or_throw(&mut cx)?.value() as usize,
);
}
}
if let Ok(alphabet) = options.get(&mut cx, "initialAlphabet") {
builder = builder.initial_alphabet(
alphabet
.downcast::<JsArray>()
.or_throw(&mut cx)?
.to_vec(&mut cx)?
.into_iter()
.map(|tokens| {
Ok(tokens
.downcast::<JsString>()
.or_throw(&mut cx)?
.value()
.chars()
.nth(0))
})
.collect::<NeonResult<Vec<_>>>()?
.into_iter()
.filter(|c| c.is_some())
.map(|c| c.unwrap())
.collect::<HashSet<_>>(),
);
if let Err(_) = alphabet.downcast::<JsUndefined>() {
builder = builder.initial_alphabet(
alphabet
.downcast::<JsArray>()
.or_throw(&mut cx)?
.to_vec(&mut cx)?
.into_iter()
.map(|tokens| {
Ok(tokens
.downcast::<JsString>()
.or_throw(&mut cx)?
.value()
.chars()
.nth(0))
})
.collect::<NeonResult<Vec<_>>>()?
.into_iter()
.filter(|c| c.is_some())
.map(|c| c.unwrap())
.collect::<HashSet<_>>(),
);
}
}
if let Ok(show) = options.get(&mut cx, "showProgress") {
builder =
builder.show_progress(show.downcast::<JsBoolean>().or_throw(&mut cx)?.value());
if let Err(_) = show.downcast::<JsUndefined>() {
builder = builder
.show_progress(show.downcast::<JsBoolean>().or_throw(&mut cx)?.value());
}
}
if let Ok(prefix) = options.get(&mut cx, "continuingSubwordPrefix") {
builder = builder.continuing_subword_prefix(
prefix.downcast::<JsString>().or_throw(&mut cx)?.value(),
);
if let Err(_) = prefix.downcast::<JsUndefined>() {
builder = builder.continuing_subword_prefix(
prefix.downcast::<JsString>().or_throw(&mut cx)?.value(),
);
}
}
if let Ok(suffix) = options.get(&mut cx, "endOfWordSuffix") {
builder = builder
.end_of_word_suffix(suffix.downcast::<JsString>().or_throw(&mut cx)?.value());
if let Err(_) = suffix.downcast::<JsUndefined>() {
builder = builder.end_of_word_suffix(
suffix.downcast::<JsString>().or_throw(&mut cx)?.value(),
);
}
}
}
}
@ -125,63 +147,85 @@ fn wordpiece_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {
if let Some(options) = options {
if let Ok(options) = options.downcast::<JsObject>() {
if let Ok(size) = options.get(&mut cx, "vocabSize") {
builder = builder
.vocab_size(size.downcast::<JsNumber>().or_throw(&mut cx)?.value() as usize);
if let Err(_) = size.downcast::<JsUndefined>() {
builder =
builder.vocab_size(
size.downcast::<JsNumber>().or_throw(&mut cx)?.value() as usize
);
}
}
if let Ok(freq) = options.get(&mut cx, "minFrequency") {
builder = builder
.min_frequency(freq.downcast::<JsNumber>().or_throw(&mut cx)?.value() as u32);
if let Err(_) = freq.downcast::<JsUndefined>() {
builder = builder.min_frequency(
freq.downcast::<JsNumber>().or_throw(&mut cx)?.value() as u32,
);
}
}
if let Ok(tokens) = options.get(&mut cx, "specialTokens") {
builder = builder.special_tokens(
tokens
.downcast::<JsArray>()
.or_throw(&mut cx)?
.to_vec(&mut cx)?
.into_iter()
.map(|token| Ok(token.downcast::<JsString>().or_throw(&mut cx)?.value()))
.collect::<NeonResult<Vec<_>>>()?,
);
if let Err(_) = tokens.downcast::<JsUndefined>() {
builder = builder.special_tokens(
tokens
.downcast::<JsArray>()
.or_throw(&mut cx)?
.to_vec(&mut cx)?
.into_iter()
.map(|token| {
Ok(token.downcast::<JsString>().or_throw(&mut cx)?.value())
})
.collect::<NeonResult<Vec<_>>>()?,
);
}
}
if let Ok(limit) = options.get(&mut cx, "limitAlphabet") {
builder = builder.limit_alphabet(
limit.downcast::<JsNumber>().or_throw(&mut cx)?.value() as usize,
);
if let Err(_) = limit.downcast::<JsUndefined>() {
builder = builder.limit_alphabet(
limit.downcast::<JsNumber>().or_throw(&mut cx)?.value() as usize,
);
}
}
if let Ok(alphabet) = options.get(&mut cx, "initialAlphabet") {
builder = builder.initial_alphabet(
alphabet
.downcast::<JsArray>()
.or_throw(&mut cx)?
.to_vec(&mut cx)?
.into_iter()
.map(|tokens| {
Ok(tokens
.downcast::<JsString>()
.or_throw(&mut cx)?
.value()
.chars()
.nth(0))
})
.collect::<NeonResult<Vec<_>>>()?
.into_iter()
.filter(|c| c.is_some())
.map(|c| c.unwrap())
.collect::<HashSet<_>>(),
);
if let Err(_) = alphabet.downcast::<JsUndefined>() {
builder = builder.initial_alphabet(
alphabet
.downcast::<JsArray>()
.or_throw(&mut cx)?
.to_vec(&mut cx)?
.into_iter()
.map(|tokens| {
Ok(tokens
.downcast::<JsString>()
.or_throw(&mut cx)?
.value()
.chars()
.nth(0))
})
.collect::<NeonResult<Vec<_>>>()?
.into_iter()
.filter(|c| c.is_some())
.map(|c| c.unwrap())
.collect::<HashSet<_>>(),
);
}
}
if let Ok(show) = options.get(&mut cx, "showProgress") {
builder =
builder.show_progress(show.downcast::<JsBoolean>().or_throw(&mut cx)?.value());
if let Err(_) = show.downcast::<JsUndefined>() {
builder = builder
.show_progress(show.downcast::<JsBoolean>().or_throw(&mut cx)?.value());
}
}
if let Ok(prefix) = options.get(&mut cx, "continuingSubwordPrefix") {
builder = builder.continuing_subword_prefix(
prefix.downcast::<JsString>().or_throw(&mut cx)?.value(),
);
if let Err(_) = prefix.downcast::<JsUndefined>() {
builder = builder.continuing_subword_prefix(
prefix.downcast::<JsString>().or_throw(&mut cx)?.value(),
);
}
}
if let Ok(suffix) = options.get(&mut cx, "endOfWordSuffix") {
builder = builder
.end_of_word_suffix(suffix.downcast::<JsString>().or_throw(&mut cx)?.value());
if let Err(_) = suffix.downcast::<JsUndefined>() {
builder = builder.end_of_word_suffix(
suffix.downcast::<JsString>().or_throw(&mut cx)?.value(),
);
}
}
}
}