BPE can fail

This commit is contained in:
Anthony MOI
2019-12-11 07:30:51 -05:00
parent fbebbec585
commit 4807894da6
3 changed files with 36 additions and 33 deletions

View File

@ -1,5 +1,5 @@
use super::{Cache, Error, Pair, Word}; use super::{Cache, Error, Pair, Word};
use crate::tokenizer::{Model, Token}; use crate::tokenizer::{Model, Result, Token};
use serde_json::Value; use serde_json::Value;
use std::{ use std::{
collections::HashMap, collections::HashMap,
@ -37,7 +37,7 @@ impl BPE {
BPE::new(HashMap::new(), HashMap::new(), HashMap::new()) BPE::new(HashMap::new(), HashMap::new(), HashMap::new())
} }
pub fn from_files(vocab: &str, merges: &str) -> Result<Self, Error> { pub fn from_files(vocab: &str, merges: &str) -> Result<Self> {
// Read vocab.json // Read vocab.json
let vocab_file = File::open(vocab)?; let vocab_file = File::open(vocab)?;
let mut vocab_file = BufReader::new(vocab_file); let mut vocab_file = BufReader::new(vocab_file);
@ -55,7 +55,7 @@ impl BPE {
} }
} }
} }
_ => return Err(Error::BadVocabulary), _ => return Err(Box::new(Error::BadVocabulary)),
}; };
// Read merges file // Read merges file
@ -100,9 +100,9 @@ impl Model for BPE {
self.vocab.len() self.vocab.len()
} }
fn tokenize(&self, sentence: Vec<String>) -> Vec<Token> { fn tokenize(&self, sentence: Vec<String>) -> Result<Vec<Token>> {
if sentence.len() == 0 { if sentence.len() == 0 {
return vec![]; return Ok(vec![]);
} }
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len()); let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
@ -181,15 +181,16 @@ impl Model for BPE {
.unzip::<_, _, Vec<String>, Vec<Word>>(); .unzip::<_, _, Vec<String>, Vec<Word>>();
self.cache.set_values(keys, values); self.cache.set_values(keys, values);
encoded Ok(encoded)
} }
fn decode(&self, ids: Vec<u32>) -> Vec<String> { fn decode(&self, ids: Vec<u32>) -> Result<Vec<String>> {
ids.into_iter() Ok(ids
.into_iter()
.map(|id| self.vocab_r.get(&id)) .map(|id| self.vocab_r.get(&id))
.filter(|token| token.is_some()) .filter(|token| token.is_some())
.map(|id| id.unwrap().clone()) .map(|id| id.unwrap().clone())
.collect() .collect())
} }
fn token_to_id(&self, token: &str) -> Option<u32> { fn token_to_id(&self, token: &str) -> Option<u32> {

View File

@ -4,10 +4,9 @@
//! In charge of training a BPE model //! In charge of training a BPE model
//! //!
use super::{Pair, Word, BPE}; use super::{Pair, Word, BPE};
use crate::tokenizer::{Model, Trainer}; use crate::tokenizer::{Model, Result, Trainer};
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
error::Error,
time::Instant, time::Instant,
}; };
@ -50,10 +49,7 @@ impl BpeTrainer {
impl Trainer for BpeTrainer { impl Trainer for BpeTrainer {
/// Train a BPE model /// Train a BPE model
fn train( fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
&self,
word_counts: HashMap<String, u32>,
) -> Result<Box<dyn Model + Sync>, Box<dyn Error>> {
let mut words: Vec<Word> = vec![]; let mut words: Vec<Word> = vec![];
let mut counts: Vec<i32> = vec![]; let mut counts: Vec<i32> = vec![];
let mut word_to_id: HashMap<String, u32> = HashMap::new(); let mut word_to_id: HashMap<String, u32> = HashMap::new();

View File

@ -1,4 +1,4 @@
use crate::tokenizer::{Decoder, PreTokenizer}; use crate::tokenizer::{Decoder, PreTokenizer, Result};
use regex::Regex; use regex::Regex;
use std::collections::HashMap; use std::collections::HashMap;
@ -32,8 +32,9 @@ lazy_static! {
pub struct ByteLevel; pub struct ByteLevel;
impl PreTokenizer for ByteLevel { impl PreTokenizer for ByteLevel {
fn pre_tokenize(&self, s: &str) -> Vec<String> { fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
RE.captures_iter(s) Ok(RE
.captures_iter(s)
.map(|capture| { .map(|capture| {
let capture = capture.get(0).unwrap(); let capture = capture.get(0).unwrap();
let start = capture.start(); let start = capture.start();
@ -78,20 +79,20 @@ impl PreTokenizer for ByteLevel {
.map(|b| std::char::from_u32(BYTES_CHAR[b]).unwrap()) .map(|b| std::char::from_u32(BYTES_CHAR[b]).unwrap())
.collect() .collect()
}) })
.collect() .collect())
} }
} }
impl Decoder for ByteLevel { impl Decoder for ByteLevel {
fn decode(&self, tokens: Vec<String>) -> String { fn decode(&self, tokens: Vec<String>) -> Result<String> {
String::from_utf8_lossy( Ok(String::from_utf8_lossy(
&tokens &tokens
.join("") .join("")
.chars() .chars()
.map(|c| CHAR_BYTES[&(c as u32)]) .map(|c| CHAR_BYTES[&(c as u32)])
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
) )
.into_owned() .into_owned())
} }
} }
@ -104,7 +105,9 @@ mod tests {
fn pre_tokenization() { fn pre_tokenization() {
let pre_tok = ByteLevel; let pre_tok = ByteLevel;
assert_eq!( assert_eq!(
pre_tok.pre_tokenize("Hello my friend, how is your day going?"), pre_tok
.pre_tokenize("Hello my friend, how is your day going?")
.unwrap(),
vec![ vec![
"Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?" "Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?"
] ]
@ -116,14 +119,17 @@ mod tests {
let decoder = ByteLevel; let decoder = ByteLevel;
assert_eq!( assert_eq!(
"Hello my friend, how is your day going?", "Hello my friend, how is your day going?",
decoder.decode( decoder
.decode(
vec![ vec![
"Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?" "Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing",
"?"
] ]
.into_iter() .into_iter()
.map(|s| s.into()) .map(|s| s.into())
.collect::<Vec<String>>() .collect::<Vec<String>>()
) )
.unwrap()
); );
} }
@ -141,13 +147,13 @@ mod tests {
let bl = ByteLevel; let bl = ByteLevel;
for sample in samples { for sample in samples {
let pre_tokenized = bl.pre_tokenize(&sample); let pre_tokenized = bl.pre_tokenize(&sample).unwrap();
let separated_tokens = pre_tokenized let separated_tokens = pre_tokenized
.into_iter() .into_iter()
.map(|token| token.split("").map(|t| t.into()).collect::<Vec<_>>()) .map(|token| token.split("").map(|t| t.into()).collect::<Vec<_>>())
.flatten() .flatten()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
assert_eq!(sample, bl.decode(separated_tokens)); assert_eq!(sample, bl.decode(separated_tokens).unwrap());
} }
} }
} }