mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-08-22 15:05:26 +00:00
feat: enable rich n-best calculation
This commit is contained in:
@ -62,6 +62,11 @@ extension Kana2Kanji {
|
||||
print("initial constraint", constraint)
|
||||
let eosNode = LatticeNode.EOSNode
|
||||
var nodes: Kana2Kanji.Nodes = []
|
||||
var constructedCandidates: [(RegisteredNode, Candidate)] = []
|
||||
var insertedCandidates: [(RegisteredNode, Candidate)] = []
|
||||
defer {
|
||||
eosNode.prevs = insertedCandidates.map(\.0)
|
||||
}
|
||||
var inferenceLimit = inferenceLimit
|
||||
while true {
|
||||
let start = Date()
|
||||
@ -78,6 +83,7 @@ extension Kana2Kanji {
|
||||
nodes = draftResult.nodes
|
||||
}
|
||||
let candidates = draftResult.result.getCandidateData().map(self.processClauseCandidate)
|
||||
constructedCandidates.append(contentsOf: zip(draftResult.result.prevs, candidates))
|
||||
var best: (Int, Candidate)? = nil
|
||||
for (i, cand) in candidates.enumerated() {
|
||||
if let (_, c) = best, cand.value > c.value {
|
||||
@ -92,10 +98,12 @@ extension Kana2Kanji {
|
||||
// 制約が満たせない場合は無視する
|
||||
return (eosNode, nodes, ZenzaiCache(inputData, constraint: PrefixConstraint([]), satisfyingCandidate: nil))
|
||||
}
|
||||
|
||||
print("Constrained draft modeling", -start.timeIntervalSinceNow)
|
||||
reviewLoop: while true {
|
||||
// resultsを更新
|
||||
eosNode.prevs.insert(draftResult.result.prevs[index], at: 0)
|
||||
// ここでN-Bestも並び変えていることになる
|
||||
insertedCandidates.insert((draftResult.result.prevs[index], candidate), at: 0)
|
||||
if inferenceLimit == 0 {
|
||||
print("inference limit! \(candidate.text) is used for excuse")
|
||||
// When inference occurs more than maximum times, then just return result at this point
|
||||
@ -110,7 +118,34 @@ extension Kana2Kanji {
|
||||
constraint: &constraint
|
||||
)
|
||||
switch nextAction {
|
||||
case .return(let constraint, let satisfied):
|
||||
case .return(let constraint, let alternativeConstraints, let satisfied):
|
||||
// alternativeConstraintsに従い、insertedCandidatesにデータを追加する
|
||||
for alternativeConstraint in alternativeConstraints.reversed() where alternativeConstraint.probabilityRatio > 0.25 {
|
||||
// constructed candidatesのうちalternativeConstraint.prefixConstraintを満たすものを列挙する
|
||||
let mostLiklyCandidate = constructedCandidates.filter {
|
||||
$0.1.text.utf8.hasPrefix(alternativeConstraint.prefixConstraint)
|
||||
}.max {
|
||||
$0.1.value < $1.1.value
|
||||
}
|
||||
if let mostLiklyCandidate {
|
||||
// 0番目は最良候補
|
||||
insertedCandidates.insert(mostLiklyCandidate, at: 1)
|
||||
} else if alternativeConstraint.probabilityRatio > 0.5 {
|
||||
// 十分に高い確率の場合、変換器を実際に呼び出して候補を作ってもらう
|
||||
let draftResult = self.kana2lattice_all_with_prefix_constraint(inputData, N_best: 3, constraint: PrefixConstraint(alternativeConstraint.prefixConstraint))
|
||||
let candidates = draftResult.result.getCandidateData().map(self.processClauseCandidate)
|
||||
let best: (Int, Candidate)? = candidates.enumerated().reduce(into: nil) { best, pair in
|
||||
if let (_, c) = best, pair.1.value > c.value {
|
||||
best = pair
|
||||
} else if best == nil {
|
||||
best = pair
|
||||
}
|
||||
}
|
||||
if let (index, candidate) = best {
|
||||
insertedCandidates.insert((draftResult.result.prevs[index], candidate), at: 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
if satisfied {
|
||||
return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
|
||||
} else {
|
||||
@ -127,7 +162,7 @@ extension Kana2Kanji {
|
||||
}
|
||||
|
||||
private enum NextAction {
|
||||
case `return`(constraint: PrefixConstraint, satisfied: Bool)
|
||||
case `return`(constraint: PrefixConstraint, alternativeConstraints: [ZenzContext.CandidateEvaluationResult.AlternativeConstraint], satisfied: Bool)
|
||||
case `continue`
|
||||
case `retry`(candidateIndex: Int)
|
||||
}
|
||||
@ -142,16 +177,16 @@ extension Kana2Kanji {
|
||||
case .error:
|
||||
// 何らかのエラーが発生
|
||||
print("error")
|
||||
return .return(constraint: constraint, satisfied: false)
|
||||
case .pass(let score):
|
||||
return .return(constraint: constraint, alternativeConstraints: [], satisfied: false)
|
||||
case .pass(let score, let alternativeConstraints):
|
||||
// 合格
|
||||
print("passed:", score)
|
||||
return .return(constraint: constraint, satisfied: true)
|
||||
return .return(constraint: constraint, alternativeConstraints: alternativeConstraints, satisfied: true)
|
||||
case .fixRequired(let prefixConstraint):
|
||||
// 同じ制約が2回連続で出てきたら諦める
|
||||
if constraint.constraint == prefixConstraint {
|
||||
print("same constraint:", prefixConstraint)
|
||||
return .return(constraint: PrefixConstraint([]), satisfied: false)
|
||||
return .return(constraint: PrefixConstraint([]), alternativeConstraints: [], satisfied: false)
|
||||
}
|
||||
// 制約が得られたので、更新する
|
||||
constraint = PrefixConstraint(prefixConstraint)
|
||||
@ -169,7 +204,7 @@ extension Kana2Kanji {
|
||||
// 同じ制約が2回連続で出てきたら諦める
|
||||
if constraint == newConstraint {
|
||||
print("same constraint:", constraint)
|
||||
return .return(constraint: PrefixConstraint([]), satisfied: false)
|
||||
return .return(constraint: PrefixConstraint([]), alternativeConstraints: [], satisfied: false)
|
||||
}
|
||||
// 制約が得られたので、更新する
|
||||
print("update whole constraint:", wholeConstraint)
|
||||
|
@ -4,6 +4,50 @@ import HeapModule
|
||||
import Algorithms
|
||||
import Foundation
|
||||
|
||||
struct FixedSizeHeap<Element: Comparable> {
|
||||
private var size: Int
|
||||
private var heap: Heap<Element>
|
||||
|
||||
init(size: Int) {
|
||||
self.size = size
|
||||
self.heap = []
|
||||
}
|
||||
|
||||
mutating func removeMax() {
|
||||
self.heap.removeMax()
|
||||
}
|
||||
|
||||
mutating func removeMin() {
|
||||
self.heap.removeMin()
|
||||
}
|
||||
|
||||
@discardableResult
|
||||
mutating func insertIfPossible(_ element: Element) -> Bool {
|
||||
if self.heap.count < self.size {
|
||||
self.heap.insert(element)
|
||||
return true
|
||||
} else if let min = self.heap.min, element > min {
|
||||
self.heap.replaceMin(with: element)
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
var unordered: [Element] {
|
||||
self.heap.unordered
|
||||
}
|
||||
|
||||
var max: Element? {
|
||||
self.heap.max
|
||||
}
|
||||
|
||||
var min: Element? {
|
||||
self.heap.min
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
enum ZenzError: LocalizedError {
|
||||
case couldNotLoadModel(path: String)
|
||||
case couldNotLoadContext
|
||||
@ -124,9 +168,14 @@ class ZenzContext {
|
||||
|
||||
enum CandidateEvaluationResult: Sendable, Equatable, Hashable {
|
||||
case error
|
||||
case pass(score: Float)
|
||||
case pass(score: Float, alternativeConstraints: [AlternativeConstraint])
|
||||
case fixRequired(prefixConstraint: [UInt8])
|
||||
case wholeResult(String)
|
||||
|
||||
struct AlternativeConstraint: Sendable, Equatable, Hashable {
|
||||
var probabilityRatio: Float
|
||||
var prefixConstraint: [UInt8]
|
||||
}
|
||||
}
|
||||
|
||||
func getLearningPriority(data: DicdataElement) -> Float {
|
||||
@ -166,7 +215,7 @@ class ZenzContext {
|
||||
let endIndex = (prompt_tokens.count - startOffset) * Int(n_vocab)
|
||||
|
||||
// Min-Heapを使用してn-bestを計算
|
||||
var minHeap: Heap<NextCharacterCandidate> = .init()
|
||||
var minHeap: FixedSizeHeap<NextCharacterCandidate> = .init(size: count)
|
||||
let token_to_penalty_weight: [llama_token: Float] = prompt_tokens.indexed().reduce(into: [:]) { dict, item in
|
||||
let (index, token) = item
|
||||
// 現在位置から遠いほど減衰させる
|
||||
@ -186,11 +235,7 @@ class ZenzContext {
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
if minHeap.count < count {
|
||||
minHeap.insert(NextCharacterCandidate(character: character, value: v))
|
||||
} else if let min = minHeap.min, v > min.value {
|
||||
minHeap.replaceMin(with: NextCharacterCandidate(character: character, value: v))
|
||||
}
|
||||
minHeap.insertIfPossible(NextCharacterCandidate(character: character, value: v))
|
||||
}
|
||||
|
||||
// Heapからソートして結果を取り出す
|
||||
@ -238,26 +283,47 @@ class ZenzContext {
|
||||
}
|
||||
|
||||
var score: Float = 0
|
||||
|
||||
struct AlternativeHighProbToken: Comparable {
|
||||
static func < (lhs: AlternativeHighProbToken, rhs: AlternativeHighProbToken) -> Bool {
|
||||
lhs.probabilityRatioToMaxProb < rhs.probabilityRatioToMaxProb
|
||||
}
|
||||
|
||||
var token: llama_token
|
||||
var constraint: [UInt8]
|
||||
// 最大probabilityに対しての割合
|
||||
var probabilityRatioToMaxProb: Float
|
||||
}
|
||||
|
||||
var altTokens = FixedSizeHeap<AlternativeHighProbToken>(size: 5)
|
||||
for (i, token_id) in tokens.indexed().dropFirst(prompt_tokens.count) {
|
||||
// それぞれのトークンが、一つ前の予測において最も確率の高いトークンであるかをチェックする
|
||||
// softmaxはmaxなので、単にlogitsの中で最も大きいものを選べば良い
|
||||
// 一方実用的にはlog_probも得ておきたい。このため、ここでは明示的にsoftmaxも計算している
|
||||
struct TokenAndExpLogit: Comparable {
|
||||
static func < (lhs: TokenAndExpLogit, rhs: TokenAndExpLogit) -> Bool {
|
||||
lhs.expLogit < rhs.expLogit
|
||||
}
|
||||
|
||||
var token: llama_token
|
||||
var expLogit: Float
|
||||
}
|
||||
var exp_sum: Float = 0
|
||||
var max_token: llama_token = 0
|
||||
var max_exp: Float = .infinity * -1
|
||||
let startIndex = (i - 1 - startOffset) * Int(n_vocab)
|
||||
let endIndex = (i - startOffset) * Int(n_vocab)
|
||||
var tokenHeap = FixedSizeHeap<TokenAndExpLogit>(size: 3)
|
||||
for index in startIndex ..< endIndex {
|
||||
let v = exp(logits[index])
|
||||
exp_sum += v
|
||||
if max_exp < v {
|
||||
max_exp = v
|
||||
max_token = llama_token(index - startIndex)
|
||||
}
|
||||
tokenHeap.insertIfPossible(TokenAndExpLogit(token: llama_token(index - startIndex), expLogit: v))
|
||||
}
|
||||
guard let maxItem = tokenHeap.max else {
|
||||
print("Max Item could not be found for unknown reason")
|
||||
return .error
|
||||
}
|
||||
// ここで最も良い候補であったかをチェックする
|
||||
if max_token != token_id {
|
||||
if max_token == llama_token_eos(model) {
|
||||
if maxItem.token != token_id {
|
||||
if maxItem.token == llama_token_eos(model) {
|
||||
var cchars = tokens[..<i].reduce(into: []) {
|
||||
$0.append(contentsOf: token_to_piece(token: $1))
|
||||
}
|
||||
@ -270,19 +336,39 @@ class ZenzContext {
|
||||
} else {
|
||||
let actual_exp: Float = exp(logits[startIndex + Int(token_id)])
|
||||
// 学習されたトークンであり、なおかつactual_expのある程度大きければ、学習されたトークンを優先する
|
||||
let preferLearnedToken = is_learned_token[i].isLearned && actual_exp * is_learned_token[i].priority > max_exp
|
||||
let preferLearnedToken = is_learned_token[i].isLearned && actual_exp * is_learned_token[i].priority > maxItem.expLogit
|
||||
if !preferLearnedToken {
|
||||
// adding "\0"
|
||||
let cchars = tokens[..<i].reduce(into: []) {
|
||||
$0.append(contentsOf: token_to_piece(token: $1))
|
||||
} + token_to_piece(token: max_token)
|
||||
} + token_to_piece(token: maxItem.token)
|
||||
return .fixRequired(prefixConstraint: cchars.dropFirst(prompt.utf8.count).map(UInt8.init))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tokenHeap.removeMax()
|
||||
let prefix = tokens[..<i].reduce(into: []) {
|
||||
$0.append(contentsOf: token_to_piece(token: $1))
|
||||
}.dropFirst(prompt.utf8.count)
|
||||
|
||||
for item in tokenHeap.unordered {
|
||||
altTokens.insertIfPossible(
|
||||
AlternativeHighProbToken(
|
||||
token: item.token,
|
||||
constraint: prefix.map(UInt8.init) + token_to_piece(token: item.token).map(UInt8.init),
|
||||
probabilityRatioToMaxProb: item.expLogit / maxItem.expLogit
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
score += log(max_exp) - log(exp_sum)
|
||||
score += log(maxItem.expLogit) - log(exp_sum)
|
||||
}
|
||||
return .pass(score: score)
|
||||
for item in altTokens.unordered.sorted(by: >) {
|
||||
print("Item")
|
||||
print(" Constraint", String(cString: item.constraint + [0]))
|
||||
print(" Probability Gain", item.probabilityRatioToMaxProb)
|
||||
}
|
||||
return .pass(score: score, alternativeConstraints: altTokens.unordered.sorted(by: >).map {.init(probabilityRatio: $0.probabilityRatioToMaxProb, prefixConstraint: $0.constraint)})
|
||||
}
|
||||
|
||||
private func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], logits: Bool) {
|
||||
|
Reference in New Issue
Block a user