diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 096f0e14..b2d69f38 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -605,7 +605,7 @@ impl Tokenizer { } } - #[args(pretty = true)] + #[args(pretty = false)] fn save(&self, path: &str, pretty: bool) -> PyResult<()> { ToPyResult(self.tokenizer.save(path, pretty)).into() } diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index dc831d4a..9b901d4e 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -31,7 +31,8 @@ mod serialization; pub use encoding::*; pub use normalizer::*; -pub type Result = std::result::Result>; +pub type Error = Box; +pub type Result = std::result::Result; pub type Offsets = (usize, usize); #[typetag::serde(tag = "type")] @@ -308,8 +309,16 @@ pub struct Tokenizer { padding: Option, } +impl std::str::FromStr for Tokenizer { + type Err = Error; + + fn from_str(s: &str) -> Result { + Ok(serde_json::from_str(s)?) + } +} + impl Tokenizer { - /// Instanciate a new Tokenizer, with the given Model + /// Instantiate a new Tokenizer, with the given Model pub fn new(model: Box) -> Self { Tokenizer { normalizer: None, @@ -330,6 +339,32 @@ impl Tokenizer { } } + /// Instantiate a new Tokenizer from the given file + pub fn from_file>(file: P) -> Result { + let file = File::open(file)?; + let buf = BufReader::new(file); + Ok(serde_json::from_reader(buf)?) + } + + /// Serialize the current tokenizer as a String + pub fn to_string(&self, pretty: bool) -> Result { + Ok(if pretty { + serde_json::to_string_pretty(self)? + } else { + serde_json::to_string(self)? + }) + } + + /// Save the current tokenizer at the given path + pub fn save(&self, path: &str, pretty: bool) -> Result<()> { + let serialized = self.to_string(pretty)?; + + let mut file = File::create(path)?; + file.write_all(&serialized.as_bytes())?; + + Ok(()) + } + /// Set the normalizer pub fn with_normalizer(&mut self, normalizer: Box) -> &Self { self.normalizer = Some(normalizer); @@ -987,23 +1022,4 @@ impl Tokenizer { .collect() } } - - /// Serialize the current tokenizer as a String - pub fn to_string(&self, pretty: bool) -> Result { - Ok(if pretty { - serde_json::to_string_pretty(self)? - } else { - serde_json::to_string(self)? - }) - } - - /// Save the current tokenizer at the given path - pub fn save(&self, path: &str, pretty: bool) -> Result<()> { - let serialized = self.to_string(pretty)?; - - let mut file = File::create(path)?; - file.write_all(&serialized.as_bytes())?; - - Ok(()) - } } diff --git a/tokenizers/src/tokenizer/serialization.rs b/tokenizers/src/tokenizer/serialization.rs index 25bb41aa..75be73af 100644 --- a/tokenizers/src/tokenizer/serialization.rs +++ b/tokenizers/src/tokenizer/serialization.rs @@ -36,7 +36,7 @@ impl Serialize for Tokenizer { tokenizer.serialize_field("padding", &self.padding)?; // Added tokens - let added_tokens = self + let mut added_tokens = self .added_tokens_map_r .iter() .map(|(id, token)| AddedTokenWithId { @@ -45,6 +45,8 @@ impl Serialize for Tokenizer { token: token.clone(), }) .collect::>(); + // We need to have these added tokens ordered by ascending ID + added_tokens.sort_unstable_by_key(|o| o.id); tokenizer.serialize_field("added_tokens", &added_tokens)?; // Then add our parts