Merge branch 'develop' into feat/zenz-v3

This commit is contained in:
Miwa
2025-02-06 23:03:07 +09:00
committed by GitHub
10 changed files with 232 additions and 53 deletions

View File

@@ -5,7 +5,14 @@ import ArgumentParser
public struct Anco: AsyncParsableCommand {
public static let configuration = CommandConfiguration(
abstract: "Anco is A(zooKey) Kana-Ka(n)ji (co)nverter",
subcommands: [Subcommands.Run.self, Subcommands.Dict.self, Subcommands.Evaluate.self, Subcommands.Session.self, Subcommands.ExperimentalPredict.self],
subcommands: [
Subcommands.Run.self,
Subcommands.Dict.self,
Subcommands.Evaluate.self,
Subcommands.ZenzEvaluate.self,
Subcommands.Session.self,
Subcommands.ExperimentalPredict.self
],
defaultSubcommand: Subcommands.Run.self
)

View File

@@ -116,6 +116,7 @@ extension Subcommands.Dict {
dictionaryResourceURL: URL(fileURLWithPath: self.dictionaryDirectory),
memoryDirectoryURL: URL(fileURLWithPath: self.dictionaryDirectory),
sharedContainerURL: URL(fileURLWithPath: self.dictionaryDirectory),
textReplacer: .empty,
metadata: .init(versionString: "anco for debugging")
)
}

View File

