mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
BPE can fail
This commit is contained in:
@ -1,5 +1,5 @@
|
|||||||
use super::{Cache, Error, Pair, Word};
|
use super::{Cache, Error, Pair, Word};
|
||||||
use crate::tokenizer::{Model, Token};
|
use crate::tokenizer::{Model, Result, Token};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
@ -37,7 +37,7 @@ impl BPE {
|
|||||||
BPE::new(HashMap::new(), HashMap::new(), HashMap::new())
|
BPE::new(HashMap::new(), HashMap::new(), HashMap::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_files(vocab: &str, merges: &str) -> Result<Self, Error> {
|
pub fn from_files(vocab: &str, merges: &str) -> Result<Self> {
|
||||||
// Read vocab.json
|
// Read vocab.json
|
||||||
let vocab_file = File::open(vocab)?;
|
let vocab_file = File::open(vocab)?;
|
||||||
let mut vocab_file = BufReader::new(vocab_file);
|
let mut vocab_file = BufReader::new(vocab_file);
|
||||||
@ -55,7 +55,7 @@ impl BPE {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => return Err(Error::BadVocabulary),
|
_ => return Err(Box::new(Error::BadVocabulary)),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Read merges file
|
// Read merges file
|
||||||
@ -100,9 +100,9 @@ impl Model for BPE {
|
|||||||
self.vocab.len()
|
self.vocab.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tokenize(&self, sentence: Vec<String>) -> Vec<Token> {
|
fn tokenize(&self, sentence: Vec<String>) -> Result<Vec<Token>> {
|
||||||
if sentence.len() == 0 {
|
if sentence.len() == 0 {
|
||||||
return vec![];
|
return Ok(vec![]);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
|
let mut encoded: Vec<Token> = Vec::with_capacity(sentence.len());
|
||||||
@ -181,15 +181,16 @@ impl Model for BPE {
|
|||||||
.unzip::<_, _, Vec<String>, Vec<Word>>();
|
.unzip::<_, _, Vec<String>, Vec<Word>>();
|
||||||
self.cache.set_values(keys, values);
|
self.cache.set_values(keys, values);
|
||||||
|
|
||||||
encoded
|
Ok(encoded)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn decode(&self, ids: Vec<u32>) -> Vec<String> {
|
fn decode(&self, ids: Vec<u32>) -> Result<Vec<String>> {
|
||||||
ids.into_iter()
|
Ok(ids
|
||||||
|
.into_iter()
|
||||||
.map(|id| self.vocab_r.get(&id))
|
.map(|id| self.vocab_r.get(&id))
|
||||||
.filter(|token| token.is_some())
|
.filter(|token| token.is_some())
|
||||||
.map(|id| id.unwrap().clone())
|
.map(|id| id.unwrap().clone())
|
||||||
.collect()
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn token_to_id(&self, token: &str) -> Option<u32> {
|
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||||
|
@ -4,10 +4,9 @@
|
|||||||
//! In charge of training a BPE model
|
//! In charge of training a BPE model
|
||||||
//!
|
//!
|
||||||
use super::{Pair, Word, BPE};
|
use super::{Pair, Word, BPE};
|
||||||
use crate::tokenizer::{Model, Trainer};
|
use crate::tokenizer::{Model, Result, Trainer};
|
||||||
use std::{
|
use std::{
|
||||||
collections::{HashMap, HashSet},
|
collections::{HashMap, HashSet},
|
||||||
error::Error,
|
|
||||||
time::Instant,
|
time::Instant,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -50,10 +49,7 @@ impl BpeTrainer {
|
|||||||
|
|
||||||
impl Trainer for BpeTrainer {
|
impl Trainer for BpeTrainer {
|
||||||
/// Train a BPE model
|
/// Train a BPE model
|
||||||
fn train(
|
fn train(&self, word_counts: HashMap<String, u32>) -> Result<Box<dyn Model + Sync>> {
|
||||||
&self,
|
|
||||||
word_counts: HashMap<String, u32>,
|
|
||||||
) -> Result<Box<dyn Model + Sync>, Box<dyn Error>> {
|
|
||||||
let mut words: Vec<Word> = vec![];
|
let mut words: Vec<Word> = vec![];
|
||||||
let mut counts: Vec<i32> = vec![];
|
let mut counts: Vec<i32> = vec![];
|
||||||
let mut word_to_id: HashMap<String, u32> = HashMap::new();
|
let mut word_to_id: HashMap<String, u32> = HashMap::new();
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::tokenizer::{Decoder, PreTokenizer};
|
use crate::tokenizer::{Decoder, PreTokenizer, Result};
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -32,8 +32,9 @@ lazy_static! {
|
|||||||
|
|
||||||
pub struct ByteLevel;
|
pub struct ByteLevel;
|
||||||
impl PreTokenizer for ByteLevel {
|
impl PreTokenizer for ByteLevel {
|
||||||
fn pre_tokenize(&self, s: &str) -> Vec<String> {
|
fn pre_tokenize(&self, s: &str) -> Result<Vec<String>> {
|
||||||
RE.captures_iter(s)
|
Ok(RE
|
||||||
|
.captures_iter(s)
|
||||||
.map(|capture| {
|
.map(|capture| {
|
||||||
let capture = capture.get(0).unwrap();
|
let capture = capture.get(0).unwrap();
|
||||||
let start = capture.start();
|
let start = capture.start();
|
||||||
@ -78,20 +79,20 @@ impl PreTokenizer for ByteLevel {
|
|||||||
.map(|b| std::char::from_u32(BYTES_CHAR[b]).unwrap())
|
.map(|b| std::char::from_u32(BYTES_CHAR[b]).unwrap())
|
||||||
.collect()
|
.collect()
|
||||||
})
|
})
|
||||||
.collect()
|
.collect())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Decoder for ByteLevel {
|
impl Decoder for ByteLevel {
|
||||||
fn decode(&self, tokens: Vec<String>) -> String {
|
fn decode(&self, tokens: Vec<String>) -> Result<String> {
|
||||||
String::from_utf8_lossy(
|
Ok(String::from_utf8_lossy(
|
||||||
&tokens
|
&tokens
|
||||||
.join("")
|
.join("")
|
||||||
.chars()
|
.chars()
|
||||||
.map(|c| CHAR_BYTES[&(c as u32)])
|
.map(|c| CHAR_BYTES[&(c as u32)])
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
)
|
)
|
||||||
.into_owned()
|
.into_owned())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,7 +105,9 @@ mod tests {
|
|||||||
fn pre_tokenization() {
|
fn pre_tokenization() {
|
||||||
let pre_tok = ByteLevel;
|
let pre_tok = ByteLevel;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
pre_tok.pre_tokenize("Hello my friend, how is your day going?"),
|
pre_tok
|
||||||
|
.pre_tokenize("Hello my friend, how is your day going?")
|
||||||
|
.unwrap(),
|
||||||
vec![
|
vec![
|
||||||
"Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?"
|
"Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?"
|
||||||
]
|
]
|
||||||
@ -116,14 +119,17 @@ mod tests {
|
|||||||
let decoder = ByteLevel;
|
let decoder = ByteLevel;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
"Hello my friend, how is your day going?",
|
"Hello my friend, how is your day going?",
|
||||||
decoder.decode(
|
decoder
|
||||||
|
.decode(
|
||||||
vec![
|
vec![
|
||||||
"Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing", "?"
|
"Hello", "Ġmy", "Ġfriend", ",", "Ġhow", "Ġis", "Ġyour", "Ġday", "Ġgoing",
|
||||||
|
"?"
|
||||||
]
|
]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|s| s.into())
|
.map(|s| s.into())
|
||||||
.collect::<Vec<String>>()
|
.collect::<Vec<String>>()
|
||||||
)
|
)
|
||||||
|
.unwrap()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,13 +147,13 @@ mod tests {
|
|||||||
|
|
||||||
let bl = ByteLevel;
|
let bl = ByteLevel;
|
||||||
for sample in samples {
|
for sample in samples {
|
||||||
let pre_tokenized = bl.pre_tokenize(&sample);
|
let pre_tokenized = bl.pre_tokenize(&sample).unwrap();
|
||||||
let separated_tokens = pre_tokenized
|
let separated_tokens = pre_tokenized
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|token| token.split("").map(|t| t.into()).collect::<Vec<_>>())
|
.map(|token| token.split("").map(|t| t.into()).collect::<Vec<_>>())
|
||||||
.flatten()
|
.flatten()
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
assert_eq!(sample, bl.decode(separated_tokens));
|
assert_eq!(sample, bl.decode(separated_tokens).unwrap());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user