Files
akaza/comb/user_language_model.py
2020-09-13 01:54:54 +09:00

113 lines
3.4 KiB
Python

import logging
import math
import os
from typing import List, Dict, Optional
from atomicwrites import atomic_write
from comb.node import Node
# ユーザーの言語モデル。
#
# unigram score
# bigram score
class UserLanguageModel:
unigram: Dict[str, int]
def __init__(self, path, logger=logging.getLogger(__name__)):
self.path = path
self.logger = logger
self.unigram = {}
if os.path.exists(self.unigram_path()):
self.read_unigram()
else:
self.total = 0
self.bigram = {}
self.bigram_total = {}
if os.path.exists(self.bigram_path()):
self.read_bigram()
def unigram_path(self):
return os.path.join(self.path, 'unigram.txt')
def bigram_path(self):
return os.path.join(self.path, 'bigram.txt')
def read_unigram(self):
total = 0
with open(self.unigram_path()) as fp:
for line in fp:
m = line.rstrip().split("\t")
if len(m) == 2:
kanji_kana, count = m
count = int(count)
self.unigram[kanji_kana] = count
total += count
self.total = total
def read_bigram(self):
with open(self.bigram_path()) as fp:
for line in fp:
m = line.rstrip().split("\t")
if len(m) == 3:
word1, word2, count = m
count = int(count)
self.bigram[f"{word1}\t{word2}"] = count
self.bigram_total[word1] = self.bigram_total.get(word1, 0) + 1
def add_entry(self, nodes: List[Node]):
# unigram
for node in nodes:
kanji = node.word
kana = node.yomi
self.logger.info(f"add user_dict entry: kana='{kana}' kanji='{kanji}'")
key = node.get_key()
self.unigram[key] = self.unigram.get(key, 0) + 1
self.total += 1
# bigram
for i in range(1, len(nodes)):
node1 = nodes[i - 1]
node2 = nodes[i]
key = node1.get_key() + "\t" + node2.get_key()
self.bigram[key] = self.bigram.get(key, 0) + 1
self.bigram_total[node1.get_key()] = self.bigram_total.get(node1.get_key(), 0) + 1
# TODO save in background thread...
# TODO: bulk save in every 1 minute.
def save(self):
with atomic_write(self.unigram_path(), overwrite=True) as f:
for kanji_kana in sorted(self.unigram.keys()):
count = self.unigram[kanji_kana]
f.write(f"{kanji_kana}\t{count}\n")
with atomic_write(self.bigram_path(), overwrite=True) as f:
for words in sorted(self.bigram.keys()):
count = self.bigram[words]
f.write(f"{words}\t{count}\n")
self.logger.info(f"SAVED {self.path}")
def get_unigram_cost(self, key: str) -> Optional[float]:
if key in self.unigram:
count = self.unigram[key]
return math.log10(count / self.total)
return None
def has_unigram_cost(self, key: str) -> bool:
return key in self.unigram
def get_bigram_cost(self, node1: Node, node2: Node) -> Optional[float]:
key1 = node1.get_key()
key2 = node2.get_key()
key = key1 + "\t" + key2
if key in self.bigram:
count = self.bigram[key]
return math.log10(count / self.bigram_total[key1])
return None