refactoring for structured perceptron

This commit is contained in:
Tokuhiro Matsuno
2023-01-07 09:34:17 +09:00
parent 649e018e73
commit d5a9f445ea
8 changed files with 200 additions and 23 deletions

View File

@ -1,7 +1,7 @@
use crate::subcmd::check::check;
use crate::subcmd::evaluate::evaluate;
use clap::{Parser, Subcommand};
use crate::subcmd::check::check;
use crate::subcmd::evaluate::evaluate;
use crate::subcmd::make_system_dict::make_system_dict;
use crate::subcmd::make_system_lm::make_system_lm;
use crate::subcmd::structured_perceptron::learn_structured_perceptron;
@ -34,7 +34,6 @@ enum Commands {
Evaluate(EvaluateArgs),
#[clap(arg_required_else_help = true)]
Check(CheckArgs),
#[clap(arg_required_else_help = true)]
LearnStructuredPerceptron(LearnStructuredPerceptronArgs),
}

View File

@ -1,16 +1,144 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::ops::Range;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use libakaza::akaza_builder::AkazaBuilder;
use log::{info, warn};
use libakaza::graph::graph_builder::GraphBuilder;
use libakaza::graph::graph_resolver::GraphResolver;
use libakaza::graph::segmenter::Segmenter;
use libakaza::graph::word_node::WordNode;
use libakaza::kana_kanji_dict::KanaKanjiDict;
use libakaza::kana_trie::marisa_kana_trie::MarisaKanaTrie;
use libakaza::lm::system_bigram::SystemBigramLMBuilder;
use libakaza::lm::system_unigram_lm::SystemUnigramLMBuilder;
use libakaza::user_side_data::user_data::UserData;
/// 構造化パーセプトロンの学習を行います。
/// 構造化パーセプトロンは、シンプルな実装で、そこそこのパフォーマンスがでる(予定)
/// 構造化パーセプトロンでいい感じに動くようならば、構造化SVMなどに挑戦したい。
pub fn learn_structured_perceptron() -> anyhow::Result<()> {
let akaza = AkazaBuilder::default().build()?;
// ここでは内部クラスなどを触ってスコア調整をしていかないといけないので、AkazaBuilder は使えない。
let force = Vec::new();
let lattice = akaza.to_lattice("ほげほげ", &force)?;
let result = akaza.resolve(&lattice)?;
let mut unigram_cost: HashMap<String, f32> = HashMap::new();
for _ in 1..20 {
let system_kana_kanji_dict = KanaKanjiDict::load("data/system_dict.trie")?;
// let system_kana_kanji_dict = KanaKanjiDictBuilder::default()
// .add("せんたくもの", "洗濯物")
// .add("せんたく", "選択/洗濯")
// .add("もの", "Mono")
// .add("ほす", "干す/HOS")
// .add("めんどう", "面倒")
// .build();
let system_single_term_dict = KanaKanjiDict::load("data/single_term.trie")?;
let all_yomis = system_kana_kanji_dict.all_yomis().unwrap();
let system_kana_trie = MarisaKanaTrie::build(all_yomis);
let segmenter = Segmenter::new(vec![Box::new(system_kana_trie)]);
let force_ranges: Vec<Range<usize>> = Vec::new();
let mut unigram_lm_builder = SystemUnigramLMBuilder::default();
for (key, cost) in &unigram_cost {
warn!("SYSTEM UNIGRM LM: {} cost={}", key.as_str(), *cost);
unigram_lm_builder.add(key.as_str(), *cost);
}
let system_unigram_lm = unigram_lm_builder.build();
let system_bigram_lm = SystemBigramLMBuilder::default().build();
let teacher =
Teacher::new("洗濯物/せんたくもの を/を 干す/ほす の/の が/が 面倒/めんどう だ/だ");
let correct_nodes = teacher.correct_node_set();
let yomi = teacher.yomi();
let segmentation_result = segmenter.build(&yomi, &force_ranges);
let graph_builder = GraphBuilder::new(
system_kana_kanji_dict,
system_single_term_dict,
Arc::new(Mutex::new(UserData::default())),
Rc::new(system_unigram_lm),
Rc::new(system_bigram_lm),
0_f32,
0_f32,
);
let graph_resolver = GraphResolver::default();
let lattice = graph_builder.construct(yomi.as_str(), segmentation_result);
let got = graph_resolver.resolve(&lattice)?;
let terms: Vec<String> = got.iter().map(|f| f[0].kanji.clone()).collect();
let result = terms.join("");
if result != yomi {
// エポックのたびに作りなおさないといけないオブジェクトが多すぎてごちゃごちゃしている。
for i in 1..yomi.len() + 2 {
// いったん、全部のードのコストを1ずつ下げる
let Some(nodes) = &lattice.node_list(i as i32) else {
continue;
};
for node in *nodes {
let modifier = if correct_nodes.contains(node) {
info!("CORRECT: {:?}", node);
-1_f32
} else {
1_f32
};
let v = unigram_cost.get(&node.key().to_string()).unwrap_or(&0_f32);
unigram_cost.insert(node.key(), *v + modifier);
}
// TODO エッジコストも考慮する
}
}
// let dot = lattice.dump_cost_dot();
// BufWriter::new(File::create("/tmp/dump.dot")?).write_fmt(format_args!("{}", dot))?;
// println!("{:?}", unigram_cost);
println!("{}", result);
}
Ok(())
}
/// 教師データ
pub struct Teacher {
pub nodes: Vec<WordNode>,
}
impl Teacher {
/// 教師データをパースする。
pub fn new(src: &str) -> Teacher {
let p: Vec<&str> = src.split(' ').collect();
let mut start_pos = 0;
let mut nodes: Vec<WordNode> = Vec::new();
for x in p {
let (surface, yomi) = x.split_once("/").unwrap();
nodes.push(WordNode::new(start_pos, surface, yomi));
start_pos += yomi.len() as i32;
}
Teacher { nodes }
}
/// 教師データの「よみ」を返す。
pub fn yomi(&self) -> String {
let mut buf = String::new();
for yomi in self.nodes.iter().map(|f| f.yomi.as_str()) {
buf += yomi;
}
buf
}
/// 正解ノードを返す
pub fn correct_node_set(&self) -> HashSet<WordNode> {
HashSet::from_iter(self.nodes.iter().cloned())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn d() -> anyhow::Result<()> {
learn_structured_perceptron()
}
}

View File

@ -38,3 +38,24 @@ Google の収集した Web のデータを使えるというのがやはり Mozc
というのはちょうどよい温度感なのかなーと思う。
また、個人的には基本のデータセットは、公開されたデータセットをベースにするのが良いと思っているんだけど。
## 学習の進みが遅いケース
基本的に、コストがノード単位に振られているから、長い文節がマッチし易い傾向にある。
これはこれで良いことなのだが、
> 洗濯物/せんたくもの を/を 干す/ほす の/の が/が 面倒/めんどう だ/だ
のようなケースの場合、"逃が/のが" が辞書に登録されていると、 "の/の が/が" の2文節を通るよりもコストがやすくなりがち。
なので、めちゃくちゃコストが下がっていくのをまたないといけない感じになりがち。
このへんは、未知語のコストを統計的かな漢字変換のときに `20` とか雑にデカくつけてるのがよくない。
しかもここがハードコードされている。こういうのを調整可能にしないといけない。
こういう、統計的かな漢字変換前提でハードコードされている部分とかをばらしていかないといけない。
クラス構造とかデータの持ち方を調整しないと、ごちゃごちゃしすぎているので、調整が必要。
例えば、「よくない」の変換結果として「翼内」が一位にくるけど「良くない」がトップに出てくるべきかもしれない。
そういった調整は、識別モデルのほうがしやすい、のかな?
ユーザーの入力結果の学習についても、統一的に扱えるような気がする。

View File

@ -163,7 +163,7 @@ impl AkazaBuilder {
Arc::new(Mutex::new(UserData::default()))
};
let graph_builder = GraphBuilder::new(
let graph_builder = GraphBuilder::new_with_default_score(
system_kana_kanji_dict,
system_single_term_dict,
user_data.clone(),

View File

@ -19,6 +19,8 @@ pub struct GraphBuilder {
user_data: Arc<Mutex<UserData>>,
system_unigram_lm: Rc<SystemUnigramLM>,
system_bigram_lm: Rc<SystemBigramLM>,
default_unigram_score_for_long: f32,
default_unigram_score_for_short: f32,
}
impl GraphBuilder {
@ -28,6 +30,8 @@ impl GraphBuilder {
user_data: Arc<Mutex<UserData>>,
system_unigram_lm: Rc<SystemUnigramLM>,
system_bigram_lm: Rc<SystemBigramLM>,
default_unigram_score_for_short: f32,
default_unigram_score_for_long: f32,
) -> GraphBuilder {
GraphBuilder {
system_kana_kanji_dict,
@ -35,9 +39,29 @@ impl GraphBuilder {
user_data,
system_unigram_lm,
system_bigram_lm,
default_unigram_score_for_short,
default_unigram_score_for_long,
}
}
pub fn new_with_default_score(
system_kana_kanji_dict: KanaKanjiDict,
system_single_term_dict: KanaKanjiDict,
user_data: Arc<Mutex<UserData>>,
system_unigram_lm: Rc<SystemUnigramLM>,
system_bigram_lm: Rc<SystemBigramLM>,
) -> GraphBuilder {
Self::new(
system_kana_kanji_dict,
system_single_term_dict,
user_data,
system_unigram_lm,
system_bigram_lm,
19_f32,
20_f32,
)
}
pub fn construct(&self, yomi: &str, words_ends_at: SegmentationResult) -> LatticeGraph {
// このグラフのインデクスは単語の終了位置。
let mut graph: BTreeMap<i32, Vec<WordNode>> = BTreeMap::new();
@ -102,6 +126,8 @@ impl GraphBuilder {
user_data: self.user_data.clone(),
system_unigram_lm: self.system_unigram_lm.clone(),
system_bigram_lm: self.system_bigram_lm.clone(),
default_unigram_score_for_long: 20.0_f32,
default_unigram_score_for_short: 19.0_f32,
}
}
}
@ -116,7 +142,7 @@ mod tests {
#[test]
fn test_single_term() {
let graph_builder = GraphBuilder::new(
let graph_builder = GraphBuilder::new_with_default_score(
KanaKanjiDict::default(),
KanaKanjiDictBuilder::default().add("すし", "🍣").build(),
Arc::new(Mutex::new(UserData::default())),
@ -139,7 +165,7 @@ mod tests {
// ひらがな、カタカナのエントリーが自動的に入るようにする。
#[test]
fn test_default_terms() {
let graph_builder = GraphBuilder::new(
let graph_builder = GraphBuilder::new_with_default_score(
KanaKanjiDict::default(),
KanaKanjiDictBuilder::default().build(),
Arc::new(Mutex::new(UserData::default())),
@ -159,7 +185,7 @@ mod tests {
// ひらがな、カタカナがすでにかな漢字辞書から提供されている場合でも、重複させない。
#[test]
fn test_default_terms_duplicated() {
let graph_builder = GraphBuilder::new(
let graph_builder = GraphBuilder::new_with_default_score(
KanaKanjiDictBuilder::default().add("", "す/ス").build(),
KanaKanjiDictBuilder::default().build(),
Arc::new(Mutex::new(UserData::default())),

View File

@ -176,7 +176,7 @@ mod tests {
let system_bigram_lm_builder = SystemBigramLMBuilder::default();
let system_bigram_lm = system_bigram_lm_builder.build();
let user_data = UserData::default();
let graph_builder = GraphBuilder::new(
let graph_builder = GraphBuilder::new_with_default_score(
dict,
Default::default(),
Arc::new(Mutex::new(user_data)),
@ -225,7 +225,7 @@ mod tests {
let mut user_data = UserData::default();
// 私/わたし のスコアをガッと上げる。
user_data.record_entries(&vec!["私/わたし".to_string()]);
let graph_builder = GraphBuilder::new(
let graph_builder = GraphBuilder::new_with_default_score(
dict,
KanaKanjiDict::default(),
Arc::new(Mutex::new(user_data)),

View File

@ -3,7 +3,7 @@ use std::fmt::{Debug, Formatter};
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use log::{error, trace};
use log::{error, trace, warn};
use crate::graph::word_node::WordNode;
use crate::lm::system_bigram::SystemBigramLM;
@ -19,6 +19,10 @@ pub struct LatticeGraph {
pub(crate) user_data: Arc<Mutex<UserData>>,
pub(crate) system_unigram_lm: Rc<SystemUnigramLM>,
pub(crate) system_bigram_lm: Rc<SystemBigramLM>,
/// -log10(1e-19)=19.0
pub(crate) default_unigram_score_for_short: f32,
/// -log10(1e-20)=20.0
pub(crate) default_unigram_score_for_long: f32,
}
impl Debug for LatticeGraph {
@ -33,7 +37,7 @@ impl Debug for LatticeGraph {
impl LatticeGraph {
/// i文字目で終わるードを探す
pub(crate) fn node_list(&self, end_pos: i32) -> Option<&Vec<WordNode>> {
pub fn node_list(&self, end_pos: i32) -> Option<&Vec<WordNode>> {
self.graph.get(&end_pos)
}
@ -125,17 +129,16 @@ impl LatticeGraph {
}
return if let Some((_, system_unigram_cost)) = self.system_unigram_lm.find(key.as_str()) {
warn!("HIT!: {}, {}", node.key(), system_unigram_cost);
system_unigram_cost
} else if node.kanji.len() < node.yomi.len() {
// 労働者災害補償保険法 のように、システム辞書には wikipedia から採録されているが,
// 言語モデルには採録されていない場合,漢字候補を先頭に持ってくる。
// つまり、変換後のほうが短くなるもののほうをコストを安くしておく。
self.default_unigram_score_for_short
// -log10(1e-19)
19.0
} else {
// -log10(1e-20)
20.0
self.default_unigram_score_for_long
};
}

View File

@ -1,7 +1,7 @@
use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct WordNode {
pub start_pos: i32,
/// 漢字
@ -56,7 +56,7 @@ impl WordNode {
cost: 0_f32,
}
}
pub(crate) fn new(start_pos: i32, kanji: &str, yomi: &str) -> WordNode {
pub fn new(start_pos: i32, kanji: &str, yomi: &str) -> WordNode {
WordNode {
start_pos,
kanji: kanji.to_string(),