mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-08-22 15:05:26 +00:00
feat: zenz-v2の文脈による条件づけ機能を活かしたAPIを追加
This commit is contained in:
@ -25,7 +25,10 @@ extension Subcommands {
|
||||
var roman2kana = false
|
||||
@Option(name: [.customLong("config_zenzai_inference_limit")], help: "inference limit for zenzai.")
|
||||
var configZenzaiInferenceLimit: Int = .max
|
||||
|
||||
@Option(name: [.customLong("config_profile")], help: "enable profile prompting for zenz-v2.")
|
||||
var configZenzV2Profile: String? = nil
|
||||
@Flag(name: [.customLong("zenz_v1")], help: "Use zenz_v1 model.")
|
||||
var zenzV1 = false
|
||||
|
||||
static var configuration = CommandConfiguration(commandName: "session", abstract: "Start session for incremental input.")
|
||||
|
||||
@ -44,6 +47,9 @@ extension Subcommands {
|
||||
}
|
||||
|
||||
@MainActor mutating func run() async {
|
||||
if self.zenzV1 {
|
||||
print("\(bold: "We strongly recommend to use zenz-v2 models")")
|
||||
}
|
||||
let memoryDirector = if self.enableLearning {
|
||||
if let dir = self.getTemporaryDirectory() {
|
||||
dir
|
||||
@ -58,10 +64,14 @@ extension Subcommands {
|
||||
var composingText = ComposingText()
|
||||
let inputStyle: InputStyle = self.roman2kana ? .roman2kana : .direct
|
||||
var lastCandidates: [Candidate] = []
|
||||
var leftSideContext: String = ""
|
||||
var page = 0
|
||||
while true {
|
||||
print()
|
||||
print("\(bold: "== Type :q to end session, type :d to delete character, type :c to stop composition. For other commands, type :h ==")")
|
||||
if !leftSideContext.isEmpty {
|
||||
print("\(bold: "Current Left-Side Context"): \(leftSideContext)")
|
||||
}
|
||||
var input = readLine(strippingNewline: true) ?? ""
|
||||
switch input {
|
||||
case ":q":
|
||||
@ -74,6 +84,7 @@ extension Subcommands {
|
||||
// クリア
|
||||
composingText.stopComposition()
|
||||
converter.stopComposition()
|
||||
leftSideContext = ""
|
||||
print("composition is stopped")
|
||||
continue
|
||||
case ":n":
|
||||
@ -119,6 +130,7 @@ extension Subcommands {
|
||||
composingText.stopComposition()
|
||||
converter.stopComposition()
|
||||
}
|
||||
leftSideContext += candidate.text
|
||||
} else {
|
||||
input = String(input.map { (c: Character) -> Character in
|
||||
[
|
||||
@ -132,7 +144,7 @@ extension Subcommands {
|
||||
}
|
||||
print(composingText.convertTarget)
|
||||
let start = Date()
|
||||
let result = converter.requestCandidates(composingText, options: requestOptions(memoryDirector: memoryDirector))
|
||||
let result = converter.requestCandidates(composingText, options: requestOptions(memoryDirector: memoryDirector, leftSideContext: leftSideContext))
|
||||
let mainResults = result.mainResults.filter {
|
||||
!self.onlyWholeConversion || $0.data.reduce(into: "", {$0.append(contentsOf: $1.ruby)}) == input.toKatakana()
|
||||
}
|
||||
@ -159,7 +171,12 @@ extension Subcommands {
|
||||
}
|
||||
}
|
||||
|
||||
func requestOptions(memoryDirector: URL) -> ConvertRequestOptions {
|
||||
func requestOptions(memoryDirector: URL, leftSideContext: String) -> ConvertRequestOptions {
|
||||
let zenzaiVersionDependentMode: ConvertRequestOptions.ZenzaiVersionDependentMode = if self.zenzV1 {
|
||||
.v1
|
||||
} else {
|
||||
.v2(.init(profile: self.configZenzV2Profile, leftSideContext: leftSideContext))
|
||||
}
|
||||
var option: ConvertRequestOptions = .withDefaultDictionary(
|
||||
N_best: self.onlyWholeConversion ? max(self.configNBest, self.displayTopN) : self.configNBest,
|
||||
requireJapanesePrediction: !self.onlyWholeConversion && !self.disablePrediction,
|
||||
@ -175,7 +192,11 @@ extension Subcommands {
|
||||
shouldResetMemory: false,
|
||||
memoryDirectoryURL: memoryDirector,
|
||||
sharedContainerURL: URL(fileURLWithPath: ""),
|
||||
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit),
|
||||
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(
|
||||
weight: URL(string: self.zenzWeightPath)!,
|
||||
inferenceLimit: self.configZenzaiInferenceLimit,
|
||||
versionDependentMode: zenzaiVersionDependentMode
|
||||
),
|
||||
metadata: .init(versionString: "anco for debugging")
|
||||
)
|
||||
if self.onlyWholeConversion {
|
||||
|
@ -142,18 +142,50 @@ public struct ConvertRequestOptions: Sendable {
|
||||
case 完全一致
|
||||
}
|
||||
|
||||
public struct ZenzaiV2DependentMode: Sendable, Equatable, Hashable {
|
||||
public init(profile: String? = nil, leftSideContext: String? = nil) {
|
||||
self.profile = profile
|
||||
self.leftSideContext = leftSideContext
|
||||
}
|
||||
|
||||
/// プロフィールコンテクストを設定した場合、プロフィールを反映したプロンプトが自動的に付与されます。プロフィールは10〜20文字程度の長さにとどめることを推奨します。
|
||||
public var profile: String?
|
||||
/// 左側の文字列を文脈として与えます。
|
||||
public var leftSideContext: String?
|
||||
}
|
||||
|
||||
public enum ZenzVersion: Sendable, Equatable, Hashable {
|
||||
case v1
|
||||
case v2
|
||||
}
|
||||
|
||||
public enum ZenzaiVersionDependentMode: Sendable, Equatable, Hashable {
|
||||
case v1
|
||||
case v2(ZenzaiV2DependentMode)
|
||||
|
||||
public var version: ZenzVersion {
|
||||
switch self {
|
||||
case .v1:
|
||||
return .v1
|
||||
case .v2(_):
|
||||
return .v2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public struct ZenzaiMode: Sendable, Equatable {
|
||||
public static let off = ZenzaiMode(enabled: false, weightURL: URL(fileURLWithPath: ""), inferenceLimit: 10)
|
||||
public static let off = ZenzaiMode(enabled: false, weightURL: URL(fileURLWithPath: ""), inferenceLimit: 10, versionDependentMode: .v2(.init()))
|
||||
|
||||
/// activate *Zenzai* - Neural Kana-Kanji Conversiion Engine
|
||||
/// - Parameters:
|
||||
/// - weight: path for model weight (gguf)
|
||||
/// - inferenceLimit: applying inference count limitation. Smaller limit makes conversion faster but quality will be worse. (Default: 10)
|
||||
public static func on(weight: URL, inferenceLimit: Int = 10) -> Self {
|
||||
ZenzaiMode(enabled: true, weightURL: weight, inferenceLimit: inferenceLimit)
|
||||
public static func on(weight: URL, inferenceLimit: Int = 10, versionDependentMode: ZenzaiVersionDependentMode = .v2(.init())) -> Self {
|
||||
ZenzaiMode(enabled: true, weightURL: weight, inferenceLimit: inferenceLimit, versionDependentMode: versionDependentMode)
|
||||
}
|
||||
var enabled: Bool
|
||||
var weightURL: URL
|
||||
var inferenceLimit: Int
|
||||
var versionDependentMode: ZenzaiVersionDependentMode
|
||||
}
|
||||
}
|
||||
|
@ -568,7 +568,7 @@ import SwiftUtils
|
||||
|
||||
// FIXME: enable cache based zenzai
|
||||
if zenzaiMode.enabled, let model = self.getModel(modelURL: zenzaiMode.weightURL) {
|
||||
let (result, nodes, cache) = self.converter.all_zenzai(inputData, zenz: model, zenzaiCache: self.zenzaiCache, inferenceLimit: zenzaiMode.inferenceLimit)
|
||||
let (result, nodes, cache) = self.converter.all_zenzai(inputData, zenz: model, zenzaiCache: self.zenzaiCache, inferenceLimit: zenzaiMode.inferenceLimit, versionDependentConfig: zenzaiMode.versionDependentMode)
|
||||
self.zenzaiCache = cache
|
||||
self.previousInputData = inputData
|
||||
return (result, nodes)
|
||||
|
@ -47,7 +47,13 @@ extension Kana2Kanji {
|
||||
}
|
||||
|
||||
/// zenzaiシステムによる完全変換。
|
||||
@MainActor func all_zenzai(_ inputData: ComposingText, zenz: Zenz, zenzaiCache: ZenzaiCache?, inferenceLimit: Int) -> (result: LatticeNode, nodes: Nodes, cache: ZenzaiCache) {
|
||||
@MainActor func all_zenzai(
|
||||
_ inputData: ComposingText,
|
||||
zenz: Zenz,
|
||||
zenzaiCache: ZenzaiCache?,
|
||||
inferenceLimit: Int,
|
||||
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
|
||||
) -> (result: LatticeNode, nodes: Nodes, cache: ZenzaiCache) {
|
||||
var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? PrefixConstraint([])
|
||||
print("initial constraint", constraint)
|
||||
let eosNode = LatticeNode.EOSNode
|
||||
@ -85,7 +91,7 @@ extension Kana2Kanji {
|
||||
// When inference occurs more than maximum times, then just return result at this point
|
||||
return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
|
||||
}
|
||||
let reviewResult = zenz.candidateEvaluate(convertTarget: inputData.convertTarget, candidates: [candidate])
|
||||
let reviewResult = zenz.candidateEvaluate(convertTarget: inputData.convertTarget, candidates: [candidate], versionDependentConfig: versionDependentConfig)
|
||||
inferenceLimit -= 1
|
||||
let nextAction = self.review(
|
||||
candidateIndex: index,
|
||||
|
@ -30,12 +30,12 @@ import SwiftUtils
|
||||
try? self.zenzContext?.reset_context()
|
||||
}
|
||||
|
||||
func candidateEvaluate(convertTarget: String, candidates: [Candidate]) -> ZenzContext.CandidateEvaluationResult {
|
||||
func candidateEvaluate(convertTarget: String, candidates: [Candidate], versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> ZenzContext.CandidateEvaluationResult {
|
||||
guard let zenzContext else {
|
||||
return .error
|
||||
}
|
||||
for candidate in candidates {
|
||||
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate)
|
||||
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate, versionDependentConfig: versionDependentConfig)
|
||||
return result
|
||||
}
|
||||
return .error
|
||||
|
@ -139,11 +139,29 @@ class ZenzContext {
|
||||
}
|
||||
}
|
||||
|
||||
func evaluate_candidate(input: String, candidate: Candidate) -> CandidateEvaluationResult {
|
||||
func evaluate_candidate(input: String, candidate: Candidate, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
|
||||
print("Evaluate", candidate)
|
||||
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
|
||||
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
|
||||
let prompt = "\u{EE00}\(input)\u{EE01}"
|
||||
let prompt: String
|
||||
if case .v2(let mode) = versionDependentConfig {
|
||||
if let leftSideContext = mode.leftSideContext, !leftSideContext.isEmpty {
|
||||
let lsContext = leftSideContext.suffix(40)
|
||||
if let profile = mode.profile, !profile.isEmpty {
|
||||
let pf = profile.suffix(25)
|
||||
prompt = "\u{EE00}\(input)\u{EE02}プロフィール:\(pf)・発言:\(lsContext)\u{EE01}"
|
||||
} else {
|
||||
prompt = "\u{EE00}\(input)\u{EE02}\(lsContext)\u{EE01}"
|
||||
}
|
||||
} else if let profile = mode.profile, !profile.isEmpty {
|
||||
let pf = profile.suffix(25)
|
||||
prompt = "\u{EE00}\(input)\u{EE02}プロフィール:\(pf)・発言:\u{EE01}"
|
||||
} else {
|
||||
prompt = "\u{EE00}\(input)\u{EE01}"
|
||||
}
|
||||
} else {
|
||||
prompt = "\u{EE00}\(input)\u{EE01}"
|
||||
}
|
||||
// Therefore, tokens = prompt_tokens + candidate_tokens is an appropriate operation.
|
||||
let prompt_tokens = self.tokenize(text: prompt, add_bos: true, add_eos: false)
|
||||
let candidate_tokens = self.tokenize(text: candidate.text, add_bos: false, add_eos: false)
|
||||
|
Reference in New Issue
Block a user