mirror of
https://github.com/mii443/akaza.git
synced 2025-08-23 15:22:21 +00:00
42 lines
1.4 KiB
Python
42 lines
1.4 KiB
Python
import functools
|
|
import logging
|
|
import math
|
|
|
|
import marisa_trie
|
|
|
|
from comb.node import Node
|
|
from comb.system_language_model import SystemLanguageModel
|
|
from comb.user_language_model import UserLanguageModel
|
|
|
|
|
|
|
|
class LanguageModel:
|
|
def __init__(self,
|
|
system_language_model: SystemLanguageModel,
|
|
user_language_model: UserLanguageModel,
|
|
logger: logging.Logger = logging.getLogger(__name__)):
|
|
self.logger = logger
|
|
self.system_language_model = system_language_model
|
|
self.user_language_model = user_language_model
|
|
|
|
def calc_node_cost(self, node: Node) -> float:
|
|
if node.is_bos():
|
|
return 0
|
|
elif node.is_eos():
|
|
return 0
|
|
else:
|
|
u = self.user_language_model.get_unigram_cost(node.get_key())
|
|
if u:
|
|
# self.logger.info(f"Use user score: {node.get_key()} -> {u}")
|
|
return u
|
|
return self.system_language_model.get_unigram_cost(node.get_key())
|
|
|
|
@functools.lru_cache
|
|
def calc_bigram_cost(self, prev_node, next_node) -> float:
|
|
# self → node で処理する。
|
|
u = self.user_language_model.get_bigram_cost(prev_node, next_node)
|
|
if u:
|
|
self.logger.info(f"Use user's bigram score: {prev_node.get_key()},{next_node.get_key()} -> {u}")
|
|
return u
|
|
return self.system_language_model.get_bigram_cost(prev_node, next_node)
|