system language model

This commit is contained in:
Tokuhiro Matsuno
2020-09-12 19:37:23 +09:00
parent ac4158ced7
commit dd83395dd3
6 changed files with 58 additions and 29 deletions

View File

@ -54,6 +54,7 @@ install: all comb/config.py model/jawiki.1gram install-dict
install -m 0644 comb/engine.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
install -m 0644 comb/ui.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
install -m 0644 comb/user_language_model.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
install -m 0644 comb/sytem_language_model.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
install -m 0644 comb/system_dict.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
install -m 0644 comb.xml $(DESTDIR)$(DATADIR)/ibus/component
@ -68,6 +69,7 @@ uninstall:
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/node.py
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/ui.py
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/user_language_model.py
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/system_language_model.py
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/system_dict.py
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/ibus.py
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/model/jawiki.1gram

View File

@ -8,6 +8,7 @@ import jaconv
from comb import combromkan
from comb.system_dict import SystemDict
from comb.system_language_model import SystemLanguageModel
from comb.user_language_model import UserLanguageModel
from comb.graph import graph_construct, viterbi, lookup
from comb.node import Node
@ -29,20 +30,16 @@ class Comb:
logger: Logger
dictionaries: List[Any]
def __init__(self, user_dict: UserLanguageModel, system_dict: SystemDict,
def __init__(self, user_language_model: UserLanguageModel, system_dict: SystemDict,
logger: Logger = logging.getLogger(__name__)):
self.logger = logger
self.dictionaries = []
self.user_dict = user_dict
self.user_language_model = user_language_model
self.system_dict = system_dict
unigram_score = marisa_trie.RecordTrie('@f')
unigram_score.load(f"{MODEL_DIR}/jawiki.1gram")
system_language_model = SystemLanguageModel.create()
bigram_score = marisa_trie.RecordTrie('@f')
bigram_score.load(f"{MODEL_DIR}/jawiki.2gram")
self.language_model = LanguageModel(unigram_score, bigram_score, user_dict)
self.language_model = LanguageModel(system_language_model, user_language_model)
# 連文節変換するバージョン。
def convert2(self, src: str, force_selected_clause: List[slice] = None) -> List[List[Node]]:
@ -99,9 +96,9 @@ class Comb:
candidates = [[hiragana, hiragana]]
for e in self.user_dict.get_candidates(src, hiragana):
if e not in candidates:
candidates.append(e)
# for e in self.user_dict.get_candidates(src, hiragana):
# if e not in candidates:
# candidates.append(e)
if hiragana == 'きょう':
# こういう類の特別なワードは、そのまま記憶してはいけない。。。

View File

@ -5,20 +5,18 @@ import math
import marisa_trie
from comb.node import Node
from comb.system_language_model import SystemLanguageModel
from comb.user_language_model import UserLanguageModel
DEFAULT_SCORE = [(math.log10(0.00000000001),)]
class LanguageModel:
def __init__(self,
system_unigram_score: marisa_trie.RecordTrie,
system_bigram_score: marisa_trie.RecordTrie,
system_language_model: SystemLanguageModel,
user_language_model: UserLanguageModel,
logger: logging.Logger = logging.getLogger(__name__)):
self.logger = logger
self.system_bigram_score = system_bigram_score
self.system_unigram_score = system_unigram_score
self.system_language_model = system_language_model
self.user_language_model = user_language_model
def calc_node_cost(self, node: Node) -> float:
@ -31,7 +29,7 @@ class LanguageModel:
if u:
# self.logger.info(f"Use user score: {node.get_key()} -> {u}")
return u
return self.system_unigram_score.get(node.get_key(), DEFAULT_SCORE)[0][0]
return self.system_language_model.get_unigram_cost(node.get_key())
@functools.lru_cache
def calc_bigram_cost(self, prev_node, next_node) -> float:
@ -40,6 +38,4 @@ class LanguageModel:
if u:
self.logger.info(f"Use user's bigram score: {prev_node.get_key()},{next_node.get_key()} -> {u}")
return u
return self.system_bigram_score.get(
f"{prev_node.get_key()}\t{next_node.get_key()}", DEFAULT_SCORE
)[0][0]
return self.system_language_model.get_bigram_cost(prev_node, next_node)

View File

@ -0,0 +1,33 @@
import math
import marisa_trie
from comb.config import MODEL_DIR
from comb.node import Node
DEFAULT_SCORE = [(math.log10(0.00000000001),)]
class SystemLanguageModel:
def __init__(self, unigram_score: marisa_trie.RecordTrie, bigram_score: marisa_trie.RecordTrie):
self.unigram_score = unigram_score
self.bigram_score = bigram_score
@staticmethod
def create():
unigram_score = marisa_trie.RecordTrie('@f')
unigram_score.load(f"{MODEL_DIR}/jawiki.1gram")
bigram_score = marisa_trie.RecordTrie('@f')
bigram_score.load(f"{MODEL_DIR}/jawiki.2gram")
return SystemLanguageModel(unigram_score, bigram_score)
def get_unigram_cost(self, key: str) -> float:
return self.unigram_score.get(key, DEFAULT_SCORE)[0][0]
def get_bigram_cost(self, node1: Node, node2: Node) -> float:
key1 = node1.get_key()
key2 = node2.get_key()
key = key1 + "\t" + key2
return self.bigram_score.get(key, DEFAULT_SCORE)[0][0]

View File

@ -1,18 +1,16 @@
from tempfile import NamedTemporaryFile
from comb.combromkan import to_hiragana
import pytest
import marisa_trie
from comb.system_dict import SystemDict
from comb.graph import lookup, graph_construct, viterbi
from comb.engine import Comb
from comb.system_dict import SystemDict
from comb.user_language_model import UserLanguageModel
tmpfile = NamedTemporaryFile(delete=False)
user_dict = UserLanguageModel(tmpfile.name)
user_language_model = UserLanguageModel(tmpfile.name)
system_dict = SystemDict()
comb = Comb(user_dict=user_dict, system_dict=system_dict)
comb = Comb(user_language_model=user_language_model, system_dict=system_dict)
@pytest.mark.parametrize('src, expected', [

View File

@ -8,6 +8,7 @@ from comb.graph import lookup, graph_construct, viterbi
from comb.language_model import LanguageModel
import logging
from comb.system_language_model import SystemLanguageModel
from comb.user_language_model import UserLanguageModel
unigram_score = marisa_trie.RecordTrie('@f')
@ -16,10 +17,12 @@ unigram_score.load('model/jawiki.1gram')
bigram_score = marisa_trie.RecordTrie('@f')
bigram_score.load('model/jawiki.2gram')
tmpdir = TemporaryDirectory()
user_dict = UserLanguageModel(tmpdir.name)
system_language_model = SystemLanguageModel(unigram_score, bigram_score)
language_model = LanguageModel(unigram_score, bigram_score, user_dict)
tmpdir = TemporaryDirectory()
user_language_model = UserLanguageModel(tmpdir.name)
language_model = LanguageModel(system_language_model, user_language_model=user_language_model)
system_dict = SystemDict()