mirror of
https://github.com/mii443/akaza.git
synced 2025-08-22 23:05:26 +00:00
Implemented user data related code
This commit is contained in:
1
akaza-core/Cargo.lock
generated
1
akaza-core/Cargo.lock
generated
@ -334,6 +334,7 @@ dependencies = [
|
|||||||
"marisa-sys",
|
"marisa-sys",
|
||||||
"regex",
|
"regex",
|
||||||
"sled",
|
"sled",
|
||||||
|
"tempfile",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -12,4 +12,7 @@ regex = "1"
|
|||||||
sled = "0.34.7"
|
sled = "0.34.7"
|
||||||
daachorse = "1.0.0"
|
daachorse = "1.0.0"
|
||||||
log = "0.4.17"
|
log = "0.4.17"
|
||||||
env_logger = "0.10.0"
|
env_logger = "0.10.0"
|
||||||
|
tempfile = "3"
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
@ -159,6 +159,27 @@ impl GraphResolver {
|
|||||||
// 経験上、長い文字列のほうがあたり、というルールでもそこそこ変換できる。
|
// 経験上、長い文字列のほうがあたり、というルールでもそこそこ変換できる。
|
||||||
// TODO あとでちゃんと unigram のコストを使うよに変える。
|
// TODO あとでちゃんと unigram のコストを使うよに変える。
|
||||||
return node.kanji.len() as f32;
|
return node.kanji.len() as f32;
|
||||||
|
|
||||||
|
/*
|
||||||
|
if let Some(user_cost) = user_language_model.get_unigram_cost(&self.key) {
|
||||||
|
// use user's score, if it's exists.
|
||||||
|
return user_cost;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.system_word_id != UNKNOWN_WORD_ID {
|
||||||
|
self.total_cost = Some(self.system_unigram_cost);
|
||||||
|
return self.system_unigram_cost;
|
||||||
|
} else {
|
||||||
|
// 労働者災害補償保険法 のように、システム辞書には採録されているが,
|
||||||
|
// 言語モデルには採録されていない場合,漢字候補を先頭に持ってくる。
|
||||||
|
return if self.word.len() < self.yomi.len() {
|
||||||
|
// 読みのほうが短いので、漢字。
|
||||||
|
ulm.get_default_cost_for_short()
|
||||||
|
} else {
|
||||||
|
ulm.get_default_cost()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
// -1 0 1 2
|
// -1 0 1 2
|
||||||
|
@ -35,10 +35,12 @@ impl KanaTrie {
|
|||||||
self.marisa.save(file_name)
|
self.marisa.save(file_name)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn load(file_name: &String) -> KanaTrie {
|
pub(crate) fn load(file_name: &String) -> Result<KanaTrie, String> {
|
||||||
let marisa = Marisa::new();
|
let marisa = Marisa::new();
|
||||||
marisa.load(file_name).unwrap();
|
match marisa.load(file_name) {
|
||||||
KanaTrie { marisa }
|
Ok(_) => Ok(KanaTrie { marisa }),
|
||||||
|
Err(err) => Err(err),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn common_prefix_search(&self, query: &String) -> Vec<String> {
|
pub(crate) fn common_prefix_search(&self, query: &String) -> Vec<String> {
|
||||||
|
@ -4,11 +4,9 @@ mod graph_resolver;
|
|||||||
pub mod kana;
|
pub mod kana;
|
||||||
pub(crate) mod kana_trie;
|
pub(crate) mod kana_trie;
|
||||||
pub mod lm;
|
pub mod lm;
|
||||||
mod node;
|
|
||||||
mod romkan;
|
mod romkan;
|
||||||
mod tinylisp;
|
mod tinylisp;
|
||||||
pub mod trie;
|
pub mod trie;
|
||||||
mod user_data;
|
mod user_data;
|
||||||
pub mod user_language_model;
|
|
||||||
|
|
||||||
const UNKNOWN_WORD_ID: i32 = -1;
|
const UNKNOWN_WORD_ID: i32 = -1;
|
||||||
|
@ -1,118 +0,0 @@
|
|||||||
use crate::lm::system_unigram_lm::SystemUnigramLM;
|
|
||||||
use crate::user_language_model::UserLanguageModel;
|
|
||||||
use crate::UNKNOWN_WORD_ID;
|
|
||||||
|
|
||||||
pub(crate) struct Node {
|
|
||||||
start_pos: i32,
|
|
||||||
yomi: String,
|
|
||||||
word: String,
|
|
||||||
pub(crate) key: String,
|
|
||||||
is_bos: bool,
|
|
||||||
is_eos: bool,
|
|
||||||
system_word_id: i32,
|
|
||||||
system_unigram_cost: f32,
|
|
||||||
total_cost: Option<f32>, // unigram cost + bigram cost + previous cost
|
|
||||||
prev: Option<Box<Node>>,
|
|
||||||
// bigram_cache: HashMap<String, f32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Node {
|
|
||||||
fn new(
|
|
||||||
start_pos: i32,
|
|
||||||
yomi: &String,
|
|
||||||
word: &String,
|
|
||||||
key: &String,
|
|
||||||
is_bos: bool,
|
|
||||||
is_eos: bool,
|
|
||||||
system_word_id: i32,
|
|
||||||
system_unigram_cost: f32,
|
|
||||||
) -> Node {
|
|
||||||
Node {
|
|
||||||
start_pos,
|
|
||||||
yomi: yomi.clone(),
|
|
||||||
word: word.clone(),
|
|
||||||
key: key.clone(),
|
|
||||||
is_bos,
|
|
||||||
is_eos,
|
|
||||||
system_word_id,
|
|
||||||
system_unigram_cost,
|
|
||||||
total_cost: Option::None,
|
|
||||||
prev: Option::None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_bos_node() -> Node {
|
|
||||||
Node {
|
|
||||||
start_pos: -1,
|
|
||||||
yomi: "__BOS__".to_string(),
|
|
||||||
word: "__BOS__".to_string(),
|
|
||||||
key: "__BOS__/__BOS__".to_string(),
|
|
||||||
is_bos: true,
|
|
||||||
is_eos: false,
|
|
||||||
system_word_id: UNKNOWN_WORD_ID,
|
|
||||||
system_unigram_cost: 0_f32,
|
|
||||||
total_cost: None,
|
|
||||||
prev: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_eos_node(start_pos: i32) -> Node {
|
|
||||||
// 本来使うべきだが、key をわざと使わない。__EOS__ 考慮すると変換精度が落ちるので。。今は使わない。
|
|
||||||
// うまく使えることが確認できれば、__EOS__/__EOS__ にする。
|
|
||||||
Node {
|
|
||||||
start_pos,
|
|
||||||
yomi: "__EOS__".to_string(),
|
|
||||||
word: "__EOS__".to_string(),
|
|
||||||
key: "__EOS__".to_string(),
|
|
||||||
is_bos: false,
|
|
||||||
is_eos: true,
|
|
||||||
system_word_id: UNKNOWN_WORD_ID,
|
|
||||||
system_unigram_cost: 0_f32,
|
|
||||||
total_cost: None,
|
|
||||||
prev: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_node(
|
|
||||||
system_unigram_lm: &SystemUnigramLM,
|
|
||||||
start_pos: i32,
|
|
||||||
yomi: &String,
|
|
||||||
kanji: &String,
|
|
||||||
) -> Node {
|
|
||||||
let key = kanji.clone() + "/" + yomi;
|
|
||||||
let result = system_unigram_lm.find_unigram(&key).unwrap();
|
|
||||||
let (word_id, cost) = result;
|
|
||||||
Self::new(start_pos, yomi, kanji, &key, false, false, word_id, cost)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_node_cost(
|
|
||||||
&mut self,
|
|
||||||
user_language_model: &UserLanguageModel,
|
|
||||||
ulm: &SystemUnigramLM,
|
|
||||||
) -> f32 {
|
|
||||||
if let Some(user_cost) = user_language_model.get_unigram_cost(&self.key) {
|
|
||||||
// use user's score, if it's exists.
|
|
||||||
return user_cost;
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.system_word_id != UNKNOWN_WORD_ID {
|
|
||||||
self.total_cost = Some(self.system_unigram_cost);
|
|
||||||
return self.system_unigram_cost;
|
|
||||||
} else {
|
|
||||||
// 労働者災害補償保険法 のように、システム辞書には採録されているが,
|
|
||||||
// 言語モデルには採録されていない場合,漢字候補を先頭に持ってくる。
|
|
||||||
return if self.word.len() < self.yomi.len() {
|
|
||||||
// 読みのほうが短いので、漢字。
|
|
||||||
ulm.get_default_cost_for_short()
|
|
||||||
} else {
|
|
||||||
ulm.get_default_cost()
|
|
||||||
};
|
|
||||||
}
|
|
||||||
// calc_bigram_cost
|
|
||||||
// get_bigram_cost
|
|
||||||
// get_bigram_cost_from_cache
|
|
||||||
// set_prev
|
|
||||||
// ==
|
|
||||||
// surface
|
|
||||||
}
|
|
||||||
}
|
|
@ -0,0 +1,68 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
const ALPHA: f32 = 0.00001;
|
||||||
|
|
||||||
|
pub(crate) struct BiGramUserStats {
|
||||||
|
/// ユニーク単語数
|
||||||
|
unique_words: u32,
|
||||||
|
// C
|
||||||
|
/// 総単語出現数
|
||||||
|
total_words: u32,
|
||||||
|
// V
|
||||||
|
/// その単語の出現頻度。「漢字/漢字」がキー。
|
||||||
|
pub(crate) word_count: HashMap<String, u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BiGramUserStats {
|
||||||
|
pub(crate) fn new(
|
||||||
|
unique_words: u32,
|
||||||
|
total_words: u32,
|
||||||
|
word_count: HashMap<String, u32>,
|
||||||
|
) -> BiGramUserStats {
|
||||||
|
BiGramUserStats {
|
||||||
|
unique_words,
|
||||||
|
total_words,
|
||||||
|
word_count,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* エッジコストを計算する。
|
||||||
|
* システム言語モデルのコストよりも安くなるように調整してある。
|
||||||
|
*/
|
||||||
|
fn get_cost(&self, key1: &String, key2: &String) -> Option<f32> {
|
||||||
|
let key = key1.clone() + "\t" + key2;
|
||||||
|
let Some(count) = self.word_count.get(key.as_str()) else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
return Some(f32::log10(
|
||||||
|
((*count as f32) + ALPHA)
|
||||||
|
/ ((self.unique_words as f32) + ALPHA + (self.total_words as f32)),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn record_entries(&mut self, kanjis: &Vec<String>) {
|
||||||
|
if kanjis.len() < 2 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// bigram
|
||||||
|
for i in 1..kanjis.len() {
|
||||||
|
let Some(kanji1) = kanjis.get(i - 1) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let Some(kanji2) = kanjis.get(i) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let key = kanji1.clone() + &"\t".to_string() + &kanji2;
|
||||||
|
if let Some(cnt) = self.word_count.get(&key) {
|
||||||
|
self.word_count.insert(key, cnt + 1);
|
||||||
|
} else {
|
||||||
|
self.word_count.insert(key, 1);
|
||||||
|
self.unique_words += 1;
|
||||||
|
}
|
||||||
|
self.total_words += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mod abstract_user_stats;
|
|
||||||
mod bigram_user_stats;
|
mod bigram_user_stats;
|
||||||
mod unigram_user_stats;
|
mod unigram_user_stats;
|
||||||
mod user_data;
|
mod user_data;
|
||||||
|
mod user_stats_utils;
|
||||||
|
@ -0,0 +1,52 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
const ALPHA: f32 = 0.00001;
|
||||||
|
|
||||||
|
pub(crate) struct UniGramUserStats {
|
||||||
|
/// ユニーク単語数
|
||||||
|
unique_words: u32, // C
|
||||||
|
/// 総単語出現数
|
||||||
|
total_words: u32, // V
|
||||||
|
/// その単語の出現頻度。「漢字」がキー。
|
||||||
|
pub(crate) word_count: HashMap<String, u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UniGramUserStats {
|
||||||
|
pub(crate) fn new(
|
||||||
|
unique_words: u32,
|
||||||
|
total_words: u32,
|
||||||
|
word_count: HashMap<String, u32>,
|
||||||
|
) -> UniGramUserStats {
|
||||||
|
UniGramUserStats {
|
||||||
|
unique_words,
|
||||||
|
total_words,
|
||||||
|
word_count,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ノードコストを計算する。
|
||||||
|
* システム言語モデルのコストよりも安くなるように調整してある。
|
||||||
|
*/
|
||||||
|
pub(crate) fn get_cost(&self, key: &String) -> Option<f32> {
|
||||||
|
let Some(count) = self.word_count.get(key) else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
return Some(f32::log10(
|
||||||
|
((*count as f32) + ALPHA)
|
||||||
|
/ ((self.unique_words as f32) + ALPHA + (self.total_words as f32)),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn record_entries(&mut self, kanjis: &Vec<String>) {
|
||||||
|
for kanji in kanjis {
|
||||||
|
if let Some(i) = self.word_count.get(kanji) {
|
||||||
|
self.word_count.insert(kanji.clone(), i + 1);
|
||||||
|
} else {
|
||||||
|
self.word_count.insert(kanji.clone(), 1);
|
||||||
|
self.unique_words += 1;
|
||||||
|
}
|
||||||
|
self.total_words += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,4 +1,12 @@
|
|||||||
use crate::kana_trie::KanaTrie;
|
use log::warn;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::io;
|
||||||
|
|
||||||
|
use crate::kana_trie::{KanaTrie, KanaTrieBuilder};
|
||||||
|
use crate::user_data::bigram_user_stats::BiGramUserStats;
|
||||||
|
use crate::user_data::unigram_user_stats::UniGramUserStats;
|
||||||
|
use crate::user_data::user_stats_utils::{read_user_stats_file, write_user_stats_file};
|
||||||
|
use marisa_sys::Marisa;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ユーザー固有データ
|
* ユーザー固有データ
|
||||||
@ -6,4 +14,91 @@ use crate::kana_trie::KanaTrie;
|
|||||||
struct UserData {
|
struct UserData {
|
||||||
/// 読み仮名のトライ。入力変換時に共通接頭辞検索するために使用。
|
/// 読み仮名のトライ。入力変換時に共通接頭辞検索するために使用。
|
||||||
kana_trie: KanaTrie,
|
kana_trie: KanaTrie,
|
||||||
|
unigram_user_stats: UniGramUserStats,
|
||||||
|
bigram_user_stats: BiGramUserStats,
|
||||||
|
|
||||||
|
unigram_path: String,
|
||||||
|
bigram_path: String,
|
||||||
|
kana_trie_path: String,
|
||||||
|
|
||||||
|
pub need_save: bool,
|
||||||
|
}
|
||||||
|
impl UserData {
|
||||||
|
fn load(unigram_path: &String, bigram_path: &String, kana_trie_path: &String) -> UserData {
|
||||||
|
// ユーザーデータが読み込めないことは fatal エラーではない。
|
||||||
|
// 初回起動時にはデータがないので。
|
||||||
|
// データがなければ初期所状態から始める
|
||||||
|
let unigram_user_stats = match read_user_stats_file(unigram_path) {
|
||||||
|
Ok(dat) => {
|
||||||
|
let unique_count = dat.len() as u32;
|
||||||
|
let total_count: u32 = dat.iter().map(|f| f.1).sum();
|
||||||
|
let mut word_count: HashMap<String, u32> = HashMap::new();
|
||||||
|
for (word, count) in dat {
|
||||||
|
word_count.insert(word, count);
|
||||||
|
}
|
||||||
|
UniGramUserStats::new(unique_count, total_count, word_count)
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
warn!(
|
||||||
|
"Cannot load user unigram data from {}: {}",
|
||||||
|
unigram_path, err
|
||||||
|
);
|
||||||
|
|
||||||
|
let unigram_user_stats = UniGramUserStats::new(0, 0, HashMap::new());
|
||||||
|
UniGramUserStats::new(0, 0, HashMap::new())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// build bigram
|
||||||
|
let bigram_user_stats = match read_user_stats_file(bigram_path) {
|
||||||
|
Ok(dat) => {
|
||||||
|
let unique_count = dat.len() as u32;
|
||||||
|
let total_count: u32 = dat.iter().map(|f| f.1).sum();
|
||||||
|
let mut words_count: HashMap<String, u32> = HashMap::new();
|
||||||
|
for (words, count) in dat {
|
||||||
|
words_count.insert(words, count);
|
||||||
|
}
|
||||||
|
BiGramUserStats::new(unique_count, total_count, words_count)
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
warn!("Cannot load user bigram data from {}: {}", bigram_path, err);
|
||||||
|
// ユーザーデータは初回起動時などにはないので、データがないものとして処理を続行する
|
||||||
|
BiGramUserStats::new(0, 0, HashMap::new())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let kana_trie = match KanaTrie::load(kana_trie_path) {
|
||||||
|
Ok(trie) => trie,
|
||||||
|
Err(err) => {
|
||||||
|
warn!("Cannot load kana trie: {} {}", kana_trie_path, err);
|
||||||
|
KanaTrie::new(Marisa::new())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
UserData {
|
||||||
|
unigram_user_stats,
|
||||||
|
bigram_user_stats,
|
||||||
|
kana_trie,
|
||||||
|
unigram_path: unigram_path.clone(),
|
||||||
|
bigram_path: bigram_path.clone(),
|
||||||
|
kana_trie_path: kana_trie_path.clone(),
|
||||||
|
need_save: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 入力確定した漢字のリストをユーザー統計データとして記録する。
|
||||||
|
fn record_entries(&mut self, kanjis: Vec<String>, kanas: Vec<String>) {
|
||||||
|
self.unigram_user_stats.record_entries(&kanjis);
|
||||||
|
self.bigram_user_stats.record_entries(&kanjis);
|
||||||
|
|
||||||
|
// for kana in kanas {
|
||||||
|
// TODO: record kanas to trie.
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_user_stats_file(&self) -> Result<(), io::Error> {
|
||||||
|
write_user_stats_file(&self.unigram_path, &self.unigram_user_stats.word_count)?;
|
||||||
|
write_user_stats_file(&self.bigram_path, &self.bigram_user_stats.word_count)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
68
akaza-core/libakaza/src/user_data/user_stats_utils.rs
Normal file
68
akaza-core/libakaza/src/user_data/user_stats_utils.rs
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::{BufRead, BufReader, Write};
|
||||||
|
use std::{fs, io};
|
||||||
|
|
||||||
|
pub(crate) fn read_user_stats_file(path: &String) -> Result<Vec<(String, u32)>, String> {
|
||||||
|
let file = match File::open(path) {
|
||||||
|
Ok(file) => file,
|
||||||
|
Err(err) => {
|
||||||
|
return Err(err.to_string());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut result: Vec<(String, u32)> = Vec::new();
|
||||||
|
|
||||||
|
for line in BufReader::new(file).lines() {
|
||||||
|
let Ok(line) = line else {
|
||||||
|
return Err("Cannot read user language model file".to_string());
|
||||||
|
};
|
||||||
|
let tokens: Vec<&str> = line.trim().splitn(2, " ").collect();
|
||||||
|
if tokens.len() != 2 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let key = tokens[0];
|
||||||
|
let Ok(count) = tokens[1].to_string().parse::<u32>() else {
|
||||||
|
return Err("Invalid line in user language model: ".to_string() + tokens[1]);
|
||||||
|
};
|
||||||
|
|
||||||
|
result.push((key.to_string(), count));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Ok(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn write_user_stats_file(
|
||||||
|
path: &String,
|
||||||
|
word_count: &HashMap<String, u32>,
|
||||||
|
) -> Result<(), io::Error> {
|
||||||
|
let mut tmpfile = File::create(path.clone() + ".tmp")?;
|
||||||
|
|
||||||
|
for (key, cnt) in word_count {
|
||||||
|
tmpfile.write(key.as_bytes())?;
|
||||||
|
tmpfile.write(" ".as_bytes())?;
|
||||||
|
tmpfile.write(cnt.to_string().as_bytes())?;
|
||||||
|
tmpfile.write("\n".as_bytes())?;
|
||||||
|
}
|
||||||
|
fs::rename(path.clone() + ".tmp", path.clone())?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::io::Read;
|
||||||
|
use tempfile::NamedTempFile;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_write() {
|
||||||
|
let tmpfile = NamedTempFile::new().unwrap();
|
||||||
|
let path = tmpfile.path().to_str().unwrap().to_string();
|
||||||
|
write_user_stats_file(&path, &HashMap::from([("渡し".to_string(), 3_u32)])).unwrap();
|
||||||
|
let mut buf = String::new();
|
||||||
|
File::open(path).unwrap().read_to_string(&mut buf).unwrap();
|
||||||
|
assert_eq!(buf, "渡し 3\n");
|
||||||
|
}
|
||||||
|
}
|
@ -1,213 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use std::collections::HashSet;
|
|
||||||
use std::fs::File;
|
|
||||||
use std::io::{BufRead, BufReader};
|
|
||||||
|
|
||||||
#[derive(PartialEq)]
|
|
||||||
enum GramType {
|
|
||||||
BiGram,
|
|
||||||
UniGram,
|
|
||||||
}
|
|
||||||
|
|
||||||
// unigram 用と bigram 用のロジックでコピペが増えがちで危ない。
|
|
||||||
// 完全に分離したほうが良い。
|
|
||||||
pub struct UserLanguageModel {
|
|
||||||
unigram_path: String,
|
|
||||||
bigram_path: String,
|
|
||||||
|
|
||||||
need_save: bool,
|
|
||||||
|
|
||||||
// unigram 言語モデルに登録されている、読み仮名を登録しておく。
|
|
||||||
// これにより、「ひょいー」などの漢字ではないものを、単語として IME が認識できるように
|
|
||||||
// している。
|
|
||||||
// 本質的には、user language model でやるべき処理というよりも、ユーザー単語辞書でもつく
|
|
||||||
// ってやるのが本筋だと思わなくもない
|
|
||||||
unigram_kanas: HashSet<String>,
|
|
||||||
|
|
||||||
/// ユニーク単語数
|
|
||||||
unigram_c: u32,
|
|
||||||
/// 総単語出現数
|
|
||||||
unigram_v: u32,
|
|
||||||
unigram: HashMap<String, u32>,
|
|
||||||
|
|
||||||
bigram_c: u32,
|
|
||||||
bigram_v: u32,
|
|
||||||
bigram: HashMap<String, u32>,
|
|
||||||
|
|
||||||
alpha: f32, // = 0.00001;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UserLanguageModel {
|
|
||||||
fn new(unigram_path: &String, bigram_path: &String) -> UserLanguageModel {
|
|
||||||
UserLanguageModel {
|
|
||||||
unigram_path: unigram_path.clone(),
|
|
||||||
bigram_path: bigram_path.clone(),
|
|
||||||
need_save: false,
|
|
||||||
unigram_kanas: HashSet::new(),
|
|
||||||
unigram_c: 0,
|
|
||||||
unigram_v: 0,
|
|
||||||
unigram: HashMap::new(),
|
|
||||||
bigram_c: 0,
|
|
||||||
bigram_v: 0,
|
|
||||||
bigram: HashMap::new(),
|
|
||||||
alpha: 0.00001,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read(
|
|
||||||
&mut self,
|
|
||||||
path: &String,
|
|
||||||
gram_type: GramType,
|
|
||||||
) -> Result<(u32, u32, HashMap<String, u32>), String> {
|
|
||||||
let mut c = 0;
|
|
||||||
let mut v = 0;
|
|
||||||
let mut map = HashMap::new();
|
|
||||||
|
|
||||||
// TODO : 厳密なエラー処理
|
|
||||||
let Ok(file) = File::open(path) else {
|
|
||||||
return Err("Cannot open user language model file".to_string());
|
|
||||||
};
|
|
||||||
|
|
||||||
for line in BufReader::new(file).lines() {
|
|
||||||
let Ok(line) = line else {
|
|
||||||
return Err("Cannot read user language model file".to_string());
|
|
||||||
};
|
|
||||||
let tokens: Vec<&str> = line.trim().splitn(2, " ").collect();
|
|
||||||
if tokens.len() != 2 {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let key = tokens[0];
|
|
||||||
let Ok(count) = tokens[1].to_string().parse::<u32>() else {
|
|
||||||
return Err("Invalid line in user language model: ".to_string() + tokens[1]);
|
|
||||||
};
|
|
||||||
|
|
||||||
map.insert(key.to_string(), count);
|
|
||||||
|
|
||||||
// unigram 言語モデルに登録されている、ひらがなの単語を、集めて登録しておく。
|
|
||||||
// これにより、「ひょいー」などの漢字ではないものを、単語として IME が認識できるように
|
|
||||||
// している。
|
|
||||||
// 本質的には、user language model でやるべき処理というよりも、ユーザー単語辞書でもつく
|
|
||||||
// ってやるのが本筋だと思わなくもない
|
|
||||||
if gram_type == GramType::UniGram {
|
|
||||||
let tokens: Vec<&str> = line.splitn(2, "/").collect();
|
|
||||||
if tokens.len() != 2 {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let kana = tokens[0];
|
|
||||||
self.unigram_kanas.insert(kana.to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
c += count;
|
|
||||||
v += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Ok((c, v, map));
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_unigram(&mut self) -> Result<(), String> {
|
|
||||||
let result = self.read(&self.unigram_path.clone(), GramType::UniGram);
|
|
||||||
let Ok((c, v, map)) = result else {
|
|
||||||
return Err(result.err().unwrap());
|
|
||||||
};
|
|
||||||
self.unigram_c = c;
|
|
||||||
self.unigram_v = v;
|
|
||||||
self.unigram = map;
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_bigram(&mut self) -> Result<(), String> {
|
|
||||||
let result = self.read(&self.bigram_path.clone(), GramType::BiGram);
|
|
||||||
let Ok((c, v, map)) = result else {
|
|
||||||
return Err(result.err().unwrap());
|
|
||||||
};
|
|
||||||
self.bigram_c = c;
|
|
||||||
self.bigram_v = v;
|
|
||||||
self.bigram = map;
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn add_entry(&mut self, nodes: &Vec<crate::node::Node>) {
|
|
||||||
// unigram
|
|
||||||
for node in nodes {
|
|
||||||
let key = &node.key;
|
|
||||||
if !self.unigram.contains_key(key) {
|
|
||||||
// increment unique count
|
|
||||||
self.unigram_c += 1;
|
|
||||||
}
|
|
||||||
self.unigram_v += 1;
|
|
||||||
let tokens: Vec<&str> = key.splitn(2, "/").collect();
|
|
||||||
let kana = tokens[1];
|
|
||||||
// std::wstring kana = std::get<1>(split2(key, L'/', splitted));
|
|
||||||
self.unigram_kanas.insert(kana.to_string());
|
|
||||||
self.unigram.insert(
|
|
||||||
key.to_string(),
|
|
||||||
self.unigram.get(key.as_str()).unwrap_or(&0_u32) + 1,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// bigram
|
|
||||||
for i in 1..nodes.len() {
|
|
||||||
let Some(node1) = nodes.get(i - 1) else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
let Some(node2) = nodes.get(i) else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
let k1 = &node1.key.to_string().clone();
|
|
||||||
let key = k1.to_string() + &"\t".to_string() + &node2.key;
|
|
||||||
if self.bigram.contains_key(&key) {
|
|
||||||
self.bigram_c += 1;
|
|
||||||
}
|
|
||||||
self.bigram_v += 1;
|
|
||||||
self.bigram
|
|
||||||
.insert(key.clone(), self.bigram.get(&key).unwrap_or(&0) + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.need_save = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn get_unigram_cost(&self, key: &String) -> Option<f32> {
|
|
||||||
let Some(count) = self.unigram.get(key) else {
|
|
||||||
return None;
|
|
||||||
};
|
|
||||||
return Some(f32::log10(
|
|
||||||
((*count as f32) + self.alpha)
|
|
||||||
/ ((self.unigram_c as f32) + self.alpha + (self.unigram_v as f32)),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_bigram_cost(&self, key1: &String, key2: &String) -> Option<f32> {
|
|
||||||
let key = key1.clone() + "\t" + key2;
|
|
||||||
let Some(count) = self.bigram.get(key.as_str()) else {
|
|
||||||
return None;
|
|
||||||
};
|
|
||||||
return Some(f32::log10(
|
|
||||||
((*count as f32) + self.alpha)
|
|
||||||
/ ((self.bigram_c as f32) + self.alpha + (self.bigram_v as f32)),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO save_file
|
|
||||||
|
|
||||||
/*
|
|
||||||
|
|
||||||
void akaza::UserLanguageModel::save_file(const std::string &path, const std::unordered_map<std::wstring, int> &map) {
|
|
||||||
std::string tmppath(path + ".tmp");
|
|
||||||
std::wofstream ofs(tmppath, std::ofstream::out);
|
|
||||||
ofs.imbue(std::locale(std::locale(), new std::codecvt_utf8<wchar_t>));
|
|
||||||
|
|
||||||
for (const auto&[words, count] : map) {
|
|
||||||
ofs << words << " " << count << std::endl;
|
|
||||||
}
|
|
||||||
ofs.close();
|
|
||||||
|
|
||||||
int status = std::rename(tmppath.c_str(), path.c_str());
|
|
||||||
if (status != 0) {
|
|
||||||
std::string err = strerror(errno);
|
|
||||||
throw std::runtime_error(err + " : " + path + " (Cannot write user language model)");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
|
@ -10,6 +10,7 @@ akaza はユーザー固有のデータを保持し利用する。
|
|||||||
ユーザーが入力したデータの統計データである。
|
ユーザーが入力したデータの統計データである。
|
||||||
|
|
||||||
ユーザーが入力した単語の、unigram と bigram が統計データとして保存される。
|
ユーザーが入力した単語の、unigram と bigram が統計データとして保存される。
|
||||||
|
保存されるのは「漢字」の方。
|
||||||
|
|
||||||
## ユーザー言語モデル
|
## ユーザー言語モデル
|
||||||
|
|
||||||
@ -18,7 +19,17 @@ akaza はユーザー固有のデータを保持し利用する。
|
|||||||
|
|
||||||
- C: ユーザーが入力した単語のユニーク数
|
- C: ユーザーが入力した単語のユニーク数
|
||||||
- V: ユーザーが入力した単語の総数
|
- V: ユーザーが入力した単語の総数
|
||||||
|
- word_count: 単語ごとの漢字入力回数
|
||||||
|
|
||||||
|
これらをもとに、コストを計算する。ユーザー言語モデルから得られるコスト値は、システム辞書に記録されるコスト値よりも低く設定されている。これにより、一度入力した単語は強烈に表出するようになる。
|
||||||
|
|
||||||
## ユーザー共通接頭辞
|
## ユーザー共通接頭辞
|
||||||
|
|
||||||
ユーザーの入力統計データをもとに、入力データの「かな」部分を利用して trie を構築する。
|
入力データの「かな」部分を利用して trie を構築する。
|
||||||
|
|
||||||
|
## 目指している形
|
||||||
|
|
||||||
|
SKK では、一度入力されたデータはユーザー辞書に登録されていく。これにより強烈にパーソナライズされていくので、そうそう誤変換しなくなっていく。
|
||||||
|
|
||||||
|
これと同じようなユーザー体験を得られるようにしていきたい。
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user