Fixing deserialization/serialization.

This commit is contained in:
Nicolas Patry
2020-08-21 11:08:02 +02:00
parent d12439be61
commit c7a84c7cc6
3 changed files with 81 additions and 13 deletions

View File

@@ -1,7 +1,6 @@
use crate::models::unigram::lattice::Lattice;
use crate::models::unigram::trie::{Trie, TrieBuilder};
use crate::tokenizer::{Model, Result, Token};
use serde::Deserialize;
use std::collections::HashMap;
use std::convert::TryInto;
@@ -13,18 +12,26 @@ type TokenMap = HashMap<String, u32>;
type Vocab = Vec<String>;
/// A `Unigram` model to encode sentences.
#[derive(Deserialize)]
#[derive(Clone)]
pub struct Unigram {
token_to_ids: TokenMap,
pub(crate) vocab: Vocab,
pub(super) scores: Vec<f64>,
#[serde(skip_deserializing, default = "empty_trie")]
trie: Trie<char>,
pub min_score: f64,
pub(super) unk_id: usize,
pub(super) bos_id: usize,
pub(super) eos_id: usize,
}
impl PartialEq for Unigram {
fn eq(&self, other: &Self) -> bool {
let vocab: Vec<(&String, &f64)> = self.vocab.iter().zip(self.scores.iter()).collect();
let other_vocab: Vec<(&String, &f64)> =
other.vocab.iter().zip(other.scores.iter()).collect();
self.unk_id == other.unk_id && vocab == other_vocab
}
}
impl std::fmt::Debug for Unigram {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("BPE")
@@ -33,10 +40,6 @@ impl std::fmt::Debug for Unigram {
}
}
fn empty_trie() -> Trie<char> {
TrieBuilder::default().build()
}
static K_UNK_PENALTY: f64 = 10.0;
impl Default for Unigram {

View File

@@ -1,16 +1,79 @@
use super::model::Unigram;
use serde::{ser::SerializeSeq, Serialize, Serializer};
use serde::{
de::{Error, MapAccess, Visitor},
ser::SerializeStruct,
Deserialize, Deserializer, Serialize, Serializer,
};
impl Serialize for Unigram {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_seq(Some(self.len()))?;
for i in 0..self.len() {
seq.serialize_element(&(&self.vocab[i], &self.scores[i]))?;
}
let mut model = serializer.serialize_struct("Unigram", 2)?;
seq.end()
model.serialize_field("unk_id", &self.unk_id)?;
let vocab: Vec<(&String, &f64)> = self.vocab.iter().zip(self.scores.iter()).collect();
model.serialize_field("vocab", &vocab)?;
model.end()
}
}
impl<'de> Deserialize<'de> for Unigram {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_struct("Unigram", &["vocab", "unk_id"], UnigramVisitor)
}
}
struct UnigramVisitor;
impl<'de> Visitor<'de> for UnigramVisitor {
type Value = Unigram;
fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(fmt, "struct Unigram")
}
fn visit_map<V>(self, mut map: V) -> std::result::Result<Self::Value, V::Error>
where
V: MapAccess<'de>,
{
let mut vocab: Option<Vec<(String, f64)>> = None;
let mut unk_id: Option<usize> = None;
while let Some(key) = map.next_key::<String>()? {
match key.as_ref() {
"unk_id" => {
unk_id = map.next_value()?;
}
"vocab" => vocab = Some(map.next_value()?),
_ => (),
}
}
match (vocab, unk_id) {
(Some(vocab), Some(unk_id)) => Ok(Unigram::from(&vocab, unk_id)),
(None, Some(_)) => Err(Error::custom("Missing vocab")),
(None, None) => Err(Error::custom("Missing vocab and unk_id")),
(Some(_), None) => Err(Error::custom("Missing unk_id")),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_serialization() {
let vocab = vec![("<unk>".to_string(), 0.0), ("a".to_string(), -0.5)];
let model = Unigram::from(&vocab, 0);
let data = serde_json::to_string(&model).unwrap();
let reconstructed = serde_json::from_str(&data).unwrap();
assert_eq!(model, reconstructed);
}
}

View File

@@ -16,6 +16,7 @@ impl<Label: Eq + Hash + Copy> TrieBuilder<Label> {
}
}
#[derive(Clone)]
pub struct Trie<Label> {
root: Node<Label>,
}
@@ -57,6 +58,7 @@ impl<Label> Default for Trie<Label> {
}
}
#[derive(Clone)]
pub struct Node<Label> {
is_leaf: bool,
children: HashMap<Label, Node<Label>>,