mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-16 17:18:43 +00:00
Fixing deserialization/serialization.
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
use crate::models::unigram::lattice::Lattice;
|
||||
use crate::models::unigram::trie::{Trie, TrieBuilder};
|
||||
use crate::tokenizer::{Model, Result, Token};
|
||||
use serde::Deserialize;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryInto;
|
||||
@@ -13,18 +12,26 @@ type TokenMap = HashMap<String, u32>;
|
||||
type Vocab = Vec<String>;
|
||||
|
||||
/// A `Unigram` model to encode sentences.
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Clone)]
|
||||
pub struct Unigram {
|
||||
token_to_ids: TokenMap,
|
||||
pub(crate) vocab: Vocab,
|
||||
pub(super) scores: Vec<f64>,
|
||||
#[serde(skip_deserializing, default = "empty_trie")]
|
||||
trie: Trie<char>,
|
||||
pub min_score: f64,
|
||||
pub(super) unk_id: usize,
|
||||
pub(super) bos_id: usize,
|
||||
pub(super) eos_id: usize,
|
||||
}
|
||||
impl PartialEq for Unigram {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
let vocab: Vec<(&String, &f64)> = self.vocab.iter().zip(self.scores.iter()).collect();
|
||||
let other_vocab: Vec<(&String, &f64)> =
|
||||
other.vocab.iter().zip(other.scores.iter()).collect();
|
||||
self.unk_id == other.unk_id && vocab == other_vocab
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Unigram {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fmt.debug_struct("BPE")
|
||||
@@ -33,10 +40,6 @@ impl std::fmt::Debug for Unigram {
|
||||
}
|
||||
}
|
||||
|
||||
fn empty_trie() -> Trie<char> {
|
||||
TrieBuilder::default().build()
|
||||
}
|
||||
|
||||
static K_UNK_PENALTY: f64 = 10.0;
|
||||
|
||||
impl Default for Unigram {
|
||||
|
||||
@@ -1,16 +1,79 @@
|
||||
use super::model::Unigram;
|
||||
use serde::{ser::SerializeSeq, Serialize, Serializer};
|
||||
use serde::{
|
||||
de::{Error, MapAccess, Visitor},
|
||||
ser::SerializeStruct,
|
||||
Deserialize, Deserializer, Serialize, Serializer,
|
||||
};
|
||||
|
||||
impl Serialize for Unigram {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut seq = serializer.serialize_seq(Some(self.len()))?;
|
||||
for i in 0..self.len() {
|
||||
seq.serialize_element(&(&self.vocab[i], &self.scores[i]))?;
|
||||
}
|
||||
let mut model = serializer.serialize_struct("Unigram", 2)?;
|
||||
|
||||
seq.end()
|
||||
model.serialize_field("unk_id", &self.unk_id)?;
|
||||
|
||||
let vocab: Vec<(&String, &f64)> = self.vocab.iter().zip(self.scores.iter()).collect();
|
||||
model.serialize_field("vocab", &vocab)?;
|
||||
|
||||
model.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Unigram {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
deserializer.deserialize_struct("Unigram", &["vocab", "unk_id"], UnigramVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
struct UnigramVisitor;
|
||||
impl<'de> Visitor<'de> for UnigramVisitor {
|
||||
type Value = Unigram;
|
||||
|
||||
fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(fmt, "struct Unigram")
|
||||
}
|
||||
|
||||
fn visit_map<V>(self, mut map: V) -> std::result::Result<Self::Value, V::Error>
|
||||
where
|
||||
V: MapAccess<'de>,
|
||||
{
|
||||
let mut vocab: Option<Vec<(String, f64)>> = None;
|
||||
let mut unk_id: Option<usize> = None;
|
||||
while let Some(key) = map.next_key::<String>()? {
|
||||
match key.as_ref() {
|
||||
"unk_id" => {
|
||||
unk_id = map.next_value()?;
|
||||
}
|
||||
"vocab" => vocab = Some(map.next_value()?),
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
match (vocab, unk_id) {
|
||||
(Some(vocab), Some(unk_id)) => Ok(Unigram::from(&vocab, unk_id)),
|
||||
(None, Some(_)) => Err(Error::custom("Missing vocab")),
|
||||
(None, None) => Err(Error::custom("Missing vocab and unk_id")),
|
||||
(Some(_), None) => Err(Error::custom("Missing unk_id")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_serialization() {
|
||||
let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)];
|
||||
let model = Unigram::from(&vocab, 0);
|
||||
|
||||
let data = serde_json::to_string(&model).unwrap();
|
||||
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||
|
||||
assert_eq!(model, reconstructed);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ impl<Label: Eq + Hash + Copy> TrieBuilder<Label> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Trie<Label> {
|
||||
root: Node<Label>,
|
||||
}
|
||||
@@ -57,6 +58,7 @@ impl<Label> Default for Trie<Label> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Node<Label> {
|
||||
is_leaf: bool,
|
||||
children: HashMap<Label, Node<Label>>,
|
||||
|
||||
Reference in New Issue
Block a user