@@ -17,22 +17,24 @@ extension Subcommands {
var zenzWeightPath: String = ""
@Option(name: [.customLong("config_zenzai_inference_limit")], help: "inference limit for zenzai.")
var configZenzaiInferenceLimit: Int = .max
@Flag(name: [.customLong("config_zenzai_ignore_left_context")], help: "ignore left_context")
var configZenzaiIgnoreLeftContext: Bool = false
static let configuration = CommandConfiguration(commandName: "evaluate", abstract: "Evaluate quality of Conversion for input data.")
private func parseInputFile() throws -> [InputItem] {
private func parseInputFile() throws -> [EvaluationInputItem] {
let url = URL(fileURLWithPath: self.inputFile)
let data = try Data(contentsOf: url)
return try JSONDecoder().decode([InputItem].self, from: data)
return try JSONDecoder().decode([EvaluationInputItem].self, from: data)
}
mutating func run() async throws {
let inputItems = try parseInputFile()
let requestOptions = requestOptions()
let converter = await KanaKanjiConverter()
let start = Date()
var executionTime: Double = 0
var resultItems: [EvaluateItem] = []
for item in inputItems {
let start = Date()
//
await converter.sendToDicdataStore(.importDynamicUserDict(
(item.user_dictionary ?? []).map {
@@ -42,7 +44,7 @@ extension Subcommands {
//
var composingText = ComposingText()
composingText.insertAtCursorPosition(item.query, inputStyle: .direct)
let requestOptions = self.requestOptions(leftSideContext: item.left_context)
let result = await converter.requestCandidates(composingText, options: requestOptions)
let mainResults = result.mainResults.filter {
$0.data.reduce(into: "", {$0.append(contentsOf: $1.ruby)}) == item.query.toKatakana()
@@ -51,16 +53,17 @@ extension Subcommands {
EvaluateItem(
query: item.query,
answers: item.answer,
left_context: item.left_context,
outputs: mainResults.prefix(self.configNBest).map {
EvaluateItemOutput(text: $0.text, score: Double($0.value))
}
)
)
executionTime += Date().timeIntervalSince(start)
// Explictly reset state
await converter.stopComposition()
}
let end = Date()
var result = EvaluateResult(n_best: self.configNBest, execution_time: end.timeIntervalSince(start), items: resultItems)
var result = EvaluateResult(n_best: self.configNBest, execution_time: executionTime, items: resultItems)
if stable {
result.execution_time = 0
result.timestamp = 0
@@ -83,7 +86,7 @@ extension Subcommands {
}
}
func requestOptions() -> ConvertRequestOptions {
func requestOptions(leftSideContext: String?) -> ConvertRequestOptions {
var option: ConvertRequestOptions = .withDefaultDictionary(
N_best: self.configNBest,
requireJapanesePrediction: false,
@@ -99,7 +102,7 @@ extension Subcommands {
shouldResetMemory: false,
memoryDirectoryURL: URL(fileURLWithPath: ""),
sharedContainerURL: URL(fileURLWithPath: ""),
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit),
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit, versionDependentMode: .v2(.init(leftSideContext: self.configZenzaiIgnoreLeftContext ? nil : leftSideContext))),
metadata: .init(versionString: "anco for debugging")
)
option.requestQuery = .
@@ -107,7 +110,7 @@ extension Subcommands {
}
}
private struct InputItem: Codable {
struct EvaluationInputItem: Codable {
///
var query: String
@@ -117,11 +120,13 @@ extension Subcommands {
///
var tag: [String] = []
///
var left_context: String? = nil
///
var user_dictionary: [InputUserDictionaryItem]? = nil
}
private struct InputUserDictionaryItem: Codable {
struct InputUserDictionaryItem: Codable {
///
var word: String
///
@@ -129,6 +134,7 @@ extension Subcommands {
///
var hint: String? = nil
}
}
struct EvaluateResult: Codable {
internal init(n_best: Int, timestamp: TimeInterval = Date().timeIntervalSince1970, execution_time: TimeInterval, items: [Subcommands.EvaluateItem]) {
@@ -166,9 +172,10 @@ extension Subcommands {
}
struct EvaluateItem: Codable {
init(query: String, answers: [String], outputs: [Subcommands.EvaluateItemOutput]) {
init(query: String, answers: [String], left_context: String?, outputs: [Subcommands.EvaluateItemOutput]) {
self.query = query
self.answers = answers
self.left_context = left_context ?? ""
self.outputs = outputs
do {
// entropy
@@ -195,6 +202,9 @@ extension Subcommands {
///
var outputs: [EvaluateItemOutput]
///
var left_context: String
///
var entropy: Double

View File

@@ -0,0 +1,89 @@
import KanaKanjiConverterModuleWithDefaultDictionary
import ArgumentParser
import Foundation
extension Subcommands {
struct ZenzEvaluate: AsyncParsableCommand {
@Argument(help: "query, answer, tagを備えたjsonファイルへのパス")
var inputFile: String = ""
@Option(name: [.customLong("output")], help: "Output file path.")
var outputFilePath: String? = nil
@Flag(name: [.customLong("stable")], help: "Report only stable properties; timestamps and values will not be reported.")
var stable: Bool = false
@Option(name: [.customLong("zenz")], help: "gguf format model weight for zenz.")
var zenzWeightPath: String = ""
static let configuration = CommandConfiguration(commandName: "zenz_evaluate", abstract: "Evaluate quality of pure zenz's Conversion for input data.")
private func parseInputFile() throws -> [EvaluationInputItem] {
let url = URL(fileURLWithPath: self.inputFile)
let data = try Data(contentsOf: url)
return try JSONDecoder().decode([EvaluationInputItem].self, from: data)
}
private func greedyDecoding(query: String, leftContext: String?, zenz: Zenz, maxCount: Int) async -> String {
var leftContext = if let leftContext {
"\u{EE02}" + String(leftContext.suffix(40))
} else {
""
}
leftContext = "\u{EE00}\(query)\(leftContext)\u{EE01}"
return await zenz.pureGreedyDecoding(pureInput: leftContext, maxCount: maxCount)
}
mutating func run() async throws {
let inputItems = try parseInputFile()
let converter = await KanaKanjiConverter()
var executionTime: Double = 0
var resultItems: [EvaluateItem] = []
guard let zenz = await converter.getModel(modelURL: URL(string: self.zenzWeightPath)!) else {
print("Failed to initialize zenz model")
return
}
for item in inputItems {
let start = Date()
if item.user_dictionary != nil {
print("Warning: zenz_evaluate command does not suppport user dictionary. User Dictionary Contents are just ignored.")
}
//
let result = await self.greedyDecoding(query: item.query, leftContext: item.left_context, zenz: zenz, maxCount: item.answer.map(\.utf8.count).max()!)
print("Results:", result)
resultItems.append(
EvaluateItem(
query: item.query,
answers: item.answer,
left_context: item.left_context,
outputs: [
EvaluateItemOutput(text: result, score: 0.0)
]
)
)
executionTime += Date().timeIntervalSince(start)
await zenz.endSession()
}
var result = EvaluateResult(n_best: 1, execution_time: executionTime, items: resultItems)
if stable {
result.execution_time = 0
result.timestamp = 0
result.items.mutatingForeach {
$0.outputs.mutatingForeach {
$0.score = Double(Int($0.score))
}
}
}
let encoder = JSONEncoder()
encoder.outputFormatting = [.prettyPrinted, .sortedKeys]
let json = try encoder.encode(result)
if let outputFilePath {
try json.write(to: URL(fileURLWithPath: outputFilePath))
} else {
let string = String(data: json, encoding: .utf8)!
print(string)
}
}
}
}

View File

@@ -29,7 +29,7 @@ public struct ConvertRequestOptions: Sendable {
/// - sharedContainerURL:
/// - textReplacer:
/// - metadata: `ConvertRequestOptions.Metadata`
public init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, textReplacer: TextReplacer = TextReplacer(), zenzaiMode: ZenzaiMode = .off, metadata: ConvertRequestOptions.Metadata?) {
public init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, textReplacer: TextReplacer, zenzaiMode: ZenzaiMode = .off, metadata: ConvertRequestOptions.Metadata?) {
self.N_best = N_best
self.requireJapanesePrediction = requireJapanesePrediction
self.requireEnglishPrediction = requireEnglishPrediction
@@ -50,25 +50,49 @@ public struct ConvertRequestOptions: Sendable {
self.dictionaryResourceURL = dictionaryResourceURL
}
package init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, textReplacer: TextReplacer = TextReplacer(), zenzaiMode: ZenzaiMode = .off, metadata: ConvertRequestOptions.Metadata?, requestQuery: RequestQuery) {
self.N_best = N_best
self.requireJapanesePrediction = requireJapanesePrediction
self.requireEnglishPrediction = requireEnglishPrediction
self.keyboardLanguage = keyboardLanguage
self.typographyLetterCandidate = typographyLetterCandidate
self.unicodeCandidate = unicodeCandidate
self.englishCandidateInRoman2KanaInput = englishCandidateInRoman2KanaInput
self.fullWidthRomanCandidate = fullWidthRomanCandidate
self.halfWidthKanaCandidate = halfWidthKanaCandidate
self.learningType = learningType
self.maxMemoryCount = maxMemoryCount
self.shouldResetMemory = shouldResetMemory
self.memoryDirectoryURL = memoryDirectoryURL
self.sharedContainerURL = sharedContainerURL
self.metadata = metadata
self.textReplacer = textReplacer
self.zenzaiMode = zenzaiMode
self.dictionaryResourceURL = dictionaryResourceURL
///
///
/// - parameters:
/// - N_best: `N`
/// - requireJapanesePrediction: `false`
/// - requireEnglishPrediction: `false``false`
/// - keyboardLanguage:
/// - typographyLetterCandidate: `true``KanaKanjiConverter.typographicalCandidates(_:)`
/// - unicodeCandidate: `true``U+xxxx`Unicode`KanaKanjiConverter.unicodeCandidates(_:)``
/// - englishCandidateInRoman2KanaInput: `true``false`
/// - fullWidthRomanCandidate: `true`
/// - halfWidthKanaCandidate: `true`
/// - learningType: `LearningType`
/// - maxMemoryCount: `0``learningType``nothing`
/// - shouldResetMemory: `true`
/// - dictionaryResourceURL:
/// - memoryDirectoryURL:
/// - sharedContainerURL:
/// - textReplacer:
/// - metadata: `ConvertRequestOptions.Metadata`
@available(*, deprecated, message: "it be removed in AzooKeyKanaKanjiConverter v1.0")
public init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, zenzaiMode: ZenzaiMode = .off, metadata: ConvertRequestOptions.Metadata?) {
self.init(
N_best: N_best,
requireJapanesePrediction: requireJapanesePrediction,
requireEnglishPrediction: requireEnglishPrediction,
keyboardLanguage: keyboardLanguage,
typographyLetterCandidate: typographyLetterCandidate,
unicodeCandidate: unicodeCandidate,
englishCandidateInRoman2KanaInput: englishCandidateInRoman2KanaInput,
fullWidthRomanCandidate: fullWidthRomanCandidate,
halfWidthKanaCandidate: halfWidthKanaCandidate,
learningType: learningType,
maxMemoryCount: maxMemoryCount,
shouldResetMemory: shouldResetMemory,
dictionaryResourceURL: dictionaryResourceURL,
memoryDirectoryURL: memoryDirectoryURL,
sharedContainerURL: sharedContainerURL,
// MARK: using deprecated initializer here
textReplacer: TextReplacer(),
zenzaiMode: zenzaiMode,
metadata: metadata
)
}
public var N_best: Int
@@ -117,6 +141,7 @@ public struct ConvertRequestOptions: Sendable {
memoryDirectoryURL: (try? FileManager.default.url(for: .libraryDirectory, in: .userDomainMask, appropriateFor: nil, create: false)) ?? Bundle.main.bundleURL,
// dummy data, won't work
sharedContainerURL: Bundle.main.bundleURL,
textReplacer: .empty,
metadata: nil
)
}

View File

@@ -40,7 +40,7 @@ import SwiftUtils
self.lastData = nil
}
private func getModel(modelURL: URL) -> Zenz? {
package func getModel(modelURL: URL) -> Zenz? {
if let model = self.zenz, model.resourceURL == modelURL {
self.zenzStatus = "load \(modelURL.absoluteString)"
return model

View File

@@ -130,7 +130,7 @@ extension ComposingText {
}
return nil
}
while var (convertTargetElements, lastElement, count) = stack.popLast() {
while case .some((var convertTargetElements, let lastElement, let count)) = stack.popLast() {
if rightIndexRange.contains(count + left - 1) {
if let convertTarget = ComposingText.getConvertTargetIfRightSideIsValid(lastElement: lastElement, of: self.input, to: count + left, convertTargetElements: convertTargetElements)?.map({$0.toKatakana()}) {
stringToInfo.append((convertTarget, (count + left - 1)))

View File

@@ -61,6 +61,17 @@ public struct TextReplacer: Sendable {
}
}
/// `TextReplacer`
/// - parameters:
/// - isEmpty:
private init(isEmpty: Bool) {
assert(isEmpty)
}
public static var empty: Self {
Self(isEmpty: true)
}
@available(*, deprecated, renamed: "init(emojiDataProvider:)", message: "it be removed in AzooKeyKanaKanjiConverter v1.0")
public init() {
self.init {

View File

@@ -1,7 +1,7 @@
import Foundation
import SwiftUtils
@MainActor final class Zenz {
@MainActor package final class Zenz {
package var resourceURL: URL
private var zenzContext: ZenzContext?
init(resourceURL: URL) throws {
@@ -27,9 +27,7 @@ import SwiftUtils
}
}
func startSession() {}
func endSession() {
package func endSession() {
try? self.zenzContext?.reset_context()
}
@@ -51,4 +49,8 @@ import SwiftUtils
let result = zenzContext.predict_next_character(leftSideContext: leftSideContext, count: count)
return result
}
package func pureGreedyDecoding(pureInput: String, maxCount: Int = .max) -> String {
return self.zenzContext?.pure_greedy_decoding(leftSideContext: pureInput, maxCount: maxCount) ?? ""
}
}

View File

@@ -67,7 +67,7 @@ enum ZenzError: LocalizedError {
}
}
class ZenzContext {
final class ZenzContext {
private var model: OpaquePointer
private var context: OpaquePointer
private var prevInput: [llama_token] = []
@@ -197,6 +197,42 @@ class ZenzContext {
}
}
///
func pure_greedy_decoding(leftSideContext: String, maxCount: Int = .max) -> String {
var prompt_tokens = self.tokenize(text: leftSideContext, add_bos: false)
let initial_count = prompt_tokens.count
let eos_token = llama_token_eos(model)
while prompt_tokens.count - initial_count < maxCount {
let startOffset = prompt_tokens.count - 1
guard let logits = self.get_logits(tokens: prompt_tokens, logits_start_index: startOffset) else {
print("logits unavailable")
return ""
}
let n_vocab = llama_n_vocab(model)
let startIndex = (prompt_tokens.count - 1 - startOffset) * Int(n_vocab)
let endIndex = (prompt_tokens.count - startOffset) * Int(n_vocab)
// Min-Heap使n-best
var max_token: llama_token = -1
var max_value: Float = Float.infinity * -1
for index in startIndex..<endIndex {
let token = llama_token(index - startIndex)
if max_value < logits[index] {
max_token = token
max_value = logits[index]
}
}
if max_token == eos_token {
break
} else {
prompt_tokens.append(max_token)
}
}
// Heap
let cchars: [CChar] = prompt_tokens.dropFirst(initial_count).flatMap(self.token_to_piece) + [0]
return String(cString: cchars)
}
func predict_next_character(leftSideContext: String, count: Int) -> [(character: Character, value: Float)] {
struct NextCharacterCandidate: Comparable {
static func < (lhs: NextCharacterCandidate, rhs: NextCharacterCandidate) -> Bool {
@@ -250,7 +286,6 @@ class ZenzContext {
}
func evaluate_candidate(input: String, candidate: Candidate, requestRichCandidates: Bool, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
print("Evaluate", candidate)
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
var userDictionaryPrompt: String = ""
@@ -289,7 +324,6 @@ class ZenzContext {
}
}
//
//
let leftSideContext: String = switch versionDependentConfig {
case .v1: ""
case .v2(let mode):