mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Ability to load a BPE model from files
This commit is contained in:
30
tokenizers/Cargo.lock
generated
30
tokenizers/Cargo.lock
generated
@ -227,6 +227,11 @@ dependencies = [
|
|||||||
"quick-error 1.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
"quick-error 1.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itoa"
|
||||||
|
version = "0.4.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lazy_static"
|
name = "lazy_static"
|
||||||
version = "1.4.0"
|
version = "1.4.0"
|
||||||
@ -388,6 +393,11 @@ dependencies = [
|
|||||||
"semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ryu"
|
||||||
|
version = "1.0.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "scopeguard"
|
name = "scopeguard"
|
||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
@ -406,6 +416,21 @@ name = "semver-parser"
|
|||||||
version = "0.7.0"
|
version = "0.7.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde"
|
||||||
|
version = "1.0.102"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_json"
|
||||||
|
version = "1.0.41"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
dependencies = [
|
||||||
|
"itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"ryu 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"serde 1.0.102 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "shlex"
|
name = "shlex"
|
||||||
version = "0.1.1"
|
version = "0.1.1"
|
||||||
@ -447,6 +472,7 @@ dependencies = [
|
|||||||
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"onig 5.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"onig 5.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
"rayon 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
"rayon 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -541,6 +567,7 @@ dependencies = [
|
|||||||
"checksum glob 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
|
"checksum glob 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574"
|
||||||
"checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120"
|
"checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120"
|
||||||
"checksum humantime 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f"
|
"checksum humantime 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f"
|
||||||
|
"checksum itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "501266b7edd0174f8530248f87f99c88fbe60ca4ef3dd486835b8d8d53136f7f"
|
||||||
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||||
"checksum libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)" = "1a31a0627fdf1f6a39ec0dd577e101440b7db22672c0901fe00a9a6fbb5c24e8"
|
"checksum libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)" = "1a31a0627fdf1f6a39ec0dd577e101440b7db22672c0901fe00a9a6fbb5c24e8"
|
||||||
"checksum libloading 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f2b111a074963af1d37a139918ac6d49ad1d0d5e47f72fd55388619691a7d753"
|
"checksum libloading 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f2b111a074963af1d37a139918ac6d49ad1d0d5e47f72fd55388619691a7d753"
|
||||||
@ -562,9 +589,12 @@ dependencies = [
|
|||||||
"checksum regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)" = "11a7e20d1cce64ef2fed88b66d347f88bd9babb82845b2b858f3edbf59a4f716"
|
"checksum regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)" = "11a7e20d1cce64ef2fed88b66d347f88bd9babb82845b2b858f3edbf59a4f716"
|
||||||
"checksum rustc-demangle 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783"
|
"checksum rustc-demangle 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783"
|
||||||
"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
|
"checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
|
||||||
|
"checksum ryu 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "bfa8506c1de11c9c4e4c38863ccbe02a305c8188e85a05a784c9e11e1c3910c8"
|
||||||
"checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d"
|
"checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d"
|
||||||
"checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403"
|
"checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403"
|
||||||
"checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3"
|
"checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3"
|
||||||
|
"checksum serde 1.0.102 (registry+https://github.com/rust-lang/crates.io-index)" = "0c4b39bd9b0b087684013a792c59e3e07a46a01d2322518d8a1104641a0b1be0"
|
||||||
|
"checksum serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)" = "2f72eb2a68a7dc3f9a691bfda9305a1c017a6215e5a4545c258500d2099a37c2"
|
||||||
"checksum shlex 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7fdf1b9db47230893d76faad238fd6097fd6d6a9245cd7a4d90dbd639536bbd2"
|
"checksum shlex 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7fdf1b9db47230893d76faad238fd6097fd6d6a9245cd7a4d90dbd639536bbd2"
|
||||||
"checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
|
"checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
|
||||||
"checksum termcolor 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "96d6098003bde162e4277c70665bd87c326f5a0c3f3fbfb285787fa482d54e6e"
|
"checksum termcolor 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "96d6098003bde162e4277c70665bd87c326f5a0c3f3fbfb285787fa482d54e6e"
|
||||||
|
@ -11,6 +11,7 @@ path = "src/cli.rs"
|
|||||||
lazy_static = "1.3.0"
|
lazy_static = "1.3.0"
|
||||||
onig = "5.0.0"
|
onig = "5.0.0"
|
||||||
rayon = "1.2.0"
|
rayon = "1.2.0"
|
||||||
|
serde_json = "1.0"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
name = "tokenizers"
|
name = "tokenizers"
|
||||||
|
@ -1,9 +1,37 @@
|
|||||||
|
use std::{convert::From, io};
|
||||||
|
|
||||||
mod model;
|
mod model;
|
||||||
mod trainer;
|
mod trainer;
|
||||||
mod word;
|
mod word;
|
||||||
|
|
||||||
pub type Pair = (u32, u32);
|
pub type Pair = (u32, u32);
|
||||||
|
|
||||||
|
/// ## Error
|
||||||
|
/// Errors that can be encountered while using BPE
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum Error {
|
||||||
|
/// An error encountered while reading files mainly.
|
||||||
|
Io(std::io::Error),
|
||||||
|
/// An error forwarded from Serde, while parsing JSON
|
||||||
|
JsonError(serde_json::Error),
|
||||||
|
/// When the vocab.json file is in the wrong format
|
||||||
|
BadVocabulary,
|
||||||
|
/// If a token found in merges, is not in the vocab
|
||||||
|
MergeTokenOutOfVocabulary(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<io::Error> for Error {
|
||||||
|
fn from(error: io::Error) -> Self {
|
||||||
|
Error::Io(error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<serde_json::Error> for Error {
|
||||||
|
fn from(error: serde_json::Error) -> Self {
|
||||||
|
Error::JsonError(error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Re-export
|
// Re-export
|
||||||
pub use model::*;
|
pub use model::*;
|
||||||
pub use trainer::*;
|
pub use trainer::*;
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
use super::{Pair, Word};
|
use super::{Error, Pair, Word};
|
||||||
use crate::tokenizer::{Model, Token};
|
use crate::tokenizer::{Model, Token};
|
||||||
use std::collections::HashMap;
|
use serde_json::Value;
|
||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
fs::File,
|
||||||
|
io::prelude::*,
|
||||||
|
io::{BufRead, BufReader},
|
||||||
|
};
|
||||||
|
|
||||||
pub struct BPE {
|
pub struct BPE {
|
||||||
/// The vocabulary assigns a number to each token
|
/// The vocabulary assigns a number to each token
|
||||||
@ -23,6 +29,62 @@ impl BPE {
|
|||||||
merges,
|
merges,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn from_files(vocab: &str, merges: &str) -> Result<Self, Error> {
|
||||||
|
// Read vocab.json
|
||||||
|
let vocab_file = File::open(vocab)?;
|
||||||
|
let mut vocab_file = BufReader::new(vocab_file);
|
||||||
|
|
||||||
|
let mut buffer = String::new();
|
||||||
|
vocab_file.read_to_string(&mut buffer)?;
|
||||||
|
let json: Value = serde_json::from_str(&buffer)?;
|
||||||
|
let mut vocab = HashMap::new();
|
||||||
|
match json {
|
||||||
|
Value::Object(m) => {
|
||||||
|
for (token, id) in m {
|
||||||
|
if let Value::Number(id) = id {
|
||||||
|
let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32;
|
||||||
|
vocab.insert(token, id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => return Err(Error::BadVocabulary),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Read merges file
|
||||||
|
let merge_file = File::open(merges)?;
|
||||||
|
let merge_file = BufReader::new(merge_file);
|
||||||
|
let mut merges = HashMap::<Pair, (u32, u32)>::new();
|
||||||
|
for (rank, line) in merge_file.lines().enumerate() {
|
||||||
|
let line = line?;
|
||||||
|
if line.starts_with("#version") {
|
||||||
|
// Skip line with: #version
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let parts = line.split(" ").collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let a = vocab
|
||||||
|
.get(parts[0])
|
||||||
|
.ok_or(Error::MergeTokenOutOfVocabulary(parts[0].to_owned()))?;
|
||||||
|
let b = vocab
|
||||||
|
.get(parts[1])
|
||||||
|
.ok_or(Error::MergeTokenOutOfVocabulary(parts[1].to_owned()))?;
|
||||||
|
let pair = (*a, *b);
|
||||||
|
let new_token = format!("{}{}", parts[0], parts[1]);
|
||||||
|
let new_id = vocab
|
||||||
|
.get(&new_token)
|
||||||
|
.ok_or(Error::MergeTokenOutOfVocabulary(new_token))?;
|
||||||
|
|
||||||
|
merges.insert(pair, (rank as u32, *new_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(BPE {
|
||||||
|
vocab: vocab.clone(),
|
||||||
|
vocab_r: vocab.into_iter().map(|(token, id)| (id, token)).collect(),
|
||||||
|
merges,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model for BPE {
|
impl Model for BPE {
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
//!
|
//!
|
||||||
//! In charge of training a BPE model
|
//! In charge of training a BPE model
|
||||||
//!
|
//!
|
||||||
use super::{Pair, Word, BPE};
|
use super::{Error, Pair, Word, BPE};
|
||||||
use crate::tokenizer::PreTokenizer;
|
use crate::tokenizer::PreTokenizer;
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use std::{
|
use std::{
|
||||||
@ -37,17 +37,6 @@ impl TrainerConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum Error {
|
|
||||||
Io(std::io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::convert::From<std::io::Error> for Error {
|
|
||||||
fn from(error: std::io::Error) -> Self {
|
|
||||||
Error::Io(error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct Trainer {
|
pub struct Trainer {
|
||||||
// Training parameters
|
// Training parameters
|
||||||
config: TrainerConfig,
|
config: TrainerConfig,
|
||||||
|
Reference in New Issue
Block a user