Merge pull request #129 from tokuhirom/dictionary-loading-refactoring

Dictionary loading refactoring
This commit is contained in:
Tokuhiro Matsuno
2023-01-16 11:18:39 +09:00
committed by GitHub
12 changed files with 112 additions and 124 deletions

View File

@ -3,8 +3,8 @@ DATADIR ?= $(PREFIX)/share
DESTDIR ?=
all: data/stats-vibrato-bigram.trie \
data/stats-vibrato-bigram.trie \
all: data/bigram.model \
data/bigram.model \
data/SKK-JISYO.akaza
# -------------------------------------------------------------------------
@ -81,7 +81,7 @@ work/stats-vibrato-bigram.raw.trie: work/stats-vibrato-unigram.raw.trie work/sta
--corpus-dirs work/aozora_bunko/vibrato-ipadic/ \
work/stats-vibrato-unigram.raw.trie work/stats-vibrato-bigram.raw.trie
data/stats-vibrato-bigram.trie: work/stats-vibrato-bigram.raw.trie work/stats-vibrato-unigram.raw.trie src/subcmd/learn_corpus.rs corpus/must.txt corpus/should.txt corpus/may.txt data/SKK-JISYO.akaza
data/bigram.model: work/stats-vibrato-bigram.raw.trie work/stats-vibrato-unigram.raw.trie src/subcmd/learn_corpus.rs corpus/must.txt corpus/should.txt corpus/may.txt data/SKK-JISYO.akaza
cargo run --release -- learn-corpus \
--delta=0.5 \
--may-epochs=10 \
@ -91,10 +91,10 @@ data/stats-vibrato-bigram.trie: work/stats-vibrato-bigram.raw.trie work/stats-vi
corpus/should.txt \
corpus/must.txt \
work/stats-vibrato-unigram.raw.trie work/stats-vibrato-bigram.raw.trie \
data/stats-vibrato-unigram.trie data/stats-vibrato-bigram.trie \
data/unigram.model data/bigram.model \
-v
data/stats-vibrato-unigram.trie: data/stats-vibrato-bigram.trie
data/unigram.model: data/bigram.model
# -------------------------------------------------------------------------

View File

@ -33,7 +33,7 @@ TODO: 書き直し
## 生成されるデータ
### stats-vibrato-bigram.trie, stats-vibrato-unigram.trie
### bigram.model, unigram.model
marisa-trie 形式のデータです。1gram, 2gram のデータが素直に格納されています。

View File

