Files
AzooKeyKanaKanjiConverter/Sources/KanaKanjiConverterModule/ConversionAlgorithms/Zenzai/zenzai.swift

253 lines
13 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import Algorithms
import Foundation
import SwiftUtils
import EfficientNGram
extension Kana2Kanji {
struct ZenzaiCache: Sendable {
init(_ inputData: ComposingText, constraint: PrefixConstraint, satisfyingCandidate: Candidate?) {
self.inputData = inputData
self.prefixConstraint = constraint
self.satisfyingCandidate = satisfyingCandidate
}
private var prefixConstraint: PrefixConstraint
private var satisfyingCandidate: Candidate?
private var inputData: ComposingText
func getNewConstraint(for newInputData: ComposingText) -> PrefixConstraint {
if let satisfyingCandidate {
var current = newInputData.convertTarget.toKatakana()[...]
var constraint = [UInt8]()
for item in satisfyingCandidate.data {
if current.hasPrefix(item.ruby) {
constraint += item.word.utf8
current = current.dropFirst(item.ruby.count)
}
}
return PrefixConstraint(constraint)
} else if newInputData.convertTarget.hasPrefix(inputData.convertTarget) {
return self.prefixConstraint
} else {
return PrefixConstraint([])
}
}
}
struct PrefixConstraint: Sendable, Equatable, Hashable, CustomStringConvertible {
init(_ constraint: [UInt8], hasEOS: Bool = false) {
self.constraint = constraint
self.hasEOS = hasEOS
}
var constraint: [UInt8]
var hasEOS: Bool
var description: String {
"PrefixConstraint(constraint: \"\(String(decoding: self.constraint, as: UTF8.self))\", hasEOS: \(self.hasEOS))"
}
var isEmpty: Bool {
self.constraint.isEmpty && !self.hasEOS
}
}
/// zenzai
@MainActor func all_zenzai(
_ inputData: ComposingText,
zenz: Zenz,
zenzaiCache: ZenzaiCache?,
inferenceLimit: Int,
requestRichCandidates: Bool,
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: EfficientNGram, personal: EfficientNGram)?,
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
) -> (result: LatticeNode, lattice: Lattice, cache: ZenzaiCache) {
var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? PrefixConstraint([])
debug("initial constraint", constraint)
let eosNode = LatticeNode.EOSNode
var lattice: Lattice = Lattice()
var constructedCandidates: [(RegisteredNode, Candidate)] = []
var insertedCandidates: [(RegisteredNode, Candidate)] = []
defer {
eosNode.prevs = insertedCandidates.map(\.0)
}
var inferenceLimit = inferenceLimit
while true {
let start = Date()
let draftResult = if constraint.isEmpty {
// N=2
// 2-best
self.kana2lattice_all(inputData, N_best: 2, needTypoCorrection: false)
} else {
// N=3
self.kana2lattice_all_with_prefix_constraint(inputData, N_best: 3, constraint: constraint)
}
if lattice.isEmpty {
//
lattice = draftResult.lattice
}
let candidates = draftResult.result.getCandidateData().map(self.processClauseCandidate)
constructedCandidates.append(contentsOf: zip(draftResult.result.prevs, candidates))
var best: (Int, Candidate)?
for (i, cand) in candidates.enumerated() {
if let (_, c) = best, cand.value > c.value {
best = (i, cand)
} else if best == nil {
best = (i, cand)
}
}
guard var (index, candidate) = best else {
debug("best was not found!")
// Empty
//
return (eosNode, lattice, ZenzaiCache(inputData, constraint: PrefixConstraint([]), satisfyingCandidate: nil))
}
debug("Constrained draft modeling", -start.timeIntervalSinceNow)
reviewLoop: while true {
// results
// N-Best
insertedCandidates.insert((draftResult.result.prevs[index], candidate), at: 0)
if inferenceLimit == 0 {
debug("inference limit! \(candidate.text) is used for excuse")
// When inference occurs more than maximum times, then just return result at this point
return (eosNode, lattice, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
}
let reviewResult = zenz.candidateEvaluate(
convertTarget: inputData.convertTarget,
candidates: [candidate],
requestRichCandidates: requestRichCandidates,
personalizationMode: personalizationMode,
versionDependentConfig: versionDependentConfig
)
inferenceLimit -= 1
let nextAction = self.review(
candidateIndex: index,
candidates: candidates,
reviewResult: reviewResult,
constraint: &constraint
)
switch nextAction {
case .return(let constraint, let alternativeConstraints, let satisfied):
if requestRichCandidates {
// alternativeConstraintsinsertedCandidates
for alternativeConstraint in alternativeConstraints.reversed() where alternativeConstraint.probabilityRatio > 0.25 {
// constructed candidatesalternativeConstraint.prefixConstraint
let mostLiklyCandidate = constructedCandidates.filter {
$0.1.text.utf8.hasPrefix(alternativeConstraint.prefixConstraint)
}.max {
$0.1.value < $1.1.value
}
if let mostLiklyCandidate {
// 0
insertedCandidates.insert(mostLiklyCandidate, at: 1)
} else if alternativeConstraint.probabilityRatio > 0.5 {
//
let draftResult = self.kana2lattice_all_with_prefix_constraint(inputData, N_best: 3, constraint: PrefixConstraint(alternativeConstraint.prefixConstraint))
let candidates = draftResult.result.getCandidateData().map(self.processClauseCandidate)
let best: (Int, Candidate)? = candidates.enumerated().reduce(into: nil) { best, pair in
if let (_, c) = best, pair.1.value > c.value {
best = pair
} else if best == nil {
best = pair
}
}
if let (index, candidate) = best {
insertedCandidates.insert((draftResult.result.prevs[index], candidate), at: 1)
}
}
}
}
if satisfied {
return (eosNode, lattice, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
} else {
return (eosNode, lattice, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: nil))
}
case .continue:
break reviewLoop
case .retry(let candidateIndex):
index = candidateIndex
candidate = candidates[candidateIndex]
}
}
}
}
private enum NextAction {
case `return`(constraint: PrefixConstraint, alternativeConstraints: [ZenzContext.CandidateEvaluationResult.AlternativeConstraint], satisfied: Bool)
case `continue`
case `retry`(candidateIndex: Int)
}
private func review(
candidateIndex: Int,
candidates: [Candidate],
reviewResult: consuming ZenzContext.CandidateEvaluationResult,
constraint: inout PrefixConstraint,
) -> NextAction {
switch reviewResult {
case .error:
//
debug("error")
return .return(constraint: constraint, alternativeConstraints: [], satisfied: false)
case .pass(let score, let alternativeConstraints):
//
debug("passed:", score)
return .return(constraint: constraint, alternativeConstraints: alternativeConstraints, satisfied: true)
case .fixRequired(let prefixConstraint):
// 2
if constraint.constraint == prefixConstraint {
debug("same constraint:", prefixConstraint)
return .return(constraint: PrefixConstraint([]), alternativeConstraints: [], satisfied: false)
}
//
let isIncrementalUpdate = prefixConstraint.hasPrefix(constraint.constraint)
constraint = PrefixConstraint(prefixConstraint)
debug("update constraint:", constraint)
if isIncrementalUpdate {
// 使
// prefix constraintcandidates
// prefix constraint!isIncrementalUpdate
for (i, candidate) in candidates.indexed() where i != candidateIndex {
if candidate.text.utf8.hasPrefix(prefixConstraint) && self.heuristicRetryValidation(candidate.text) {
debug("found \(candidate.text) as another retry")
return .retry(candidateIndex: i)
}
}
}
return .continue
case .wholeResult(let wholeConstraint):
let newConstraint = PrefixConstraint(Array(wholeConstraint.utf8), hasEOS: true)
// 2
if constraint == newConstraint {
debug("same constraint:", constraint)
return .return(constraint: PrefixConstraint([]), alternativeConstraints: [], satisfied: false)
}
//
debug("update whole constraint:", wholeConstraint)
let isIncrementalUpdate = wholeConstraint.utf8.hasPrefix(constraint.constraint)
constraint = PrefixConstraint(Array(wholeConstraint.utf8), hasEOS: true)
if isIncrementalUpdate {
// 使
// prefix constraint!isIncrementalUpdate
for (i, candidate) in candidates.indexed() where i != candidateIndex {
if candidate.text == wholeConstraint && self.heuristicRetryValidation(candidate.text) {
debug("found \(candidate.text) as another retry")
return .retry(candidateIndex: i)
}
}
}
return .continue
}
}
///
private func heuristicRetryValidation(_ text: String) -> Bool {
//
if text.unicodeScalars.contains("\u{3099}") || text.unicodeScalars.contains("\u{309A}") {
return false
}
return true
}
}