From 722b61230dae6ad2b11206cc528e648948736b74 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Wed, 1 Jan 2020 14:49:03 -0500 Subject: [PATCH] BPE handles UNK token --- bindings/python/src/models.rs | 2 +- tokenizers/src/models/bpe/model.rs | 43 ++++++++++++++++++++---------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 5e40fc82..65304094 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -70,7 +70,7 @@ impl BPE { #[staticmethod] fn empty() -> Model { Model { - model: Container::Owned(Box::new(tk::models::bpe::BPE::empty())), + model: Container::Owned(Box::new(tk::models::bpe::BPE::default())), } } } diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 416db87b..9117e3c9 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -22,15 +22,33 @@ pub struct BPE { /// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will /// perform no merges, so the result will just be characters. dropout: Option, + /// The unknown token to be used when we encounter an unknown char + unk_token: Option, +} + +impl Default for BPE { + fn default() -> Self { + Self { + vocab: HashMap::new(), + vocab_r: HashMap::new(), + merges: HashMap::new(), + cache: Cache::new(), + dropout: None, + unk_token: None, + } + } } impl Clone for BPE { fn clone(&self) -> Self { - BPE::new( - self.vocab.clone(), - self.vocab_r.clone(), - self.merges.clone(), - ) + Self { + vocab: self.vocab.clone(), + vocab_r: self.vocab_r.clone(), + merges: self.merges.clone(), + cache: Cache::new(), + dropout: self.dropout, + unk_token: self.unk_token, + } } } @@ -44,8 +62,7 @@ impl BPE { vocab, vocab_r, merges, - cache: Cache::new(), - dropout: None, + ..Default::default() } } @@ -63,16 +80,12 @@ impl BPE { vocab, vocab_r, merges, - cache: Cache::new(), dropout: if dropout == 0.0 { None } else { Some(dropout) }, + ..Default::default() }) } } - pub fn empty() -> Self { - BPE::new(HashMap::new(), HashMap::new(), HashMap::new()) - } - pub fn from_files(vocab: &str, merges: &str) -> Result { // Read vocab.json let vocab_file = File::open(vocab)?; @@ -129,8 +142,7 @@ impl BPE { vocab: vocab.clone(), vocab_r: vocab.into_iter().map(|(token, id)| (id, token)).collect(), merges, - cache: Cache::new(), - dropout: None, + ..Default::default() }) } @@ -139,6 +151,9 @@ impl BPE { for c in w.chars() { if let Some(id) = self.vocab.get(&c.to_string()) { word.add(*id); + } else if let Some(unk) = &self.unk_token { + // Handle UNK token + word.add(*unk); } }