mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Enable dropout = 0.0
as an equivalent to none
in BPE (#1550)
* enable dropout = 0.0 * typo * lint * formatter * enable dropout = 0.0 * formatter
This commit is contained in:
@ -69,6 +69,10 @@ class TestBPE:
|
||||
model.byte_fallback = True
|
||||
assert model.byte_fallback == True
|
||||
|
||||
def test_dropout_zero(self):
|
||||
model = BPE(dropout=0.0)
|
||||
assert model.dropout == 0.0
|
||||
|
||||
|
||||
class TestWordPiece:
|
||||
def test_instantiate(self, bert_files):
|
||||
|
@ -31,7 +31,7 @@ pub enum Error {
|
||||
#[error("Unk token `{0}` not found in the vocabulary")]
|
||||
UnkTokenOutOfVocabulary(String),
|
||||
/// Dropout not between 0 and 1.
|
||||
#[error("Dropout should be between 0 and 1")]
|
||||
#[error("Dropout should be between 0 and 1, inclusive")]
|
||||
InvalidDropout,
|
||||
}
|
||||
|
||||
|
@ -136,7 +136,7 @@ impl BpeBuilder {
|
||||
pub fn build(mut self) -> Result<BPE> {
|
||||
// Validate dropout.
|
||||
if let Some(p) = self.config.dropout {
|
||||
if p <= 0.0 || p > 1.0 {
|
||||
if !(0.0..=1.0).contains(&p) {
|
||||
return Err(Error::InvalidDropout.into());
|
||||
}
|
||||
}
|
||||
@ -214,7 +214,7 @@ pub struct BPE {
|
||||
pub(crate) merges: MergeMap,
|
||||
/// Contains the cache for optimizing the encoding step.
|
||||
cache: Option<Cache<String, Word>>,
|
||||
/// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will
|
||||
/// Dropout probability for merges. 0.0 = no dropout is the default. At 1.0, tokenization will
|
||||
/// perform no merges, so the result will just be characters.
|
||||
pub dropout: Option<f32>,
|
||||
/// The unknown token to be used when we encounter an unknown char
|
||||
@ -493,7 +493,7 @@ impl Model for BPE {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
if self.dropout.is_none() {
|
||||
if self.dropout.is_none() || self.dropout == Some(0.0) {
|
||||
self.tokenize_with_cache(sequence)
|
||||
} else {
|
||||
let word = self.merge_word(sequence)?;
|
||||
@ -685,6 +685,11 @@ mod tests {
|
||||
let tokens = bpe.tokenize("unrelated").unwrap();
|
||||
assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]);
|
||||
|
||||
// With dropout = 0.0 (equivalent to dropout == none)
|
||||
bpe.dropout = Some(0.0);
|
||||
let tokens = bpe.tokenize("unrelated").unwrap();
|
||||
assert_eq!(tokens, vec![Token::new(15u32, "unrelated".into(), (0, 9))]);
|
||||
|
||||
// Now set dropout to 1.0. Result should be no merges performed.
|
||||
bpe.dropout = Some(1.0);
|
||||
let tokens = bpe.tokenize("unrelated").unwrap();
|
||||
@ -739,6 +744,13 @@ mod tests {
|
||||
assert_eq!(bpe.vocab.get("ab").unwrap(), &3u32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Ensure BPEBuilder with dropout = 0.0 doesn't error
|
||||
fn test_bpe_with_dropout_0() {
|
||||
let bpe = BPE::builder().dropout(0.0).build().unwrap();
|
||||
assert_eq!(bpe.dropout, Some(0.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Ensure `BPE::from_file` works as expected.
|
||||
fn test_bpe_with_continuing_subword_prefix() {
|
||||
|
@ -229,6 +229,21 @@ fn tokenizer() {
|
||||
assert_eq!(serde_json::to_string(&de).unwrap(), ser);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bpe_with_dropout_serde() {
|
||||
let mut bpe = BPE::default();
|
||||
bpe.dropout = Some(0.1);
|
||||
let ser = serde_json::to_string(&bpe).unwrap();
|
||||
let de = serde_json::from_str(&ser).unwrap();
|
||||
assert_eq!(bpe, de);
|
||||
|
||||
// set dropout to 0.0 (which is analogous to None) and reserialize
|
||||
bpe.dropout = Some(0.0);
|
||||
let ser = serde_json::to_string(&bpe).unwrap();
|
||||
let de = serde_json::from_str(&ser).unwrap();
|
||||
assert_eq!(bpe, de);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_long_file() {
|
||||
let _tokenizer = Tokenizer::from_file("data/albert-base-v1-tokenizer.json").unwrap();
|
||||
|
Reference in New Issue
Block a user