Make parameter name on Model.save() optional.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
Morgan Funtowicz
2020-02-22 00:01:32 +01:00
parent 11dd6c8bae
commit f88a6b40ac
7 changed files with 48 additions and 13 deletions

View File

@@ -23,7 +23,22 @@ impl Model {
))
}
fn save(&self, folder: &str, name: &str) -> PyResult<Vec<String>> {
#[args(kwargs = "**")]
fn save(&self, folder: &str, kwargs: Option<&PyDict>) -> PyResult<Vec<String>> {
let mut name: Option<&str> = None;
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
match key {
"name" => {
name = value.extract()?
}
_ => println!("Ignored unknown kwarg option {}", key),
}
}
}
let saved: PyResult<Vec<_>> = ToPyResult(
self.model
.execute(|model| model.save(Path::new(folder), name)),

View File

@@ -209,14 +209,14 @@ class BaseTokenizer:
"""
return self._tokenizer.id_to_token(id)
def save(self, directory: str, name: str):
def save(self, directory: str, name: str = None):
""" Save the current model to the given directory
Args:
directory: str:
A path to the destination directory
name: str:
name: (Optional) str:
The name of the tokenizer, to be used in the saved files
"""
return self._tokenizer.model.save(directory, name)
return self._tokenizer.model.save(directory, name=name)

View File

@@ -7,7 +7,7 @@ class Model:
a Model will return a instance of this class when instantiated.
"""
def save(self, folder: str, name: str) -> List[str]:
def save(self, folder: str, name: Optional[str] = None) -> List[str]:
""" Save the current model
Save the current model in the given folder, using the given name for the various

View File

@@ -418,9 +418,14 @@ impl Model for BPE {
self.vocab_r.get(&id).cloned()
}
fn save(&self, folder: &Path, name: &str) -> Result<Vec<PathBuf>> {
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
let vocab_file_name = match name {
Some(name) => format!("{}-vocab.json", name).to_string(),
None => "vocab.json".to_string()
};
// Write vocab.json
let vocab_path: PathBuf = [folder, Path::new(&format!("{}-vocab.json", name))]
let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
.iter()
.collect();
let mut vocab_file = File::create(&vocab_path)?;
@@ -429,7 +434,12 @@ impl Model for BPE {
vocab_file.write_all(&serialized.as_bytes())?;
// Write merges.txt
let merges_path: PathBuf = [folder, Path::new(&format!("{}-merges.txt", name))]
let merges_file_name = match name {
Some(name) => format!("{}-merges.txt", name).to_string(),
None => "merges.txt".to_string()
};
let merges_path: PathBuf = [folder, Path::new(merges_file_name.as_str())]
.iter()
.collect();
let mut merges_file = File::create(&merges_path)?;

View File

@@ -162,9 +162,14 @@ impl Model for WordLevel {
self.vocab.keys().len()
}
fn save(&self, folder: &Path, name: &str) -> Result<Vec<PathBuf>> {
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
let vocab_file_name = match name {
Some(name) => format!("{}-vocab.json", name).to_string(),
None => "vocab.json".to_string()
};
// Write vocab.txt
let vocab_path: PathBuf = [folder, Path::new(&format!("{}-vocab.txt", name))]
let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
.iter()
.collect();
let mut vocab_file = File::create(&vocab_path)?;

View File

@@ -250,9 +250,14 @@ impl Model for WordPiece {
self.vocab_r.get(&id).cloned()
}
fn save(&self, folder: &Path, name: &str) -> Result<Vec<PathBuf>> {
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
let vocab_file_name = match name {
Some(name) => format!("{}-vocab.json", name).to_string(),
None => "vocab.json".to_string()
};
// Write vocab.txt
let vocab_path: PathBuf = [folder, Path::new(&format!("{}-vocab.txt", name))]
let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
.iter()
.collect();
let mut vocab_file = File::create(&vocab_path)?;

View File

@@ -42,7 +42,7 @@ pub trait Model {
fn token_to_id(&self, token: &str) -> Option<u32>;
fn id_to_token(&self, id: u32) -> Option<String>;
fn get_vocab_size(&self) -> usize;
fn save(&self, folder: &Path, name: &str) -> Result<Vec<PathBuf>>;
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>>;
}
/// A `PostProcessor` has the responsibility to post process an encoded output of the `Tokenizer`.