@ -1,23 +1,29 @@
use std::fs::File;
use std::io::Write;
use std::path::Path;
use std::sync::{Arc, Mutex};
use encoding_rs::{EUC_JP, UTF_8};
use log::info;
use libakaza::dict::merge_dict::merge_dict;
use libakaza::dict::skk::read::read_skkdict;
use libakaza::config::{Config, DictConfig};
use libakaza::engine::bigram_word_viterbi_engine::BigramWordViterbiEngineBuilder;
use libakaza::user_side_data::user_data::UserData;
pub fn check(yomi: &str, expected: Option<String>, user_data: bool) -> anyhow::Result<()> {
let dict = merge_dict(vec![
read_skkdict(Path::new("skk-dev-dict/SKK-JISYO.L"), EUC_JP)?,
read_skkdict(Path::new("data/SKK-JISYO.akaza"), UTF_8)?,
]);
let mut builder = BigramWordViterbiEngineBuilder::new(Some(dict), None);
let mut builder = BigramWordViterbiEngineBuilder::new(Config {
dicts: vec![
DictConfig {
dict_type: "skk".to_string(),
encoding: Some("euc-jp".to_string()),
path: "skk-dev-dict/SKK-JISYO.L".to_string(),
},
DictConfig {
dict_type: "skk".to_string(),
encoding: Some("utf-8".to_string()),
path: "data/SKK-JISYO.akaza".to_string(),
},
],
single_term: None,
});
if user_data {
info!("Enabled user data");
let user_data = UserData::load_from_default_path()?;

View File

@ -1,14 +1,11 @@
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
use std::time::SystemTime;
use anyhow::Context;
use encoding_rs::{EUC_JP, UTF_8};
use log::info;
use libakaza::dict::merge_dict::merge_dict;
use libakaza::dict::skk::read::read_skkdict;
use libakaza::config::{Config, DictConfig};
use libakaza::engine::base::HenkanEngine;
use libakaza::engine::bigram_word_viterbi_engine::BigramWordViterbiEngineBuilder;
@ -66,14 +63,23 @@ pub fn evaluate(corpus_dir: &String, load_user_config: bool) -> anyhow::Result<(
"corpus.5.txt",
];
let dicts = merge_dict(vec![
read_skkdict(Path::new("skk-dev-dict/SKK-JISYO.L"), EUC_JP)?,
read_skkdict(Path::new("data/SKK-JISYO.akaza"), UTF_8)?,
]);
let akaza = BigramWordViterbiEngineBuilder::new(Some(dicts), None)
.load_user_config(load_user_config)
.build()?;
let akaza = BigramWordViterbiEngineBuilder::new(Config {
dicts: vec![
DictConfig {
dict_type: "skk".to_string(),
encoding: Some("euc-jp".to_string()),
path: "skk-dev-dict/SKK-JISYO.L".to_string(),
},
DictConfig {
dict_type: "skk".to_string(),
encoding: Some("utf-8".to_string()),
path: "data/SKK-JISYO.akaza".to_string(),
},
],
single_term: None,
})
.load_user_config(load_user_config)
.build()?;
let mut good_cnt = 0;
let mut bad_cnt = 0;

View File

@ -14,6 +14,7 @@ use log::{error, info, warn};
use ibus_sys::core::ibus_main;
use ibus_sys::engine::IBusEngine;
use ibus_sys::glib::{gchar, guint};
use libakaza::config::Config;
use libakaza::engine::bigram_word_viterbi_engine::BigramWordViterbiEngineBuilder;
use libakaza::user_side_data::user_data::UserData;
@ -101,7 +102,7 @@ fn main() -> Result<()> {
unsafe {
let sys_time = SystemTime::now();
let user_data = load_user_data();
let akaza = BigramWordViterbiEngineBuilder::new(None, None)
let akaza = BigramWordViterbiEngineBuilder::new(Config::load()?)
.user_data(user_data.clone())
.load_user_config(true)
.build()?;

View File

@ -5,6 +5,8 @@ dicts:
encoding: euc-jp
dict_type: skk
*/
use anyhow::Result;
use log::{info, warn};
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::BufReader;
@ -16,12 +18,34 @@ pub struct Config {
}
impl Config {
pub fn load_from_file(path: &str) -> anyhow::Result<Config> {
pub fn load_from_file(path: &str) -> anyhow::Result<Self> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let config: Config = serde_yaml::from_reader(reader)?;
Ok(config)
}
pub fn load() -> Result<Self> {
let basedir = xdg::BaseDirectories::with_prefix("akaza")?;
let configfile = basedir.get_config_file("config.yml");
let config = match Config::load_from_file(configfile.to_str().unwrap()) {
Ok(config) => config,
Err(err) => {
warn!(
"Cannot load configuration file: {} {}",
configfile.to_string_lossy(),
err
);
return Ok(Config::default());
}
};
info!(
"Loaded config file: {}, {:?}",
configfile.to_string_lossy(),
config
);
Ok(config)
}
}
#[derive(Debug, PartialEq, Serialize, Deserialize, Default)]

View File

@ -20,7 +20,7 @@ pub fn load_dicts(dict_configs: &Vec<DictConfig>) -> Result<HashMap<String, Vec<
dicts.push(dict);
}
Err(err) => {
error!("Cannot load {:?}. {}", dict_config, err);
error!("Cannot load dictionary: {:?}. {}", dict_config, err);
// 一顧の辞書の読み込みに失敗しても、他の辞書は読み込むべきなので
// 処理は続行する
}

View File

@ -5,11 +5,9 @@ use std::ops::Range;
use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::time::SystemTime;
use anyhow::{bail, Result};
use encoding_rs::UTF_8;
use log::{info, warn};
use crate::config::Config;
use crate::dict::loader::load_dicts;
@ -124,20 +122,15 @@ impl<U: SystemUnigramLM, B: SystemBigramLM> BigramWordViterbiEngine<U, B> {
pub struct BigramWordViterbiEngineBuilder {
user_data: Option<Arc<Mutex<UserData>>>,
load_user_config: bool,
dicts: Option<HashMap<String, Vec<String>>>,
single_term: Option<HashMap<String, Vec<String>>>,
pub config: Config,
}
impl BigramWordViterbiEngineBuilder {
pub fn new(
dicts: Option<HashMap<String, Vec<String>>>,
single_term: Option<HashMap<String, Vec<String>>>,
) -> BigramWordViterbiEngineBuilder {
pub fn new(config: Config) -> BigramWordViterbiEngineBuilder {
BigramWordViterbiEngineBuilder {
user_data: None,
load_user_config: false,
dicts,
single_term,
config,
}
}
@ -157,13 +150,13 @@ impl BigramWordViterbiEngineBuilder {
&self,
) -> Result<BigramWordViterbiEngine<MarisaSystemUnigramLM, MarisaSystemBigramLM>> {
let system_unigram_lm = MarisaSystemUnigramLM::load(
Self::try_load("stats-vibrato-unigram.trie")?
Self::try_load("unigram.model")?
.to_string_lossy()
.to_string()
.as_str(),
)?;
let system_bigram_lm = MarisaSystemBigramLM::load(
Self::try_load("stats-vibrato-bigram.trie")?
Self::try_load("bigram.model")?
.to_string_lossy()
.to_string()
.as_str(),
@ -176,40 +169,17 @@ impl BigramWordViterbiEngineBuilder {
Arc::new(Mutex::new(UserData::default()))
};
// TODO このへんごちゃごちゃしすぎ。
let (dict, single_term, mut kana_trie) = {
let t1 = SystemTime::now();
let config = if self.load_user_config {
self.load_config()?
} else {
Config::default()
};
let dicts = load_dicts(&config.dicts)?;
let dicts = merge_dict(vec![system_dict, dicts]);
let single_term = if let Some(st) = &config.single_term {
load_dicts(st)?
} else {
HashMap::new()
};
// 次に、辞書を元に、トライを作成していく。
let kana_trie = CedarwoodKanaTrie::default();
let t2 = SystemTime::now();
info!(
"Loaded configuration in {}msec.",
t2.duration_since(t1).unwrap().as_millis()
);
(dicts, single_term, kana_trie)
};
let dict = if let Some(dd) = &self.dicts {
merge_dict(vec![dict, dd.clone()])
let dict = load_dicts(&self.config.dicts)?;
let dict = merge_dict(vec![system_dict, dict]);
let single_term = if let Some(st) = &&self.config.single_term {
load_dicts(st)?
} else {
dict
};
let single_term = if let Some(dd) = &self.single_term {
merge_dict(vec![single_term, dd.clone()])
} else {
single_term
HashMap::new()
};
// 辞書を元に、トライを作成していく。
let mut kana_trie = CedarwoodKanaTrie::default();
for yomi in dict.keys() {
assert!(!yomi.is_empty());
kana_trie.update(yomi.as_str());
@ -245,28 +215,6 @@ impl BigramWordViterbiEngineBuilder {
})
}
fn load_config(&self) -> Result<Config> {
let basedir = xdg::BaseDirectories::with_prefix("akaza")?;
let configfile = basedir.get_config_file("config.yml");
let config = match Config::load_from_file(configfile.to_str().unwrap()) {
Ok(config) => config,
Err(err) => {
warn!(
"Cannot load configuration file: {} {}",
configfile.to_string_lossy(),
err
);
return Ok(Config::default());
}
};
info!(
"Loaded config file: {}, {:?}",
configfile.to_string_lossy(),
config
);
Ok(config)
}
pub fn try_load(file_name: &str) -> Result<PathBuf> {
if cfg!(test) {
let path = Path::new(env!("CARGO_MANIFEST_DIR"));

View File

@ -33,16 +33,16 @@ impl MarisaSystemBigramLMBuilder {
// 最大でも 8,388,608 単語までになるように vocab を制限すること。
// 現実的な線で切っても、500万単語ぐらいで十分だと思われる。
// -rw-r--r-- 1 tokuhirom tokuhirom 28M Dec 31 23:56 stats-vibrato-bigram.trie
// -rw-r--r-- 1 tokuhirom tokuhirom 28M Dec 31 23:56 bigram.model
// ↓ 1MB 節約できる。
// -rw-r--r-- 1 tokuhirom tokuhirom 27M Jan 1 02:05 stats-vibrato-bigram.trie
// -rw-r--r-- 1 tokuhirom tokuhirom 27M Jan 1 02:05 bigram.model
// 4+4+4=12バイト必要だったところが、3+3+4=10バイトになって、10/12=5/6 なので、
// 本来なら 23.3 MB ぐらいまで減ってほしいところだけど、そこまではいかない。
// TRIE 構造だからそういう感じには減らない。
// さらに、スコアを f16 にしてみたが、あまりかわらない。
// -rw-r--r-- 1 tokuhirom tokuhirom 27M Jan 1 02:14 stats-vibrato-bigram.trie
// -rw-r--r-- 1 tokuhirom tokuhirom 27M Jan 1 02:14 bigram.model
let id1_bytes = word_id1.to_le_bytes();
let id2_bytes = word_id2.to_le_bytes();

View File

@ -17,13 +17,13 @@ mod tests {
fn load_unigram() -> anyhow::Result<MarisaSystemUnigramLM> {
let datadir = datadir();
let path = datadir + "/stats-vibrato-unigram.trie";
let path = datadir + "/unigram.model";
MarisaSystemUnigramLM::load(&path)
}
fn load_bigram() -> MarisaSystemBigramLM {
let datadir = datadir();
let path = datadir + "/stats-vibrato-bigram.trie";
let path = datadir + "/bigram.model";
MarisaSystemBigramLM::load(&path).unwrap()
}

View File

@ -14,7 +14,7 @@ mod tests {
#[test]
fn test_load() {
let path = datadir() + "/stats-vibrato-unigram.trie";
let path = datadir() + "/unigram.model";
let lm = MarisaSystemUnigramLM::load(&path).unwrap();
let (id, score) = lm.find("私/わたし").unwrap();
assert!(id > 0);

View File

@ -6,10 +6,9 @@ mod tests {
use std::path::Path;
use anyhow::Result;
use encoding_rs::UTF_8;
use libakaza::dict::skk::read::read_skkdict;
use log::LevelFilter;
use libakaza::config::{Config, DictConfig};
use libakaza::engine::base::HenkanEngine;
use libakaza::engine::bigram_word_viterbi_engine::{
BigramWordViterbiEngine, BigramWordViterbiEngineBuilder,
@ -23,24 +22,28 @@ mod tests {
let datadir = env!("CARGO_MANIFEST_DIR").to_string() + "/../akaza-data/data/";
assert!(Path::new(datadir.as_str()).exists());
env::set_var("AKAZA_DATA_DIR", datadir);
BigramWordViterbiEngineBuilder::new(
Some(read_skkdict(
Path::new(
(env!("CARGO_MANIFEST_DIR").to_string()
+ "/../akaza-data/data/SKK-JISYO.akaza")
.as_str(),
),
UTF_8,
)?),
Some(read_skkdict(
Path::new(
(env!("CARGO_MANIFEST_DIR").to_string()
+ "/../akaza-data/skk-dev-dict/SKK-JISYO.emoji")
.as_str(),
),
UTF_8,
)?),
)
BigramWordViterbiEngineBuilder::new(Config {
dicts: vec![
DictConfig {
dict_type: "skk".to_string(),
encoding: Some("euc-jp".to_string()),
path: (env!("CARGO_MANIFEST_DIR").to_string()
+ "/../akaza-data/skk-dev-dict/SKK-JISYO.L"),
},
DictConfig {
dict_type: "skk".to_string(),
encoding: Some("utf-8".to_string()),
path: (env!("CARGO_MANIFEST_DIR").to_string()
+ "/../akaza-data/data/SKK-JISYO.akaza"),
},
],
single_term: Some(vec![DictConfig {
dict_type: "skk".to_string(),
encoding: Some("utf-8".to_string()),
path: (env!("CARGO_MANIFEST_DIR").to_string()
+ "/../akaza-data/skk-dev-dict/SKK-JISYO.emoji"),
}]),
})
.build()
}
@ -83,7 +86,7 @@ mod tests {
#[test]
fn test_sushi() -> Result<()> {
let _ = env_logger::builder()
.filter_level(LevelFilter::Trace)
.filter_level(LevelFilter::Info)
.is_test(true)
.try_init();