mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-12-03 11:08:33 +00:00
Merge branch 'develop' into feat/zenz-v3
This commit is contained in:
@@ -5,7 +5,14 @@ import ArgumentParser
|
|||||||
public struct Anco: AsyncParsableCommand {
|
public struct Anco: AsyncParsableCommand {
|
||||||
public static let configuration = CommandConfiguration(
|
public static let configuration = CommandConfiguration(
|
||||||
abstract: "Anco is A(zooKey) Kana-Ka(n)ji (co)nverter",
|
abstract: "Anco is A(zooKey) Kana-Ka(n)ji (co)nverter",
|
||||||
subcommands: [Subcommands.Run.self, Subcommands.Dict.self, Subcommands.Evaluate.self, Subcommands.Session.self, Subcommands.ExperimentalPredict.self],
|
subcommands: [
|
||||||
|
Subcommands.Run.self,
|
||||||
|
Subcommands.Dict.self,
|
||||||
|
Subcommands.Evaluate.self,
|
||||||
|
Subcommands.ZenzEvaluate.self,
|
||||||
|
Subcommands.Session.self,
|
||||||
|
Subcommands.ExperimentalPredict.self
|
||||||
|
],
|
||||||
defaultSubcommand: Subcommands.Run.self
|
defaultSubcommand: Subcommands.Run.self
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ extension Subcommands.Dict {
|
|||||||
dictionaryResourceURL: URL(fileURLWithPath: self.dictionaryDirectory),
|
dictionaryResourceURL: URL(fileURLWithPath: self.dictionaryDirectory),
|
||||||
memoryDirectoryURL: URL(fileURLWithPath: self.dictionaryDirectory),
|
memoryDirectoryURL: URL(fileURLWithPath: self.dictionaryDirectory),
|
||||||
sharedContainerURL: URL(fileURLWithPath: self.dictionaryDirectory),
|
sharedContainerURL: URL(fileURLWithPath: self.dictionaryDirectory),
|
||||||
|
textReplacer: .empty,
|
||||||
metadata: .init(versionString: "anco for debugging")
|
metadata: .init(versionString: "anco for debugging")
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,22 +17,24 @@ extension Subcommands {
|
|||||||
var zenzWeightPath: String = ""
|
var zenzWeightPath: String = ""
|
||||||
@Option(name: [.customLong("config_zenzai_inference_limit")], help: "inference limit for zenzai.")
|
@Option(name: [.customLong("config_zenzai_inference_limit")], help: "inference limit for zenzai.")
|
||||||
var configZenzaiInferenceLimit: Int = .max
|
var configZenzaiInferenceLimit: Int = .max
|
||||||
|
@Flag(name: [.customLong("config_zenzai_ignore_left_context")], help: "ignore left_context")
|
||||||
|
var configZenzaiIgnoreLeftContext: Bool = false
|
||||||
|
|
||||||
static let configuration = CommandConfiguration(commandName: "evaluate", abstract: "Evaluate quality of Conversion for input data.")
|
static let configuration = CommandConfiguration(commandName: "evaluate", abstract: "Evaluate quality of Conversion for input data.")
|
||||||
|
|
||||||
private func parseInputFile() throws -> [InputItem] {
|
private func parseInputFile() throws -> [EvaluationInputItem] {
|
||||||
let url = URL(fileURLWithPath: self.inputFile)
|
let url = URL(fileURLWithPath: self.inputFile)
|
||||||
let data = try Data(contentsOf: url)
|
let data = try Data(contentsOf: url)
|
||||||
return try JSONDecoder().decode([InputItem].self, from: data)
|
return try JSONDecoder().decode([EvaluationInputItem].self, from: data)
|
||||||
}
|
}
|
||||||
|
|
||||||
mutating func run() async throws {
|
mutating func run() async throws {
|
||||||
let inputItems = try parseInputFile()
|
let inputItems = try parseInputFile()
|
||||||
let requestOptions = requestOptions()
|
|
||||||
let converter = await KanaKanjiConverter()
|
let converter = await KanaKanjiConverter()
|
||||||
let start = Date()
|
var executionTime: Double = 0
|
||||||
var resultItems: [EvaluateItem] = []
|
var resultItems: [EvaluateItem] = []
|
||||||
for item in inputItems {
|
for item in inputItems {
|
||||||
|
let start = Date()
|
||||||
// セットアップ
|
// セットアップ
|
||||||
await converter.sendToDicdataStore(.importDynamicUserDict(
|
await converter.sendToDicdataStore(.importDynamicUserDict(
|
||||||
(item.user_dictionary ?? []).map {
|
(item.user_dictionary ?? []).map {
|
||||||
@@ -42,7 +44,7 @@ extension Subcommands {
|
|||||||
// 変換
|
// 変換
|
||||||
var composingText = ComposingText()
|
var composingText = ComposingText()
|
||||||
composingText.insertAtCursorPosition(item.query, inputStyle: .direct)
|
composingText.insertAtCursorPosition(item.query, inputStyle: .direct)
|
||||||
|
let requestOptions = self.requestOptions(leftSideContext: item.left_context)
|
||||||
let result = await converter.requestCandidates(composingText, options: requestOptions)
|
let result = await converter.requestCandidates(composingText, options: requestOptions)
|
||||||
let mainResults = result.mainResults.filter {
|
let mainResults = result.mainResults.filter {
|
||||||
$0.data.reduce(into: "", {$0.append(contentsOf: $1.ruby)}) == item.query.toKatakana()
|
$0.data.reduce(into: "", {$0.append(contentsOf: $1.ruby)}) == item.query.toKatakana()
|
||||||
@@ -51,16 +53,17 @@ extension Subcommands {
|
|||||||
EvaluateItem(
|
EvaluateItem(
|
||||||
query: item.query,
|
query: item.query,
|
||||||
answers: item.answer,
|
answers: item.answer,
|
||||||
|
left_context: item.left_context,
|
||||||
outputs: mainResults.prefix(self.configNBest).map {
|
outputs: mainResults.prefix(self.configNBest).map {
|
||||||
EvaluateItemOutput(text: $0.text, score: Double($0.value))
|
EvaluateItemOutput(text: $0.text, score: Double($0.value))
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
executionTime += Date().timeIntervalSince(start)
|
||||||
// Explictly reset state
|
// Explictly reset state
|
||||||
await converter.stopComposition()
|
await converter.stopComposition()
|
||||||
}
|
}
|
||||||
let end = Date()
|
var result = EvaluateResult(n_best: self.configNBest, execution_time: executionTime, items: resultItems)
|
||||||
var result = EvaluateResult(n_best: self.configNBest, execution_time: end.timeIntervalSince(start), items: resultItems)
|
|
||||||
if stable {
|
if stable {
|
||||||
result.execution_time = 0
|
result.execution_time = 0
|
||||||
result.timestamp = 0
|
result.timestamp = 0
|
||||||
@@ -83,7 +86,7 @@ extension Subcommands {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestOptions() -> ConvertRequestOptions {
|
func requestOptions(leftSideContext: String?) -> ConvertRequestOptions {
|
||||||
var option: ConvertRequestOptions = .withDefaultDictionary(
|
var option: ConvertRequestOptions = .withDefaultDictionary(
|
||||||
N_best: self.configNBest,
|
N_best: self.configNBest,
|
||||||
requireJapanesePrediction: false,
|
requireJapanesePrediction: false,
|
||||||
@@ -99,7 +102,7 @@ extension Subcommands {
|
|||||||
shouldResetMemory: false,
|
shouldResetMemory: false,
|
||||||
memoryDirectoryURL: URL(fileURLWithPath: ""),
|
memoryDirectoryURL: URL(fileURLWithPath: ""),
|
||||||
sharedContainerURL: URL(fileURLWithPath: ""),
|
sharedContainerURL: URL(fileURLWithPath: ""),
|
||||||
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit),
|
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit, versionDependentMode: .v2(.init(leftSideContext: self.configZenzaiIgnoreLeftContext ? nil : leftSideContext))),
|
||||||
metadata: .init(versionString: "anco for debugging")
|
metadata: .init(versionString: "anco for debugging")
|
||||||
)
|
)
|
||||||
option.requestQuery = .完全一致
|
option.requestQuery = .完全一致
|
||||||
@@ -107,7 +110,7 @@ extension Subcommands {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private struct InputItem: Codable {
|
struct EvaluationInputItem: Codable {
|
||||||
/// 入力クエリ
|
/// 入力クエリ
|
||||||
var query: String
|
var query: String
|
||||||
|
|
||||||
@@ -117,17 +120,20 @@ extension Subcommands {
|
|||||||
/// タグ
|
/// タグ
|
||||||
var tag: [String] = []
|
var tag: [String] = []
|
||||||
|
|
||||||
|
/// 左文脈
|
||||||
|
var left_context: String? = nil
|
||||||
|
|
||||||
/// ユーザ辞書
|
/// ユーザ辞書
|
||||||
var user_dictionary: [InputUserDictionaryItem]? = nil
|
var user_dictionary: [InputUserDictionaryItem]? = nil
|
||||||
}
|
|
||||||
|
|
||||||
private struct InputUserDictionaryItem: Codable {
|
struct InputUserDictionaryItem: Codable {
|
||||||
/// 漢字
|
/// 漢字
|
||||||
var word: String
|
var word: String
|
||||||
/// 読み
|
/// 読み
|
||||||
var reading: String
|
var reading: String
|
||||||
/// ヒント
|
/// ヒント
|
||||||
var hint: String? = nil
|
var hint: String? = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct EvaluateResult: Codable {
|
struct EvaluateResult: Codable {
|
||||||
@@ -166,9 +172,10 @@ extension Subcommands {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct EvaluateItem: Codable {
|
struct EvaluateItem: Codable {
|
||||||
init(query: String, answers: [String], outputs: [Subcommands.EvaluateItemOutput]) {
|
init(query: String, answers: [String], left_context: String?, outputs: [Subcommands.EvaluateItemOutput]) {
|
||||||
self.query = query
|
self.query = query
|
||||||
self.answers = answers
|
self.answers = answers
|
||||||
|
self.left_context = left_context ?? ""
|
||||||
self.outputs = outputs
|
self.outputs = outputs
|
||||||
do {
|
do {
|
||||||
// entropyを示す
|
// entropyを示す
|
||||||
@@ -195,6 +202,9 @@ extension Subcommands {
|
|||||||
/// 出力
|
/// 出力
|
||||||
var outputs: [EvaluateItemOutput]
|
var outputs: [EvaluateItemOutput]
|
||||||
|
|
||||||
|
/// 文脈
|
||||||
|
var left_context: String
|
||||||
|
|
||||||
/// エントロピー
|
/// エントロピー
|
||||||
var entropy: Double
|
var entropy: Double
|
||||||
|
|
||||||
|
|||||||
89
Sources/CliTool/Subcommands/ZenzEvaluateCommand.swift
Normal file
89
Sources/CliTool/Subcommands/ZenzEvaluateCommand.swift
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
import KanaKanjiConverterModuleWithDefaultDictionary
|
||||||
|
import ArgumentParser
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
extension Subcommands {
|
||||||
|
struct ZenzEvaluate: AsyncParsableCommand {
|
||||||
|
@Argument(help: "query, answer, tagを備えたjsonファイルへのパス")
|
||||||
|
var inputFile: String = ""
|
||||||
|
|
||||||
|
@Option(name: [.customLong("output")], help: "Output file path.")
|
||||||
|
var outputFilePath: String? = nil
|
||||||
|
@Flag(name: [.customLong("stable")], help: "Report only stable properties; timestamps and values will not be reported.")
|
||||||
|
var stable: Bool = false
|
||||||
|
@Option(name: [.customLong("zenz")], help: "gguf format model weight for zenz.")
|
||||||
|
var zenzWeightPath: String = ""
|
||||||
|
|
||||||
|
static let configuration = CommandConfiguration(commandName: "zenz_evaluate", abstract: "Evaluate quality of pure zenz's Conversion for input data.")
|
||||||
|
|
||||||
|
private func parseInputFile() throws -> [EvaluationInputItem] {
|
||||||
|
let url = URL(fileURLWithPath: self.inputFile)
|
||||||
|
let data = try Data(contentsOf: url)
|
||||||
|
return try JSONDecoder().decode([EvaluationInputItem].self, from: data)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func greedyDecoding(query: String, leftContext: String?, zenz: Zenz, maxCount: Int) async -> String {
|
||||||
|
var leftContext = if let leftContext {
|
||||||
|
"\u{EE02}" + String(leftContext.suffix(40))
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
}
|
||||||
|
leftContext = "\u{EE00}\(query)\(leftContext)\u{EE01}"
|
||||||
|
return await zenz.pureGreedyDecoding(pureInput: leftContext, maxCount: maxCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
mutating func run() async throws {
|
||||||
|
let inputItems = try parseInputFile()
|
||||||
|
let converter = await KanaKanjiConverter()
|
||||||
|
var executionTime: Double = 0
|
||||||
|
var resultItems: [EvaluateItem] = []
|
||||||
|
|
||||||
|
guard let zenz = await converter.getModel(modelURL: URL(string: self.zenzWeightPath)!) else {
|
||||||
|
print("Failed to initialize zenz model")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for item in inputItems {
|
||||||
|
let start = Date()
|
||||||
|
if item.user_dictionary != nil {
|
||||||
|
print("Warning: zenz_evaluate command does not suppport user dictionary. User Dictionary Contents are just ignored.")
|
||||||
|
}
|
||||||
|
// 変換
|
||||||
|
let result = await self.greedyDecoding(query: item.query, leftContext: item.left_context, zenz: zenz, maxCount: item.answer.map(\.utf8.count).max()!)
|
||||||
|
print("Results:", result)
|
||||||
|
resultItems.append(
|
||||||
|
EvaluateItem(
|
||||||
|
query: item.query,
|
||||||
|
answers: item.answer,
|
||||||
|
left_context: item.left_context,
|
||||||
|
outputs: [
|
||||||
|
EvaluateItemOutput(text: result, score: 0.0)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
executionTime += Date().timeIntervalSince(start)
|
||||||
|
await zenz.endSession()
|
||||||
|
}
|
||||||
|
var result = EvaluateResult(n_best: 1, execution_time: executionTime, items: resultItems)
|
||||||
|
if stable {
|
||||||
|
result.execution_time = 0
|
||||||
|
result.timestamp = 0
|
||||||
|
result.items.mutatingForeach {
|
||||||
|
$0.outputs.mutatingForeach {
|
||||||
|
$0.score = Double(Int($0.score))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let encoder = JSONEncoder()
|
||||||
|
encoder.outputFormatting = [.prettyPrinted, .sortedKeys]
|
||||||
|
let json = try encoder.encode(result)
|
||||||
|
|
||||||
|
if let outputFilePath {
|
||||||
|
try json.write(to: URL(fileURLWithPath: outputFilePath))
|
||||||
|
} else {
|
||||||
|
let string = String(data: json, encoding: .utf8)!
|
||||||
|
print(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -29,7 +29,7 @@ public struct ConvertRequestOptions: Sendable {
|
|||||||
/// - sharedContainerURL: ユーザ辞書など、キーボード外で書き込んだ設定データの保存されているディレクトリを指定します。
|
/// - sharedContainerURL: ユーザ辞書など、キーボード外で書き込んだ設定データの保存されているディレクトリを指定します。
|
||||||
/// - textReplacer: 予測変換のための置換機を指定します。
|
/// - textReplacer: 予測変換のための置換機を指定します。
|
||||||
/// - metadata: メタデータを指定します。詳しくは`ConvertRequestOptions.Metadata`を参照してください。
|
/// - metadata: メタデータを指定します。詳しくは`ConvertRequestOptions.Metadata`を参照してください。
|
||||||
public init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, textReplacer: TextReplacer = TextReplacer(), zenzaiMode: ZenzaiMode = .off, metadata: ConvertRequestOptions.Metadata?) {
|
public init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, textReplacer: TextReplacer, zenzaiMode: ZenzaiMode = .off, metadata: ConvertRequestOptions.Metadata?) {
|
||||||
self.N_best = N_best
|
self.N_best = N_best
|
||||||
self.requireJapanesePrediction = requireJapanesePrediction
|
self.requireJapanesePrediction = requireJapanesePrediction
|
||||||
self.requireEnglishPrediction = requireEnglishPrediction
|
self.requireEnglishPrediction = requireEnglishPrediction
|
||||||
@@ -50,25 +50,49 @@ public struct ConvertRequestOptions: Sendable {
|
|||||||
self.dictionaryResourceURL = dictionaryResourceURL
|
self.dictionaryResourceURL = dictionaryResourceURL
|
||||||
}
|
}
|
||||||
|
|
||||||
package init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, textReplacer: TextReplacer = TextReplacer(), zenzaiMode: ZenzaiMode = .off, metadata: ConvertRequestOptions.Metadata?, requestQuery: RequestQuery) {
|
/// 変換リクエストに必要な設定データ
|
||||||
self.N_best = N_best
|
///
|
||||||
self.requireJapanesePrediction = requireJapanesePrediction
|
/// - parameters:
|
||||||
self.requireEnglishPrediction = requireEnglishPrediction
|
/// - N_best: 変換候補の数。上位`N`件までの言語モデル上の妥当性を保証します。大きくすると計算量が増加します。
|
||||||
self.keyboardLanguage = keyboardLanguage
|
/// - requireJapanesePrediction: 日本語の予測変換候補の必要性。`false`にすると、日本語の予測変換候補を出力しなくなります。
|
||||||
self.typographyLetterCandidate = typographyLetterCandidate
|
/// - requireEnglishPrediction: 英語の予測変換候補の必要性。`false`にすると、英語の予測変換候補を出力しなくなります。ローマ字入力を用いた日本語入力では`false`にした方が良いでしょう。
|
||||||
self.unicodeCandidate = unicodeCandidate
|
/// - keyboardLanguage: キーボードの言語を指定します。
|
||||||
self.englishCandidateInRoman2KanaInput = englishCandidateInRoman2KanaInput
|
/// - typographyLetterCandidate: `true`の場合、「おしゃれなフォント」での英数字変換候補が出力に含まれるようになります。詳しくは`KanaKanjiConverter.typographicalCandidates(_:)`を参照してください。
|
||||||
self.fullWidthRomanCandidate = fullWidthRomanCandidate
|
/// - unicodeCandidate: `true`の場合、`U+xxxx`のような入力に対してUnicodeの変換候補が出力に含まれるようになります。詳しくは`KanaKanjiConverter.unicodeCandidates(_:)`を参照してください。`
|
||||||
self.halfWidthKanaCandidate = halfWidthKanaCandidate
|
/// - englishCandidateInRoman2KanaInput: `true`の場合、日本語ローマ字入力時に英語変換候補を出力します。`false`の場合、ローマ字入力時に英語変換候補を出力しません。
|
||||||
self.learningType = learningType
|
/// - fullWidthRomanCandidate: `true`の場合、全角英数字の変換候補が出力に含まれるようになります。
|
||||||
self.maxMemoryCount = maxMemoryCount
|
/// - halfWidthKanaCandidate: `true`の場合、半角カナの変換候補が出力に含まれるようになります。
|
||||||
self.shouldResetMemory = shouldResetMemory
|
/// - learningType: 学習モードを指定します。詳しくは`LearningType`を参照してください。
|
||||||
self.memoryDirectoryURL = memoryDirectoryURL
|
/// - maxMemoryCount: 学習が有効な場合に保持するデータの最大数を指定します。`0`の場合`learningType`を`nothing`に指定する方が適切です。
|
||||||
self.sharedContainerURL = sharedContainerURL
|
/// - shouldResetMemory: `true`の場合、変換を開始する前に学習データをリセットします。
|
||||||
self.metadata = metadata
|
/// - dictionaryResourceURL: 内蔵辞書データの読み出し先を指定します。
|
||||||
self.textReplacer = textReplacer
|
/// - memoryDirectoryURL: 学習データの保存先を指定します。書き込み可能なディレクトリを指定してください。
|
||||||
self.zenzaiMode = zenzaiMode
|
/// - sharedContainerURL: ユーザ辞書など、キーボード外で書き込んだ設定データの保存されているディレクトリを指定します。
|
||||||
self.dictionaryResourceURL = dictionaryResourceURL
|
/// - textReplacer: 予測変換のための置換機を指定します。
|
||||||
|
/// - metadata: メタデータを指定します。詳しくは`ConvertRequestOptions.Metadata`を参照してください。
|
||||||
|
@available(*, deprecated, message: "it be removed in AzooKeyKanaKanjiConverter v1.0")
|
||||||
|
public init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, zenzaiMode: ZenzaiMode = .off, metadata: ConvertRequestOptions.Metadata?) {
|
||||||
|
self.init(
|
||||||
|
N_best: N_best,
|
||||||
|
requireJapanesePrediction: requireJapanesePrediction,
|
||||||
|
requireEnglishPrediction: requireEnglishPrediction,
|
||||||
|
keyboardLanguage: keyboardLanguage,
|
||||||
|
typographyLetterCandidate: typographyLetterCandidate,
|
||||||
|
unicodeCandidate: unicodeCandidate,
|
||||||
|
englishCandidateInRoman2KanaInput: englishCandidateInRoman2KanaInput,
|
||||||
|
fullWidthRomanCandidate: fullWidthRomanCandidate,
|
||||||
|
halfWidthKanaCandidate: halfWidthKanaCandidate,
|
||||||
|
learningType: learningType,
|
||||||
|
maxMemoryCount: maxMemoryCount,
|
||||||
|
shouldResetMemory: shouldResetMemory,
|
||||||
|
dictionaryResourceURL: dictionaryResourceURL,
|
||||||
|
memoryDirectoryURL: memoryDirectoryURL,
|
||||||
|
sharedContainerURL: sharedContainerURL,
|
||||||
|
// MARK: using deprecated initializer here
|
||||||
|
textReplacer: TextReplacer(),
|
||||||
|
zenzaiMode: zenzaiMode,
|
||||||
|
metadata: metadata
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
public var N_best: Int
|
public var N_best: Int
|
||||||
@@ -117,6 +141,7 @@ public struct ConvertRequestOptions: Sendable {
|
|||||||
memoryDirectoryURL: (try? FileManager.default.url(for: .libraryDirectory, in: .userDomainMask, appropriateFor: nil, create: false)) ?? Bundle.main.bundleURL,
|
memoryDirectoryURL: (try? FileManager.default.url(for: .libraryDirectory, in: .userDomainMask, appropriateFor: nil, create: false)) ?? Bundle.main.bundleURL,
|
||||||
// dummy data, won't work
|
// dummy data, won't work
|
||||||
sharedContainerURL: Bundle.main.bundleURL,
|
sharedContainerURL: Bundle.main.bundleURL,
|
||||||
|
textReplacer: .empty,
|
||||||
metadata: nil
|
metadata: nil
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ import SwiftUtils
|
|||||||
self.lastData = nil
|
self.lastData = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
private func getModel(modelURL: URL) -> Zenz? {
|
package func getModel(modelURL: URL) -> Zenz? {
|
||||||
if let model = self.zenz, model.resourceURL == modelURL {
|
if let model = self.zenz, model.resourceURL == modelURL {
|
||||||
self.zenzStatus = "load \(modelURL.absoluteString)"
|
self.zenzStatus = "load \(modelURL.absoluteString)"
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ extension ComposingText {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
while var (convertTargetElements, lastElement, count) = stack.popLast() {
|
while case .some((var convertTargetElements, let lastElement, let count)) = stack.popLast() {
|
||||||
if rightIndexRange.contains(count + left - 1) {
|
if rightIndexRange.contains(count + left - 1) {
|
||||||
if let convertTarget = ComposingText.getConvertTargetIfRightSideIsValid(lastElement: lastElement, of: self.input, to: count + left, convertTargetElements: convertTargetElements)?.map({$0.toKatakana()}) {
|
if let convertTarget = ComposingText.getConvertTargetIfRightSideIsValid(lastElement: lastElement, of: self.input, to: count + left, convertTargetElements: convertTargetElements)?.map({$0.toKatakana()}) {
|
||||||
stringToInfo.append((convertTarget, (count + left - 1)))
|
stringToInfo.append((convertTarget, (count + left - 1)))
|
||||||
|
|||||||
@@ -61,6 +61,17 @@ public struct TextReplacer: Sendable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 動作しない`TextReplacer`を構築するためのイニシャライザ
|
||||||
|
/// - parameters:
|
||||||
|
/// - isEmpty: 入力
|
||||||
|
private init(isEmpty: Bool) {
|
||||||
|
assert(isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
public static var empty: Self {
|
||||||
|
Self(isEmpty: true)
|
||||||
|
}
|
||||||
|
|
||||||
@available(*, deprecated, renamed: "init(emojiDataProvider:)", message: "it be removed in AzooKeyKanaKanjiConverter v1.0")
|
@available(*, deprecated, renamed: "init(emojiDataProvider:)", message: "it be removed in AzooKeyKanaKanjiConverter v1.0")
|
||||||
public init() {
|
public init() {
|
||||||
self.init {
|
self.init {
|
||||||
|
|||||||
@@ -1,35 +1,33 @@
|
|||||||
import Foundation
|
import Foundation
|
||||||
import SwiftUtils
|
import SwiftUtils
|
||||||
|
|
||||||
@MainActor final class Zenz {
|
@MainActor package final class Zenz {
|
||||||
package var resourceURL: URL
|
package var resourceURL: URL
|
||||||
private var zenzContext: ZenzContext?
|
private var zenzContext: ZenzContext?
|
||||||
init(resourceURL: URL) throws {
|
init(resourceURL: URL) throws {
|
||||||
self.resourceURL = resourceURL
|
self.resourceURL = resourceURL
|
||||||
do {
|
do {
|
||||||
#if canImport(Darwin)
|
#if canImport(Darwin)
|
||||||
if #available(iOS 16, macOS 13, *) {
|
if #available(iOS 16, macOS 13, *) {
|
||||||
self.zenzContext = try ZenzContext.createContext(path: resourceURL.path(percentEncoded: false))
|
self.zenzContext = try ZenzContext.createContext(path: resourceURL.path(percentEncoded: false))
|
||||||
} else {
|
} else {
|
||||||
// this is not percent-encoded
|
// this is not percent-encoded
|
||||||
self.zenzContext = try ZenzContext.createContext(path: resourceURL.path)
|
self.zenzContext = try ZenzContext.createContext(path: resourceURL.path)
|
||||||
}
|
}
|
||||||
#elseif canImport(WinSDK)
|
#elseif canImport(WinSDK)
|
||||||
// remove first "/" from path (for windows)
|
// remove first "/" from path (for windows)
|
||||||
self.zenzContext = try ZenzContext.createContext(path: String(resourceURL.path.dropFirst()))
|
self.zenzContext = try ZenzContext.createContext(path: String(resourceURL.path.dropFirst()))
|
||||||
#else
|
#else
|
||||||
// this is not percent-encoded
|
// this is not percent-encoded
|
||||||
self.zenzContext = try ZenzContext.createContext(path: resourceURL.path)
|
self.zenzContext = try ZenzContext.createContext(path: resourceURL.path)
|
||||||
#endif
|
#endif
|
||||||
debug("Loaded model \(resourceURL.lastPathComponent)")
|
debug("Loaded model \(resourceURL.lastPathComponent)")
|
||||||
} catch {
|
} catch {
|
||||||
throw error
|
throw error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func startSession() {}
|
package func endSession() {
|
||||||
|
|
||||||
func endSession() {
|
|
||||||
try? self.zenzContext?.reset_context()
|
try? self.zenzContext?.reset_context()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,4 +49,8 @@ import SwiftUtils
|
|||||||
let result = zenzContext.predict_next_character(leftSideContext: leftSideContext, count: count)
|
let result = zenzContext.predict_next_character(leftSideContext: leftSideContext, count: count)
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
package func pureGreedyDecoding(pureInput: String, maxCount: Int = .max) -> String {
|
||||||
|
return self.zenzContext?.pure_greedy_decoding(leftSideContext: pureInput, maxCount: maxCount) ?? ""
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ enum ZenzError: LocalizedError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class ZenzContext {
|
final class ZenzContext {
|
||||||
private var model: OpaquePointer
|
private var model: OpaquePointer
|
||||||
private var context: OpaquePointer
|
private var context: OpaquePointer
|
||||||
private var prevInput: [llama_token] = []
|
private var prevInput: [llama_token] = []
|
||||||
@@ -197,6 +197,42 @@ class ZenzContext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// ピュアな貪欲法による生成を行って返す
|
||||||
|
func pure_greedy_decoding(leftSideContext: String, maxCount: Int = .max) -> String {
|
||||||
|
var prompt_tokens = self.tokenize(text: leftSideContext, add_bos: false)
|
||||||
|
let initial_count = prompt_tokens.count
|
||||||
|
let eos_token = llama_token_eos(model)
|
||||||
|
while prompt_tokens.count - initial_count < maxCount {
|
||||||
|
let startOffset = prompt_tokens.count - 1
|
||||||
|
guard let logits = self.get_logits(tokens: prompt_tokens, logits_start_index: startOffset) else {
|
||||||
|
print("logits unavailable")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
let n_vocab = llama_n_vocab(model)
|
||||||
|
let startIndex = (prompt_tokens.count - 1 - startOffset) * Int(n_vocab)
|
||||||
|
let endIndex = (prompt_tokens.count - startOffset) * Int(n_vocab)
|
||||||
|
// Min-Heapを使用してn-bestを計算
|
||||||
|
var max_token: llama_token = -1
|
||||||
|
var max_value: Float = Float.infinity * -1
|
||||||
|
for index in startIndex..<endIndex {
|
||||||
|
let token = llama_token(index - startIndex)
|
||||||
|
if max_value < logits[index] {
|
||||||
|
max_token = token
|
||||||
|
max_value = logits[index]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if max_token == eos_token {
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
prompt_tokens.append(max_token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Heapからソートして結果を取り出す
|
||||||
|
let cchars: [CChar] = prompt_tokens.dropFirst(initial_count).flatMap(self.token_to_piece) + [0]
|
||||||
|
return String(cString: cchars)
|
||||||
|
}
|
||||||
|
|
||||||
func predict_next_character(leftSideContext: String, count: Int) -> [(character: Character, value: Float)] {
|
func predict_next_character(leftSideContext: String, count: Int) -> [(character: Character, value: Float)] {
|
||||||
struct NextCharacterCandidate: Comparable {
|
struct NextCharacterCandidate: Comparable {
|
||||||
static func < (lhs: NextCharacterCandidate, rhs: NextCharacterCandidate) -> Bool {
|
static func < (lhs: NextCharacterCandidate, rhs: NextCharacterCandidate) -> Bool {
|
||||||
@@ -250,7 +286,6 @@ class ZenzContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func evaluate_candidate(input: String, candidate: Candidate, requestRichCandidates: Bool, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
|
func evaluate_candidate(input: String, candidate: Candidate, requestRichCandidates: Bool, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
|
||||||
print("Evaluate", candidate)
|
|
||||||
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
|
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
|
||||||
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
|
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
|
||||||
var userDictionaryPrompt: String = ""
|
var userDictionaryPrompt: String = ""
|
||||||
@@ -289,7 +324,6 @@ class ZenzContext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 左文脈を取得
|
// 左文脈を取得
|
||||||
// プロフィールがある場合はこれを条件に追加
|
|
||||||
let leftSideContext: String = switch versionDependentConfig {
|
let leftSideContext: String = switch versionDependentConfig {
|
||||||
case .v1: ""
|
case .v1: ""
|
||||||
case .v2(let mode):
|
case .v2(let mode):
|
||||||
|
|||||||
Reference in New Issue
Block a user