Wordpiece can fail

This commit is contained in:
Anthony MOI
2019-12-11 07:30:27 -05:00
parent a929a99e05
commit fbebbec585
3 changed files with 39 additions and 14 deletions

View File

@ -1,9 +1,9 @@
use crate::tokenizer::Decoder; use crate::tokenizer::{Decoder, Result};
pub struct WordPiece; pub struct WordPiece;
impl Decoder for WordPiece { impl Decoder for WordPiece {
fn decode(&self, tokens: Vec<String>) -> String { fn decode(&self, tokens: Vec<String>) -> Result<String> {
tokens.join(" ").replace(" ##", "") Ok(tokens.join(" ").replace(" ##", ""))
} }
} }

View File

@ -1,10 +1,28 @@
use crate::tokenizer::{Model, Token}; use crate::tokenizer::{Model, Result, Token};
use std::{ use std::{
collections::HashMap, collections::HashMap,
fmt,
fs::File, fs::File,
io::{BufRead, BufReader}, io::{BufRead, BufReader},
}; };
#[derive(Debug)]
pub enum Error {
MissingUnkToken,
}
impl std::error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self {
MissingUnkToken => write!(
fmt,
"WordPiece error: Missing [UNK] token from the vocabulary"
),
}
}
}
pub struct WordPiece { pub struct WordPiece {
unk_token: String, unk_token: String,
max_input_chars_per_word: usize, max_input_chars_per_word: usize,
@ -52,7 +70,7 @@ impl Model for WordPiece {
self.vocab.len() self.vocab.len()
} }
fn tokenize(&self, sentence: Vec<String>) -> Vec<Token> { fn tokenize(&self, sentence: Vec<String>) -> Result<Vec<Token>> {
let mut output_tokens = vec![]; let mut output_tokens = vec![];
let mut offset = 0usize; let mut offset = 0usize;
@ -61,7 +79,10 @@ impl Model for WordPiece {
if char_len > self.max_input_chars_per_word { if char_len > self.max_input_chars_per_word {
output_tokens.push(Token { output_tokens.push(Token {
value: self.unk_token.clone(), value: self.unk_token.clone(),
id: *self.vocab.get(&self.unk_token).unwrap_or(&0), id: *self
.vocab
.get(&self.unk_token)
.ok_or(Error::MissingUnkToken)?,
offsets: (offset, offset + char_len), offsets: (offset, offset + char_len),
}); });
continue; continue;
@ -104,7 +125,10 @@ impl Model for WordPiece {
if is_bad { if is_bad {
output_tokens.push(Token { output_tokens.push(Token {
value: self.unk_token.clone(), value: self.unk_token.clone(),
id: *self.vocab.get(&self.unk_token).unwrap_or(&0), id: *self
.vocab
.get(&self.unk_token)
.ok_or(Error::MissingUnkToken)?,
offsets: (offset, offset + char_len), offsets: (offset, offset + char_len),
}); });
} else { } else {
@ -114,15 +138,16 @@ impl Model for WordPiece {
offset += char_len; offset += char_len;
} }
output_tokens Ok(output_tokens)
} }
fn decode(&self, ids: Vec<u32>) -> Vec<String> { fn decode(&self, ids: Vec<u32>) -> Result<Vec<String>> {
ids.into_iter() Ok(ids
.into_iter()
.map(|id| self.vocab_r.get(&id)) .map(|id| self.vocab_r.get(&id))
.filter(|token| token.is_some()) .filter(|token| token.is_some())
.map(|id| id.unwrap().clone()) .map(|id| id.unwrap().clone())
.collect() .collect())
} }
fn token_to_id(&self, token: &str) -> Option<u32> { fn token_to_id(&self, token: &str) -> Option<u32> {

View File

@ -1,4 +1,4 @@
use crate::tokenizer::PreTokenizer; use crate::tokenizer::{PreTokenizer, Result};
use std::collections::HashSet; use std::collections::HashSet;
use unicode_categories::UnicodeCategories; use unicode_categories::UnicodeCategories;
use unicode_normalization::UnicodeNormalization; use unicode_normalization::UnicodeNormalization;
@ -144,7 +144,7 @@ impl BasicPreTokenizer {
} }
impl PreTokenizer for BasicPreTokenizer { impl PreTokenizer for BasicPreTokenizer {
fn pre_tokenize(&self, s: &str) -> Vec<String> { fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
let mut text = self.clean_text(s); let mut text = self.clean_text(s);
// This was added on November 1st, 2018 for the multilingual and Chinese // This was added on November 1st, 2018 for the multilingual and Chinese
@ -166,6 +166,6 @@ impl PreTokenizer for BasicPreTokenizer {
split_tokens.extend(self.run_split_on_punc(&tk)); split_tokens.extend(self.run_split_on_punc(&tk));
} }
split_tokens Ok(split_tokens)
} }
} }