Rust - serialization fixes + loading/saving methods

This commit is contained in:
Anthony MOI
2020-05-19 19:01:48 -04:00
parent 5d4dfc2340
commit cffcbb95fc
3 changed files with 41 additions and 23 deletions

View File

@@ -605,7 +605,7 @@ impl Tokenizer {
}
}
#[args(pretty = true)]
#[args(pretty = false)]
fn save(&self, path: &str, pretty: bool) -> PyResult<()> {
ToPyResult(self.tokenizer.save(path, pretty)).into()
}

View File

@@ -31,7 +31,8 @@ mod serialization;
pub use encoding::*;
pub use normalizer::*;
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
pub type Error = Box<dyn std::error::Error + Send + Sync>;
pub type Result<T> = std::result::Result<T, Error>;
pub type Offsets = (usize, usize);
#[typetag::serde(tag = "type")]
@@ -308,8 +309,16 @@ pub struct Tokenizer {
padding: Option<PaddingParams>,
}
impl std::str::FromStr for Tokenizer {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
Ok(serde_json::from_str(s)?)
}
}
impl Tokenizer {
/// Instanciate a new Tokenizer, with the given Model
/// Instantiate a new Tokenizer, with the given Model
pub fn new(model: Box<dyn Model>) -> Self {
Tokenizer {
normalizer: None,
@@ -330,6 +339,32 @@ impl Tokenizer {
}
}
/// Instantiate a new Tokenizer from the given file
pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
let file = File::open(file)?;
let buf = BufReader::new(file);
Ok(serde_json::from_reader(buf)?)
}
/// Serialize the current tokenizer as a String
pub fn to_string(&self, pretty: bool) -> Result<String> {
Ok(if pretty {
serde_json::to_string_pretty(self)?
} else {
serde_json::to_string(self)?
})
}
/// Save the current tokenizer at the given path
pub fn save(&self, path: &str, pretty: bool) -> Result<()> {
let serialized = self.to_string(pretty)?;
let mut file = File::create(path)?;
file.write_all(&serialized.as_bytes())?;
Ok(())
}
/// Set the normalizer
pub fn with_normalizer(&mut self, normalizer: Box<dyn Normalizer>) -> &Self {
self.normalizer = Some(normalizer);
@@ -987,23 +1022,4 @@ impl Tokenizer {
.collect()
}
}
/// Serialize the current tokenizer as a String
pub fn to_string(&self, pretty: bool) -> Result<String> {
Ok(if pretty {
serde_json::to_string_pretty(self)?
} else {
serde_json::to_string(self)?
})
}
/// Save the current tokenizer at the given path
pub fn save(&self, path: &str, pretty: bool) -> Result<()> {
let serialized = self.to_string(pretty)?;
let mut file = File::create(path)?;
file.write_all(&serialized.as_bytes())?;
Ok(())
}
}

View File

@@ -36,7 +36,7 @@ impl Serialize for Tokenizer {
tokenizer.serialize_field("padding", &self.padding)?;
// Added tokens
let added_tokens = self
let mut added_tokens = self
.added_tokens_map_r
.iter()
.map(|(id, token)| AddedTokenWithId {
@@ -45,6 +45,8 @@ impl Serialize for Tokenizer {
token: token.clone(),
})
.collect::<Vec<_>>();
// We need to have these added tokens ordered by ascending ID
added_tokens.sort_unstable_by_key(|o| o.id);
tokenizer.serialize_field("added_tokens", &added_tokens)?;
// Then add our parts