mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Wordpiece can fail
This commit is contained in:
@ -1,9 +1,9 @@
|
||||
use crate::tokenizer::Decoder;
|
||||
use crate::tokenizer::{Decoder, Result};
|
||||
|
||||
pub struct WordPiece;
|
||||
|
||||
impl Decoder for WordPiece {
|
||||
fn decode(&self, tokens: Vec<String>) -> String {
|
||||
tokens.join(" ").replace(" ##", "")
|
||||
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||
Ok(tokens.join(" ").replace(" ##", ""))
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,28 @@
|
||||
use crate::tokenizer::{Model, Token};
|
||||
use crate::tokenizer::{Model, Result, Token};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fmt,
|
||||
fs::File,
|
||||
io::{BufRead, BufReader},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
MissingUnkToken,
|
||||
}
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
MissingUnkToken => write!(
|
||||
fmt,
|
||||
"WordPiece error: Missing [UNK] token from the vocabulary"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WordPiece {
|
||||
unk_token: String,
|
||||
max_input_chars_per_word: usize,
|
||||
@ -52,7 +70,7 @@ impl Model for WordPiece {
|
||||
self.vocab.len()
|
||||
}
|
||||
|
||||
fn tokenize(&self, sentence: Vec<String>) -> Vec<Token> {
|
||||
fn tokenize(&self, sentence: Vec<String>) -> Result<Vec<Token>> {
|
||||
let mut output_tokens = vec![];
|
||||
|
||||
let mut offset = 0usize;
|
||||
@ -61,7 +79,10 @@ impl Model for WordPiece {
|
||||
if char_len > self.max_input_chars_per_word {
|
||||
output_tokens.push(Token {
|
||||
value: self.unk_token.clone(),
|
||||
id: *self.vocab.get(&self.unk_token).unwrap_or(&0),
|
||||
id: *self
|
||||
.vocab
|
||||
.get(&self.unk_token)
|
||||
.ok_or(Error::MissingUnkToken)?,
|
||||
offsets: (offset, offset + char_len),
|
||||
});
|
||||
continue;
|
||||
@ -104,7 +125,10 @@ impl Model for WordPiece {
|
||||
if is_bad {
|
||||
output_tokens.push(Token {
|
||||
value: self.unk_token.clone(),
|
||||
id: *self.vocab.get(&self.unk_token).unwrap_or(&0),
|
||||
id: *self
|
||||
.vocab
|
||||
.get(&self.unk_token)
|
||||
.ok_or(Error::MissingUnkToken)?,
|
||||
offsets: (offset, offset + char_len),
|
||||
});
|
||||
} else {
|
||||
@ -114,15 +138,16 @@ impl Model for WordPiece {
|
||||
offset += char_len;
|
||||
}
|
||||
|
||||
output_tokens
|
||||
Ok(output_tokens)
|
||||
}
|
||||
|
||||
fn decode(&self, ids: Vec<u32>) -> Vec<String> {
|
||||
ids.into_iter()
|
||||
fn decode(&self, ids: Vec<u32>) -> Result<Vec<String>> {
|
||||
Ok(ids
|
||||
.into_iter()
|
||||
.map(|id| self.vocab_r.get(&id))
|
||||
.filter(|token| token.is_some())
|
||||
.map(|id| id.unwrap().clone())
|
||||
.collect()
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::tokenizer::PreTokenizer;
|
||||
use crate::tokenizer::{PreTokenizer, Result};
|
||||
use std::collections::HashSet;
|
||||
use unicode_categories::UnicodeCategories;
|
||||
use unicode_normalization::UnicodeNormalization;
|
||||
@ -144,7 +144,7 @@ impl BasicPreTokenizer {
|
||||
}
|
||||
|
||||
impl PreTokenizer for BasicPreTokenizer {
|
||||
fn pre_tokenize(&self, s: &str) -> Vec<String> {
|
||||
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
|
||||
let mut text = self.clean_text(s);
|
||||
|
||||
// This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
@ -166,6 +166,6 @@ impl PreTokenizer for BasicPreTokenizer {
|
||||
split_tokens.extend(self.run_split_on_punc(&tk));
|
||||
}
|
||||
|
||||
split_tokens
|
||||
Ok(split_tokens)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user