mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-12-03 02:58:27 +00:00
Update Prediction logic
This commit is contained in:
@@ -36,4 +36,8 @@ public enum CIDData: Sendable {
|
||||
case .EOS: return 1316
|
||||
}
|
||||
}
|
||||
|
||||
public static func isJoshi(cid: Int) -> Bool {
|
||||
return 147 <= cid && cid <= 368
|
||||
}
|
||||
}
|
||||
|
||||
@@ -617,28 +617,53 @@ import SwiftUtils
|
||||
|
||||
/// 変換確定後の予測変換候補を要求する関数
|
||||
public func requestPredictionCandidates(leftSideCandidate: Candidate, options: ConvertRequestOptions) -> [PredictionCandidate] {
|
||||
var seenCandidates: Set<String> = []
|
||||
// ゼロヒント予測変換に基づく候補を列挙
|
||||
let zeroHintResults = self.getUniquePredictionCandidate(self.converter.getZeroHintPredictionCandidates(preparts: [leftSideCandidate], N_best: 15))
|
||||
seenCandidates.formUnion(zeroHintResults.map{$0.text})
|
||||
var zeroHintResults = self.getUniquePredictionCandidate(self.converter.getZeroHintPredictionCandidates(preparts: [leftSideCandidate], N_best: 15))
|
||||
do {
|
||||
// 助詞は最大3つに制限する
|
||||
var joshiCount = 0
|
||||
zeroHintResults = zeroHintResults.reduce(into: []) { results, candidate in
|
||||
switch candidate.type {
|
||||
case .additional(data: let data):
|
||||
if CIDData.isJoshi(cid: data.last?.rcid ?? CIDData.EOS.cid) {
|
||||
if joshiCount < 3 {
|
||||
results.append(candidate)
|
||||
joshiCount += 1
|
||||
}
|
||||
} else {
|
||||
results.append(candidate)
|
||||
}
|
||||
case .replacement:
|
||||
results.append(candidate)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 予測変換に基づく候補を列挙
|
||||
let predictionResults = self.getUniquePredictionCandidate(self.converter.getPredictionCandidates(prepart: leftSideCandidate, N_best: 15), seenCandidates: seenCandidates)
|
||||
seenCandidates.formUnion(predictionResults.map{$0.text})
|
||||
let predictionResults = self.converter.getPredictionCandidates(prepart: leftSideCandidate, N_best: 15)
|
||||
// 絵文字を追加
|
||||
let replacer = TextReplacer()
|
||||
var emojiCandidates: [PredictionCandidate] = []
|
||||
for data in leftSideCandidate.data where DicdataStore.includeMMValueCalculation(data) {
|
||||
let result = replacer.getSearchResult(query: data.word, target: [.emoji], ignoreNonBaseEmoji: true)
|
||||
for emoji in result {
|
||||
emojiCandidates.append(.additional(.init(text: emoji.text, data: [.init(ruby: "エモジ", cid: CIDData.記号.cid, mid: MIDData.一般.mid, value: -3)], value: -3)))
|
||||
emojiCandidates.append(PredictionCandidate(text: emoji.text, value: -3, type: .additional(data: [.init(word: emoji.text, ruby: "エモジ", cid: CIDData.記号.cid, mid: MIDData.一般.mid, value: -3)])))
|
||||
}
|
||||
}
|
||||
emojiCandidates = self.getUniquePredictionCandidate(emojiCandidates, seenCandidates: seenCandidates)
|
||||
|
||||
emojiCandidates = self.getUniquePredictionCandidate(emojiCandidates)
|
||||
|
||||
var results: [PredictionCandidate] = []
|
||||
var seenCandidates: Set<String> = []
|
||||
|
||||
results.append(contentsOf: emojiCandidates.suffix(3))
|
||||
results.append(contentsOf: predictionResults.min(count: (10 - results.count) / 2, sortedBy: {$0.value > $1.value}))
|
||||
results.append(contentsOf: zeroHintResults.min(count: 10 - results.count, sortedBy: {$0.value > $1.value}))
|
||||
seenCandidates.formUnion(emojiCandidates.suffix(3).map{$0.text})
|
||||
|
||||
let predictions = self.getUniquePredictionCandidate(predictionResults, seenCandidates: seenCandidates).min(count: (10 - results.count) / 2, sortedBy: {$0.value > $1.value})
|
||||
results.append(contentsOf: predictions)
|
||||
seenCandidates.formUnion(predictions.map{$0.text})
|
||||
|
||||
let zeroHints = self.getUniquePredictionCandidate(zeroHintResults, seenCandidates: seenCandidates)
|
||||
results.append(contentsOf: zeroHints.min(count: 10 - results.count, sortedBy: {$0.value > $1.value}))
|
||||
return results
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ extension Kana2Kanji {
|
||||
}
|
||||
// 共通接頭辞を切り落とす
|
||||
let text = String(data.word.dropFirst(totalWord.count))
|
||||
result.insert(.replacement(.init(text: text, targetData: totalData, replacementData: [data], value: newValue)), at: lastindex)
|
||||
result.insert(.init(text: text, value: newValue, type: .replacement(targetData: totalData, replacementData: [data])), at: lastindex)
|
||||
}
|
||||
}
|
||||
return result
|
||||
@@ -111,15 +111,15 @@ extension Kana2Kanji {
|
||||
let newValue = candidate.value + mmValue + ccValue + wValue
|
||||
|
||||
// 追加すべきindexを取得する
|
||||
let lastindex: Int = (result.lastIndex(where: {$0.value >= newValue}) ?? -1) + 1
|
||||
if lastindex == N_best {
|
||||
let lastIndex: Int = (result.lastIndex(where: {$0.value >= newValue}) ?? -1) + 1
|
||||
if lastIndex == N_best {
|
||||
continue
|
||||
}
|
||||
// カウントがオーバーしている場合は除去する
|
||||
if result.count >= N_best {
|
||||
result.removeLast()
|
||||
}
|
||||
result.insert(.additional(.init(text: data.word, data: [data], value: newValue)), at: lastindex)
|
||||
result.insert(.init(text: data.word, value: newValue, type: .additional(data: [data])), at: lastIndex)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,63 +7,48 @@
|
||||
|
||||
import Foundation
|
||||
|
||||
public enum PredictionCandidate: Sendable, Hashable {
|
||||
case additional(AdditionalPredictionCandidate)
|
||||
case replacement(ReplacementPredictionCandidate)
|
||||
|
||||
public struct AdditionalPredictionCandidate: Sendable, Hashable {
|
||||
public var text: String
|
||||
public var data: [DicdataElement]
|
||||
public var value: PValue
|
||||
}
|
||||
public struct ReplacementPredictionCandidate: Sendable, Hashable {
|
||||
/// 予測変換として表示するデータ
|
||||
public var text: String
|
||||
/// 置換対象のデータ
|
||||
public var targetData: [DicdataElement]
|
||||
/// 置換後のデータ
|
||||
public var replacementData: [DicdataElement]
|
||||
/// 重み
|
||||
public var value: PValue
|
||||
}
|
||||
|
||||
public var value: PValue {
|
||||
switch self {
|
||||
case .additional(let c):
|
||||
c.value
|
||||
case .replacement(let c):
|
||||
c.value
|
||||
public struct PredictionCandidate {
|
||||
public init(text: String, value: PValue, type: PredictionCandidate.PredictionType) {
|
||||
self.text = text
|
||||
self.value = value
|
||||
self.type = type
|
||||
if Set(["。", ".", "."]).contains(text) {
|
||||
self.isTerminal = true
|
||||
} else {
|
||||
self.isTerminal = false
|
||||
}
|
||||
}
|
||||
|
||||
public var text: String {
|
||||
switch self {
|
||||
case .additional(let c):
|
||||
c.text
|
||||
case .replacement(let c):
|
||||
c.text
|
||||
}
|
||||
}
|
||||
public var text: String
|
||||
public var value: PValue
|
||||
public var type: PredictionType
|
||||
public var isTerminal: Bool
|
||||
|
||||
public func join(to candidate: consuming Candidate) -> Candidate {
|
||||
switch self {
|
||||
case .additional(let c):
|
||||
for data in c.data {
|
||||
switch self.type {
|
||||
case .additional(let data):
|
||||
for data in data {
|
||||
candidate.text.append(contentsOf: data.word)
|
||||
candidate.data.append(data)
|
||||
}
|
||||
candidate.value = c.value
|
||||
candidate.value = self.value
|
||||
candidate.correspondingCount = candidate.data.reduce(into: 0) { $0 += $1.ruby.count }
|
||||
candidate.lastMid = c.data.last(where: DicdataStore.includeMMValueCalculation)?.mid ?? candidate.lastMid
|
||||
candidate.lastMid = data.last(where: DicdataStore.includeMMValueCalculation)?.mid ?? candidate.lastMid
|
||||
return candidate
|
||||
case .replacement(let c):
|
||||
candidate.data.removeLast(c.targetData.count)
|
||||
candidate.data.append(contentsOf: c.replacementData)
|
||||
case .replacement(let targetData, let replacementData):
|
||||
candidate.data.removeLast(targetData.count)
|
||||
candidate.data.append(contentsOf: replacementData)
|
||||
candidate.text = candidate.data.reduce(into: "") {$0 += $1.word}
|
||||
candidate.value = c.value
|
||||
candidate.value = self.value
|
||||
candidate.lastMid = candidate.data.last(where: DicdataStore.includeMMValueCalculation)?.mid ?? MIDData.BOS.mid
|
||||
candidate.correspondingCount = candidate.data.reduce(into: 0) { $0 += $1.ruby.count }
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
|
||||
public enum PredictionType: Sendable, Hashable {
|
||||
case additional(data: [DicdataElement])
|
||||
case replacement(targetData: [DicdataElement], replacementData: [DicdataElement])
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user