Improve the way BpeTrainer creates a BPE

Before these modifications, the BPE constructed by the BpeTrainer didn't
account for some of its configuration. So if the Trainer was using some
suffix or prefix for example, it didn't instantiate the produced BPE
with these options.
Also any special tokens that are used during training, should be
automatically added to the Tokenizer with the Model. This is now handled
properly.
This commit is contained in:
Anthony MOI
2020-01-17 17:41:35 -05:00
parent d1ae0bd576
commit 3895d14bf9
3 changed files with 43 additions and 17 deletions

View File

@ -316,7 +316,7 @@ impl BpeTrainer {
(words, counts)
}
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<BPE> {
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<(BPE, Vec<String>)> {
let mut word_to_id: HashMap<String, u32> = HashMap::new();
let mut id_to_word: Vec<String> = vec![];
@ -456,22 +456,37 @@ impl BpeTrainer {
}
self.finalize_progress(&progress, merges.len());
Ok(BPE::new(
let mut builder = BPE::builder().vocab_and_merges(
word_to_id,
merges
.into_iter()
.enumerate()
.map(|(index, (pair, new_id))| (pair, (index as u32, new_id)))
.collect(),
);
if let Some(prefix) = &self.continuing_subword_prefix {
builder = builder.continuing_subword_prefix(prefix.to_owned());
}
if let Some(suffix) = &self.end_of_word_suffix {
builder = builder.end_of_word_suffix(suffix.to_owned());
}
Ok((
builder
.build()
.expect("Trainer should know how to build BPE"),
self.special_tokens.clone(),
))
}
}
impl Trainer for BpeTrainer {
/// Train a BPE model
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
let bpe = self.train(word_counts)?;
Ok(Box::new(bpe))
fn train(
&self,
word_counts: HashMap<String, u32>,
) -> Result<(Box<dyn Model + Sync>, Vec<String>)> {
let (bpe, tokens) = self.train(word_counts)?;
Ok((Box::new(bpe), tokens))
}
/// Process a bunch of tokens, counting them

View File

@ -89,16 +89,19 @@ impl WordPieceTrainer {
WordPieceTrainerBuilder::default()
}
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<WordPiece> {
let bpe = self.bpe_trainer.train(word_counts)?;
Ok(WordPiece::from_bpe(&bpe))
pub fn train(&self, word_counts: HashMap<String, u32>) -> Result<(WordPiece, Vec<String>)> {
let (bpe, tokens) = self.bpe_trainer.train(word_counts)?;
Ok((WordPiece::from_bpe(&bpe), tokens))
}
}
impl Trainer for WordPieceTrainer {
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
let wp = self.train(word_counts)?;
Ok(Box::new(wp))
fn train(
&self,
word_counts: HashMap<String, u32>,
) -> Result<(Box<dyn Model + Sync>, Vec<String>)> {
let (wp, tokens) = self.train(word_counts)?;
Ok((Box::new(wp), tokens))
}
fn process_tokens(&self, mut words: &mut HashMap<String, u32>, tokens: Vec<String>) {

View File

@ -62,8 +62,12 @@ pub trait Decoder {
/// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences
/// and it returns a `Model` when done.
pub trait Trainer: Sync {
/// Whether we should show progress during the training.
fn should_show_progress(&self) -> bool;
fn train(&self, words: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>>;
/// The actual training method. This will return a new trained Model as well as a list
/// of `special_tokens` to be added directly to the tokenizer along with the model.
fn train(&self, words: HashMap<String, u32>) -> Result<(Box<dyn Model + Sync>, Vec<String>)>;
/// Process a bunch of token, counting them as relevant.
fn process_tokens(&self, words: &mut HashMap<String, u32>, tokens: Vec<String>);
}
@ -466,7 +470,9 @@ impl Tokenizer {
}
}
self.model = trainer.train(words)?;
let (model, special_tokens) = trainer.train(words)?;
self.model = model;
self.add_special_tokens(&special_tokens);
Ok(())
}
@ -552,16 +558,18 @@ impl Tokenizer {
/// Register the given tokens as special tokens. This is especially useful for removing
/// these special tokens while decoding
pub fn add_special_tokens(&mut self, tokens: &[&str]) -> usize {
pub fn add_special_tokens<T: AsRef<str>>(&mut self, tokens: &[T]) -> usize {
let added_tokens = tokens
.iter()
.map(|t| AddedToken::from((*t).to_owned()))
.map(|t| AddedToken::from(t.as_ref().to_owned()))
.collect::<Vec<_>>();
let added = self.add_tokens(&added_tokens);
for token in tokens {
if let Some(id) = self.token_to_id(token) {
self.special_tokens.entry((*token).to_owned()).or_insert(id);
if let Some(id) = self.token_to_id(token.as_ref()) {
self.special_tokens
.entry(token.as_ref().to_owned())
.or_insert(id);
}
}