mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-08-22 15:05:26 +00:00
feat: zenz-v2の文脈による条件づけ機能を活かしたAPIを追加
This commit is contained in:
@ -25,7 +25,10 @@ extension Subcommands {
|
|||||||
var roman2kana = false
|
var roman2kana = false
|
||||||
@Option(name: [.customLong("config_zenzai_inference_limit")], help: "inference limit for zenzai.")
|
@Option(name: [.customLong("config_zenzai_inference_limit")], help: "inference limit for zenzai.")
|
||||||
var configZenzaiInferenceLimit: Int = .max
|
var configZenzaiInferenceLimit: Int = .max
|
||||||
|
@Option(name: [.customLong("config_profile")], help: "enable profile prompting for zenz-v2.")
|
||||||
|
var configZenzV2Profile: String? = nil
|
||||||
|
@Flag(name: [.customLong("zenz_v1")], help: "Use zenz_v1 model.")
|
||||||
|
var zenzV1 = false
|
||||||
|
|
||||||
static var configuration = CommandConfiguration(commandName: "session", abstract: "Start session for incremental input.")
|
static var configuration = CommandConfiguration(commandName: "session", abstract: "Start session for incremental input.")
|
||||||
|
|
||||||
@ -44,6 +47,9 @@ extension Subcommands {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@MainActor mutating func run() async {
|
@MainActor mutating func run() async {
|
||||||
|
if self.zenzV1 {
|
||||||
|
print("\(bold: "We strongly recommend to use zenz-v2 models")")
|
||||||
|
}
|
||||||
let memoryDirector = if self.enableLearning {
|
let memoryDirector = if self.enableLearning {
|
||||||
if let dir = self.getTemporaryDirectory() {
|
if let dir = self.getTemporaryDirectory() {
|
||||||
dir
|
dir
|
||||||
@ -58,10 +64,14 @@ extension Subcommands {
|
|||||||
var composingText = ComposingText()
|
var composingText = ComposingText()
|
||||||
let inputStyle: InputStyle = self.roman2kana ? .roman2kana : .direct
|
let inputStyle: InputStyle = self.roman2kana ? .roman2kana : .direct
|
||||||
var lastCandidates: [Candidate] = []
|
var lastCandidates: [Candidate] = []
|
||||||
|
var leftSideContext: String = ""
|
||||||
var page = 0
|
var page = 0
|
||||||
while true {
|
while true {
|
||||||
print()
|
print()
|
||||||
print("\(bold: "== Type :q to end session, type :d to delete character, type :c to stop composition. For other commands, type :h ==")")
|
print("\(bold: "== Type :q to end session, type :d to delete character, type :c to stop composition. For other commands, type :h ==")")
|
||||||
|
if !leftSideContext.isEmpty {
|
||||||
|
print("\(bold: "Current Left-Side Context"): \(leftSideContext)")
|
||||||
|
}
|
||||||
var input = readLine(strippingNewline: true) ?? ""
|
var input = readLine(strippingNewline: true) ?? ""
|
||||||
switch input {
|
switch input {
|
||||||
case ":q":
|
case ":q":
|
||||||
@ -74,6 +84,7 @@ extension Subcommands {
|
|||||||
// クリア
|
// クリア
|
||||||
composingText.stopComposition()
|
composingText.stopComposition()
|
||||||
converter.stopComposition()
|
converter.stopComposition()
|
||||||
|
leftSideContext = ""
|
||||||
print("composition is stopped")
|
print("composition is stopped")
|
||||||
continue
|
continue
|
||||||
case ":n":
|
case ":n":
|
||||||
@ -119,6 +130,7 @@ extension Subcommands {
|
|||||||
composingText.stopComposition()
|
composingText.stopComposition()
|
||||||
converter.stopComposition()
|
converter.stopComposition()
|
||||||
}
|
}
|
||||||
|
leftSideContext += candidate.text
|
||||||
} else {
|
} else {
|
||||||
input = String(input.map { (c: Character) -> Character in
|
input = String(input.map { (c: Character) -> Character in
|
||||||
[
|
[
|
||||||
@ -132,7 +144,7 @@ extension Subcommands {
|
|||||||
}
|
}
|
||||||
print(composingText.convertTarget)
|
print(composingText.convertTarget)
|
||||||
let start = Date()
|
let start = Date()
|
||||||
let result = converter.requestCandidates(composingText, options: requestOptions(memoryDirector: memoryDirector))
|
let result = converter.requestCandidates(composingText, options: requestOptions(memoryDirector: memoryDirector, leftSideContext: leftSideContext))
|
||||||
let mainResults = result.mainResults.filter {
|
let mainResults = result.mainResults.filter {
|
||||||
!self.onlyWholeConversion || $0.data.reduce(into: "", {$0.append(contentsOf: $1.ruby)}) == input.toKatakana()
|
!self.onlyWholeConversion || $0.data.reduce(into: "", {$0.append(contentsOf: $1.ruby)}) == input.toKatakana()
|
||||||
}
|
}
|
||||||
@ -159,7 +171,12 @@ extension Subcommands {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestOptions(memoryDirector: URL) -> ConvertRequestOptions {
|
func requestOptions(memoryDirector: URL, leftSideContext: String) -> ConvertRequestOptions {
|
||||||
|
let zenzaiVersionDependentMode: ConvertRequestOptions.ZenzaiVersionDependentMode = if self.zenzV1 {
|
||||||
|
.v1
|
||||||
|
} else {
|
||||||
|
.v2(.init(profile: self.configZenzV2Profile, leftSideContext: leftSideContext))
|
||||||
|
}
|
||||||
var option: ConvertRequestOptions = .withDefaultDictionary(
|
var option: ConvertRequestOptions = .withDefaultDictionary(
|
||||||
N_best: self.onlyWholeConversion ? max(self.configNBest, self.displayTopN) : self.configNBest,
|
N_best: self.onlyWholeConversion ? max(self.configNBest, self.displayTopN) : self.configNBest,
|
||||||
requireJapanesePrediction: !self.onlyWholeConversion && !self.disablePrediction,
|
requireJapanesePrediction: !self.onlyWholeConversion && !self.disablePrediction,
|
||||||
@ -175,7 +192,11 @@ extension Subcommands {
|
|||||||
shouldResetMemory: false,
|
shouldResetMemory: false,
|
||||||
memoryDirectoryURL: memoryDirector,
|
memoryDirectoryURL: memoryDirector,
|
||||||
sharedContainerURL: URL(fileURLWithPath: ""),
|
sharedContainerURL: URL(fileURLWithPath: ""),
|
||||||
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit),
|
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(
|
||||||
|
weight: URL(string: self.zenzWeightPath)!,
|
||||||
|
inferenceLimit: self.configZenzaiInferenceLimit,
|
||||||
|
versionDependentMode: zenzaiVersionDependentMode
|
||||||
|
),
|
||||||
metadata: .init(versionString: "anco for debugging")
|
metadata: .init(versionString: "anco for debugging")
|
||||||
)
|
)
|
||||||
if self.onlyWholeConversion {
|
if self.onlyWholeConversion {
|
||||||
|
@ -142,18 +142,50 @@ public struct ConvertRequestOptions: Sendable {
|
|||||||
case 完全一致
|
case 完全一致
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public struct ZenzaiV2DependentMode: Sendable, Equatable, Hashable {
|
||||||
|
public init(profile: String? = nil, leftSideContext: String? = nil) {
|
||||||
|
self.profile = profile
|
||||||
|
self.leftSideContext = leftSideContext
|
||||||
|
}
|
||||||
|
|
||||||
|
/// プロフィールコンテクストを設定した場合、プロフィールを反映したプロンプトが自動的に付与されます。プロフィールは10〜20文字程度の長さにとどめることを推奨します。
|
||||||
|
public var profile: String?
|
||||||
|
/// 左側の文字列を文脈として与えます。
|
||||||
|
public var leftSideContext: String?
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum ZenzVersion: Sendable, Equatable, Hashable {
|
||||||
|
case v1
|
||||||
|
case v2
|
||||||
|
}
|
||||||
|
|
||||||
|
public enum ZenzaiVersionDependentMode: Sendable, Equatable, Hashable {
|
||||||
|
case v1
|
||||||
|
case v2(ZenzaiV2DependentMode)
|
||||||
|
|
||||||
|
public var version: ZenzVersion {
|
||||||
|
switch self {
|
||||||
|
case .v1:
|
||||||
|
return .v1
|
||||||
|
case .v2(_):
|
||||||
|
return .v2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public struct ZenzaiMode: Sendable, Equatable {
|
public struct ZenzaiMode: Sendable, Equatable {
|
||||||
public static let off = ZenzaiMode(enabled: false, weightURL: URL(fileURLWithPath: ""), inferenceLimit: 10)
|
public static let off = ZenzaiMode(enabled: false, weightURL: URL(fileURLWithPath: ""), inferenceLimit: 10, versionDependentMode: .v2(.init()))
|
||||||
|
|
||||||
/// activate *Zenzai* - Neural Kana-Kanji Conversiion Engine
|
/// activate *Zenzai* - Neural Kana-Kanji Conversiion Engine
|
||||||
/// - Parameters:
|
/// - Parameters:
|
||||||
/// - weight: path for model weight (gguf)
|
/// - weight: path for model weight (gguf)
|
||||||
/// - inferenceLimit: applying inference count limitation. Smaller limit makes conversion faster but quality will be worse. (Default: 10)
|
/// - inferenceLimit: applying inference count limitation. Smaller limit makes conversion faster but quality will be worse. (Default: 10)
|
||||||
public static func on(weight: URL, inferenceLimit: Int = 10) -> Self {
|
public static func on(weight: URL, inferenceLimit: Int = 10, versionDependentMode: ZenzaiVersionDependentMode = .v2(.init())) -> Self {
|
||||||
ZenzaiMode(enabled: true, weightURL: weight, inferenceLimit: inferenceLimit)
|
ZenzaiMode(enabled: true, weightURL: weight, inferenceLimit: inferenceLimit, versionDependentMode: versionDependentMode)
|
||||||
}
|
}
|
||||||
var enabled: Bool
|
var enabled: Bool
|
||||||
var weightURL: URL
|
var weightURL: URL
|
||||||
var inferenceLimit: Int
|
var inferenceLimit: Int
|
||||||
|
var versionDependentMode: ZenzaiVersionDependentMode
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -568,7 +568,7 @@ import SwiftUtils
|
|||||||
|
|
||||||
// FIXME: enable cache based zenzai
|
// FIXME: enable cache based zenzai
|
||||||
if zenzaiMode.enabled, let model = self.getModel(modelURL: zenzaiMode.weightURL) {
|
if zenzaiMode.enabled, let model = self.getModel(modelURL: zenzaiMode.weightURL) {
|
||||||
let (result, nodes, cache) = self.converter.all_zenzai(inputData, zenz: model, zenzaiCache: self.zenzaiCache, inferenceLimit: zenzaiMode.inferenceLimit)
|
let (result, nodes, cache) = self.converter.all_zenzai(inputData, zenz: model, zenzaiCache: self.zenzaiCache, inferenceLimit: zenzaiMode.inferenceLimit, versionDependentConfig: zenzaiMode.versionDependentMode)
|
||||||
self.zenzaiCache = cache
|
self.zenzaiCache = cache
|
||||||
self.previousInputData = inputData
|
self.previousInputData = inputData
|
||||||
return (result, nodes)
|
return (result, nodes)
|
||||||
|
@ -47,7 +47,13 @@ extension Kana2Kanji {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// zenzaiシステムによる完全変換。
|
/// zenzaiシステムによる完全変換。
|
||||||
@MainActor func all_zenzai(_ inputData: ComposingText, zenz: Zenz, zenzaiCache: ZenzaiCache?, inferenceLimit: Int) -> (result: LatticeNode, nodes: Nodes, cache: ZenzaiCache) {
|
@MainActor func all_zenzai(
|
||||||
|
_ inputData: ComposingText,
|
||||||
|
zenz: Zenz,
|
||||||
|
zenzaiCache: ZenzaiCache?,
|
||||||
|
inferenceLimit: Int,
|
||||||
|
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
|
||||||
|
) -> (result: LatticeNode, nodes: Nodes, cache: ZenzaiCache) {
|
||||||
var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? PrefixConstraint([])
|
var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? PrefixConstraint([])
|
||||||
print("initial constraint", constraint)
|
print("initial constraint", constraint)
|
||||||
let eosNode = LatticeNode.EOSNode
|
let eosNode = LatticeNode.EOSNode
|
||||||
@ -85,7 +91,7 @@ extension Kana2Kanji {
|
|||||||
// When inference occurs more than maximum times, then just return result at this point
|
// When inference occurs more than maximum times, then just return result at this point
|
||||||
return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
|
return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
|
||||||
}
|
}
|
||||||
let reviewResult = zenz.candidateEvaluate(convertTarget: inputData.convertTarget, candidates: [candidate])
|
let reviewResult = zenz.candidateEvaluate(convertTarget: inputData.convertTarget, candidates: [candidate], versionDependentConfig: versionDependentConfig)
|
||||||
inferenceLimit -= 1
|
inferenceLimit -= 1
|
||||||
let nextAction = self.review(
|
let nextAction = self.review(
|
||||||
candidateIndex: index,
|
candidateIndex: index,
|
||||||
|
@ -30,12 +30,12 @@ import SwiftUtils
|
|||||||
try? self.zenzContext?.reset_context()
|
try? self.zenzContext?.reset_context()
|
||||||
}
|
}
|
||||||
|
|
||||||
func candidateEvaluate(convertTarget: String, candidates: [Candidate]) -> ZenzContext.CandidateEvaluationResult {
|
func candidateEvaluate(convertTarget: String, candidates: [Candidate], versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> ZenzContext.CandidateEvaluationResult {
|
||||||
guard let zenzContext else {
|
guard let zenzContext else {
|
||||||
return .error
|
return .error
|
||||||
}
|
}
|
||||||
for candidate in candidates {
|
for candidate in candidates {
|
||||||
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate)
|
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate, versionDependentConfig: versionDependentConfig)
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
return .error
|
return .error
|
||||||
|
@ -139,11 +139,29 @@ class ZenzContext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func evaluate_candidate(input: String, candidate: Candidate) -> CandidateEvaluationResult {
|
func evaluate_candidate(input: String, candidate: Candidate, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
|
||||||
print("Evaluate", candidate)
|
print("Evaluate", candidate)
|
||||||
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
|
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
|
||||||
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
|
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
|
||||||
let prompt = "\u{EE00}\(input)\u{EE01}"
|
let prompt: String
|
||||||
|
if case .v2(let mode) = versionDependentConfig {
|
||||||
|
if let leftSideContext = mode.leftSideContext, !leftSideContext.isEmpty {
|
||||||
|
let lsContext = leftSideContext.suffix(40)
|
||||||
|
if let profile = mode.profile, !profile.isEmpty {
|
||||||
|
let pf = profile.suffix(25)
|
||||||
|
prompt = "\u{EE00}\(input)\u{EE02}プロフィール:\(pf)・発言:\(lsContext)\u{EE01}"
|
||||||
|
} else {
|
||||||
|
prompt = "\u{EE00}\(input)\u{EE02}\(lsContext)\u{EE01}"
|
||||||
|
}
|
||||||
|
} else if let profile = mode.profile, !profile.isEmpty {
|
||||||
|
let pf = profile.suffix(25)
|
||||||
|
prompt = "\u{EE00}\(input)\u{EE02}プロフィール:\(pf)・発言:\u{EE01}"
|
||||||
|
} else {
|
||||||
|
prompt = "\u{EE00}\(input)\u{EE01}"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
prompt = "\u{EE00}\(input)\u{EE01}"
|
||||||
|
}
|
||||||
// Therefore, tokens = prompt_tokens + candidate_tokens is an appropriate operation.
|
// Therefore, tokens = prompt_tokens + candidate_tokens is an appropriate operation.
|
||||||
let prompt_tokens = self.tokenize(text: prompt, add_bos: true, add_eos: false)
|
let prompt_tokens = self.tokenize(text: prompt, add_bos: true, add_eos: false)
|
||||||
let candidate_tokens = self.tokenize(text: candidate.text, add_bos: false, add_eos: false)
|
let candidate_tokens = self.tokenize(text: candidate.text, add_bos: false, add_eos: false)
|
||||||
|
Reference in New Issue
Block a user