コストの計算を利用時にするように変更。 (#275)

今後の #252 をやりやすくするための施策。

モデルのフォーマットが変わるので注意。
This commit is contained in:
Tokuhiro Matsuno
2023-02-01 01:22:23 +09:00
committed by GitHub
parent 20a5d4224f
commit 9d8657316b
8 changed files with 52 additions and 80 deletions

View File

@ -67,8 +67,6 @@ impl LearningService {
}
let system_unigram_lm = Rc::new(OnMemorySystemUnigramLM::new(
Rc::new(RefCell::new(unigram_map)),
src_system_unigram_lm.get_default_cost(),
src_system_unigram_lm.get_default_cost_for_short(),
src_system_unigram_lm.total_words,
src_system_unigram_lm.unique_words,
));
@ -190,9 +188,8 @@ impl LearningService {
}
// ↓本来なら現在のデータで再調整すべきだが、一旦元のものを使う。
// TODO あとで整理する
unigram_builder.set_default_cost(self.system_unigram_lm.get_default_cost());
unigram_builder
.set_default_cost_for_short(self.system_unigram_lm.get_default_cost_for_short());
unigram_builder.set_unique_words(self.system_unigram_lm.unique_words);
unigram_builder.set_total_words(self.system_unigram_lm.total_words);
info!("Save unigram to {}", dst_unigram);
unigram_builder.save(dst_unigram)?;
Ok(())

View File

@ -48,8 +48,6 @@ impl WordcntUnigramBuilder {
pub struct WordcntUnigram {
marisa: Marisa,
default_cost: f32,
default_cost_for_short: f32,
pub(crate) total_words: u32,
pub(crate) unique_words: u32,
}
@ -88,13 +86,8 @@ impl WordcntUnigram {
// 単語の種類数
let unique_words = map.keys().count() as u32;
let default_cost = calc_cost(0, total_words, unique_words);
let default_cost_for_short = calc_cost(1, total_words, unique_words);
Ok(WordcntUnigram {
marisa,
default_cost,
default_cost_for_short,
total_words,
unique_words,
})
@ -102,12 +95,8 @@ impl WordcntUnigram {
}
impl SystemUnigramLM for WordcntUnigram {
fn get_default_cost(&self) -> f32 {
self.default_cost
}
fn get_default_cost_for_short(&self) -> f32 {
self.default_cost_for_short
fn get_cost(&self, wordcnt: u32) -> f32 {
calc_cost(wordcnt, self.total_words, self.unique_words)
}
/// @return (word_id, score)。
@ -181,8 +170,8 @@ mod tests {
);
assert_eq!(wordcnt.total_words, 45); // 単語発生数
assert_eq!(wordcnt.unique_words, 2); // ユニーク単語数
assert_eq!(wordcnt.get_default_cost(), 6.672098);
assert_eq!(wordcnt.get_default_cost_for_short(), 1.6720936);
assert_eq!(wordcnt.get_cost(0), 6.672098);
assert_eq!(wordcnt.get_cost(1), 1.6720936);
assert_eq!(wordcnt.find("私/わたし"), Some((1_i32, 1.1949753)));
assert_eq!(wordcnt.find("彼/かれ"), Some((0_i32, 0.048848562)));

View File

@ -172,8 +172,8 @@ mod tests {
Arc::new(Mutex::new(UserData::default())),
Rc::new(
MarisaSystemUnigramLMBuilder::default()
.set_default_cost(20_f32)
.set_default_cost_for_short(19_f32)
.set_unique_words(20)
.set_total_words(19)
.build(),
),
Rc::new(
@ -205,8 +205,8 @@ mod tests {
Arc::new(Mutex::new(UserData::default())),
Rc::new(
MarisaSystemUnigramLMBuilder::default()
.set_default_cost(20_f32)
.set_default_cost_for_short(19_f32)
.set_unique_words(20)
.set_total_words(19)
.build(),
),
Rc::new(
@ -238,8 +238,8 @@ mod tests {
Arc::new(Mutex::new(UserData::default())),
Rc::new(
MarisaSystemUnigramLMBuilder::default()
.set_default_cost(20_f32)
.set_default_cost_for_short(19_f32)
.set_unique_words(20)
.set_total_words(19)
.build(),
),
Rc::new(

View File

@ -319,8 +319,8 @@ mod tests {
// -1 0 1 2
// BOS a b c
let system_unigram_lm = MarisaSystemUnigramLMBuilder::default()
.set_default_cost(20_f32)
.set_default_cost_for_short(19_f32)
.set_unique_words(20)
.set_total_words(19)
.build();
let system_bigram_lm = MarisaSystemBigramLMBuilder::default()
.set_default_edge_cost(20_f32)
@ -371,8 +371,8 @@ mod tests {
let mut system_unigram_lm_builder = MarisaSystemUnigramLMBuilder::default();
let system_unigram_lm = system_unigram_lm_builder
.set_default_cost(19_f32)
.set_default_cost_for_short(20_f32)
.set_unique_words(19)
.set_total_words(20)
.build();
let system_bigram_lm = MarisaSystemBigramLMBuilder::default()
.set_default_edge_cost(20_f32)
@ -452,8 +452,8 @@ mod tests {
let mut system_unigram_lm_builder = MarisaSystemUnigramLMBuilder::default();
let system_unigram_lm = system_unigram_lm_builder
.set_default_cost(19_f32)
.set_default_cost_for_short(20_f32)
.set_unique_words(19)
.set_total_words(20)
.build();
let system_bigram_lm = MarisaSystemBigramLMBuilder::default()
.set_default_edge_cost(20_f32)

View File

@ -134,9 +134,9 @@ impl<U: SystemUnigramLM, B: SystemBigramLM> LatticeGraph<U, B> {
// 労働者災害補償保険法 のように、システム辞書には wikipedia から採録されているが,
// 言語モデルには採録されていない場合,漢字候補を先頭に持ってくる。
// つまり、変換後のほうが短くなるもののほうをコストを安くしておく。
self.system_unigram_lm.get_default_cost_for_short()
self.system_unigram_lm.get_cost(1)
} else {
self.system_unigram_lm.get_default_cost()
self.system_unigram_lm.get_cost(0)
};
}

View File

@ -7,8 +7,7 @@ pub trait SystemBigramLM {
}
pub trait SystemUnigramLM {
fn get_default_cost(&self) -> f32;
fn get_default_cost_for_short(&self) -> f32;
fn get_cost(&self, wordcnt: u32) -> f32;
fn find(&self, word: &str) -> Option<(i32, f32)>;
fn as_hash_map(&self) -> HashMap<String, (i32, f32)>;

View File

@ -8,8 +8,6 @@ use crate::lm::base::SystemUnigramLM;
pub struct OnMemorySystemUnigramLM {
// word -> (word_id, cost)
map: Rc<RefCell<HashMap<String, (i32, u32)>>>,
pub default_cost: f32,
pub default_cost_for_short: f32,
pub total_words: u32,
pub unique_words: u32,
}
@ -17,17 +15,13 @@ pub struct OnMemorySystemUnigramLM {
impl OnMemorySystemUnigramLM {
pub fn new(
map: Rc<RefCell<HashMap<String, (i32, u32)>>>,
default_cost: f32,
default_cost_for_short: f32,
c: u32,
v: u32,
total_words: u32,
unique_words: u32,
) -> Self {
OnMemorySystemUnigramLM {
map,
default_cost,
default_cost_for_short,
total_words: c,
unique_words: v,
total_words,
unique_words,
}
}
@ -57,12 +51,8 @@ impl OnMemorySystemUnigramLM {
}
impl SystemUnigramLM for OnMemorySystemUnigramLM {
fn get_default_cost(&self) -> f32 {
self.default_cost
}
fn get_default_cost_for_short(&self) -> f32 {
self.default_cost_for_short
fn get_cost(&self, wordcnt: u32) -> f32 {
calc_cost(wordcnt, self.total_words, self.unique_words)
}
fn find(&self, word: &str) -> Option<(i32, f32)> {

View File

@ -5,6 +5,7 @@ use log::info;
use marisa_sys::{Keyset, Marisa};
use crate::cost::calc_cost;
use crate::lm::base::SystemUnigramLM;
/*
@ -14,8 +15,8 @@ use crate::lm::base::SystemUnigramLM;
packed float # score: 4 bytes
*/
const DEFAULT_COST_FOR_SHORT_KEY: &str = "__DEFAULT_COST_FOR_SHORT__";
const DEFAULT_COST_KEY: &str = "__DEFAULT_COST__";
const UNIQUE_WORDS_KEY: &str = "__UNIQUE_WORDS__";
const TOTAL_WORDS_KEY: &str = "__TOTAL_WORDS__";
/**
* unigram 言語モデル。
@ -48,13 +49,13 @@ impl MarisaSystemUnigramLMBuilder {
keyset
}
pub fn set_default_cost_for_short(&mut self, cost: f32) -> &mut Self {
self.add(DEFAULT_COST_FOR_SHORT_KEY, cost);
pub fn set_total_words(&mut self, total_words: u32) -> &mut Self {
self.add(TOTAL_WORDS_KEY, total_words as f32);
self
}
pub fn set_default_cost(&mut self, cost: f32) -> &mut Self {
self.add(DEFAULT_COST_KEY, cost);
pub fn set_unique_words(&mut self, unique_words: u32) -> &mut Self {
self.add(UNIQUE_WORDS_KEY, unique_words as f32);
self
}
@ -68,22 +69,22 @@ impl MarisaSystemUnigramLMBuilder {
pub fn build(&self) -> MarisaSystemUnigramLM {
let mut marisa = Marisa::default();
marisa.build(&self.keyset());
let (_, default_cost_for_short) =
MarisaSystemUnigramLM::find_from_trie(&marisa, DEFAULT_COST_FOR_SHORT_KEY).unwrap();
let (_, default_cost) =
MarisaSystemUnigramLM::find_from_trie(&marisa, DEFAULT_COST_FOR_SHORT_KEY).unwrap();
let (_, total_words) =
MarisaSystemUnigramLM::find_from_trie(&marisa, TOTAL_WORDS_KEY).unwrap();
let (_, unique_words) =
MarisaSystemUnigramLM::find_from_trie(&marisa, UNIQUE_WORDS_KEY).unwrap();
MarisaSystemUnigramLM {
marisa,
default_cost_for_short,
default_cost,
total_words: total_words as u32,
unique_words: unique_words as u32,
}
}
}
pub struct MarisaSystemUnigramLM {
marisa: Marisa,
default_cost_for_short: f32,
default_cost: f32,
total_words: u32,
unique_words: u32,
}
impl MarisaSystemUnigramLM {
@ -95,16 +96,16 @@ impl MarisaSystemUnigramLM {
info!("Reading {}", fname);
let mut marisa = Marisa::default();
marisa.load(fname)?;
let Some((_, default_cost_for_short)) = Self::find_from_trie(&marisa, DEFAULT_COST_FOR_SHORT_KEY) else {
bail!("Missing key for {}", DEFAULT_COST_FOR_SHORT_KEY);
let Some((_, total_words)) = Self::find_from_trie(&marisa, TOTAL_WORDS_KEY) else {
bail!("Missing key for {}", TOTAL_WORDS_KEY);
};
let Some((_, default_cost)) = Self::find_from_trie(&marisa, DEFAULT_COST_FOR_SHORT_KEY) else {
bail!("Missing key for {}", DEFAULT_COST_KEY);
let Some((_, unique_words)) = Self::find_from_trie(&marisa, UNIQUE_WORDS_KEY) else {
bail!("Missing key for {}", UNIQUE_WORDS_KEY);
};
Ok(MarisaSystemUnigramLM {
marisa,
default_cost_for_short,
default_cost,
total_words: total_words as u32,
unique_words: unique_words as u32,
})
}
@ -131,12 +132,8 @@ impl MarisaSystemUnigramLM {
}
impl SystemUnigramLM for MarisaSystemUnigramLM {
fn get_default_cost(&self) -> f32 {
self.default_cost
}
fn get_default_cost_for_short(&self) -> f32 {
self.default_cost_for_short
fn get_cost(&self, wordcnt: u32) -> f32 {
calc_cost(wordcnt, self.total_words, self.unique_words)
}
/// @return (word_id, score)。
@ -173,8 +170,8 @@ mod tests {
builder.add("hello", 0.4);
builder.add("world", 0.2);
builder
.set_default_cost(20_f32)
.set_default_cost_for_short(19_f32)
.set_total_words(2)
.set_unique_words(2)
.save(&tmpfile)
.unwrap();