mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Improve the way BpeTrainer creates a BPE
Before these modifications, the BPE constructed by the BpeTrainer didn't account for some of its configuration. So if the Trainer was using some suffix or prefix for example, it didn't instantiate the produced BPE with these options. Also any special tokens that are used during training, should be automatically added to the Tokenizer with the Model. This is now handled properly.
This commit is contained in:
@ -316,7 +316,7 @@ impl BpeTrainer {
|
||||
(words, counts)
|
||||
}
|
||||
|
||||
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<BPE> {
|
||||
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<(BPE, Vec<String>)> {
|
||||
let mut word_to_id: HashMap<String, u32> = HashMap::new();
|
||||
let mut id_to_word: Vec<String> = vec![];
|
||||
|
||||
@ -456,22 +456,37 @@ impl BpeTrainer {
|
||||
}
|
||||
self.finalize_progress(&progress, merges.len());
|
||||
|
||||
Ok(BPE::new(
|
||||
let mut builder = BPE::builder().vocab_and_merges(
|
||||
word_to_id,
|
||||
merges
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, (pair, new_id))| (pair, (index as u32, new_id)))
|
||||
.collect(),
|
||||
);
|
||||
if let Some(prefix) = &self.continuing_subword_prefix {
|
||||
builder = builder.continuing_subword_prefix(prefix.to_owned());
|
||||
}
|
||||
if let Some(suffix) = &self.end_of_word_suffix {
|
||||
builder = builder.end_of_word_suffix(suffix.to_owned());
|
||||
}
|
||||
Ok((
|
||||
builder
|
||||
.build()
|
||||
.expect("Trainer should know how to build BPE"),
|
||||
self.special_tokens.clone(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl Trainer for BpeTrainer {
|
||||
/// Train a BPE model
|
||||
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
|
||||
let bpe = self.train(word_counts)?;
|
||||
Ok(Box::new(bpe))
|
||||
fn train(
|
||||
&self,
|
||||
word_counts: HashMap<String, u32>,
|
||||
) -> Result<(Box<dyn Model + Sync>, Vec<String>)> {
|
||||
let (bpe, tokens) = self.train(word_counts)?;
|
||||
Ok((Box::new(bpe), tokens))
|
||||
}
|
||||
|
||||
/// Process a bunch of tokens, counting them
|
||||
|
@ -89,16 +89,19 @@ impl WordPieceTrainer {
|
||||
WordPieceTrainerBuilder::default()
|
||||
}
|
||||
|
||||
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<WordPiece> {
|
||||
let bpe = self.bpe_trainer.train(word_counts)?;
|
||||
Ok(WordPiece::from_bpe(&bpe))
|
||||
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<(WordPiece, Vec<String>)> {
|
||||
let (bpe, tokens) = self.bpe_trainer.train(word_counts)?;
|
||||
Ok((WordPiece::from_bpe(&bpe), tokens))
|
||||
}
|
||||
}
|
||||
|
||||
impl Trainer for WordPieceTrainer {
|
||||
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
|
||||
let wp = self.train(word_counts)?;
|
||||
Ok(Box::new(wp))
|
||||
fn train(
|
||||
&self,
|
||||
word_counts: HashMap<String, u32>,
|
||||
) -> Result<(Box<dyn Model + Sync>, Vec<String>)> {
|
||||
let (wp, tokens) = self.train(word_counts)?;
|
||||
Ok((Box::new(wp), tokens))
|
||||
}
|
||||
|
||||
fn process_tokens(&self, mut words: &mut HashMap<String, u32>, tokens: Vec<String>) {
|
||||
|
@ -62,8 +62,12 @@ pub trait Decoder {
|
||||
/// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences
|
||||
/// and it returns a `Model` when done.
|
||||
pub trait Trainer: Sync {
|
||||
/// Whether we should show progress during the training.
|
||||
fn should_show_progress(&self) -> bool;
|
||||
fn train(&self, words: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>>;
|
||||
/// The actual training method. This will return a new trained Model as well as a list
|
||||
/// of `special_tokens` to be added directly to the tokenizer along with the model.
|
||||
fn train(&self, words: HashMap<String, u32>) -> Result<(Box<dyn Model + Sync>, Vec<String>)>;
|
||||
/// Process a bunch of token, counting them as relevant.
|
||||
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
|
||||
}
|
||||
|
||||
@ -466,7 +470,9 @@ impl Tokenizer {
|
||||
}
|
||||
}
|
||||
|
||||
self.model = trainer.train(words)?;
|
||||
let (model, special_tokens) = trainer.train(words)?;
|
||||
self.model = model;
|
||||
self.add_special_tokens(&special_tokens);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -552,16 +558,18 @@ impl Tokenizer {
|
||||
|
||||
/// Register the given tokens as special tokens. This is especially useful for removing
|
||||
/// these special tokens while decoding
|
||||
pub fn add_special_tokens(&mut self, tokens: &[&str]) -> usize {
|
||||
pub fn add_special_tokens<T: AsRef<str>>(&mut self, tokens: &[T]) -> usize {
|
||||
let added_tokens = tokens
|
||||
.iter()
|
||||
.map(|t| AddedToken::from((*t).to_owned()))
|
||||
.map(|t| AddedToken::from(t.as_ref().to_owned()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let added = self.add_tokens(&added_tokens);
|
||||
for token in tokens {
|
||||
if let Some(id) = self.token_to_id(token) {
|
||||
self.special_tokens.entry((*token).to_owned()).or_insert(id);
|
||||
if let Some(id) = self.token_to_id(token.as_ref()) {
|
||||
self.special_tokens
|
||||
.entry(token.as_ref().to_owned())
|
||||
.or_insert(id);
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user