mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Node - Add WordLevelTrainer
This commit is contained in:
29
bindings/node/lib/bindings/trainers.d.ts
vendored
29
bindings/node/lib/bindings/trainers.d.ts
vendored
@ -63,6 +63,35 @@ export function bpeTrainer(options?: TrainerOptions): Trainer;
|
||||
*/
|
||||
export function wordPieceTrainer(options?: TrainerOptions): Trainer;
|
||||
|
||||
export interface WordLevelTrainerOptions {
|
||||
/**
|
||||
* The minimum frequency a pair should have in order to be merged.
|
||||
* @default 2
|
||||
*/
|
||||
minFrequency?: number;
|
||||
/**
|
||||
* Whether to show progress bars while training.
|
||||
* @default true
|
||||
*/
|
||||
showProgress?: boolean;
|
||||
/**
|
||||
* A list of special tokens the model should know of.
|
||||
* @default []
|
||||
*/
|
||||
specialTokens?: (string | AddedToken)[];
|
||||
/**
|
||||
* The size of the final vocabulary, including all tokens and alphabet.
|
||||
* @default 30000
|
||||
*/
|
||||
vocabSize?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Instantiate a new WordLevel Trainer
|
||||
* @param [options] WordLevel Trainer options
|
||||
*/
|
||||
export function wordLevelTrainer(options?: WordLevelTrainerOptions): Trainer;
|
||||
|
||||
export interface UnigramTrainerOptions {
|
||||
vocabSize?: number;
|
||||
nSubIterations?: number;
|
||||
|
@ -3,5 +3,6 @@ const native = require("./native");
|
||||
module.exports = {
|
||||
bpeTrainer: native.trainers_BPETrainer,
|
||||
wordPieceTrainer: native.trainers_WordPieceTrainer,
|
||||
wordLevelTrainer: native.trainers_WordLevelTrainer,
|
||||
unigramTrainer: native.trainers_UnigramTrainer,
|
||||
};
|
||||
|
@ -8,7 +8,8 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tk::models::{
|
||||
bpe::BpeTrainer, unigram::UnigramTrainer, wordpiece::WordPieceTrainer, TrainerWrapper,
|
||||
bpe::BpeTrainer, unigram::UnigramTrainer, wordlevel::WordLevelTrainer,
|
||||
wordpiece::WordPieceTrainer, TrainerWrapper,
|
||||
};
|
||||
|
||||
/// Trainer
|
||||
@ -17,6 +18,14 @@ pub struct Trainer {
|
||||
pub trainer: Option<Arc<TrainerWrapper>>,
|
||||
}
|
||||
|
||||
impl From<TrainerWrapper> for Trainer {
|
||||
fn from(trainer: TrainerWrapper) -> Self {
|
||||
Self {
|
||||
trainer: Some(Arc::new(trainer)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl tk::Trainer for Trainer {
|
||||
type Model = Model;
|
||||
|
||||
@ -27,14 +36,26 @@ impl tk::Trainer for Trainer {
|
||||
.should_show_progress()
|
||||
}
|
||||
|
||||
fn train(&self, words: HashMap<String, u32>) -> tk::Result<(Self::Model, Vec<tk::AddedToken>)> {
|
||||
let (model, special_tokens) = self
|
||||
fn train(
|
||||
&self,
|
||||
words: HashMap<String, u32>,
|
||||
model: &mut Self::Model,
|
||||
) -> tk::Result<Vec<tk::AddedToken>> {
|
||||
let special_tokens = self
|
||||
.trainer
|
||||
.as_ref()
|
||||
.ok_or("Uninitialized Trainer")?
|
||||
.train(words)?;
|
||||
.train(
|
||||
words,
|
||||
&mut model
|
||||
.model
|
||||
.as_ref()
|
||||
.ok_or("Uninitialized Model")?
|
||||
.write()
|
||||
.unwrap(),
|
||||
)?;
|
||||
|
||||
Ok((model.into(), special_tokens))
|
||||
Ok(special_tokens)
|
||||
}
|
||||
|
||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||
@ -238,6 +259,81 @@ fn wordpiece_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {
|
||||
Ok(js_trainer)
|
||||
}
|
||||
|
||||
// WordLevel
|
||||
|
||||
struct WordLevelTrainerOptions(WordLevelTrainer);
|
||||
impl From<WordLevelTrainerOptions> for WordLevelTrainer {
|
||||
fn from(v: WordLevelTrainerOptions) -> Self {
|
||||
v.0
|
||||
}
|
||||
}
|
||||
impl FromJsValue for WordLevelTrainerOptions {
|
||||
fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult<Self> {
|
||||
if let Ok(options) = from.downcast::<JsObject>() {
|
||||
let mut builder = WordLevelTrainer::builder();
|
||||
|
||||
if let Ok(size) = options.get(cx, "vocabSize") {
|
||||
if let Some(size) = Option::from_value(size, cx)? {
|
||||
builder.vocab_size(size);
|
||||
}
|
||||
}
|
||||
if let Ok(freq) = options.get(cx, "minFrequency") {
|
||||
if let Some(freq) = Option::from_value(freq, cx)? {
|
||||
builder.min_frequency(freq);
|
||||
}
|
||||
}
|
||||
if let Ok(tokens) = options.get(cx, "specialTokens") {
|
||||
if tokens.downcast::<JsNull>().is_err() && tokens.downcast::<JsUndefined>().is_err()
|
||||
{
|
||||
builder.special_tokens(
|
||||
tokens
|
||||
.downcast::<JsArray>()
|
||||
.map_err(|e| Error(format!("{}", e)))?
|
||||
.to_vec(cx)?
|
||||
.into_iter()
|
||||
.map(|token| Ok(AddedToken::from_value(token, cx)?.into()))
|
||||
.collect::<Result<Vec<_>, Error>>()?,
|
||||
);
|
||||
}
|
||||
}
|
||||
if let Ok(show) = options.get(cx, "showProgress") {
|
||||
if let Some(show) = Option::from_value(show, cx)? {
|
||||
builder.show_progress(show);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self(
|
||||
builder
|
||||
.build()
|
||||
.expect("WordLevelTrainerBuilder cannot fail"),
|
||||
))
|
||||
} else {
|
||||
Err(Error("Expected options type: object".into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// wordlevel_trainer(options?: {
|
||||
/// vocabSize?: number = 30000,
|
||||
/// minFrequency?: number = 0,
|
||||
/// specialTokens?: string[] = [],
|
||||
/// showProgress?: bool = true,
|
||||
/// })
|
||||
fn wordlevel_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {
|
||||
let trainer = cx.extract_opt::<WordLevelTrainerOptions>(0)?.map_or_else(
|
||||
|| WordLevelTrainer::builder().build().unwrap(),
|
||||
|o| o.into(),
|
||||
);
|
||||
|
||||
let mut js_trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?;
|
||||
let guard = cx.lock();
|
||||
js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(trainer.into()));
|
||||
|
||||
Ok(js_trainer)
|
||||
}
|
||||
|
||||
// Unigram
|
||||
|
||||
struct UnigramTrainerOptions(UnigramTrainer);
|
||||
impl From<UnigramTrainerOptions> for UnigramTrainer {
|
||||
fn from(v: UnigramTrainerOptions) -> Self {
|
||||
@ -337,6 +433,7 @@ fn unigram_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {
|
||||
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||
m.export_function(&format!("{}_BPETrainer", prefix), bpe_trainer)?;
|
||||
m.export_function(&format!("{}_WordPieceTrainer", prefix), wordpiece_trainer)?;
|
||||
m.export_function(&format!("{}_WordLevelTrainer", prefix), wordlevel_trainer)?;
|
||||
m.export_function(&format!("{}_UnigramTrainer", prefix), unigram_trainer)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -301,34 +301,41 @@ impl PyWordLevelTrainer {
|
||||
#[new]
|
||||
#[args(kwargs = "**")]
|
||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
|
||||
let mut trainer = tk::models::wordlevel::WordLevelTrainer::default();
|
||||
let mut builder = tk::models::wordlevel::WordLevelTrainer::builder();
|
||||
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, val) in kwargs {
|
||||
let key: &str = key.extract()?;
|
||||
match key {
|
||||
"vocab_size" => trainer.vocab_size = val.extract()?,
|
||||
"min_frequency" => trainer.min_frequency = val.extract()?,
|
||||
"show_progress" => trainer.show_progress = val.extract()?,
|
||||
"vocab_size" => {
|
||||
builder.vocab_size(val.extract()?);
|
||||
}
|
||||
"min_frequency" => {
|
||||
builder.min_frequency(val.extract()?);
|
||||
}
|
||||
"show_progress" => {
|
||||
builder.show_progress(val.extract()?);
|
||||
}
|
||||
"special_tokens" => {
|
||||
trainer.special_tokens = val
|
||||
.cast_as::<PyList>()?
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(PyAddedToken::from(content, Some(true)).get_token())
|
||||
} else if let Ok(mut token) =
|
||||
token.extract::<PyRefMut<PyAddedToken>>()
|
||||
{
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"special_tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?
|
||||
builder.special_tokens(
|
||||
val.cast_as::<PyList>()?
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
if let Ok(content) = token.extract::<String>() {
|
||||
Ok(PyAddedToken::from(content, Some(true)).get_token())
|
||||
} else if let Ok(mut token) =
|
||||
token.extract::<PyRefMut<PyAddedToken>>()
|
||||
{
|
||||
token.is_special_token = true;
|
||||
Ok(token.get_token())
|
||||
} else {
|
||||
Err(exceptions::PyTypeError::new_err(
|
||||
"special_tokens must be a List[Union[str, AddedToken]]",
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?,
|
||||
);
|
||||
}
|
||||
_ => println!("Ignored unknown kwargs option {}", key),
|
||||
}
|
||||
@ -337,7 +344,12 @@ impl PyWordLevelTrainer {
|
||||
|
||||
Ok((
|
||||
PyWordLevelTrainer {},
|
||||
PyTrainer::new(Arc::new(trainer.into())),
|
||||
PyTrainer::new(Arc::new(
|
||||
builder
|
||||
.build()
|
||||
.expect("WordLevelTrainerBuilder cannot fail")
|
||||
.into(),
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
@ -2,15 +2,20 @@ use super::WordLevel;
|
||||
use crate::{AddedToken, Result, Trainer};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Builder)]
|
||||
pub struct WordLevelTrainer {
|
||||
/// The minimum frequency a word must have to be part of the vocabulary
|
||||
pub min_frequency: u32,
|
||||
#[builder(default)]
|
||||
min_frequency: u32,
|
||||
/// The target vocabulary size
|
||||
pub vocab_size: usize,
|
||||
#[builder(default)]
|
||||
vocab_size: usize,
|
||||
/// Whether to show progress while training
|
||||
pub show_progress: bool,
|
||||
#[builder(default)]
|
||||
show_progress: bool,
|
||||
/// A list of special tokens that the model should know of
|
||||
pub special_tokens: Vec<AddedToken>,
|
||||
#[builder(default)]
|
||||
special_tokens: Vec<AddedToken>,
|
||||
}
|
||||
|
||||
impl Default for WordLevelTrainer {
|
||||
@ -25,6 +30,10 @@ impl Default for WordLevelTrainer {
|
||||
}
|
||||
|
||||
impl WordLevelTrainer {
|
||||
pub fn builder() -> WordLevelTrainerBuilder {
|
||||
WordLevelTrainerBuilder::default()
|
||||
}
|
||||
|
||||
fn train(
|
||||
&self,
|
||||
word_counts: HashMap<String, u32>,
|
||||
|
Reference in New Issue
Block a user