Implemented user data related code

This commit is contained in:
Tokuhiro Matsuno
2022-12-30 17:05:36 +09:00
parent a39ef060f9
commit 00ba03fa1c
14 changed files with 328 additions and 340 deletions

1
akaza-core/Cargo.lock generated
View File

@ -334,6 +334,7 @@ dependencies = [
"marisa-sys",
"regex",
"sled",
"tempfile",
]
[[package]]

View File

@ -12,4 +12,7 @@ regex = "1"
sled = "0.34.7"
daachorse = "1.0.0"
log = "0.4.17"
env_logger = "0.10.0"
env_logger = "0.10.0"
tempfile = "3"
[build-dependencies]

View File

@ -159,6 +159,27 @@ impl GraphResolver {
// 経験上、長い文字列のほうがあたり、というルールでもそこそこ変換できる。
// TODO あとでちゃんと unigram のコストを使うよに変える。
return node.kanji.len() as f32;
/*
if let Some(user_cost) = user_language_model.get_unigram_cost(&self.key) {
// use user's score, if it's exists.
return user_cost;
}
if self.system_word_id != UNKNOWN_WORD_ID {
self.total_cost = Some(self.system_unigram_cost);
return self.system_unigram_cost;
} else {
// 労働者災害補償保険法 のように、システム辞書には採録されているが,
// 言語モデルには採録されていない場合,漢字候補を先頭に持ってくる。
return if self.word.len() < self.yomi.len() {
// 読みのほうが短いので、漢字。
ulm.get_default_cost_for_short()
} else {
ulm.get_default_cost()
};
}
*/
}
// -1 0 1 2

View File

