feat: enable rich n-best calculation

This commit is contained in:
Miwa / Ensan
2024-08-07 01:23:03 +09:00
parent 6f901eecdd
commit 1971f6382b
2 changed files with 148 additions and 27 deletions

View File

@ -4,6 +4,50 @@ import HeapModule
import Algorithms
import Foundation
struct FixedSizeHeap<Element: Comparable> {
private var size: Int
private var heap: Heap<Element>
init(size: Int) {
self.size = size
self.heap = []
}
mutating func removeMax() {
self.heap.removeMax()
}
mutating func removeMin() {
self.heap.removeMin()
}
@discardableResult
mutating func insertIfPossible(_ element: Element) -> Bool {
if self.heap.count < self.size {
self.heap.insert(element)
return true
} else if let min = self.heap.min, element > min {
self.heap.replaceMin(with: element)
return true
} else {
return false
}
}
var unordered: [Element] {
self.heap.unordered
}
var max: Element? {
self.heap.max
}
var min: Element? {
self.heap.min
}
}
enum ZenzError: LocalizedError {
case couldNotLoadModel(path: String)
case couldNotLoadContext
@ -124,9 +168,14 @@ class ZenzContext {
enum CandidateEvaluationResult: Sendable, Equatable, Hashable {
case error
case pass(score: Float)
case pass(score: Float, alternativeConstraints: [AlternativeConstraint])
case fixRequired(prefixConstraint: [UInt8])
case wholeResult(String)
struct AlternativeConstraint: Sendable, Equatable, Hashable {
var probabilityRatio: Float
var prefixConstraint: [UInt8]
}
}
func getLearningPriority(data: DicdataElement) -> Float {
@ -166,7 +215,7 @@ class ZenzContext {
let endIndex = (prompt_tokens.count - startOffset) * Int(n_vocab)
// Min-Heap使n-best
var minHeap: Heap<NextCharacterCandidate> = .init()
var minHeap: FixedSizeHeap<NextCharacterCandidate> = .init(size: count)
let token_to_penalty_weight: [llama_token: Float] = prompt_tokens.indexed().reduce(into: [:]) { dict, item in
let (index, token) = item
//
@ -186,11 +235,7 @@ class ZenzContext {
} else {
continue
}
if minHeap.count < count {
minHeap.insert(NextCharacterCandidate(character: character, value: v))
} else if let min = minHeap.min, v > min.value {
minHeap.replaceMin(with: NextCharacterCandidate(character: character, value: v))
}
minHeap.insertIfPossible(NextCharacterCandidate(character: character, value: v))
}
// Heap
@ -238,26 +283,47 @@ class ZenzContext {
}
var score: Float = 0
struct AlternativeHighProbToken: Comparable {
static func < (lhs: AlternativeHighProbToken, rhs: AlternativeHighProbToken) -> Bool {
lhs.probabilityRatioToMaxProb < rhs.probabilityRatioToMaxProb
}
var token: llama_token
var constraint: [UInt8]
// probability
var probabilityRatioToMaxProb: Float
}
var altTokens = FixedSizeHeap<AlternativeHighProbToken>(size: 5)
for (i, token_id) in tokens.indexed().dropFirst(prompt_tokens.count) {
//
// softmaxmaxlogits
// log_probsoftmax
struct TokenAndExpLogit: Comparable {
static func < (lhs: TokenAndExpLogit, rhs: TokenAndExpLogit) -> Bool {
lhs.expLogit < rhs.expLogit
}
var token: llama_token
var expLogit: Float
}
var exp_sum: Float = 0
var max_token: llama_token = 0
var max_exp: Float = .infinity * -1
let startIndex = (i - 1 - startOffset) * Int(n_vocab)
let endIndex = (i - startOffset) * Int(n_vocab)
var tokenHeap = FixedSizeHeap<TokenAndExpLogit>(size: 3)
for index in startIndex ..< endIndex {
let v = exp(logits[index])
exp_sum += v
if max_exp < v {
max_exp = v
max_token = llama_token(index - startIndex)
}
tokenHeap.insertIfPossible(TokenAndExpLogit(token: llama_token(index - startIndex), expLogit: v))
}
guard let maxItem = tokenHeap.max else {
print("Max Item could not be found for unknown reason")
return .error
}
//
if max_token != token_id {
if max_token == llama_token_eos(model) {
if maxItem.token != token_id {
if maxItem.token == llama_token_eos(model) {
var cchars = tokens[..<i].reduce(into: []) {
$0.append(contentsOf: token_to_piece(token: $1))
}
@ -270,19 +336,39 @@ class ZenzContext {
} else {
let actual_exp: Float = exp(logits[startIndex + Int(token_id)])
// actual_exp
let preferLearnedToken = is_learned_token[i].isLearned && actual_exp * is_learned_token[i].priority > max_exp
let preferLearnedToken = is_learned_token[i].isLearned && actual_exp * is_learned_token[i].priority > maxItem.expLogit
if !preferLearnedToken {
// adding "\0"
let cchars = tokens[..<i].reduce(into: []) {
$0.append(contentsOf: token_to_piece(token: $1))
} + token_to_piece(token: max_token)
} + token_to_piece(token: maxItem.token)
return .fixRequired(prefixConstraint: cchars.dropFirst(prompt.utf8.count).map(UInt8.init))
}
}
} else {
tokenHeap.removeMax()
let prefix = tokens[..<i].reduce(into: []) {
$0.append(contentsOf: token_to_piece(token: $1))
}.dropFirst(prompt.utf8.count)
for item in tokenHeap.unordered {
altTokens.insertIfPossible(
AlternativeHighProbToken(
token: item.token,
constraint: prefix.map(UInt8.init) + token_to_piece(token: item.token).map(UInt8.init),
probabilityRatioToMaxProb: item.expLogit / maxItem.expLogit
)
)
}
}
score += log(max_exp) - log(exp_sum)
score += log(maxItem.expLogit) - log(exp_sum)
}
return .pass(score: score)
for item in altTokens.unordered.sorted(by: >) {
print("Item")
print(" Constraint", String(cString: item.constraint + [0]))
print(" Probability Gain", item.probabilityRatioToMaxProb)
}
return .pass(score: score, alternativeConstraints: altTokens.unordered.sorted(by: >).map {.init(probabilityRatio: $0.probabilityRatioToMaxProb, prefixConstraint: $0.constraint)})
}
private func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], logits: Bool) {