mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
implement dropout for BPE
This commit is contained in:
@ -14,6 +14,7 @@ path = "src/cli.rs"
|
||||
|
||||
[dependencies]
|
||||
lazy_static = "1.3.0"
|
||||
rand = "0.7.2"
|
||||
regex = "1.3.1"
|
||||
regex-syntax = "0.6.12"
|
||||
rayon = "1.2.0"
|
||||
|
@ -21,6 +21,8 @@ pub enum Error {
|
||||
BadMerges(usize),
|
||||
/// If a token found in merges, is not in the vocab
|
||||
MergeTokenOutOfVocabulary(String),
|
||||
/// Dropout not between 0 and 1.
|
||||
InvalidDropout,
|
||||
}
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
@ -45,6 +47,7 @@ impl std::fmt::Display for Error {
|
||||
Error::MergeTokenOutOfVocabulary(token) => {
|
||||
write!(f, "Token {} out of vocabulary", token)
|
||||
}
|
||||
Error::InvalidDropout => write!(f, "Dropout should be between 0 and 1"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -54,9 +57,7 @@ impl std::error::Error for Error {
|
||||
match self {
|
||||
Error::Io(e) => Some(e),
|
||||
Error::JsonError(e) => Some(e),
|
||||
Error::BadVocabulary => None,
|
||||
Error::BadMerges(_) => None,
|
||||
Error::MergeTokenOutOfVocabulary(_) => None,
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
use super::{Cache, Error, Pair, Word};
|
||||
use crate::tokenizer::{Model, Offsets, Result, Token};
|
||||
use rand::{thread_rng, Rng};
|
||||
use serde_json::Value;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
@ -9,14 +10,17 @@ use std::{
|
||||
};
|
||||
|
||||
pub struct BPE {
|
||||
/// The vocabulary assigns a number to each token
|
||||
/// The vocabulary assigns a number to each token.
|
||||
vocab: HashMap<String, u32>,
|
||||
/// Reversed vocabulary, to rebuild sentences
|
||||
/// Reversed vocabulary, to rebuild sentences.
|
||||
vocab_r: HashMap<u32, String>,
|
||||
/// Contains the mapping between Pairs and their (rank, new_id)
|
||||
/// Contains the mapping between Pairs and their (rank, new_id).
|
||||
merges: HashMap<Pair, (u32, u32)>,
|
||||
/// Contains the cache for optimizing the encoding step
|
||||
/// Contains the cache for optimizing the encoding step.
|
||||
cache: Cache<String, Word>,
|
||||
/// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will
|
||||
/// perform no merges, so the result will just be characters.
|
||||
dropout: Option<f32>,
|
||||
}
|
||||
|
||||
impl BPE {
|
||||
@ -30,6 +34,26 @@ impl BPE {
|
||||
vocab_r,
|
||||
merges,
|
||||
cache: Cache::new(),
|
||||
dropout: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_dropout(
|
||||
vocab: HashMap<String, u32>,
|
||||
vocab_r: HashMap<u32, String>,
|
||||
merges: HashMap<Pair, (u32, u32)>,
|
||||
dropout: f32,
|
||||
) -> Result<Self> {
|
||||
if dropout < 0.0 || dropout > 1.0 {
|
||||
Err(Error::InvalidDropout.into())
|
||||
} else {
|
||||
Ok(BPE {
|
||||
vocab,
|
||||
vocab_r,
|
||||
merges,
|
||||
cache: Cache::new(),
|
||||
dropout: if dropout == 0.0 { None } else { Some(dropout) },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -94,6 +118,7 @@ impl BPE {
|
||||
vocab_r: vocab.into_iter().map(|(token, id)| (id, token)).collect(),
|
||||
merges,
|
||||
cache: Cache::new(),
|
||||
dropout: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -117,7 +142,9 @@ impl Model for BPE {
|
||||
);
|
||||
|
||||
for (i, (w, initial_offsets)) in sentence.iter().enumerate() {
|
||||
if cached_words[i].is_none() {
|
||||
// If we're using dropout or we don't have a cache hit, we have to compute
|
||||
// merges for this word.
|
||||
if self.dropout.is_some() || cached_words[i].is_none() {
|
||||
let mut word = Word::new();
|
||||
for c in w.chars() {
|
||||
match self.vocab.get(&c.to_string()) {
|
||||
@ -142,6 +169,18 @@ impl Model for BPE {
|
||||
let rank = self
|
||||
.merges
|
||||
.get(&pair)
|
||||
.map(|rank| {
|
||||
if let Some(dropout) = self.dropout {
|
||||
// With probability `dropout` we'll ignore
|
||||
if thread_rng().gen::<f32>() < dropout {
|
||||
&(std::u32::MAX, std::u32::MAX)
|
||||
} else {
|
||||
rank
|
||||
}
|
||||
} else {
|
||||
rank
|
||||
}
|
||||
})
|
||||
.unwrap_or(&(std::u32::MAX, std::u32::MAX));
|
||||
(rank, pair)
|
||||
})
|
||||
@ -203,6 +242,106 @@ mod tests {
|
||||
use super::*;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[test]
|
||||
// Test tokenization. With dropout set to 0 tokenization is deterministic,
|
||||
// so we know exactly what the result should be.
|
||||
//
|
||||
// To test this, we'll build a simple model to tokenize the word 'unrelated'.
|
||||
fn test_tokenize_with_and_without_dropout() {
|
||||
let vocab: HashMap<String, u32> = [
|
||||
("u".into(), 0),
|
||||
("n".into(), 1),
|
||||
("r".into(), 2),
|
||||
("e".into(), 3),
|
||||
("l".into(), 4),
|
||||
("a".into(), 5),
|
||||
("t".into(), 6),
|
||||
("d".into(), 7),
|
||||
("re".into(), 8),
|
||||
("at".into(), 9),
|
||||
("ed".into(), 10),
|
||||
("un".into(), 11),
|
||||
("ated".into(), 12),
|
||||
("rel".into(), 13),
|
||||
("related".into(), 14),
|
||||
("unrelated".into(), 15),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
let vocab_r: HashMap<u32, String> = vocab
|
||||
.iter()
|
||||
.map(|(key, val)| (val.clone(), key.to_owned()))
|
||||
.collect();
|
||||
let merges: HashMap<Pair, (u32, u32)> = [
|
||||
(
|
||||
(vocab["r"].clone(), vocab["e"].clone()),
|
||||
(1u32, vocab["re"].clone()),
|
||||
), // 'r-e' -> 're'
|
||||
(
|
||||
(vocab["a"].clone(), vocab["t"].clone()),
|
||||
(2u32, vocab["at"].clone()),
|
||||
), // 'a-t' -> 'at'
|
||||
(
|
||||
(vocab["e"].clone(), vocab["d"].clone()),
|
||||
(3u32, vocab["ed"].clone()),
|
||||
), // 'e-d' -> 'ed'
|
||||
(
|
||||
(vocab["u"].clone(), vocab["n"].clone()),
|
||||
(4u32, vocab["un"].clone()),
|
||||
), // 'u-n' -> 'un'
|
||||
(
|
||||
(vocab["at"].clone(), vocab["ed"].clone()),
|
||||
(5u32, vocab["ated"].clone()),
|
||||
), // 'at-ed' -> 'ated'
|
||||
(
|
||||
(vocab["re"].clone(), vocab["l"].clone()),
|
||||
(6u32, vocab["rel"].clone()),
|
||||
), // 're-l' -> 'rel'
|
||||
(
|
||||
(vocab["rel"].clone(), vocab["ated"].clone()),
|
||||
(7u32, vocab["related"].clone()),
|
||||
), // 'rel-ated' -> 'related'
|
||||
(
|
||||
(vocab["un"].clone(), vocab["related"].clone()),
|
||||
(8u32, vocab["unrelated"].clone()),
|
||||
), // 'un-related' -> 'unrelated'
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
let mut bpe = BPE::new(vocab, vocab_r, merges);
|
||||
|
||||
let sentence: Vec<(String, Offsets)> = vec![("unrelated".into(), (0, 9))];
|
||||
|
||||
// With no dropout:
|
||||
let tokens = bpe.tokenize(sentence.clone()).unwrap();
|
||||
assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]);
|
||||
|
||||
// Now set dropout to 1.0. Result should be no merges performed.
|
||||
bpe.dropout = Some(1.0);
|
||||
let tokens = bpe.tokenize(sentence.clone()).unwrap();
|
||||
assert_eq!(
|
||||
tokens,
|
||||
vec![
|
||||
Token::new(0u32, "u".into(), (0, 1)),
|
||||
Token::new(1u32, "n".into(), (1, 2)),
|
||||
Token::new(2u32, "r".into(), (2, 3)),
|
||||
Token::new(3u32, "e".into(), (3, 4)),
|
||||
Token::new(4u32, "l".into(), (4, 5)),
|
||||
Token::new(5u32, "a".into(), (5, 6)),
|
||||
Token::new(6u32, "t".into(), (6, 7)),
|
||||
Token::new(3u32, "e".into(), (7, 8)),
|
||||
Token::new(7u32, "d".into(), (8, 9)),
|
||||
]
|
||||
);
|
||||
|
||||
// Now try with dropout between 0 and 1.
|
||||
bpe.dropout = Some(0.5);
|
||||
let tokens = bpe.tokenize(sentence.clone()).unwrap();
|
||||
assert!(1 <= tokens.len() && tokens.len() <= 9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Ensure `BPE::from_files` works as expected.
|
||||
fn test_bpe_from_files() {
|
||||
@ -219,11 +358,22 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
// Make sure we can instatiate a BPE model from the files.
|
||||
assert!(BPE::from_files(
|
||||
let result = BPE::from_files(
|
||||
vocab_file.path().to_str().unwrap(),
|
||||
merges_file.path().to_str().unwrap()
|
||||
)
|
||||
.is_ok());
|
||||
merges_file.path().to_str().unwrap(),
|
||||
);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let bpe = result.unwrap();
|
||||
|
||||
// Check merges.
|
||||
assert_eq!(bpe.merges.get(&(0u32, 1u32)).unwrap(), &(1u32, 3u32));
|
||||
|
||||
// Check vocab.
|
||||
assert_eq!(bpe.vocab.get("a").unwrap(), &0u32);
|
||||
assert_eq!(bpe.vocab.get("b").unwrap(), &1u32);
|
||||
assert_eq!(bpe.vocab.get("c").unwrap(), &2u32);
|
||||
assert_eq!(bpe.vocab.get("ab").unwrap(), &3u32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
Reference in New Issue
Block a user