mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-08-22 15:05:26 +00:00
322 lines
16 KiB
Swift
322 lines
16 KiB
Swift
import Algorithms
|
|
import ArgumentParser
|
|
import Foundation
|
|
import KanaKanjiConverterModuleWithDefaultDictionary
|
|
import SwiftUtils
|
|
import Tokenizers
|
|
|
|
extension Subcommands {
|
|
struct Session: AsyncParsableCommand {
|
|
@Argument(help: "ひらがなで表記された入力")
|
|
var input: String = ""
|
|
|
|
@Option(name: [.customLong("config_n_best")], help: "The parameter n (n best parameter) for internal viterbi search.")
|
|
var configNBest: Int = 10
|
|
@Option(name: [.customShort("n"), .customLong("top_n")], help: "Display top n candidates.")
|
|
var displayTopN: Int = 1
|
|
@Option(name: [.customLong("zenz")], help: "gguf format model weight for zenz.")
|
|
var zenzWeightPath: String = ""
|
|
@Flag(name: [.customLong("mix_english_candidate")], help: "Enable mixing English Candidates.")
|
|
var mixEnglishCandidate = false
|
|
@Flag(name: [.customLong("disable_prediction")], help: "Disable producing prediction candidates.")
|
|
var disablePrediction = false
|
|
@Flag(name: [.customLong("enable_memory")], help: "Enable memory.")
|
|
var enableLearning = false
|
|
@Option(name: [.customLong("readonly_memory")], help: "Enable readonly memory.")
|
|
var readOnlyMemoryPath: String?
|
|
@Flag(name: [.customLong("only_whole_conversion")], help: "Show only whole conversion (完全一致変換).")
|
|
var onlyWholeConversion = false
|
|
@Flag(name: [.customLong("report_score")], help: "Show internal score for the candidate.")
|
|
var reportScore = false
|
|
@Flag(name: [.customLong("roman2kana")], help: "Use roman2kana input.")
|
|
var roman2kana = false
|
|
@Option(name: [.customLong("config_zenzai_inference_limit")], help: "inference limit for zenzai.")
|
|
var configZenzaiInferenceLimit: Int = .max
|
|
@Flag(name: [.customLong("config_zenzai_rich_n_best")], help: "enable rich n_best generation for zenzai.")
|
|
var configRequestRichCandidates = false
|
|
@Option(name: [.customLong("config_profile")], help: "enable profile prompting for zenz-v2 and later.")
|
|
var configZenzaiProfile: String?
|
|
@Option(name: [.customLong("config_topic")], help: "enable topic prompting for zenz-v3 and later.")
|
|
var configZenzaiTopic: String?
|
|
@Flag(name: [.customLong("zenz_v1")], help: "Use zenz_v1 model.")
|
|
var zenzV1 = false
|
|
@Flag(name: [.customLong("zenz_v2")], help: "Use zenz_v2 model.")
|
|
var zenzV2 = false
|
|
@Flag(name: [.customLong("zenz_v3")], help: "Use zenz_v3 model.")
|
|
var zenzV3 = false
|
|
@Option(name: [.customLong("config_zenzai_base_lm")], help: "Marisa files for Base LM.")
|
|
var configZenzaiBaseLM: String?
|
|
@Option(name: [.customLong("config_zenzai_personal_lm")], help: "Marisa files for Personal LM.")
|
|
var configZenzaiPersonalLM: String?
|
|
@Option(name: [.customLong("config_zenzai_personalization_alpha")], help: "Strength of personalization (0.5 by default)")
|
|
var configZenzaiPersonalizationAlpha: Float = 0.5
|
|
|
|
@Option(name: [.customLong("replay")], help: "history.txt for replay.")
|
|
var replayHistory: String?
|
|
|
|
static let configuration = CommandConfiguration(commandName: "session", abstract: "Start session for incremental input.")
|
|
|
|
private func getTemporaryDirectory() -> URL? {
|
|
let fileManager = FileManager.default
|
|
let tempDirectoryURL = fileManager.temporaryDirectory.appendingPathComponent(UUID().uuidString)
|
|
|
|
do {
|
|
try fileManager.createDirectory(at: tempDirectoryURL, withIntermediateDirectories: true, attributes: nil)
|
|
print("Temporary directory created at \(tempDirectoryURL)")
|
|
return tempDirectoryURL
|
|
} catch {
|
|
print("Error creating temporary directory: \(error)")
|
|
return nil
|
|
}
|
|
}
|
|
|
|
@MainActor mutating func run() async {
|
|
if self.zenzV1 || self.zenzV2 {
|
|
print("\(bold: "We strongly recommend to use zenz-v3 models")")
|
|
}
|
|
if (self.zenzV1 || self.zenzV2 || self.zenzV3) && self.zenzWeightPath.isEmpty {
|
|
preconditionFailure("\(bold: "zenz version is specified but --zenz weight is not specified")")
|
|
}
|
|
if !self.zenzWeightPath.isEmpty && (!self.zenzV1 && !self.zenzV2 && !self.zenzV3) {
|
|
print("zenz version is not specified. By default, zenz-v3 will be used.")
|
|
}
|
|
let learningType: LearningType = if self.readOnlyMemoryPath != nil {
|
|
// 読み取りのみ
|
|
.onlyOutput
|
|
} else if self.enableLearning {
|
|
// 読み書き
|
|
.inputAndOutput
|
|
} else {
|
|
// 読み書きなし
|
|
.nothing
|
|
}
|
|
let memoryDirectory = if let readOnlyMemoryPath {
|
|
URL(fileURLWithPath: readOnlyMemoryPath)
|
|
} else if self.enableLearning {
|
|
if let dir = self.getTemporaryDirectory() {
|
|
dir
|
|
} else {
|
|
fatalError("Could not get temporary directory.")
|
|
}
|
|
} else {
|
|
URL(fileURLWithPath: "")
|
|
}
|
|
print("Working with \(learningType) mode. Memory path is \(memoryDirectory).")
|
|
|
|
let converter = KanaKanjiConverter()
|
|
converter.sendToDicdataStore(
|
|
.setRequestOptions(requestOptions(learningType: learningType, memoryDirectory: memoryDirectory, leftSideContext: nil))
|
|
)
|
|
var composingText = ComposingText()
|
|
let inputStyle: InputStyle = self.roman2kana ? .roman2kana : .direct
|
|
var lastCandidates: [Candidate] = []
|
|
var leftSideContext: String = ""
|
|
var page = 0
|
|
|
|
var histories = [String]()
|
|
|
|
var inputs = self.replayHistory.map {
|
|
try! String(contentsOfFile: $0, encoding: .utf8)
|
|
}?.split(by: "\n")
|
|
inputs?.append(":q")
|
|
|
|
while true {
|
|
print()
|
|
print("\(bold: "== Type :q to end session, type :d to delete character, type :c to stop composition. For other commands, type :h ==")")
|
|
if !leftSideContext.isEmpty {
|
|
print("\(bold: "Current Left-Side Context"): \(leftSideContext)")
|
|
}
|
|
var input = if inputs != nil {
|
|
inputs!.removeFirst()
|
|
} else {
|
|
readLine(strippingNewline: true) ?? ""
|
|
}
|
|
histories.append(input)
|
|
switch input {
|
|
case ":q", ":quit":
|
|
// 終了
|
|
return
|
|
case ":d", ":del":
|
|
if !composingText.isEmpty {
|
|
composingText.deleteBackwardFromCursorPosition(count: 1)
|
|
} else {
|
|
_ = leftSideContext.popLast()
|
|
continue
|
|
}
|
|
case ":c", ":clear":
|
|
// クリア
|
|
composingText.stopComposition()
|
|
converter.stopComposition()
|
|
leftSideContext = ""
|
|
print("composition is stopped")
|
|
continue
|
|
case ":n", ":next":
|
|
// ページ送り
|
|
page += 1
|
|
for (i, candidate) in lastCandidates[self.displayTopN * page ..< self.displayTopN * (page + 1)].indexed() {
|
|
if self.reportScore {
|
|
print("\(bold: String(i)). \(candidate.text) \(bold: "score:") \(candidate.value)")
|
|
} else {
|
|
print("\(bold: String(i)). \(candidate.text)")
|
|
}
|
|
}
|
|
continue
|
|
case ":s", ":save":
|
|
composingText.stopComposition()
|
|
converter.stopComposition()
|
|
converter.sendToDicdataStore(.closeKeyboard)
|
|
if learningType.needUpdateMemory {
|
|
print("saved")
|
|
} else {
|
|
print("anything should not be saved because the learning type is not for update memory")
|
|
}
|
|
continue
|
|
case ":p", ":pred":
|
|
// 次の文字の予測を取得する
|
|
let results = converter.predictNextCharacter(
|
|
leftSideContext: leftSideContext,
|
|
count: 10,
|
|
options: requestOptions(learningType: learningType, memoryDirectory: memoryDirectory, leftSideContext: leftSideContext)
|
|
)
|
|
if let firstCandidate = results.first {
|
|
leftSideContext.append(firstCandidate.character)
|
|
}
|
|
continue
|
|
case ":h", ":help":
|
|
// ヘルプ
|
|
print("""
|
|
\(bold: "== anco session commands ==")
|
|
\(bold: ":q, :quit") - quit session
|
|
\(bold: ":c, :clear") - clear composition
|
|
\(bold: ":d, :del") - delete one character
|
|
\(bold: ":n, :next") - see more candidates
|
|
\(bold: ":s, :save") - save memory to temporary directory
|
|
\(bold: ":p, :pred") - predict next one character
|
|
\(bold: ":%d") - select candidate at that index (like :3 to select 3rd candidate)
|
|
\(bold: ":ctx %s") - set the string as context
|
|
\(bold: ":dump %s") - dump command history to specified file name (default: history.txt).
|
|
""")
|
|
default:
|
|
if input.hasPrefix(":ctx") {
|
|
let ctx = String(input.split(by: ":ctx ").last ?? "")
|
|
leftSideContext.append(ctx)
|
|
continue
|
|
} else if input.hasPrefix(":dump") {
|
|
let fileName = if ":dump " < input {
|
|
String(input.dropFirst(6))
|
|
} else {
|
|
"history.txt"
|
|
}
|
|
histories.removeAll(where: {$0.hasPrefix(":dump")})
|
|
let content = histories.joined(separator: "\n")
|
|
try! content.write(to: URL(fileURLWithPath: fileName), atomically: true, encoding: .utf8)
|
|
continue
|
|
} else if input.hasPrefix(":"), let index = Int(input.dropFirst()) {
|
|
if !lastCandidates.indices.contains(index) {
|
|
print("\(bold: "Error"): Index \(index) is not available for current context.")
|
|
continue
|
|
}
|
|
let candidate = lastCandidates[index]
|
|
print("Submit \(candidate.text)")
|
|
converter.setCompletedData(candidate)
|
|
converter.updateLearningData(candidate)
|
|
composingText.prefixComplete(correspondingCount: candidate.correspondingCount)
|
|
if composingText.isEmpty {
|
|
composingText.stopComposition()
|
|
converter.stopComposition()
|
|
}
|
|
leftSideContext += candidate.text
|
|
} else {
|
|
input = String(input.map { (c: Character) -> Character in
|
|
[
|
|
"-": "ー",
|
|
".": "。",
|
|
",": "、"
|
|
][c, default: c]
|
|
})
|
|
composingText.insertAtCursorPosition(input, inputStyle: inputStyle)
|
|
}
|
|
}
|
|
print(composingText.convertTarget)
|
|
let start = Date()
|
|
let result = converter.requestCandidates(composingText, options: requestOptions(learningType: learningType, memoryDirectory: memoryDirectory, leftSideContext: leftSideContext))
|
|
let mainResults = result.mainResults.filter {
|
|
!self.onlyWholeConversion || $0.data.reduce(into: "", {$0.append(contentsOf: $1.ruby)}) == input.toKatakana()
|
|
}
|
|
for (i, candidate) in mainResults.prefix(self.displayTopN).indexed() {
|
|
if self.reportScore {
|
|
print("\(bold: String(i)). \(candidate.text) \(bold: "score:") \(candidate.value)")
|
|
} else {
|
|
print("\(bold: String(i)). \(candidate.text)")
|
|
}
|
|
}
|
|
lastCandidates = mainResults
|
|
page = 0
|
|
if self.onlyWholeConversion {
|
|
// entropyを示す
|
|
let mean = mainResults.reduce(into: 0) { $0 += Double($1.value) } / Double(mainResults.count)
|
|
let expValues = mainResults.map { exp(Double($0.value) - mean) }
|
|
let sumOfExpValues = expValues.reduce(into: 0, +=)
|
|
// 確率値に補正
|
|
let probs = mainResults.map { exp(Double($0.value) - mean) / sumOfExpValues }
|
|
let entropy = -probs.reduce(into: 0) { $0 += $1 * log($1) }
|
|
print("\(bold: "Entropy:") \(entropy)")
|
|
}
|
|
print("\(bold: "Time:") \(-start.timeIntervalSinceNow)")
|
|
}
|
|
}
|
|
|
|
func requestOptions(learningType: LearningType, memoryDirectory: URL, leftSideContext: String?) -> ConvertRequestOptions {
|
|
let zenzaiVersionDependentMode: ConvertRequestOptions.ZenzaiVersionDependentMode = if self.zenzV1 {
|
|
.v1
|
|
} else if self.zenzV2 {
|
|
.v2(.init(profile: self.configZenzaiProfile, leftSideContext: leftSideContext))
|
|
} else {
|
|
.v3(.init(profile: self.configZenzaiProfile, topic: self.configZenzaiTopic, leftSideContext: leftSideContext))
|
|
}
|
|
let personalizationMode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode?
|
|
if let base = self.configZenzaiBaseLM, let personal = self.configZenzaiPersonalLM {
|
|
personalizationMode = .init(
|
|
baseNgramLanguageModel: base,
|
|
personalNgramLanguageModel: personal,
|
|
n: 5,
|
|
d: 0.75,
|
|
alpha: self.configZenzaiPersonalizationAlpha
|
|
)
|
|
} else if self.configZenzaiBaseLM != nil || self.configZenzaiPersonalLM != nil {
|
|
fatalError("Both --config_zenzai_base_lm and --config_zenzai_personal_lm must be set")
|
|
} else {
|
|
personalizationMode = nil
|
|
}
|
|
var option: ConvertRequestOptions = .withDefaultDictionary(
|
|
N_best: self.onlyWholeConversion ? max(self.configNBest, self.displayTopN) : self.configNBest,
|
|
requireJapanesePrediction: !self.onlyWholeConversion && !self.disablePrediction,
|
|
requireEnglishPrediction: false,
|
|
keyboardLanguage: .ja_JP,
|
|
typographyLetterCandidate: false,
|
|
unicodeCandidate: true,
|
|
englishCandidateInRoman2KanaInput: self.mixEnglishCandidate,
|
|
fullWidthRomanCandidate: false,
|
|
halfWidthKanaCandidate: false,
|
|
learningType: learningType,
|
|
shouldResetMemory: false,
|
|
memoryDirectoryURL: memoryDirectory,
|
|
sharedContainerURL: URL(fileURLWithPath: ""),
|
|
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(
|
|
weight: URL(string: self.zenzWeightPath)!,
|
|
inferenceLimit: self.configZenzaiInferenceLimit,
|
|
requestRichCandidates: self.configRequestRichCandidates,
|
|
personalizationMode: personalizationMode,
|
|
versionDependentMode: zenzaiVersionDependentMode
|
|
),
|
|
metadata: .init(versionString: "anco for debugging")
|
|
)
|
|
if self.onlyWholeConversion {
|
|
option.requestQuery = .完全一致
|
|
}
|
|
return option
|
|
}
|
|
}
|
|
}
|
|
|