feat: support N-gram based personalization

This commit is contained in:
Miwa / Ensan
2025-02-08 18:57:04 +09:00
parent b47aa2d90a
commit 36277064b4
10 changed files with 190 additions and 31 deletions

View File

@@ -19,6 +19,12 @@ extension Subcommands {
var configZenzaiInferenceLimit: Int = .max
@Flag(name: [.customLong("config_zenzai_ignore_left_context")], help: "ignore left_context")
var configZenzaiIgnoreLeftContext: Bool = false
@Option(name: [.customLong("config_zenzai_base_lm")], help: "Marisa files for Base LM.")
var configZenzaiBaseLM: String?
@Option(name: [.customLong("config_zenzai_personal_lm")], help: "Marisa files for Personal LM.")
var configZenzaiPersonalLM: String?
@Option(name: [.customLong("config_zenzai_personalization_alpha")], help: "Strength of personalization (0.5 by default)")
var configZenzaiPersonalizationAlpha: Float = 0.5
static let configuration = CommandConfiguration(commandName: "evaluate", abstract: "Evaluate quality of Conversion for input data.")
@@ -87,6 +93,20 @@ extension Subcommands {
}
func requestOptions(leftSideContext: String?) -> ConvertRequestOptions {
let personalizationMode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode?
if let base = self.configZenzaiBaseLM, let personal = self.configZenzaiPersonalLM {
personalizationMode = .init(
baseNgramLanguageModel: base,
personalNgramLanguageModel: personal,
n: 5,
d: 0.75,
alpha: self.configZenzaiPersonalizationAlpha
)
} else if self.configZenzaiBaseLM != nil || self.configZenzaiPersonalLM != nil {
fatalError("Both --config_zenzai_base_lm and --config_zenzai_personal_lm must be set")
} else {
personalizationMode = nil
}
var option: ConvertRequestOptions = .withDefaultDictionary(
N_best: self.configNBest,
requireJapanesePrediction: false,
@@ -102,7 +122,7 @@ extension Subcommands {
shouldResetMemory: false,
memoryDirectoryURL: URL(fileURLWithPath: ""),
sharedContainerURL: URL(fileURLWithPath: ""),
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit, versionDependentMode: .v2(.init(leftSideContext: self.configZenzaiIgnoreLeftContext ? nil : leftSideContext))),
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit, personalizationMode: personalizationMode, versionDependentMode: .v2(.init(leftSideContext: self.configZenzaiIgnoreLeftContext ? nil : leftSideContext))),
metadata: .init(versionString: "anco for debugging")
)
option.requestQuery = .

View File

@@ -38,7 +38,7 @@ extension Subcommands {
shouldResetMemory: false,
memoryDirectoryURL: URL(fileURLWithPath: ""),
sharedContainerURL: URL(fileURLWithPath: ""),
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: .max, versionDependentMode: .v3(.init())),
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: .max, personalizationMode: nil, versionDependentMode: .v3(.init())),
metadata: .init(versionString: "anco for debugging")
)
}

View File

