fix: 学習の優先度の調整を実装し、実装のミスを修正

This commit is contained in:
Miwa / Ensan
2024-07-04 23:17:00 +09:00
parent c81f02d6a3
commit 854fb4b1cf
3 changed files with 21 additions and 30 deletions

View File

@ -78,7 +78,7 @@ public struct DicdataElement: Equatable, Hashable, Sendable {
extension DicdataElement: CustomDebugStringConvertible {
public var debugDescription: String {
"(ruby: \(self.ruby), word: \(self.word), cid: (\(self.lcid), \(self.rcid)), mid: \(self.mid), value: \(self.baseValue)+\(self.adjust)=\(self.value()))"
"(ruby: \(self.ruby), word: \(self.word), cid: (\(self.lcid), \(self.rcid)), mid: \(self.mid), value: \(self.baseValue)+\(self.adjust)=\(self.value()), metadata: (isLearned: \(self.metadata.contains(.isLearned))))"
}
}

View File

@ -138,8 +138,8 @@ extension Kana2Kanji {
return .return(constraint: PrefixConstraint([]), satisfied: false)
}
//
print("update constraint:", prefixConstraint)
constraint = PrefixConstraint(prefixConstraint)
print("update constraint:", constraint)
// 使
for i in candidates.indices where i != candidateIndex {
if candidates[i].text.utf8.hasPrefix(prefixConstraint) {

View File

@ -127,19 +127,16 @@ class ZenzContext {
case wholeResult(String)
}
func getNextDicdataElement(for prefix: String, of candidate: Candidate) -> DicdataElement? {
var curPrefix = ""
for datum in candidate.data {
if curPrefix == prefix {
// prefix
return datum
} else if curPrefix.hasPrefix(prefix) {
// FIXME:
return nil
}
curPrefix.append(datum.word)
func getLearningPriority(data: DicdataElement) -> Float {
//
//
return if 1 <= data.ruby.count && data.ruby.count <= 4 {
Float(data.ruby.count + 2)
} else if 5 <= data.ruby.count && data.ruby.count <= 15 {
Float(data.ruby.count * 2)
} else {
30
}
return nil
}
func evaluate_candidate(input: String, candidate: Candidate) -> CandidateEvaluationResult {
@ -159,8 +156,9 @@ class ZenzContext {
return .error
}
let n_vocab = llama_n_vocab(model)
let is_learned_token: [Bool] = Array(repeating: false, count: prompt_tokens.count) + candidate.data.flatMap {
Array(repeating: $0.metadata.contains(.isLearned), count: self.tokenize(text: $0.word, add_bos: false).count)
let is_learned_token: [(isLearned: Bool, priority: Float)] = Array(repeating: (false, 0), count: prompt_tokens.count) + candidate.data.flatMap {
// priority
Array(repeating: ($0.metadata.contains(.isLearned), getLearningPriority(data: $0)), count: self.tokenize(text: $0.word, add_bos: false).count)
}
var score: Float = 0
@ -171,8 +169,6 @@ class ZenzContext {
var exp_sum: Float = 0
var max_token: llama_token = 0
var max_exp: Float = .infinity * -1
var actual_token = token_id
var actual_exp: Float = .infinity * -1
let startIndex = (i - 1 - startOffset) * Int(n_vocab)
let endIndex = (i - startOffset) * Int(n_vocab)
for index in startIndex ..< endIndex {
@ -182,9 +178,6 @@ class ZenzContext {
max_exp = v
max_token = llama_token(index - startIndex)
}
if index == actual_token {
actual_exp = v
}
}
//
if max_token != token_id {
@ -199,16 +192,14 @@ class ZenzContext {
let wholeResult = String(string.dropFirst(prompt.count))
return .wholeResult(wholeResult)
} else {
var cchars = tokens[..<i].reduce(into: []) {
$0.append(contentsOf: token_to_piece(token: $1))
}
let acceptedPrefix = String(cString: cchars + [0]).dropFirst(prompt.count)
if is_learned_token[i] && actual_exp * 10 > max_exp {
// zenz
// pass
} else {
let actual_exp: Float = exp(logits[startIndex + Int(token_id)])
// actual_exp
let preferLearnedToken = is_learned_token[i].isLearned && actual_exp * is_learned_token[i].priority > max_exp
if !preferLearnedToken {
// adding "\0"
cchars += token_to_piece(token: max_token)
let cchars = tokens[..<i].reduce(into: []) {
$0.append(contentsOf: token_to_piece(token: $1))
} + token_to_piece(token: max_token)
return .fixRequired(prefixConstraint: cchars.dropFirst(prompt.utf8.count).map(UInt8.init))
}
}