mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 16:49:27 +00:00
Node - Add WordLevelTrainer
This commit is contained in:
29
bindings/node/lib/bindings/trainers.d.ts
vendored
29
bindings/node/lib/bindings/trainers.d.ts
vendored
@ -63,6 +63,35 @@ export function bpeTrainer(options?: TrainerOptions): Trainer;
|
|||||||
*/
|
*/
|
||||||
export function wordPieceTrainer(options?: TrainerOptions): Trainer;
|
export function wordPieceTrainer(options?: TrainerOptions): Trainer;
|
||||||
|
|
||||||
|
export interface WordLevelTrainerOptions {
|
||||||
|
/**
|
||||||
|
* The minimum frequency a pair should have in order to be merged.
|
||||||
|
* @default 2
|
||||||
|
*/
|
||||||
|
minFrequency?: number;
|
||||||
|
/**
|
||||||
|
* Whether to show progress bars while training.
|
||||||
|
* @default true
|
||||||
|
*/
|
||||||
|
showProgress?: boolean;
|
||||||
|
/**
|
||||||
|
* A list of special tokens the model should know of.
|
||||||
|
* @default []
|
||||||
|
*/
|
||||||
|
specialTokens?: (string | AddedToken)[];
|
||||||
|
/**
|
||||||
|
* The size of the final vocabulary, including all tokens and alphabet.
|
||||||
|
* @default 30000
|
||||||
|
*/
|
||||||
|
vocabSize?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Instantiate a new WordLevel Trainer
|
||||||
|
* @param [options] WordLevel Trainer options
|
||||||
|
*/
|
||||||
|
export function wordLevelTrainer(options?: WordLevelTrainerOptions): Trainer;
|
||||||
|
|
||||||
export interface UnigramTrainerOptions {
|
export interface UnigramTrainerOptions {
|
||||||
vocabSize?: number;
|
vocabSize?: number;
|
||||||
nSubIterations?: number;
|
nSubIterations?: number;
|
||||||
|
@ -3,5 +3,6 @@ const native = require("./native");
|
|||||||
module.exports = {
|
module.exports = {
|
||||||
bpeTrainer: native.trainers_BPETrainer,
|
bpeTrainer: native.trainers_BPETrainer,
|
||||||
wordPieceTrainer: native.trainers_WordPieceTrainer,
|
wordPieceTrainer: native.trainers_WordPieceTrainer,
|
||||||
|
wordLevelTrainer: native.trainers_WordLevelTrainer,
|
||||||
unigramTrainer: native.trainers_UnigramTrainer,
|
unigramTrainer: native.trainers_UnigramTrainer,
|
||||||
};
|
};
|
||||||
|
@ -8,7 +8,8 @@ use std::collections::HashMap;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use tk::models::{
|
use tk::models::{
|
||||||
bpe::BpeTrainer, unigram::UnigramTrainer, wordpiece::WordPieceTrainer, TrainerWrapper,
|
bpe::BpeTrainer, unigram::UnigramTrainer, wordlevel::WordLevelTrainer,
|
||||||
|
wordpiece::WordPieceTrainer, TrainerWrapper,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Trainer
|
/// Trainer
|
||||||
@ -17,6 +18,14 @@ pub struct Trainer {
|
|||||||
pub trainer: Option<Arc<TrainerWrapper>>,
|
pub trainer: Option<Arc<TrainerWrapper>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<TrainerWrapper> for Trainer {
|
||||||
|
fn from(trainer: TrainerWrapper) -> Self {
|
||||||
|
Self {
|
||||||
|
trainer: Some(Arc::new(trainer)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl tk::Trainer for Trainer {
|
impl tk::Trainer for Trainer {
|
||||||
type Model = Model;
|
type Model = Model;
|
||||||
|
|
||||||
@ -27,14 +36,26 @@ impl tk::Trainer for Trainer {
|
|||||||
.should_show_progress()
|
.should_show_progress()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn train(&self, words: HashMap<String, u32>) -> tk::Result<(Self::Model, Vec<tk::AddedToken>)> {
|
fn train(
|
||||||
let (model, special_tokens) = self
|
&self,
|
||||||
|
words: HashMap<String, u32>,
|
||||||
|
model: &mut Self::Model,
|
||||||
|
) -> tk::Result<Vec<tk::AddedToken>> {
|
||||||
|
let special_tokens = self
|
||||||
.trainer
|
.trainer
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.ok_or("Uninitialized Trainer")?
|
.ok_or("Uninitialized Trainer")?
|
||||||
.train(words)?;
|
.train(
|
||||||
|
words,
|
||||||
|
&mut model
|
||||||
|
.model
|
||||||
|
.as_ref()
|
||||||
|
.ok_or("Uninitialized Model")?
|
||||||
|
.write()
|
||||||
|
.unwrap(),
|
||||||
|
)?;
|
||||||
|
|
||||||
Ok((model.into(), special_tokens))
|
Ok(special_tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||||
@ -238,6 +259,81 @@ fn wordpiece_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {
|
|||||||
Ok(js_trainer)
|
Ok(js_trainer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WordLevel
|
||||||
|
|
||||||
|
struct WordLevelTrainerOptions(WordLevelTrainer);
|
||||||
|
impl From<WordLevelTrainerOptions> for WordLevelTrainer {
|
||||||
|
fn from(v: WordLevelTrainerOptions) -> Self {
|
||||||
|
v.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl FromJsValue for WordLevelTrainerOptions {
|
||||||
|
fn from_value<'c, C: Context<'c>>(from: Handle<'c, JsValue>, cx: &mut C) -> LibResult<Self> {
|
||||||
|
if let Ok(options) = from.downcast::<JsObject>() {
|
||||||
|
let mut builder = WordLevelTrainer::builder();
|
||||||
|
|
||||||
|
if let Ok(size) = options.get(cx, "vocabSize") {
|
||||||
|
if let Some(size) = Option::from_value(size, cx)? {
|
||||||
|
builder.vocab_size(size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Ok(freq) = options.get(cx, "minFrequency") {
|
||||||
|
if let Some(freq) = Option::from_value(freq, cx)? {
|
||||||
|
builder.min_frequency(freq);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Ok(tokens) = options.get(cx, "specialTokens") {
|
||||||
|
if tokens.downcast::<JsNull>().is_err() && tokens.downcast::<JsUndefined>().is_err()
|
||||||
|
{
|
||||||
|
builder.special_tokens(
|
||||||
|
tokens
|
||||||
|
.downcast::<JsArray>()
|
||||||
|
.map_err(|e| Error(format!("{}", e)))?
|
||||||
|
.to_vec(cx)?
|
||||||
|
.into_iter()
|
||||||
|
.map(|token| Ok(AddedToken::from_value(token, cx)?.into()))
|
||||||
|
.collect::<Result<Vec<_>, Error>>()?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Ok(show) = options.get(cx, "showProgress") {
|
||||||
|
if let Some(show) = Option::from_value(show, cx)? {
|
||||||
|
builder.show_progress(show);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self(
|
||||||
|
builder
|
||||||
|
.build()
|
||||||
|
.expect("WordLevelTrainerBuilder cannot fail"),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
Err(Error("Expected options type: object".into()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// wordlevel_trainer(options?: {
|
||||||
|
/// vocabSize?: number = 30000,
|
||||||
|
/// minFrequency?: number = 0,
|
||||||
|
/// specialTokens?: string[] = [],
|
||||||
|
/// showProgress?: bool = true,
|
||||||
|
/// })
|
||||||
|
fn wordlevel_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {
|
||||||
|
let trainer = cx.extract_opt::<WordLevelTrainerOptions>(0)?.map_or_else(
|
||||||
|
|| WordLevelTrainer::builder().build().unwrap(),
|
||||||
|
|o| o.into(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut js_trainer = JsTrainer::new::<_, JsTrainer, _>(&mut cx, vec![])?;
|
||||||
|
let guard = cx.lock();
|
||||||
|
js_trainer.borrow_mut(&guard).trainer = Some(Arc::new(trainer.into()));
|
||||||
|
|
||||||
|
Ok(js_trainer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unigram
|
||||||
|
|
||||||
struct UnigramTrainerOptions(UnigramTrainer);
|
struct UnigramTrainerOptions(UnigramTrainer);
|
||||||
impl From<UnigramTrainerOptions> for UnigramTrainer {
|
impl From<UnigramTrainerOptions> for UnigramTrainer {
|
||||||
fn from(v: UnigramTrainerOptions) -> Self {
|
fn from(v: UnigramTrainerOptions) -> Self {
|
||||||
@ -337,6 +433,7 @@ fn unigram_trainer(mut cx: FunctionContext) -> JsResult<JsTrainer> {
|
|||||||
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
pub fn register(m: &mut ModuleContext, prefix: &str) -> NeonResult<()> {
|
||||||
m.export_function(&format!("{}_BPETrainer", prefix), bpe_trainer)?;
|
m.export_function(&format!("{}_BPETrainer", prefix), bpe_trainer)?;
|
||||||
m.export_function(&format!("{}_WordPieceTrainer", prefix), wordpiece_trainer)?;
|
m.export_function(&format!("{}_WordPieceTrainer", prefix), wordpiece_trainer)?;
|
||||||
|
m.export_function(&format!("{}_WordLevelTrainer", prefix), wordlevel_trainer)?;
|
||||||
m.export_function(&format!("{}_UnigramTrainer", prefix), unigram_trainer)?;
|
m.export_function(&format!("{}_UnigramTrainer", prefix), unigram_trainer)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -301,34 +301,41 @@ impl PyWordLevelTrainer {
|
|||||||
#[new]
|
#[new]
|
||||||
#[args(kwargs = "**")]
|
#[args(kwargs = "**")]
|
||||||
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
|
pub fn new(kwargs: Option<&PyDict>) -> PyResult<(Self, PyTrainer)> {
|
||||||
let mut trainer = tk::models::wordlevel::WordLevelTrainer::default();
|
let mut builder = tk::models::wordlevel::WordLevelTrainer::builder();
|
||||||
|
|
||||||
if let Some(kwargs) = kwargs {
|
if let Some(kwargs) = kwargs {
|
||||||
for (key, val) in kwargs {
|
for (key, val) in kwargs {
|
||||||
let key: &str = key.extract()?;
|
let key: &str = key.extract()?;
|
||||||
match key {
|
match key {
|
||||||
"vocab_size" => trainer.vocab_size = val.extract()?,
|
"vocab_size" => {
|
||||||
"min_frequency" => trainer.min_frequency = val.extract()?,
|
builder.vocab_size(val.extract()?);
|
||||||
"show_progress" => trainer.show_progress = val.extract()?,
|
}
|
||||||
|
"min_frequency" => {
|
||||||
|
builder.min_frequency(val.extract()?);
|
||||||
|
}
|
||||||
|
"show_progress" => {
|
||||||
|
builder.show_progress(val.extract()?);
|
||||||
|
}
|
||||||
"special_tokens" => {
|
"special_tokens" => {
|
||||||
trainer.special_tokens = val
|
builder.special_tokens(
|
||||||
.cast_as::<PyList>()?
|
val.cast_as::<PyList>()?
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|token| {
|
.map(|token| {
|
||||||
if let Ok(content) = token.extract::<String>() {
|
if let Ok(content) = token.extract::<String>() {
|
||||||
Ok(PyAddedToken::from(content, Some(true)).get_token())
|
Ok(PyAddedToken::from(content, Some(true)).get_token())
|
||||||
} else if let Ok(mut token) =
|
} else if let Ok(mut token) =
|
||||||
token.extract::<PyRefMut<PyAddedToken>>()
|
token.extract::<PyRefMut<PyAddedToken>>()
|
||||||
{
|
{
|
||||||
token.is_special_token = true;
|
token.is_special_token = true;
|
||||||
Ok(token.get_token())
|
Ok(token.get_token())
|
||||||
} else {
|
} else {
|
||||||
Err(exceptions::PyTypeError::new_err(
|
Err(exceptions::PyTypeError::new_err(
|
||||||
"special_tokens must be a List[Union[str, AddedToken]]",
|
"special_tokens must be a List[Union[str, AddedToken]]",
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect::<PyResult<Vec<_>>>()?
|
.collect::<PyResult<Vec<_>>>()?,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
_ => println!("Ignored unknown kwargs option {}", key),
|
_ => println!("Ignored unknown kwargs option {}", key),
|
||||||
}
|
}
|
||||||
@ -337,7 +344,12 @@ impl PyWordLevelTrainer {
|
|||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
PyWordLevelTrainer {},
|
PyWordLevelTrainer {},
|
||||||
PyTrainer::new(Arc::new(trainer.into())),
|
PyTrainer::new(Arc::new(
|
||||||
|
builder
|
||||||
|
.build()
|
||||||
|
.expect("WordLevelTrainerBuilder cannot fail")
|
||||||
|
.into(),
|
||||||
|
)),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,15 +2,20 @@ use super::WordLevel;
|
|||||||
use crate::{AddedToken, Result, Trainer};
|
use crate::{AddedToken, Result, Trainer};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Builder)]
|
||||||
pub struct WordLevelTrainer {
|
pub struct WordLevelTrainer {
|
||||||
/// The minimum frequency a word must have to be part of the vocabulary
|
/// The minimum frequency a word must have to be part of the vocabulary
|
||||||
pub min_frequency: u32,
|
#[builder(default)]
|
||||||
|
min_frequency: u32,
|
||||||
/// The target vocabulary size
|
/// The target vocabulary size
|
||||||
pub vocab_size: usize,
|
#[builder(default)]
|
||||||
|
vocab_size: usize,
|
||||||
/// Whether to show progress while training
|
/// Whether to show progress while training
|
||||||
pub show_progress: bool,
|
#[builder(default)]
|
||||||
|
show_progress: bool,
|
||||||
/// A list of special tokens that the model should know of
|
/// A list of special tokens that the model should know of
|
||||||
pub special_tokens: Vec<AddedToken>,
|
#[builder(default)]
|
||||||
|
special_tokens: Vec<AddedToken>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for WordLevelTrainer {
|
impl Default for WordLevelTrainer {
|
||||||
@ -25,6 +30,10 @@ impl Default for WordLevelTrainer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl WordLevelTrainer {
|
impl WordLevelTrainer {
|
||||||
|
pub fn builder() -> WordLevelTrainerBuilder {
|
||||||
|
WordLevelTrainerBuilder::default()
|
||||||
|
}
|
||||||
|
|
||||||
fn train(
|
fn train(
|
||||||
&self,
|
&self,
|
||||||
word_counts: HashMap<String, u32>,
|
word_counts: HashMap<String, u32>,
|
||||||
|
Reference in New Issue
Block a user