implemented

This commit is contained in:
Tokuhiro Matsuno
2020-10-15 00:42:46 +09:00
parent a00909bcfc
commit 10cf283562
31 changed files with 253 additions and 140 deletions

2
.gitignore vendored
View File

@@ -8,3 +8,5 @@ __pycache__
dist/
*.lprof
*.so
perf.data
callgrind.*

View File

@@ -24,12 +24,12 @@ install: ibus_akaza/config.py akaza.xml po/ja.mo
install -m 0644 ibus_akaza/config.py $(DESTDIR)$(DATADIR)/ibus-akaza/ibus_akaza/
install -m 0644 ibus_akaza/keymap.py $(DESTDIR)$(DATADIR)/ibus-akaza/ibus_akaza/
install -m 0644 ibus_akaza/config_loader.py $(DESTDIR)$(DATADIR)/ibus-akaza/ibus_akaza/
install -m 0644 ibus_akaza/skk_file_dict.py $(DESTDIR)$(DATADIR)/ibus-akaza/ibus_akaza/
install -m 0644 ibus_akaza/__init__.py $(DESTDIR)$(DATADIR)/ibus-akaza/ibus_akaza/
install -m 0644 akaza.xml $(DESTDIR)$(DATADIR)/ibus/component
install -m 0644 po/ja.mo $(DESTDIR)$(DATADIR)/locale/ja/LC_MESSAGES/ibus-akaza.mo
ibus_akaza/config.py: ibus_akaza/config.py.in
s_akaza/config.py: ibus_akaza/config.py.in
sed -e "s:@SYSCONFDIR@:$(SYSCONFDIR):g" \
-e "s:@MODELDIR@:$(DESTDIR)/$(DATADIR)/akaza-data/:g" \
-e "s:@DICTIONARYDIR@:$(DESTDIR)/$(DATADIR)/ibus-akaza/dictionary:g" \

View File

@@ -2,4 +2,3 @@ import os
SYS_CONF_DIR = os.environ.get('AKAZA_SYSCONF_DIR', '@SYSCONFDIR@')
MODEL_DIR = os.environ.get('AKAZA_MODEL_DIR', '@MODELDIR@')
DICTIONARY_DIR = os.environ.get('AKAZA_DICTIONARY_DIR', '@DICTIONARYDIR@')

View File

@@ -1,4 +1,4 @@
from akaza_data.systemlm_loader import BinaryDict
from pyakaza.bind import BinaryDict
from skkdictutils import parse_skkdict, merge_skkdict, ari2nasi

View File

