system language model

This commit is contained in:
Tokuhiro Matsuno
2020-09-12 19:37:23 +09:00
parent ac4158ced7
commit dd83395dd3
6 changed files with 58 additions and 29 deletions

View File

@ -8,6 +8,7 @@ import jaconv
from comb import combromkan
from comb.system_dict import SystemDict
from comb.system_language_model import SystemLanguageModel
from comb.user_language_model import UserLanguageModel
from comb.graph import graph_construct, viterbi, lookup
from comb.node import Node
@ -29,20 +30,16 @@ class Comb:
logger: Logger
dictionaries: List[Any]
def __init__(self, user_dict: UserLanguageModel, system_dict: SystemDict,
def __init__(self, user_language_model: UserLanguageModel, system_dict: SystemDict,
logger: Logger = logging.getLogger(__name__)):
self.logger = logger
self.dictionaries = []
self.user_dict = user_dict
self.user_language_model = user_language_model
self.system_dict = system_dict
unigram_score = marisa_trie.RecordTrie('@f')
unigram_score.load(f"{MODEL_DIR}/jawiki.1gram")
system_language_model = SystemLanguageModel.create()
bigram_score = marisa_trie.RecordTrie('@f')
bigram_score.load(f"{MODEL_DIR}/jawiki.2gram")
self.language_model = LanguageModel(unigram_score, bigram_score, user_dict)
self.language_model = LanguageModel(system_language_model, user_language_model)
# 連文節変換するバージョン。
def convert2(self, src: str, force_selected_clause: List[slice] = None) -> List[List[Node]]:
@ -99,9 +96,9 @@ class Comb:
candidates = [[hiragana, hiragana]]
for e in self.user_dict.get_candidates(src, hiragana):
if e not in candidates:
candidates.append(e)
# for e in self.user_dict.get_candidates(src, hiragana):
# if e not in candidates:
# candidates.append(e)
if hiragana == 'きょう':
# こういう類の特別なワードは、そのまま記憶してはいけない。。。