implement dropout for BPE

This commit is contained in:
epwalsh
2019-12-30 14:14:26 -08:00
parent 5194daa0ce
commit 0be9e5a7f0
3 changed files with 164 additions and 12 deletions

View File

@ -14,6 +14,7 @@ path = "src/cli.rs"
[dependencies]
lazy_static = "1.3.0"
rand = "0.7.2"
regex = "1.3.1"
regex-syntax = "0.6.12"
rayon = "1.2.0"

View File

@ -21,6 +21,8 @@ pub enum Error {
BadMerges(usize),
/// If a token found in merges, is not in the vocab
MergeTokenOutOfVocabulary(String),
/// Dropout not between 0 and 1.
InvalidDropout,
}
impl From<io::Error> for Error {
@ -45,6 +47,7 @@ impl std::fmt::Display for Error {
Error::MergeTokenOutOfVocabulary(token) => {
write!(f, "Token {} out of vocabulary", token)
}
Error::InvalidDropout => write!(f, "Dropout should be between 0 and 1"),
}
}
}
@ -54,9 +57,7 @@ impl std::error::Error for Error {
match self {
Error::Io(e) => Some(e),
Error::JsonError(e) => Some(e),
Error::BadVocabulary => None,
Error::BadMerges(_) => None,
Error::MergeTokenOutOfVocabulary(_) => None,
_ => None,
}
}
}

View File

