mirror of
https://github.com/mii443/akaza.git
synced 2025-08-24 07:39:23 +00:00
use bigram based dict.
This commit is contained in:
@ -36,6 +36,10 @@ class LanguageModel:
|
||||
@functools.lru_cache
|
||||
def calc_bigram_cost(self, prev_node, next_node) -> float:
|
||||
# self → node で処理する。
|
||||
u = self.user_dict.get_bigram_cost(prev_node, next_node)
|
||||
if u:
|
||||
self.logger.info(f"Use user's bigram score: {prev_node.get_key()},{next_node.get_key()} -> {u}")
|
||||
return u
|
||||
return self.system_bigram_score.get(
|
||||
f"{prev_node.get_key()}\t{next_node.get_key()}", DEFAULT_SCORE
|
||||
)[0][0]
|
||||
|
@ -22,40 +22,74 @@ class UserDict:
|
||||
|
||||
self.unigram = {}
|
||||
if os.path.exists(self.unigram_path()):
|
||||
self.read()
|
||||
self.read_unigram()
|
||||
else:
|
||||
self.total = 0
|
||||
|
||||
self.bigram = {}
|
||||
self.bigram_total = {}
|
||||
if os.path.exists(self.bigram_path()):
|
||||
self.read_bigram()
|
||||
|
||||
def unigram_path(self):
|
||||
return os.path.join(self.path, 'unigram.txt')
|
||||
|
||||
def read(self):
|
||||
def bigram_path(self):
|
||||
return os.path.join(self.path, 'bigram.txt')
|
||||
|
||||
def read_unigram(self):
|
||||
total = 0
|
||||
with open(self.unigram_path()) as fp:
|
||||
for line in fp:
|
||||
m = line.rstrip().split(" ")
|
||||
m = line.rstrip().split("\t")
|
||||
if len(m) == 2:
|
||||
kanji_kana, count = m
|
||||
count = int(count)
|
||||
self.unigram[kanji_kana] = count
|
||||
total += count
|
||||
self.total = total
|
||||
|
||||
def read_bigram(self):
|
||||
with open(self.bigram_path()) as fp:
|
||||
for line in fp:
|
||||
m = line.rstrip().split("\t")
|
||||
if len(m) == 3:
|
||||
word1, word2, count = m
|
||||
count = int(count)
|
||||
self.bigram[f"{word1}\t{word2}"] = count
|
||||
self.bigram_total[word1] = self.bigram_total.get(word1, 0) + 1
|
||||
|
||||
def add_entry(self, nodes: List[Node]):
|
||||
# unigram
|
||||
for node in nodes:
|
||||
kanji = node.word
|
||||
kana = node.yomi
|
||||
|
||||
self.logger.info(f"add user_dict entry: kana='{kana}' kanji='{kanji}'")
|
||||
|
||||
key = f"{kanji}/{kana}"
|
||||
key = node.get_key()
|
||||
self.unigram[key] = self.unigram.get(key, 0) + 1
|
||||
self.total += 1
|
||||
|
||||
# bigram
|
||||
for i in range(1, len(nodes)):
|
||||
node1 = nodes[i - 1]
|
||||
node2 = nodes[i]
|
||||
key = node1.get_key() + "\t" + node2.get_key()
|
||||
self.bigram[key] = self.bigram.get(key, 0) + 1
|
||||
self.bigram_total[node1.get_key()] = self.bigram_total.get(node1.get_key(), 0) + 1
|
||||
|
||||
def save(self):
|
||||
with atomic_write(self.unigram_path(), overwrite=True) as f:
|
||||
for kanji_kana in sorted(self.unigram.keys()):
|
||||
count = self.unigram[kanji_kana]
|
||||
f.write(f"{kanji_kana}\t{count}\n")
|
||||
|
||||
with atomic_write(self.bigram_path(), overwrite=True) as f:
|
||||
for words in sorted(self.bigram.keys()):
|
||||
count = self.bigram[words]
|
||||
f.write(f"{words}\t{count}\n")
|
||||
|
||||
self.logger.info(f"SAVED {self.path}")
|
||||
|
||||
def get_unigram_cost(self, key: str) -> Optional[float]:
|
||||
@ -63,3 +97,12 @@ class UserDict:
|
||||
count = self.unigram[key]
|
||||
return math.log10(count / self.total)
|
||||
return None
|
||||
|
||||
def get_bigram_cost(self, node1: Node, node2: Node) -> Optional[float]:
|
||||
key1 = node1.get_key()
|
||||
key2 = node2.get_key()
|
||||
key = key1 + "\t" + key2
|
||||
if key in self.bigram:
|
||||
count = self.bigram[key]
|
||||
return math.log10(count / self.bigram_total[key1])
|
||||
return None
|
||||
|
@ -19,3 +19,27 @@ def test_read():
|
||||
d.add_entry([Node(start_pos=0, word='熟語', yomi='じゅくご')])
|
||||
assert d.unigram == {'単語/たんご': 2, '熟語/じゅくご': 1}
|
||||
assert d.total == 3
|
||||
|
||||
|
||||
def test_read2():
|
||||
tmpdir = TemporaryDirectory()
|
||||
d = UserDict(tmpdir.name + "/foobar.dict")
|
||||
d.add_entry([
|
||||
Node(start_pos=0, word='私', yomi='わたし'),
|
||||
Node(start_pos=1, word='だよ', yomi='だよ'),
|
||||
])
|
||||
d.add_entry([
|
||||
Node(start_pos=0, word='それは', yomi='それは'),
|
||||
Node(start_pos=3, word='私', yomi='わたし'),
|
||||
Node(start_pos=4, word='だよ', yomi='だよ'),
|
||||
])
|
||||
d.add_entry([
|
||||
Node(start_pos=0, word='私', yomi='わたし'),
|
||||
Node(start_pos=1, word='です', yomi='です'),
|
||||
])
|
||||
|
||||
assert d.unigram == {'それは/それは': 1, 'だよ/だよ': 2, '私/わたし': 3, 'です/です': 1}
|
||||
assert d.total == 7
|
||||
|
||||
assert d.bigram == {'それは/それは\t私/わたし': 1, '私/わたし\tだよ/だよ': 2, '私/わたし\tです/です': 1}
|
||||
assert d.bigram_total == {'それは/それは': 1, '私/わたし': 3}
|
||||
|
Reference in New Issue
Block a user