mirror of
https://github.com/mii443/akaza.git
synced 2025-08-30 10:39:27 +00:00
system language model
This commit is contained in:
2
Makefile
2
Makefile
@ -54,6 +54,7 @@ install: all comb/config.py model/jawiki.1gram install-dict
|
||||
install -m 0644 comb/engine.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
|
||||
install -m 0644 comb/ui.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
|
||||
install -m 0644 comb/user_language_model.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
|
||||
install -m 0644 comb/sytem_language_model.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
|
||||
install -m 0644 comb/system_dict.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
|
||||
install -m 0644 comb.xml $(DESTDIR)$(DATADIR)/ibus/component
|
||||
|
||||
@ -68,6 +69,7 @@ uninstall:
|
||||
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/node.py
|
||||
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/ui.py
|
||||
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/user_language_model.py
|
||||
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/system_language_model.py
|
||||
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/system_dict.py
|
||||
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/ibus.py
|
||||
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/model/jawiki.1gram
|
||||
|
@ -8,6 +8,7 @@ import jaconv
|
||||
from comb import combromkan
|
||||
|
||||
from comb.system_dict import SystemDict
|
||||
from comb.system_language_model import SystemLanguageModel
|
||||
from comb.user_language_model import UserLanguageModel
|
||||
from comb.graph import graph_construct, viterbi, lookup
|
||||
from comb.node import Node
|
||||
@ -29,20 +30,16 @@ class Comb:
|
||||
logger: Logger
|
||||
dictionaries: List[Any]
|
||||
|
||||
def __init__(self, user_dict: UserLanguageModel, system_dict: SystemDict,
|
||||
def __init__(self, user_language_model: UserLanguageModel, system_dict: SystemDict,
|
||||
logger: Logger = logging.getLogger(__name__)):
|
||||
self.logger = logger
|
||||
self.dictionaries = []
|
||||
self.user_dict = user_dict
|
||||
self.user_language_model = user_language_model
|
||||
self.system_dict = system_dict
|
||||
|
||||
unigram_score = marisa_trie.RecordTrie('@f')
|
||||
unigram_score.load(f"{MODEL_DIR}/jawiki.1gram")
|
||||
system_language_model = SystemLanguageModel.create()
|
||||
|
||||
bigram_score = marisa_trie.RecordTrie('@f')
|
||||
bigram_score.load(f"{MODEL_DIR}/jawiki.2gram")
|
||||
|
||||
self.language_model = LanguageModel(unigram_score, bigram_score, user_dict)
|
||||
self.language_model = LanguageModel(system_language_model, user_language_model)
|
||||
|
||||
# 連文節変換するバージョン。
|
||||
def convert2(self, src: str, force_selected_clause: List[slice] = None) -> List[List[Node]]:
|
||||
@ -99,9 +96,9 @@ class Comb:
|
||||
|
||||
candidates = [[hiragana, hiragana]]
|
||||
|
||||
for e in self.user_dict.get_candidates(src, hiragana):
|
||||
if e not in candidates:
|
||||
candidates.append(e)
|
||||
# for e in self.user_dict.get_candidates(src, hiragana):
|
||||
# if e not in candidates:
|
||||
# candidates.append(e)
|
||||
|
||||
if hiragana == 'きょう':
|
||||
# こういう類の特別なワードは、そのまま記憶してはいけない。。。
|
||||
|
@ -5,20 +5,18 @@ import math
|
||||
import marisa_trie
|
||||
|
||||
from comb.node import Node
|
||||
from comb.system_language_model import SystemLanguageModel
|
||||
from comb.user_language_model import UserLanguageModel
|
||||
|
||||
DEFAULT_SCORE = [(math.log10(0.00000000001),)]
|
||||
|
||||
|
||||
class LanguageModel:
|
||||
def __init__(self,
|
||||
system_unigram_score: marisa_trie.RecordTrie,
|
||||
system_bigram_score: marisa_trie.RecordTrie,
|
||||
system_language_model: SystemLanguageModel,
|
||||
user_language_model: UserLanguageModel,
|
||||
logger: logging.Logger = logging.getLogger(__name__)):
|
||||
self.logger = logger
|
||||
self.system_bigram_score = system_bigram_score
|
||||
self.system_unigram_score = system_unigram_score
|
||||
self.system_language_model = system_language_model
|
||||
self.user_language_model = user_language_model
|
||||
|
||||
def calc_node_cost(self, node: Node) -> float:
|
||||
@ -31,7 +29,7 @@ class LanguageModel:
|
||||
if u:
|
||||
# self.logger.info(f"Use user score: {node.get_key()} -> {u}")
|
||||
return u
|
||||
return self.system_unigram_score.get(node.get_key(), DEFAULT_SCORE)[0][0]
|
||||
return self.system_language_model.get_unigram_cost(node.get_key())
|
||||
|
||||
@functools.lru_cache
|
||||
def calc_bigram_cost(self, prev_node, next_node) -> float:
|
||||
@ -40,6 +38,4 @@ class LanguageModel:
|
||||
if u:
|
||||
self.logger.info(f"Use user's bigram score: {prev_node.get_key()},{next_node.get_key()} -> {u}")
|
||||
return u
|
||||
return self.system_bigram_score.get(
|
||||
f"{prev_node.get_key()}\t{next_node.get_key()}", DEFAULT_SCORE
|
||||
)[0][0]
|
||||
return self.system_language_model.get_bigram_cost(prev_node, next_node)
|
||||
|
33
comb/system_language_model.py
Normal file
33
comb/system_language_model.py
Normal file
@ -0,0 +1,33 @@
|
||||
import math
|
||||
|
||||
import marisa_trie
|
||||
|
||||
from comb.config import MODEL_DIR
|
||||
from comb.node import Node
|
||||
|
||||
DEFAULT_SCORE = [(math.log10(0.00000000001),)]
|
||||
|
||||
|
||||
class SystemLanguageModel:
|
||||
def __init__(self, unigram_score: marisa_trie.RecordTrie, bigram_score: marisa_trie.RecordTrie):
|
||||
self.unigram_score = unigram_score
|
||||
self.bigram_score = bigram_score
|
||||
|
||||
@staticmethod
|
||||
def create():
|
||||
unigram_score = marisa_trie.RecordTrie('@f')
|
||||
unigram_score.load(f"{MODEL_DIR}/jawiki.1gram")
|
||||
|
||||
bigram_score = marisa_trie.RecordTrie('@f')
|
||||
bigram_score.load(f"{MODEL_DIR}/jawiki.2gram")
|
||||
|
||||
return SystemLanguageModel(unigram_score, bigram_score)
|
||||
|
||||
def get_unigram_cost(self, key: str) -> float:
|
||||
return self.unigram_score.get(key, DEFAULT_SCORE)[0][0]
|
||||
|
||||
def get_bigram_cost(self, node1: Node, node2: Node) -> float:
|
||||
key1 = node1.get_key()
|
||||
key2 = node2.get_key()
|
||||
key = key1 + "\t" + key2
|
||||
return self.bigram_score.get(key, DEFAULT_SCORE)[0][0]
|
@ -1,18 +1,16 @@
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from comb.combromkan import to_hiragana
|
||||
import pytest
|
||||
import marisa_trie
|
||||
from comb.system_dict import SystemDict
|
||||
from comb.graph import lookup, graph_construct, viterbi
|
||||
|
||||
from comb.engine import Comb
|
||||
from comb.system_dict import SystemDict
|
||||
from comb.user_language_model import UserLanguageModel
|
||||
|
||||
tmpfile = NamedTemporaryFile(delete=False)
|
||||
user_dict = UserLanguageModel(tmpfile.name)
|
||||
user_language_model = UserLanguageModel(tmpfile.name)
|
||||
system_dict = SystemDict()
|
||||
|
||||
comb = Comb(user_dict=user_dict, system_dict=system_dict)
|
||||
comb = Comb(user_language_model=user_language_model, system_dict=system_dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('src, expected', [
|
||||
|
@ -8,6 +8,7 @@ from comb.graph import lookup, graph_construct, viterbi
|
||||
from comb.language_model import LanguageModel
|
||||
import logging
|
||||
|
||||
from comb.system_language_model import SystemLanguageModel
|
||||
from comb.user_language_model import UserLanguageModel
|
||||
|
||||
unigram_score = marisa_trie.RecordTrie('@f')
|
||||
@ -16,10 +17,12 @@ unigram_score.load('model/jawiki.1gram')
|
||||
bigram_score = marisa_trie.RecordTrie('@f')
|
||||
bigram_score.load('model/jawiki.2gram')
|
||||
|
||||
tmpdir = TemporaryDirectory()
|
||||
user_dict = UserLanguageModel(tmpdir.name)
|
||||
system_language_model = SystemLanguageModel(unigram_score, bigram_score)
|
||||
|
||||
language_model = LanguageModel(unigram_score, bigram_score, user_dict)
|
||||
tmpdir = TemporaryDirectory()
|
||||
user_language_model = UserLanguageModel(tmpdir.name)
|
||||
|
||||
language_model = LanguageModel(system_language_model, user_language_model=user_language_model)
|
||||
|
||||
system_dict = SystemDict()
|
||||
|
||||
|
Reference in New Issue
Block a user