feat: 実装を改善し、単語の途中に非argmaxな文字が出現した場合の処理を安定化

This commit is contained in:
Miwa / Ensan
2024-07-04 22:00:35 +09:00
parent d130b8168f
commit a50464304c

View File

@ -159,6 +159,9 @@ class ZenzContext {
return .error
}
let n_vocab = llama_n_vocab(model)
let is_learned_token: [Bool] = Array(repeating: false, count: prompt_tokens.count) + candidate.data.flatMap {
Array(repeating: $0.metadata.contains(.isLearned), count: self.tokenize(text: $0.word, add_bos: false).count)
}
var score: Float = 0
for (i, token_id) in tokens.indexed().dropFirst(prompt_tokens.count) {
@ -200,10 +203,7 @@ class ZenzContext {
$0.append(contentsOf: token_to_piece(token: $1))
}
let acceptedPrefix = String(cString: cchars + [0]).dropFirst(prompt.count)
if let nextDicdataElement = getNextDicdataElement(for: String(acceptedPrefix), of: candidate),
nextDicdataElement.metadata.contains(.isLearned),
actual_exp * 10 > max_exp
{
if is_learned_token[i] && actual_exp * 10 > max_exp {
// zenz
// pass
} else {