feat: add next character prediction API and add it to session cli

This commit is contained in:
ensan-hcl
2024-08-01 15:57:56 +09:00
parent 60e4578045
commit 445361d6dc
7 changed files with 151 additions and 8 deletions

View File

@ -1,5 +1,7 @@
import llama
import SwiftUtils
import HeapModule
import Algorithms
import Foundation
enum ZenzError: LocalizedError {
@ -139,6 +141,62 @@ class ZenzContext {
}
}
func predict_next_character(leftSideContext: String, count: Int) -> [(character: Character, value: Float)] {
struct NextCharacterCandidate: Comparable {
static func < (lhs: NextCharacterCandidate, rhs: NextCharacterCandidate) -> Bool {
lhs.value < rhs.value
}
var character: Character
var value: Float
}
//
// \u{EE01}
let prompt_tokens = self.tokenize(text: "\u{EE00}。\u{EE02}\(leftSideContext)", add_bos: false)
let startOffset = prompt_tokens.count - 1
guard let logits = self.get_logits(tokens: prompt_tokens, logits_start_index: startOffset) else {
print("logits unavailable")
return []
}
let n_vocab = llama_n_vocab(model)
var exp_sum: Float = 0
let startIndex = (prompt_tokens.count - 1 - startOffset) * Int(n_vocab)
let endIndex = (prompt_tokens.count - startOffset) * Int(n_vocab)
// Min-Heap使n-best
var minHeap: Heap<NextCharacterCandidate> = .init()
let token_to_penalty_weight: [llama_token: Float] = prompt_tokens.indexed().reduce(into: [:]) { dict, item in
let (index, token) = item
//
dict[token, default: 0] += 2 / Float(prompt_tokens.count - index)
}
for index in startIndex..<endIndex {
let token = llama_token(index - startIndex)
let repeat_penalty = Float(1.0 + token_to_penalty_weight[token, default: 0])
let v = exp(logits[index] / repeat_penalty)
exp_sum += v
let tokenPieceData = Data((token_to_piece(token: token)).map(UInt8.init))
let character: Character
if let validCharacter = String(data: tokenPieceData, encoding: .utf8), let c = validCharacter.first {
character = c
} else {
continue
}
if minHeap.count < count {
minHeap.insert(NextCharacterCandidate(character: character, value: v))
} else if let min = minHeap.min, v > min.value {
minHeap.replaceMin(with: NextCharacterCandidate(character: character, value: v))
}
}
// Heap
return minHeap.unordered.sorted { $0.value > $1.value }.map { ($0.character, $0.value / exp_sum) }
}
func evaluate_candidate(input: String, candidate: Candidate, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
print("Evaluate", candidate)
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'