diff --git a/bindings/python/py_src/tokenizers/implementations/bert_wordpiece.py b/bindings/python/py_src/tokenizers/implementations/bert_wordpiece.py index e4c48545..284e8e7c 100644 --- a/bindings/python/py_src/tokenizers/implementations/bert_wordpiece.py +++ b/bindings/python/py_src/tokenizers/implementations/bert_wordpiece.py @@ -80,6 +80,10 @@ class BertWordPieceTokenizer(BaseTokenizer): super().__init__(tokenizer, parameters) + def from_file(vocab: str, **kwargs): + vocab = WordPiece.read_file(vocab) + return BertWordPieceTokenizer(vocab, **kwargs) + def train( self, files: Union[str, List[str]], diff --git a/bindings/python/py_src/tokenizers/implementations/byte_level_bpe.py b/bindings/python/py_src/tokenizers/implementations/byte_level_bpe.py index 18e61a7b..bb9ce30a 100644 --- a/bindings/python/py_src/tokenizers/implementations/byte_level_bpe.py +++ b/bindings/python/py_src/tokenizers/implementations/byte_level_bpe.py @@ -77,6 +77,11 @@ class ByteLevelBPETokenizer(BaseTokenizer): super().__init__(tokenizer, parameters) + @staticmethod + def from_files(vocab_filename: str, merges_filename: str, **kwargs): + vocab, merges = BPE.read_files(vocab_filename, merges_filename) + return ByteLevelBPETokenizer(vocab, merges, **kwargs) + def train( self, files: Union[str, List[str]], diff --git a/bindings/python/py_src/tokenizers/implementations/char_level_bpe.py b/bindings/python/py_src/tokenizers/implementations/char_level_bpe.py index 088c8e78..e57c0c40 100644 --- a/bindings/python/py_src/tokenizers/implementations/char_level_bpe.py +++ b/bindings/python/py_src/tokenizers/implementations/char_level_bpe.py @@ -94,6 +94,11 @@ class CharBPETokenizer(BaseTokenizer): super().__init__(tokenizer, parameters) + @staticmethod + def from_files(vocab_filename: str, merges_filename: str, **kwargs): + vocab, merges = BPE.read_files(vocab_filename, merges_filename) + return CharBPETokenizer(vocab, merges, **kwargs) + def train( self, files: Union[str, List[str]], diff --git a/bindings/python/py_src/tokenizers/implementations/sentencepiece_bpe.py b/bindings/python/py_src/tokenizers/implementations/sentencepiece_bpe.py index 1d900921..bac48291 100644 --- a/bindings/python/py_src/tokenizers/implementations/sentencepiece_bpe.py +++ b/bindings/python/py_src/tokenizers/implementations/sentencepiece_bpe.py @@ -47,6 +47,11 @@ class SentencePieceBPETokenizer(BaseTokenizer): super().__init__(tokenizer, parameters) + @staticmethod + def from_files(vocab_filename: str, merges_filename: str, **kwargs): + vocab, merges = BPE.read_files(vocab_filename, merges_filename) + return SentencePieceBPETokenizer(vocab, merges, **kwargs) + def train( self, files: Union[str, List[str]], diff --git a/bindings/python/py_src/tokenizers/models/__init__.pyi b/bindings/python/py_src/tokenizers/models/__init__.pyi index 11fbb42d..850cdf04 100644 --- a/bindings/python/py_src/tokenizers/models/__init__.pyi +++ b/bindings/python/py_src/tokenizers/models/__init__.pyi @@ -62,8 +62,6 @@ class BPE(Model): fuse_unk: (`optional`) bool: Multiple unk tokens get fused into only 1 """ - - @staticmethod def __init__( self, vocab: Optional[Union[str, Dict[str, int]]], @@ -77,6 +75,15 @@ class BPE(Model): ): pass + @staticmethod + def read_files(vocab_filename: str, merges_filename: str) -> Tuple[Vocab, Merges]: + pass + + @staticmethod + def from_files(vocab_filename: str, merges_filename: str, **kwargs) -> BPE: + vocab, merges = BPE.read_files(vocab_filename, merges_filename) + return BPE(vocab, merges, **kwargs) + class WordPiece(Model): """ WordPiece model class @@ -101,6 +108,15 @@ class WordPiece(Model): ): pass + @staticmethod + def read_file(vocab_filename: str) -> Tuple[Vocab]: + pass + + @staticmethod + def from_files(vocab_filename: str, **kwargs) -> WordPiece: + vocab = WordPiece.read_files(vocab_filename) + return WordPiece(vocab, **kwargs) + class WordLevel(Model): """ Most simple tokenizer model based on mapping token from a vocab file to their corresponding id. @@ -118,6 +134,15 @@ class WordLevel(Model): def __init__(self, vocab: Optional[Union[str, Dict[str, int]]], unk_token: Optional[str]): pass + @staticmethod + def read_file(vocab_filename: str) -> Tuple[Vocab]: + pass + + @staticmethod + def from_files(vocab_filename: str, **kwargs) -> WordLevel: + vocab = WordLevel.read_files(vocab_filename) + return WordLevel(vocab, **kwargs) + class Unigram(Model): """UnigramEncoding model class diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 4733815b..6009b8d1 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -7,7 +7,7 @@ use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::*; use serde::{Deserialize, Serialize}; -use tk::models::bpe::{BpeBuilder, BPE}; +use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE}; use tk::models::unigram::Unigram; use tk::models::wordlevel::WordLevel; use tk::models::wordpiece::{WordPiece, WordPieceBuilder}; @@ -217,16 +217,13 @@ impl PyBPE { } #[staticmethod] - #[args(kwargs = "**")] - fn from_files( - vocab_filename: String, - merges_filename: String, - kwargs: Option<&PyDict>, - ) -> PyResult<(Self, PyModel)> { - let mut builder = BPE::builder(); - builder = builder.files(vocab_filename, merges_filename); - - PyBPE::with_builder(builder, kwargs) + fn read_files(vocab_filename: &str, merges_filename: &str) -> PyResult<(Vocab, Merges)> { + BPE::read_files(vocab_filename, merges_filename).map_err(|e| { + exceptions::PyValueError::new_err(format!( + "Error while reading vocab&merges files: {}", + e + )) + }) } } @@ -292,10 +289,10 @@ impl PyWordPiece { } #[staticmethod] - fn from_file(vocab: String, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> { - let mut builder = WordPiece::builder(); - builder = builder.files(vocab); - PyWordPiece::with_builder(builder, kwargs) + fn read_file(vocab_filename: &str) -> PyResult { + WordPiece::read_files(vocab_filename).map_err(|e| { + exceptions::PyValueError::new_err(format!("Error while reading WordPiece file: {}", e)) + }) } } @@ -356,15 +353,10 @@ impl PyWordLevel { } #[staticmethod] - fn from_file(vocab_filename: &str, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> { - let unk_token = PyWordLevel::get_unk(kwargs)?; - let model = WordLevel::from_files(vocab_filename, unk_token).map_err(|e| { - exceptions::PyException::new_err(format!( - "Error while loading WordLevel from file: {}", - e - )) - })?; - Ok((PyWordLevel {}, PyModel::new(Arc::new(model.into())))) + fn read_file(vocab_filename: &str) -> PyResult { + WordLevel::read_files(vocab_filename).map_err(|e| { + exceptions::PyValueError::new_err(format!("Error while reading WordLevel file: {}", e)) + }) } } diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 13df4c05..8ef6bade 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -12,9 +12,9 @@ use std::{ path::{Path, PathBuf}, }; -type Vocab = HashMap; +pub type Vocab = HashMap; type VocabR = HashMap; -type Merges = HashMap; +pub type Merges = HashMap; struct Config { files: Option<(String, String)>, diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index c7ba24f2..c0a979e4 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -9,6 +9,8 @@ use std::path::{Path, PathBuf}; mod serialization; +type Vocab = HashMap; + #[derive(Debug)] pub enum Error { MissingUnkToken, @@ -105,9 +107,7 @@ impl WordLevel { WordLevelBuilder::new() } - /// Initialize a WordLevel model from vocab and merges file. - pub fn from_files(vocab_path: &str, unk_token: String) -> Result { - // Read vocab.json + pub fn read_files(vocab_path: &str) -> Result { let vocab_file = File::open(vocab_path)?; let mut vocab_file = BufReader::new(vocab_file); let mut buffer = String::new(); @@ -127,7 +127,12 @@ impl WordLevel { } _ => return Err(Box::new(Error::BadVocabulary)), }; + Ok(vocab) + } + /// Initialize a WordLevel model from vocab and merges file. + pub fn from_files(vocab_path: &str, unk_token: String) -> Result { + let vocab = WordLevel::read_files(vocab_path)?; Ok(Self::builder().vocab(vocab).unk_token(unk_token).build()) } }