@ -1,5 +1,6 @@
use super::{Cache, Error, Pair, Word};
use crate::tokenizer::{Model, Offsets, Result, Token};
use rand::{thread_rng, Rng};
use serde_json::Value;
use std::{
collections::HashMap,
@ -9,14 +10,17 @@ use std::{
};
pub struct BPE {
/// The vocabulary assigns a number to each token
/// The vocabulary assigns a number to each token.
vocab: HashMap<String, u32>,
/// Reversed vocabulary, to rebuild sentences
/// Reversed vocabulary, to rebuild sentences.
vocab_r: HashMap<u32, String>,
/// Contains the mapping between Pairs and their (rank, new_id)
/// Contains the mapping between Pairs and their (rank, new_id).
merges: HashMap<Pair, (u32, u32)>,
/// Contains the cache for optimizing the encoding step
/// Contains the cache for optimizing the encoding step.
cache: Cache<String, Word>,
/// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will
/// perform no merges, so the result will just be characters.
dropout: Option<f32>,
}
impl BPE {
@ -30,6 +34,26 @@ impl BPE {
vocab_r,
merges,
cache: Cache::new(),
dropout: None,
}
}
pub fn with_dropout(
vocab: HashMap<String, u32>,
vocab_r: HashMap<u32, String>,
merges: HashMap<Pair, (u32, u32)>,
dropout: f32,
) -> Result<Self> {
if dropout < 0.0 || dropout > 1.0 {
Err(Error::InvalidDropout.into())
} else {
Ok(BPE {
vocab,
vocab_r,
merges,
cache: Cache::new(),
dropout: if dropout == 0.0 { None } else { Some(dropout) },
})
}
}
@ -94,6 +118,7 @@ impl BPE {
vocab_r: vocab.into_iter().map(|(token, id)| (id, token)).collect(),
merges,
cache: Cache::new(),
dropout: None,
})
}
}
@ -117,7 +142,9 @@ impl Model for BPE {
);
for (i, (w, initial_offsets)) in sentence.iter().enumerate() {
if cached_words[i].is_none() {
// If we're using dropout or we don't have a cache hit, we have to compute
// merges for this word.
if self.dropout.is_some() || cached_words[i].is_none() {
let mut word = Word::new();
for c in w.chars() {
match self.vocab.get(&c.to_string()) {
@ -142,6 +169,18 @@ impl Model for BPE {
let rank = self
.merges
.get(&pair)
.map(|rank| {
if let Some(dropout) = self.dropout {
// With probability `dropout` we'll ignore
if thread_rng().gen::<f32>() < dropout {
&(std::u32::MAX, std::u32::MAX)
} else {
rank
}
} else {
rank
}
})
.unwrap_or(&(std::u32::MAX, std::u32::MAX));
(rank, pair)
})
@ -203,6 +242,106 @@ mod tests {
use super::*;
use tempfile::NamedTempFile;
#[test]
// Test tokenization. With dropout set to 0 tokenization is deterministic,
// so we know exactly what the result should be.
//
// To test this, we'll build a simple model to tokenize the word 'unrelated'.
fn test_tokenize_with_and_without_dropout() {
let vocab: HashMap<String, u32> = [
("u".into(), 0),
("n".into(), 1),
("r".into(), 2),
("e".into(), 3),
("l".into(), 4),
("a".into(), 5),
("t".into(), 6),
("d".into(), 7),
("re".into(), 8),
("at".into(), 9),
("ed".into(), 10),
("un".into(), 11),
("ated".into(), 12),
("rel".into(), 13),
("related".into(), 14),
("unrelated".into(), 15),
]
.iter()
.cloned()
.collect();
let vocab_r: HashMap<u32, String> = vocab
.iter()
.map(|(key, val)| (val.clone(), key.to_owned()))
.collect();
let merges: HashMap<Pair, (u32, u32)> = [
(
(vocab["r"].clone(), vocab["e"].clone()),
(1u32, vocab["re"].clone()),
), // 'r-e' -> 're'
(
(vocab["a"].clone(), vocab["t"].clone()),
(2u32, vocab["at"].clone()),
), // 'a-t' -> 'at'
(
(vocab["e"].clone(), vocab["d"].clone()),
(3u32, vocab["ed"].clone()),
), // 'e-d' -> 'ed'
(
(vocab["u"].clone(), vocab["n"].clone()),
(4u32, vocab["un"].clone()),
), // 'u-n' -> 'un'
(
(vocab["at"].clone(), vocab["ed"].clone()),
(5u32, vocab["ated"].clone()),
), // 'at-ed' -> 'ated'
(
(vocab["re"].clone(), vocab["l"].clone()),
(6u32, vocab["rel"].clone()),
), // 're-l' -> 'rel'
(
(vocab["rel"].clone(), vocab["ated"].clone()),
(7u32, vocab["related"].clone()),
), // 'rel-ated' -> 'related'
(
(vocab["un"].clone(), vocab["related"].clone()),
(8u32, vocab["unrelated"].clone()),
), // 'un-related' -> 'unrelated'
]
.iter()
.cloned()
.collect();
let mut bpe = BPE::new(vocab, vocab_r, merges);
let sentence: Vec<(String, Offsets)> = vec![("unrelated".into(), (0, 9))];
// With no dropout:
let tokens = bpe.tokenize(sentence.clone()).unwrap();
assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]);
// Now set dropout to 1.0. Result should be no merges performed.
bpe.dropout = Some(1.0);
let tokens = bpe.tokenize(sentence.clone()).unwrap();
assert_eq!(
tokens,
vec![
Token::new(0u32, "u".into(), (0, 1)),
Token::new(1u32, "n".into(), (1, 2)),
Token::new(2u32, "r".into(), (2, 3)),
Token::new(3u32, "e".into(), (3, 4)),
Token::new(4u32, "l".into(), (4, 5)),
Token::new(5u32, "a".into(), (5, 6)),
Token::new(6u32, "t".into(), (6, 7)),
Token::new(3u32, "e".into(), (7, 8)),
Token::new(7u32, "d".into(), (8, 9)),
]
);
// Now try with dropout between 0 and 1.
bpe.dropout = Some(0.5);
let tokens = bpe.tokenize(sentence.clone()).unwrap();
assert!(1 <= tokens.len() && tokens.len() <= 9);
}
#[test]
// Ensure `BPE::from_files` works as expected.
fn test_bpe_from_files() {
@ -219,11 +358,22 @@ mod tests {
.unwrap();
// Make sure we can instatiate a BPE model from the files.
assert!(BPE::from_files(
let result = BPE::from_files(
vocab_file.path().to_str().unwrap(),
merges_file.path().to_str().unwrap()
)
.is_ok());
merges_file.path().to_str().unwrap(),
);
assert!(result.is_ok());
let bpe = result.unwrap();
// Check merges.
assert_eq!(bpe.merges.get(&(0u32, 1u32)).unwrap(), &(1u32, 3u32));
// Check vocab.
assert_eq!(bpe.vocab.get("a").unwrap(), &0u32);
assert_eq!(bpe.vocab.get("b").unwrap(), &1u32);
assert_eq!(bpe.vocab.get("c").unwrap(), &2u32);
assert_eq!(bpe.vocab.get("ab").unwrap(), &3u32);
}
#[test]