mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-12-03 02:58:27 +00:00
feat: support N-gram based personalization
This commit is contained in:
@@ -19,6 +19,12 @@ extension Subcommands {
|
||||
var configZenzaiInferenceLimit: Int = .max
|
||||
@Flag(name: [.customLong("config_zenzai_ignore_left_context")], help: "ignore left_context")
|
||||
var configZenzaiIgnoreLeftContext: Bool = false
|
||||
@Option(name: [.customLong("config_zenzai_base_lm")], help: "Marisa files for Base LM.")
|
||||
var configZenzaiBaseLM: String?
|
||||
@Option(name: [.customLong("config_zenzai_personal_lm")], help: "Marisa files for Personal LM.")
|
||||
var configZenzaiPersonalLM: String?
|
||||
@Option(name: [.customLong("config_zenzai_personalization_alpha")], help: "Strength of personalization (0.5 by default)")
|
||||
var configZenzaiPersonalizationAlpha: Float = 0.5
|
||||
|
||||
static let configuration = CommandConfiguration(commandName: "evaluate", abstract: "Evaluate quality of Conversion for input data.")
|
||||
|
||||
@@ -87,6 +93,20 @@ extension Subcommands {
|
||||
}
|
||||
|
||||
func requestOptions(leftSideContext: String?) -> ConvertRequestOptions {
|
||||
let personalizationMode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode?
|
||||
if let base = self.configZenzaiBaseLM, let personal = self.configZenzaiPersonalLM {
|
||||
personalizationMode = .init(
|
||||
baseNgramLanguageModel: base,
|
||||
personalNgramLanguageModel: personal,
|
||||
n: 5,
|
||||
d: 0.75,
|
||||
alpha: self.configZenzaiPersonalizationAlpha
|
||||
)
|
||||
} else if self.configZenzaiBaseLM != nil || self.configZenzaiPersonalLM != nil {
|
||||
fatalError("Both --config_zenzai_base_lm and --config_zenzai_personal_lm must be set")
|
||||
} else {
|
||||
personalizationMode = nil
|
||||
}
|
||||
var option: ConvertRequestOptions = .withDefaultDictionary(
|
||||
N_best: self.configNBest,
|
||||
requireJapanesePrediction: false,
|
||||
@@ -102,7 +122,7 @@ extension Subcommands {
|
||||
shouldResetMemory: false,
|
||||
memoryDirectoryURL: URL(fileURLWithPath: ""),
|
||||
sharedContainerURL: URL(fileURLWithPath: ""),
|
||||
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit, versionDependentMode: .v2(.init(leftSideContext: self.configZenzaiIgnoreLeftContext ? nil : leftSideContext))),
|
||||
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit, personalizationMode: personalizationMode, versionDependentMode: .v2(.init(leftSideContext: self.configZenzaiIgnoreLeftContext ? nil : leftSideContext))),
|
||||
metadata: .init(versionString: "anco for debugging")
|
||||
)
|
||||
option.requestQuery = .完全一致
|
||||
|
||||
@@ -38,7 +38,7 @@ extension Subcommands {
|
||||
shouldResetMemory: false,
|
||||
memoryDirectoryURL: URL(fileURLWithPath: ""),
|
||||
sharedContainerURL: URL(fileURLWithPath: ""),
|
||||
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: .max, versionDependentMode: .v3(.init())),
|
||||
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: .max, personalizationMode: nil, versionDependentMode: .v3(.init())),
|
||||
metadata: .init(versionString: "anco for debugging")
|
||||
)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,12 @@ extension Subcommands {
|
||||
var zenzWeightPath: String = ""
|
||||
@Option(name: [.customLong("config_zenzai_inference_limit")], help: "inference limit for zenzai.")
|
||||
var configZenzaiInferenceLimit: Int = .max
|
||||
@Option(name: [.customLong("config_zenzai_base_lm")], help: "Marisa files for Base LM.")
|
||||
var configZenzaiBaseLM: String?
|
||||
@Option(name: [.customLong("config_zenzai_personal_lm")], help: "Marisa files for Personal LM.")
|
||||
var configZenzaiPersonalLM: String?
|
||||
@Option(name: [.customLong("config_zenzai_personalization_alpha")], help: "Strength of personalization (0.5 by default)")
|
||||
var configZenzaiPersonalizationAlpha: Float = 0.5
|
||||
|
||||
@Flag(name: [.customLong("disable_prediction")], help: "Disable producing prediction candidates.")
|
||||
var disablePrediction = false
|
||||
@@ -55,6 +61,20 @@ extension Subcommands {
|
||||
}
|
||||
|
||||
func requestOptions() -> ConvertRequestOptions {
|
||||
let personalizationMode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode?
|
||||
if let base = self.configZenzaiBaseLM, let personal = self.configZenzaiPersonalLM {
|
||||
personalizationMode = .init(
|
||||
baseNgramLanguageModel: base,
|
||||
personalNgramLanguageModel: personal,
|
||||
n: 5,
|
||||
d: 0.75,
|
||||
alpha: self.configZenzaiPersonalizationAlpha
|
||||
)
|
||||
} else if self.configZenzaiBaseLM != nil || self.configZenzaiPersonalLM != nil {
|
||||
fatalError("Both --config_zenzai_base_lm and --config_zenzai_personal_lm must be set")
|
||||
} else {
|
||||
personalizationMode = nil
|
||||
}
|
||||
var option: ConvertRequestOptions = .withDefaultDictionary(
|
||||
N_best: self.onlyWholeConversion ? max(self.configNBest, self.displayTopN) : self.configNBest,
|
||||
requireJapanesePrediction: !self.onlyWholeConversion && !self.disablePrediction,
|
||||
@@ -70,7 +90,7 @@ extension Subcommands {
|
||||
shouldResetMemory: false,
|
||||
memoryDirectoryURL: URL(fileURLWithPath: ""),
|
||||
sharedContainerURL: URL(fileURLWithPath: ""),
|
||||
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit),
|
||||
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit, personalizationMode: personalizationMode),
|
||||
metadata: .init(versionString: "anco for debugging")
|
||||
)
|
||||
if self.onlyWholeConversion {
|
||||
|
||||
@@ -35,6 +35,12 @@ extension Subcommands {
|
||||
var zenzV1 = false
|
||||
@Flag(name: [.customLong("zenz_v2")], help: "Use zenz_v2 model.")
|
||||
var zenzV2 = false
|
||||
@Option(name: [.customLong("config_zenzai_base_lm")], help: "Marisa files for Base LM.")
|
||||
var configZenzaiBaseLM: String?
|
||||
@Option(name: [.customLong("config_zenzai_personal_lm")], help: "Marisa files for Personal LM.")
|
||||
var configZenzaiPersonalLM: String?
|
||||
@Option(name: [.customLong("config_zenzai_personalization_alpha")], help: "Strength of personalization (0.5 by default)")
|
||||
var configZenzaiPersonalizationAlpha: Float = 0.5
|
||||
|
||||
static let configuration = CommandConfiguration(commandName: "session", abstract: "Start session for incremental input.")
|
||||
|
||||
@@ -201,6 +207,20 @@ extension Subcommands {
|
||||
} else {
|
||||
.v3(.init(profile: self.configZenzaiProfile, topic: self.configZenzaiTopic, leftSideContext: leftSideContext))
|
||||
}
|
||||
let personalizationMode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode?
|
||||
if let base = self.configZenzaiBaseLM, let personal = self.configZenzaiPersonalLM {
|
||||
personalizationMode = .init(
|
||||
baseNgramLanguageModel: base,
|
||||
personalNgramLanguageModel: personal,
|
||||
n: 5,
|
||||
d: 0.75,
|
||||
alpha: self.configZenzaiPersonalizationAlpha
|
||||
)
|
||||
} else if self.configZenzaiBaseLM != nil || self.configZenzaiPersonalLM != nil {
|
||||
fatalError("Both --config_zenzai_base_lm and --config_zenzai_personal_lm must be set")
|
||||
} else {
|
||||
personalizationMode = nil
|
||||
}
|
||||
var option: ConvertRequestOptions = .withDefaultDictionary(
|
||||
N_best: self.onlyWholeConversion ? max(self.configNBest, self.displayTopN) : self.configNBest,
|
||||
requireJapanesePrediction: !self.onlyWholeConversion && !self.disablePrediction,
|
||||
@@ -220,6 +240,7 @@ extension Subcommands {
|
||||
weight: URL(string: self.zenzWeightPath)!,
|
||||
inferenceLimit: self.configZenzaiInferenceLimit,
|
||||
requestRichCandidates: self.configRequestRichCandidates,
|
||||
personalizationMode: personalizationMode,
|
||||
versionDependentMode: zenzaiVersionDependentMode
|
||||
),
|
||||
metadata: .init(versionString: "anco for debugging")
|
||||
|
||||
@@ -224,6 +224,21 @@ public struct ConvertRequestOptions: Sendable {
|
||||
}
|
||||
|
||||
public struct ZenzaiMode: Sendable, Equatable {
|
||||
public struct PersonalizationMode: Sendable, Equatable {
|
||||
public init(baseNgramLanguageModel: String, personalNgramLanguageModel: String, n: Int = 5, d: Double = 0.75, alpha: Float = 0.5) {
|
||||
self.baseNgramLanguageModel = baseNgramLanguageModel
|
||||
self.personalNgramLanguageModel = personalNgramLanguageModel
|
||||
self.n = n
|
||||
self.d = d
|
||||
self.alpha = alpha
|
||||
}
|
||||
|
||||
var n: Int = 5
|
||||
var d: Double = 0.75
|
||||
var alpha: Float = 0.5
|
||||
var baseNgramLanguageModel: String
|
||||
var personalNgramLanguageModel: String
|
||||
}
|
||||
public static let off = ZenzaiMode(
|
||||
enabled: false,
|
||||
weightURL: URL(fileURLWithPath: ""),
|
||||
@@ -237,13 +252,15 @@ public struct ConvertRequestOptions: Sendable {
|
||||
/// - weight: path for model weight (gguf)
|
||||
/// - inferenceLimit: applying inference count limitation. Smaller limit makes conversion faster but quality will be worse. (Default: 10)
|
||||
/// - requestRichCandidates: when this flag is true, the converter spends more time but generate richer N-Best candidates for candidate list view. Usually this option is not recommended for live conversion.
|
||||
/// - personalizationMode: values for personalization.
|
||||
/// - versionDependentMode: specify zenz model version and its configuration.
|
||||
public static func on(weight: URL, inferenceLimit: Int = 10, requestRichCandidates: Bool = false, versionDependentMode: ZenzaiVersionDependentMode = .v3(.init())) -> Self {
|
||||
public static func on(weight: URL, inferenceLimit: Int = 10, requestRichCandidates: Bool = false, personalizationMode: PersonalizationMode?, versionDependentMode: ZenzaiVersionDependentMode = .v3(.init())) -> Self {
|
||||
ZenzaiMode(
|
||||
enabled: true,
|
||||
weightURL: weight,
|
||||
inferenceLimit: inferenceLimit,
|
||||
requestRichCandidates: requestRichCandidates,
|
||||
personalizationMode: personalizationMode,
|
||||
versionDependentMode: versionDependentMode
|
||||
)
|
||||
}
|
||||
@@ -251,6 +268,7 @@ public struct ConvertRequestOptions: Sendable {
|
||||
var weightURL: URL
|
||||
var inferenceLimit: Int
|
||||
var requestRichCandidates: Bool
|
||||
var personalizationMode: PersonalizationMode?
|
||||
var versionDependentMode: ZenzaiVersionDependentMode
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
import Foundation
|
||||
import SwiftUtils
|
||||
import SwiftNGram
|
||||
|
||||
/// かな漢字変換の管理を受け持つクラス
|
||||
@MainActor public final class KanaKanjiConverter {
|
||||
@@ -25,14 +26,16 @@ import SwiftUtils
|
||||
private var nodes: [[LatticeNode]] = []
|
||||
private var completedData: Candidate?
|
||||
private var lastData: DicdataElement?
|
||||
/// Zenzaiのためのzenz-v1モデル
|
||||
/// Zenzaiのためのzenzモデル
|
||||
private var zenz: Zenz? = nil
|
||||
private var zenzaiCache: Kana2Kanji.ZenzaiCache? = nil
|
||||
private var zenzaiPersonalization: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)?
|
||||
public private(set) var zenzStatus: String = ""
|
||||
|
||||
/// リセットする関数
|
||||
public func stopComposition() {
|
||||
self.zenz?.endSession()
|
||||
self.zenzaiPersonalization = nil
|
||||
self.zenzaiCache = nil
|
||||
self.previousInputData = nil
|
||||
self.nodes = []
|
||||
@@ -40,6 +43,20 @@ import SwiftUtils
|
||||
self.lastData = nil
|
||||
}
|
||||
|
||||
private func getZenzaiPersonalization(mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode?) -> (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)? {
|
||||
guard let mode else {
|
||||
return nil
|
||||
}
|
||||
if let zenzaiPersonalization, zenzaiPersonalization.mode == mode {
|
||||
return zenzaiPersonalization
|
||||
}
|
||||
let tokenizer = ZenzTokenizer()
|
||||
let baseModel = LM(baseFilename: mode.baseNgramLanguageModel, n: mode.n, d: mode.d, tokenizer: tokenizer)
|
||||
let personalModel = LM(baseFilename: mode.personalNgramLanguageModel, n: mode.n, d: mode.d, tokenizer: tokenizer)
|
||||
self.zenzaiPersonalization = (mode, baseModel, personalModel)
|
||||
return (mode, baseModel, personalModel)
|
||||
}
|
||||
|
||||
package func getModel(modelURL: URL) -> Zenz? {
|
||||
if let model = self.zenz, model.resourceURL == modelURL {
|
||||
self.zenzStatus = "load \(modelURL.absoluteString)"
|
||||
@@ -594,6 +611,7 @@ import SwiftUtils
|
||||
return nil
|
||||
}
|
||||
|
||||
print("ConvertToLattice ", zenzaiMode)
|
||||
// FIXME: enable cache based zenzai
|
||||
if zenzaiMode.enabled, let model = self.getModel(modelURL: zenzaiMode.weightURL) {
|
||||
let (result, nodes, cache) = self.converter.all_zenzai(
|
||||
@@ -602,6 +620,7 @@ import SwiftUtils
|
||||
zenzaiCache: self.zenzaiCache,
|
||||
inferenceLimit: zenzaiMode.inferenceLimit,
|
||||
requestRichCandidates: zenzaiMode.requestRichCandidates,
|
||||
personalizationMode: self.getZenzaiPersonalization(mode: zenzaiMode.personalizationMode),
|
||||
versionDependentConfig: zenzaiMode.versionDependentMode
|
||||
)
|
||||
self.zenzaiCache = cache
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import Foundation
|
||||
import SwiftUtils
|
||||
import SwiftNGram
|
||||
|
||||
extension Kana2Kanji {
|
||||
struct ZenzaiCache: Sendable {
|
||||
@@ -57,6 +58,7 @@ extension Kana2Kanji {
|
||||
zenzaiCache: ZenzaiCache?,
|
||||
inferenceLimit: Int,
|
||||
requestRichCandidates: Bool,
|
||||
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)?,
|
||||
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
|
||||
) -> (result: LatticeNode, nodes: Nodes, cache: ZenzaiCache) {
|
||||
var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? PrefixConstraint([])
|
||||
@@ -110,7 +112,13 @@ extension Kana2Kanji {
|
||||
// When inference occurs more than maximum times, then just return result at this point
|
||||
return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
|
||||
}
|
||||
let reviewResult = zenz.candidateEvaluate(convertTarget: inputData.convertTarget, candidates: [candidate], requestRichCandidates: requestRichCandidates, versionDependentConfig: versionDependentConfig)
|
||||
let reviewResult = zenz.candidateEvaluate(
|
||||
convertTarget: inputData.convertTarget,
|
||||
candidates: [candidate],
|
||||
requestRichCandidates: requestRichCandidates,
|
||||
personalizationMode: personalizationMode,
|
||||
versionDependentConfig: versionDependentConfig
|
||||
)
|
||||
inferenceLimit -= 1
|
||||
let nextAction = self.review(
|
||||
candidateIndex: index,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import Foundation
|
||||
import SwiftUtils
|
||||
import SwiftNGram
|
||||
|
||||
@MainActor package final class Zenz {
|
||||
package var resourceURL: URL
|
||||
@@ -31,12 +32,24 @@ import SwiftUtils
|
||||
try? self.zenzContext?.reset_context()
|
||||
}
|
||||
|
||||
func candidateEvaluate(convertTarget: String, candidates: [Candidate], requestRichCandidates: Bool, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> ZenzContext.CandidateEvaluationResult {
|
||||
func candidateEvaluate(
|
||||
convertTarget: String,
|
||||
candidates: [Candidate],
|
||||
requestRichCandidates: Bool,
|
||||
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)?,
|
||||
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
|
||||
) -> ZenzContext.CandidateEvaluationResult {
|
||||
guard let zenzContext else {
|
||||
return .error
|
||||
}
|
||||
for candidate in candidates {
|
||||
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate, requestRichCandidates: requestRichCandidates, versionDependentConfig: versionDependentConfig)
|
||||
let result = zenzContext.evaluate_candidate(
|
||||
input: convertTarget.toKatakana(),
|
||||
candidate: candidate,
|
||||
requestRichCandidates: requestRichCandidates,
|
||||
personalizationMode: personalizationMode,
|
||||
versionDependentConfig: versionDependentConfig
|
||||
)
|
||||
return result
|
||||
}
|
||||
return .error
|
||||
|
||||
@@ -7,6 +7,7 @@ import SwiftUtils
|
||||
import HeapModule
|
||||
import Algorithms
|
||||
import Foundation
|
||||
import SwiftNGram
|
||||
|
||||
struct FixedSizeHeap<Element: Comparable> {
|
||||
private var size: Int
|
||||
@@ -285,7 +286,14 @@ final class ZenzContext {
|
||||
return minHeap.unordered.sorted { $0.value > $1.value }.map { ($0.character, $0.value / exp_sum) }
|
||||
}
|
||||
|
||||
func evaluate_candidate(input: String, candidate: Candidate, requestRichCandidates: Bool, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
|
||||
func evaluate_candidate(
|
||||
input: String,
|
||||
candidate: Candidate,
|
||||
requestRichCandidates: Bool,
|
||||
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)?,
|
||||
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
|
||||
) -> CandidateEvaluationResult {
|
||||
print("Evaluate", candidate)
|
||||
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
|
||||
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
|
||||
var userDictionaryPrompt: String = ""
|
||||
@@ -381,7 +389,7 @@ final class ZenzContext {
|
||||
let n_vocab = llama_n_vocab(model)
|
||||
let is_learned_token: [(isLearned: Bool, priority: Float)] = Array(repeating: (false, 0), count: prompt_tokens.count) + candidate.data.flatMap {
|
||||
// priorityは文字数にする→文字数が長いほど優先される
|
||||
Array(repeating: ($0.metadata.contains(.isLearned), getLearningPriority(data: $0)), count: self.tokenize(text: $0.word, add_bos: false).count)
|
||||
Array(repeating: ($0.metadata.contains(.isLearned), logf(getLearningPriority(data: $0))), count: self.tokenize(text: $0.word, add_bos: false).count)
|
||||
}
|
||||
|
||||
var score: Float = 0
|
||||
@@ -402,23 +410,49 @@ final class ZenzContext {
|
||||
// それぞれのトークンが、一つ前の予測において最も確率の高いトークンであるかをチェックする
|
||||
// softmaxはmaxなので、単にlogitsの中で最も大きいものを選べば良い
|
||||
// 一方実用的にはlog_probも得ておきたい。このため、ここでは明示的にsoftmaxも計算している
|
||||
struct TokenAndExpLogit: Comparable {
|
||||
static func < (lhs: TokenAndExpLogit, rhs: TokenAndExpLogit) -> Bool {
|
||||
lhs.expLogit < rhs.expLogit
|
||||
struct TokenAndLogprob: Comparable {
|
||||
static func < (lhs: TokenAndLogprob, rhs: TokenAndLogprob) -> Bool {
|
||||
lhs.logprob < rhs.logprob
|
||||
}
|
||||
|
||||
var token: llama_token
|
||||
var expLogit: Float
|
||||
var logprob: Float
|
||||
}
|
||||
var exp_sum: Float = 0
|
||||
var sumexp: Float = 0
|
||||
let startIndex = (i - 1 - startOffset) * Int(n_vocab)
|
||||
let endIndex = (i - startOffset) * Int(n_vocab)
|
||||
var tokenHeap = FixedSizeHeap<TokenAndExpLogit>(size: requestRichCandidates ? 3 : 1)
|
||||
var tokenHeap = FixedSizeHeap<TokenAndLogprob>(size: requestRichCandidates ? 3 : 1)
|
||||
for index in startIndex ..< endIndex {
|
||||
let v = expf(logits[index])
|
||||
exp_sum += v
|
||||
tokenHeap.insertIfPossible(TokenAndExpLogit(token: llama_token(index - startIndex), expLogit: v))
|
||||
sumexp += expf(logits[index])
|
||||
}
|
||||
let logsumexp = logf(sumexp)
|
||||
|
||||
if let (mode, baseLM, personalLM) = personalizationMode {
|
||||
let prefix = tokens[..<i].dropFirst(prompt_tokens.count).map(Int.init)
|
||||
let baseProb: [Float]
|
||||
let personalProb: [Float]
|
||||
// SwiftNgramのLMは無条件の場合エラーになるため(Unigram確率はサポートしていない)
|
||||
if !prefix.isEmpty {
|
||||
baseProb = baseLM.bulkPredict(prefix).map { logf(Float($0) + 1e-7) }
|
||||
personalProb = personalLM.bulkPredict(prefix).map { logf(Float($0) + 1e-7) }
|
||||
} else {
|
||||
baseProb = Array(repeating: 0, count: Int(n_vocab))
|
||||
personalProb = baseProb
|
||||
}
|
||||
// p = probabilityBuffer / exp_sum
|
||||
// p' = p / p_b * p_p
|
||||
for (i, (lpb, lpp)) in zip(0 ..< Int(n_vocab), zip(baseProb, personalProb)) {
|
||||
let logp = logits[startIndex + i] - logsumexp
|
||||
let logp_ = logp + mode.alpha * (lpp - lpb) // personalized probability
|
||||
tokenHeap.insertIfPossible(TokenAndLogprob(token: llama_token(i), logprob: logp_))
|
||||
}
|
||||
} else {
|
||||
// p = probabilityBuffer / exp_sum
|
||||
for i in startIndex ..< endIndex {
|
||||
let logp = logits[i] - logsumexp
|
||||
tokenHeap.insertIfPossible(TokenAndLogprob(token: llama_token(i - startIndex), logprob: logp))
|
||||
}
|
||||
}
|
||||
|
||||
guard let maxItem = tokenHeap.max else {
|
||||
print("Max Item could not be found for unknown reason")
|
||||
return .error
|
||||
@@ -436,9 +470,9 @@ final class ZenzContext {
|
||||
let wholeResult = String(string.dropFirst(prompt.count))
|
||||
return .wholeResult(wholeResult)
|
||||
} else {
|
||||
let actual_exp: Float = expf(logits[startIndex + Int(token_id)])
|
||||
let actual_logp: Float = logits[startIndex + Int(token_id)] - logsumexp
|
||||
// 学習されたトークンであり、なおかつactual_expのある程度大きければ、学習されたトークンを優先する
|
||||
let preferLearnedToken = is_learned_token[i].isLearned && actual_exp * is_learned_token[i].priority > maxItem.expLogit
|
||||
let preferLearnedToken = is_learned_token[i].isLearned && actual_logp + is_learned_token[i].priority > maxItem.logprob
|
||||
if !preferLearnedToken {
|
||||
// adding "\0"
|
||||
let cchars = tokens[..<i].reduce(into: []) {
|
||||
@@ -458,12 +492,12 @@ final class ZenzContext {
|
||||
AlternativeHighProbToken(
|
||||
token: item.token,
|
||||
constraint: prefix.map(UInt8.init) + token_to_piece(token: item.token).map(UInt8.init),
|
||||
probabilityRatioToMaxProb: item.expLogit / maxItem.expLogit
|
||||
probabilityRatioToMaxProb: expf(item.logprob - maxItem.logprob)
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
score += logf(maxItem.expLogit) - logf(exp_sum)
|
||||
score += logf(maxItem.logprob)
|
||||
}
|
||||
return .pass(score: score, alternativeConstraints: altTokens.unordered.sorted(by: >).map {.init(probabilityRatio: $0.probabilityRatioToMaxProb, prefixConstraint: $0.constraint)})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user