@ -35,10 +35,12 @@ impl KanaTrie {
self.marisa.save(file_name)
}
pub(crate) fn load(file_name: &String) -> KanaTrie {
pub(crate) fn load(file_name: &String) -> Result<KanaTrie, String> {
let marisa = Marisa::new();
marisa.load(file_name).unwrap();
KanaTrie { marisa }
match marisa.load(file_name) {
Ok(_) => Ok(KanaTrie { marisa }),
Err(err) => Err(err),
}
}
pub(crate) fn common_prefix_search(&self, query: &String) -> Vec<String> {

View File

@ -4,11 +4,9 @@ mod graph_resolver;
pub mod kana;
pub(crate) mod kana_trie;
pub mod lm;
mod node;
mod romkan;
mod tinylisp;
pub mod trie;
mod user_data;
pub mod user_language_model;
const UNKNOWN_WORD_ID: i32 = -1;

View File

@ -1,118 +0,0 @@
use crate::lm::system_unigram_lm::SystemUnigramLM;
use crate::user_language_model::UserLanguageModel;
use crate::UNKNOWN_WORD_ID;
pub(crate) struct Node {
start_pos: i32,
yomi: String,
word: String,
pub(crate) key: String,
is_bos: bool,
is_eos: bool,
system_word_id: i32,
system_unigram_cost: f32,
total_cost: Option<f32>, // unigram cost + bigram cost + previous cost
prev: Option<Box<Node>>,
// bigram_cache: HashMap<String, f32>,
}
impl Node {
fn new(
start_pos: i32,
yomi: &String,
word: &String,
key: &String,
is_bos: bool,
is_eos: bool,
system_word_id: i32,
system_unigram_cost: f32,
) -> Node {
Node {
start_pos,
yomi: yomi.clone(),
word: word.clone(),
key: key.clone(),
is_bos,
is_eos,
system_word_id,
system_unigram_cost,
total_cost: Option::None,
prev: Option::None,
}
}
fn new_bos_node() -> Node {
Node {
start_pos: -1,
yomi: "__BOS__".to_string(),
word: "__BOS__".to_string(),
key: "__BOS__/__BOS__".to_string(),
is_bos: true,
is_eos: false,
system_word_id: UNKNOWN_WORD_ID,
system_unigram_cost: 0_f32,
total_cost: None,
prev: None,
}
}
fn new_eos_node(start_pos: i32) -> Node {
// 本来使うべきだが、key をわざと使わない。__EOS__ 考慮すると変換精度が落ちるので。。今は使わない。
// うまく使えることが確認できれば、__EOS__/__EOS__ にする。
Node {
start_pos,
yomi: "__EOS__".to_string(),
word: "__EOS__".to_string(),
key: "__EOS__".to_string(),
is_bos: false,
is_eos: true,
system_word_id: UNKNOWN_WORD_ID,
system_unigram_cost: 0_f32,
total_cost: None,
prev: None,
}
}
fn create_node(
system_unigram_lm: &SystemUnigramLM,
start_pos: i32,
yomi: &String,
kanji: &String,
) -> Node {
let key = kanji.clone() + "/" + yomi;
let result = system_unigram_lm.find_unigram(&key).unwrap();
let (word_id, cost) = result;
Self::new(start_pos, yomi, kanji, &key, false, false, word_id, cost)
}
fn calc_node_cost(
&mut self,
user_language_model: &UserLanguageModel,
ulm: &SystemUnigramLM,
) -> f32 {
if let Some(user_cost) = user_language_model.get_unigram_cost(&self.key) {
// use user's score, if it's exists.
return user_cost;
}
if self.system_word_id != UNKNOWN_WORD_ID {
self.total_cost = Some(self.system_unigram_cost);
return self.system_unigram_cost;
} else {
// 労働者災害補償保険法 のように、システム辞書には採録されているが,
// 言語モデルには採録されていない場合,漢字候補を先頭に持ってくる。
return if self.word.len() < self.yomi.len() {
// 読みのほうが短いので、漢字。
ulm.get_default_cost_for_short()
} else {
ulm.get_default_cost()
};
}
// calc_bigram_cost
// get_bigram_cost
// get_bigram_cost_from_cache
// set_prev
// ==
// surface
}
}

View File

@ -0,0 +1,68 @@
use std::collections::HashMap;
const ALPHA: f32 = 0.00001;
pub(crate) struct BiGramUserStats {
/// ユニーク単語数
unique_words: u32,
// C
/// 総単語出現数
total_words: u32,
// V
/// その単語の出現頻度。「漢字/漢字」がキー。
pub(crate) word_count: HashMap<String, u32>,
}
impl BiGramUserStats {
pub(crate) fn new(
unique_words: u32,
total_words: u32,
word_count: HashMap<String, u32>,
) -> BiGramUserStats {
BiGramUserStats {
unique_words,
total_words,
word_count,
}
}
/**
* エッジコストを計算する。
* システム言語モデルのコストよりも安くなるように調整してある。
*/
fn get_cost(&self, key1: &String, key2: &String) -> Option<f32> {
let key = key1.clone() + "\t" + key2;
let Some(count) = self.word_count.get(key.as_str()) else {
return None;
};
return Some(f32::log10(
((*count as f32) + ALPHA)
/ ((self.unique_words as f32) + ALPHA + (self.total_words as f32)),
));
}
pub(crate) fn record_entries(&mut self, kanjis: &Vec<String>) {
if kanjis.len() < 2 {
return;
}
// bigram
for i in 1..kanjis.len() {
let Some(kanji1) = kanjis.get(i - 1) else {
continue;
};
let Some(kanji2) = kanjis.get(i) else {
continue;
};
let key = kanji1.clone() + &"\t".to_string() + &kanji2;
if let Some(cnt) = self.word_count.get(&key) {
self.word_count.insert(key, cnt + 1);
} else {
self.word_count.insert(key, 1);
self.unique_words += 1;
}
self.total_words += 1;
}
}
}

View File

@ -1,4 +1,4 @@
mod abstract_user_stats;
mod bigram_user_stats;
mod unigram_user_stats;
mod user_data;
mod user_stats_utils;

View File

@ -0,0 +1,52 @@
use std::collections::HashMap;
const ALPHA: f32 = 0.00001;
pub(crate) struct UniGramUserStats {
/// ユニーク単語数
unique_words: u32, // C
/// 総単語出現数
total_words: u32, // V
/// その単語の出現頻度。「漢字」がキー。
pub(crate) word_count: HashMap<String, u32>,
}
impl UniGramUserStats {
pub(crate) fn new(
unique_words: u32,
total_words: u32,
word_count: HashMap<String, u32>,
) -> UniGramUserStats {
UniGramUserStats {
unique_words,
total_words,
word_count,
}
}
/**
* ノードコストを計算する。
* システム言語モデルのコストよりも安くなるように調整してある。
*/
pub(crate) fn get_cost(&self, key: &String) -> Option<f32> {
let Some(count) = self.word_count.get(key) else {
return None;
};
return Some(f32::log10(
((*count as f32) + ALPHA)
/ ((self.unique_words as f32) + ALPHA + (self.total_words as f32)),
));
}
pub(crate) fn record_entries(&mut self, kanjis: &Vec<String>) {
for kanji in kanjis {
if let Some(i) = self.word_count.get(kanji) {
self.word_count.insert(kanji.clone(), i + 1);
} else {
self.word_count.insert(kanji.clone(), 1);
self.unique_words += 1;
}
self.total_words += 1;
}
}
}

View File

@ -1,4 +1,12 @@
use crate::kana_trie::KanaTrie;
use log::warn;
use std::collections::HashMap;
use std::io;
use crate::kana_trie::{KanaTrie, KanaTrieBuilder};
use crate::user_data::bigram_user_stats::BiGramUserStats;
use crate::user_data::unigram_user_stats::UniGramUserStats;
use crate::user_data::user_stats_utils::{read_user_stats_file, write_user_stats_file};
use marisa_sys::Marisa;
/**
* ユーザー固有データ
@ -6,4 +14,91 @@ use crate::kana_trie::KanaTrie;
struct UserData {
/// 読み仮名のトライ。入力変換時に共通接頭辞検索するために使用。
kana_trie: KanaTrie,
unigram_user_stats: UniGramUserStats,
bigram_user_stats: BiGramUserStats,
unigram_path: String,
bigram_path: String,
kana_trie_path: String,
pub need_save: bool,
}
impl UserData {
fn load(unigram_path: &String, bigram_path: &String, kana_trie_path: &String) -> UserData {
// ユーザーデータが読み込めないことは fatal エラーではない。
// 初回起動時にはデータがないので。
// データがなければ初期所状態から始める
let unigram_user_stats = match read_user_stats_file(unigram_path) {
Ok(dat) => {
let unique_count = dat.len() as u32;
let total_count: u32 = dat.iter().map(|f| f.1).sum();
let mut word_count: HashMap<String, u32> = HashMap::new();
for (word, count) in dat {
word_count.insert(word, count);
}
UniGramUserStats::new(unique_count, total_count, word_count)
}
Err(err) => {
warn!(
"Cannot load user unigram data from {}: {}",
unigram_path, err
);
let unigram_user_stats = UniGramUserStats::new(0, 0, HashMap::new());
UniGramUserStats::new(0, 0, HashMap::new())
}
};
// build bigram
let bigram_user_stats = match read_user_stats_file(bigram_path) {
Ok(dat) => {
let unique_count = dat.len() as u32;
let total_count: u32 = dat.iter().map(|f| f.1).sum();
let mut words_count: HashMap<String, u32> = HashMap::new();
for (words, count) in dat {
words_count.insert(words, count);
}
BiGramUserStats::new(unique_count, total_count, words_count)
}
Err(err) => {
warn!("Cannot load user bigram data from {}: {}", bigram_path, err);
// ユーザーデータは初回起動時などにはないので、データがないものとして処理を続行する
BiGramUserStats::new(0, 0, HashMap::new())
}
};
let kana_trie = match KanaTrie::load(kana_trie_path) {
Ok(trie) => trie,
Err(err) => {
warn!("Cannot load kana trie: {} {}", kana_trie_path, err);
KanaTrie::new(Marisa::new())
}
};
UserData {
unigram_user_stats,
bigram_user_stats,
kana_trie,
unigram_path: unigram_path.clone(),
bigram_path: bigram_path.clone(),
kana_trie_path: kana_trie_path.clone(),
need_save: false,
}
}
/// 入力確定した漢字のリストをユーザー統計データとして記録する。
fn record_entries(&mut self, kanjis: Vec<String>, kanas: Vec<String>) {
self.unigram_user_stats.record_entries(&kanjis);
self.bigram_user_stats.record_entries(&kanjis);
// for kana in kanas {
// TODO: record kanas to trie.
// }
}
fn write_user_stats_file(&self) -> Result<(), io::Error> {
write_user_stats_file(&self.unigram_path, &self.unigram_user_stats.word_count)?;
write_user_stats_file(&self.bigram_path, &self.bigram_user_stats.word_count)?;
Ok(())
}
}

View File

@ -0,0 +1,68 @@
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
use std::{fs, io};
pub(crate) fn read_user_stats_file(path: &String) -> Result<Vec<(String, u32)>, String> {
let file = match File::open(path) {
Ok(file) => file,
Err(err) => {
return Err(err.to_string());
}
};
let mut result: Vec<(String, u32)> = Vec::new();
for line in BufReader::new(file).lines() {
let Ok(line) = line else {
return Err("Cannot read user language model file".to_string());
};
let tokens: Vec<&str> = line.trim().splitn(2, " ").collect();
if tokens.len() != 2 {
continue;
}
let key = tokens[0];
let Ok(count) = tokens[1].to_string().parse::<u32>() else {
return Err("Invalid line in user language model: ".to_string() + tokens[1]);
};
result.push((key.to_string(), count));
}
return Ok(result);
}
pub(crate) fn write_user_stats_file(
path: &String,
word_count: &HashMap<String, u32>,
) -> Result<(), io::Error> {
let mut tmpfile = File::create(path.clone() + ".tmp")?;
for (key, cnt) in word_count {
tmpfile.write(key.as_bytes())?;
tmpfile.write(" ".as_bytes())?;
tmpfile.write(cnt.to_string().as_bytes())?;
tmpfile.write("\n".as_bytes())?;
}
fs::rename(path.clone() + ".tmp", path.clone())?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Read;
use tempfile::NamedTempFile;
#[test]
fn test_write() {
let tmpfile = NamedTempFile::new().unwrap();
let path = tmpfile.path().to_str().unwrap().to_string();
write_user_stats_file(&path, &HashMap::from([("渡し".to_string(), 3_u32)])).unwrap();
let mut buf = String::new();
File::open(path).unwrap().read_to_string(&mut buf).unwrap();
assert_eq!(buf, "渡し 3\n");
}
}

View File

@ -1,213 +0,0 @@
use std::collections::HashMap;
use std::collections::HashSet;
use std::fs::File;
use std::io::{BufRead, BufReader};
#[derive(PartialEq)]
enum GramType {
BiGram,
UniGram,
}
// unigram 用と bigram 用のロジックでコピペが増えがちで危ない。
// 完全に分離したほうが良い。
pub struct UserLanguageModel {
unigram_path: String,
bigram_path: String,
need_save: bool,
// unigram 言語モデルに登録されている、読み仮名を登録しておく。
// これにより、「ひょいー」などの漢字ではないものを、単語として IME が認識できるように
// している。
// 本質的には、user language model でやるべき処理というよりも、ユーザー単語辞書でもつく
// ってやるのが本筋だと思わなくもない
unigram_kanas: HashSet<String>,
/// ユニーク単語数
unigram_c: u32,
/// 総単語出現数
unigram_v: u32,
unigram: HashMap<String, u32>,
bigram_c: u32,
bigram_v: u32,
bigram: HashMap<String, u32>,
alpha: f32, // = 0.00001;
}
impl UserLanguageModel {
fn new(unigram_path: &String, bigram_path: &String) -> UserLanguageModel {
UserLanguageModel {
unigram_path: unigram_path.clone(),
bigram_path: bigram_path.clone(),
need_save: false,
unigram_kanas: HashSet::new(),
unigram_c: 0,
unigram_v: 0,
unigram: HashMap::new(),
bigram_c: 0,
bigram_v: 0,
bigram: HashMap::new(),
alpha: 0.00001,
}
}
fn read(
&mut self,
path: &String,
gram_type: GramType,
) -> Result<(u32, u32, HashMap<String, u32>), String> {
let mut c = 0;
let mut v = 0;
let mut map = HashMap::new();
// TODO : 厳密なエラー処理
let Ok(file) = File::open(path) else {
return Err("Cannot open user language model file".to_string());
};
for line in BufReader::new(file).lines() {
let Ok(line) = line else {
return Err("Cannot read user language model file".to_string());
};
let tokens: Vec<&str> = line.trim().splitn(2, " ").collect();
if tokens.len() != 2 {
continue;
}
let key = tokens[0];
let Ok(count) = tokens[1].to_string().parse::<u32>() else {
return Err("Invalid line in user language model: ".to_string() + tokens[1]);
};
map.insert(key.to_string(), count);
// unigram 言語モデルに登録されている、ひらがなの単語を、集めて登録しておく。
// これにより、「ひょいー」などの漢字ではないものを、単語として IME が認識できるように
// している。
// 本質的には、user language model でやるべき処理というよりも、ユーザー単語辞書でもつく
// ってやるのが本筋だと思わなくもない
if gram_type == GramType::UniGram {
let tokens: Vec<&str> = line.splitn(2, "/").collect();
if tokens.len() != 2 {
continue;
}
let kana = tokens[0];
self.unigram_kanas.insert(kana.to_string());
}
c += count;
v += 1;
}
return Ok((c, v, map));
}
fn load_unigram(&mut self) -> Result<(), String> {
let result = self.read(&self.unigram_path.clone(), GramType::UniGram);
let Ok((c, v, map)) = result else {
return Err(result.err().unwrap());
};
self.unigram_c = c;
self.unigram_v = v;
self.unigram = map;
return Ok(());
}
fn load_bigram(&mut self) -> Result<(), String> {
let result = self.read(&self.bigram_path.clone(), GramType::BiGram);
let Ok((c, v, map)) = result else {
return Err(result.err().unwrap());
};
self.bigram_c = c;
self.bigram_v = v;
self.bigram = map;
return Ok(());
}
fn add_entry(&mut self, nodes: &Vec<crate::node::Node>) {
// unigram
for node in nodes {
let key = &node.key;
if !self.unigram.contains_key(key) {
// increment unique count
self.unigram_c += 1;
}
self.unigram_v += 1;
let tokens: Vec<&str> = key.splitn(2, "/").collect();
let kana = tokens[1];
// std::wstring kana = std::get<1>(split2(key, L'/', splitted));
self.unigram_kanas.insert(kana.to_string());
self.unigram.insert(
key.to_string(),
self.unigram.get(key.as_str()).unwrap_or(&0_u32) + 1,
);
}
// bigram
for i in 1..nodes.len() {
let Some(node1) = nodes.get(i - 1) else {
continue;
};
let Some(node2) = nodes.get(i) else {
continue;
};
let k1 = &node1.key.to_string().clone();
let key = k1.to_string() + &"\t".to_string() + &node2.key;
if self.bigram.contains_key(&key) {
self.bigram_c += 1;
}
self.bigram_v += 1;
self.bigram
.insert(key.clone(), self.bigram.get(&key).unwrap_or(&0) + 1);
}
self.need_save = true;
}
pub(crate) fn get_unigram_cost(&self, key: &String) -> Option<f32> {
let Some(count) = self.unigram.get(key) else {
return None;
};
return Some(f32::log10(
((*count as f32) + self.alpha)
/ ((self.unigram_c as f32) + self.alpha + (self.unigram_v as f32)),
));
}
fn get_bigram_cost(&self, key1: &String, key2: &String) -> Option<f32> {
let key = key1.clone() + "\t" + key2;
let Some(count) = self.bigram.get(key.as_str()) else {
return None;
};
return Some(f32::log10(
((*count as f32) + self.alpha)
/ ((self.bigram_c as f32) + self.alpha + (self.bigram_v as f32)),
));
}
// TODO save_file
/*
void akaza::UserLanguageModel::save_file(const std::string &path, const std::unordered_map<std::wstring, int> &map) {
std::string tmppath(path + ".tmp");
std::wofstream ofs(tmppath, std::ofstream::out);
ofs.imbue(std::locale(std::locale(), new std::codecvt_utf8<wchar_t>));
for (const auto&[words, count] : map) {
ofs << words << " " << count << std::endl;
}
ofs.close();
int status = std::rename(tmppath.c_str(), path.c_str());
if (status != 0) {
std::string err = strerror(errno);
throw std::runtime_error(err + " : " + path + " (Cannot write user language model)");
}
}
*/
}

View File

@ -10,6 +10,7 @@ akaza はユーザー固有のデータを保持し利用する。
ユーザーが入力したデータの統計データである。
ユーザーが入力した単語の、unigram と bigram が統計データとして保存される。
保存されるのは「漢字」の方。
## ユーザー言語モデル
@ -18,7 +19,17 @@ akaza はユーザー固有のデータを保持し利用する。
- C: ユーザーが入力した単語のユニーク数
- V: ユーザーが入力した単語の総数
- word_count: 単語ごとの漢字入力回数
これらをもとに、コストを計算する。ユーザー言語モデルから得られるコスト値は、システム辞書に記録されるコスト値よりも低く設定されている。これにより、一度入力した単語は強烈に表出するようになる。
## ユーザー共通接頭辞
ユーザーの入力統計データをもとに、入力データの「かな」部分を利用して trie を構築する。
入力データの「かな」部分を利用して trie を構築する。
## 目指している形
SKK では、一度入力されたデータはユーザー辞書に登録されていく。これにより強烈にパーソナライズされていくので、そうそう誤変換しなくなっていく。
これと同じようなユーザー体験を得られるようにしていきたい。