feat: enable rich n-best calculation

This commit is contained in:
Miwa / Ensan
2024-08-07 01:23:03 +09:00
parent 6f901eecdd
commit 1971f6382b
2 changed files with 148 additions and 27 deletions

View File

@ -62,6 +62,11 @@ extension Kana2Kanji {
print("initial constraint", constraint)
let eosNode = LatticeNode.EOSNode
var nodes: Kana2Kanji.Nodes = []
var constructedCandidates: [(RegisteredNode, Candidate)] = []
var insertedCandidates: [(RegisteredNode, Candidate)] = []
defer {
eosNode.prevs = insertedCandidates.map(\.0)
}
var inferenceLimit = inferenceLimit
while true {
let start = Date()
@ -78,6 +83,7 @@ extension Kana2Kanji {
nodes = draftResult.nodes
}
let candidates = draftResult.result.getCandidateData().map(self.processClauseCandidate)
constructedCandidates.append(contentsOf: zip(draftResult.result.prevs, candidates))
var best: (Int, Candidate)? = nil
for (i, cand) in candidates.enumerated() {
if let (_, c) = best, cand.value > c.value {
@ -92,10 +98,12 @@ extension Kana2Kanji {
//
return (eosNode, nodes, ZenzaiCache(inputData, constraint: PrefixConstraint([]), satisfyingCandidate: nil))
}
print("Constrained draft modeling", -start.timeIntervalSinceNow)
reviewLoop: while true {
// results
eosNode.prevs.insert(draftResult.result.prevs[index], at: 0)
// N-Best
insertedCandidates.insert((draftResult.result.prevs[index], candidate), at: 0)
if inferenceLimit == 0 {
print("inference limit! \(candidate.text) is used for excuse")
// When inference occurs more than maximum times, then just return result at this point
@ -110,7 +118,34 @@ extension Kana2Kanji {
constraint: &constraint
)
switch nextAction {
case .return(let constraint, let satisfied):
case .return(let constraint, let alternativeConstraints, let satisfied):
// alternativeConstraintsinsertedCandidates
for alternativeConstraint in alternativeConstraints.reversed() where alternativeConstraint.probabilityRatio > 0.25 {
// constructed candidatesalternativeConstraint.prefixConstraint
let mostLiklyCandidate = constructedCandidates.filter {
$0.1.text.utf8.hasPrefix(alternativeConstraint.prefixConstraint)
}.max {
$0.1.value < $1.1.value
}
if let mostLiklyCandidate {
// 0
insertedCandidates.insert(mostLiklyCandidate, at: 1)
} else if alternativeConstraint.probabilityRatio > 0.5 {
//
let draftResult = self.kana2lattice_all_with_prefix_constraint(inputData, N_best: 3, constraint: PrefixConstraint(alternativeConstraint.prefixConstraint))
let candidates = draftResult.result.getCandidateData().map(self.processClauseCandidate)
let best: (Int, Candidate)? = candidates.enumerated().reduce(into: nil) { best, pair in
if let (_, c) = best, pair.1.value > c.value {
best = pair
} else if best == nil {
best = pair
}
}
if let (index, candidate) = best {
insertedCandidates.insert((draftResult.result.prevs[index], candidate), at: 1)
}
}
}
if satisfied {
return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
} else {
@ -127,7 +162,7 @@ extension Kana2Kanji {
}
private enum NextAction {
case `return`(constraint: PrefixConstraint, satisfied: Bool)
case `return`(constraint: PrefixConstraint, alternativeConstraints: [ZenzContext.CandidateEvaluationResult.AlternativeConstraint], satisfied: Bool)
case `continue`
case `retry`(candidateIndex: Int)
}
@ -142,16 +177,16 @@ extension Kana2Kanji {
case .error:
//
print("error")
return .return(constraint: constraint, satisfied: false)
case .pass(let score):
return .return(constraint: constraint, alternativeConstraints: [], satisfied: false)
case .pass(let score, let alternativeConstraints):
//
print("passed:", score)
return .return(constraint: constraint, satisfied: true)
return .return(constraint: constraint, alternativeConstraints: alternativeConstraints, satisfied: true)
case .fixRequired(let prefixConstraint):
// 2
if constraint.constraint == prefixConstraint {
print("same constraint:", prefixConstraint)
return .return(constraint: PrefixConstraint([]), satisfied: false)
return .return(constraint: PrefixConstraint([]), alternativeConstraints: [], satisfied: false)
}
//
constraint = PrefixConstraint(prefixConstraint)
@ -169,7 +204,7 @@ extension Kana2Kanji {
// 2
if constraint == newConstraint {
print("same constraint:", constraint)
return .return(constraint: PrefixConstraint([]), satisfied: false)
return .return(constraint: PrefixConstraint([]), alternativeConstraints: [], satisfied: false)
}
//
print("update whole constraint:", wholeConstraint)

View File

@ -4,6 +4,50 @@ import HeapModule
import Algorithms
import Foundation
struct FixedSizeHeap<Element: Comparable> {
private var size: Int
private var heap: Heap<Element>
init(size: Int) {
self.size = size
self.heap = []
}
mutating func removeMax() {
self.heap.removeMax()
}
mutating func removeMin() {
self.heap.removeMin()
}
@discardableResult
mutating func insertIfPossible(_ element: Element) -> Bool {
if self.heap.count < self.size {
self.heap.insert(element)
return true
} else if let min = self.heap.min, element > min {
self.heap.replaceMin(with: element)
return true
} else {
return false
}
}
var unordered: [Element] {
self.heap.unordered
}
var max: Element? {
self.heap.max
}
var min: Element? {
self.heap.min
}
}
enum ZenzError: LocalizedError {
case couldNotLoadModel(path: String)
case couldNotLoadContext
@ -124,9 +168,14 @@ class ZenzContext {
enum CandidateEvaluationResult: Sendable, Equatable, Hashable {
case error
case pass(score: Float)
case pass(score: Float, alternativeConstraints: [AlternativeConstraint])
case fixRequired(prefixConstraint: [UInt8])
case wholeResult(String)
struct AlternativeConstraint: Sendable, Equatable, Hashable {
var probabilityRatio: Float
var prefixConstraint: [UInt8]
}
}
func getLearningPriority(data: DicdataElement) -> Float {
@ -166,7 +215,7 @@ class ZenzContext {
let endIndex = (prompt_tokens.count - startOffset) * Int(n_vocab)
// Min-Heap使n-best
var minHeap: Heap<NextCharacterCandidate> = .init()
var minHeap: FixedSizeHeap<NextCharacterCandidate> = .init(size: count)
let token_to_penalty_weight: [llama_token: Float] = prompt_tokens.indexed().reduce(into: [:]) { dict, item in
let (index, token) = item
//
@ -186,11 +235,7 @@ class ZenzContext {
} else {
continue
}
if minHeap.count < count {
minHeap.insert(NextCharacterCandidate(character: character, value: v))
} else if let min = minHeap.min, v > min.value {
minHeap.replaceMin(with: NextCharacterCandidate(character: character, value: v))
}
minHeap.insertIfPossible(NextCharacterCandidate(character: character, value: v))
}
// Heap
@ -238,26 +283,47 @@ class ZenzContext {
}
var score: Float = 0
struct AlternativeHighProbToken: Comparable {
static func < (lhs: AlternativeHighProbToken, rhs: AlternativeHighProbToken) -> Bool {
lhs.probabilityRatioToMaxProb < rhs.probabilityRatioToMaxProb
}
var token: llama_token
var constraint: [UInt8]
// probability
var probabilityRatioToMaxProb: Float
}
var altTokens = FixedSizeHeap<AlternativeHighProbToken>(size: 5)
for (i, token_id) in tokens.indexed().dropFirst(prompt_tokens.count) {
//
// softmaxmaxlogits
// log_probsoftmax
struct TokenAndExpLogit: Comparable {
static func < (lhs: TokenAndExpLogit, rhs: TokenAndExpLogit) -> Bool {
lhs.expLogit < rhs.expLogit
}
var token: llama_token
var expLogit: Float
}
var exp_sum: Float = 0
var max_token: llama_token = 0
var max_exp: Float = .infinity * -1
let startIndex = (i - 1 - startOffset) * Int(n_vocab)
let endIndex = (i - startOffset) * Int(n_vocab)
var tokenHeap = FixedSizeHeap<TokenAndExpLogit>(size: 3)
for index in startIndex ..< endIndex {
let v = exp(logits[index])
exp_sum += v
if max_exp < v {
max_exp = v
max_token = llama_token(index - startIndex)
}
tokenHeap.insertIfPossible(TokenAndExpLogit(token: llama_token(index - startIndex), expLogit: v))
}
guard let maxItem = tokenHeap.max else {
print("Max Item could not be found for unknown reason")
return .error
}
//
if max_token != token_id {
if max_token == llama_token_eos(model) {
if maxItem.token != token_id {
if maxItem.token == llama_token_eos(model) {
var cchars = tokens[..<i].reduce(into: []) {
$0.append(contentsOf: token_to_piece(token: $1))
}
@ -270,19 +336,39 @@ class ZenzContext {
} else {
let actual_exp: Float = exp(logits[startIndex + Int(token_id)])
// actual_exp
let preferLearnedToken = is_learned_token[i].isLearned && actual_exp * is_learned_token[i].priority > max_exp
let preferLearnedToken = is_learned_token[i].isLearned && actual_exp * is_learned_token[i].priority > maxItem.expLogit
if !preferLearnedToken {
// adding "\0"
let cchars = tokens[..<i].reduce(into: []) {
$0.append(contentsOf: token_to_piece(token: $1))
} + token_to_piece(token: max_token)
} + token_to_piece(token: maxItem.token)
return .fixRequired(prefixConstraint: cchars.dropFirst(prompt.utf8.count).map(UInt8.init))
}
}
} else {
tokenHeap.removeMax()
let prefix = tokens[..<i].reduce(into: []) {
$0.append(contentsOf: token_to_piece(token: $1))
}.dropFirst(prompt.utf8.count)
for item in tokenHeap.unordered {
altTokens.insertIfPossible(
AlternativeHighProbToken(
token: item.token,
constraint: prefix.map(UInt8.init) + token_to_piece(token: item.token).map(UInt8.init),
probabilityRatioToMaxProb: item.expLogit / maxItem.expLogit
)
)
}
}
score += log(max_exp) - log(exp_sum)
score += log(maxItem.expLogit) - log(exp_sum)
}
return .pass(score: score)
for item in altTokens.unordered.sorted(by: >) {
print("Item")
print(" Constraint", String(cString: item.constraint + [0]))
print(" Probability Gain", item.probabilityRatioToMaxProb)
}
return .pass(score: score, alternativeConstraints: altTokens.unordered.sorted(by: >).map {.init(probabilityRatio: $0.probabilityRatioToMaxProb, prefixConstraint: $0.constraint)})
}
private func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], logits: Bool) {