Wordpiece can fail

This commit is contained in:
Anthony MOI
2019-12-11 07:30:27 -05:00
parent a929a99e05
commit fbebbec585
3 changed files with 39 additions and 14 deletions

View File

@ -1,9 +1,9 @@
use crate::tokenizer::Decoder;
use crate::tokenizer::{Decoder, Result};
pub struct WordPiece;
impl Decoder for WordPiece {
fn decode(&self, tokens: Vec<String>) -> String {
tokens.join(" ").replace(" ##", "")
fn decode(&self, tokens: Vec<String>) -> Result<String> {
Ok(tokens.join(" ").replace(" ##", ""))
}
}

View File

@ -1,10 +1,28 @@
use crate::tokenizer::{Model, Token};
use crate::tokenizer::{Model, Result, Token};
use std::{
collections::HashMap,
fmt,
fs::File,
io::{BufRead, BufReader},
};
#[derive(Debug)]
pub enum Error {
MissingUnkToken,
}
impl std::error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self {
MissingUnkToken => write!(
fmt,
"WordPiece error: Missing [UNK] token from the vocabulary"
),
}
}
}
pub struct WordPiece {
unk_token: String,
max_input_chars_per_word: usize,
@ -52,7 +70,7 @@ impl Model for WordPiece {
self.vocab.len()
}
fn tokenize(&self, sentence: Vec<String>) -> Vec<Token> {
fn tokenize(&self, sentence: Vec<String>) -> Result<Vec<Token>> {
let mut output_tokens = vec![];
let mut offset = 0usize;
@ -61,7 +79,10 @@ impl Model for WordPiece {
if char_len > self.max_input_chars_per_word {
output_tokens.push(Token {
value: self.unk_token.clone(),
id: *self.vocab.get(&self.unk_token).unwrap_or(&0),
id: *self
.vocab
.get(&self.unk_token)
.ok_or(Error::MissingUnkToken)?,
offsets: (offset, offset + char_len),
});
continue;
@ -104,7 +125,10 @@ impl Model for WordPiece {
if is_bad {
output_tokens.push(Token {
value: self.unk_token.clone(),
id: *self.vocab.get(&self.unk_token).unwrap_or(&0),
id: *self
.vocab
.get(&self.unk_token)
.ok_or(Error::MissingUnkToken)?,
offsets: (offset, offset + char_len),
});
} else {
@ -114,15 +138,16 @@ impl Model for WordPiece {
offset += char_len;
}
output_tokens
Ok(output_tokens)
}
fn decode(&self, ids: Vec<u32>) -> Vec<String> {
ids.into_iter()
fn decode(&self, ids: Vec<u32>) -> Result<Vec<String>> {
Ok(ids
.into_iter()
.map(|id| self.vocab_r.get(&id))
.filter(|token| token.is_some())
.map(|id| id.unwrap().clone())
.collect()
.collect())
}
fn token_to_id(&self, token: &str) -> Option<u32> {

View File

@ -1,4 +1,4 @@
use crate::tokenizer::PreTokenizer;
use crate::tokenizer::{PreTokenizer, Result};
use std::collections::HashSet;
use unicode_categories::UnicodeCategories;
use unicode_normalization::UnicodeNormalization;
@ -144,7 +144,7 @@ impl BasicPreTokenizer {
}
impl PreTokenizer for BasicPreTokenizer {
fn pre_tokenize(&self, s: &str) -> Vec<String> {
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
let mut text = self.clean_text(s);
// This was added on November 1st, 2018 for the multilingual and Chinese
@ -166,6 +166,6 @@ impl PreTokenizer for BasicPreTokenizer {
split_tokens.extend(self.run_split_on_punc(&tk));
}
split_tokens
Ok(split_tokens)
}
}