mirror of
https://github.com/mii443/akaza.git
synced 2025-08-22 14:55:31 +00:00
30 lines
766 B
Python
30 lines
766 B
Python
import math
|
|
|
|
import marisa_trie
|
|
|
|
from comb.config import MODEL_DIR
|
|
from comb.node import Node
|
|
|
|
DEFAULT_SCORE = [(math.log10(0.00000000001),)]
|
|
|
|
|
|
class SystemLanguageModel:
|
|
def __init__(self, score: marisa_trie.RecordTrie):
|
|
self.score = score
|
|
|
|
@staticmethod
|
|
def create():
|
|
score = marisa_trie.RecordTrie('@f')
|
|
score.mmap(f"{MODEL_DIR}/system_language_model.trie")
|
|
|
|
return SystemLanguageModel(score)
|
|
|
|
def get_unigram_cost(self, key: str) -> float:
|
|
return self.score.get(key, DEFAULT_SCORE)[0][0]
|
|
|
|
def get_bigram_cost(self, node1: Node, node2: Node) -> float:
|
|
key1 = node1.get_key()
|
|
key2 = node2.get_key()
|
|
key = key1 + "\t" + key2
|
|
return self.score.get(key, DEFAULT_SCORE)[0][0]
|