From f88a6b40ac56181ee3fd22141e3a21167b1571ce Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Sat, 22 Feb 2020 00:01:32 +0100 Subject: [PATCH] Make parameter name on Model.save() optional. Signed-off-by: Morgan Funtowicz --- bindings/python/src/models.rs | 17 ++++++++++++++++- .../implementations/base_tokenizer.py | 6 +++--- bindings/python/tokenizers/models/__init__.pyi | 2 +- tokenizers/src/models/bpe/model.rs | 16 +++++++++++++--- tokenizers/src/models/wordlevel/mod.rs | 9 +++++++-- tokenizers/src/models/wordpiece/mod.rs | 9 +++++++-- tokenizers/src/tokenizer/mod.rs | 2 +- 7 files changed, 48 insertions(+), 13 deletions(-) diff --git a/bindings/python/src/models.rs b/bindings/python/src/models.rs index 0fa61635..e87267f0 100644 --- a/bindings/python/src/models.rs +++ b/bindings/python/src/models.rs @@ -23,7 +23,22 @@ impl Model { )) } - fn save(&self, folder: &str, name: &str) -> PyResult> { + #[args(kwargs = "**")] + fn save(&self, folder: &str, kwargs: Option<&PyDict>) -> PyResult> { + let mut name: Option<&str> = None; + if let Some(kwargs) = kwargs { + for (key, value) in kwargs { + let key: &str = key.extract()?; + match key { + "name" => { + name = value.extract()? + } + _ => println!("Ignored unknown kwarg option {}", key), + } + } + } + + let saved: PyResult> = ToPyResult( self.model .execute(|model| model.save(Path::new(folder), name)), diff --git a/bindings/python/tokenizers/implementations/base_tokenizer.py b/bindings/python/tokenizers/implementations/base_tokenizer.py index 384a6904..1e368f2e 100644 --- a/bindings/python/tokenizers/implementations/base_tokenizer.py +++ b/bindings/python/tokenizers/implementations/base_tokenizer.py @@ -209,14 +209,14 @@ class BaseTokenizer: """ return self._tokenizer.id_to_token(id) - def save(self, directory: str, name: str): + def save(self, directory: str, name: str = None): """ Save the current model to the given directory Args: directory: str: A path to the destination directory - name: str: + name: (Optional) str: The name of the tokenizer, to be used in the saved files """ - return self._tokenizer.model.save(directory, name) + return self._tokenizer.model.save(directory, name=name) diff --git a/bindings/python/tokenizers/models/__init__.pyi b/bindings/python/tokenizers/models/__init__.pyi index d7e85112..0fd36487 100644 --- a/bindings/python/tokenizers/models/__init__.pyi +++ b/bindings/python/tokenizers/models/__init__.pyi @@ -7,7 +7,7 @@ class Model: a Model will return a instance of this class when instantiated. """ - def save(self, folder: str, name: str) -> List[str]: + def save(self, folder: str, name: Optional[str] = None) -> List[str]: """ Save the current model Save the current model in the given folder, using the given name for the various diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 36613e64..b4d58cc2 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -418,9 +418,14 @@ impl Model for BPE { self.vocab_r.get(&id).cloned() } - fn save(&self, folder: &Path, name: &str) -> Result> { + fn save(&self, folder: &Path, name: Option<&str>) -> Result> { + let vocab_file_name = match name { + Some(name) => format!("{}-vocab.json", name).to_string(), + None => "vocab.json".to_string() + }; + // Write vocab.json - let vocab_path: PathBuf = [folder, Path::new(&format!("{}-vocab.json", name))] + let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())] .iter() .collect(); let mut vocab_file = File::create(&vocab_path)?; @@ -429,7 +434,12 @@ impl Model for BPE { vocab_file.write_all(&serialized.as_bytes())?; // Write merges.txt - let merges_path: PathBuf = [folder, Path::new(&format!("{}-merges.txt", name))] + let merges_file_name = match name { + Some(name) => format!("{}-merges.txt", name).to_string(), + None => "merges.txt".to_string() + }; + + let merges_path: PathBuf = [folder, Path::new(merges_file_name.as_str())] .iter() .collect(); let mut merges_file = File::create(&merges_path)?; diff --git a/tokenizers/src/models/wordlevel/mod.rs b/tokenizers/src/models/wordlevel/mod.rs index 1743c43b..87f8a39b 100644 --- a/tokenizers/src/models/wordlevel/mod.rs +++ b/tokenizers/src/models/wordlevel/mod.rs @@ -162,9 +162,14 @@ impl Model for WordLevel { self.vocab.keys().len() } - fn save(&self, folder: &Path, name: &str) -> Result> { + fn save(&self, folder: &Path, name: Option<&str>) -> Result> { + let vocab_file_name = match name { + Some(name) => format!("{}-vocab.json", name).to_string(), + None => "vocab.json".to_string() + }; + // Write vocab.txt - let vocab_path: PathBuf = [folder, Path::new(&format!("{}-vocab.txt", name))] + let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())] .iter() .collect(); let mut vocab_file = File::create(&vocab_path)?; diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 935749bc..2ee5961b 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -250,9 +250,14 @@ impl Model for WordPiece { self.vocab_r.get(&id).cloned() } - fn save(&self, folder: &Path, name: &str) -> Result> { + fn save(&self, folder: &Path, name: Option<&str>) -> Result> { + let vocab_file_name = match name { + Some(name) => format!("{}-vocab.json", name).to_string(), + None => "vocab.json".to_string() + }; + // Write vocab.txt - let vocab_path: PathBuf = [folder, Path::new(&format!("{}-vocab.txt", name))] + let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())] .iter() .collect(); let mut vocab_file = File::create(&vocab_path)?; diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 00a53a16..6a4f21a6 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -42,7 +42,7 @@ pub trait Model { fn token_to_id(&self, token: &str) -> Option; fn id_to_token(&self, id: u32) -> Option; fn get_vocab_size(&self) -> usize; - fn save(&self, folder: &Path, name: &str) -> Result>; + fn save(&self, folder: &Path, name: Option<&str>) -> Result>; } /// A `PostProcessor` has the responsibility to post process an encoded output of the `Tokenizer`.