mirror of
https://github.com/mii443/akaza.git
synced 2025-08-22 14:55:31 +00:00
@ -67,8 +67,6 @@ impl LearningService {
|
||||
}
|
||||
let system_unigram_lm = Rc::new(OnMemorySystemUnigramLM::new(
|
||||
Rc::new(RefCell::new(unigram_map)),
|
||||
src_system_unigram_lm.get_default_cost(),
|
||||
src_system_unigram_lm.get_default_cost_for_short(),
|
||||
src_system_unigram_lm.total_words,
|
||||
src_system_unigram_lm.unique_words,
|
||||
));
|
||||
@ -190,9 +188,8 @@ impl LearningService {
|
||||
}
|
||||
// ↓本来なら現在のデータで再調整すべきだが、一旦元のものを使う。
|
||||
// TODO あとで整理する
|
||||
unigram_builder.set_default_cost(self.system_unigram_lm.get_default_cost());
|
||||
unigram_builder
|
||||
.set_default_cost_for_short(self.system_unigram_lm.get_default_cost_for_short());
|
||||
unigram_builder.set_unique_words(self.system_unigram_lm.unique_words);
|
||||
unigram_builder.set_total_words(self.system_unigram_lm.total_words);
|
||||
info!("Save unigram to {}", dst_unigram);
|
||||
unigram_builder.save(dst_unigram)?;
|
||||
Ok(())
|
||||
|
@ -48,8 +48,6 @@ impl WordcntUnigramBuilder {
|
||||
|
||||
pub struct WordcntUnigram {
|
||||
marisa: Marisa,
|
||||
default_cost: f32,
|
||||
default_cost_for_short: f32,
|
||||
pub(crate) total_words: u32,
|
||||
pub(crate) unique_words: u32,
|
||||
}
|
||||
@ -88,13 +86,8 @@ impl WordcntUnigram {
|
||||
// 単語の種類数
|
||||
let unique_words = map.keys().count() as u32;
|
||||
|
||||
let default_cost = calc_cost(0, total_words, unique_words);
|
||||
let default_cost_for_short = calc_cost(1, total_words, unique_words);
|
||||
|
||||
Ok(WordcntUnigram {
|
||||
marisa,
|
||||
default_cost,
|
||||
default_cost_for_short,
|
||||
total_words,
|
||||
unique_words,
|
||||
})
|
||||
@ -102,12 +95,8 @@ impl WordcntUnigram {
|
||||
}
|
||||
|
||||
impl SystemUnigramLM for WordcntUnigram {
|
||||
fn get_default_cost(&self) -> f32 {
|
||||
self.default_cost
|
||||
}
|
||||
|
||||
fn get_default_cost_for_short(&self) -> f32 {
|
||||
self.default_cost_for_short
|
||||
fn get_cost(&self, wordcnt: u32) -> f32 {
|
||||
calc_cost(wordcnt, self.total_words, self.unique_words)
|
||||
}
|
||||
|
||||
/// @return (word_id, score)。
|
||||
@ -181,8 +170,8 @@ mod tests {
|
||||
);
|
||||
assert_eq!(wordcnt.total_words, 45); // 単語発生数
|
||||
assert_eq!(wordcnt.unique_words, 2); // ユニーク単語数
|
||||
assert_eq!(wordcnt.get_default_cost(), 6.672098);
|
||||
assert_eq!(wordcnt.get_default_cost_for_short(), 1.6720936);
|
||||
assert_eq!(wordcnt.get_cost(0), 6.672098);
|
||||
assert_eq!(wordcnt.get_cost(1), 1.6720936);
|
||||
|
||||
assert_eq!(wordcnt.find("私/わたし"), Some((1_i32, 1.1949753)));
|
||||
assert_eq!(wordcnt.find("彼/かれ"), Some((0_i32, 0.048848562)));
|
||||
|
@ -172,8 +172,8 @@ mod tests {
|
||||
Arc::new(Mutex::new(UserData::default())),
|
||||
Rc::new(
|
||||
MarisaSystemUnigramLMBuilder::default()
|
||||
.set_default_cost(20_f32)
|
||||
.set_default_cost_for_short(19_f32)
|
||||
.set_unique_words(20)
|
||||
.set_total_words(19)
|
||||
.build(),
|
||||
),
|
||||
Rc::new(
|
||||
@ -205,8 +205,8 @@ mod tests {
|
||||
Arc::new(Mutex::new(UserData::default())),
|
||||
Rc::new(
|
||||
MarisaSystemUnigramLMBuilder::default()
|
||||
.set_default_cost(20_f32)
|
||||
.set_default_cost_for_short(19_f32)
|
||||
.set_unique_words(20)
|
||||
.set_total_words(19)
|
||||
.build(),
|
||||
),
|
||||
Rc::new(
|
||||
@ -238,8 +238,8 @@ mod tests {
|
||||
Arc::new(Mutex::new(UserData::default())),
|
||||
Rc::new(
|
||||
MarisaSystemUnigramLMBuilder::default()
|
||||
.set_default_cost(20_f32)
|
||||
.set_default_cost_for_short(19_f32)
|
||||
.set_unique_words(20)
|
||||
.set_total_words(19)
|
||||
.build(),
|
||||
),
|
||||
Rc::new(
|
||||
|
@ -319,8 +319,8 @@ mod tests {
|
||||
// -1 0 1 2
|
||||
// BOS a b c
|
||||
let system_unigram_lm = MarisaSystemUnigramLMBuilder::default()
|
||||
.set_default_cost(20_f32)
|
||||
.set_default_cost_for_short(19_f32)
|
||||
.set_unique_words(20)
|
||||
.set_total_words(19)
|
||||
.build();
|
||||
let system_bigram_lm = MarisaSystemBigramLMBuilder::default()
|
||||
.set_default_edge_cost(20_f32)
|
||||
@ -371,8 +371,8 @@ mod tests {
|
||||
|
||||
let mut system_unigram_lm_builder = MarisaSystemUnigramLMBuilder::default();
|
||||
let system_unigram_lm = system_unigram_lm_builder
|
||||
.set_default_cost(19_f32)
|
||||
.set_default_cost_for_short(20_f32)
|
||||
.set_unique_words(19)
|
||||
.set_total_words(20)
|
||||
.build();
|
||||
let system_bigram_lm = MarisaSystemBigramLMBuilder::default()
|
||||
.set_default_edge_cost(20_f32)
|
||||
@ -452,8 +452,8 @@ mod tests {
|
||||
|
||||
let mut system_unigram_lm_builder = MarisaSystemUnigramLMBuilder::default();
|
||||
let system_unigram_lm = system_unigram_lm_builder
|
||||
.set_default_cost(19_f32)
|
||||
.set_default_cost_for_short(20_f32)
|
||||
.set_unique_words(19)
|
||||
.set_total_words(20)
|
||||
.build();
|
||||
let system_bigram_lm = MarisaSystemBigramLMBuilder::default()
|
||||
.set_default_edge_cost(20_f32)
|
||||
|
@ -134,9 +134,9 @@ impl<U: SystemUnigramLM, B: SystemBigramLM> LatticeGraph<U, B> {
|
||||
// 労働者災害補償保険法 のように、システム辞書には wikipedia から採録されているが,
|
||||
// 言語モデルには採録されていない場合,漢字候補を先頭に持ってくる。
|
||||
// つまり、変換後のほうが短くなるもののほうをコストを安くしておく。
|
||||
self.system_unigram_lm.get_default_cost_for_short()
|
||||
self.system_unigram_lm.get_cost(1)
|
||||
} else {
|
||||
self.system_unigram_lm.get_default_cost()
|
||||
self.system_unigram_lm.get_cost(0)
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -7,8 +7,7 @@ pub trait SystemBigramLM {
|
||||
}
|
||||
|
||||
pub trait SystemUnigramLM {
|
||||
fn get_default_cost(&self) -> f32;
|
||||
fn get_default_cost_for_short(&self) -> f32;
|
||||
fn get_cost(&self, wordcnt: u32) -> f32;
|
||||
|
||||
fn find(&self, word: &str) -> Option<(i32, f32)>;
|
||||
fn as_hash_map(&self) -> HashMap<String, (i32, f32)>;
|
||||
|
@ -8,8 +8,6 @@ use crate::lm::base::SystemUnigramLM;
|
||||
pub struct OnMemorySystemUnigramLM {
|
||||
// word -> (word_id, cost)
|
||||
map: Rc<RefCell<HashMap<String, (i32, u32)>>>,
|
||||
pub default_cost: f32,
|
||||
pub default_cost_for_short: f32,
|
||||
pub total_words: u32,
|
||||
pub unique_words: u32,
|
||||
}
|
||||
@ -17,17 +15,13 @@ pub struct OnMemorySystemUnigramLM {
|
||||
impl OnMemorySystemUnigramLM {
|
||||
pub fn new(
|
||||
map: Rc<RefCell<HashMap<String, (i32, u32)>>>,
|
||||
default_cost: f32,
|
||||
default_cost_for_short: f32,
|
||||
c: u32,
|
||||
v: u32,
|
||||
total_words: u32,
|
||||
unique_words: u32,
|
||||
) -> Self {
|
||||
OnMemorySystemUnigramLM {
|
||||
map,
|
||||
default_cost,
|
||||
default_cost_for_short,
|
||||
total_words: c,
|
||||
unique_words: v,
|
||||
total_words,
|
||||
unique_words,
|
||||
}
|
||||
}
|
||||
|
||||
@ -57,12 +51,8 @@ impl OnMemorySystemUnigramLM {
|
||||
}
|
||||
|
||||
impl SystemUnigramLM for OnMemorySystemUnigramLM {
|
||||
fn get_default_cost(&self) -> f32 {
|
||||
self.default_cost
|
||||
}
|
||||
|
||||
fn get_default_cost_for_short(&self) -> f32 {
|
||||
self.default_cost_for_short
|
||||
fn get_cost(&self, wordcnt: u32) -> f32 {
|
||||
calc_cost(wordcnt, self.total_words, self.unique_words)
|
||||
}
|
||||
|
||||
fn find(&self, word: &str) -> Option<(i32, f32)> {
|
||||
|
@ -5,6 +5,7 @@ use log::info;
|
||||
|
||||
use marisa_sys::{Keyset, Marisa};
|
||||
|
||||
use crate::cost::calc_cost;
|
||||
use crate::lm::base::SystemUnigramLM;
|
||||
|
||||
/*
|
||||
@ -14,8 +15,8 @@ use crate::lm::base::SystemUnigramLM;
|
||||
packed float # score: 4 bytes
|
||||
*/
|
||||
|
||||
const DEFAULT_COST_FOR_SHORT_KEY: &str = "__DEFAULT_COST_FOR_SHORT__";
|
||||
const DEFAULT_COST_KEY: &str = "__DEFAULT_COST__";
|
||||
const UNIQUE_WORDS_KEY: &str = "__UNIQUE_WORDS__";
|
||||
const TOTAL_WORDS_KEY: &str = "__TOTAL_WORDS__";
|
||||
|
||||
/**
|
||||
* unigram 言語モデル。
|
||||
@ -48,13 +49,13 @@ impl MarisaSystemUnigramLMBuilder {
|
||||
keyset
|
||||
}
|
||||
|
||||
pub fn set_default_cost_for_short(&mut self, cost: f32) -> &mut Self {
|
||||
self.add(DEFAULT_COST_FOR_SHORT_KEY, cost);
|
||||
pub fn set_total_words(&mut self, total_words: u32) -> &mut Self {
|
||||
self.add(TOTAL_WORDS_KEY, total_words as f32);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_default_cost(&mut self, cost: f32) -> &mut Self {
|
||||
self.add(DEFAULT_COST_KEY, cost);
|
||||
pub fn set_unique_words(&mut self, unique_words: u32) -> &mut Self {
|
||||
self.add(UNIQUE_WORDS_KEY, unique_words as f32);
|
||||
self
|
||||
}
|
||||
|
||||
@ -68,22 +69,22 @@ impl MarisaSystemUnigramLMBuilder {
|
||||
pub fn build(&self) -> MarisaSystemUnigramLM {
|
||||
let mut marisa = Marisa::default();
|
||||
marisa.build(&self.keyset());
|
||||
let (_, default_cost_for_short) =
|
||||
MarisaSystemUnigramLM::find_from_trie(&marisa, DEFAULT_COST_FOR_SHORT_KEY).unwrap();
|
||||
let (_, default_cost) =
|
||||
MarisaSystemUnigramLM::find_from_trie(&marisa, DEFAULT_COST_FOR_SHORT_KEY).unwrap();
|
||||
let (_, total_words) =
|
||||
MarisaSystemUnigramLM::find_from_trie(&marisa, TOTAL_WORDS_KEY).unwrap();
|
||||
let (_, unique_words) =
|
||||
MarisaSystemUnigramLM::find_from_trie(&marisa, UNIQUE_WORDS_KEY).unwrap();
|
||||
MarisaSystemUnigramLM {
|
||||
marisa,
|
||||
default_cost_for_short,
|
||||
default_cost,
|
||||
total_words: total_words as u32,
|
||||
unique_words: unique_words as u32,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MarisaSystemUnigramLM {
|
||||
marisa: Marisa,
|
||||
default_cost_for_short: f32,
|
||||
default_cost: f32,
|
||||
total_words: u32,
|
||||
unique_words: u32,
|
||||
}
|
||||
|
||||
impl MarisaSystemUnigramLM {
|
||||
@ -95,16 +96,16 @@ impl MarisaSystemUnigramLM {
|
||||
info!("Reading {}", fname);
|
||||
let mut marisa = Marisa::default();
|
||||
marisa.load(fname)?;
|
||||
let Some((_, default_cost_for_short)) = Self::find_from_trie(&marisa, DEFAULT_COST_FOR_SHORT_KEY) else {
|
||||
bail!("Missing key for {}", DEFAULT_COST_FOR_SHORT_KEY);
|
||||
let Some((_, total_words)) = Self::find_from_trie(&marisa, TOTAL_WORDS_KEY) else {
|
||||
bail!("Missing key for {}", TOTAL_WORDS_KEY);
|
||||
};
|
||||
let Some((_, default_cost)) = Self::find_from_trie(&marisa, DEFAULT_COST_FOR_SHORT_KEY) else {
|
||||
bail!("Missing key for {}", DEFAULT_COST_KEY);
|
||||
let Some((_, unique_words)) = Self::find_from_trie(&marisa, UNIQUE_WORDS_KEY) else {
|
||||
bail!("Missing key for {}", UNIQUE_WORDS_KEY);
|
||||
};
|
||||
Ok(MarisaSystemUnigramLM {
|
||||
marisa,
|
||||
default_cost_for_short,
|
||||
default_cost,
|
||||
total_words: total_words as u32,
|
||||
unique_words: unique_words as u32,
|
||||
})
|
||||
}
|
||||
|
||||
@ -131,12 +132,8 @@ impl MarisaSystemUnigramLM {
|
||||
}
|
||||
|
||||
impl SystemUnigramLM for MarisaSystemUnigramLM {
|
||||
fn get_default_cost(&self) -> f32 {
|
||||
self.default_cost
|
||||
}
|
||||
|
||||
fn get_default_cost_for_short(&self) -> f32 {
|
||||
self.default_cost_for_short
|
||||
fn get_cost(&self, wordcnt: u32) -> f32 {
|
||||
calc_cost(wordcnt, self.total_words, self.unique_words)
|
||||
}
|
||||
|
||||
/// @return (word_id, score)。
|
||||
@ -173,8 +170,8 @@ mod tests {
|
||||
builder.add("hello", 0.4);
|
||||
builder.add("world", 0.2);
|
||||
builder
|
||||
.set_default_cost(20_f32)
|
||||
.set_default_cost_for_short(19_f32)
|
||||
.set_total_words(2)
|
||||
.set_unique_words(2)
|
||||
.save(&tmpfile)
|
||||
.unwrap();
|
||||
|
||||
|
Reference in New Issue
Block a user