mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-08-22 15:05:26 +00:00
perf: avoid useless calculation
This commit is contained in:
@ -110,7 +110,7 @@ extension Kana2Kanji {
|
||||
// When inference occurs more than maximum times, then just return result at this point
|
||||
return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
|
||||
}
|
||||
let reviewResult = zenz.candidateEvaluate(convertTarget: inputData.convertTarget, candidates: [candidate], versionDependentConfig: versionDependentConfig)
|
||||
let reviewResult = zenz.candidateEvaluate(convertTarget: inputData.convertTarget, candidates: [candidate], requestRichCandidates: requestRichCandidates, versionDependentConfig: versionDependentConfig)
|
||||
inferenceLimit -= 1
|
||||
let nextAction = self.review(
|
||||
candidateIndex: index,
|
||||
@ -120,30 +120,32 @@ extension Kana2Kanji {
|
||||
)
|
||||
switch nextAction {
|
||||
case .return(let constraint, let alternativeConstraints, let satisfied):
|
||||
// alternativeConstraintsに従い、insertedCandidatesにデータを追加する
|
||||
for alternativeConstraint in alternativeConstraints.reversed() where alternativeConstraint.probabilityRatio > 0.25 {
|
||||
// constructed candidatesのうちalternativeConstraint.prefixConstraintを満たすものを列挙する
|
||||
let mostLiklyCandidate = constructedCandidates.filter {
|
||||
$0.1.text.utf8.hasPrefix(alternativeConstraint.prefixConstraint)
|
||||
}.max {
|
||||
$0.1.value < $1.1.value
|
||||
}
|
||||
if let mostLiklyCandidate {
|
||||
// 0番目は最良候補
|
||||
insertedCandidates.insert(mostLiklyCandidate, at: 1)
|
||||
} else if requestRichCandidates && alternativeConstraint.probabilityRatio > 0.5 {
|
||||
// 十分に高い確率の場合、変換器を実際に呼び出して候補を作ってもらう
|
||||
let draftResult = self.kana2lattice_all_with_prefix_constraint(inputData, N_best: 3, constraint: PrefixConstraint(alternativeConstraint.prefixConstraint))
|
||||
let candidates = draftResult.result.getCandidateData().map(self.processClauseCandidate)
|
||||
let best: (Int, Candidate)? = candidates.enumerated().reduce(into: nil) { best, pair in
|
||||
if let (_, c) = best, pair.1.value > c.value {
|
||||
best = pair
|
||||
} else if best == nil {
|
||||
best = pair
|
||||
}
|
||||
if requestRichCandidates {
|
||||
// alternativeConstraintsに従い、insertedCandidatesにデータを追加する
|
||||
for alternativeConstraint in alternativeConstraints.reversed() where alternativeConstraint.probabilityRatio > 0.25 {
|
||||
// constructed candidatesのうちalternativeConstraint.prefixConstraintを満たすものを列挙する
|
||||
let mostLiklyCandidate = constructedCandidates.filter {
|
||||
$0.1.text.utf8.hasPrefix(alternativeConstraint.prefixConstraint)
|
||||
}.max {
|
||||
$0.1.value < $1.1.value
|
||||
}
|
||||
if let (index, candidate) = best {
|
||||
insertedCandidates.insert((draftResult.result.prevs[index], candidate), at: 1)
|
||||
if let mostLiklyCandidate {
|
||||
// 0番目は最良候補
|
||||
insertedCandidates.insert(mostLiklyCandidate, at: 1)
|
||||
} else if alternativeConstraint.probabilityRatio > 0.5 {
|
||||
// 十分に高い確率の場合、変換器を実際に呼び出して候補を作ってもらう
|
||||
let draftResult = self.kana2lattice_all_with_prefix_constraint(inputData, N_best: 3, constraint: PrefixConstraint(alternativeConstraint.prefixConstraint))
|
||||
let candidates = draftResult.result.getCandidateData().map(self.processClauseCandidate)
|
||||
let best: (Int, Candidate)? = candidates.enumerated().reduce(into: nil) { best, pair in
|
||||
if let (_, c) = best, pair.1.value > c.value {
|
||||
best = pair
|
||||
} else if best == nil {
|
||||
best = pair
|
||||
}
|
||||
}
|
||||
if let (index, candidate) = best {
|
||||
insertedCandidates.insert((draftResult.result.prevs[index], candidate), at: 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -30,12 +30,12 @@ import SwiftUtils
|
||||
try? self.zenzContext?.reset_context()
|
||||
}
|
||||
|
||||
func candidateEvaluate(convertTarget: String, candidates: [Candidate], versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> ZenzContext.CandidateEvaluationResult {
|
||||
func candidateEvaluate(convertTarget: String, candidates: [Candidate], requestRichCandidates: Bool, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> ZenzContext.CandidateEvaluationResult {
|
||||
guard let zenzContext else {
|
||||
return .error
|
||||
}
|
||||
for candidate in candidates {
|
||||
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate, versionDependentConfig: versionDependentConfig)
|
||||
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate, requestRichCandidates: requestRichCandidates, versionDependentConfig: versionDependentConfig)
|
||||
return result
|
||||
}
|
||||
return .error
|
||||
|
@ -45,6 +45,10 @@ struct FixedSizeHeap<Element: Comparable> {
|
||||
var min: Element? {
|
||||
self.heap.min
|
||||
}
|
||||
|
||||
var isEmpty: Bool {
|
||||
self.heap.isEmpty
|
||||
}
|
||||
}
|
||||
|
||||
enum ZenzError: LocalizedError {
|
||||
@ -241,7 +245,7 @@ class ZenzContext {
|
||||
return minHeap.unordered.sorted { $0.value > $1.value }.map { ($0.character, $0.value / exp_sum) }
|
||||
}
|
||||
|
||||
func evaluate_candidate(input: String, candidate: Candidate, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
|
||||
func evaluate_candidate(input: String, candidate: Candidate, requestRichCandidates: Bool, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
|
||||
print("Evaluate", candidate)
|
||||
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
|
||||
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
|
||||
@ -294,7 +298,7 @@ class ZenzContext {
|
||||
var probabilityRatioToMaxProb: Float
|
||||
}
|
||||
|
||||
var altTokens = FixedSizeHeap<AlternativeHighProbToken>(size: 5)
|
||||
var altTokens = FixedSizeHeap<AlternativeHighProbToken>(size: requestRichCandidates ? 5 : 0)
|
||||
for (i, token_id) in tokens.indexed().dropFirst(prompt_tokens.count) {
|
||||
// それぞれのトークンが、一つ前の予測において最も確率の高いトークンであるかをチェックする
|
||||
// softmaxはmaxなので、単にlogitsの中で最も大きいものを選べば良い
|
||||
@ -310,7 +314,7 @@ class ZenzContext {
|
||||
var exp_sum: Float = 0
|
||||
let startIndex = (i - 1 - startOffset) * Int(n_vocab)
|
||||
let endIndex = (i - startOffset) * Int(n_vocab)
|
||||
var tokenHeap = FixedSizeHeap<TokenAndExpLogit>(size: 3)
|
||||
var tokenHeap = FixedSizeHeap<TokenAndExpLogit>(size: requestRichCandidates ? 3 : 0)
|
||||
for index in startIndex ..< endIndex {
|
||||
let v = exp(logits[index])
|
||||
exp_sum += v
|
||||
@ -344,7 +348,7 @@ class ZenzContext {
|
||||
return .fixRequired(prefixConstraint: cchars.dropFirst(prompt.utf8.count).map(UInt8.init))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
} else if !tokenHeap.isEmpty {
|
||||
tokenHeap.removeMax()
|
||||
let prefix = tokens[..<i].reduce(into: []) {
|
||||
$0.append(contentsOf: token_to_piece(token: $1))
|
||||
@ -362,11 +366,6 @@ class ZenzContext {
|
||||
}
|
||||
score += log(maxItem.expLogit) - log(exp_sum)
|
||||
}
|
||||
for item in altTokens.unordered.sorted(by: >) {
|
||||
print("Item")
|
||||
print(" Constraint", String(cString: item.constraint + [0]))
|
||||
print(" Probability Gain", item.probabilityRatioToMaxProb)
|
||||
}
|
||||
return .pass(score: score, alternativeConstraints: altTokens.unordered.sorted(by: >).map {.init(probabilityRatio: $0.probabilityRatioToMaxProb, prefixConstraint: $0.constraint)})
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user