@@ -1,5 +1,5 @@
import time
from typing import List, Dict
from typing import List, Dict, Optional
import gi
@@ -17,14 +17,11 @@ import gettext
from jaconv import jaconv
from akaza import Akaza
from akaza.romkan import RomkanConverter
from akaza.node import Node
from akaza.user_language_model import UserLanguageModel
from akaza.graph_resolver import GraphResolver
from pyakaza.bind import Akaza, GraphResolver, BinaryDict, SystemUnigramLM, SystemBigramLM, Node, UserLanguageModel, \
Slice, RomkanConverter, TinyLisp
from ibus_akaza import config_loader
from ibus_akaza.config import MODEL_DIR
from akaza_data.systemlm_loader import BinaryDict, SystemUnigramLM, SystemBigramLM, TinyLisp
from .keymap import build_default_keymap, KEY_STATE_PRECOMPOSITION, KEY_STATE_COMPOSITION, KEY_STATE_CONVERSION
from .input_mode import get_input_mode_from_prop_name, InputMode, INPUT_MODE_ALNUM, INPUT_MODE_HIRAGANA, \
@@ -41,7 +38,10 @@ def build_akaza():
user_language_model_path = configdir.joinpath('user_language_model')
user_language_model_path.mkdir(parents=True, exist_ok=True, mode=0o700)
user_language_model = UserLanguageModel(str(user_language_model_path))
user_language_model = UserLanguageModel(
str(user_language_model_path.joinpath('unigram.txt')),
str(user_language_model_path.joinpath('bigram.txt'))
)
system_dict = BinaryDict()
print(f"{MODEL_DIR + '/system_dict.trie'}")
@@ -57,18 +57,18 @@ def build_akaza():
emoji_dict.load(MODEL_DIR + "/single_term.trie")
resolver = GraphResolver(
normal_dicts=[system_dict] + user_dicts,
system_unigram_lm=system_unigram_lm,
system_bigram_lm=system_bigram_lm,
user_language_model=user_language_model,
single_term_dicts=[emoji_dict],
user_language_model,
system_unigram_lm,
system_bigram_lm,
[system_dict] + user_dicts,
[emoji_dict],
)
romkan = RomkanConverter(additional=user_settings.get('romaji'))
romkan = RomkanConverter(user_settings.get('romaji'))
lisp_evaluator = TinyLisp()
return user_language_model, Akaza(resolver=resolver, romkan=romkan), romkan, lisp_evaluator, user_settings
return user_language_model, Akaza(resolver, romkan), romkan, lisp_evaluator, user_settings
try:
@@ -76,9 +76,14 @@ try:
user_language_model, akaza, romkan, lisp_evaluator, user_settings = build_akaza()
def save_periodically():
while True:
user_language_model.save()
time.sleep(10)
user_language_model_save_thread = threading.Thread(
name='user_language_model_save_thread',
target=lambda: user_language_model.save_periodically(),
target=lambda: save_periodically(),
daemon=True,
)
user_language_model_save_thread.start()
@@ -104,7 +109,7 @@ class AkazaIBusEngine(IBus.Engine):
prop_list: IBus.PropList
akaza: Akaza
input_mode: InputMode
force_selected_clause: List[slice]
force_selected_clause: Optional[List[slice]]
__gtype_name__ = 'AkazaIBusEngine'
@@ -128,7 +133,7 @@ class AkazaIBusEngine(IBus.Engine):
self.node_selected = {}
# 文節を選びなおしたもの。
self.force_selected_clause = []
self.force_selected_clause = None
# カーソル変更をしたばっかりかどうかを、みるフラグ。
self.cursor_moved = False
@@ -401,10 +406,10 @@ g
F6 などを押した時用。
"""
# 候補を設定
self.clauses = [[Node(start_pos=0, word=word, yomi=yomi)]]
self.clauses = [[Node(0, yomi, word)]]
self.current_clause = 0
self.node_selected = {}
self.force_selected_clause = []
self.force_selected_clause = None
# ルックアップテーブルに候補を設定
self.lookup_table.clear()
@@ -504,7 +509,7 @@ g
if len(self.clauses) == 0:
return False
max_len = max([clause[0].start_pos + len(clause[0].yomi) for clause in self.clauses])
max_len = max([clause[0].get_start_pos() + len(clause[0].get_yomi()) for clause in self.clauses])
self.force_selected_clause = []
for i, clause in enumerate(self.clauses):
@@ -512,15 +517,15 @@ g
if self.current_clause == i:
# 現在選択中の文節の場合、伸ばす。
self.force_selected_clause.append(
slice(node.start_pos, min(node.start_pos + len(node.yomi) + 1, max_len)))
slice(node.get_start_pos(), min(node.get_start_pos() + len(node.get_yomi()) + 1, max_len)))
elif self.current_clause + 1 == i:
# 次の分節を一文字ヘラス
self.force_selected_clause.append(
slice(node.start_pos + 1, node.start_pos + len(node.yomi)))
slice(node.get_start_pos() + 1, node.get_start_pos() + len(node.get_yomi())))
else:
# それ以外は現在指定の分節のまま
self.force_selected_clause.append(
slice(node.start_pos, node.start_pos + len(node.yomi)))
slice(node.get_start_pos(), node.get_start_pos() + len(node.get_yomi())))
self.force_selected_clause = [x for x in self.force_selected_clause if x.start != x.stop]
self._update_candidates()
@@ -542,15 +547,15 @@ g
if target_clause == i:
# 現在選択中の文節の場合、伸ばす。
self.force_selected_clause.append(
slice(node.start_pos - 1, node.start_pos + len(node.yomi)))
slice(node.get_start_pos() - 1, node.get_start_pos() + len(node.get_yomi())))
elif target_clause - 1 == i:
# 前の分節を一文字ヘラス
self.force_selected_clause.append(
slice(node.start_pos, node.start_pos + len(node.yomi) - 1))
slice(node.get_start_pos(), node.get_start_pos() + len(node.get_yomi()) - 1))
else:
# それ以外は現在指定の分節のまま
self.force_selected_clause.append(
slice(node.start_pos, node.start_pos + len(node.yomi)))
slice(node.get_start_pos(), node.get_start_pos() + len(node.get_yomi())))
self.force_selected_clause = [x for x in self.force_selected_clause if x.start != x.stop]
self._update_candidates()
@@ -578,7 +583,7 @@ g
self.clauses = []
self.current_clause = 0
self.node_selected = {}
self.force_selected_clause = []
self.force_selected_clause = None
self.lookup_table.clear()
self.update_lookup_table(self.lookup_table, False)
@@ -610,7 +615,12 @@ g
def _update_candidates(self):
if len(self.preedit_string) > 0:
# 変換をかける
self.clauses = self.akaza.convert(self.preedit_string, self.force_selected_clause)
print(f"-------{self.preedit_string}-----{self.force_selected_clause}----PPP")
slices = None
if self.force_selected_clause:
slices = [Slice(s.start, s.stop-s.start) for s in self.force_selected_clause]
print(f"-------{self.preedit_string}-----{self.force_selected_clause}---{slices}----PPP")
self.clauses = self.akaza.convert(self.preedit_string, slices)
else:
self.clauses = []
self.create_lookup_table()
@@ -630,7 +640,7 @@ g
current_node = current_clause[0]
# -- auxiliary text(ポップアップしてるやつのほう)
first_candidate = current_node.yomi
first_candidate = current_node.get_yomi()
auxiliary_text = IBus.Text.new_from_string(first_candidate)
auxiliary_text.set_attributes(IBus.AttrList())
self.update_auxiliary_text(auxiliary_text, preedit_len > 0)
@@ -663,7 +673,7 @@ g
# 先頭が大文字だと、
if len(self.preedit_string) > 0 and self.preedit_string[0].isupper() \
and len(self.force_selected_clause) == 0:
and self.force_selected_clause is None:
return self.preedit_string, self.preedit_string
yomi = self.romkan.to_hiragana(self.preedit_string)
@@ -687,7 +697,7 @@ g
# 平仮名にする。
yomi, word = self._make_preedit_word()
self.clauses = [
[Node(word=word, yomi=yomi, start_pos=0)]
[Node(0, yomi, word)]
]
self.current_clause = 0
@@ -732,7 +742,7 @@ g
def do_reset(self):
self.logger.debug("do_reset")
self.preedit_string = ''
self.force_selected_clause = []
self.force_selected_clause = None
self.clauses = []
self.current_clause = 0
self.node_selected = {}

View File

@@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2020-10-04 23:30+0900\n"
"POT-Creation-Date: 2020-10-14 22:44+0900\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
@@ -37,11 +37,11 @@ msgstr ""
msgid "Full-width Alphanumeric (C-S-l)"
msgstr ""
#: ibus_akaza/ui.py:156 ibus_akaza/ui.py:293
#: ibus_akaza/ui.py:153 ibus_akaza/ui.py:290
#, python-format
msgid "Input mode (%s)"
msgstr ""
#: ibus_akaza/ui.py:158
#: ibus_akaza/ui.py:155
msgid "Switch input mode"
msgstr ""

View File

@@ -7,5 +7,6 @@ setup(
extras_require={
},
entry_points={
}
},
zip_safe=False
)

View File

@@ -1,8 +1,7 @@
import pathlib
import sys
sys.path.insert(0, str(pathlib.Path(__file__).parent.joinpath('../../akaza-data/').absolute().resolve()))
sys.path.insert(0, str(pathlib.Path(__file__).parent.joinpath('../../akaza-core/').absolute().resolve()))
sys.path.insert(0, str(pathlib.Path(__file__).parent.joinpath('../../pyakaza/').absolute().resolve()))
from ibus_akaza.skk_file_dict import load_skk_file_dict
@@ -11,4 +10,3 @@ def test_read2():
path = str(pathlib.Path(__file__).parent.joinpath('data', 'SKK-JISYO.test'))
d = load_skk_file_dict(path)
assert d.find_kanjis('たばた') == ['田端']
assert d.prefixes('たばた') == ['', 'たば', 'たばた']

View File

@@ -3,13 +3,12 @@ import sys
import pathlib
import pytest
sys.path.insert(0, str(pathlib.Path(__file__).parent.joinpath('../../akaza-data/').absolute().resolve()))
sys.path.insert(0, str(pathlib.Path(__file__).parent.joinpath('../../akaza-core/').absolute().resolve()))
sys.path.insert(0, str(pathlib.Path(__file__).parent.joinpath('../../pyakaza/').absolute().resolve()))
sys.path.insert(0, str(pathlib.Path(__file__).parent.joinpath('../').absolute().resolve()))
os.environ['AKAZA_DICTIONARY_DIR'] = 'model/'
os.environ['AKAZA_MODEL_DIR'] = '../akaza-data/akaza_data/data/'
from akaza.node import Node
from pyakaza.bind import Node
from ibus_akaza.ui import AkazaIBusEngine
from ibus_akaza.input_mode import INPUT_MODE_KATAKANA, INPUT_MODE_HIRAGANA
@@ -41,7 +40,9 @@ def test_extend_clause_right():
assert got == 'タノシ/イジ/カン'
# 2文節目が イジ になっている
assert '維持' in [node.word for node in ui.clauses[1]]
second_words = [node.get_word() for node in ui.clauses[1]]
print(second_words)
assert '維持' in second_words
def test_extend_clause_right_most_right():
@@ -81,7 +82,7 @@ def test_extend_clause_left():
# ↑ focus
ui.cursor_right()
print('/'.join([clause[0].yomi for clause in ui.clauses]))
print('/'.join([clause[0].get_yomi() for clause in ui.clauses]))
# タノ/シイ/ジカン
# 0 1 2 3 4 5 6
@@ -95,10 +96,10 @@ def test_extend_clause_left():
got = '/'.join(["タノシイジカン"[s] for s in ui.force_selected_clause])
assert got == 'タノ/シイ/ジカン'
print('/'.join([clause[0].yomi for clause in ui.clauses]))
print('/'.join([clause[0].get_yomi() for clause in ui.clauses]))
# 2文節目が しい になっている
assert 'たの/しい/じかん' == '/'.join([clause[0].yomi for clause in ui.clauses])
assert 'たの/しい/じかん' == '/'.join([clause[0].get_yomi() for clause in ui.clauses])
def test_extend_clause_left_most_left():
@@ -107,7 +108,7 @@ def test_extend_clause_left_most_left():
ui.update_candidates()
# タノシ/イ/ジカン
print('/'.join([clause[0].yomi for clause in ui.clauses]))
print('/'.join([clause[0].get_yomi() for clause in ui.clauses]))
# タノ/シイ/ジカン
# 0 1 2 3 4 5 6
@@ -121,10 +122,10 @@ def test_extend_clause_left_most_left():
got = '/'.join(["タノシイジカン"[s] for s in ui.force_selected_clause])
assert got == 'タノ/シイ/ジカン'
print('/'.join([clause[0].yomi for clause in ui.clauses]))
print('/'.join([clause[0].get_yomi() for clause in ui.clauses]))
# 2文節目が しい になっている
assert 'たの/しい/じかん' == '/'.join([clause[0].yomi for clause in ui.clauses])
assert 'たの/しい/じかん' == '/'.join([clause[0].get_yomi() for clause in ui.clauses])
@pytest.mark.skip(reason='今はうごかない')
@@ -134,7 +135,7 @@ def test_extend_clause_left_most_left_and_more():
ui.update_candidates()
# どん/だけ/とち/かん
assert '/'.join([clause[0].yomi for clause in ui.clauses]) == 'どん/だけ/と/ちかん'
assert '/'.join([clause[0].get_yomi() for clause in ui.clauses]) == 'どん/だけ/と/ちかん'
# どん/だけ/とち/かん
# 0 1 2 3 4 5 6
@@ -147,11 +148,11 @@ def test_extend_clause_left_most_left_and_more():
ui.cursor_right() # focus to とち
assert ui.current_clause == 2
ui.extend_clause_right() # とち→とちか
assert '/'.join([clause[0].yomi for clause in ui.clauses]) == 'どん/だけ/とちか/ん'
assert '/'.join([clause[0].get_yomi() for clause in ui.clauses]) == 'どん/だけ/とちか/ん'
assert '/'.join(['どんだけとちかん'[s] for s in ui.force_selected_clause]) == 'どん/だけ/とちか/ん'
assert ui.current_clause == 2
ui.extend_clause_right() # とちか→とちかん
assert '/'.join([clause[0].yomi for clause in ui.clauses]) == 'どん/だけ/とちかん'
assert '/'.join([clause[0].get_yomi() for clause in ui.clauses]) == 'どん/だけ/とちかん'
def test_update_preedit_text_before_henkan1():
@@ -161,12 +162,12 @@ def test_update_preedit_text_before_henkan1():
ui.update_preedit_text_before_henkan()
print(ui.clauses)
assert [
[Node(word='ひょい', yomi='ひょい', start_pos=0)]
[Node(0, 'ひょい', 'ひょい')]
] == [
[Node(word='ひょい', yomi='ひょい', start_pos=0)]
[Node(0, 'ひょい', 'ひょい')]
]
assert ui.clauses == [
[Node(word='ひょい', yomi='ひょい', start_pos=0)]
[Node(0, 'ひょい', 'ひょい')]
]
@@ -176,5 +177,9 @@ def test_update_preedit_text_before_henkan2():
ui.preedit_string = "hyoi-"
ui.update_preedit_text_before_henkan()
assert ui.clauses == [
[Node(word='ヒョイ', yomi='ひょい', start_pos=0)]
[Node(0, 'ひょい', 'ヒョイ')]
]
if __name__ == '__main__':
# test_extend_clause_right()
test_extend_clause_left()

View File

@@ -4,7 +4,7 @@ project(libakaza)
include(GNUInstallDirs)
set(CMAKE_CXX_STANDARD 17)
SET(CMAKE_C_FLAGS "-Wall -O2 -g ${CC_WARNING_FLAGS} ${CMAKE_C_FLAGS}")
SET(CMAKE_C_FLAGS "-Wall -O2 -pg -g ${CC_WARNING_FLAGS} ${CMAKE_C_FLAGS}")
# =============================================================================================
#
@@ -107,6 +107,11 @@ set_target_properties(10_integration.t PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CM
target_link_libraries(10_integration.t marisa akaza)
SET(TEST_EXES ${TEST_EXES} 10_integration.t)
add_executable(11_wnn.t t/11_wnn.cc)
set_target_properties(11_wnn.t PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/t/")
target_link_libraries(11_wnn.t marisa akaza)
SET(TEST_EXES ${TEST_EXES} 11_wnn.t)
ADD_CUSTOM_TARGET(test env BINARY_DIR=${CMAKE_CURRENT_BINARY_DIR} prove --exec '' -v t/*.t
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TEST_EXES})

View File

@@ -44,7 +44,7 @@ int main(int argc, char **argv) {
single_term_dicts
);
std::vector<std::tuple<std::string, std::string>> additional;
std::map<std::string, std::string> additional;
auto romkanConverter = std::make_shared<akaza::RomkanConverter>(additional);
akaza::Akaza akaza(graphResolver, romkanConverter);

View File

@@ -30,6 +30,10 @@ namespace akaza {
public:
BinaryDict() {}
size_t size() {
return dict_trie.size();
}
void load(const std::string& dict_path);
void save(std::string dict_path) {

View File

@@ -23,14 +23,16 @@ namespace akaza {
_len = len;
}
size_t start() const {
[[nodiscard]] size_t start() const {
return _start;
}
size_t len() const {
[[nodiscard]] size_t len() const {
return _len;
}
std::string repr();
};
/*
@@ -51,7 +53,7 @@ namespace akaza {
self.single_term_dicts = single_term_dicts
*/
private:
std::shared_ptr<UserLanguageModel> _user_language_model;
std::shared_ptr<UserLanguageModel> user_language_model_;
std::shared_ptr<SystemUnigramLM> _system_unigram_lm;
std::shared_ptr<SystemBigramLM> _system_bigram_lm;
std::vector<std::shared_ptr<BinaryDict>> _normal_dicts;
@@ -69,13 +71,7 @@ namespace akaza {
const std::shared_ptr<SystemBigramLM> &system_bigram_lm,
const std::vector<std::shared_ptr<BinaryDict>> &normal_dicts,
const std::vector<std::shared_ptr<BinaryDict>> &single_term_dicts
) {
_user_language_model = user_language_model;
_system_unigram_lm = system_unigram_lm;
_system_bigram_lm = system_bigram_lm;
_normal_dicts = normal_dicts;
_single_term_dicts = single_term_dicts;
}
);
/*
def lookup(self, s: str):

View File

@@ -90,6 +90,9 @@ namespace akaza {
}
void set_prev(std::shared_ptr<Node> &prev);
bool operator==(const Node &node);
bool operator!=(const Node &node);
};
std::shared_ptr<Node> create_bos_node();

View File

@@ -10,7 +10,7 @@ namespace akaza {
std::regex _last_char_pattern;
std::map<std::string, std::string> _map;
public:
RomkanConverter(const std::vector<std::tuple<std::string, std::string>> &additional);
RomkanConverter(const std::map<std::string, std::string> &additional);
std::string remove_last_char(const std::string & s);
std::string to_hiragana(const std::string & s);
};

View File

@@ -48,6 +48,14 @@ namespace akaza {
this->bigram_path = bigram_path;
}
size_t size_unigram() {
return unigram.size();
}
size_t size_bigram() {
return bigram.size();
}
void load_unigram() {
read(unigram_path, true, unigram_C, unigram_V, unigram);
}

View File

@@ -9,6 +9,13 @@ std::string akaza::Akaza::get_version() {
std::vector<std::vector<std::shared_ptr<akaza::Node>>> akaza::Akaza::convert(
const std::string &src,
const std::optional<std::vector<akaza::Slice>> &forceSelectedClauses) {
D(std::cout << "Akaza::convert '"
<< src << "' (HASH="
<< std::hash<std::string>{}(src)
<< ")"
<< " " << __FILE__ << ":" << __LINE__ << std::endl);
assert(!forceSelectedClauses.has_value() || !forceSelectedClauses.value().empty());
if (!src.empty() && isupper(src[0]) && !forceSelectedClauses.has_value()) {
return {{std::make_shared<akaza::Node>(0, src, src)}};
}

View File

@@ -5,6 +5,10 @@
#ifndef LIBAKAZA_DEBUG_LOG_H
#define LIBAKAZA_DEBUG_LOG_H
#if 1
#define D(x) do { } while (0)
#else
#define D(x) do { (x); } while (0)
#endif
#endif //LIBAKAZA_DEBUG_LOG_H

View File

@@ -1,21 +1,35 @@
#include <memory>
#include <codecvt>
#include <locale>
#include <sstream>
#include "../include/akaza.h"
#include "debug_log.h"
#include "kana.h"
static std::string tojkata(const std::string &src) {
std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t> cnv;
std::wstring wstr = cnv.from_bytes(src);
std::transform(wstr.begin(), wstr.end(), wstr.begin(), [](wchar_t c) {
return std::towctrans(c, std::wctrans("tojkata"));
});
std::string sstr = cnv.to_bytes(wstr);
D(std::cout << "TOJKATA: " << src << " -> " << sstr
<< " " << __FILE__ << ":" << __LINE__ << std::endl);
return sstr;
akaza::GraphResolver::GraphResolver(const std::shared_ptr<UserLanguageModel> &user_language_model,
const std::shared_ptr<SystemUnigramLM> &system_unigram_lm,
const std::shared_ptr<SystemBigramLM> &system_bigram_lm,
const std::vector<std::shared_ptr<BinaryDict>> &normal_dicts,
const std::vector<std::shared_ptr<BinaryDict>> &single_term_dicts) {
user_language_model_ = user_language_model;
_system_unigram_lm = system_unigram_lm;
_system_bigram_lm = system_bigram_lm;
_normal_dicts = normal_dicts;
_single_term_dicts = single_term_dicts;
D(std::cout << "GraphResolver: "
<< " ULM.uni=" << user_language_model_->size_unigram()
<< " ULM.bi=" << user_language_model->size_bigram()
<< " SystemUnigramLM.size=" << system_unigram_lm->size()
<< " SystemBigramLM.size=" << system_bigram_lm->size());
for (const auto &d: normal_dicts) {
D(std::cout << " ND=" << d->size());
}
for (const auto &d: single_term_dicts) {
D(std::cout << " STD=" << d->size());
}
D(std::cout << std::endl);
}
static inline void insert_basic_candidates(std::set<std::tuple<std::string, std::string>> &kanjiset,
@@ -71,7 +85,7 @@ akaza::GraphResolver::construct_normal_graph(const std::string &s) {
}
}
if (exist_kanjis || _user_language_model->has_unigram_cost_by_yomi(yomi)) {
if (exist_kanjis || user_language_model_->has_unigram_cost_by_yomi(yomi)) {
insert_basic_candidates(kanjiset, yomi);
}
@@ -203,7 +217,7 @@ void akaza::GraphResolver::fill_cost(akaza::Graph &graph) {
continue;
}
D(std::cout << "fill_cost: " << node->get_key() << std::endl);
auto node_cost = node->calc_node_cost(*_user_language_model, *_system_unigram_lm);
auto node_cost = node->calc_node_cost(*user_language_model_, *_system_unigram_lm);
auto cost = INT32_MIN;
auto prev_nodes = graph.get_prev_items(node);
@@ -214,7 +228,7 @@ void akaza::GraphResolver::fill_cost(akaza::Graph &graph) {
// << " " << __FILE__ << ":" << __LINE__ << std::endl);
auto bigram_cost = prev_node->get_bigram_cost(
*node,
*_user_language_model,
*user_language_model_,
*_system_bigram_lm);
auto tmp_cost = prev_node->get_cost() + bigram_cost + node_cost;
if (cost < tmp_cost) { // コストが最大になる経路をえらんでいる
@@ -222,6 +236,7 @@ void akaza::GraphResolver::fill_cost(akaza::Graph &graph) {
shortest_prev = prev_node;
}
}
assert(shortest_prev);
D(std::cout << "[fill_cost] set prev: " << node->get_key() << " " << shortest_prev->get_key()
<< " " << __FILE__ << ":" << __LINE__ << std::endl);
node->set_prev(shortest_prev);
@@ -270,7 +285,7 @@ std::vector<std::vector<std::shared_ptr<akaza::Node>>> akaza::GraphResolver::fin
}
std::vector<std::shared_ptr<akaza::Node>> nodes = graph.get_items_by_start_and_length(node);
auto userLanguageModel = this->_user_language_model;
auto userLanguageModel = this->user_language_model_;
auto systemBigramLm = this->_system_bigram_lm;
std::sort(nodes.begin(), nodes.end(), [last_node, userLanguageModel, systemBigramLm](auto &a, auto &b) {
return a->get_cost() + a->get_bigram_cost(*last_node, *userLanguageModel,
@@ -300,3 +315,9 @@ akaza::GraphResolver::graph_construct(const std::string &s, std::optional<std::v
graph.build(utf32conv.from_bytes(s).size(), nodemap);
return graph;
}
std::string akaza::Slice::repr() {
std::stringstream ss;
ss << "<akaza::Slice start=" << _start << " len=" << _len << ">";
return ss.str();
}

View File

@@ -123,3 +123,11 @@ void akaza::Node::set_prev(std::shared_ptr<Node> &prev) {
assert(this->start_pos != prev->start_pos);
this->_prev = prev;
}
bool akaza::Node::operator==(akaza::Node const &node) {
return this->word == node.word && this->yomi == node.yomi && this->start_pos == node.start_pos;
}
bool akaza::Node::operator!=(akaza::Node const &node) {
return this->word != node.word || this->yomi != node.yomi || this->start_pos != node.start_pos;
}

View File

@@ -13,7 +13,7 @@ static std::string quotemeta(const std::string &input) {
return std::regex_replace(input, specialChars, R"(\$&)");
}
akaza::RomkanConverter::RomkanConverter(const std::vector<std::tuple<std::string, std::string>> &additional) {
akaza::RomkanConverter::RomkanConverter(const std::map<std::string, std::string> &additional) {
// romaji -> hiragana
for (const auto &[rom, hira]: DEFAULT_ROMKAN_H) {
_map[rom] = hira;

View File

@@ -11,15 +11,25 @@ def test_surface():
*/
void test_surface() {
static void test_surface() {
auto lisp = akaza::tinylisp::TinyLisp();
auto node = akaza::Node(0, "たしざんてすと", R"((. "a" "b"))");
ok(node.surface(lisp) == "ab");
}
static void test_eq() {
auto a = akaza::Node(0, "", "");
auto b = akaza::Node(0, "", "");
ok(a == a);
ok(b == b);
ok(b != a);
}
int main() {
test_surface();
test_eq();
done_testing();
}

View File

@@ -17,7 +17,7 @@ def test_remove_last_char(src, expected):
assert romkan.remove_last_char(src) == expected
*/
static void test_remove_last_char() {
std::vector<std::tuple<std::string, std::string>> additional = {
std::map<std::string, std::string> additional = {
};
auto romkan = akaza::RomkanConverter(additional);
@@ -45,7 +45,7 @@ def test_bar(src, expected):
assert romkan.to_hiragana(src) == expected
*/
static void test_to_hiragana() {
std::vector<std::tuple<std::string, std::string>> additional = {
std::map<std::string, std::string> additional = {
};
auto romkan = akaza::RomkanConverter(additional);
@@ -93,7 +93,7 @@ int main() {
test_remove_last_char();
test_to_hiragana();
std::vector<std::tuple<std::string, std::string>> additional = {
std::map<std::string, std::string> additional = {
};
auto romkan = akaza::RomkanConverter(additional);

View File

@@ -53,7 +53,7 @@ int main() {
single_term_dicts
);
std::vector<std::tuple<std::string, std::string>> additional = {};
std::map<std::string, std::string> additional = {};
auto romkan = std::make_shared<akaza::RomkanConverter>(additional);
akaza::Akaza akaza = akaza::Akaza(graphResolver, romkan);

View File

@@ -21,6 +21,8 @@ std::string convert_test(const std::string &src, const std::string &expected) {
}
int main() {
convert_test("tanosiijikan", "楽しい時間");
convert_test("たのしいじかん", "楽しい時間");
convert_test("zh", "");
convert_test("それなwww", "それなwww");
convert_test("watasinonamaehanakanodesu.", "私の名前は中野です。");

27
libakaza/t/11_wnn.cc Normal file
View File

@@ -0,0 +1,27 @@
#include "../include/akaza.h"
#include "../picotest/picotest.h"
#include "../picotest/picotest.c"
#include "test_akaza.h"
#include <filesystem>
std::string convert_test(const std::string &src, const std::string &expected) {
auto akaza = build_akaza();
std::vector<std::vector<std::shared_ptr<akaza::Node>>> result = akaza->convert(
src,
std::nullopt);
std::string retval;
for (const auto &nodes: result) {
retval += nodes[0]->get_word();
}
note("RESULT: src=%s got=%s expected=%s", src.c_str(), retval.c_str(), expected.c_str());
ok(expected == retval);
assert(expected == retval);
return retval;
}
int main() {
convert_test("わたしのなまえはなかのです。", "私の名前は中野です。");
convert_test("わたしのなまえはなかのです", "私の名前は中野です");
done_testing();
}

View File

@@ -41,7 +41,7 @@ static std::shared_ptr<akaza::GraphResolver> build_graph_resolver() {
static std::unique_ptr<akaza::Akaza> build_akaza() {
auto graph_resolver = build_graph_resolver();
std::vector<std::tuple<std::string, std::string>> additional;
std::map<std::string, std::string> additional;
auto romkanConverter = std::make_shared<akaza::RomkanConverter>(additional);
return std::make_unique<akaza::Akaza>(graph_resolver, romkanConverter);

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python
import pathlib
from tempfile import TemporaryDirectory
import sys
@@ -9,43 +10,37 @@ print(path)
sys.path.insert(0, path)
sys.path.insert(0, pathlib.Path(__file__).parent.parent.parent.joinpath('akaza-data').absolute())
import akaza
from akaza.dictionary import Dictionary
from akaza.graph_resolver import GraphResolver
from akaza.language_model import LanguageModel
from akaza.romkan import RomkanConverter
from akaza.user_language_model import UserLanguageModel
from akaza_data import SystemLanguageModel, SystemDict
from akaza_data.emoji import EmojiDict
from pyakaza.bind import Akaza, GraphResolver, BinaryDict, SystemUnigramLM, SystemBigramLM, Node, UserLanguageModel, \
RomkanConverter
system_language_model = SystemLanguageModel.load()
system_dict = SystemDict.load()
emoji_dict = EmojiDict.load()
tmpdir = TemporaryDirectory()
user_language_model_path = pathlib.Path('/tmp/user_language_model')
user_language_model_path.mkdir(parents=True, exist_ok=True, mode=0o700)
user_language_model = UserLanguageModel(str(user_language_model_path))
language_model = LanguageModel(
system_language_model=system_language_model,
user_language_model=user_language_model,
user_language_model = UserLanguageModel(
tmpdir.name + "/uni",
tmpdir.name + "/bi"
)
dictionary = Dictionary(
system_dict=system_dict,
emoji_dict=emoji_dict,
user_dicts=[],
)
system_unigram_lm = SystemUnigramLM()
system_unigram_lm.load("../akaza-data/akaza_data/data/lm_v2_1gram.trie")
system_bigram_lm = SystemBigramLM()
system_bigram_lm.load("../akaza-data/akaza_data/data/lm_v2_2gram.trie")
system_dict = BinaryDict()
system_dict.load("../akaza-data/akaza_data/data/system_dict.trie")
single_term = BinaryDict()
single_term.load("../akaza-data/akaza_data/data/single_term.trie")
resolver = GraphResolver(
dictionary=dictionary,
language_model=language_model,
)
romkan = RomkanConverter()
akaza = akaza.Akaza(
resolver=resolver,
romkan=romkan,
user_language_model,
system_unigram_lm,
system_bigram_lm,
[system_dict],
[single_term],
)
romkanConverter = RomkanConverter({})
akaza = Akaza(resolver, romkanConverter)
# for i in range(10):
# for line in ['watasinonamaehanakanodseu', 'tonarinokyakuhayokukakikuukyakuda', 'kyounotenkihakumoridana.',
@@ -62,7 +57,7 @@ print("START")
cmp_res_dict = cmpthese(
10,
{
"term1": lambda: akaza.convert('nagakunattekuruto,henkannninyojitunijikanngakakaruyouninattekuru.'),
"term1": lambda: akaza.convert('nagakunattekuruto,henkannninyojitunijikanngakakaruyouninattekuru.', None),
},
repeat=10,
)

View File

@@ -5,14 +5,15 @@ os.system("make binary")
setup(
name="pyakaza",
version="0.0.2",
version="0.0.3",
install_requires=[],
packages=['pyakaza'],
package_data={
'akaza_data': ['*.so'],
'pyakaza': ['*.so'],
},
extras_require={
},
entry_points={
}
},
zip_safe=False
)

View File

@@ -12,11 +12,12 @@ PYBIND11_MODULE(bind, m) {
.def(py::init<std::shared_ptr<akaza::GraphResolver> &,
std::shared_ptr<akaza::RomkanConverter> &>())
.def("convert", &akaza::Akaza::convert)
.def("get_version", &akaza::Akaza::get_version)
;
.def("get_version", &akaza::Akaza::get_version);
py::class_<akaza::RomkanConverter, std::shared_ptr<akaza::RomkanConverter>>(m, "RomkanConverter")
.def(py::init<const std::vector<std::tuple<std::string, std::string>> &>());
.def(py::init<const std::map<std::string, std::string> &>())
.def("to_hiragana", &akaza::RomkanConverter::to_hiragana)
.def("remove_last_char", &akaza::RomkanConverter::remove_last_char);
py::class_<akaza::SystemUnigramLM, std::shared_ptr<akaza::SystemUnigramLM>>(m, "SystemUnigramLM")
.def(py::init())
@@ -47,6 +48,7 @@ PYBIND11_MODULE(bind, m) {
py::class_<akaza::Node, std::shared_ptr<akaza::Node>>(m, "Node")
.def(py::init<size_t, const std::string &, const std::string &>())
.def("__eq__", &akaza::Node::operator==, py::is_operator())
.def("get_key", &akaza::Node::get_key)
.def("is_bos", &akaza::Node::is_bos)
.def("is_eos", &akaza::Node::is_eos)
@@ -58,7 +60,12 @@ PYBIND11_MODULE(bind, m) {
.def("get_prev", &akaza::Node::get_prev)
.def("calc_node_cost", &akaza::Node::calc_node_cost)
.def("get_bigram_cost", &akaza::Node::get_bigram_cost)
.def("get_word_id", &akaza::Node::get_word_id);
.def("get_word_id", &akaza::Node::get_word_id)
.def("__repr__",
[](const akaza::Node &node) {
return "<akaza::Node yomi= '" + node.get_yomi() + " word=" + node.get_word() + "'>";
}
);
py::class_<akaza::GraphResolver, std::shared_ptr<akaza::GraphResolver>>(m, "GraphResolver")
.def(py::init<const std::shared_ptr<akaza::UserLanguageModel> &,
@@ -82,5 +89,6 @@ PYBIND11_MODULE(bind, m) {
.def("should_save", &akaza::UserLanguageModel::should_save);
py::class_<akaza::Slice, std::shared_ptr<akaza::Slice>>(m, "Slice")
.def(py::init<size_t, size_t>());
.def(py::init<size_t, size_t>())
.def("__repr__", &akaza::Slice::repr);
}

View File

@@ -39,15 +39,14 @@ def test_wnn():
[system_dict],
[single_term],
)
romkanConverter = RomkanConverter([])
romkanConverter = RomkanConverter({})
akaza = Akaza(resolver, romkanConverter)
src = 'わたしのなまえはなかのです'
expected = '私の名前は中野です'
src = 'わたしのなまえはなかのです'
expected = '私の名前は中野です'
print(akaza.get_version())
got = akaza.convert(src, None)
print([c[0].get_word() for c in got])
assert got == expected
assert ''.join([c[0].get_word() for c in got]) == expected