mirror of
https://github.com/mii443/akaza.git
synced 2025-08-30 02:29:30 +00:00
system language model
This commit is contained in:
2
Makefile
2
Makefile
@ -54,6 +54,7 @@ install: all comb/config.py model/jawiki.1gram install-dict
|
|||||||
install -m 0644 comb/engine.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
|
install -m 0644 comb/engine.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
|
||||||
install -m 0644 comb/ui.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
|
install -m 0644 comb/ui.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
|
||||||
install -m 0644 comb/user_language_model.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
|
install -m 0644 comb/user_language_model.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
|
||||||
|
install -m 0644 comb/sytem_language_model.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
|
||||||
install -m 0644 comb/system_dict.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
|
install -m 0644 comb/system_dict.py $(DESTDIR)$(DATADIR)/ibus-comb/comb/
|
||||||
install -m 0644 comb.xml $(DESTDIR)$(DATADIR)/ibus/component
|
install -m 0644 comb.xml $(DESTDIR)$(DATADIR)/ibus/component
|
||||||
|
|
||||||
@ -68,6 +69,7 @@ uninstall:
|
|||||||
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/node.py
|
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/node.py
|
||||||
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/ui.py
|
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/ui.py
|
||||||
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/user_language_model.py
|
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/user_language_model.py
|
||||||
|
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/system_language_model.py
|
||||||
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/system_dict.py
|
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/comb/system_dict.py
|
||||||
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/ibus.py
|
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/ibus.py
|
||||||
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/model/jawiki.1gram
|
rm -f $(DESTDIR)$(DATADIR)/ibus-comb/model/jawiki.1gram
|
||||||
|
@ -8,6 +8,7 @@ import jaconv
|
|||||||
from comb import combromkan
|
from comb import combromkan
|
||||||
|
|
||||||
from comb.system_dict import SystemDict
|
from comb.system_dict import SystemDict
|
||||||
|
from comb.system_language_model import SystemLanguageModel
|
||||||
from comb.user_language_model import UserLanguageModel
|
from comb.user_language_model import UserLanguageModel
|
||||||
from comb.graph import graph_construct, viterbi, lookup
|
from comb.graph import graph_construct, viterbi, lookup
|
||||||
from comb.node import Node
|
from comb.node import Node
|
||||||
@ -29,20 +30,16 @@ class Comb:
|
|||||||
logger: Logger
|
logger: Logger
|
||||||
dictionaries: List[Any]
|
dictionaries: List[Any]
|
||||||
|
|
||||||
def __init__(self, user_dict: UserLanguageModel, system_dict: SystemDict,
|
def __init__(self, user_language_model: UserLanguageModel, system_dict: SystemDict,
|
||||||
logger: Logger = logging.getLogger(__name__)):
|
logger: Logger = logging.getLogger(__name__)):
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.dictionaries = []
|
self.dictionaries = []
|
||||||
self.user_dict = user_dict
|
self.user_language_model = user_language_model
|
||||||
self.system_dict = system_dict
|
self.system_dict = system_dict
|
||||||
|
|
||||||
unigram_score = marisa_trie.RecordTrie('@f')
|
system_language_model = SystemLanguageModel.create()
|
||||||
unigram_score.load(f"{MODEL_DIR}/jawiki.1gram")
|
|
||||||
|
|
||||||
bigram_score = marisa_trie.RecordTrie('@f')
|
self.language_model = LanguageModel(system_language_model, user_language_model)
|
||||||
bigram_score.load(f"{MODEL_DIR}/jawiki.2gram")
|
|
||||||
|
|
||||||
self.language_model = LanguageModel(unigram_score, bigram_score, user_dict)
|
|
||||||
|
|
||||||
# 連文節変換するバージョン。
|
# 連文節変換するバージョン。
|
||||||
def convert2(self, src: str, force_selected_clause: List[slice] = None) -> List[List[Node]]:
|
def convert2(self, src: str, force_selected_clause: List[slice] = None) -> List[List[Node]]:
|
||||||
@ -99,9 +96,9 @@ class Comb:
|
|||||||
|
|
||||||
candidates = [[hiragana, hiragana]]
|
candidates = [[hiragana, hiragana]]
|
||||||
|
|
||||||
for e in self.user_dict.get_candidates(src, hiragana):
|
# for e in self.user_dict.get_candidates(src, hiragana):
|
||||||
if e not in candidates:
|
# if e not in candidates:
|
||||||
candidates.append(e)
|
# candidates.append(e)
|
||||||
|
|
||||||
if hiragana == 'きょう':
|
if hiragana == 'きょう':
|
||||||
# こういう類の特別なワードは、そのまま記憶してはいけない。。。
|
# こういう類の特別なワードは、そのまま記憶してはいけない。。。
|
||||||
|
@ -5,20 +5,18 @@ import math
|
|||||||
import marisa_trie
|
import marisa_trie
|
||||||
|
|
||||||
from comb.node import Node
|
from comb.node import Node
|
||||||
|
from comb.system_language_model import SystemLanguageModel
|
||||||
from comb.user_language_model import UserLanguageModel
|
from comb.user_language_model import UserLanguageModel
|
||||||
|
|
||||||
DEFAULT_SCORE = [(math.log10(0.00000000001),)]
|
|
||||||
|
|
||||||
|
|
||||||
class LanguageModel:
|
class LanguageModel:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
system_unigram_score: marisa_trie.RecordTrie,
|
system_language_model: SystemLanguageModel,
|
||||||
system_bigram_score: marisa_trie.RecordTrie,
|
|
||||||
user_language_model: UserLanguageModel,
|
user_language_model: UserLanguageModel,
|
||||||
logger: logging.Logger = logging.getLogger(__name__)):
|
logger: logging.Logger = logging.getLogger(__name__)):
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.system_bigram_score = system_bigram_score
|
self.system_language_model = system_language_model
|
||||||
self.system_unigram_score = system_unigram_score
|
|
||||||
self.user_language_model = user_language_model
|
self.user_language_model = user_language_model
|
||||||
|
|
||||||
def calc_node_cost(self, node: Node) -> float:
|
def calc_node_cost(self, node: Node) -> float:
|
||||||
@ -31,7 +29,7 @@ class LanguageModel:
|
|||||||
if u:
|
if u:
|
||||||
# self.logger.info(f"Use user score: {node.get_key()} -> {u}")
|
# self.logger.info(f"Use user score: {node.get_key()} -> {u}")
|
||||||
return u
|
return u
|
||||||
return self.system_unigram_score.get(node.get_key(), DEFAULT_SCORE)[0][0]
|
return self.system_language_model.get_unigram_cost(node.get_key())
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def calc_bigram_cost(self, prev_node, next_node) -> float:
|
def calc_bigram_cost(self, prev_node, next_node) -> float:
|
||||||
@ -40,6 +38,4 @@ class LanguageModel:
|
|||||||
if u:
|
if u:
|
||||||
self.logger.info(f"Use user's bigram score: {prev_node.get_key()},{next_node.get_key()} -> {u}")
|
self.logger.info(f"Use user's bigram score: {prev_node.get_key()},{next_node.get_key()} -> {u}")
|
||||||
return u
|
return u
|
||||||
return self.system_bigram_score.get(
|
return self.system_language_model.get_bigram_cost(prev_node, next_node)
|
||||||
f"{prev_node.get_key()}\t{next_node.get_key()}", DEFAULT_SCORE
|
|
||||||
)[0][0]
|
|
||||||
|
33
comb/system_language_model.py
Normal file
33
comb/system_language_model.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import marisa_trie
|
||||||
|
|
||||||
|
from comb.config import MODEL_DIR
|
||||||
|
from comb.node import Node
|
||||||
|
|
||||||
|
DEFAULT_SCORE = [(math.log10(0.00000000001),)]
|
||||||
|
|
||||||
|
|
||||||
|
class SystemLanguageModel:
|
||||||
|
def __init__(self, unigram_score: marisa_trie.RecordTrie, bigram_score: marisa_trie.RecordTrie):
|
||||||
|
self.unigram_score = unigram_score
|
||||||
|
self.bigram_score = bigram_score
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create():
|
||||||
|
unigram_score = marisa_trie.RecordTrie('@f')
|
||||||
|
unigram_score.load(f"{MODEL_DIR}/jawiki.1gram")
|
||||||
|
|
||||||
|
bigram_score = marisa_trie.RecordTrie('@f')
|
||||||
|
bigram_score.load(f"{MODEL_DIR}/jawiki.2gram")
|
||||||
|
|
||||||
|
return SystemLanguageModel(unigram_score, bigram_score)
|
||||||
|
|
||||||
|
def get_unigram_cost(self, key: str) -> float:
|
||||||
|
return self.unigram_score.get(key, DEFAULT_SCORE)[0][0]
|
||||||
|
|
||||||
|
def get_bigram_cost(self, node1: Node, node2: Node) -> float:
|
||||||
|
key1 = node1.get_key()
|
||||||
|
key2 = node2.get_key()
|
||||||
|
key = key1 + "\t" + key2
|
||||||
|
return self.bigram_score.get(key, DEFAULT_SCORE)[0][0]
|
@ -1,18 +1,16 @@
|
|||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
from comb.combromkan import to_hiragana
|
|
||||||
import pytest
|
import pytest
|
||||||
import marisa_trie
|
|
||||||
from comb.system_dict import SystemDict
|
|
||||||
from comb.graph import lookup, graph_construct, viterbi
|
|
||||||
from comb.engine import Comb
|
from comb.engine import Comb
|
||||||
|
from comb.system_dict import SystemDict
|
||||||
from comb.user_language_model import UserLanguageModel
|
from comb.user_language_model import UserLanguageModel
|
||||||
|
|
||||||
tmpfile = NamedTemporaryFile(delete=False)
|
tmpfile = NamedTemporaryFile(delete=False)
|
||||||
user_dict = UserLanguageModel(tmpfile.name)
|
user_language_model = UserLanguageModel(tmpfile.name)
|
||||||
system_dict = SystemDict()
|
system_dict = SystemDict()
|
||||||
|
|
||||||
comb = Comb(user_dict=user_dict, system_dict=system_dict)
|
comb = Comb(user_language_model=user_language_model, system_dict=system_dict)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('src, expected', [
|
@pytest.mark.parametrize('src, expected', [
|
||||||
|
@ -8,6 +8,7 @@ from comb.graph import lookup, graph_construct, viterbi
|
|||||||
from comb.language_model import LanguageModel
|
from comb.language_model import LanguageModel
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from comb.system_language_model import SystemLanguageModel
|
||||||
from comb.user_language_model import UserLanguageModel
|
from comb.user_language_model import UserLanguageModel
|
||||||
|
|
||||||
unigram_score = marisa_trie.RecordTrie('@f')
|
unigram_score = marisa_trie.RecordTrie('@f')
|
||||||
@ -16,10 +17,12 @@ unigram_score.load('model/jawiki.1gram')
|
|||||||
bigram_score = marisa_trie.RecordTrie('@f')
|
bigram_score = marisa_trie.RecordTrie('@f')
|
||||||
bigram_score.load('model/jawiki.2gram')
|
bigram_score.load('model/jawiki.2gram')
|
||||||
|
|
||||||
tmpdir = TemporaryDirectory()
|
system_language_model = SystemLanguageModel(unigram_score, bigram_score)
|
||||||
user_dict = UserLanguageModel(tmpdir.name)
|
|
||||||
|
|
||||||
language_model = LanguageModel(unigram_score, bigram_score, user_dict)
|
tmpdir = TemporaryDirectory()
|
||||||
|
user_language_model = UserLanguageModel(tmpdir.name)
|
||||||
|
|
||||||
|
language_model = LanguageModel(system_language_model, user_language_model=user_language_model)
|
||||||
|
|
||||||
system_dict = SystemDict()
|
system_dict = SystemDict()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user