Modify BPE with unk_token being a String

This commit is contained in:
Anthony MOI
2020-01-07 19:22:29 -05:00
parent b17f9d8872
commit 03c431c60e
4 changed files with 24 additions and 17 deletions

View File

@ -25,7 +25,7 @@ class BPE:
merges: str,
cache_capacity: Optional[int],
dropout: Optional[float],
unk_token: Optional[int],
unk_token: Optional[str],
continuing_subword_prefix: Optional[str],
end_of_word_suffix: Optional[str]) -> Model:
""" Instantiate a BPE Model from the given vocab and merges files.
@ -45,8 +45,8 @@ class BPE:
dropout: (`optional`) Optional[float] [0, 1]:
The BPE dropout to use. Must be an float between 0 and 1
unk_token: (`optional`) int:
The unknown token id to be used by the model.
unk_token: (`optional`) str:
The unknown token to be used by the model.
continuing_subword_prefix: (`optional`) str:
The prefix to attach to subword units that don't represent a beginning of word.

View File

@ -22,6 +22,8 @@ pub enum Error {
BadMerges(usize),
/// If a token found in merges, is not in the vocab
MergeTokenOutOfVocabulary(String),
/// If the provided unk token is out of vocabulary
UnkTokenOutOfVocabulary(String),
/// Dropout not between 0 and 1.
InvalidDropout,
}
@ -46,7 +48,10 @@ impl std::fmt::Display for Error {
Error::BadVocabulary => write!(f, "Bad vocabulary json file"),
Error::BadMerges(line) => write!(f, "Merges text file invalid at line {}", line),
Error::MergeTokenOutOfVocabulary(token) => {
write!(f, "Token {} out of vocabulary", token)
write!(f, "Token `{}` out of vocabulary", token)
}
Error::UnkTokenOutOfVocabulary(token) => {
write!(f, "Unk token `{}` not found in the vocabulary", token)
}
Error::InvalidDropout => write!(f, "Dropout should be between 0 and 1"),
}

View File

@ -16,7 +16,7 @@ struct Config {
merges: Option<HashMap<Pair, (u32, u32)>>,
cache_capacity: Option<usize>,
dropout: Option<f32>,
unk_token: Option<u32>,
unk_token: Option<String>,
continuing_subword_prefix: Option<String>,
end_of_word_suffix: Option<String>,
}
@ -57,7 +57,7 @@ impl BpeBuilder {
}
/// Set the `UNK` token for the vocab.
pub fn unk_token(mut self, unk_token: u32) -> Self {
pub fn unk_token(mut self, unk_token: String) -> Self {
self.config.unk_token = Some(unk_token);
self
}
@ -130,7 +130,7 @@ pub struct BPE {
/// perform no merges, so the result will just be characters.
dropout: Option<f32>,
/// The unknown token to be used when we encounter an unknown char
unk_token: Option<u32>,
unk_token: Option<String>,
/// An optional prefix to use on any subword that exist only behind another one
continuing_subword_prefix: Option<String>,
/// An optional suffix to caracterize and end-of-word subword
@ -166,7 +166,7 @@ impl Clone for BPE {
merges: self.merges.clone(),
cache: fresh_cache,
dropout: self.dropout,
unk_token: self.unk_token,
unk_token: self.unk_token.clone(),
continuing_subword_prefix: self.continuing_subword_prefix.clone(),
end_of_word_suffix: self.end_of_word_suffix.clone(),
}
@ -254,7 +254,7 @@ impl BPE {
&self.vocab
}
pub fn get_unk_token(&self) -> &Option<u32> {
pub fn get_unk_token(&self) -> &Option<String> {
&self.unk_token
}
@ -262,7 +262,7 @@ impl BPE {
&self.continuing_subword_prefix
}
fn merge_word(&self, w: &str) -> Word {
fn merge_word(&self, w: &str) -> Result<Word> {
let mut word = Word::new();
for (is_first, is_last, c) in w.chars().with_first_and_last() {
let mut s = c.to_string();
@ -283,8 +283,12 @@ impl BPE {
if let Some(id) = self.vocab.get(&s) {
word.add(*id);
} else if let Some(unk) = &self.unk_token {
let unk_id = self
.vocab
.get(unk)
.ok_or_else(|| Error::UnkTokenOutOfVocabulary(unk.to_owned()))?;
// Handle UNK token
word.add(*unk);
word.add(*unk_id);
}
}
@ -328,7 +332,7 @@ impl BPE {
word.merge(pair.0, pair.1, *new_id);
}
word
Ok(word)
}
fn word_to_tokens(&self, word: &Word, initial_offsets: &(usize, usize)) -> Vec<Token> {
@ -370,14 +374,14 @@ impl Model for BPE {
let tokens = match cached_words {
None => {
// Not using cache since we're using dropout.
let word = self.merge_word(&w);
let word = self.merge_word(&w)?;
self.word_to_tokens(&word, initial_offsets)
}
Some(ref mut cache) => {
match cache[i].as_ref() {
None => {
// No cache hit, so re-compute merges.
let word = self.merge_word(&w);
let word = self.merge_word(&w)?;
let tokens = self.word_to_tokens(&word, initial_offsets);
// Add to cache.
cache[i] = Some(word);

View File

@ -156,9 +156,7 @@ impl WordPiece {
pub fn from_bpe(bpe: &BPE) -> Self {
let mut wp = Self::builder().vocab(bpe.get_vocab().clone()).build();
if let Some(unk) = bpe.get_unk_token() {
if let Some(unk_token) = wp.vocab_r.get(&unk) {
wp.unk_token = unk_token.to_owned();
}
wp.unk_token = unk.to_owned();
}
if let Some(prefix) = bpe.get_continuing_subword_prefix() {
wp.continuing_subword_prefix = prefix.to_owned();