merge 2 system_language_model files into 1 file

This commit is contained in:
Tokuhiro Matsuno
2020-09-14 09:49:48 +09:00
parent aacb9e5b8d
commit 1c061b367d
6 changed files with 39 additions and 36 deletions

View File

@ -9,25 +9,21 @@ DEFAULT_SCORE = [(math.log10(0.00000000001),)]
class SystemLanguageModel:
def __init__(self, unigram_score: marisa_trie.RecordTrie, bigram_score: marisa_trie.RecordTrie):
self.unigram_score = unigram_score
self.bigram_score = bigram_score
def __init__(self, score: marisa_trie.RecordTrie):
self.score = score
@staticmethod
def create():
unigram_score = marisa_trie.RecordTrie('@f')
unigram_score.mmap(f"{MODEL_DIR}/jawiki.1gram")
score = marisa_trie.RecordTrie('@f')
score.mmap(f"{MODEL_DIR}/system_language_model.trie")
bigram_score = marisa_trie.RecordTrie('@f')
bigram_score.mmap(f"{MODEL_DIR}/jawiki.2gram")
return SystemLanguageModel(unigram_score, bigram_score)
return SystemLanguageModel(score)
def get_unigram_cost(self, key: str) -> float:
return self.unigram_score.get(key, DEFAULT_SCORE)[0][0]
return self.score.get(key, DEFAULT_SCORE)[0][0]
def get_bigram_cost(self, node1: Node, node2: Node) -> float:
key1 = node1.get_key()
key2 = node2.get_key()
key = key1 + "\t" + key2
return self.bigram_score.get(key, DEFAULT_SCORE)[0][0]
return self.score.get(key, DEFAULT_SCORE)[0][0]