Update Prediction logic

This commit is contained in:
ensan-hcl
2023-09-22 23:00:14 +09:00
parent 6c2a93a26c
commit ae2c193177
4 changed files with 71 additions and 57 deletions

View File

@@ -36,4 +36,8 @@ public enum CIDData: Sendable {
case .EOS: return 1316
}
}
public static func isJoshi(cid: Int) -> Bool {
return 147 <= cid && cid <= 368
}
}

View File

@@ -617,28 +617,53 @@ import SwiftUtils
///
public func requestPredictionCandidates(leftSideCandidate: Candidate, options: ConvertRequestOptions) -> [PredictionCandidate] {
var seenCandidates: Set<String> = []
//
let zeroHintResults = self.getUniquePredictionCandidate(self.converter.getZeroHintPredictionCandidates(preparts: [leftSideCandidate], N_best: 15))
seenCandidates.formUnion(zeroHintResults.map{$0.text})
var zeroHintResults = self.getUniquePredictionCandidate(self.converter.getZeroHintPredictionCandidates(preparts: [leftSideCandidate], N_best: 15))
do {
// 3
var joshiCount = 0
zeroHintResults = zeroHintResults.reduce(into: []) { results, candidate in
switch candidate.type {
case .additional(data: let data):
if CIDData.isJoshi(cid: data.last?.rcid ?? CIDData.EOS.cid) {
if joshiCount < 3 {
results.append(candidate)
joshiCount += 1
}
} else {
results.append(candidate)
}
case .replacement:
results.append(candidate)
}
}
}
//
let predictionResults = self.getUniquePredictionCandidate(self.converter.getPredictionCandidates(prepart: leftSideCandidate, N_best: 15), seenCandidates: seenCandidates)
seenCandidates.formUnion(predictionResults.map{$0.text})
let predictionResults = self.converter.getPredictionCandidates(prepart: leftSideCandidate, N_best: 15)
//
let replacer = TextReplacer()
var emojiCandidates: [PredictionCandidate] = []
for data in leftSideCandidate.data where DicdataStore.includeMMValueCalculation(data) {
let result = replacer.getSearchResult(query: data.word, target: [.emoji], ignoreNonBaseEmoji: true)
for emoji in result {
emojiCandidates.append(.additional(.init(text: emoji.text, data: [.init(ruby: "エモジ", cid: CIDData..cid, mid: MIDData..mid, value: -3)], value: -3)))
emojiCandidates.append(PredictionCandidate(text: emoji.text, value: -3, type: .additional(data: [.init(word: emoji.text, ruby: "エモジ", cid: CIDData..cid, mid: MIDData..mid, value: -3)])))
}
}
emojiCandidates = self.getUniquePredictionCandidate(emojiCandidates, seenCandidates: seenCandidates)
emojiCandidates = self.getUniquePredictionCandidate(emojiCandidates)
var results: [PredictionCandidate] = []
var seenCandidates: Set<String> = []
results.append(contentsOf: emojiCandidates.suffix(3))
results.append(contentsOf: predictionResults.min(count: (10 - results.count) / 2, sortedBy: {$0.value > $1.value}))
results.append(contentsOf: zeroHintResults.min(count: 10 - results.count, sortedBy: {$0.value > $1.value}))
seenCandidates.formUnion(emojiCandidates.suffix(3).map{$0.text})
let predictions = self.getUniquePredictionCandidate(predictionResults, seenCandidates: seenCandidates).min(count: (10 - results.count) / 2, sortedBy: {$0.value > $1.value})
results.append(contentsOf: predictions)
seenCandidates.formUnion(predictions.map{$0.text})
let zeroHints = self.getUniquePredictionCandidate(zeroHintResults, seenCandidates: seenCandidates)
results.append(contentsOf: zeroHints.min(count: 10 - results.count, sortedBy: {$0.value > $1.value}))
return results
}
}

View File

@@ -84,7 +84,7 @@ extension Kana2Kanji {
}
//
let text = String(data.word.dropFirst(totalWord.count))
result.insert(.replacement(.init(text: text, targetData: totalData, replacementData: [data], value: newValue)), at: lastindex)
result.insert(.init(text: text, value: newValue, type: .replacement(targetData: totalData, replacementData: [data])), at: lastindex)
}
}
return result
@@ -111,15 +111,15 @@ extension Kana2Kanji {
let newValue = candidate.value + mmValue + ccValue + wValue
// index
let lastindex: Int = (result.lastIndex(where: {$0.value >= newValue}) ?? -1) + 1
if lastindex == N_best {
let lastIndex: Int = (result.lastIndex(where: {$0.value >= newValue}) ?? -1) + 1
if lastIndex == N_best {
continue
}
//
if result.count >= N_best {
result.removeLast()
}
result.insert(.additional(.init(text: data.word, data: [data], value: newValue)), at: lastindex)
result.insert(.init(text: data.word, value: newValue, type: .additional(data: [data])), at: lastIndex)
}
}
}

View File

@@ -7,63 +7,48 @@
import Foundation
public enum PredictionCandidate: Sendable, Hashable {
case additional(AdditionalPredictionCandidate)
case replacement(ReplacementPredictionCandidate)
public struct AdditionalPredictionCandidate: Sendable, Hashable {
public var text: String
public var data: [DicdataElement]
public var value: PValue
}
public struct ReplacementPredictionCandidate: Sendable, Hashable {
///
public var text: String
///
public var targetData: [DicdataElement]
///
public var replacementData: [DicdataElement]
///
public var value: PValue
}
public var value: PValue {
switch self {
case .additional(let c):
c.value
case .replacement(let c):
c.value
public struct PredictionCandidate {
public init(text: String, value: PValue, type: PredictionCandidate.PredictionType) {
self.text = text
self.value = value
self.type = type
if Set(["", ".", ""]).contains(text) {
self.isTerminal = true
} else {
self.isTerminal = false
}
}
public var text: String {
switch self {
case .additional(let c):
c.text
case .replacement(let c):
c.text
}
}
public var text: String
public var value: PValue
public var type: PredictionType
public var isTerminal: Bool
public func join(to candidate: consuming Candidate) -> Candidate {
switch self {
case .additional(let c):
for data in c.data {
switch self.type {
case .additional(let data):
for data in data {
candidate.text.append(contentsOf: data.word)
candidate.data.append(data)
}
candidate.value = c.value
candidate.value = self.value
candidate.correspondingCount = candidate.data.reduce(into: 0) { $0 += $1.ruby.count }
candidate.lastMid = c.data.last(where: DicdataStore.includeMMValueCalculation)?.mid ?? candidate.lastMid
candidate.lastMid = data.last(where: DicdataStore.includeMMValueCalculation)?.mid ?? candidate.lastMid
return candidate
case .replacement(let c):
candidate.data.removeLast(c.targetData.count)
candidate.data.append(contentsOf: c.replacementData)
case .replacement(let targetData, let replacementData):
candidate.data.removeLast(targetData.count)
candidate.data.append(contentsOf: replacementData)
candidate.text = candidate.data.reduce(into: "") {$0 += $1.word}
candidate.value = c.value
candidate.value = self.value
candidate.lastMid = candidate.data.last(where: DicdataStore.includeMMValueCalculation)?.mid ?? MIDData.BOS.mid
candidate.correspondingCount = candidate.data.reduce(into: 0) { $0 += $1.ruby.count }
return candidate
}
}
public enum PredictionType: Sendable, Hashable {
case additional(data: [DicdataElement])
case replacement(targetData: [DicdataElement], replacementData: [DicdataElement])
}
}