diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 97d7fd69..e95ede67 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -28,6 +28,7 @@ fn models(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index cb788b69..c7d3dd0e 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -153,3 +153,40 @@ impl WordPiece { } } } + +#[pyclass] +pub struct WordLevel {} + +#[pymethods] +impl WordLevel { + #[staticmethod] + #[args(kwargs = "**")] + fn from_files(vocab: &str, kwargs: Option<&PyDict>) -> PyResult { + let mut unk_token = String::from(""); + + if let Some(kwargs) = kwargs { + for (key, val) in kwargs { + let key: &str = key.extract()?; + match key { + "unk_token" => unk_token = val.extract()?, + _ => println!("Ignored unknown kwargs option {}", key), + } + } + } + + match tk::models::wordlevel::WordLevel::from_files( + vocab, + unk_token, + ) { + Err(e) => { + println!("Errors: {:?}", e); + Err(exceptions::Exception::py_err( + "Error while initializing WordLevel", + )) + } + Ok(model) => Ok(Model { + model: Container::Owned(Box::new(model)), + }), + } + } +} diff --git a/bindings/python/tokenizers/models/__init__.py b/bindings/python/tokenizers/models/__init__.py index 5aec8ea5..4c198b5b 100644 --- a/bindings/python/tokenizers/models/__init__.py +++ b/bindings/python/tokenizers/models/__init__.py @@ -3,3 +3,4 @@ from .. import models Model = models.Model BPE = models.BPE WordPiece = models.WordPiece +WordLevel = models.WordLevel \ No newline at end of file diff --git a/bindings/python/tokenizers/models/__init__.pyi b/bindings/python/tokenizers/models/__init__.pyi index eea5cbe4..0a7182dc 100644 --- a/bindings/python/tokenizers/models/__init__.pyi +++ b/bindings/python/tokenizers/models/__init__.pyi @@ -17,7 +17,7 @@ class Model: pass -class BPE: +class BPE(Model): """ BytePairEncoding model class """ @staticmethod @@ -62,7 +62,7 @@ class BPE: pass -class WordPiece: +class WordPiece(Model): """ WordPiece model class """ @staticmethod @@ -87,3 +87,22 @@ class WordPiece: def empty() -> Model: """ Instantiate an empty WordPiece Model. """ pass + + +class WordLevel(Model): + """ + Most simple tokenizer model based on mapping token from a vocab file to their corresponding id. + """ + + @staticmethod + def from_files(vocab: str, unk_token: str) -> Model: + """ Instantiate a WordLevel Model from the given vocab file. + + Args: + vocab: string: + Path to a vocabulary file. + + unk_token: str: + The unknown token to be used by the model. + """ + pass \ No newline at end of file diff --git a/tokenizers/src/models/mod.rs b/tokenizers/src/models/mod.rs index d142a3e0..e1a2ad8f 100644 --- a/tokenizers/src/models/mod.rs +++ b/tokenizers/src/models/mod.rs @@ -1,4 +1,5 @@ //! Popular tokenizer models. pub mod bpe; +pub mod wordlevel; pub mod wordpiece; diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs new file mode 100644 index 00000000..1743c43b --- /dev/null +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -0,0 +1,183 @@ +use crate::tokenizer::{Model, Result, Token}; +use serde_json::Value; +use std::collections::HashMap; +use std::fmt; +use std::fs::File; +use std::io::{BufReader, Read, Write}; +use std::path::{Path, PathBuf}; + +#[derive(Debug)] +pub enum Error { + MissingUnkToken, + BadVocabulary, +} +impl std::error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::MissingUnkToken => write!( + fmt, + "WordLevel error: Missing [UNK] token from the vocabulary" + ), + Error::BadVocabulary => write!(fmt, "Bad vocabulary json file"), + } + } +} + +struct Config { + vocab: HashMap, + unk_token: String, +} + +/// A `WordLevelBuilder` can be used to create a `WordLevel` +/// model with a custom configuration. +pub struct WordLevelBuilder { + config: Config, +} + +impl Default for WordLevelBuilder { + fn default() -> Self { + Self { + config: Config { + vocab: HashMap::new(), + unk_token: String::from(""), + }, + } + } +} + +impl WordLevelBuilder { + /// Construct a new `WordLevelBuilder`. + pub fn new() -> Self { + Self::default() + } + + /// Set the vocab (token -> ID) mapping. + pub fn vocab(mut self, vocab: HashMap) -> Self { + self.config.vocab = vocab; + self + } + + /// The the `UNK` token for the vocab. + pub fn unk_token(mut self, unk_token: String) -> Self { + self.config.unk_token = unk_token; + self + } + + /// Contructs a `WordLevel` model that uses the `WordLevelBuilder`'s configuration. + pub fn build(self) -> WordLevel { + let vocab_r = self + .config + .vocab + .iter() + .map(|(key, val)| (*val, key.to_owned())) + .collect(); + WordLevel { + vocab: self.config.vocab, + vocab_r, + unk_token: self.config.unk_token, + } + } +} + +pub struct WordLevel { + vocab: HashMap, + vocab_r: HashMap, + unk_token: String, +} + +impl WordLevel { + fn builder() -> WordLevelBuilder { + WordLevelBuilder::new() + } + + /// Initialize a WordLevel model from vocab and merges file. + pub fn from_files(vocab_path: &str, unk_token: String) -> Result { + // Read vocab.json + let vocab_file = File::open(vocab_path)?; + let mut vocab_file = BufReader::new(vocab_file); + let mut buffer = String::new(); + let mut vocab = HashMap::new(); + + vocab_file.read_to_string(&mut buffer)?; + let json: Value = serde_json::from_str(&buffer)?; + + match json { + Value::Object(m) => { + for (token, id) in m { + if let Value::Number(id) = id { + let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32; + vocab.insert(token, id); + } + } + } + _ => return Err(Box::new(Error::BadVocabulary)), + }; + + Ok(Self::builder().vocab(vocab).unk_token(unk_token).build()) + } +} + +impl Default for WordLevel { + fn default() -> Self { + Self { + vocab: HashMap::new(), + vocab_r: HashMap::new(), + unk_token: String::from(""), + } + } +} + +impl Model for WordLevel { + fn tokenize(&self, tokens: Vec<(String, (usize, usize))>) -> Result> { + let mut output_tokens = vec![]; + + for (token, initial_offsets) in tokens { + let t = Token { + id: *self + .vocab + .get(&*token) + .or_else(|| self.vocab.get(&*self.unk_token)) + .ok_or(Error::MissingUnkToken)?, + value: token, + offsets: initial_offsets, + }; + + output_tokens.push(t); + } + + Ok(output_tokens) + } + + fn token_to_id(&self, token: &str) -> Option { + self.vocab.get(token).copied() + } + + fn id_to_token(&self, id: u32) -> Option { + self.vocab_r.get(&id).cloned() + } + + fn get_vocab_size(&self) -> usize { + self.vocab.keys().len() + } + + fn save(&self, folder: &Path, name: &str) -> Result> { + // Write vocab.txt + let vocab_path: PathBuf = [folder, Path::new(&format!("{}-vocab.txt", name))] + .iter() + .collect(); + let mut vocab_file = File::create(&vocab_path)?; + let mut vocab: Vec<(&String, &u32)> = self.vocab.iter().collect(); + vocab.sort_unstable_by_key(|k| *k.1); + vocab_file.write_all( + &vocab + .into_iter() + .map(|(token, _)| format!("{}\n", token).as_bytes().to_owned()) + .flatten() + .collect::>()[..], + )?; + + Ok(vec![vocab_path]) + } +}