@@ -15,6 +15,12 @@ extension Subcommands {
var zenzWeightPath: String = ""
@Option(name: [.customLong("config_zenzai_inference_limit")], help: "inference limit for zenzai.")
var configZenzaiInferenceLimit: Int = .max
@Option(name: [.customLong("config_zenzai_base_lm")], help: "Marisa files for Base LM.")
var configZenzaiBaseLM: String?
@Option(name: [.customLong("config_zenzai_personal_lm")], help: "Marisa files for Personal LM.")
var configZenzaiPersonalLM: String?
@Option(name: [.customLong("config_zenzai_personalization_alpha")], help: "Strength of personalization (0.5 by default)")
var configZenzaiPersonalizationAlpha: Float = 0.5
@Flag(name: [.customLong("disable_prediction")], help: "Disable producing prediction candidates.")
var disablePrediction = false
@@ -55,6 +61,20 @@ extension Subcommands {
}
func requestOptions() -> ConvertRequestOptions {
let personalizationMode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode?
if let base = self.configZenzaiBaseLM, let personal = self.configZenzaiPersonalLM {
personalizationMode = .init(
baseNgramLanguageModel: base,
personalNgramLanguageModel: personal,
n: 5,
d: 0.75,
alpha: self.configZenzaiPersonalizationAlpha
)
} else if self.configZenzaiBaseLM != nil || self.configZenzaiPersonalLM != nil {
fatalError("Both --config_zenzai_base_lm and --config_zenzai_personal_lm must be set")
} else {
personalizationMode = nil
}
var option: ConvertRequestOptions = .withDefaultDictionary(
N_best: self.onlyWholeConversion ? max(self.configNBest, self.displayTopN) : self.configNBest,
requireJapanesePrediction: !self.onlyWholeConversion && !self.disablePrediction,
@@ -70,7 +90,7 @@ extension Subcommands {
shouldResetMemory: false,
memoryDirectoryURL: URL(fileURLWithPath: ""),
sharedContainerURL: URL(fileURLWithPath: ""),
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit),
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit, personalizationMode: personalizationMode),
metadata: .init(versionString: "anco for debugging")
)
if self.onlyWholeConversion {

View File

@@ -35,6 +35,12 @@ extension Subcommands {
var zenzV1 = false
@Flag(name: [.customLong("zenz_v2")], help: "Use zenz_v2 model.")
var zenzV2 = false
@Option(name: [.customLong("config_zenzai_base_lm")], help: "Marisa files for Base LM.")
var configZenzaiBaseLM: String?
@Option(name: [.customLong("config_zenzai_personal_lm")], help: "Marisa files for Personal LM.")
var configZenzaiPersonalLM: String?
@Option(name: [.customLong("config_zenzai_personalization_alpha")], help: "Strength of personalization (0.5 by default)")
var configZenzaiPersonalizationAlpha: Float = 0.5
static let configuration = CommandConfiguration(commandName: "session", abstract: "Start session for incremental input.")
@@ -201,6 +207,20 @@ extension Subcommands {
} else {
.v3(.init(profile: self.configZenzaiProfile, topic: self.configZenzaiTopic, leftSideContext: leftSideContext))
}
let personalizationMode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode?
if let base = self.configZenzaiBaseLM, let personal = self.configZenzaiPersonalLM {
personalizationMode = .init(
baseNgramLanguageModel: base,
personalNgramLanguageModel: personal,
n: 5,
d: 0.75,
alpha: self.configZenzaiPersonalizationAlpha
)
} else if self.configZenzaiBaseLM != nil || self.configZenzaiPersonalLM != nil {
fatalError("Both --config_zenzai_base_lm and --config_zenzai_personal_lm must be set")
} else {
personalizationMode = nil
}
var option: ConvertRequestOptions = .withDefaultDictionary(
N_best: self.onlyWholeConversion ? max(self.configNBest, self.displayTopN) : self.configNBest,
requireJapanesePrediction: !self.onlyWholeConversion && !self.disablePrediction,
@@ -220,6 +240,7 @@ extension Subcommands {
weight: URL(string: self.zenzWeightPath)!,
inferenceLimit: self.configZenzaiInferenceLimit,
requestRichCandidates: self.configRequestRichCandidates,
personalizationMode: personalizationMode,
versionDependentMode: zenzaiVersionDependentMode
),
metadata: .init(versionString: "anco for debugging")

View File

@@ -224,6 +224,21 @@ public struct ConvertRequestOptions: Sendable {
}
public struct ZenzaiMode: Sendable, Equatable {
public struct PersonalizationMode: Sendable, Equatable {
public init(baseNgramLanguageModel: String, personalNgramLanguageModel: String, n: Int = 5, d: Double = 0.75, alpha: Float = 0.5) {
self.baseNgramLanguageModel = baseNgramLanguageModel
self.personalNgramLanguageModel = personalNgramLanguageModel
self.n = n
self.d = d
self.alpha = alpha
}
var n: Int = 5
var d: Double = 0.75
var alpha: Float = 0.5
var baseNgramLanguageModel: String
var personalNgramLanguageModel: String
}
public static let off = ZenzaiMode(
enabled: false,
weightURL: URL(fileURLWithPath: ""),
@@ -237,13 +252,15 @@ public struct ConvertRequestOptions: Sendable {
/// - weight: path for model weight (gguf)
/// - inferenceLimit: applying inference count limitation. Smaller limit makes conversion faster but quality will be worse. (Default: 10)
/// - requestRichCandidates: when this flag is true, the converter spends more time but generate richer N-Best candidates for candidate list view. Usually this option is not recommended for live conversion.
/// - personalizationMode: values for personalization.
/// - versionDependentMode: specify zenz model version and its configuration.
public static func on(weight: URL, inferenceLimit: Int = 10, requestRichCandidates: Bool = false, versionDependentMode: ZenzaiVersionDependentMode = .v3(.init())) -> Self {
public static func on(weight: URL, inferenceLimit: Int = 10, requestRichCandidates: Bool = false, personalizationMode: PersonalizationMode?, versionDependentMode: ZenzaiVersionDependentMode = .v3(.init())) -> Self {
ZenzaiMode(
enabled: true,
weightURL: weight,
inferenceLimit: inferenceLimit,
requestRichCandidates: requestRichCandidates,
personalizationMode: personalizationMode,
versionDependentMode: versionDependentMode
)
}
@@ -251,6 +268,7 @@ public struct ConvertRequestOptions: Sendable {
var weightURL: URL
var inferenceLimit: Int
var requestRichCandidates: Bool
var personalizationMode: PersonalizationMode?
var versionDependentMode: ZenzaiVersionDependentMode
}
}

View File

@@ -8,6 +8,7 @@
import Foundation
import SwiftUtils
import SwiftNGram
///
@MainActor public final class KanaKanjiConverter {
@@ -25,14 +26,16 @@ import SwiftUtils
private var nodes: [[LatticeNode]] = []
private var completedData: Candidate?
private var lastData: DicdataElement?
/// Zenzaizenz-v1
/// Zenzaizenz
private var zenz: Zenz? = nil
private var zenzaiCache: Kana2Kanji.ZenzaiCache? = nil
private var zenzaiPersonalization: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)?
public private(set) var zenzStatus: String = ""
///
public func stopComposition() {
self.zenz?.endSession()
self.zenzaiPersonalization = nil
self.zenzaiCache = nil
self.previousInputData = nil
self.nodes = []
@@ -40,6 +43,20 @@ import SwiftUtils
self.lastData = nil
}
private func getZenzaiPersonalization(mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode?) -> (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)? {
guard let mode else {
return nil
}
if let zenzaiPersonalization, zenzaiPersonalization.mode == mode {
return zenzaiPersonalization
}
let tokenizer = ZenzTokenizer()
let baseModel = LM(baseFilename: mode.baseNgramLanguageModel, n: mode.n, d: mode.d, tokenizer: tokenizer)
let personalModel = LM(baseFilename: mode.personalNgramLanguageModel, n: mode.n, d: mode.d, tokenizer: tokenizer)
self.zenzaiPersonalization = (mode, baseModel, personalModel)
return (mode, baseModel, personalModel)
}
package func getModel(modelURL: URL) -> Zenz? {
if let model = self.zenz, model.resourceURL == modelURL {
self.zenzStatus = "load \(modelURL.absoluteString)"
@@ -594,6 +611,7 @@ import SwiftUtils
return nil
}
print("ConvertToLattice ", zenzaiMode)
// FIXME: enable cache based zenzai
if zenzaiMode.enabled, let model = self.getModel(modelURL: zenzaiMode.weightURL) {
let (result, nodes, cache) = self.converter.all_zenzai(
@@ -602,6 +620,7 @@ import SwiftUtils
zenzaiCache: self.zenzaiCache,
inferenceLimit: zenzaiMode.inferenceLimit,
requestRichCandidates: zenzaiMode.requestRichCandidates,
personalizationMode: self.getZenzaiPersonalization(mode: zenzaiMode.personalizationMode),
versionDependentConfig: zenzaiMode.versionDependentMode
)
self.zenzaiCache = cache

View File

@@ -1,5 +1,6 @@
import Foundation
import SwiftUtils
import SwiftNGram
extension Kana2Kanji {
struct ZenzaiCache: Sendable {
@@ -57,6 +58,7 @@ extension Kana2Kanji {
zenzaiCache: ZenzaiCache?,
inferenceLimit: Int,
requestRichCandidates: Bool,
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)?,
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
) -> (result: LatticeNode, nodes: Nodes, cache: ZenzaiCache) {
var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? PrefixConstraint([])
@@ -110,7 +112,13 @@ extension Kana2Kanji {
// When inference occurs more than maximum times, then just return result at this point
return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
}
let reviewResult = zenz.candidateEvaluate(convertTarget: inputData.convertTarget, candidates: [candidate], requestRichCandidates: requestRichCandidates, versionDependentConfig: versionDependentConfig)
let reviewResult = zenz.candidateEvaluate(
convertTarget: inputData.convertTarget,
candidates: [candidate],
requestRichCandidates: requestRichCandidates,
personalizationMode: personalizationMode,
versionDependentConfig: versionDependentConfig
)
inferenceLimit -= 1
let nextAction = self.review(
candidateIndex: index,

View File

@@ -1,5 +1,6 @@
import Foundation
import SwiftUtils
import SwiftNGram
@MainActor package final class Zenz {
package var resourceURL: URL
@@ -31,12 +32,24 @@ import SwiftUtils
try? self.zenzContext?.reset_context()
}
func candidateEvaluate(convertTarget: String, candidates: [Candidate], requestRichCandidates: Bool, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> ZenzContext.CandidateEvaluationResult {
func candidateEvaluate(
convertTarget: String,
candidates: [Candidate],
requestRichCandidates: Bool,
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)?,
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
) -> ZenzContext.CandidateEvaluationResult {
guard let zenzContext else {
return .error
}
for candidate in candidates {
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate, requestRichCandidates: requestRichCandidates, versionDependentConfig: versionDependentConfig)
let result = zenzContext.evaluate_candidate(
input: convertTarget.toKatakana(),
candidate: candidate,
requestRichCandidates: requestRichCandidates,
personalizationMode: personalizationMode,
versionDependentConfig: versionDependentConfig
)
return result
}
return .error

View File

@@ -7,6 +7,7 @@ import SwiftUtils
import HeapModule
import Algorithms
import Foundation
import SwiftNGram
struct FixedSizeHeap<Element: Comparable> {
private var size: Int
@@ -285,7 +286,14 @@ final class ZenzContext {
return minHeap.unordered.sorted { $0.value > $1.value }.map { ($0.character, $0.value / exp_sum) }
}
func evaluate_candidate(input: String, candidate: Candidate, requestRichCandidates: Bool, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
func evaluate_candidate(
input: String,
candidate: Candidate,
requestRichCandidates: Bool,
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)?,
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
) -> CandidateEvaluationResult {
print("Evaluate", candidate)
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
var userDictionaryPrompt: String = ""
@@ -381,7 +389,7 @@ final class ZenzContext {
let n_vocab = llama_n_vocab(model)
let is_learned_token: [(isLearned: Bool, priority: Float)] = Array(repeating: (false, 0), count: prompt_tokens.count) + candidate.data.flatMap {
// priority
Array(repeating: ($0.metadata.contains(.isLearned), getLearningPriority(data: $0)), count: self.tokenize(text: $0.word, add_bos: false).count)
Array(repeating: ($0.metadata.contains(.isLearned), logf(getLearningPriority(data: $0))), count: self.tokenize(text: $0.word, add_bos: false).count)
}
var score: Float = 0
@@ -402,23 +410,49 @@ final class ZenzContext {
//
// softmaxmaxlogits
// log_probsoftmax
struct TokenAndExpLogit: Comparable {
static func < (lhs: TokenAndExpLogit, rhs: TokenAndExpLogit) -> Bool {
lhs.expLogit < rhs.expLogit
struct TokenAndLogprob: Comparable {
static func < (lhs: TokenAndLogprob, rhs: TokenAndLogprob) -> Bool {
lhs.logprob < rhs.logprob
}
var token: llama_token
var expLogit: Float
var logprob: Float
}
var exp_sum: Float = 0
var sumexp: Float = 0
let startIndex = (i - 1 - startOffset) * Int(n_vocab)
let endIndex = (i - startOffset) * Int(n_vocab)
var tokenHeap = FixedSizeHeap<TokenAndExpLogit>(size: requestRichCandidates ? 3 : 1)
var tokenHeap = FixedSizeHeap<TokenAndLogprob>(size: requestRichCandidates ? 3 : 1)
for index in startIndex ..< endIndex {
let v = expf(logits[index])
exp_sum += v
tokenHeap.insertIfPossible(TokenAndExpLogit(token: llama_token(index - startIndex), expLogit: v))
sumexp += expf(logits[index])
}
let logsumexp = logf(sumexp)
if let (mode, baseLM, personalLM) = personalizationMode {
let prefix = tokens[..<i].dropFirst(prompt_tokens.count).map(Int.init)
let baseProb: [Float]
let personalProb: [Float]
// SwiftNgramLM(Unigram)
if !prefix.isEmpty {
baseProb = baseLM.bulkPredict(prefix).map { logf(Float($0) + 1e-7) }
personalProb = personalLM.bulkPredict(prefix).map { logf(Float($0) + 1e-7) }
} else {
baseProb = Array(repeating: 0, count: Int(n_vocab))
personalProb = baseProb
}
// p = probabilityBuffer / exp_sum
// p' = p / p_b * p_p
for (i, (lpb, lpp)) in zip(0 ..< Int(n_vocab), zip(baseProb, personalProb)) {
let logp = logits[startIndex + i] - logsumexp
let logp_ = logp + mode.alpha * (lpp - lpb) // personalized probability
tokenHeap.insertIfPossible(TokenAndLogprob(token: llama_token(i), logprob: logp_))
}
} else {
// p = probabilityBuffer / exp_sum
for i in startIndex ..< endIndex {
let logp = logits[i] - logsumexp
tokenHeap.insertIfPossible(TokenAndLogprob(token: llama_token(i - startIndex), logprob: logp))
}
}
guard let maxItem = tokenHeap.max else {
print("Max Item could not be found for unknown reason")
return .error
@@ -436,9 +470,9 @@ final class ZenzContext {
let wholeResult = String(string.dropFirst(prompt.count))
return .wholeResult(wholeResult)
} else {
let actual_exp: Float = expf(logits[startIndex + Int(token_id)])
let actual_logp: Float = logits[startIndex + Int(token_id)] - logsumexp
// actual_exp
let preferLearnedToken = is_learned_token[i].isLearned && actual_exp * is_learned_token[i].priority > maxItem.expLogit
let preferLearnedToken = is_learned_token[i].isLearned && actual_logp + is_learned_token[i].priority > maxItem.logprob
if !preferLearnedToken {
// adding "\0"
let cchars = tokens[..<i].reduce(into: []) {
@@ -458,12 +492,12 @@ final class ZenzContext {
AlternativeHighProbToken(
token: item.token,
constraint: prefix.map(UInt8.init) + token_to_piece(token: item.token).map(UInt8.init),
probabilityRatioToMaxProb: item.expLogit / maxItem.expLogit
probabilityRatioToMaxProb: expf(item.logprob - maxItem.logprob)
)
)
}
}
score += logf(maxItem.expLogit) - logf(exp_sum)
score += logf(maxItem.logprob)
}
return .pass(score: score, alternativeConstraints: altTokens.unordered.sorted(by: >).map {.init(probabilityRatio: $0.probabilityRatioToMaxProb, prefixConstraint: $0.constraint)})
}