Files
akaza/comb/language_model.py
2020-09-12 11:36:22 +09:00

46 lines
1.5 KiB
Python

import functools
import logging
import math
import marisa_trie
from comb.node import Node
from comb.user_dict import UserDict
DEFAULT_SCORE = [(math.log10(0.00000000001),)]
class LanguageModel:
def __init__(self,
system_unigram_score: marisa_trie.RecordTrie,
system_bigram_score: marisa_trie.RecordTrie,
user_dict: UserDict,
logger: logging.Logger = logging.getLogger(__name__)):
self.logger = logger
self.system_bigram_score = system_bigram_score
self.system_unigram_score = system_unigram_score
self.user_dict = user_dict
def calc_node_cost(self, node: Node) -> float:
if node.is_bos():
return 0
elif node.is_eos():
return 0
else:
u = self.user_dict.get_unigram_cost(node.get_key())
if u:
# self.logger.info(f"Use user score: {node.get_key()} -> {u}")
return u
return self.system_unigram_score.get(node.get_key(), DEFAULT_SCORE)[0][0]
@functools.lru_cache
def calc_bigram_cost(self, prev_node, next_node) -> float:
# self → node で処理する。
u = self.user_dict.get_bigram_cost(prev_node, next_node)
if u:
self.logger.info(f"Use user's bigram score: {prev_node.get_key()},{next_node.get_key()} -> {u}")
return u
return self.system_bigram_score.get(
f"{prev_node.get_key()}\t{next_node.get_key()}", DEFAULT_SCORE
)[0][0]