diff --git a/.github/workflows/swift.yml b/.github/workflows/swift.yml index b633710..42eb1c8 100644 --- a/.github/workflows/swift.yml +++ b/.github/workflows/swift.yml @@ -9,12 +9,12 @@ on: branches: [ "main", "develop" ] jobs: - build: + macos-build: name: Swift ${{ matrix.swift-version }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-latest] + os: [macos-latest] swift-version: ["5.9", "5.10"] steps: - uses: swift-actions/setup-swift@v2 @@ -24,6 +24,25 @@ jobs: with: submodules: true - name: Build - run: swift build -Xswiftc -strict-concurrency=complete -v + run: swift build -Xswiftc -strict-concurrency=complete -Xcxx -xobjective-c++ -v - name: Run tests - run: swift test -c release -Xswiftc -strict-concurrency=complete -v + run: swift test -c release -Xswiftc -strict-concurrency=complete -Xcxx -xobjective-c++ -v + ubuntu-build: + name: Swift ${{ matrix.swift-version }} on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + swift-version: ["5.9", "5.10"] + steps: + - uses: swift-actions/setup-swift@v2 + with: + swift-version: ${{ matrix.swift-version }} + - uses: actions/checkout@v4 + with: + submodules: true + - name: Build + run: swift build -Xswiftc -strict-concurrency=complete -v + - name: Run tests + run: swift test -c release -Xswiftc -strict-concurrency=complete -v + \ No newline at end of file diff --git a/.gitignore b/.gitignore index 5d66184..88016f5 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ Package.resolved *.pyc .docc-build .vscode +*.gguf diff --git a/Package.swift b/Package.swift index e621898..b0a08e4 100644 --- a/Package.swift +++ b/Package.swift @@ -13,9 +13,10 @@ let swiftSettings: [SwiftSetting] = [ .enableUpcomingFeature("DisableOutwardActorInference"), .enableUpcomingFeature("ImportObjcForwardDeclarations") ] + let package = Package( name: "AzooKeyKanakanjiConverter", - platforms: [.iOS(.v14), .macOS(.v11)], + platforms: [.iOS(.v14), .macOS(.v12)], products: [ // Products define the executables and libraries a package produces, and make them visible to other packages. .library( @@ -39,6 +40,8 @@ let package = Package( .package(url: "https://github.com/apple/swift-algorithms", from: "1.0.0"), .package(url: "https://github.com/apple/swift-collections", from: "1.0.0"), .package(url: "https://github.com/apple/swift-argument-parser", .upToNextMajor(from: "1.0.0")), + // local package + .package(url: "https://github.com/ensan-hcl/llama.cpp", branch: "9f41923"), ], targets: [ // Targets are the basic building blocks of a package. A target can define a module or a test suite. @@ -54,9 +57,9 @@ let package = Package( .target( name: "KanaKanjiConverterModule", dependencies: [ - "SwiftUtils" + "SwiftUtils", + .product(name: "llama", package: "llama.cpp") ], - resources: [], swiftSettings: swiftSettings ), .target( diff --git a/Sources/CliTool/Anco.swift b/Sources/CliTool/Anco.swift index 715511e..a449430 100644 --- a/Sources/CliTool/Anco.swift +++ b/Sources/CliTool/Anco.swift @@ -2,10 +2,10 @@ import KanaKanjiConverterModuleWithDefaultDictionary import ArgumentParser @main -public struct Anco: ParsableCommand { +public struct Anco: AsyncParsableCommand { public static var configuration = CommandConfiguration( abstract: "Anco is A(zooKey) Kana-Ka(n)ji (co)nverter", - subcommands: [Subcommands.Run.self, Subcommands.Dict.self, Subcommands.Evaluate.self], + subcommands: [Subcommands.Run.self, Subcommands.Dict.self, Subcommands.Evaluate.self, Subcommands.Session.self], defaultSubcommand: Subcommands.Run.self ) diff --git a/Sources/CliTool/Subcommands/EvaluateCommand.swift b/Sources/CliTool/Subcommands/EvaluateCommand.swift index 9c59159..11e2c76 100644 --- a/Sources/CliTool/Subcommands/EvaluateCommand.swift +++ b/Sources/CliTool/Subcommands/EvaluateCommand.swift @@ -13,10 +13,14 @@ extension Subcommands { var configNBest: Int = 10 @Flag(name: [.customLong("stable")], help: "Report only stable properties; timestamps and values will not be reported.") var stable: Bool = false + @Option(name: [.customLong("zenz")], help: "gguf format model weight for zenz.") + var zenzWeightPath: String = "" + @Option(name: [.customLong("config_zenzai_inference_limit")], help: "inference limit for zenzai.") + var configZenzaiInferenceLimit: Int = .max static var configuration = CommandConfiguration(commandName: "evaluate", abstract: "Evaluate quality of Conversion for input data.") - func parseInputFile() throws -> [InputItem] { + private func parseInputFile() throws -> [InputItem] { let url = URL(fileURLWithPath: self.inputFile) let lines = (try String(contentsOf: url)).split(separator: "\n", omittingEmptySubsequences: false) return lines.enumerated().compactMap { (index, line) -> InputItem? in @@ -33,14 +37,15 @@ extension Subcommands { @MainActor mutating func run() throws { let inputItems = try parseInputFile() - + let requestOptions = requestOptions() let converter = KanaKanjiConverter() let start = Date() var resultItems: [EvaluateItem] = [] for item in inputItems { var composingText = ComposingText() composingText.insertAtCursorPosition(item.query, inputStyle: .direct) - let result = converter.requestCandidates(composingText, options: requestOptions()) + + let result = converter.requestCandidates(composingText, options: requestOptions) let mainResults = result.mainResults.filter { $0.data.reduce(into: "", {$0.append(contentsOf: $1.ruby)}) == item.query.toKatakana() } @@ -53,6 +58,8 @@ extension Subcommands { } ) ) + // Explictly reset state + converter.stopComposition() } let end = Date() var result = EvaluateResult(n_best: self.configNBest, execution_time: end.timeIntervalSince(start), items: resultItems) @@ -94,6 +101,7 @@ extension Subcommands { shouldResetMemory: false, memoryDirectoryURL: URL(fileURLWithPath: ""), sharedContainerURL: URL(fileURLWithPath: ""), + zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit), metadata: .init(versionString: "anco for debugging") ) option.requestQuery = .完全一致 @@ -101,7 +109,7 @@ extension Subcommands { } } - struct InputItem { + private struct InputItem { /// 入力クエリ var query: String diff --git a/Sources/CliTool/Subcommands/RunCommand.swift b/Sources/CliTool/Subcommands/RunCommand.swift index cbe9395..7a90403 100644 --- a/Sources/CliTool/Subcommands/RunCommand.swift +++ b/Sources/CliTool/Subcommands/RunCommand.swift @@ -3,7 +3,7 @@ import ArgumentParser import Foundation extension Subcommands { - struct Run: ParsableCommand { + struct Run: AsyncParsableCommand { @Argument(help: "ひらがなで表記された入力") var input: String = "" @@ -11,6 +11,10 @@ extension Subcommands { var configNBest: Int = 10 @Option(name: [.customShort("n"), .customLong("top_n")], help: "Display top n candidates.") var displayTopN: Int = 1 + @Option(name: [.customLong("zenz")], help: "gguf format model weight for zenz.") + var zenzWeightPath: String = "" + @Option(name: [.customLong("config_zenzai_inference_limit")], help: "inference limit for zenzai.") + var configZenzaiInferenceLimit: Int = .max @Flag(name: [.customLong("disable_prediction")], help: "Disable producing prediction candidates.") var disablePrediction = false @@ -23,7 +27,7 @@ extension Subcommands { static var configuration = CommandConfiguration(commandName: "run", abstract: "Show help for this utility.") - @MainActor mutating func run() { + @MainActor mutating func run() async { let converter = KanaKanjiConverter() var composingText = ComposingText() composingText.insertAtCursorPosition(input, inputStyle: .direct) @@ -66,6 +70,7 @@ extension Subcommands { shouldResetMemory: false, memoryDirectoryURL: URL(fileURLWithPath: ""), sharedContainerURL: URL(fileURLWithPath: ""), + zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit), metadata: .init(versionString: "anco for debugging") ) if self.onlyWholeConversion { diff --git a/Sources/CliTool/Subcommands/SessionCommand.swift b/Sources/CliTool/Subcommands/SessionCommand.swift new file mode 100644 index 0000000..33950e4 --- /dev/null +++ b/Sources/CliTool/Subcommands/SessionCommand.swift @@ -0,0 +1,102 @@ +import KanaKanjiConverterModuleWithDefaultDictionary +import ArgumentParser +import Foundation + +extension Subcommands { + struct Session: AsyncParsableCommand { + @Argument(help: "ひらがなで表記された入力") + var input: String = "" + + @Option(name: [.customLong("config_n_best")], help: "The parameter n (n best parameter) for internal viterbi search.") + var configNBest: Int = 10 + @Option(name: [.customShort("n"), .customLong("top_n")], help: "Display top n candidates.") + var displayTopN: Int = 1 + @Option(name: [.customLong("zenz")], help: "gguf format model weight for zenz.") + var zenzWeightPath: String = "" + @Flag(name: [.customLong("disable_prediction")], help: "Disable producing prediction candidates.") + var disablePrediction = false + @Flag(name: [.customLong("only_whole_conversion")], help: "Show only whole conversion (完全一致変換).") + var onlyWholeConversion = false + @Flag(name: [.customLong("report_score")], help: "Show internal score for the candidate.") + var reportScore = false + @Flag(name: [.customLong("roman2kana")], help: "Use roman2kana input.") + var roman2kana = false + @Option(name: [.customLong("config_zenzai_inference_limit")], help: "inference limit for zenzai.") + var configZenzaiInferenceLimit: Int = .max + + + static var configuration = CommandConfiguration(commandName: "session", abstract: "Start session for incremental input.") + + @MainActor mutating func run() async { + let converter = KanaKanjiConverter() + var composingText = ComposingText() + let inputStyle: InputStyle = self.roman2kana ? .roman2kana : .direct + while true { + print() + print("\(bold: "== type :q to end session, type :d to delete character, type :c to stop composition, type any other text to input ==")") + let input = readLine(strippingNewline: true) ?? "" + switch input { + case ":q": return + case ":d": + composingText.deleteBackwardFromCursorPosition(count: 1) + case ":c": + composingText.stopComposition() + converter.stopComposition() + print("composition is stopped") + continue + default: + composingText.insertAtCursorPosition(input, inputStyle: inputStyle) + } + print(composingText.convertTarget) + let start = Date() + let result = converter.requestCandidates(composingText, options: requestOptions()) + let mainResults = result.mainResults.filter { + !self.onlyWholeConversion || $0.data.reduce(into: "", {$0.append(contentsOf: $1.ruby)}) == input.toKatakana() + } + for candidate in mainResults.prefix(self.displayTopN) { + if self.reportScore { + print("\(candidate.text) \(bold: "score:") \(candidate.value)") + } else { + print(candidate.text) + } + } + if self.onlyWholeConversion { + // entropyを示す + let mean = mainResults.reduce(into: 0) { $0 += Double($1.value) } / Double(mainResults.count) + let expValues = mainResults.map { exp(Double($0.value) - mean) } + let sumOfExpValues = expValues.reduce(into: 0, +=) + // 確率値に補正 + let probs = mainResults.map { exp(Double($0.value) - mean) / sumOfExpValues } + let entropy = -probs.reduce(into: 0) { $0 += $1 * log($1) } + print("\(bold: "Entropy:") \(entropy)") + } + print("\(bold: "Time:") \(-start.timeIntervalSinceNow)") + } + } + + func requestOptions() -> ConvertRequestOptions { + var option: ConvertRequestOptions = .withDefaultDictionary( + N_best: self.onlyWholeConversion ? max(self.configNBest, self.displayTopN) : self.configNBest, + requireJapanesePrediction: !self.onlyWholeConversion && !self.disablePrediction, + requireEnglishPrediction: false, + keyboardLanguage: .ja_JP, + typographyLetterCandidate: false, + unicodeCandidate: true, + englishCandidateInRoman2KanaInput: true, + fullWidthRomanCandidate: false, + halfWidthKanaCandidate: false, + learningType: .nothing, + maxMemoryCount: 0, + shouldResetMemory: false, + memoryDirectoryURL: URL(fileURLWithPath: ""), + sharedContainerURL: URL(fileURLWithPath: ""), + zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit), + metadata: .init(versionString: "anco for debugging") + ) + if self.onlyWholeConversion { + option.requestQuery = .完全一致 + } + return option + } + } +} diff --git a/Sources/KanaKanjiConverterModule/Converter/ConvertRequestOptions.swift b/Sources/KanaKanjiConverterModule/Converter/ConvertRequestOptions.swift index 5753002..ced60c1 100644 --- a/Sources/KanaKanjiConverterModule/Converter/ConvertRequestOptions.swift +++ b/Sources/KanaKanjiConverterModule/Converter/ConvertRequestOptions.swift @@ -29,7 +29,7 @@ public struct ConvertRequestOptions: Sendable { /// - sharedContainerURL: ユーザ辞書など、キーボード外で書き込んだ設定データの保存されているディレクトリを指定します。 /// - textReplacer: 予測変換のための置換機を指定します。 /// - metadata: メタデータを指定します。詳しくは`ConvertRequestOptions.Metadata`を参照してください。 - public init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, textReplacer: TextReplacer = TextReplacer(), metadata: ConvertRequestOptions.Metadata?) { + public init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, textReplacer: TextReplacer = TextReplacer(), zenzaiMode: ZenzaiMode = .off, metadata: ConvertRequestOptions.Metadata?) { self.N_best = N_best self.requireJapanesePrediction = requireJapanesePrediction self.requireEnglishPrediction = requireEnglishPrediction @@ -46,10 +46,11 @@ public struct ConvertRequestOptions: Sendable { self.sharedContainerURL = sharedContainerURL self.metadata = metadata self.textReplacer = textReplacer + self.zenzaiMode = zenzaiMode self.dictionaryResourceURL = dictionaryResourceURL } - package init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, textReplacer: TextReplacer = TextReplacer(), metadata: ConvertRequestOptions.Metadata?, requestQuery: RequestQuery) { + package init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, textReplacer: TextReplacer = TextReplacer(), zenzaiMode: ZenzaiMode = .off, metadata: ConvertRequestOptions.Metadata?, requestQuery: RequestQuery) { self.N_best = N_best self.requireJapanesePrediction = requireJapanesePrediction self.requireEnglishPrediction = requireEnglishPrediction @@ -66,6 +67,7 @@ public struct ConvertRequestOptions: Sendable { self.sharedContainerURL = sharedContainerURL self.metadata = metadata self.textReplacer = textReplacer + self.zenzaiMode = zenzaiMode self.dictionaryResourceURL = dictionaryResourceURL } @@ -88,6 +90,7 @@ public struct ConvertRequestOptions: Sendable { public var memoryDirectoryURL: URL public var sharedContainerURL: URL public var dictionaryResourceURL: URL + public var zenzaiMode: ZenzaiMode // メタデータ public var metadata: Metadata? @@ -138,4 +141,19 @@ public struct ConvertRequestOptions: Sendable { case `default` case 完全一致 } + + public struct ZenzaiMode: Sendable, Equatable { + public static let off = ZenzaiMode(enabled: false, weightURL: URL(fileURLWithPath: ""), inferenceLimit: 10) + + /// activate *Zenzai* - Neural Kana-Kanji Conversiion Engine + /// - Parameters: + /// - weight: path for model weight (gguf) + /// - inferenceLimit: applying inference count limitation. Smaller limit makes conversion faster but quality will be worse. (Default: 10) + public static func on(weight: URL, inferenceLimit: Int = 10) -> Self { + ZenzaiMode(enabled: true, weightURL: weight, inferenceLimit: inferenceLimit) + } + var enabled: Bool + var weightURL: URL + var inferenceLimit: Int + } } diff --git a/Sources/KanaKanjiConverterModule/Converter/KanaKanjiConverter.swift b/Sources/KanaKanjiConverterModule/Converter/KanaKanjiConverter.swift index 5ca2f49..bf3e124 100644 --- a/Sources/KanaKanjiConverterModule/Converter/KanaKanjiConverter.swift +++ b/Sources/KanaKanjiConverterModule/Converter/KanaKanjiConverter.swift @@ -25,15 +25,37 @@ import SwiftUtils private var nodes: [[LatticeNode]] = [] private var completedData: Candidate? private var lastData: DicdataElement? + /// Zenzaiのためのzenz-v1モデル + private var zenz: Zenz? = nil + private var zenzaiCache: Kana2Kanji.ZenzaiCache? = nil + public private(set) var zenzStatus: String = "" /// リセットする関数 public func stopComposition() { + self.zenz?.endSession() + self.zenzaiCache = nil self.previousInputData = nil self.nodes = [] self.completedData = nil self.lastData = nil } + private func getModel(modelURL: URL) -> Zenz? { + if let model = self.zenz, model.resourceURL == modelURL { + self.zenzStatus = "load \(modelURL.absoluteString)" + return model + } else { + do { + self.zenz = try Zenz(resourceURL: modelURL) + self.zenzStatus = "load \(modelURL.absoluteString)" + return self.zenz + } catch { + self.zenzStatus = "load \(modelURL.absoluteString) " + error.localizedDescription + return nil + } + } + } + /// 入力する言語が分かったらこの関数をなるべく早い段階で呼ぶことで、SpellCheckerの初期化が行われ、変換がスムーズになる public func setKeyboardLanguage(_ language: KeyboardLanguage) { if !checkerInitialized[language, default: false] { @@ -429,10 +451,27 @@ import SwiftUtils // 文章全体を変換した場合の候補上位5件を作る let whole_sentence_unique_candidates = self.getUniqueCandidate(sums.map {$0.1}) if case .完全一致 = options.requestQuery { - // 完全一致候補のみが要求されている場合、ここで全てのデータを返してreturnする - return ConversionResult(mainResults: whole_sentence_unique_candidates.sorted(by: {$0.value > $1.value}), firstClauseResults: []) + if options.zenzaiMode.enabled { + return ConversionResult(mainResults: whole_sentence_unique_candidates, firstClauseResults: []) + } else { + return ConversionResult(mainResults: whole_sentence_unique_candidates.sorted(by: {$0.value > $1.value}), firstClauseResults: []) + } + } + // モデル重みを統合 + let sentence_candidates: [Candidate] + if options.zenzaiMode.enabled { + // FIXME: もう少し良い方法はありそうだけど、短期的にかなりハックな実装にした + // candidateのvalueをZenzaiの出力順に書き換えることで、このあとのrerank処理で騙されてくれるようになっている + // より根本的には、`Candidate`にAI評価値をもたせるなどの方法が必要そう + var first5 = Array(whole_sentence_unique_candidates.prefix(5)) + var values = first5.map(\.value).sorted(by: >) + for (i, v) in zip(first5.indices, values) { + first5[i].value = v + } + sentence_candidates = first5 + } else { + sentence_candidates = whole_sentence_unique_candidates.min(count: 5, sortedBy: {$0.value > $1.value}) } - let sentence_candidates = whole_sentence_unique_candidates.min(count: 5, sortedBy: {$0.value > $1.value}) // 予測変換を最大3件作成する let prediction_candidates: [Candidate] = options.requireJapanesePrediction ? Array(self.getUniqueCandidate(self.getPredictionCandidate(sums, composingText: inputData, options: options)).min(count: 3, sortedBy: {$0.value > $1.value})) : [] @@ -447,7 +486,7 @@ import SwiftUtils } // 文全体変換5件と予測変換3件を混ぜてベスト8を出す - let best8 = getUniqueCandidate(sentence_candidates.chained(prediction_candidates)).sorted {$0.value > $1.value} + let best8 = getUniqueCandidate(sentence_candidates.prefix(5).chained(prediction_candidates)).sorted {$0.value > $1.value} // その他のトップレベル変換(先頭に表示されうる変換候補) let toplevel_additional_candidate = self.getTopLevelAdditionalCandidate(inputData, options: options) // best8、foreign_candidates、zeroHintPrediction_candidates、toplevel_additional_candidateを混ぜて上位5件を取得する @@ -522,11 +561,19 @@ import SwiftUtils /// - N_best: 計算途中で保存する候補数。実際に得られる候補数とは異なる。 /// - Returns: /// 結果のラティスノードと、計算済みノードの全体 - private func convertToLattice(_ inputData: ComposingText, N_best: Int) -> (result: LatticeNode, nodes: [[LatticeNode]])? { + private func convertToLattice(_ inputData: ComposingText, N_best: Int, zenzaiMode: ConvertRequestOptions.ZenzaiMode) -> (result: LatticeNode, nodes: [[LatticeNode]])? { if inputData.convertTarget.isEmpty { return nil } + // FIXME: enable cache based zenzai + if zenzaiMode.enabled, let model = self.getModel(modelURL: zenzaiMode.weightURL) { + let (result, nodes, cache) = self.converter.all_zenzai(inputData, zenz: model, zenzaiCache: self.zenzaiCache, inferenceLimit: zenzaiMode.inferenceLimit) + self.zenzaiCache = cache + self.previousInputData = inputData + return (result, nodes) + } + guard let previousInputData else { debug("convertToLattice: 新規計算用の関数を呼びますA") let result = converter.kana2lattice_all(inputData, N_best: N_best) @@ -621,7 +668,7 @@ import SwiftUtils // DicdataStoreにRequestOptionを通知する self.sendToDicdataStore(.setRequestOptions(options)) - guard let result = self.convertToLattice(inputData, N_best: options.N_best) else { + guard let result = self.convertToLattice(inputData, N_best: options.N_best, zenzaiMode: options.zenzaiMode) else { return ConversionResult(mainResults: [], firstClauseResults: []) } diff --git a/Sources/KanaKanjiConverterModule/DicdataStore/DicdataStore.swift b/Sources/KanaKanjiConverterModule/DicdataStore/DicdataStore.swift index c343e38..5efb0c6 100644 --- a/Sources/KanaKanjiConverterModule/DicdataStore/DicdataStore.swift +++ b/Sources/KanaKanjiConverterModule/DicdataStore/DicdataStore.swift @@ -219,7 +219,6 @@ public final class DicdataStore { } // MARK: 誤り訂正の対象を列挙する。非常に重い処理。 var stringToInfo = inputData.getRangesWithTypos(fromIndex, rightIndexRange: toIndexLeft ..< toIndexRight) - // MARK: 検索対象を列挙していく。 let stringSet = stringToInfo.keys.map {($0, $0.map(self.character2charId))} let (minCharIDsCount, maxCharIDsCount) = stringSet.lazy.map {$0.1.count}.minAndMax() ?? (0, -1) @@ -310,6 +309,73 @@ public final class DicdataStore { } } + /// kana2latticeから参照する。 + /// - Parameters: + /// - inputData: 入力データ + /// - from: 起点 + /// - toIndexRange: `from ..< (toIndexRange)`の範囲で辞書ルックアップを行う。 + public func getFrozenLOUDSDataInRange(inputData: ComposingText, from fromIndex: Int, toIndexRange: Range? = nil) -> [LatticeNode] { + let toIndexLeft = toIndexRange?.startIndex ?? fromIndex + let toIndexRight = min(toIndexRange?.endIndex ?? inputData.input.count, fromIndex + self.maxlength) + debug("getLOUDSDataInRange", fromIndex, toIndexRange?.description ?? "nil", toIndexLeft, toIndexRight) + if fromIndex > toIndexLeft || toIndexLeft >= toIndexRight { + debug("getLOUDSDataInRange: index is wrong") + return [] + } + + let segments = (fromIndex ..< toIndexRight).reduce(into: []) { (segments: inout [String], rightIndex: Int) in + segments.append((segments.last ?? "") + String(inputData.input[rightIndex].character.toKatakana())) + } + let character = String(inputData.input[fromIndex].character.toKatakana()) + let characterNode = LatticeNode(data: DicdataElement(word: character, ruby: character, cid: CIDData.一般名詞.cid, mid: MIDData.一般.mid, value: -10), inputRange: fromIndex ..< fromIndex + 1) + if fromIndex == .zero { + characterNode.prevs.append(.BOSNode()) + } + + // MARK: 誤り訂正なし + var stringToEndIndex = inputData.getRanges(fromIndex, rightIndexRange: toIndexLeft ..< toIndexRight) + // MARK: 検索対象を列挙していく。 + guard let (minString, maxString) = stringToEndIndex.keys.minAndMax(by: {$0.count < $1.count}) else { + return [characterNode] + } + let maxIDs = maxString.map(self.character2charId) + var keys = [String(stringToEndIndex.keys.first!.first!), "user"] + if learningManager.enabled { + keys.append("memory") + } + // MARK: 検索によって得たindicesから辞書データを実際に取り出していく + var dicdata: [DicdataElement] = [] + let depth = minString.count - 1 ..< maxString.count + for identifier in keys { + dicdata.append(contentsOf: self.getDicdataFromLoudstxt3(identifier: identifier, indices: self.throughMatchLOUDS(identifier: identifier, charIDs: maxIDs, depth: depth))) + } + if learningManager.enabled { + // temporalな学習結果にpenaltyを加えて追加する + dicdata.append(contentsOf: self.learningManager.temporaryThroughMatch(charIDs: consume maxIDs, depth: depth)) + } + for i in toIndexLeft ..< toIndexRight { + dicdata.append(contentsOf: self.getWiseDicdata(convertTarget: segments[i - fromIndex], inputData: inputData, inputRange: fromIndex ..< i + 1)) + dicdata.append(contentsOf: self.getMatchOSUserDict(segments[i - fromIndex])) + } + if fromIndex == .zero { + return dicdata.compactMap { + guard let endIndex = stringToEndIndex[Array($0.ruby)] else { + return nil + } + let node = LatticeNode(data: $0, inputRange: fromIndex ..< endIndex + 1) + node.prevs.append(RegisteredNode.BOSNode()) + return node + } + [characterNode] + } else { + return dicdata.compactMap { + guard let endIndex = stringToEndIndex[Array($0.ruby)] else { + return nil + } + return LatticeNode(data: $0, inputRange: fromIndex ..< endIndex + 1) + } + [characterNode] + } + } + /// kana2latticeから参照する。louds版。 /// - Parameters: /// - inputData: 入力データ @@ -727,12 +793,8 @@ public final class DicdataStore { /// wordTypesの初期化時に使うのみ。 private static let PREPOSITION_wordIDs: Set = [1315, 6, 557, 558, 559, 560] /// wordTypesの初期化時に使うのみ。 - private static let INPOSITION_wordIDs: Set = Set(Array(561..<868) - + Array(1283..<1297) - + Array(1306..<1310) - + Array(11..<53) - + Array(555..<557) - + Array(1281..<1283) + private static let INPOSITION_wordIDs: Set = Set( + Array(561..<868).chained(1283..<1297).chained(1306..<1310).chained(11..<53).chained(555..<557).chained(1281..<1283) ).union([1314, 3, 2, 4, 5, 1, 9]) /* diff --git a/Sources/KanaKanjiConverterModule/DicdataStore/TypoCorrection.swift b/Sources/KanaKanjiConverterModule/DicdataStore/TypoCorrection.swift index a4855ed..37b8919 100644 --- a/Sources/KanaKanjiConverterModule/DicdataStore/TypoCorrection.swift +++ b/Sources/KanaKanjiConverterModule/DicdataStore/TypoCorrection.swift @@ -97,6 +97,70 @@ extension ComposingText { return Dictionary(stringToInfo, uniquingKeysWith: {$0.penalty < $1.penalty ? $1 : $0}) } + /// closedRangeでもらう + /// 例えば`left=4, rightIndexRange=6..<10`の場合、`4...6, 4...7, 4...8, 4...9`の範囲で計算する + /// `left <= rightIndexRange.startIndex`が常に成り立つ + func getRanges(_ left: Int, rightIndexRange: Range) -> [[Character]: Int] { + let count = rightIndexRange.endIndex - left + debug("getRangesWithTypos", left, rightIndexRange, count) + let nodes = (0.. [TypoCandidate] in + let j = i + k + if count <= j { + return [] + } + return Self.getTypo(self.input[left + i ... left + j], frozen: true) + } + } + + // Performance Tuning Note:直接Dictionaryを作るのではなく、一度Arrayを作ってから最後にDictionaryに変換する方が、高速である + var stringToInfo: [([Character], Int)] = [] + + // 深さ優先で列挙する + var stack: [(convertTargetElements: [ConvertTargetElement], lastElement: InputElement, count: Int)] = nodes[0].compactMap { typoCandidate in + guard let firstElement = typoCandidate.inputElements.first else { + return nil + } + if Self.isLeftSideValid(first: firstElement, of: self.input, from: left) { + var convertTargetElements = [ConvertTargetElement]() + for element in typoCandidate.inputElements { + ComposingText.updateConvertTargetElements(currentElements: &convertTargetElements, newElement: element) + } + return (convertTargetElements, typoCandidate.inputElements.last!, typoCandidate.inputElements.count) + } + return nil + } + while var (convertTargetElements, lastElement, count) = stack.popLast() { + if rightIndexRange.contains(count + left - 1) { + if let convertTarget = ComposingText.getConvertTargetIfRightSideIsValid(lastElement: lastElement, of: self.input, to: count + left, convertTargetElements: convertTargetElements)?.map({$0.toKatakana()}) { + stringToInfo.append((convertTarget, (count + left - 1))) + } + } + // エスケープ + if nodes.endIndex <= count { + continue + } + stack.append(contentsOf: nodes[count].compactMap { + if count + $0.inputElements.count > nodes.endIndex { + return nil + } + for element in $0.inputElements { + ComposingText.updateConvertTargetElements(currentElements: &convertTargetElements, newElement: element) + } + if shouldBeRemovedForDicdataStore(components: convertTargetElements) { + return nil + } + return ( + convertTargetElements: convertTargetElements, + lastElement: $0.inputElements.last!, + count: count + $0.inputElements.count + ) + }) + } + return Dictionary(stringToInfo, uniquingKeysWith: {$0 < $1 ? $1 : $0}) + } + + func getRangeWithTypos(_ left: Int, _ right: Int) -> [[Character]: PValue] { // 各iから始まる候補を列挙する // 例えばinput = [d(あ), r(s), r(i), r(t), r(s), d(は), d(は), d(れ)]の場合 @@ -178,19 +242,20 @@ extension ComposingText { return Dictionary(stringToPenalty, uniquingKeysWith: max) } - private static func getTypo(_ elements: some Collection) -> [TypoCandidate] { + private static func getTypo(_ elements: some Collection, frozen: Bool = false) -> [TypoCandidate] { let key = elements.reduce(into: "") {$0.append($1.character)}.toKatakana() if (elements.allSatisfy {$0.inputStyle == .direct}) { + let dictionary: [String: [TypoUnit]] = frozen ? [:] : Self.directPossibleTypo if key.count > 1 { - return Self.directPossibleTypo[key, default: []].map { + return dictionary[key, default: []].map { TypoCandidate( inputElements: $0.value.map {InputElement(character: $0, inputStyle: .direct)}, weight: $0.weight ) } } else if key.count == 1 { - var result = Self.directPossibleTypo[key, default: []].map { + var result = dictionary[key, default: []].map { TypoCandidate( inputElements: $0.value.map {InputElement(character: $0, inputStyle: .direct)}, weight: $0.weight @@ -202,15 +267,16 @@ extension ComposingText { } } if (elements.allSatisfy {$0.inputStyle == .roman2kana}) { + let dictionary: [String: [String]] = frozen ? [:] : Self.roman2KanaPossibleTypo if key.count > 1 { - return Self.roman2KanaPossibleTypo[key, default: []].map { + return dictionary[key, default: []].map { TypoCandidate( inputElements: $0.map {InputElement(character: $0, inputStyle: .roman2kana)}, weight: 3.5 ) } } else if key.count == 1 { - var result = Self.roman2KanaPossibleTypo[key, default: []].map { + var result = dictionary[key, default: []].map { TypoCandidate( inputElements: $0.map {InputElement(character: $0, inputStyle: .roman2kana)}, weight: 3.5 diff --git a/Sources/KanaKanjiConverterModule/Kana2Kanji/all_with_prefix_constraint.swift b/Sources/KanaKanjiConverterModule/Kana2Kanji/all_with_prefix_constraint.swift new file mode 100644 index 0000000..230f362 --- /dev/null +++ b/Sources/KanaKanjiConverterModule/Kana2Kanji/all_with_prefix_constraint.swift @@ -0,0 +1,100 @@ +import Foundation +import SwiftUtils + +extension Kana2Kanji { + /// カナを漢字に変換する関数, 前提はなくかな列が与えられた場合。 + /// - Parameters: + /// - inputData: 入力データ。 + /// - N_best: N_best。 + /// - Returns: + /// 変換候補。 + /// ### 実装状況 + /// (0)多用する変数の宣言。 + /// + /// (1)まず、追加された一文字に繋がるノードを列挙する。 + /// + /// (2)次に、計算済みノードから、(1)で求めたノードにつながるようにregisterして、N_bestを求めていく。 + /// + /// (3)(1)のregisterされた結果をresultノードに追加していく。この際EOSとの連接計算を行っておく。 + /// + /// (4)ノードをアップデートした上で返却する。 + func kana2lattice_all_with_prefix_constraint(_ inputData: ComposingText, N_best: Int, constraint: String) -> (result: LatticeNode, nodes: Nodes) { + debug("新規に計算を行います。inputされた文字列は\(inputData.input.count)文字分の\(inputData.convertTarget)。制約は\(constraint)") + let count: Int = inputData.input.count + let result: LatticeNode = LatticeNode.EOSNode + let nodes: [[LatticeNode]] = (.zero ..< count).map {dicdataStore.getFrozenLOUDSDataInRange(inputData: inputData, from: $0)} + // 「i文字目から始まるnodes」に対して + for (i, nodeArray) in nodes.enumerated() { + // それぞれのnodeに対して + for node in nodeArray { + if node.prevs.isEmpty { + continue + } + if self.dicdataStore.shouldBeRemoved(data: node.data) { + continue + } + // 生起確率を取得する。 + let wValue: PValue = node.data.value() + if i == 0 { + // valuesを更新する + node.values = node.prevs.map {$0.totalValue + wValue + self.dicdataStore.getCCValue($0.data.rcid, node.data.lcid)} + } else { + // valuesを更新する + node.values = node.prevs.map {$0.totalValue + wValue} + } + // 変換した文字数 + let nextIndex: Int = node.inputRange.endIndex + // 文字数がcountと等しい場合登録する + if nextIndex == count { + for index in node.prevs.indices { + let newnode: RegisteredNode = node.getRegisteredNode(index, value: node.values[index]) + let text = newnode.getCandidateData().data.reduce(into: "") { $0.append(contentsOf: $1.word)} + node.data.word + if text.hasPrefix(constraint) { + result.prevs.append(newnode) + } + } + } else { + let candidates = node.getCandidateData().map { + $0.data.reduce(into: "") { $0.append(contentsOf: $1.word)} + node.data.word + } + // nodeの繋がる次にあり得る全てのnextnodeに対して + for nextnode in nodes[nextIndex] { + // この関数はこの時点で呼び出して、後のnode.registered.isEmptyで最終的に弾くのが良い。 + if self.dicdataStore.shouldBeRemoved(data: nextnode.data) { + continue + } + // クラスの連続確率を計算する。 + let ccValue: PValue = self.dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid) + // nodeの持っている全てのprevnodeに対して + for (index, value) in node.values.enumerated() { + // 制約を少なくとも満たしている必要がある + // common prefixが単語か制約のどちらかに一致している必要 + // 制約 AB 単語 ABC (OK) + // 制約 AB 単語 A (OK) + // 制約 AB 単語 AC (NG) + let text = candidates[index] + nextnode.data.word + if !text.hasPrefix(constraint) && !constraint.hasPrefix(text) { + continue + } + let newValue: PValue = ccValue + value + // 追加すべきindexを取得する + let lastindex: Int = (nextnode.prevs.lastIndex(where: {$0.totalValue >= newValue}) ?? -1) + 1 + if lastindex == N_best { + continue + } + let newnode: RegisteredNode = node.getRegisteredNode(index, value: newValue) + // カウントがオーバーしている場合は除去する + if nextnode.prevs.count >= N_best { + nextnode.prevs.removeLast() + } + // removeしてからinsertした方が速い (insertはO(N)なので) + nextnode.prevs.insert(newnode, at: lastindex) + } + } + } + } + } + return (result: result, nodes: nodes) + } + +} diff --git a/Sources/KanaKanjiConverterModule/Kana2Kanji/zenzai.swift b/Sources/KanaKanjiConverterModule/Kana2Kanji/zenzai.swift new file mode 100644 index 0000000..4525bb3 --- /dev/null +++ b/Sources/KanaKanjiConverterModule/Kana2Kanji/zenzai.swift @@ -0,0 +1,139 @@ +import Foundation +import SwiftUtils + +extension Kana2Kanji { + struct ZenzaiCache: Sendable { + init(_ inputData: ComposingText, constraint: String, satisfyingCandidate: Candidate?) { + self.inputData = inputData + self.prefixConstraint = constraint + self.satisfyingCandidate = satisfyingCandidate + } + + private var prefixConstraint: String + private var satisfyingCandidate: Candidate? + private var inputData: ComposingText + + func getNewConstraint(for newInputData: ComposingText) -> String { + if let satisfyingCandidate { + var current = newInputData.convertTarget.toKatakana()[...] + var constraint = "" + for item in satisfyingCandidate.data { + if current.hasPrefix(item.ruby) { + constraint += item.word + current = current.dropFirst(item.ruby.count) + } + } + return constraint + } else if newInputData.convertTarget.hasPrefix(inputData.convertTarget) { + return self.prefixConstraint + } else { + return "" + } + } + } + + /// zenzaiシステムによる完全変換。 + @MainActor func all_zenzai(_ inputData: ComposingText, zenz: Zenz, zenzaiCache: ZenzaiCache?, inferenceLimit: Int) -> (result: LatticeNode, nodes: Nodes, cache: ZenzaiCache) { + var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? "" + print("initial constraint", constraint) + let eosNode = LatticeNode.EOSNode + var nodes: Kana2Kanji.Nodes = [] + var inferenceLimit = inferenceLimit + while true { + // 実験の結果、ここは2-bestを取ると平均的な速度が最良になることがわかったので、そうしている。 + let start = Date() + let draftResult = self.kana2lattice_all_with_prefix_constraint(inputData, N_best: 2, constraint: constraint) + if nodes.isEmpty { + // 初回のみ + nodes = draftResult.nodes + } + let candidates = draftResult.result.getCandidateData().map(self.processClauseCandidate) + var best: (Int, Candidate)? = nil + for (i, cand) in candidates.enumerated() { + if let (_, c) = best, cand.value > c.value { + best = (i, cand) + } else if best == nil { + best = (i, cand) + } + } + guard var (index, candidate) = best else { + print("best was not found!") + // Emptyの場合 + // 制約が満たせない場合は無視する + return (eosNode, nodes, ZenzaiCache(inputData, constraint: "", satisfyingCandidate: nil)) + } + print("Constrained draft modeling", -start.timeIntervalSinceNow) + reviewLoop: while true { + // resultsを更新 + eosNode.prevs.insert(draftResult.result.prevs[index], at: 0) + if inferenceLimit == 0 { + print("inference limit! \(candidate.text) is used for excuse") + // When inference occurs more than maximum times, then just return result at this point + return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate)) + } + let reviewResult = zenz.candidateEvaluate(convertTarget: inputData.convertTarget, candidates: [candidate]) + inferenceLimit -= 1 + let nextAction = self.review( + candidateIndex: index, + candidates: candidates, + reviewResult: reviewResult, + constraint: &constraint + ) + switch nextAction { + case .return(let constraint, let satisfied): + if satisfied { + return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate)) + } else { + return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: nil)) + } + case .continue: + break reviewLoop + case .retry(let candidateIndex): + index = candidateIndex + candidate = candidates[candidateIndex] + } + } + } + } + + private enum NextAction { + case `return`(constraint: String, satisfied: Bool) + case `continue` + case `retry`(candidateIndex: Int) + } + + private func review( + candidateIndex: Int, + candidates: [Candidate], + reviewResult: consuming ZenzContext.CandidateEvaluationResult, + constraint: inout String + ) -> NextAction { + switch reviewResult { + case .error: + // 何らかのエラーが発生 + print("error") + return .return(constraint: constraint, satisfied: false) + case .pass(let score): + // 合格 + print("passed:", score) + return .return(constraint: constraint, satisfied: true) + case .fixRequired(let prefixConstraint): + // 同じ制約が2回連続で出てきたら諦める + if constraint == prefixConstraint { + print("same constraint:", prefixConstraint) + return .return(constraint: "", satisfied: false) + } + // 制約が得られたので、更新する + print("update constraint:", prefixConstraint) + constraint = prefixConstraint + // もし制約を満たす候補があるならそれを使って再レビューチャレンジを戦うことで、推論を減らせる + for i in candidates.indices where i != candidateIndex { + if candidates[i].text.hasPrefix(prefixConstraint) { + print("found \(candidates[i].text) as another retry") + return .retry(candidateIndex: i) + } + } + return .continue + } + } +} diff --git a/Sources/KanaKanjiConverterModule/Zenz/Zenz.swift b/Sources/KanaKanjiConverterModule/Zenz/Zenz.swift new file mode 100644 index 0000000..310e875 --- /dev/null +++ b/Sources/KanaKanjiConverterModule/Zenz/Zenz.swift @@ -0,0 +1,43 @@ +import Foundation +import SwiftUtils + +@MainActor final class Zenz { + package var resourceURL: URL + private var zenzContext: ZenzContext? + init(resourceURL: URL) throws { + self.resourceURL = resourceURL + do { + #if canImport(Darwin) + if #available(iOS 15, macOS 13, *) { + self.zenzContext = try ZenzContext.createContext(path: resourceURL.path(percentEncoded: false)) + } else { + // this is not percent-encoded + self.zenzContext = try ZenzContext.createContext(path: resourceURL.path) + } + #else + // this is not percent-encoded + self.zenzContext = try ZenzContext.createContext(path: resourceURL.path) + #endif + debug("Loaded model \(resourceURL.lastPathComponent)") + } catch { + throw error + } + } + + func startSession() {} + + func endSession() { + try? self.zenzContext?.reset_context() + } + + func candidateEvaluate(convertTarget: String, candidates: [Candidate]) -> ZenzContext.CandidateEvaluationResult { + guard let zenzContext else { + return .error + } + for candidate in candidates { + let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate.text) + return result + } + return .error + } +} diff --git a/Sources/KanaKanjiConverterModule/Zenz/ZenzContext.swift b/Sources/KanaKanjiConverterModule/Zenz/ZenzContext.swift new file mode 100644 index 0000000..f0e3120 --- /dev/null +++ b/Sources/KanaKanjiConverterModule/Zenz/ZenzContext.swift @@ -0,0 +1,233 @@ +import llama +import SwiftUtils +import Foundation + +enum ZenzError: LocalizedError { + case couldNotLoadModel(path: String) + case couldNotLoadContext + + var errorDescription: String? { + switch self { + case .couldNotLoadContext: "failed to load context" + case .couldNotLoadModel(path: let path): "could not load model weight at \(path)" + } + } +} + +class ZenzContext { + private var model: OpaquePointer + private var context: OpaquePointer + private var prevInput: [llama_token] = [] + + private let n_len: Int32 = 512 + + init(model: OpaquePointer, context: OpaquePointer) { + self.model = model + self.context = context + } + + deinit { + llama_free(context) + llama_free_model(model) + llama_backend_free() + } + + private static var ctx_params: llama_context_params { + let n_threads = max(1, min(8, ProcessInfo.processInfo.processorCount - 2)) + debug("Using \(n_threads) threads") + var ctx_params = llama_context_default_params() + ctx_params.seed = 1234 + ctx_params.n_ctx = 512 + ctx_params.n_threads = UInt32(n_threads) + ctx_params.n_threads_batch = UInt32(n_threads) + ctx_params.n_batch = 512 + return ctx_params + } + + static func createContext(path: String) throws -> ZenzContext { + llama_backend_init() + var model_params = llama_model_default_params() + model_params.use_mmap = true + let model = llama_load_model_from_file(path, model_params) + guard let model else { + debug("Could not load model at \(path)") + throw ZenzError.couldNotLoadModel(path: path) + } + + let context = llama_new_context_with_model(model, ctx_params) + guard let context else { + debug("Could not load context!") + throw ZenzError.couldNotLoadContext + } + + return ZenzContext(model: model, context: context) + } + + func reset_context() throws { + llama_free(self.context) + let context = llama_new_context_with_model(self.model, Self.ctx_params) + guard let context else { + debug("Could not load context!") + throw ZenzError.couldNotLoadContext + } + self.context = context + } + + private func get_logits(tokens: [llama_token], logits_start_index: Int = 0) -> UnsafeMutablePointer? { + // manage kv_cache + do { + let commonTokens = self.prevInput.commonPrefix(with: tokens) + llama_kv_cache_seq_rm(context, 0, llama_pos(commonTokens.count), -1) + } + var batch = llama_batch_init(512, 0, 1) + let n_ctx = llama_n_ctx(context) + let n_kv_req = tokens.count + (Int(n_len) - tokens.count) + if n_kv_req > n_ctx { + debug("error: n_kv_req > n_ctx, the required KV cache size is not big enough") + } + for i in tokens.indices { + llama_batch_add(&batch, tokens[i], Int32(i), [0], logits: logits_start_index <= i) + } + // 評価 + if llama_decode(context, batch) != 0 { + debug("llama_decode() failed") + return nil + } + return llama_get_logits(context) + } + + func evaluate(text: String, ignorePrompt: String = "") -> Float { + let tokens_list = self.tokenize(text: text, add_bos: true, add_eos: true) + guard let logits = self.get_logits(tokens: tokens_list) else { + debug("logits unavailable") + return .nan + } + let tokenizedPromptCount = ignorePrompt.isEmpty ? 1 : tokenize(text: ignorePrompt, add_bos: true, add_eos: false).count + let n_vocab = llama_n_vocab(model) + + var sum: Float = 0 + // 最初のプロンプト部分は無視する + for (i, token_id) in tokens_list.indexed().dropFirst(tokenizedPromptCount) { + // FIXME: there can be more efficient implementations, poossibly using Accelerate or other frameworks. + var log_prob: Float = 0 + for index in ((i - 1) * Int(n_vocab)) ..< (i * Int(n_vocab)) { + log_prob += exp(logits[index]) + } + log_prob = log(log_prob) + log_prob = logits[Int((i - 1) * Int(n_vocab) + Int(token_id))] - log_prob + sum += log_prob + } + return sum + } + + enum CandidateEvaluationResult: Sendable, Equatable, Hashable { + case error + case pass(score: Float) + case fixRequired(prefixConstraint: String) + } + + func evaluate_candidate(input: String, candidate: String) -> CandidateEvaluationResult { + // For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer' + // We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer + let prompt = "\u{EE00}\(input)\u{EE01}" + // Therefore, tokens = prompt_tokens + candidate_tokens is an appropriate operation. + let prompt_tokens = self.tokenize(text: prompt, add_bos: true, add_eos: false) + let candidate_tokens = self.tokenize(text: candidate, add_bos: false, add_eos: false) + let tokens = prompt_tokens + candidate_tokens + let startOffset = prompt_tokens.count - 1 + let pos_max = llama_kv_cache_seq_pos_max(self.context, 0) + print("pos max:", pos_max) + guard let logits = self.get_logits(tokens: tokens, logits_start_index: startOffset) else { + debug("logits unavailable") + return .error + } + let n_vocab = llama_n_vocab(model) + + var score: Float = 0 + for (i, token_id) in tokens.indexed().dropFirst(prompt_tokens.count) { + // それぞれのトークンが、一つ前の予測において最も確率の高いトークンであるかをチェックする + // softmaxはmaxなので、単にlogitsの中で最も大きいものを選べば良い + // 一方実用的にはlog_probも得ておきたい。このため、ここでは明示的にsoftmaxも計算している + var exp_sum: Float = 0 + var max_token: llama_token = 0 + var max_exp: Float = .infinity * -1 + let startIndex = (i - 1 - startOffset) * Int(n_vocab) + let endIndex = (i - startOffset) * Int(n_vocab) + for index in startIndex ..< endIndex { + let v = exp(logits[index]) + exp_sum += v + if max_exp < v { + max_exp = v + max_token = llama_token(index - startIndex) + } + } + // ここで最も良い候補であったかをチェックする + if max_token != token_id { + var cchars = tokens[.. [llama_token] { + let text = text.lowercased() + let utf8Count = text.utf8.count + let n_tokens = utf8Count + (add_bos ? 1 : 0) + let tokens = UnsafeMutablePointer.allocate(capacity: n_tokens) + let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false) + var swiftTokens: [llama_token] = if tokenCount < 0 { + [llama_token_bos(model)] + } else { + (0.. [CChar] { + let result = UnsafeMutablePointer.allocate(capacity: 8) + result.initialize(repeating: Int8(0), count: 8) + defer { + result.deallocate() + } + let nTokens = llama_token_to_piece(model, token, result, 8, false) + + if nTokens < 0 { + let newResult = UnsafeMutablePointer.allocate(capacity: Int(-nTokens)) + newResult.initialize(repeating: Int8(0), count: Int(-nTokens)) + defer { + newResult.deallocate() + } + let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens, false) + let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens)) + return Array(bufferPointer) + } else { + let bufferPointer = UnsafeBufferPointer(start: result, count: Int(nTokens)) + return Array(bufferPointer) + } + } +} diff --git a/Sources/KanaKanjiConverterModuleWithDefaultDictionary/KanaKanjiConverterModuleWithDefaultDictionary.swift b/Sources/KanaKanjiConverterModuleWithDefaultDictionary/KanaKanjiConverterModuleWithDefaultDictionary.swift index 43f4c74..131dabc 100644 --- a/Sources/KanaKanjiConverterModuleWithDefaultDictionary/KanaKanjiConverterModuleWithDefaultDictionary.swift +++ b/Sources/KanaKanjiConverterModuleWithDefaultDictionary/KanaKanjiConverterModuleWithDefaultDictionary.swift @@ -17,6 +17,7 @@ public extension ConvertRequestOptions { shouldResetMemory: Bool = false, memoryDirectoryURL: URL, sharedContainerURL: URL, + zenzaiMode: ZenzaiMode = .off, textReplacer: TextReplacer = TextReplacer(), metadata: ConvertRequestOptions.Metadata? ) -> Self { @@ -44,6 +45,7 @@ public extension ConvertRequestOptions { memoryDirectoryURL: memoryDirectoryURL, sharedContainerURL: sharedContainerURL, textReplacer: textReplacer, + zenzaiMode: zenzaiMode, metadata: metadata ) } diff --git a/Sources/KanaKanjiConverterModuleWithDefaultDictionary/azooKey_dictionary_storage b/Sources/KanaKanjiConverterModuleWithDefaultDictionary/azooKey_dictionary_storage index 43b855b..2eb6781 160000 --- a/Sources/KanaKanjiConverterModuleWithDefaultDictionary/azooKey_dictionary_storage +++ b/Sources/KanaKanjiConverterModuleWithDefaultDictionary/azooKey_dictionary_storage @@ -1 +1 @@ -Subproject commit 43b855b2b6275db3b94fbb51552fc039ee5f156b +Subproject commit 2eb6781de9d66343fe5ab65ec9f282e8db065f91 diff --git a/Tests/KanaKanjiConverterModuleTests/LOUDSTests.swift b/Tests/KanaKanjiConverterModuleTests/LOUDSTests.swift index f214379..da463c6 100644 --- a/Tests/KanaKanjiConverterModuleTests/LOUDSTests.swift +++ b/Tests/KanaKanjiConverterModuleTests/LOUDSTests.swift @@ -10,7 +10,7 @@ import XCTest final class LOUDSTests: XCTestCase { - static var resourceURL = Bundle.module.resourceURL!.standardizedFileURL.appendingPathComponent("DictionaryMock", isDirectory: true) + static let resourceURL = Bundle.module.resourceURL!.standardizedFileURL.appendingPathComponent("DictionaryMock", isDirectory: true) func requestOptions() -> ConvertRequestOptions { var options: ConvertRequestOptions = .default options.dictionaryResourceURL = Self.resourceURL diff --git a/Tests/KanaKanjiConverterModuleWithDefaultDictionaryTests/ConverterTests/ConverterTests.swift b/Tests/KanaKanjiConverterModuleWithDefaultDictionaryTests/ConverterTests/ConverterTests.swift index b089f00..90cbe1e 100644 --- a/Tests/KanaKanjiConverterModuleWithDefaultDictionaryTests/ConverterTests/ConverterTests.swift +++ b/Tests/KanaKanjiConverterModuleWithDefaultDictionaryTests/ConverterTests/ConverterTests.swift @@ -19,8 +19,8 @@ final class ConverterTests: XCTestCase { func requestOptions() -> ConvertRequestOptions { .withDefaultDictionary( - N_best: 5, - requireJapanesePrediction: true, + N_best: 10, + requireJapanesePrediction: false, requireEnglishPrediction: false, keyboardLanguage: .ja_JP, typographyLetterCandidate: false, @@ -38,46 +38,41 @@ final class ConverterTests: XCTestCase { } func testFullConversion() async throws { - await MainActor.run { - do { - let converter = KanaKanjiConverter() - var c = ComposingText() - c.insertAtCursorPosition("あずーきーはしんじだいのきーぼーどあぷりです", inputStyle: .direct) - let results = converter.requestCandidates(c, options: requestOptions()) - XCTAssertEqual(results.mainResults.first?.text, "azooKeyは新時代のキーボードアプリです") - } - do { - let converter = KanaKanjiConverter() - var c = ComposingText() - c.insertAtCursorPosition("ようしょうきからてにすすいえいやきゅうしょうりんじけんぽうなどさまざまなすぽーつをけいけんしながらそだちしょうがっこうじだいはろさんぜるすきんこうにたいざいしておりごるふやてにすをならっていた", inputStyle: .direct) - let results = converter.requestCandidates(c, options: requestOptions()) + do { + let converter = await KanaKanjiConverter() + var c = ComposingText() + c.insertAtCursorPosition("あずーきーはしんじだいのきーぼーどあぷりです", inputStyle: .direct) + let results = await converter.requestCandidates(c, options: requestOptions()) + XCTAssertEqual(results.mainResults.first?.text, "azooKeyは新時代のキーボードアプリです") + } + do { + let converter = await KanaKanjiConverter() + var c = ComposingText() + c.insertAtCursorPosition("ようしょうきからてにすすいえいやきゅうしょうりんじけんぽうなどさまざまなすぽーつをけいけんしながらそだちしょうがっこうじだいはろさんぜるすきんこうにたいざいしておりごるふやてにすをならっていた", inputStyle: .direct) + let results = await converter.requestCandidates(c, options: requestOptions()) + XCTAssertEqual(results.mainResults.first?.text, "幼少期からテニス水泳野球少林寺拳法など様々なスポーツを経験しながら育ち小学校時代はロサンゼルス近郊に滞在しておりゴルフやテニスを習っていた") + } + } + + // 1文字ずつ変換する + // memo: 内部実装としては別のモジュールが呼ばれるのだが、それをテストする方法があまりないかもしれない + func testGradualConversion() async throws { + let converter = await KanaKanjiConverter() + var c = ComposingText() + let text = "ようしょうきからてにすすいえいやきゅうしょうりんじけんぽうなどさまざまなすぽーつをけいけんしながらそだちしょうがっこうじだいはろさんぜるすきんこうにたいざいしておりごるふやてにすをならっていた" + for char in text { + c.insertAtCursorPosition(String(char), inputStyle: .direct) + let results = await converter.requestCandidates(c, options: requestOptions()) + if c.input.count == text.count { XCTAssertEqual(results.mainResults.first?.text, "幼少期からテニス水泳野球少林寺拳法など様々なスポーツを経験しながら育ち小学校時代はロサンゼルス近郊に滞在しておりゴルフやテニスを習っていた") } } } - // 1文字ずつ変換する - // memo: 内部実装としては別のモジュールが呼ばれるのだが、それをテストする方法があまりないかもしれない - func testGradualConversion() async throws { - await MainActor.run { - let converter = KanaKanjiConverter() - var c = ComposingText() - let text = "ようしょうきからてにすすいえいやきゅうしょうりんじけんぽうなどさまざまなすぽーつをけいけんしながらそだちしょうがっこうじだいはろさんぜるすきんこうにたいざいしておりごるふやてにすをならっていた" - for char in text { - c.insertAtCursorPosition(String(char), inputStyle: .direct) - let results = converter.requestCandidates(c, options: requestOptions()) - if c.input.count == text.count { - XCTAssertEqual(results.mainResults.first?.text, "幼少期からテニス水泳野球少林寺拳法など様々なスポーツを経験しながら育ち小学校時代はロサンゼルス近郊に滞在しておりゴルフやテニスを習っていた") - } - } - } - } - // 1文字ずつ変換する // memo: 内部実装としては別のモジュールが呼ばれるのだが、それをテストする方法があまりないかもしれない func testRoman2KanaGradualConversion() async throws { - await MainActor.run { - let converter = KanaKanjiConverter() + let converter = await KanaKanjiConverter() var c = ComposingText() let text = "youshoukikaratenisusuieiyakyuushourinjikenpounadosamazamanasupoーtuwokeikennsinagarasodatishougakkouzidaiharosanzerusukinkounitaizaisiteorigoruhuyatenisuwonaratteita" // 許容される変換結果 @@ -87,19 +82,17 @@ final class ConverterTests: XCTestCase { ] for char in text { c.insertAtCursorPosition(String(char), inputStyle: .roman2kana) - let results = converter.requestCandidates(c, options: requestOptions()) + let results = await converter.requestCandidates(c, options: requestOptions()) if c.input.count == text.count { XCTAssertTrue(possibles.contains(results.mainResults.first!.text)) } - } } } // 2,3文字ずつ変換する // memo: 内部実装としては別のモジュールが呼ばれるのだが、それをテストする方法があまりないかもしれない func testSemiGradualConversion() async throws { - await MainActor.run { - let converter = KanaKanjiConverter() + let converter = await KanaKanjiConverter() var c = ComposingText() let text = "ようしょうきからてにすすいえいやきゅうしょうりんじけんぽうなどさまざまなすぽーつをけいけんしながらそだちしょうがっこうじだいはろさんぜるすきんこうにたいざいしておりごるふやてにすをならっていた" var leftIndex = text.startIndex @@ -110,44 +103,40 @@ final class ConverterTests: XCTestCase { let rightIndex = text.index(leftIndex, offsetBy: count, limitedBy: text.endIndex) ?? text.endIndex let prefix = String(text[leftIndex ..< rightIndex]) c.insertAtCursorPosition(prefix, inputStyle: .direct) - let results = converter.requestCandidates(c, options: requestOptions()) + let results = await converter.requestCandidates(c, options: requestOptions()) leftIndex = rightIndex if rightIndex == text.endIndex { XCTAssertEqual(results.mainResults.first?.text, "幼少期からテニス水泳野球少林寺拳法など様々なスポーツを経験しながら育ち小学校時代はロサンゼルス近郊に滞在しておりゴルフやテニスを習っていた") } } - } } // 1文字ずつ入力するが、時折削除を行う // memo: 内部実装としてはdeleted_last_nのテストを意図している func testGradualConversionWithDelete() async throws { - await MainActor.run { - let converter = KanaKanjiConverter() + let converter = await KanaKanjiConverter() var c = ComposingText() let text = Array("ようしょうきからてにすすいえいやきゅうしょうりんじけんぽうなどさまざまなすぽーつをけいけんしながらそだちしょうがっこうじだいはろさんぜるすきんこうにたいざいしておりごるふやてにすをならっていた") let deleteIndices = [1, 4, 8, 10, 15, 18, 20, 21, 23, 25, 26, 28, 29, 33, 34, 37, 39, 40, 42, 44, 45, 49, 51, 54, 58, 60, 62, 64, 67, 69, 70, 75, 80] for (i, char) in text.enumerated() { c.insertAtCursorPosition(String(char), inputStyle: .direct) - let results = converter.requestCandidates(c, options: requestOptions()) + let results = await converter.requestCandidates(c, options: requestOptions()) if deleteIndices.contains(i) { let count = i % 3 + 1 c.deleteBackwardFromCursorPosition(count: count) - _ = converter.requestCandidates(c, options: requestOptions()) + _ = await converter.requestCandidates(c, options: requestOptions()) c.insertAtCursorPosition(String(text[i - count + 1 ... i]), inputStyle: .direct) - _ = converter.requestCandidates(c, options: requestOptions()) + _ = await converter.requestCandidates(c, options: requestOptions()) } if c.input.count == text.count { XCTAssertEqual(results.mainResults.first?.text, "幼少期からテニス水泳野球少林寺拳法など様々なスポーツを経験しながら育ち小学校時代はロサンゼルス近郊に滞在しておりゴルフやテニスを習っていた") } } - } } // 必ず正解すべきテストケース func testMustCases() async throws { - await MainActor.run { // ダイレクト入力 do { let cases: [(input: String, expect: String)] = [ @@ -162,19 +151,19 @@ final class ConverterTests: XCTestCase { var options = requestOptions() options.requireJapanesePrediction = false for (input, expect) in cases { - let converter = KanaKanjiConverter() + let converter = await KanaKanjiConverter() var c = ComposingText() sequentialInput(&c, sequence: input, inputStyle: .direct) - let results = converter.requestCandidates(c, options: options) + let results = await converter.requestCandidates(c, options: options) XCTAssertEqual(results.mainResults.first?.text, expect) } // gradual input for (input, expect) in cases { - let converter = KanaKanjiConverter() + let converter = await KanaKanjiConverter() var c = ComposingText() for char in input { c.insertAtCursorPosition(String(char), inputStyle: .direct) - let results = converter.requestCandidates(c, options: options) + let results = await converter.requestCandidates(c, options: options) if c.input.count == input.count { XCTAssertEqual(results.mainResults.first?.text, expect) } @@ -193,33 +182,31 @@ final class ConverterTests: XCTestCase { var options = requestOptions() options.requireJapanesePrediction = false for (input, expect) in cases { - let converter = KanaKanjiConverter() + let converter = await KanaKanjiConverter() var c = ComposingText() sequentialInput(&c, sequence: input, inputStyle: .roman2kana) - let results = converter.requestCandidates(c, options: options) + let results = await converter.requestCandidates(c, options: options) XCTAssertEqual(results.mainResults.first?.text, expect) } // gradual input for (input, expect) in cases { - let converter = KanaKanjiConverter() + let converter = await KanaKanjiConverter() var c = ComposingText() for char in input { c.insertAtCursorPosition(String(char), inputStyle: .roman2kana) - let results = converter.requestCandidates(c, options: options) + let results = await converter.requestCandidates(c, options: options) if c.input.count == input.count { XCTAssertEqual(results.mainResults.first?.text, expect) } } } } - } } // 変換結果が比較的一意なテストケースを無数に持ち、一定の割合を正解することを要求する // 辞書を更新した結果性能が悪化したら気付ける func testAccuracy() async throws { - await MainActor.run { let cases: [(input: String, expect: [String])] = [ ("3がつ8にち", ["3月8日"]), ("いっていのわりあい", ["一定の割合"]), @@ -275,10 +262,10 @@ final class ConverterTests: XCTestCase { var score: Double = 0 for (input, expect) in cases { - let converter = KanaKanjiConverter() + let converter = await KanaKanjiConverter() var c = ComposingText() c.insertAtCursorPosition(input, inputStyle: .direct) - let results = converter.requestCandidates(c, options: requestOptions()) + let results = await converter.requestCandidates(c, options: requestOptions()) if expect.contains(results.mainResults[0].text) { score += 1 @@ -291,14 +278,12 @@ final class ConverterTests: XCTestCase { let accuracy = score / Double(cases.count) print("\(#function) Result: accuracy \(accuracy), score \(score), count \(cases.count)") XCTAssertGreaterThan(accuracy, 0.7) // 0.7 < acuracy - } } // 変換結果が比較的一意なテストケースを無数に持ち、一定の割合を正解することを要求する // 辞書を更新した結果性能が悪化したら気付ける // 口語表現を中心にテストする func testVerbalAccuracy() async throws { - await MainActor.run { let cases: [(input: String, expect: [String])] = [ ("うわああああ、まじか", ["うわああああ、マジか", "うわああああ、まじか"]), ("は?", ["は?"]), @@ -326,10 +311,10 @@ final class ConverterTests: XCTestCase { var score: Double = 0 for (input, expect) in cases { - let converter = KanaKanjiConverter() + let converter = await KanaKanjiConverter() var c = ComposingText() c.insertAtCursorPosition(input, inputStyle: .direct) - let results = converter.requestCandidates(c, options: requestOptions()) + let results = await converter.requestCandidates(c, options: requestOptions()) if expect.contains(results.mainResults[0].text) { score += 1 @@ -342,12 +327,10 @@ final class ConverterTests: XCTestCase { let accuracy = score / Double(cases.count) print("\(#function) Result: accuracy \(accuracy), score \(score), count \(cases.count)") XCTAssertGreaterThan(accuracy, 0.7) // 0.7 < acuracy - } } /// MIDベースの文節単位計算でどれだけ同音異義語の判断が向上しているか確認する。 func testMeaningBasedConversionAccuracy() async throws { - await MainActor.run { let cases: [(input: String, expect: String)] = [ ("しょうぼう、しょうか、ほのお", "消防、消火、炎"), ("いえき、しょうか、こうそ", "胃液、消化、酵素"), @@ -627,12 +610,12 @@ final class ConverterTests: XCTestCase { var score: Double = 0 for (input, expect) in cases { - let converter = KanaKanjiConverter() + let converter = await KanaKanjiConverter() var c = ComposingText() c.insertAtCursorPosition(input, inputStyle: .direct) var options = requestOptions() options.requireJapanesePrediction = false - let results = converter.requestCandidates(c, options: options) + let results = await converter.requestCandidates(c, options: options) if results.mainResults[0].text == expect { score += 1 @@ -645,10 +628,9 @@ final class ConverterTests: XCTestCase { let accuracy = score / Double(cases.count) print("\(#function) Result: accuracy \(accuracy), score \(score), count \(cases.count)") XCTAssertGreaterThan(accuracy, 0.7) // 0.7 < accuracy - } } - #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) || os(visionOS) +#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) || os(visionOS) func testMozcEvaluationData() async throws { // ダウンロードするURL let urlString = "https://raw.githubusercontent.com/google/mozc/master/src/data/dictionary_oss/evaluation.tsv" @@ -729,7 +711,7 @@ final class ConverterTests: XCTestCase { XCTAssertTrue(mozcScore < azooKeyScore) } } - #endif +#endif enum MozcCommand: Equatable { /// 変換に`arg`が現れる diff --git a/install_cli.sh b/install_cli.sh index 389309d..4dd95dd 100755 --- a/install_cli.sh +++ b/install_cli.sh @@ -1,2 +1,6 @@ -swift build -c release +swift build -c release -Xcxx -xobjective-c++ cp -f .build/release/CliTool /usr/local/bin/anco + +# FIXME: Unfortunately, in order to use zenzai in anco, you will need to build CliTool with xcodebuild +# It is highly desirable to make it work only with `swift build` +# xcodebuild -scheme CliTool -destination "platform=macOS,name=Any Mac" -configuration Release