diff --git a/tokenizers/Cargo.lock b/tokenizers/Cargo.lock index 79c664a5..2e025672 100644 --- a/tokenizers/Cargo.lock +++ b/tokenizers/Cargo.lock @@ -227,6 +227,11 @@ dependencies = [ "quick-error 1.2.2 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "itoa" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "lazy_static" version = "1.4.0" @@ -388,6 +393,11 @@ dependencies = [ "semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "ryu" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "scopeguard" version = "1.0.0" @@ -406,6 +416,21 @@ name = "semver-parser" version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "serde" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "serde_json" +version = "1.0.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)", + "ryu 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", + "serde 1.0.102 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "shlex" version = "0.1.1" @@ -447,6 +472,7 @@ dependencies = [ "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "onig 5.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "rayon 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -541,6 +567,7 @@ dependencies = [ "checksum glob 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" "checksum hermit-abi 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "307c3c9f937f38e3534b1d6447ecf090cafcc9744e4a6360e8b037b2cf5af120" "checksum humantime 1.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f" +"checksum itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "501266b7edd0174f8530248f87f99c88fbe60ca4ef3dd486835b8d8d53136f7f" "checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" "checksum libc 0.2.65 (registry+https://github.com/rust-lang/crates.io-index)" = "1a31a0627fdf1f6a39ec0dd577e101440b7db22672c0901fe00a9a6fbb5c24e8" "checksum libloading 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f2b111a074963af1d37a139918ac6d49ad1d0d5e47f72fd55388619691a7d753" @@ -562,9 +589,12 @@ dependencies = [ "checksum regex-syntax 0.6.12 (registry+https://github.com/rust-lang/crates.io-index)" = "11a7e20d1cce64ef2fed88b66d347f88bd9babb82845b2b858f3edbf59a4f716" "checksum rustc-demangle 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783" "checksum rustc_version 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a" +"checksum ryu 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "bfa8506c1de11c9c4e4c38863ccbe02a305c8188e85a05a784c9e11e1c3910c8" "checksum scopeguard 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b42e15e59b18a828bbf5c58ea01debb36b9b096346de35d941dcb89009f24a0d" "checksum semver 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)" = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403" "checksum semver-parser 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" +"checksum serde 1.0.102 (registry+https://github.com/rust-lang/crates.io-index)" = "0c4b39bd9b0b087684013a792c59e3e07a46a01d2322518d8a1104641a0b1be0" +"checksum serde_json 1.0.41 (registry+https://github.com/rust-lang/crates.io-index)" = "2f72eb2a68a7dc3f9a691bfda9305a1c017a6215e5a4545c258500d2099a37c2" "checksum shlex 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7fdf1b9db47230893d76faad238fd6097fd6d6a9245cd7a4d90dbd639536bbd2" "checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" "checksum termcolor 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "96d6098003bde162e4277c70665bd87c326f5a0c3f3fbfb285787fa482d54e6e" diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 39613ca4..ec667da9 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -11,6 +11,7 @@ path = "src/cli.rs" lazy_static = "1.3.0" onig = "5.0.0" rayon = "1.2.0" +serde_json = "1.0" [lib] name = "tokenizers" diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index 4377cdd7..e3bccb8a 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -1,9 +1,37 @@ +use std::{convert::From, io}; + mod model; mod trainer; mod word; pub type Pair = (u32, u32); +/// ## Error +/// Errors that can be encountered while using BPE +#[derive(Debug)] +pub enum Error { + /// An error encountered while reading files mainly. + Io(std::io::Error), + /// An error forwarded from Serde, while parsing JSON + JsonError(serde_json::Error), + /// When the vocab.json file is in the wrong format + BadVocabulary, + /// If a token found in merges, is not in the vocab + MergeTokenOutOfVocabulary(String), +} + +impl From for Error { + fn from(error: io::Error) -> Self { + Error::Io(error) + } +} + +impl From for Error { + fn from(error: serde_json::Error) -> Self { + Error::JsonError(error) + } +} + // Re-export pub use model::*; pub use trainer::*; diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 009349e0..3ab37a2a 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -1,6 +1,12 @@ -use super::{Pair, Word}; +use super::{Error, Pair, Word}; use crate::tokenizer::{Model, Token}; -use std::collections::HashMap; +use serde_json::Value; +use std::{ + collections::HashMap, + fs::File, + io::prelude::*, + io::{BufRead, BufReader}, +}; pub struct BPE { /// The vocabulary assigns a number to each token @@ -23,6 +29,62 @@ impl BPE { merges, } } + + pub fn from_files(vocab: &str, merges: &str) -> Result { + // Read vocab.json + let vocab_file = File::open(vocab)?; + let mut vocab_file = BufReader::new(vocab_file); + + let mut buffer = String::new(); + vocab_file.read_to_string(&mut buffer)?; + let json: Value = serde_json::from_str(&buffer)?; + let mut vocab = HashMap::new(); + match json { + Value::Object(m) => { + for (token, id) in m { + if let Value::Number(id) = id { + let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32; + vocab.insert(token, id); + } + } + } + _ => return Err(Error::BadVocabulary), + }; + + // Read merges file + let merge_file = File::open(merges)?; + let merge_file = BufReader::new(merge_file); + let mut merges = HashMap::::new(); + for (rank, line) in merge_file.lines().enumerate() { + let line = line?; + if line.starts_with("#version") { + // Skip line with: #version + continue; + } + + let parts = line.split(" ").collect::>(); + + let a = vocab + .get(parts[0]) + .ok_or(Error::MergeTokenOutOfVocabulary(parts[0].to_owned()))?; + let b = vocab + .get(parts[1]) + .ok_or(Error::MergeTokenOutOfVocabulary(parts[1].to_owned()))?; + let pair = (*a, *b); + let new_token = format!("{}{}", parts[0], parts[1]); + let new_id = vocab + .get(&new_token) + .ok_or(Error::MergeTokenOutOfVocabulary(new_token))?; + + merges.insert(pair, (rank as u32, *new_id)); + } + + Ok(BPE { + vocab: vocab.clone(), + vocab_r: vocab.into_iter().map(|(token, id)| (id, token)).collect(), + merges, + }) + } } impl Model for BPE { diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index 1d3d4852..2721c5de 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -3,7 +3,7 @@ //! //! In charge of training a BPE model //! -use super::{Pair, Word, BPE}; +use super::{Error, Pair, Word, BPE}; use crate::tokenizer::PreTokenizer; use rayon::prelude::*; use std::{ @@ -37,17 +37,6 @@ impl TrainerConfig { } } -#[derive(Debug)] -pub enum Error { - Io(std::io::Error), -} - -impl std::convert::From for Error { - fn from(error: std::io::Error) -> Self { - Error::Io(error) - } -} - pub struct Trainer { // Training parameters config: TrainerConfig,