mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-08-22 15:05:26 +00:00
feat: add next character prediction API and add it to session cli
This commit is contained in:
@ -1,5 +1,7 @@
|
||||
import llama
|
||||
import SwiftUtils
|
||||
import HeapModule
|
||||
import Algorithms
|
||||
import Foundation
|
||||
|
||||
enum ZenzError: LocalizedError {
|
||||
@ -139,6 +141,62 @@ class ZenzContext {
|
||||
}
|
||||
}
|
||||
|
||||
func predict_next_character(leftSideContext: String, count: Int) -> [(character: Character, value: Float)] {
|
||||
struct NextCharacterCandidate: Comparable {
|
||||
static func < (lhs: NextCharacterCandidate, rhs: NextCharacterCandidate) -> Bool {
|
||||
lhs.value < rhs.value
|
||||
}
|
||||
var character: Character
|
||||
var value: Float
|
||||
}
|
||||
|
||||
// 文末を目指して生成するためのプロンプト
|
||||
// \u{EE01}を停止トークンとみなせる
|
||||
let prompt_tokens = self.tokenize(text: "\u{EE00}。\u{EE02}\(leftSideContext)", add_bos: false)
|
||||
let startOffset = prompt_tokens.count - 1
|
||||
|
||||
guard let logits = self.get_logits(tokens: prompt_tokens, logits_start_index: startOffset) else {
|
||||
print("logits unavailable")
|
||||
return []
|
||||
}
|
||||
|
||||
let n_vocab = llama_n_vocab(model)
|
||||
var exp_sum: Float = 0
|
||||
let startIndex = (prompt_tokens.count - 1 - startOffset) * Int(n_vocab)
|
||||
let endIndex = (prompt_tokens.count - startOffset) * Int(n_vocab)
|
||||
|
||||
// Min-Heapを使用してn-bestを計算
|
||||
var minHeap: Heap<NextCharacterCandidate> = .init()
|
||||
let token_to_penalty_weight: [llama_token: Float] = prompt_tokens.indexed().reduce(into: [:]) { dict, item in
|
||||
let (index, token) = item
|
||||
// 現在位置から遠いほど減衰させる
|
||||
dict[token, default: 0] += 2 / Float(prompt_tokens.count - index)
|
||||
}
|
||||
|
||||
for index in startIndex..<endIndex {
|
||||
let token = llama_token(index - startIndex)
|
||||
let repeat_penalty = Float(1.0 + token_to_penalty_weight[token, default: 0])
|
||||
let v = exp(logits[index] / repeat_penalty)
|
||||
exp_sum += v
|
||||
|
||||
let tokenPieceData = Data((token_to_piece(token: token)).map(UInt8.init))
|
||||
let character: Character
|
||||
if let validCharacter = String(data: tokenPieceData, encoding: .utf8), let c = validCharacter.first {
|
||||
character = c
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
if minHeap.count < count {
|
||||
minHeap.insert(NextCharacterCandidate(character: character, value: v))
|
||||
} else if let min = minHeap.min, v > min.value {
|
||||
minHeap.replaceMin(with: NextCharacterCandidate(character: character, value: v))
|
||||
}
|
||||
}
|
||||
|
||||
// Heapからソートして結果を取り出す
|
||||
return minHeap.unordered.sorted { $0.value > $1.value }.map { ($0.character, $0.value / exp_sum) }
|
||||
}
|
||||
|
||||
func evaluate_candidate(input: String, candidate: Candidate, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
|
||||
print("Evaluate", candidate)
|
||||
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
|
||||
|
Reference in New Issue
Block a user