Files
akaza/comb/language_model.py
2020-09-12 19:37:23 +09:00

42 lines
1.4 KiB
Python

import functools
import logging
import math
import marisa_trie
from comb.node import Node
from comb.system_language_model import SystemLanguageModel
from comb.user_language_model import UserLanguageModel
class LanguageModel:
def __init__(self,
system_language_model: SystemLanguageModel,
user_language_model: UserLanguageModel,
logger: logging.Logger = logging.getLogger(__name__)):
self.logger = logger
self.system_language_model = system_language_model
self.user_language_model = user_language_model
def calc_node_cost(self, node: Node) -> float:
if node.is_bos():
return 0
elif node.is_eos():
return 0
else:
u = self.user_language_model.get_unigram_cost(node.get_key())
if u:
# self.logger.info(f"Use user score: {node.get_key()} -> {u}")
return u
return self.system_language_model.get_unigram_cost(node.get_key())
@functools.lru_cache
def calc_bigram_cost(self, prev_node, next_node) -> float:
# self → node で処理する。
u = self.user_language_model.get_bigram_cost(prev_node, next_node)
if u:
self.logger.info(f"Use user's bigram score: {prev_node.get_key()},{next_node.get_key()} -> {u}")
return u
return self.system_language_model.get_bigram_cost(prev_node, next_node)