mirror of
https://github.com/mii443/akaza.git
synced 2025-08-28 17:49:28 +00:00
system language model
This commit is contained in:
33
comb/system_language_model.py
Normal file
33
comb/system_language_model.py
Normal file
@ -0,0 +1,33 @@
|
||||
import math
|
||||
|
||||
import marisa_trie
|
||||
|
||||
from comb.config import MODEL_DIR
|
||||
from comb.node import Node
|
||||
|
||||
DEFAULT_SCORE = [(math.log10(0.00000000001),)]
|
||||
|
||||
|
||||
class SystemLanguageModel:
|
||||
def __init__(self, unigram_score: marisa_trie.RecordTrie, bigram_score: marisa_trie.RecordTrie):
|
||||
self.unigram_score = unigram_score
|
||||
self.bigram_score = bigram_score
|
||||
|
||||
@staticmethod
|
||||
def create():
|
||||
unigram_score = marisa_trie.RecordTrie('@f')
|
||||
unigram_score.load(f"{MODEL_DIR}/jawiki.1gram")
|
||||
|
||||
bigram_score = marisa_trie.RecordTrie('@f')
|
||||
bigram_score.load(f"{MODEL_DIR}/jawiki.2gram")
|
||||
|
||||
return SystemLanguageModel(unigram_score, bigram_score)
|
||||
|
||||
def get_unigram_cost(self, key: str) -> float:
|
||||
return self.unigram_score.get(key, DEFAULT_SCORE)[0][0]
|
||||
|
||||
def get_bigram_cost(self, node1: Node, node2: Node) -> float:
|
||||
key1 = node1.get_key()
|
||||
key2 = node2.get_key()
|
||||
key = key1 + "\t" + key2
|
||||
return self.bigram_score.get(key, DEFAULT_SCORE)[0][0]
|
Reference in New Issue
Block a user