Files
akaza/libakaza/src/graph/graph_builder.rs
Tokuhiro Matsuno c5444759e5 refactor kana_kanji
2023-01-22 00:02:50 +09:00

238 lines
8.9 KiB
Rust

use std::collections::btree_map::BTreeMap;
use std::collections::HashSet;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use kelp::{hira2kata, ConvOption};
use log::trace;
use crate::graph::lattice_graph::LatticeGraph;
use crate::graph::segmenter::SegmentationResult;
use crate::graph::word_node::WordNode;
use crate::kana_kanji::base::KanaKanjiDict;
use crate::lm::base::{SystemBigramLM, SystemUnigramLM};
use crate::user_side_data::user_data::UserData;
pub struct GraphBuilder<U: SystemUnigramLM, B: SystemBigramLM, KD: KanaKanjiDict> {
system_kana_kanji_dict: KD,
system_single_term_dict: KD,
user_data: Arc<Mutex<UserData>>,
system_unigram_lm: Rc<U>,
system_bigram_lm: Rc<B>,
}
impl<U: SystemUnigramLM, B: SystemBigramLM, KD: KanaKanjiDict> GraphBuilder<U, B, KD> {
pub fn new(
system_kana_kanji_dict: KD,
system_single_term_dict: KD,
user_data: Arc<Mutex<UserData>>,
system_unigram_lm: Rc<U>,
system_bigram_lm: Rc<B>,
) -> GraphBuilder<U, B, KD> {
GraphBuilder {
system_kana_kanji_dict,
system_single_term_dict,
user_data,
system_unigram_lm,
system_bigram_lm,
}
}
pub fn new_with_default_score(
system_kana_kanji_dict: KD,
system_single_term_dict: KD,
user_data: Arc<Mutex<UserData>>,
system_unigram_lm: Rc<U>,
system_bigram_lm: Rc<B>,
) -> GraphBuilder<U, B, KD> {
Self::new(
system_kana_kanji_dict,
system_single_term_dict,
user_data,
system_unigram_lm,
system_bigram_lm,
)
}
pub fn construct(&self, yomi: &str, words_ends_at: &SegmentationResult) -> LatticeGraph<U, B> {
// このグラフのインデクスは単語の終了位置。
let mut graph: BTreeMap<i32, Vec<WordNode>> = BTreeMap::new();
graph.insert(0, vec![WordNode::create_bos()]);
graph.insert(
(yomi.len() + 1) as i32,
vec![WordNode::create_eos(yomi.len() as i32)],
);
for (end_pos, segmented_yomis) in words_ends_at.iter() {
for segmented_yomi in segmented_yomis {
let vec = graph.entry(*end_pos as i32).or_default();
let mut seen: HashSet<String> = HashSet::new();
// 漢字に変換した結果もあれば insert する。
if let Some(kanjis) = self.system_kana_kanji_dict.get(segmented_yomi) {
for kanji in kanjis {
let node = WordNode::new(
(end_pos - segmented_yomi.len()) as i32,
&kanji,
segmented_yomi,
self.system_unigram_lm
.find((kanji.to_string() + "/" + segmented_yomi).as_str()),
);
trace!("WordIDScore: {:?}", node.word_id_and_score);
vec.push(node);
seen.insert(kanji.to_string());
}
}
for surface in [
segmented_yomi,
hira2kata(segmented_yomi, ConvOption::default()).as_str(),
] {
if seen.contains(surface) {
continue;
}
// ひらがなそのものと、カタカナ表現もエントリーとして登録しておく。
let node = WordNode::new(
(end_pos - segmented_yomi.len()) as i32,
surface,
segmented_yomi,
None,
);
vec.push(node);
}
// 変換範囲が全体になっていれば single term 辞書を利用する。
if segmented_yomi == yomi {
if let Some(surfaces) = self.system_single_term_dict.get(yomi) {
for surface in surfaces {
let node = WordNode::new(
(end_pos - segmented_yomi.len()) as i32,
&surface,
segmented_yomi,
self.system_unigram_lm
.find((surface.to_string() + "/" + segmented_yomi).as_str()),
);
vec.push(node);
}
}
}
}
}
LatticeGraph {
graph,
yomi: yomi.to_string(),
user_data: self.user_data.clone(),
system_unigram_lm: self.system_unigram_lm.clone(),
system_bigram_lm: self.system_bigram_lm.clone(),
}
}
}
#[cfg(test)]
mod tests {
use crate::kana_kanji::hashmap_vec::HashmapVecKanaKanjiDict;
use crate::lm::system_bigram::MarisaSystemBigramLMBuilder;
use crate::lm::system_unigram_lm::MarisaSystemUnigramLMBuilder;
use std::collections::HashMap;
use super::*;
#[test]
fn test_single_term() -> anyhow::Result<()> {
let graph_builder = GraphBuilder::new_with_default_score(
HashmapVecKanaKanjiDict::new(HashMap::new()),
HashmapVecKanaKanjiDict::new(HashMap::from([(
"すし".to_string(),
vec!["🍣".to_string()],
)])),
Arc::new(Mutex::new(UserData::default())),
Rc::new(
MarisaSystemUnigramLMBuilder::default()
.set_default_cost(20_f32)
.set_default_cost_for_short(19_f32)
.build(),
),
Rc::new(
MarisaSystemBigramLMBuilder::default()
.set_default_edge_cost(20_f32)
.build()?,
),
);
let yomi = "すし";
let got = graph_builder.construct(
yomi,
&SegmentationResult::new(BTreeMap::from([(6, vec!["すし".to_string()])])),
);
let nodes = got.node_list(6).unwrap();
let got_surfaces: Vec<String> = nodes.iter().map(|f| f.surface.to_string()).collect();
assert_eq!(
got_surfaces,
vec!["すし".to_string(), "スシ".to_string(), "🍣".to_string()]
);
Ok(())
}
// ひらがな、カタカナのエントリーが自動的に入るようにする。
#[test]
fn test_default_terms() -> anyhow::Result<()> {
let graph_builder = GraphBuilder::new_with_default_score(
HashmapVecKanaKanjiDict::new(HashMap::new()),
HashmapVecKanaKanjiDict::new(HashMap::new()),
Arc::new(Mutex::new(UserData::default())),
Rc::new(
MarisaSystemUnigramLMBuilder::default()
.set_default_cost(20_f32)
.set_default_cost_for_short(19_f32)
.build(),
),
Rc::new(
MarisaSystemBigramLMBuilder::default()
.set_default_edge_cost(20_f32)
.build()?,
),
);
let yomi = "";
let got = graph_builder.construct(
yomi,
&SegmentationResult::new(BTreeMap::from([(3, vec!["".to_string()])])),
);
let nodes = got.node_list(3).unwrap();
let got_surfaces: Vec<String> = nodes.iter().map(|f| f.surface.to_string()).collect();
assert_eq!(got_surfaces, vec!["".to_string(), "".to_string()]);
Ok(())
}
// ひらがな、カタカナがすでにかな漢字辞書から提供されている場合でも、重複させない。
#[test]
fn test_default_terms_duplicated() -> anyhow::Result<()> {
let graph_builder = GraphBuilder::new_with_default_score(
HashmapVecKanaKanjiDict::new(HashMap::from([(
"".to_string(),
vec!["".to_string(), "".to_string()],
)])),
HashmapVecKanaKanjiDict::new(HashMap::new()),
Arc::new(Mutex::new(UserData::default())),
Rc::new(
MarisaSystemUnigramLMBuilder::default()
.set_default_cost(20_f32)
.set_default_cost_for_short(19_f32)
.build(),
),
Rc::new(
MarisaSystemBigramLMBuilder::default()
.set_default_edge_cost(20_f32)
.build()?,
),
);
let yomi = "";
let got = graph_builder.construct(
yomi,
&SegmentationResult::new(BTreeMap::from([(3, vec!["".to_string()])])),
);
let nodes = got.node_list(3).unwrap();
let got_surfaces: Vec<String> = nodes.iter().map(|f| f.surface.to_string()).collect();
assert_eq!(got_surfaces, vec!["".to_string(), "".to_string()]);
Ok(())
}
}