mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Modify BPE with unk_token being a String
This commit is contained in:
@ -25,7 +25,7 @@ class BPE:
|
||||
merges: str,
|
||||
cache_capacity: Optional[int],
|
||||
dropout: Optional[float],
|
||||
unk_token: Optional[int],
|
||||
unk_token: Optional[str],
|
||||
continuing_subword_prefix: Optional[str],
|
||||
end_of_word_suffix: Optional[str]) -> Model:
|
||||
""" Instantiate a BPE Model from the given vocab and merges files.
|
||||
@ -45,8 +45,8 @@ class BPE:
|
||||
dropout: (`optional`) Optional[float] [0, 1]:
|
||||
The BPE dropout to use. Must be an float between 0 and 1
|
||||
|
||||
unk_token: (`optional`) int:
|
||||
The unknown token id to be used by the model.
|
||||
unk_token: (`optional`) str:
|
||||
The unknown token to be used by the model.
|
||||
|
||||
continuing_subword_prefix: (`optional`) str:
|
||||
The prefix to attach to subword units that don't represent a beginning of word.
|
||||
|
@ -22,6 +22,8 @@ pub enum Error {
|
||||
BadMerges(usize),
|
||||
/// If a token found in merges, is not in the vocab
|
||||
MergeTokenOutOfVocabulary(String),
|
||||
/// If the provided unk token is out of vocabulary
|
||||
UnkTokenOutOfVocabulary(String),
|
||||
/// Dropout not between 0 and 1.
|
||||
InvalidDropout,
|
||||
}
|
||||
@ -46,7 +48,10 @@ impl std::fmt::Display for Error {
|
||||
Error::BadVocabulary => write!(f, "Bad vocabulary json file"),
|
||||
Error::BadMerges(line) => write!(f, "Merges text file invalid at line {}", line),
|
||||
Error::MergeTokenOutOfVocabulary(token) => {
|
||||
write!(f, "Token {} out of vocabulary", token)
|
||||
write!(f, "Token `{}` out of vocabulary", token)
|
||||
}
|
||||
Error::UnkTokenOutOfVocabulary(token) => {
|
||||
write!(f, "Unk token `{}` not found in the vocabulary", token)
|
||||
}
|
||||
Error::InvalidDropout => write!(f, "Dropout should be between 0 and 1"),
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ struct Config {
|
||||
merges: Option<HashMap<Pair, (u32, u32)>>,
|
||||
cache_capacity: Option<usize>,
|
||||
dropout: Option<f32>,
|
||||
unk_token: Option<u32>,
|
||||
unk_token: Option<String>,
|
||||
continuing_subword_prefix: Option<String>,
|
||||
end_of_word_suffix: Option<String>,
|
||||
}
|
||||
@ -57,7 +57,7 @@ impl BpeBuilder {
|
||||
}
|
||||
|
||||
/// Set the `UNK` token for the vocab.
|
||||
pub fn unk_token(mut self, unk_token: u32) -> Self {
|
||||
pub fn unk_token(mut self, unk_token: String) -> Self {
|
||||
self.config.unk_token = Some(unk_token);
|
||||
self
|
||||
}
|
||||
@ -130,7 +130,7 @@ pub struct BPE {
|
||||
/// perform no merges, so the result will just be characters.
|
||||
dropout: Option<f32>,
|
||||
/// The unknown token to be used when we encounter an unknown char
|
||||
unk_token: Option<u32>,
|
||||
unk_token: Option<String>,
|
||||
/// An optional prefix to use on any subword that exist only behind another one
|
||||
continuing_subword_prefix: Option<String>,
|
||||
/// An optional suffix to caracterize and end-of-word subword
|
||||
@ -166,7 +166,7 @@ impl Clone for BPE {
|
||||
merges: self.merges.clone(),
|
||||
cache: fresh_cache,
|
||||
dropout: self.dropout,
|
||||
unk_token: self.unk_token,
|
||||
unk_token: self.unk_token.clone(),
|
||||
continuing_subword_prefix: self.continuing_subword_prefix.clone(),
|
||||
end_of_word_suffix: self.end_of_word_suffix.clone(),
|
||||
}
|
||||
@ -254,7 +254,7 @@ impl BPE {
|
||||
&self.vocab
|
||||
}
|
||||
|
||||
pub fn get_unk_token(&self) -> &Option<u32> {
|
||||
pub fn get_unk_token(&self) -> &Option<String> {
|
||||
&self.unk_token
|
||||
}
|
||||
|
||||
@ -262,7 +262,7 @@ impl BPE {
|
||||
&self.continuing_subword_prefix
|
||||
}
|
||||
|
||||
fn merge_word(&self, w: &str) -> Word {
|
||||
fn merge_word(&self, w: &str) -> Result<Word> {
|
||||
let mut word = Word::new();
|
||||
for (is_first, is_last, c) in w.chars().with_first_and_last() {
|
||||
let mut s = c.to_string();
|
||||
@ -283,8 +283,12 @@ impl BPE {
|
||||
if let Some(id) = self.vocab.get(&s) {
|
||||
word.add(*id);
|
||||
} else if let Some(unk) = &self.unk_token {
|
||||
let unk_id = self
|
||||
.vocab
|
||||
.get(unk)
|
||||
.ok_or_else(|| Error::UnkTokenOutOfVocabulary(unk.to_owned()))?;
|
||||
// Handle UNK token
|
||||
word.add(*unk);
|
||||
word.add(*unk_id);
|
||||
}
|
||||
}
|
||||
|
||||
@ -328,7 +332,7 @@ impl BPE {
|
||||
word.merge(pair.0, pair.1, *new_id);
|
||||
}
|
||||
|
||||
word
|
||||
Ok(word)
|
||||
}
|
||||
|
||||
fn word_to_tokens(&self, word: &Word, initial_offsets: &(usize, usize)) -> Vec<Token> {
|
||||
@ -370,14 +374,14 @@ impl Model for BPE {
|
||||
let tokens = match cached_words {
|
||||
None => {
|
||||
// Not using cache since we're using dropout.
|
||||
let word = self.merge_word(&w);
|
||||
let word = self.merge_word(&w)?;
|
||||
self.word_to_tokens(&word, initial_offsets)
|
||||
}
|
||||
Some(ref mut cache) => {
|
||||
match cache[i].as_ref() {
|
||||
None => {
|
||||
// No cache hit, so re-compute merges.
|
||||
let word = self.merge_word(&w);
|
||||
let word = self.merge_word(&w)?;
|
||||
let tokens = self.word_to_tokens(&word, initial_offsets);
|
||||
// Add to cache.
|
||||
cache[i] = Some(word);
|
||||
|
@ -156,9 +156,7 @@ impl WordPiece {
|
||||
pub fn from_bpe(bpe: &BPE) -> Self {
|
||||
let mut wp = Self::builder().vocab(bpe.get_vocab().clone()).build();
|
||||
if let Some(unk) = bpe.get_unk_token() {
|
||||
if let Some(unk_token) = wp.vocab_r.get(&unk) {
|
||||
wp.unk_token = unk_token.to_owned();
|
||||
}
|
||||
wp.unk_token = unk.to_owned();
|
||||
}
|
||||
if let Some(prefix) = bpe.get_continuing_subword_prefix() {
|
||||
wp.continuing_subword_prefix = prefix.to_owned();
|
||||
|
Reference in New Issue
Block a user