mirror of
https://github.com/mii443/akaza.git
synced 2025-08-22 23:05:26 +00:00
113 lines
3.4 KiB
Python
113 lines
3.4 KiB
Python
import logging
|
|
import math
|
|
import os
|
|
from typing import List, Dict, Optional
|
|
|
|
from atomicwrites import atomic_write
|
|
|
|
from comb.node import Node
|
|
|
|
|
|
# ユーザーの言語モデル。
|
|
#
|
|
# unigram score
|
|
# bigram score
|
|
class UserLanguageModel:
|
|
unigram: Dict[str, int]
|
|
|
|
def __init__(self, path, logger=logging.getLogger(__name__)):
|
|
self.path = path
|
|
self.logger = logger
|
|
|
|
self.unigram = {}
|
|
if os.path.exists(self.unigram_path()):
|
|
self.read_unigram()
|
|
else:
|
|
self.total = 0
|
|
|
|
self.bigram = {}
|
|
self.bigram_total = {}
|
|
if os.path.exists(self.bigram_path()):
|
|
self.read_bigram()
|
|
|
|
def unigram_path(self):
|
|
return os.path.join(self.path, 'unigram.txt')
|
|
|
|
def bigram_path(self):
|
|
return os.path.join(self.path, 'bigram.txt')
|
|
|
|
def read_unigram(self):
|
|
total = 0
|
|
with open(self.unigram_path()) as fp:
|
|
for line in fp:
|
|
m = line.rstrip().split("\t")
|
|
if len(m) == 2:
|
|
kanji_kana, count = m
|
|
count = int(count)
|
|
self.unigram[kanji_kana] = count
|
|
total += count
|
|
self.total = total
|
|
|
|
def read_bigram(self):
|
|
with open(self.bigram_path()) as fp:
|
|
for line in fp:
|
|
m = line.rstrip().split("\t")
|
|
if len(m) == 3:
|
|
word1, word2, count = m
|
|
count = int(count)
|
|
self.bigram[f"{word1}\t{word2}"] = count
|
|
self.bigram_total[word1] = self.bigram_total.get(word1, 0) + 1
|
|
|
|
def add_entry(self, nodes: List[Node]):
|
|
# unigram
|
|
for node in nodes:
|
|
kanji = node.word
|
|
kana = node.yomi
|
|
|
|
self.logger.info(f"add user_dict entry: kana='{kana}' kanji='{kanji}'")
|
|
|
|
key = node.get_key()
|
|
self.unigram[key] = self.unigram.get(key, 0) + 1
|
|
self.total += 1
|
|
|
|
# bigram
|
|
for i in range(1, len(nodes)):
|
|
node1 = nodes[i - 1]
|
|
node2 = nodes[i]
|
|
key = node1.get_key() + "\t" + node2.get_key()
|
|
self.bigram[key] = self.bigram.get(key, 0) + 1
|
|
self.bigram_total[node1.get_key()] = self.bigram_total.get(node1.get_key(), 0) + 1
|
|
|
|
# TODO save in background thread...
|
|
# TODO: bulk save in every 1 minute.
|
|
def save(self):
|
|
with atomic_write(self.unigram_path(), overwrite=True) as f:
|
|
for kanji_kana in sorted(self.unigram.keys()):
|
|
count = self.unigram[kanji_kana]
|
|
f.write(f"{kanji_kana}\t{count}\n")
|
|
|
|
with atomic_write(self.bigram_path(), overwrite=True) as f:
|
|
for words in sorted(self.bigram.keys()):
|
|
count = self.bigram[words]
|
|
f.write(f"{words}\t{count}\n")
|
|
|
|
self.logger.info(f"SAVED {self.path}")
|
|
|
|
def get_unigram_cost(self, key: str) -> Optional[float]:
|
|
if key in self.unigram:
|
|
count = self.unigram[key]
|
|
return math.log10(count / self.total)
|
|
return None
|
|
|
|
def has_unigram_cost(self, key: str) -> bool:
|
|
return key in self.unigram
|
|
|
|
def get_bigram_cost(self, node1: Node, node2: Node) -> Optional[float]:
|
|
key1 = node1.get_key()
|
|
key2 = node2.get_key()
|
|
key = key1 + "\t" + key2
|
|
if key in self.bigram:
|
|
count = self.bigram[key]
|
|
return math.log10(count / self.bigram_total[key1])
|
|
return None
|