diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 48ec68b3..47abd57f 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -245,6 +245,10 @@ impl BPE { &self.unk_token } + pub fn get_continuing_subword_prefix(&self) -> &Option { + &self.continuing_subword_prefix + } + fn merge_word(&self, w: &str) -> Word { let mut word = Word::new(); for (is_first, is_last, c) in w.chars().with_first_and_last() { diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index a4ab13a8..63e6f998 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -28,6 +28,7 @@ impl fmt::Display for Error { pub struct WordPiece { unk_token: String, + continuing_subword_prefix: String, max_input_chars_per_word: usize, vocab: HashMap, vocab_r: HashMap, @@ -38,6 +39,7 @@ impl Default for WordPiece { WordPiece { vocab: HashMap::new(), vocab_r: HashMap::new(), + continuing_subword_prefix: String::from("##"), unk_token: String::from("[UNK]"), max_input_chars_per_word: 100, } @@ -64,6 +66,7 @@ impl WordPiece { vocab_r: vocab.into_iter().map(|(token, id)| (id, token)).collect(), unk_token, max_input_chars_per_word: max_input_chars_per_word.unwrap_or(100), + ..Default::default() }) } @@ -86,6 +89,9 @@ impl WordPiece { wp.unk_token = unk_token.to_owned(); } } + if let Some(prefix) = bpe.get_continuing_subword_prefix() { + wp.continuing_subword_prefix = prefix.to_owned(); + } wp } @@ -125,7 +131,7 @@ impl Model for WordPiece { while start < end { let mut substr = chars[start..end].iter().collect::(); if start > 0 { - substr = format!("##{}", substr); + substr = format!("{}{}", self.continuing_subword_prefix, substr); } if self.vocab.contains_key(&substr) { cur_str = Some(Token {