mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-08-22 15:05:26 +00:00
feat: enable rich n-best calculation
This commit is contained in:
@ -4,6 +4,50 @@ import HeapModule
|
||||
import Algorithms
|
||||
import Foundation
|
||||
|
||||
struct FixedSizeHeap<Element: Comparable> {
|
||||
private var size: Int
|
||||
private var heap: Heap<Element>
|
||||
|
||||
init(size: Int) {
|
||||
self.size = size
|
||||
self.heap = []
|
||||
}
|
||||
|
||||
mutating func removeMax() {
|
||||
self.heap.removeMax()
|
||||
}
|
||||
|
||||
mutating func removeMin() {
|
||||
self.heap.removeMin()
|
||||
}
|
||||
|
||||
@discardableResult
|
||||
mutating func insertIfPossible(_ element: Element) -> Bool {
|
||||
if self.heap.count < self.size {
|
||||
self.heap.insert(element)
|
||||
return true
|
||||
} else if let min = self.heap.min, element > min {
|
||||
self.heap.replaceMin(with: element)
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
var unordered: [Element] {
|
||||
self.heap.unordered
|
||||
}
|
||||
|
||||
var max: Element? {
|
||||
self.heap.max
|
||||
}
|
||||
|
||||
var min: Element? {
|
||||
self.heap.min
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
enum ZenzError: LocalizedError {
|
||||
case couldNotLoadModel(path: String)
|
||||
case couldNotLoadContext
|
||||
@ -124,9 +168,14 @@ class ZenzContext {
|
||||
|
||||
enum CandidateEvaluationResult: Sendable, Equatable, Hashable {
|
||||
case error
|
||||
case pass(score: Float)
|
||||
case pass(score: Float, alternativeConstraints: [AlternativeConstraint])
|
||||
case fixRequired(prefixConstraint: [UInt8])
|
||||
case wholeResult(String)
|
||||
|
||||
struct AlternativeConstraint: Sendable, Equatable, Hashable {
|
||||
var probabilityRatio: Float
|
||||
var prefixConstraint: [UInt8]
|
||||
}
|
||||
}
|
||||
|
||||
func getLearningPriority(data: DicdataElement) -> Float {
|
||||
@ -166,7 +215,7 @@ class ZenzContext {
|
||||
let endIndex = (prompt_tokens.count - startOffset) * Int(n_vocab)
|
||||
|
||||
// Min-Heapを使用してn-bestを計算
|
||||
var minHeap: Heap<NextCharacterCandidate> = .init()
|
||||
var minHeap: FixedSizeHeap<NextCharacterCandidate> = .init(size: count)
|
||||
let token_to_penalty_weight: [llama_token: Float] = prompt_tokens.indexed().reduce(into: [:]) { dict, item in
|
||||
let (index, token) = item
|
||||
// 現在位置から遠いほど減衰させる
|
||||
@ -186,11 +235,7 @@ class ZenzContext {
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
if minHeap.count < count {
|
||||
minHeap.insert(NextCharacterCandidate(character: character, value: v))
|
||||
} else if let min = minHeap.min, v > min.value {
|
||||
minHeap.replaceMin(with: NextCharacterCandidate(character: character, value: v))
|
||||
}
|
||||
minHeap.insertIfPossible(NextCharacterCandidate(character: character, value: v))
|
||||
}
|
||||
|
||||
// Heapからソートして結果を取り出す
|
||||
@ -238,26 +283,47 @@ class ZenzContext {
|
||||
}
|
||||
|
||||
var score: Float = 0
|
||||
|
||||
struct AlternativeHighProbToken: Comparable {
|
||||
static func < (lhs: AlternativeHighProbToken, rhs: AlternativeHighProbToken) -> Bool {
|
||||
lhs.probabilityRatioToMaxProb < rhs.probabilityRatioToMaxProb
|
||||
}
|
||||
|
||||
var token: llama_token
|
||||
var constraint: [UInt8]
|
||||
// 最大probabilityに対しての割合
|
||||
var probabilityRatioToMaxProb: Float
|
||||
}
|
||||
|
||||
var altTokens = FixedSizeHeap<AlternativeHighProbToken>(size: 5)
|
||||
for (i, token_id) in tokens.indexed().dropFirst(prompt_tokens.count) {
|
||||
// それぞれのトークンが、一つ前の予測において最も確率の高いトークンであるかをチェックする
|
||||
// softmaxはmaxなので、単にlogitsの中で最も大きいものを選べば良い
|
||||
// 一方実用的にはlog_probも得ておきたい。このため、ここでは明示的にsoftmaxも計算している
|
||||
struct TokenAndExpLogit: Comparable {
|
||||
static func < (lhs: TokenAndExpLogit, rhs: TokenAndExpLogit) -> Bool {
|
||||
lhs.expLogit < rhs.expLogit
|
||||
}
|
||||
|
||||
var token: llama_token
|
||||
var expLogit: Float
|
||||
}
|
||||
var exp_sum: Float = 0
|
||||
var max_token: llama_token = 0
|
||||
var max_exp: Float = .infinity * -1
|
||||
let startIndex = (i - 1 - startOffset) * Int(n_vocab)
|
||||
let endIndex = (i - startOffset) * Int(n_vocab)
|
||||
var tokenHeap = FixedSizeHeap<TokenAndExpLogit>(size: 3)
|
||||
for index in startIndex ..< endIndex {
|
||||
let v = exp(logits[index])
|
||||
exp_sum += v
|
||||
if max_exp < v {
|
||||
max_exp = v
|
||||
max_token = llama_token(index - startIndex)
|
||||
}
|
||||
tokenHeap.insertIfPossible(TokenAndExpLogit(token: llama_token(index - startIndex), expLogit: v))
|
||||
}
|
||||
guard let maxItem = tokenHeap.max else {
|
||||
print("Max Item could not be found for unknown reason")
|
||||
return .error
|
||||
}
|
||||
// ここで最も良い候補であったかをチェックする
|
||||
if max_token != token_id {
|
||||
if max_token == llama_token_eos(model) {
|
||||
if maxItem.token != token_id {
|
||||
if maxItem.token == llama_token_eos(model) {
|
||||
var cchars = tokens[..<i].reduce(into: []) {
|
||||
$0.append(contentsOf: token_to_piece(token: $1))
|
||||
}
|
||||
@ -270,19 +336,39 @@ class ZenzContext {
|
||||
} else {
|
||||
let actual_exp: Float = exp(logits[startIndex + Int(token_id)])
|
||||
// 学習されたトークンであり、なおかつactual_expのある程度大きければ、学習されたトークンを優先する
|
||||
let preferLearnedToken = is_learned_token[i].isLearned && actual_exp * is_learned_token[i].priority > max_exp
|
||||
let preferLearnedToken = is_learned_token[i].isLearned && actual_exp * is_learned_token[i].priority > maxItem.expLogit
|
||||
if !preferLearnedToken {
|
||||
// adding "\0"
|
||||
let cchars = tokens[..<i].reduce(into: []) {
|
||||
$0.append(contentsOf: token_to_piece(token: $1))
|
||||
} + token_to_piece(token: max_token)
|
||||
} + token_to_piece(token: maxItem.token)
|
||||
return .fixRequired(prefixConstraint: cchars.dropFirst(prompt.utf8.count).map(UInt8.init))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tokenHeap.removeMax()
|
||||
let prefix = tokens[..<i].reduce(into: []) {
|
||||
$0.append(contentsOf: token_to_piece(token: $1))
|
||||
}.dropFirst(prompt.utf8.count)
|
||||
|
||||
for item in tokenHeap.unordered {
|
||||
altTokens.insertIfPossible(
|
||||
AlternativeHighProbToken(
|
||||
token: item.token,
|
||||
constraint: prefix.map(UInt8.init) + token_to_piece(token: item.token).map(UInt8.init),
|
||||
probabilityRatioToMaxProb: item.expLogit / maxItem.expLogit
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
score += log(max_exp) - log(exp_sum)
|
||||
score += log(maxItem.expLogit) - log(exp_sum)
|
||||
}
|
||||
return .pass(score: score)
|
||||
for item in altTokens.unordered.sorted(by: >) {
|
||||
print("Item")
|
||||
print(" Constraint", String(cString: item.constraint + [0]))
|
||||
print(" Probability Gain", item.probabilityRatioToMaxProb)
|
||||
}
|
||||
return .pass(score: score, alternativeConstraints: altTokens.unordered.sorted(by: >).map {.init(probabilityRatio: $0.probabilityRatioToMaxProb, prefixConstraint: $0.constraint)})
|
||||
}
|
||||
|
||||
private func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], logits: Bool) {
|
||||
|
Reference in New Issue
Block a user