mirror of
https://github.com/mii443/akaza.git
synced 2025-08-24 15:49:25 +00:00
use bigram based dict.
This commit is contained in:
@ -36,6 +36,10 @@ class LanguageModel:
|
|||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def calc_bigram_cost(self, prev_node, next_node) -> float:
|
def calc_bigram_cost(self, prev_node, next_node) -> float:
|
||||||
# self → node で処理する。
|
# self → node で処理する。
|
||||||
|
u = self.user_dict.get_bigram_cost(prev_node, next_node)
|
||||||
|
if u:
|
||||||
|
self.logger.info(f"Use user's bigram score: {prev_node.get_key()},{next_node.get_key()} -> {u}")
|
||||||
|
return u
|
||||||
return self.system_bigram_score.get(
|
return self.system_bigram_score.get(
|
||||||
f"{prev_node.get_key()}\t{next_node.get_key()}", DEFAULT_SCORE
|
f"{prev_node.get_key()}\t{next_node.get_key()}", DEFAULT_SCORE
|
||||||
)[0][0]
|
)[0][0]
|
||||||
|
@ -22,40 +22,74 @@ class UserDict:
|
|||||||
|
|
||||||
self.unigram = {}
|
self.unigram = {}
|
||||||
if os.path.exists(self.unigram_path()):
|
if os.path.exists(self.unigram_path()):
|
||||||
self.read()
|
self.read_unigram()
|
||||||
else:
|
else:
|
||||||
self.total = 0
|
self.total = 0
|
||||||
|
|
||||||
|
self.bigram = {}
|
||||||
|
self.bigram_total = {}
|
||||||
|
if os.path.exists(self.bigram_path()):
|
||||||
|
self.read_bigram()
|
||||||
|
|
||||||
def unigram_path(self):
|
def unigram_path(self):
|
||||||
return os.path.join(self.path, 'unigram.txt')
|
return os.path.join(self.path, 'unigram.txt')
|
||||||
|
|
||||||
def read(self):
|
def bigram_path(self):
|
||||||
|
return os.path.join(self.path, 'bigram.txt')
|
||||||
|
|
||||||
|
def read_unigram(self):
|
||||||
total = 0
|
total = 0
|
||||||
with open(self.unigram_path()) as fp:
|
with open(self.unigram_path()) as fp:
|
||||||
for line in fp:
|
for line in fp:
|
||||||
m = line.rstrip().split(" ")
|
m = line.rstrip().split("\t")
|
||||||
if len(m) == 2:
|
if len(m) == 2:
|
||||||
kanji_kana, count = m
|
kanji_kana, count = m
|
||||||
|
count = int(count)
|
||||||
self.unigram[kanji_kana] = count
|
self.unigram[kanji_kana] = count
|
||||||
total += count
|
total += count
|
||||||
self.total = total
|
self.total = total
|
||||||
|
|
||||||
|
def read_bigram(self):
|
||||||
|
with open(self.bigram_path()) as fp:
|
||||||
|
for line in fp:
|
||||||
|
m = line.rstrip().split("\t")
|
||||||
|
if len(m) == 3:
|
||||||
|
word1, word2, count = m
|
||||||
|
count = int(count)
|
||||||
|
self.bigram[f"{word1}\t{word2}"] = count
|
||||||
|
self.bigram_total[word1] = self.bigram_total.get(word1, 0) + 1
|
||||||
|
|
||||||
def add_entry(self, nodes: List[Node]):
|
def add_entry(self, nodes: List[Node]):
|
||||||
|
# unigram
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
kanji = node.word
|
kanji = node.word
|
||||||
kana = node.yomi
|
kana = node.yomi
|
||||||
|
|
||||||
self.logger.info(f"add user_dict entry: kana='{kana}' kanji='{kanji}'")
|
self.logger.info(f"add user_dict entry: kana='{kana}' kanji='{kanji}'")
|
||||||
|
|
||||||
key = f"{kanji}/{kana}"
|
key = node.get_key()
|
||||||
self.unigram[key] = self.unigram.get(key, 0) + 1
|
self.unigram[key] = self.unigram.get(key, 0) + 1
|
||||||
self.total += 1
|
self.total += 1
|
||||||
|
|
||||||
|
# bigram
|
||||||
|
for i in range(1, len(nodes)):
|
||||||
|
node1 = nodes[i - 1]
|
||||||
|
node2 = nodes[i]
|
||||||
|
key = node1.get_key() + "\t" + node2.get_key()
|
||||||
|
self.bigram[key] = self.bigram.get(key, 0) + 1
|
||||||
|
self.bigram_total[node1.get_key()] = self.bigram_total.get(node1.get_key(), 0) + 1
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
with atomic_write(self.unigram_path(), overwrite=True) as f:
|
with atomic_write(self.unigram_path(), overwrite=True) as f:
|
||||||
for kanji_kana in sorted(self.unigram.keys()):
|
for kanji_kana in sorted(self.unigram.keys()):
|
||||||
count = self.unigram[kanji_kana]
|
count = self.unigram[kanji_kana]
|
||||||
f.write(f"{kanji_kana}\t{count}\n")
|
f.write(f"{kanji_kana}\t{count}\n")
|
||||||
|
|
||||||
|
with atomic_write(self.bigram_path(), overwrite=True) as f:
|
||||||
|
for words in sorted(self.bigram.keys()):
|
||||||
|
count = self.bigram[words]
|
||||||
|
f.write(f"{words}\t{count}\n")
|
||||||
|
|
||||||
self.logger.info(f"SAVED {self.path}")
|
self.logger.info(f"SAVED {self.path}")
|
||||||
|
|
||||||
def get_unigram_cost(self, key: str) -> Optional[float]:
|
def get_unigram_cost(self, key: str) -> Optional[float]:
|
||||||
@ -63,3 +97,12 @@ class UserDict:
|
|||||||
count = self.unigram[key]
|
count = self.unigram[key]
|
||||||
return math.log10(count / self.total)
|
return math.log10(count / self.total)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_bigram_cost(self, node1: Node, node2: Node) -> Optional[float]:
|
||||||
|
key1 = node1.get_key()
|
||||||
|
key2 = node2.get_key()
|
||||||
|
key = key1 + "\t" + key2
|
||||||
|
if key in self.bigram:
|
||||||
|
count = self.bigram[key]
|
||||||
|
return math.log10(count / self.bigram_total[key1])
|
||||||
|
return None
|
||||||
|
@ -19,3 +19,27 @@ def test_read():
|
|||||||
d.add_entry([Node(start_pos=0, word='熟語', yomi='じゅくご')])
|
d.add_entry([Node(start_pos=0, word='熟語', yomi='じゅくご')])
|
||||||
assert d.unigram == {'単語/たんご': 2, '熟語/じゅくご': 1}
|
assert d.unigram == {'単語/たんご': 2, '熟語/じゅくご': 1}
|
||||||
assert d.total == 3
|
assert d.total == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_read2():
|
||||||
|
tmpdir = TemporaryDirectory()
|
||||||
|
d = UserDict(tmpdir.name + "/foobar.dict")
|
||||||
|
d.add_entry([
|
||||||
|
Node(start_pos=0, word='私', yomi='わたし'),
|
||||||
|
Node(start_pos=1, word='だよ', yomi='だよ'),
|
||||||
|
])
|
||||||
|
d.add_entry([
|
||||||
|
Node(start_pos=0, word='それは', yomi='それは'),
|
||||||
|
Node(start_pos=3, word='私', yomi='わたし'),
|
||||||
|
Node(start_pos=4, word='だよ', yomi='だよ'),
|
||||||
|
])
|
||||||
|
d.add_entry([
|
||||||
|
Node(start_pos=0, word='私', yomi='わたし'),
|
||||||
|
Node(start_pos=1, word='です', yomi='です'),
|
||||||
|
])
|
||||||
|
|
||||||
|
assert d.unigram == {'それは/それは': 1, 'だよ/だよ': 2, '私/わたし': 3, 'です/です': 1}
|
||||||
|
assert d.total == 7
|
||||||
|
|
||||||
|
assert d.bigram == {'それは/それは\t私/わたし': 1, '私/わたし\tだよ/だよ': 2, '私/わたし\tです/です': 1}
|
||||||
|
assert d.bigram_total == {'それは/それは': 1, '私/わたし': 3}
|
||||||
|
Reference in New Issue
Block a user