perf: avoid useless calculation

This commit is contained in:
Miwa / Ensan
2024-08-07 23:38:59 +09:00
parent 9d31eeee66
commit 20fe93c21d
3 changed files with 36 additions and 35 deletions

View File

@ -110,7 +110,7 @@ extension Kana2Kanji {
// When inference occurs more than maximum times, then just return result at this point
return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
}
let reviewResult = zenz.candidateEvaluate(convertTarget: inputData.convertTarget, candidates: [candidate], versionDependentConfig: versionDependentConfig)
let reviewResult = zenz.candidateEvaluate(convertTarget: inputData.convertTarget, candidates: [candidate], requestRichCandidates: requestRichCandidates, versionDependentConfig: versionDependentConfig)
inferenceLimit -= 1
let nextAction = self.review(
candidateIndex: index,
@ -120,30 +120,32 @@ extension Kana2Kanji {
)
switch nextAction {
case .return(let constraint, let alternativeConstraints, let satisfied):
// alternativeConstraintsinsertedCandidates
for alternativeConstraint in alternativeConstraints.reversed() where alternativeConstraint.probabilityRatio > 0.25 {
// constructed candidatesalternativeConstraint.prefixConstraint
let mostLiklyCandidate = constructedCandidates.filter {
$0.1.text.utf8.hasPrefix(alternativeConstraint.prefixConstraint)
}.max {
$0.1.value < $1.1.value
}
if let mostLiklyCandidate {
// 0
insertedCandidates.insert(mostLiklyCandidate, at: 1)
} else if requestRichCandidates && alternativeConstraint.probabilityRatio > 0.5 {
//
let draftResult = self.kana2lattice_all_with_prefix_constraint(inputData, N_best: 3, constraint: PrefixConstraint(alternativeConstraint.prefixConstraint))
let candidates = draftResult.result.getCandidateData().map(self.processClauseCandidate)
let best: (Int, Candidate)? = candidates.enumerated().reduce(into: nil) { best, pair in
if let (_, c) = best, pair.1.value > c.value {
best = pair
} else if best == nil {
best = pair
}
if requestRichCandidates {
// alternativeConstraintsinsertedCandidates
for alternativeConstraint in alternativeConstraints.reversed() where alternativeConstraint.probabilityRatio > 0.25 {
// constructed candidatesalternativeConstraint.prefixConstraint
let mostLiklyCandidate = constructedCandidates.filter {
$0.1.text.utf8.hasPrefix(alternativeConstraint.prefixConstraint)
}.max {
$0.1.value < $1.1.value
}
if let (index, candidate) = best {
insertedCandidates.insert((draftResult.result.prevs[index], candidate), at: 1)
if let mostLiklyCandidate {
// 0
insertedCandidates.insert(mostLiklyCandidate, at: 1)
} else if alternativeConstraint.probabilityRatio > 0.5 {
//
let draftResult = self.kana2lattice_all_with_prefix_constraint(inputData, N_best: 3, constraint: PrefixConstraint(alternativeConstraint.prefixConstraint))
let candidates = draftResult.result.getCandidateData().map(self.processClauseCandidate)
let best: (Int, Candidate)? = candidates.enumerated().reduce(into: nil) { best, pair in
if let (_, c) = best, pair.1.value > c.value {
best = pair
} else if best == nil {
best = pair
}
}
if let (index, candidate) = best {
insertedCandidates.insert((draftResult.result.prevs[index], candidate), at: 1)
}
}
}
}

View File

@ -30,12 +30,12 @@ import SwiftUtils
try? self.zenzContext?.reset_context()
}
func candidateEvaluate(convertTarget: String, candidates: [Candidate], versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> ZenzContext.CandidateEvaluationResult {
func candidateEvaluate(convertTarget: String, candidates: [Candidate], requestRichCandidates: Bool, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> ZenzContext.CandidateEvaluationResult {
guard let zenzContext else {
return .error
}
for candidate in candidates {
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate, versionDependentConfig: versionDependentConfig)
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate, requestRichCandidates: requestRichCandidates, versionDependentConfig: versionDependentConfig)
return result
}
return .error

View File

@ -45,6 +45,10 @@ struct FixedSizeHeap<Element: Comparable> {
var min: Element? {
self.heap.min
}
var isEmpty: Bool {
self.heap.isEmpty
}
}
enum ZenzError: LocalizedError {
@ -241,7 +245,7 @@ class ZenzContext {
return minHeap.unordered.sorted { $0.value > $1.value }.map { ($0.character, $0.value / exp_sum) }
}
func evaluate_candidate(input: String, candidate: Candidate, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
func evaluate_candidate(input: String, candidate: Candidate, requestRichCandidates: Bool, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
print("Evaluate", candidate)
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
@ -294,7 +298,7 @@ class ZenzContext {
var probabilityRatioToMaxProb: Float
}
var altTokens = FixedSizeHeap<AlternativeHighProbToken>(size: 5)
var altTokens = FixedSizeHeap<AlternativeHighProbToken>(size: requestRichCandidates ? 5 : 0)
for (i, token_id) in tokens.indexed().dropFirst(prompt_tokens.count) {
//
// softmaxmaxlogits
@ -310,7 +314,7 @@ class ZenzContext {
var exp_sum: Float = 0
let startIndex = (i - 1 - startOffset) * Int(n_vocab)
let endIndex = (i - startOffset) * Int(n_vocab)
var tokenHeap = FixedSizeHeap<TokenAndExpLogit>(size: 3)
var tokenHeap = FixedSizeHeap<TokenAndExpLogit>(size: requestRichCandidates ? 3 : 0)
for index in startIndex ..< endIndex {
let v = exp(logits[index])
exp_sum += v
@ -344,7 +348,7 @@ class ZenzContext {
return .fixRequired(prefixConstraint: cchars.dropFirst(prompt.utf8.count).map(UInt8.init))
}
}
} else {
} else if !tokenHeap.isEmpty {
tokenHeap.removeMax()
let prefix = tokens[..<i].reduce(into: []) {
$0.append(contentsOf: token_to_piece(token: $1))
@ -362,11 +366,6 @@ class ZenzContext {
}
score += log(maxItem.expLogit) - log(exp_sum)
}
for item in altTokens.unordered.sorted(by: >) {
print("Item")
print(" Constraint", String(cString: item.constraint + [0]))
print(" Probability Gain", item.probabilityRatioToMaxProb)
}
return .pass(score: score, alternativeConstraints: altTokens.unordered.sorted(by: >).map {.init(probabilityRatio: $0.probabilityRatioToMaxProb, prefixConstraint: $0.constraint)})
}