feat: zenz-v2の文脈による条件づけ機能を活かしたAPIを追加

This commit is contained in:
Miwa / Ensan
2024-07-31 00:46:39 +09:00
parent 596e4701f8
commit 3770633c44
6 changed files with 91 additions and 14 deletions

View File

@ -25,7 +25,10 @@ extension Subcommands {
var roman2kana = false
@Option(name: [.customLong("config_zenzai_inference_limit")], help: "inference limit for zenzai.")
var configZenzaiInferenceLimit: Int = .max
@Option(name: [.customLong("config_profile")], help: "enable profile prompting for zenz-v2.")
var configZenzV2Profile: String? = nil
@Flag(name: [.customLong("zenz_v1")], help: "Use zenz_v1 model.")
var zenzV1 = false
static var configuration = CommandConfiguration(commandName: "session", abstract: "Start session for incremental input.")
@ -44,6 +47,9 @@ extension Subcommands {
}
@MainActor mutating func run() async {
if self.zenzV1 {
print("\(bold: "We strongly recommend to use zenz-v2 models")")
}
let memoryDirector = if self.enableLearning {
if let dir = self.getTemporaryDirectory() {
dir
@ -58,10 +64,14 @@ extension Subcommands {
var composingText = ComposingText()
let inputStyle: InputStyle = self.roman2kana ? .roman2kana : .direct
var lastCandidates: [Candidate] = []
var leftSideContext: String = ""
var page = 0
while true {
print()
print("\(bold: "== Type :q to end session, type :d to delete character, type :c to stop composition. For other commands, type :h ==")")
if !leftSideContext.isEmpty {
print("\(bold: "Current Left-Side Context"): \(leftSideContext)")
}
var input = readLine(strippingNewline: true) ?? ""
switch input {
case ":q":
@ -74,6 +84,7 @@ extension Subcommands {
//
composingText.stopComposition()
converter.stopComposition()
leftSideContext = ""
print("composition is stopped")
continue
case ":n":
@ -119,6 +130,7 @@ extension Subcommands {
composingText.stopComposition()
converter.stopComposition()
}
leftSideContext += candidate.text
} else {
input = String(input.map { (c: Character) -> Character in
[
@ -132,7 +144,7 @@ extension Subcommands {
}
print(composingText.convertTarget)
let start = Date()
let result = converter.requestCandidates(composingText, options: requestOptions(memoryDirector: memoryDirector))
let result = converter.requestCandidates(composingText, options: requestOptions(memoryDirector: memoryDirector, leftSideContext: leftSideContext))
let mainResults = result.mainResults.filter {
!self.onlyWholeConversion || $0.data.reduce(into: "", {$0.append(contentsOf: $1.ruby)}) == input.toKatakana()
}
@ -159,7 +171,12 @@ extension Subcommands {
}
}
func requestOptions(memoryDirector: URL) -> ConvertRequestOptions {
func requestOptions(memoryDirector: URL, leftSideContext: String) -> ConvertRequestOptions {
let zenzaiVersionDependentMode: ConvertRequestOptions.ZenzaiVersionDependentMode = if self.zenzV1 {
.v1
} else {
.v2(.init(profile: self.configZenzV2Profile, leftSideContext: leftSideContext))
}
var option: ConvertRequestOptions = .withDefaultDictionary(
N_best: self.onlyWholeConversion ? max(self.configNBest, self.displayTopN) : self.configNBest,
requireJapanesePrediction: !self.onlyWholeConversion && !self.disablePrediction,
@ -175,7 +192,11 @@ extension Subcommands {
shouldResetMemory: false,
memoryDirectoryURL: memoryDirector,
sharedContainerURL: URL(fileURLWithPath: ""),
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit),
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(
weight: URL(string: self.zenzWeightPath)!,
inferenceLimit: self.configZenzaiInferenceLimit,
versionDependentMode: zenzaiVersionDependentMode
),
metadata: .init(versionString: "anco for debugging")
)
if self.onlyWholeConversion {

View File

@ -142,18 +142,50 @@ public struct ConvertRequestOptions: Sendable {
case
}
public struct ZenzaiV2DependentMode: Sendable, Equatable, Hashable {
public init(profile: String? = nil, leftSideContext: String? = nil) {
self.profile = profile
self.leftSideContext = leftSideContext
}
/// 1020
public var profile: String?
///
public var leftSideContext: String?
}
public enum ZenzVersion: Sendable, Equatable, Hashable {
case v1
case v2
}
public enum ZenzaiVersionDependentMode: Sendable, Equatable, Hashable {
case v1
case v2(ZenzaiV2DependentMode)
public var version: ZenzVersion {
switch self {
case .v1:
return .v1
case .v2(_):
return .v2
}
}
}
public struct ZenzaiMode: Sendable, Equatable {
public static let off = ZenzaiMode(enabled: false, weightURL: URL(fileURLWithPath: ""), inferenceLimit: 10)
public static let off = ZenzaiMode(enabled: false, weightURL: URL(fileURLWithPath: ""), inferenceLimit: 10, versionDependentMode: .v2(.init()))
/// activate *Zenzai* - Neural Kana-Kanji Conversiion Engine
/// - Parameters:
/// - weight: path for model weight (gguf)
/// - inferenceLimit: applying inference count limitation. Smaller limit makes conversion faster but quality will be worse. (Default: 10)
public static func on(weight: URL, inferenceLimit: Int = 10) -> Self {
ZenzaiMode(enabled: true, weightURL: weight, inferenceLimit: inferenceLimit)
public static func on(weight: URL, inferenceLimit: Int = 10, versionDependentMode: ZenzaiVersionDependentMode = .v2(.init())) -> Self {
ZenzaiMode(enabled: true, weightURL: weight, inferenceLimit: inferenceLimit, versionDependentMode: versionDependentMode)
}
var enabled: Bool
var weightURL: URL
var inferenceLimit: Int
var versionDependentMode: ZenzaiVersionDependentMode
}
}

View File

@ -568,7 +568,7 @@ import SwiftUtils
// FIXME: enable cache based zenzai
if zenzaiMode.enabled, let model = self.getModel(modelURL: zenzaiMode.weightURL) {
let (result, nodes, cache) = self.converter.all_zenzai(inputData, zenz: model, zenzaiCache: self.zenzaiCache, inferenceLimit: zenzaiMode.inferenceLimit)
let (result, nodes, cache) = self.converter.all_zenzai(inputData, zenz: model, zenzaiCache: self.zenzaiCache, inferenceLimit: zenzaiMode.inferenceLimit, versionDependentConfig: zenzaiMode.versionDependentMode)
self.zenzaiCache = cache
self.previousInputData = inputData
return (result, nodes)

View File

@ -47,7 +47,13 @@ extension Kana2Kanji {
}
/// zenzai
@MainActor func all_zenzai(_ inputData: ComposingText, zenz: Zenz, zenzaiCache: ZenzaiCache?, inferenceLimit: Int) -> (result: LatticeNode, nodes: Nodes, cache: ZenzaiCache) {
@MainActor func all_zenzai(
_ inputData: ComposingText,
zenz: Zenz,
zenzaiCache: ZenzaiCache?,
inferenceLimit: Int,
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
) -> (result: LatticeNode, nodes: Nodes, cache: ZenzaiCache) {
var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? PrefixConstraint([])
print("initial constraint", constraint)
let eosNode = LatticeNode.EOSNode
@ -85,7 +91,7 @@ extension Kana2Kanji {
// When inference occurs more than maximum times, then just return result at this point
return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
}
let reviewResult = zenz.candidateEvaluate(convertTarget: inputData.convertTarget, candidates: [candidate])
let reviewResult = zenz.candidateEvaluate(convertTarget: inputData.convertTarget, candidates: [candidate], versionDependentConfig: versionDependentConfig)
inferenceLimit -= 1
let nextAction = self.review(
candidateIndex: index,

View File

@ -30,12 +30,12 @@ import SwiftUtils
try? self.zenzContext?.reset_context()
}
func candidateEvaluate(convertTarget: String, candidates: [Candidate]) -> ZenzContext.CandidateEvaluationResult {
func candidateEvaluate(convertTarget: String, candidates: [Candidate], versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> ZenzContext.CandidateEvaluationResult {
guard let zenzContext else {
return .error
}
for candidate in candidates {
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate)
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate, versionDependentConfig: versionDependentConfig)
return result
}
return .error

View File

@ -139,11 +139,29 @@ class ZenzContext {
}
}
func evaluate_candidate(input: String, candidate: Candidate) -> CandidateEvaluationResult {
func evaluate_candidate(input: String, candidate: Candidate, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
print("Evaluate", candidate)
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
let prompt = "\u{EE00}\(input)\u{EE01}"
let prompt: String
if case .v2(let mode) = versionDependentConfig {
if let leftSideContext = mode.leftSideContext, !leftSideContext.isEmpty {
let lsContext = leftSideContext.suffix(40)
if let profile = mode.profile, !profile.isEmpty {
let pf = profile.suffix(25)
prompt = "\u{EE00}\(input)\u{EE02}プロフィール:\(pf)・発言:\(lsContext)\u{EE01}"
} else {
prompt = "\u{EE00}\(input)\u{EE02}\(lsContext)\u{EE01}"
}
} else if let profile = mode.profile, !profile.isEmpty {
let pf = profile.suffix(25)
prompt = "\u{EE00}\(input)\u{EE02}プロフィール:\(pf)・発言:\u{EE01}"
} else {
prompt = "\u{EE00}\(input)\u{EE01}"
}
} else {
prompt = "\u{EE00}\(input)\u{EE01}"
}
// Therefore, tokens = prompt_tokens + candidate_tokens is an appropriate operation.
let prompt_tokens = self.tokenize(text: prompt, add_bos: true, add_eos: false)
let candidate_tokens = self.tokenize(text: candidate.text, add_bos: false, add_eos: false)