Make parameter name on Model.save() optional.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
Morgan Funtowicz
2020-02-22 00:01:32 +01:00
parent 11dd6c8bae
commit f88a6b40ac
7 changed files with 48 additions and 13 deletions

View File

@@ -23,7 +23,22 @@ impl Model {
)) ))
} }
fn save(&self, folder: &str, name: &str) -> PyResult<Vec<String>> { #[args(kwargs = "**")]
fn save(&self, folder: &str, kwargs: Option<&PyDict>) -> PyResult<Vec<String>> {
let mut name: Option<&str> = None;
if let Some(kwargs) = kwargs {
for (key, value) in kwargs {
let key: &str = key.extract()?;
match key {
"name" => {
name = value.extract()?
}
_ => println!("Ignored unknown kwarg option {}", key),
}
}
}
let saved: PyResult<Vec<_>> = ToPyResult( let saved: PyResult<Vec<_>> = ToPyResult(
self.model self.model
.execute(|model| model.save(Path::new(folder), name)), .execute(|model| model.save(Path::new(folder), name)),

View File

@@ -209,14 +209,14 @@ class BaseTokenizer:
""" """
return self._tokenizer.id_to_token(id) return self._tokenizer.id_to_token(id)
def save(self, directory: str, name: str): def save(self, directory: str, name: str = None):
""" Save the current model to the given directory """ Save the current model to the given directory
Args: Args:
directory: str: directory: str:
A path to the destination directory A path to the destination directory
name: str: name: (Optional) str:
The name of the tokenizer, to be used in the saved files The name of the tokenizer, to be used in the saved files
""" """
return self._tokenizer.model.save(directory, name) return self._tokenizer.model.save(directory, name=name)

View File

@@ -7,7 +7,7 @@ class Model:
a Model will return a instance of this class when instantiated. a Model will return a instance of this class when instantiated.
""" """
def save(self, folder: str, name: str) -> List[str]: def save(self, folder: str, name: Optional[str] = None) -> List[str]:
""" Save the current model """ Save the current model
Save the current model in the given folder, using the given name for the various Save the current model in the given folder, using the given name for the various

View File

@@ -418,9 +418,14 @@ impl Model for BPE {
self.vocab_r.get(&id).cloned() self.vocab_r.get(&id).cloned()
} }
fn save(&self, folder: &Path, name: &str) -> Result<Vec<PathBuf>> { fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
let vocab_file_name = match name {
Some(name) => format!("{}-vocab.json", name).to_string(),
None => "vocab.json".to_string()
};
// Write vocab.json // Write vocab.json
let vocab_path: PathBuf = [folder, Path::new(&format!("{}-vocab.json", name))] let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
.iter() .iter()
.collect(); .collect();
let mut vocab_file = File::create(&vocab_path)?; let mut vocab_file = File::create(&vocab_path)?;
@@ -429,7 +434,12 @@ impl Model for BPE {
vocab_file.write_all(&serialized.as_bytes())?; vocab_file.write_all(&serialized.as_bytes())?;
// Write merges.txt // Write merges.txt
let merges_path: PathBuf = [folder, Path::new(&format!("{}-merges.txt", name))] let merges_file_name = match name {
Some(name) => format!("{}-merges.txt", name).to_string(),
None => "merges.txt".to_string()
};
let merges_path: PathBuf = [folder, Path::new(merges_file_name.as_str())]
.iter() .iter()
.collect(); .collect();
let mut merges_file = File::create(&merges_path)?; let mut merges_file = File::create(&merges_path)?;

View File

@@ -162,9 +162,14 @@ impl Model for WordLevel {
self.vocab.keys().len() self.vocab.keys().len()
} }
fn save(&self, folder: &Path, name: &str) -> Result<Vec<PathBuf>> { fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
let vocab_file_name = match name {
Some(name) => format!("{}-vocab.json", name).to_string(),
None => "vocab.json".to_string()
};
// Write vocab.txt // Write vocab.txt
let vocab_path: PathBuf = [folder, Path::new(&format!("{}-vocab.txt", name))] let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
.iter() .iter()
.collect(); .collect();
let mut vocab_file = File::create(&vocab_path)?; let mut vocab_file = File::create(&vocab_path)?;

View File

@@ -250,9 +250,14 @@ impl Model for WordPiece {
self.vocab_r.get(&id).cloned() self.vocab_r.get(&id).cloned()
} }
fn save(&self, folder: &Path, name: &str) -> Result<Vec<PathBuf>> { fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
let vocab_file_name = match name {
Some(name) => format!("{}-vocab.json", name).to_string(),
None => "vocab.json".to_string()
};
// Write vocab.txt // Write vocab.txt
let vocab_path: PathBuf = [folder, Path::new(&format!("{}-vocab.txt", name))] let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
.iter() .iter()
.collect(); .collect();
let mut vocab_file = File::create(&vocab_path)?; let mut vocab_file = File::create(&vocab_path)?;

View File

@@ -42,7 +42,7 @@ pub trait Model {
fn token_to_id(&self, token: &str) -> Option<u32>; fn token_to_id(&self, token: &str) -> Option<u32>;
fn id_to_token(&self, id: u32) -> Option<String>; fn id_to_token(&self, id: u32) -> Option<String>;
fn get_vocab_size(&self) -> usize; fn get_vocab_size(&self) -> usize;
fn save(&self, folder: &Path, name: &str) -> Result<Vec<PathBuf>>; fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>>;
} }
/// A `PostProcessor` has the responsibility to post process an encoded output of the `Tokenizer`. /// A `PostProcessor` has the responsibility to post process an encoded output of the `Tokenizer`.