mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Ability to load a BPE model from files
This commit is contained in:
30
tokenizers/Cargo.lock
generated
30
tokenizers/Cargo.lock
generated
@ -227,6 +227,11 @@ dependencies = [
|
||||
"quick-error 1.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
@ -388,6 +393,11 @@ dependencies = [
|
||||
"semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.0.0"
|
||||
@ -406,6 +416,21 @@ name = "semver-parser"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.102"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.41"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"ryu 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde 1.0.102 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "0.1.1"
|
||||
@ -447,6 +472,7 @@ dependencies = [
|
||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"onig 5.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rayon 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -541,6 +567,7 @@ dependencies = [
|
||||
"checksum glob 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
|
||||
"checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120"
|
||||
"checksum humantime 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f"
|
||||
"checksum itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "501266b7edd0174f8530248f87f99c88fbe60ca4ef3dd486835b8d8d53136f7f"
|
||||
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
"checksum libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)" = "1a31a0627fdf1f6a39ec0dd577e101440b7db22672c0901fe00a9a6fbb5c24e8"
|
||||
"checksum libloading 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f2b111a074963af1d37a139918ac6d49ad1d0d5e47f72fd55388619691a7d753"
|
||||
@ -562,9 +589,12 @@ dependencies = [
|
||||
"checksum regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)" = "11a7e20d1cce64ef2fed88b66d347f88bd9babb82845b2b858f3edbf59a4f716"
|
||||
"checksum rustc-demangle 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783"
|
||||
"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
|
||||
"checksum ryu 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "bfa8506c1de11c9c4e4c38863ccbe02a305c8188e85a05a784c9e11e1c3910c8"
|
||||
"checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d"
|
||||
"checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403"
|
||||
"checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3"
|
||||
"checksum serde 1.0.102 (registry+https://github.com/rust-lang/crates.io-index)" = "0c4b39bd9b0b087684013a792c59e3e07a46a01d2322518d8a1104641a0b1be0"
|
||||
"checksum serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)" = "2f72eb2a68a7dc3f9a691bfda9305a1c017a6215e5a4545c258500d2099a37c2"
|
||||
"checksum shlex 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7fdf1b9db47230893d76faad238fd6097fd6d6a9245cd7a4d90dbd639536bbd2"
|
||||
"checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
|
||||
"checksum termcolor 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "96d6098003bde162e4277c70665bd87c326f5a0c3f3fbfb285787fa482d54e6e"
|
||||
|
@ -11,6 +11,7 @@ path = "src/cli.rs"
|
||||
lazy_static = "1.3.0"
|
||||
onig = "5.0.0"
|
||||
rayon = "1.2.0"
|
||||
serde_json = "1.0"
|
||||
|
||||
[lib]
|
||||
name = "tokenizers"
|
||||
|
@ -1,9 +1,37 @@
|
||||
use std::{convert::From, io};
|
||||
|
||||
mod model;
|
||||
mod trainer;
|
||||
mod word;
|
||||
|
||||
pub type Pair = (u32, u32);
|
||||
|
||||
/// ## Error
|
||||
/// Errors that can be encountered while using BPE
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
/// An error encountered while reading files mainly.
|
||||
Io(std::io::Error),
|
||||
/// An error forwarded from Serde, while parsing JSON
|
||||
JsonError(serde_json::Error),
|
||||
/// When the vocab.json file is in the wrong format
|
||||
BadVocabulary,
|
||||
/// If a token found in merges, is not in the vocab
|
||||
MergeTokenOutOfVocabulary(String),
|
||||
}
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
fn from(error: io::Error) -> Self {
|
||||
Error::Io(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for Error {
|
||||
fn from(error: serde_json::Error) -> Self {
|
||||
Error::JsonError(error)
|
||||
}
|
||||
}
|
||||
|
||||
// Re-export
|
||||
pub use model::*;
|
||||
pub use trainer::*;
|
||||
|
@ -1,6 +1,12 @@
|
||||
use super::{Pair, Word};
|
||||
use super::{Error, Pair, Word};
|
||||
use crate::tokenizer::{Model, Token};
|
||||
use std::collections::HashMap;
|
||||
use serde_json::Value;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs::File,
|
||||
io::prelude::*,
|
||||
io::{BufRead, BufReader},
|
||||
};
|
||||
|
||||
pub struct BPE {
|
||||
/// The vocabulary assigns a number to each token
|
||||
@ -23,6 +29,62 @@ impl BPE {
|
||||
merges,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_files(vocab: &str, merges: &str) -> Result<Self, Error> {
|
||||
// Read vocab.json
|
||||
let vocab_file = File::open(vocab)?;
|
||||
let mut vocab_file = BufReader::new(vocab_file);
|
||||
|
||||
let mut buffer = String::new();
|
||||
vocab_file.read_to_string(&mut buffer)?;
|
||||
let json: Value = serde_json::from_str(&buffer)?;
|
||||
let mut vocab = HashMap::new();
|
||||
match json {
|
||||
Value::Object(m) => {
|
||||
for (token, id) in m {
|
||||
if let Value::Number(id) = id {
|
||||
let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32;
|
||||
vocab.insert(token, id);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => return Err(Error::BadVocabulary),
|
||||
};
|
||||
|
||||
// Read merges file
|
||||
let merge_file = File::open(merges)?;
|
||||
let merge_file = BufReader::new(merge_file);
|
||||
let mut merges = HashMap::<Pair, (u32, u32)>::new();
|
||||
for (rank, line) in merge_file.lines().enumerate() {
|
||||
let line = line?;
|
||||
if line.starts_with("#version") {
|
||||
// Skip line with: #version
|
||||
continue;
|
||||
}
|
||||
|
||||
let parts = line.split(" ").collect::<Vec<_>>();
|
||||
|
||||
let a = vocab
|
||||
.get(parts[0])
|
||||
.ok_or(Error::MergeTokenOutOfVocabulary(parts[0].to_owned()))?;
|
||||
let b = vocab
|
||||
.get(parts[1])
|
||||
.ok_or(Error::MergeTokenOutOfVocabulary(parts[1].to_owned()))?;
|
||||
let pair = (*a, *b);
|
||||
let new_token = format!("{}{}", parts[0], parts[1]);
|
||||
let new_id = vocab
|
||||
.get(&new_token)
|
||||
.ok_or(Error::MergeTokenOutOfVocabulary(new_token))?;
|
||||
|
||||
merges.insert(pair, (rank as u32, *new_id));
|
||||
}
|
||||
|
||||
Ok(BPE {
|
||||
vocab: vocab.clone(),
|
||||
vocab_r: vocab.into_iter().map(|(token, id)| (id, token)).collect(),
|
||||
merges,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Model for BPE {
|
||||
|
@ -3,7 +3,7 @@
|
||||
//!
|
||||
//! In charge of training a BPE model
|
||||
//!
|
||||
use super::{Pair, Word, BPE};
|
||||
use super::{Error, Pair, Word, BPE};
|
||||
use crate::tokenizer::PreTokenizer;
|
||||
use rayon::prelude::*;
|
||||
use std::{
|
||||
@ -37,17 +37,6 @@ impl TrainerConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
Io(std::io::Error),
|
||||
}
|
||||
|
||||
impl std::convert::From<std::io::Error> for Error {
|
||||
fn from(error: std::io::Error) -> Self {
|
||||
Error::Io(error)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Trainer {
|
||||
// Training parameters
|
||||
config: TrainerConfig,
|
||||
|
Reference in New Issue
Block a user