BPE handles UNK token

This commit is contained in:
Anthony MOI
2020-01-01 14:49:03 -05:00
parent 75713ce809
commit 722b61230d
2 changed files with 30 additions and 15 deletions

View File

@ -70,7 +70,7 @@ impl BPE {
#[staticmethod]
fn empty() -> Model {
Model {
model: Container::Owned(Box::new(tk::models::bpe::BPE::empty())),
model: Container::Owned(Box::new(tk::models::bpe::BPE::default())),
}
}
}

View File

@ -22,15 +22,33 @@ pub struct BPE {
/// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will
/// perform no merges, so the result will just be characters.
dropout: Option<f32>,
/// The unknown token to be used when we encounter an unknown char
unk_token: Option<u32>,
}
impl Default for BPE {
fn default() -> Self {
Self {
vocab: HashMap::new(),
vocab_r: HashMap::new(),
merges: HashMap::new(),
cache: Cache::new(),
dropout: None,
unk_token: None,
}
}
}
impl Clone for BPE {
fn clone(&self) -> Self {
BPE::new(
self.vocab.clone(),
self.vocab_r.clone(),
self.merges.clone(),
)
Self {
vocab: self.vocab.clone(),
vocab_r: self.vocab_r.clone(),
merges: self.merges.clone(),
cache: Cache::new(),
dropout: self.dropout,
unk_token: self.unk_token,
}
}
}
@ -44,8 +62,7 @@ impl BPE {
vocab,
vocab_r,
merges,
cache: Cache::new(),
dropout: None,
..Default::default()
}
}
@ -63,16 +80,12 @@ impl BPE {
vocab,
vocab_r,
merges,
cache: Cache::new(),
dropout: if dropout == 0.0 { None } else { Some(dropout) },
..Default::default()
})
}
}
pub fn empty() -> Self {
BPE::new(HashMap::new(), HashMap::new(), HashMap::new())
}
pub fn from_files(vocab: &str, merges: &str) -> Result<Self> {
// Read vocab.json
let vocab_file = File::open(vocab)?;
@ -129,8 +142,7 @@ impl BPE {
vocab: vocab.clone(),
vocab_r: vocab.into_iter().map(|(token, id)| (id, token)).collect(),
merges,
cache: Cache::new(),
dropout: None,
..Default::default()
})
}
@ -139,6 +151,9 @@ impl BPE {
for c in w.chars() {
if let Some(id) = self.vocab.get(&c.to_string()) {
word.add(*id);
} else if let Some(unk) = &self.unk_token {
// Handle UNK token
word.add(*unk);
}
}