mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-05 20:28:22 +00:00
Rust - serialization fixes + loading/saving methods
This commit is contained in:
@@ -605,7 +605,7 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[args(pretty = true)]
|
#[args(pretty = false)]
|
||||||
fn save(&self, path: &str, pretty: bool) -> PyResult<()> {
|
fn save(&self, path: &str, pretty: bool) -> PyResult<()> {
|
||||||
ToPyResult(self.tokenizer.save(path, pretty)).into()
|
ToPyResult(self.tokenizer.save(path, pretty)).into()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,7 +31,8 @@ mod serialization;
|
|||||||
pub use encoding::*;
|
pub use encoding::*;
|
||||||
pub use normalizer::*;
|
pub use normalizer::*;
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
|
pub type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||||
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
pub type Offsets = (usize, usize);
|
pub type Offsets = (usize, usize);
|
||||||
|
|
||||||
#[typetag::serde(tag = "type")]
|
#[typetag::serde(tag = "type")]
|
||||||
@@ -308,8 +309,16 @@ pub struct Tokenizer {
|
|||||||
padding: Option<PaddingParams>,
|
padding: Option<PaddingParams>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::str::FromStr for Tokenizer {
|
||||||
|
type Err = Error;
|
||||||
|
|
||||||
|
fn from_str(s: &str) -> Result<Self> {
|
||||||
|
Ok(serde_json::from_str(s)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Tokenizer {
|
impl Tokenizer {
|
||||||
/// Instanciate a new Tokenizer, with the given Model
|
/// Instantiate a new Tokenizer, with the given Model
|
||||||
pub fn new(model: Box<dyn Model>) -> Self {
|
pub fn new(model: Box<dyn Model>) -> Self {
|
||||||
Tokenizer {
|
Tokenizer {
|
||||||
normalizer: None,
|
normalizer: None,
|
||||||
@@ -330,6 +339,32 @@ impl Tokenizer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Instantiate a new Tokenizer from the given file
|
||||||
|
pub fn from_file<P: AsRef<Path>>(file: P) -> Result<Self> {
|
||||||
|
let file = File::open(file)?;
|
||||||
|
let buf = BufReader::new(file);
|
||||||
|
Ok(serde_json::from_reader(buf)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Serialize the current tokenizer as a String
|
||||||
|
pub fn to_string(&self, pretty: bool) -> Result<String> {
|
||||||
|
Ok(if pretty {
|
||||||
|
serde_json::to_string_pretty(self)?
|
||||||
|
} else {
|
||||||
|
serde_json::to_string(self)?
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Save the current tokenizer at the given path
|
||||||
|
pub fn save(&self, path: &str, pretty: bool) -> Result<()> {
|
||||||
|
let serialized = self.to_string(pretty)?;
|
||||||
|
|
||||||
|
let mut file = File::create(path)?;
|
||||||
|
file.write_all(&serialized.as_bytes())?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the normalizer
|
/// Set the normalizer
|
||||||
pub fn with_normalizer(&mut self, normalizer: Box<dyn Normalizer>) -> &Self {
|
pub fn with_normalizer(&mut self, normalizer: Box<dyn Normalizer>) -> &Self {
|
||||||
self.normalizer = Some(normalizer);
|
self.normalizer = Some(normalizer);
|
||||||
@@ -987,23 +1022,4 @@ impl Tokenizer {
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Serialize the current tokenizer as a String
|
|
||||||
pub fn to_string(&self, pretty: bool) -> Result<String> {
|
|
||||||
Ok(if pretty {
|
|
||||||
serde_json::to_string_pretty(self)?
|
|
||||||
} else {
|
|
||||||
serde_json::to_string(self)?
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Save the current tokenizer at the given path
|
|
||||||
pub fn save(&self, path: &str, pretty: bool) -> Result<()> {
|
|
||||||
let serialized = self.to_string(pretty)?;
|
|
||||||
|
|
||||||
let mut file = File::create(path)?;
|
|
||||||
file.write_all(&serialized.as_bytes())?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ impl Serialize for Tokenizer {
|
|||||||
tokenizer.serialize_field("padding", &self.padding)?;
|
tokenizer.serialize_field("padding", &self.padding)?;
|
||||||
|
|
||||||
// Added tokens
|
// Added tokens
|
||||||
let added_tokens = self
|
let mut added_tokens = self
|
||||||
.added_tokens_map_r
|
.added_tokens_map_r
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(id, token)| AddedTokenWithId {
|
.map(|(id, token)| AddedTokenWithId {
|
||||||
@@ -45,6 +45,8 @@ impl Serialize for Tokenizer {
|
|||||||
token: token.clone(),
|
token: token.clone(),
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
// We need to have these added tokens ordered by ascending ID
|
||||||
|
added_tokens.sort_unstable_by_key(|o| o.id);
|
||||||
tokenizer.serialize_field("added_tokens", &added_tokens)?;
|
tokenizer.serialize_field("added_tokens", &added_tokens)?;
|
||||||
|
|
||||||
// Then add our parts
|
// Then add our parts
|
||||||
|
|||||||
Reference in New Issue
Block a user