mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-08-22 15:05:26 +00:00
fix: 学習の優先度の調整を実装し、実装のミスを修正
This commit is contained in:
@ -78,7 +78,7 @@ public struct DicdataElement: Equatable, Hashable, Sendable {
|
||||
|
||||
extension DicdataElement: CustomDebugStringConvertible {
|
||||
public var debugDescription: String {
|
||||
"(ruby: \(self.ruby), word: \(self.word), cid: (\(self.lcid), \(self.rcid)), mid: \(self.mid), value: \(self.baseValue)+\(self.adjust)=\(self.value()))"
|
||||
"(ruby: \(self.ruby), word: \(self.word), cid: (\(self.lcid), \(self.rcid)), mid: \(self.mid), value: \(self.baseValue)+\(self.adjust)=\(self.value()), metadata: (isLearned: \(self.metadata.contains(.isLearned))))"
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -138,8 +138,8 @@ extension Kana2Kanji {
|
||||
return .return(constraint: PrefixConstraint([]), satisfied: false)
|
||||
}
|
||||
// 制約が得られたので、更新する
|
||||
print("update constraint:", prefixConstraint)
|
||||
constraint = PrefixConstraint(prefixConstraint)
|
||||
print("update constraint:", constraint)
|
||||
// もし制約を満たす候補があるならそれを使って再レビューチャレンジを戦うことで、推論を減らせる
|
||||
for i in candidates.indices where i != candidateIndex {
|
||||
if candidates[i].text.utf8.hasPrefix(prefixConstraint) {
|
||||
|
@ -127,19 +127,16 @@ class ZenzContext {
|
||||
case wholeResult(String)
|
||||
}
|
||||
|
||||
func getNextDicdataElement(for prefix: String, of candidate: Candidate) -> DicdataElement? {
|
||||
var curPrefix = ""
|
||||
for datum in candidate.data {
|
||||
if curPrefix == prefix {
|
||||
// prefixの直後の単語を返したい
|
||||
return datum
|
||||
} else if curPrefix.hasPrefix(prefix) {
|
||||
// FIXME: 積極的な一語化の方針と相性が良くないはず
|
||||
return nil
|
||||
}
|
||||
curPrefix.append(datum.word)
|
||||
func getLearningPriority(data: DicdataElement) -> Float {
|
||||
// 文字数の長い候補ほど優先的に適用されるようにする
|
||||
// 積極的な複合語化の効果を期待
|
||||
return if 1 <= data.ruby.count && data.ruby.count <= 4 {
|
||||
Float(data.ruby.count + 2)
|
||||
} else if 5 <= data.ruby.count && data.ruby.count <= 15 {
|
||||
Float(data.ruby.count * 2)
|
||||
} else {
|
||||
30
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func evaluate_candidate(input: String, candidate: Candidate) -> CandidateEvaluationResult {
|
||||
@ -159,8 +156,9 @@ class ZenzContext {
|
||||
return .error
|
||||
}
|
||||
let n_vocab = llama_n_vocab(model)
|
||||
let is_learned_token: [Bool] = Array(repeating: false, count: prompt_tokens.count) + candidate.data.flatMap {
|
||||
Array(repeating: $0.metadata.contains(.isLearned), count: self.tokenize(text: $0.word, add_bos: false).count)
|
||||
let is_learned_token: [(isLearned: Bool, priority: Float)] = Array(repeating: (false, 0), count: prompt_tokens.count) + candidate.data.flatMap {
|
||||
// priorityは文字数にする→文字数が長いほど優先される
|
||||
Array(repeating: ($0.metadata.contains(.isLearned), getLearningPriority(data: $0)), count: self.tokenize(text: $0.word, add_bos: false).count)
|
||||
}
|
||||
|
||||
var score: Float = 0
|
||||
@ -171,8 +169,6 @@ class ZenzContext {
|
||||
var exp_sum: Float = 0
|
||||
var max_token: llama_token = 0
|
||||
var max_exp: Float = .infinity * -1
|
||||
var actual_token = token_id
|
||||
var actual_exp: Float = .infinity * -1
|
||||
let startIndex = (i - 1 - startOffset) * Int(n_vocab)
|
||||
let endIndex = (i - startOffset) * Int(n_vocab)
|
||||
for index in startIndex ..< endIndex {
|
||||
@ -182,9 +178,6 @@ class ZenzContext {
|
||||
max_exp = v
|
||||
max_token = llama_token(index - startIndex)
|
||||
}
|
||||
if index == actual_token {
|
||||
actual_exp = v
|
||||
}
|
||||
}
|
||||
// ここで最も良い候補であったかをチェックする
|
||||
if max_token != token_id {
|
||||
@ -199,16 +192,14 @@ class ZenzContext {
|
||||
let wholeResult = String(string.dropFirst(prompt.count))
|
||||
return .wholeResult(wholeResult)
|
||||
} else {
|
||||
var cchars = tokens[..<i].reduce(into: []) {
|
||||
$0.append(contentsOf: token_to_piece(token: $1))
|
||||
}
|
||||
let acceptedPrefix = String(cString: cchars + [0]).dropFirst(prompt.count)
|
||||
if is_learned_token[i] && actual_exp * 10 > max_exp {
|
||||
// 学習による候補であり、zenzによる評価も許容範囲なので、素通しする
|
||||
// pass
|
||||
} else {
|
||||
let actual_exp: Float = exp(logits[startIndex + Int(token_id)])
|
||||
// 学習されたトークンであり、なおかつactual_expのある程度大きければ、学習されたトークンを優先する
|
||||
let preferLearnedToken = is_learned_token[i].isLearned && actual_exp * is_learned_token[i].priority > max_exp
|
||||
if !preferLearnedToken {
|
||||
// adding "\0"
|
||||
cchars += token_to_piece(token: max_token)
|
||||
let cchars = tokens[..<i].reduce(into: []) {
|
||||
$0.append(contentsOf: token_to_piece(token: $1))
|
||||
} + token_to_piece(token: max_token)
|
||||
return .fixRequired(prefixConstraint: cchars.dropFirst(prompt.utf8.count).map(UInt8.init))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user