use bigram based dict.

This commit is contained in:
Tokuhiro Matsuno
2020-09-12 11:36:22 +09:00
parent 87585f06ae
commit 4c3efffe37
3 changed files with 75 additions and 4 deletions

View File

@ -36,6 +36,10 @@ class LanguageModel:
@functools.lru_cache
def calc_bigram_cost(self, prev_node, next_node) -> float:
# self → node で処理する。
u = self.user_dict.get_bigram_cost(prev_node, next_node)
if u:
self.logger.info(f"Use user's bigram score: {prev_node.get_key()},{next_node.get_key()} -> {u}")
return u
return self.system_bigram_score.get(
f"{prev_node.get_key()}\t{next_node.get_key()}", DEFAULT_SCORE
)[0][0]

View File

@ -22,40 +22,74 @@ class UserDict:
self.unigram = {}
if os.path.exists(self.unigram_path()):
self.read()
self.read_unigram()
else:
self.total = 0
self.bigram = {}
self.bigram_total = {}
if os.path.exists(self.bigram_path()):
self.read_bigram()
def unigram_path(self):
return os.path.join(self.path, 'unigram.txt')
def read(self):
def bigram_path(self):
return os.path.join(self.path, 'bigram.txt')
def read_unigram(self):
total = 0
with open(self.unigram_path()) as fp:
for line in fp:
m = line.rstrip().split(" ")
m = line.rstrip().split("\t")
if len(m) == 2:
kanji_kana, count = m
count = int(count)
self.unigram[kanji_kana] = count
total += count
self.total = total
def read_bigram(self):
with open(self.bigram_path()) as fp:
for line in fp:
m = line.rstrip().split("\t")
if len(m) == 3:
word1, word2, count = m
count = int(count)
self.bigram[f"{word1}\t{word2}"] = count
self.bigram_total[word1] = self.bigram_total.get(word1, 0) + 1
def add_entry(self, nodes: List[Node]):
# unigram
for node in nodes:
kanji = node.word
kana = node.yomi
self.logger.info(f"add user_dict entry: kana='{kana}' kanji='{kanji}'")
key = f"{kanji}/{kana}"
key = node.get_key()
self.unigram[key] = self.unigram.get(key, 0) + 1
self.total += 1
# bigram
for i in range(1, len(nodes)):
node1 = nodes[i - 1]
node2 = nodes[i]
key = node1.get_key() + "\t" + node2.get_key()
self.bigram[key] = self.bigram.get(key, 0) + 1
self.bigram_total[node1.get_key()] = self.bigram_total.get(node1.get_key(), 0) + 1
def save(self):
with atomic_write(self.unigram_path(), overwrite=True) as f:
for kanji_kana in sorted(self.unigram.keys()):
count = self.unigram[kanji_kana]
f.write(f"{kanji_kana}\t{count}\n")
with atomic_write(self.bigram_path(), overwrite=True) as f:
for words in sorted(self.bigram.keys()):
count = self.bigram[words]
f.write(f"{words}\t{count}\n")
self.logger.info(f"SAVED {self.path}")
def get_unigram_cost(self, key: str) -> Optional[float]:
@ -63,3 +97,12 @@ class UserDict:
count = self.unigram[key]
return math.log10(count / self.total)
return None
def get_bigram_cost(self, node1: Node, node2: Node) -> Optional[float]:
key1 = node1.get_key()
key2 = node2.get_key()
key = key1 + "\t" + key2
if key in self.bigram:
count = self.bigram[key]
return math.log10(count / self.bigram_total[key1])
return None

View File

@ -19,3 +19,27 @@ def test_read():
d.add_entry([Node(start_pos=0, word='熟語', yomi='じゅくご')])
assert d.unigram == {'単語/たんご': 2, '熟語/じゅくご': 1}
assert d.total == 3
def test_read2():
tmpdir = TemporaryDirectory()
d = UserDict(tmpdir.name + "/foobar.dict")
d.add_entry([
Node(start_pos=0, word='', yomi='わたし'),
Node(start_pos=1, word='だよ', yomi='だよ'),
])
d.add_entry([
Node(start_pos=0, word='それは', yomi='それは'),
Node(start_pos=3, word='', yomi='わたし'),
Node(start_pos=4, word='だよ', yomi='だよ'),
])
d.add_entry([
Node(start_pos=0, word='', yomi='わたし'),
Node(start_pos=1, word='です', yomi='です'),
])
assert d.unigram == {'それは/それは': 1, 'だよ/だよ': 2, '私/わたし': 3, 'です/です': 1}
assert d.total == 7
assert d.bigram == {'それは/それは\t私/わたし': 1, '私/わたし\tだよ/だよ': 2, '私/わたし\tです/です': 1}
assert d.bigram_total == {'それは/それは': 1, '私/わたし': 3}