mirror of
https://github.com/mii443/akaza.git
synced 2025-08-22 14:55:31 +00:00
Implemented user data related code
This commit is contained in:
1
akaza-core/Cargo.lock
generated
1
akaza-core/Cargo.lock
generated
@ -334,6 +334,7 @@ dependencies = [
|
||||
"marisa-sys",
|
||||
"regex",
|
||||
"sled",
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -12,4 +12,7 @@ regex = "1"
|
||||
sled = "0.34.7"
|
||||
daachorse = "1.0.0"
|
||||
log = "0.4.17"
|
||||
env_logger = "0.10.0"
|
||||
env_logger = "0.10.0"
|
||||
tempfile = "3"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -159,6 +159,27 @@ impl GraphResolver {
|
||||
// 経験上、長い文字列のほうがあたり、というルールでもそこそこ変換できる。
|
||||
// TODO あとでちゃんと unigram のコストを使うよに変える。
|
||||
return node.kanji.len() as f32;
|
||||
|
||||
/*
|
||||
if let Some(user_cost) = user_language_model.get_unigram_cost(&self.key) {
|
||||
// use user's score, if it's exists.
|
||||
return user_cost;
|
||||
}
|
||||
|
||||
if self.system_word_id != UNKNOWN_WORD_ID {
|
||||
self.total_cost = Some(self.system_unigram_cost);
|
||||
return self.system_unigram_cost;
|
||||
} else {
|
||||
// 労働者災害補償保険法 のように、システム辞書には採録されているが,
|
||||
// 言語モデルには採録されていない場合,漢字候補を先頭に持ってくる。
|
||||
return if self.word.len() < self.yomi.len() {
|
||||
// 読みのほうが短いので、漢字。
|
||||
ulm.get_default_cost_for_short()
|
||||
} else {
|
||||
ulm.get_default_cost()
|
||||
};
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
// -1 0 1 2
|
||||
|
@ -35,10 +35,12 @@ impl KanaTrie {
|
||||
self.marisa.save(file_name)
|
||||
}
|
||||
|
||||
pub(crate) fn load(file_name: &String) -> KanaTrie {
|
||||
pub(crate) fn load(file_name: &String) -> Result<KanaTrie, String> {
|
||||
let marisa = Marisa::new();
|
||||
marisa.load(file_name).unwrap();
|
||||
KanaTrie { marisa }
|
||||
match marisa.load(file_name) {
|
||||
Ok(_) => Ok(KanaTrie { marisa }),
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn common_prefix_search(&self, query: &String) -> Vec<String> {
|
||||
|
@ -4,11 +4,9 @@ mod graph_resolver;
|
||||
pub mod kana;
|
||||
pub(crate) mod kana_trie;
|
||||
pub mod lm;
|
||||
mod node;
|
||||
mod romkan;
|
||||
mod tinylisp;
|
||||
pub mod trie;
|
||||
mod user_data;
|
||||
pub mod user_language_model;
|
||||
|
||||
const UNKNOWN_WORD_ID: i32 = -1;
|
||||
|
@ -1,118 +0,0 @@
|
||||
use crate::lm::system_unigram_lm::SystemUnigramLM;
|
||||
use crate::user_language_model::UserLanguageModel;
|
||||
use crate::UNKNOWN_WORD_ID;
|
||||
|
||||
pub(crate) struct Node {
|
||||
start_pos: i32,
|
||||
yomi: String,
|
||||
word: String,
|
||||
pub(crate) key: String,
|
||||
is_bos: bool,
|
||||
is_eos: bool,
|
||||
system_word_id: i32,
|
||||
system_unigram_cost: f32,
|
||||
total_cost: Option<f32>, // unigram cost + bigram cost + previous cost
|
||||
prev: Option<Box<Node>>,
|
||||
// bigram_cache: HashMap<String, f32>,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
fn new(
|
||||
start_pos: i32,
|
||||
yomi: &String,
|
||||
word: &String,
|
||||
key: &String,
|
||||
is_bos: bool,
|
||||
is_eos: bool,
|
||||
system_word_id: i32,
|
||||
system_unigram_cost: f32,
|
||||
) -> Node {
|
||||
Node {
|
||||
start_pos,
|
||||
yomi: yomi.clone(),
|
||||
word: word.clone(),
|
||||
key: key.clone(),
|
||||
is_bos,
|
||||
is_eos,
|
||||
system_word_id,
|
||||
system_unigram_cost,
|
||||
total_cost: Option::None,
|
||||
prev: Option::None,
|
||||
}
|
||||
}
|
||||
|
||||
fn new_bos_node() -> Node {
|
||||
Node {
|
||||
start_pos: -1,
|
||||
yomi: "__BOS__".to_string(),
|
||||
word: "__BOS__".to_string(),
|
||||
key: "__BOS__/__BOS__".to_string(),
|
||||
is_bos: true,
|
||||
is_eos: false,
|
||||
system_word_id: UNKNOWN_WORD_ID,
|
||||
system_unigram_cost: 0_f32,
|
||||
total_cost: None,
|
||||
prev: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn new_eos_node(start_pos: i32) -> Node {
|
||||
// 本来使うべきだが、key をわざと使わない。__EOS__ 考慮すると変換精度が落ちるので。。今は使わない。
|
||||
// うまく使えることが確認できれば、__EOS__/__EOS__ にする。
|
||||
Node {
|
||||
start_pos,
|
||||
yomi: "__EOS__".to_string(),
|
||||
word: "__EOS__".to_string(),
|
||||
key: "__EOS__".to_string(),
|
||||
is_bos: false,
|
||||
is_eos: true,
|
||||
system_word_id: UNKNOWN_WORD_ID,
|
||||
system_unigram_cost: 0_f32,
|
||||
total_cost: None,
|
||||
prev: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_node(
|
||||
system_unigram_lm: &SystemUnigramLM,
|
||||
start_pos: i32,
|
||||
yomi: &String,
|
||||
kanji: &String,
|
||||
) -> Node {
|
||||
let key = kanji.clone() + "/" + yomi;
|
||||
let result = system_unigram_lm.find_unigram(&key).unwrap();
|
||||
let (word_id, cost) = result;
|
||||
Self::new(start_pos, yomi, kanji, &key, false, false, word_id, cost)
|
||||
}
|
||||
|
||||
fn calc_node_cost(
|
||||
&mut self,
|
||||
user_language_model: &UserLanguageModel,
|
||||
ulm: &SystemUnigramLM,
|
||||
) -> f32 {
|
||||
if let Some(user_cost) = user_language_model.get_unigram_cost(&self.key) {
|
||||
// use user's score, if it's exists.
|
||||
return user_cost;
|
||||
}
|
||||
|
||||
if self.system_word_id != UNKNOWN_WORD_ID {
|
||||
self.total_cost = Some(self.system_unigram_cost);
|
||||
return self.system_unigram_cost;
|
||||
} else {
|
||||
// 労働者災害補償保険法 のように、システム辞書には採録されているが,
|
||||
// 言語モデルには採録されていない場合,漢字候補を先頭に持ってくる。
|
||||
return if self.word.len() < self.yomi.len() {
|
||||
// 読みのほうが短いので、漢字。
|
||||
ulm.get_default_cost_for_short()
|
||||
} else {
|
||||
ulm.get_default_cost()
|
||||
};
|
||||
}
|
||||
// calc_bigram_cost
|
||||
// get_bigram_cost
|
||||
// get_bigram_cost_from_cache
|
||||
// set_prev
|
||||
// ==
|
||||
// surface
|
||||
}
|
||||
}
|
@ -0,0 +1,68 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
const ALPHA: f32 = 0.00001;
|
||||
|
||||
pub(crate) struct BiGramUserStats {
|
||||
/// ユニーク単語数
|
||||
unique_words: u32,
|
||||
// C
|
||||
/// 総単語出現数
|
||||
total_words: u32,
|
||||
// V
|
||||
/// その単語の出現頻度。「漢字/漢字」がキー。
|
||||
pub(crate) word_count: HashMap<String, u32>,
|
||||
}
|
||||
|
||||
impl BiGramUserStats {
|
||||
pub(crate) fn new(
|
||||
unique_words: u32,
|
||||
total_words: u32,
|
||||
word_count: HashMap<String, u32>,
|
||||
) -> BiGramUserStats {
|
||||
BiGramUserStats {
|
||||
unique_words,
|
||||
total_words,
|
||||
word_count,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* エッジコストを計算する。
|
||||
* システム言語モデルのコストよりも安くなるように調整してある。
|
||||
*/
|
||||
fn get_cost(&self, key1: &String, key2: &String) -> Option<f32> {
|
||||
let key = key1.clone() + "\t" + key2;
|
||||
let Some(count) = self.word_count.get(key.as_str()) else {
|
||||
return None;
|
||||
};
|
||||
return Some(f32::log10(
|
||||
((*count as f32) + ALPHA)
|
||||
/ ((self.unique_words as f32) + ALPHA + (self.total_words as f32)),
|
||||
));
|
||||
}
|
||||
|
||||
pub(crate) fn record_entries(&mut self, kanjis: &Vec<String>) {
|
||||
if kanjis.len() < 2 {
|
||||
return;
|
||||
}
|
||||
|
||||
// bigram
|
||||
for i in 1..kanjis.len() {
|
||||
let Some(kanji1) = kanjis.get(i - 1) else {
|
||||
continue;
|
||||
};
|
||||
let Some(kanji2) = kanjis.get(i) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let key = kanji1.clone() + &"\t".to_string() + &kanji2;
|
||||
if let Some(cnt) = self.word_count.get(&key) {
|
||||
self.word_count.insert(key, cnt + 1);
|
||||
} else {
|
||||
self.word_count.insert(key, 1);
|
||||
self.unique_words += 1;
|
||||
}
|
||||
self.total_words += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
mod abstract_user_stats;
|
||||
mod bigram_user_stats;
|
||||
mod unigram_user_stats;
|
||||
mod user_data;
|
||||
mod user_stats_utils;
|
||||
|
@ -0,0 +1,52 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
const ALPHA: f32 = 0.00001;
|
||||
|
||||
pub(crate) struct UniGramUserStats {
|
||||
/// ユニーク単語数
|
||||
unique_words: u32, // C
|
||||
/// 総単語出現数
|
||||
total_words: u32, // V
|
||||
/// その単語の出現頻度。「漢字」がキー。
|
||||
pub(crate) word_count: HashMap<String, u32>,
|
||||
}
|
||||
|
||||
impl UniGramUserStats {
|
||||
pub(crate) fn new(
|
||||
unique_words: u32,
|
||||
total_words: u32,
|
||||
word_count: HashMap<String, u32>,
|
||||
) -> UniGramUserStats {
|
||||
UniGramUserStats {
|
||||
unique_words,
|
||||
total_words,
|
||||
word_count,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* ノードコストを計算する。
|
||||
* システム言語モデルのコストよりも安くなるように調整してある。
|
||||
*/
|
||||
pub(crate) fn get_cost(&self, key: &String) -> Option<f32> {
|
||||
let Some(count) = self.word_count.get(key) else {
|
||||
return None;
|
||||
};
|
||||
return Some(f32::log10(
|
||||
((*count as f32) + ALPHA)
|
||||
/ ((self.unique_words as f32) + ALPHA + (self.total_words as f32)),
|
||||
));
|
||||
}
|
||||
|
||||
pub(crate) fn record_entries(&mut self, kanjis: &Vec<String>) {
|
||||
for kanji in kanjis {
|
||||
if let Some(i) = self.word_count.get(kanji) {
|
||||
self.word_count.insert(kanji.clone(), i + 1);
|
||||
} else {
|
||||
self.word_count.insert(kanji.clone(), 1);
|
||||
self.unique_words += 1;
|
||||
}
|
||||
self.total_words += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,12 @@
|
||||
use crate::kana_trie::KanaTrie;
|
||||
use log::warn;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
|
||||
use crate::kana_trie::{KanaTrie, KanaTrieBuilder};
|
||||
use crate::user_data::bigram_user_stats::BiGramUserStats;
|
||||
use crate::user_data::unigram_user_stats::UniGramUserStats;
|
||||
use crate::user_data::user_stats_utils::{read_user_stats_file, write_user_stats_file};
|
||||
use marisa_sys::Marisa;
|
||||
|
||||
/**
|
||||
* ユーザー固有データ
|
||||
@ -6,4 +14,91 @@ use crate::kana_trie::KanaTrie;
|
||||
struct UserData {
|
||||
/// 読み仮名のトライ。入力変換時に共通接頭辞検索するために使用。
|
||||
kana_trie: KanaTrie,
|
||||
unigram_user_stats: UniGramUserStats,
|
||||
bigram_user_stats: BiGramUserStats,
|
||||
|
||||
unigram_path: String,
|
||||
bigram_path: String,
|
||||
kana_trie_path: String,
|
||||
|
||||
pub need_save: bool,
|
||||
}
|
||||
impl UserData {
|
||||
fn load(unigram_path: &String, bigram_path: &String, kana_trie_path: &String) -> UserData {
|
||||
// ユーザーデータが読み込めないことは fatal エラーではない。
|
||||
// 初回起動時にはデータがないので。
|
||||
// データがなければ初期所状態から始める
|
||||
let unigram_user_stats = match read_user_stats_file(unigram_path) {
|
||||
Ok(dat) => {
|
||||
let unique_count = dat.len() as u32;
|
||||
let total_count: u32 = dat.iter().map(|f| f.1).sum();
|
||||
let mut word_count: HashMap<String, u32> = HashMap::new();
|
||||
for (word, count) in dat {
|
||||
word_count.insert(word, count);
|
||||
}
|
||||
UniGramUserStats::new(unique_count, total_count, word_count)
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Cannot load user unigram data from {}: {}",
|
||||
unigram_path, err
|
||||
);
|
||||
|
||||
let unigram_user_stats = UniGramUserStats::new(0, 0, HashMap::new());
|
||||
UniGramUserStats::new(0, 0, HashMap::new())
|
||||
}
|
||||
};
|
||||
|
||||
// build bigram
|
||||
let bigram_user_stats = match read_user_stats_file(bigram_path) {
|
||||
Ok(dat) => {
|
||||
let unique_count = dat.len() as u32;
|
||||
let total_count: u32 = dat.iter().map(|f| f.1).sum();
|
||||
let mut words_count: HashMap<String, u32> = HashMap::new();
|
||||
for (words, count) in dat {
|
||||
words_count.insert(words, count);
|
||||
}
|
||||
BiGramUserStats::new(unique_count, total_count, words_count)
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Cannot load user bigram data from {}: {}", bigram_path, err);
|
||||
// ユーザーデータは初回起動時などにはないので、データがないものとして処理を続行する
|
||||
BiGramUserStats::new(0, 0, HashMap::new())
|
||||
}
|
||||
};
|
||||
|
||||
let kana_trie = match KanaTrie::load(kana_trie_path) {
|
||||
Ok(trie) => trie,
|
||||
Err(err) => {
|
||||
warn!("Cannot load kana trie: {} {}", kana_trie_path, err);
|
||||
KanaTrie::new(Marisa::new())
|
||||
}
|
||||
};
|
||||
|
||||
UserData {
|
||||
unigram_user_stats,
|
||||
bigram_user_stats,
|
||||
kana_trie,
|
||||
unigram_path: unigram_path.clone(),
|
||||
bigram_path: bigram_path.clone(),
|
||||
kana_trie_path: kana_trie_path.clone(),
|
||||
need_save: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// 入力確定した漢字のリストをユーザー統計データとして記録する。
|
||||
fn record_entries(&mut self, kanjis: Vec<String>, kanas: Vec<String>) {
|
||||
self.unigram_user_stats.record_entries(&kanjis);
|
||||
self.bigram_user_stats.record_entries(&kanjis);
|
||||
|
||||
// for kana in kanas {
|
||||
// TODO: record kanas to trie.
|
||||
// }
|
||||
}
|
||||
|
||||
fn write_user_stats_file(&self) -> Result<(), io::Error> {
|
||||
write_user_stats_file(&self.unigram_path, &self.unigram_user_stats.word_count)?;
|
||||
write_user_stats_file(&self.bigram_path, &self.bigram_user_stats.word_count)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
68
akaza-core/libakaza/src/user_data/user_stats_utils.rs
Normal file
68
akaza-core/libakaza/src/user_data/user_stats_utils.rs
Normal file
@ -0,0 +1,68 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::{fs, io};
|
||||
|
||||
pub(crate) fn read_user_stats_file(path: &String) -> Result<Vec<(String, u32)>, String> {
|
||||
let file = match File::open(path) {
|
||||
Ok(file) => file,
|
||||
Err(err) => {
|
||||
return Err(err.to_string());
|
||||
}
|
||||
};
|
||||
|
||||
let mut result: Vec<(String, u32)> = Vec::new();
|
||||
|
||||
for line in BufReader::new(file).lines() {
|
||||
let Ok(line) = line else {
|
||||
return Err("Cannot read user language model file".to_string());
|
||||
};
|
||||
let tokens: Vec<&str> = line.trim().splitn(2, " ").collect();
|
||||
if tokens.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let key = tokens[0];
|
||||
let Ok(count) = tokens[1].to_string().parse::<u32>() else {
|
||||
return Err("Invalid line in user language model: ".to_string() + tokens[1]);
|
||||
};
|
||||
|
||||
result.push((key.to_string(), count));
|
||||
}
|
||||
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
pub(crate) fn write_user_stats_file(
|
||||
path: &String,
|
||||
word_count: &HashMap<String, u32>,
|
||||
) -> Result<(), io::Error> {
|
||||
let mut tmpfile = File::create(path.clone() + ".tmp")?;
|
||||
|
||||
for (key, cnt) in word_count {
|
||||
tmpfile.write(key.as_bytes())?;
|
||||
tmpfile.write(" ".as_bytes())?;
|
||||
tmpfile.write(cnt.to_string().as_bytes())?;
|
||||
tmpfile.write("\n".as_bytes())?;
|
||||
}
|
||||
fs::rename(path.clone() + ".tmp", path.clone())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Read;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[test]
|
||||
fn test_write() {
|
||||
let tmpfile = NamedTempFile::new().unwrap();
|
||||
let path = tmpfile.path().to_str().unwrap().to_string();
|
||||
write_user_stats_file(&path, &HashMap::from([("渡し".to_string(), 3_u32)])).unwrap();
|
||||
let mut buf = String::new();
|
||||
File::open(path).unwrap().read_to_string(&mut buf).unwrap();
|
||||
assert_eq!(buf, "渡し 3\n");
|
||||
}
|
||||
}
|
@ -1,213 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader};
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum GramType {
|
||||
BiGram,
|
||||
UniGram,
|
||||
}
|
||||
|
||||
// unigram 用と bigram 用のロジックでコピペが増えがちで危ない。
|
||||
// 完全に分離したほうが良い。
|
||||
pub struct UserLanguageModel {
|
||||
unigram_path: String,
|
||||
bigram_path: String,
|
||||
|
||||
need_save: bool,
|
||||
|
||||
// unigram 言語モデルに登録されている、読み仮名を登録しておく。
|
||||
// これにより、「ひょいー」などの漢字ではないものを、単語として IME が認識できるように
|
||||
// している。
|
||||
// 本質的には、user language model でやるべき処理というよりも、ユーザー単語辞書でもつく
|
||||
// ってやるのが本筋だと思わなくもない
|
||||
unigram_kanas: HashSet<String>,
|
||||
|
||||
/// ユニーク単語数
|
||||
unigram_c: u32,
|
||||
/// 総単語出現数
|
||||
unigram_v: u32,
|
||||
unigram: HashMap<String, u32>,
|
||||
|
||||
bigram_c: u32,
|
||||
bigram_v: u32,
|
||||
bigram: HashMap<String, u32>,
|
||||
|
||||
alpha: f32, // = 0.00001;
|
||||
}
|
||||
|
||||
impl UserLanguageModel {
|
||||
fn new(unigram_path: &String, bigram_path: &String) -> UserLanguageModel {
|
||||
UserLanguageModel {
|
||||
unigram_path: unigram_path.clone(),
|
||||
bigram_path: bigram_path.clone(),
|
||||
need_save: false,
|
||||
unigram_kanas: HashSet::new(),
|
||||
unigram_c: 0,
|
||||
unigram_v: 0,
|
||||
unigram: HashMap::new(),
|
||||
bigram_c: 0,
|
||||
bigram_v: 0,
|
||||
bigram: HashMap::new(),
|
||||
alpha: 0.00001,
|
||||
}
|
||||
}
|
||||
|
||||
fn read(
|
||||
&mut self,
|
||||
path: &String,
|
||||
gram_type: GramType,
|
||||
) -> Result<(u32, u32, HashMap<String, u32>), String> {
|
||||
let mut c = 0;
|
||||
let mut v = 0;
|
||||
let mut map = HashMap::new();
|
||||
|
||||
// TODO : 厳密なエラー処理
|
||||
let Ok(file) = File::open(path) else {
|
||||
return Err("Cannot open user language model file".to_string());
|
||||
};
|
||||
|
||||
for line in BufReader::new(file).lines() {
|
||||
let Ok(line) = line else {
|
||||
return Err("Cannot read user language model file".to_string());
|
||||
};
|
||||
let tokens: Vec<&str> = line.trim().splitn(2, " ").collect();
|
||||
if tokens.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let key = tokens[0];
|
||||
let Ok(count) = tokens[1].to_string().parse::<u32>() else {
|
||||
return Err("Invalid line in user language model: ".to_string() + tokens[1]);
|
||||
};
|
||||
|
||||
map.insert(key.to_string(), count);
|
||||
|
||||
// unigram 言語モデルに登録されている、ひらがなの単語を、集めて登録しておく。
|
||||
// これにより、「ひょいー」などの漢字ではないものを、単語として IME が認識できるように
|
||||
// している。
|
||||
// 本質的には、user language model でやるべき処理というよりも、ユーザー単語辞書でもつく
|
||||
// ってやるのが本筋だと思わなくもない
|
||||
if gram_type == GramType::UniGram {
|
||||
let tokens: Vec<&str> = line.splitn(2, "/").collect();
|
||||
if tokens.len() != 2 {
|
||||
continue;
|
||||
}
|
||||
let kana = tokens[0];
|
||||
self.unigram_kanas.insert(kana.to_string());
|
||||
}
|
||||
|
||||
c += count;
|
||||
v += 1;
|
||||
}
|
||||
|
||||
return Ok((c, v, map));
|
||||
}
|
||||
|
||||
fn load_unigram(&mut self) -> Result<(), String> {
|
||||
let result = self.read(&self.unigram_path.clone(), GramType::UniGram);
|
||||
let Ok((c, v, map)) = result else {
|
||||
return Err(result.err().unwrap());
|
||||
};
|
||||
self.unigram_c = c;
|
||||
self.unigram_v = v;
|
||||
self.unigram = map;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
fn load_bigram(&mut self) -> Result<(), String> {
|
||||
let result = self.read(&self.bigram_path.clone(), GramType::BiGram);
|
||||
let Ok((c, v, map)) = result else {
|
||||
return Err(result.err().unwrap());
|
||||
};
|
||||
self.bigram_c = c;
|
||||
self.bigram_v = v;
|
||||
self.bigram = map;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
fn add_entry(&mut self, nodes: &Vec<crate::node::Node>) {
|
||||
// unigram
|
||||
for node in nodes {
|
||||
let key = &node.key;
|
||||
if !self.unigram.contains_key(key) {
|
||||
// increment unique count
|
||||
self.unigram_c += 1;
|
||||
}
|
||||
self.unigram_v += 1;
|
||||
let tokens: Vec<&str> = key.splitn(2, "/").collect();
|
||||
let kana = tokens[1];
|
||||
// std::wstring kana = std::get<1>(split2(key, L'/', splitted));
|
||||
self.unigram_kanas.insert(kana.to_string());
|
||||
self.unigram.insert(
|
||||
key.to_string(),
|
||||
self.unigram.get(key.as_str()).unwrap_or(&0_u32) + 1,
|
||||
);
|
||||
}
|
||||
|
||||
// bigram
|
||||
for i in 1..nodes.len() {
|
||||
let Some(node1) = nodes.get(i - 1) else {
|
||||
continue;
|
||||
};
|
||||
let Some(node2) = nodes.get(i) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let k1 = &node1.key.to_string().clone();
|
||||
let key = k1.to_string() + &"\t".to_string() + &node2.key;
|
||||
if self.bigram.contains_key(&key) {
|
||||
self.bigram_c += 1;
|
||||
}
|
||||
self.bigram_v += 1;
|
||||
self.bigram
|
||||
.insert(key.clone(), self.bigram.get(&key).unwrap_or(&0) + 1);
|
||||
}
|
||||
|
||||
self.need_save = true;
|
||||
}
|
||||
|
||||
pub(crate) fn get_unigram_cost(&self, key: &String) -> Option<f32> {
|
||||
let Some(count) = self.unigram.get(key) else {
|
||||
return None;
|
||||
};
|
||||
return Some(f32::log10(
|
||||
((*count as f32) + self.alpha)
|
||||
/ ((self.unigram_c as f32) + self.alpha + (self.unigram_v as f32)),
|
||||
));
|
||||
}
|
||||
|
||||
fn get_bigram_cost(&self, key1: &String, key2: &String) -> Option<f32> {
|
||||
let key = key1.clone() + "\t" + key2;
|
||||
let Some(count) = self.bigram.get(key.as_str()) else {
|
||||
return None;
|
||||
};
|
||||
return Some(f32::log10(
|
||||
((*count as f32) + self.alpha)
|
||||
/ ((self.bigram_c as f32) + self.alpha + (self.bigram_v as f32)),
|
||||
));
|
||||
}
|
||||
|
||||
// TODO save_file
|
||||
|
||||
/*
|
||||
|
||||
void akaza::UserLanguageModel::save_file(const std::string &path, const std::unordered_map<std::wstring, int> &map) {
|
||||
std::string tmppath(path + ".tmp");
|
||||
std::wofstream ofs(tmppath, std::ofstream::out);
|
||||
ofs.imbue(std::locale(std::locale(), new std::codecvt_utf8<wchar_t>));
|
||||
|
||||
for (const auto&[words, count] : map) {
|
||||
ofs << words << " " << count << std::endl;
|
||||
}
|
||||
ofs.close();
|
||||
|
||||
int status = std::rename(tmppath.c_str(), path.c_str());
|
||||
if (status != 0) {
|
||||
std::string err = strerror(errno);
|
||||
throw std::runtime_error(err + " : " + path + " (Cannot write user language model)");
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
@ -10,6 +10,7 @@ akaza はユーザー固有のデータを保持し利用する。
|
||||
ユーザーが入力したデータの統計データである。
|
||||
|
||||
ユーザーが入力した単語の、unigram と bigram が統計データとして保存される。
|
||||
保存されるのは「漢字」の方。
|
||||
|
||||
## ユーザー言語モデル
|
||||
|
||||
@ -18,7 +19,17 @@ akaza はユーザー固有のデータを保持し利用する。
|
||||
|
||||
- C: ユーザーが入力した単語のユニーク数
|
||||
- V: ユーザーが入力した単語の総数
|
||||
- word_count: 単語ごとの漢字入力回数
|
||||
|
||||
これらをもとに、コストを計算する。ユーザー言語モデルから得られるコスト値は、システム辞書に記録されるコスト値よりも低く設定されている。これにより、一度入力した単語は強烈に表出するようになる。
|
||||
|
||||
## ユーザー共通接頭辞
|
||||
|
||||
ユーザーの入力統計データをもとに、入力データの「かな」部分を利用して trie を構築する。
|
||||
入力データの「かな」部分を利用して trie を構築する。
|
||||
|
||||
## 目指している形
|
||||
|
||||
SKK では、一度入力されたデータはユーザー辞書に登録されていく。これにより強烈にパーソナライズされていくので、そうそう誤変換しなくなっていく。
|
||||
|
||||
これと同じようなユーザー体験を得られるようにしていきたい。
|
||||
|
||||
|
Reference in New Issue
Block a user