Files
AzooKeyKanaKanjiConverter/Sources/CliTool/Subcommands/ZenzEvaluateCommand.swift
Miwa / Ensan 4e3921cee2 impl for eval
2025-02-06 21:58:26 +09:00

90 lines
3.8 KiB
Swift

import KanaKanjiConverterModuleWithDefaultDictionary
import ArgumentParser
import Foundation
extension Subcommands {
struct ZenzEvaluate: AsyncParsableCommand {
@Argument(help: "query, answer, tagを備えたjsonファイルへのパス")
var inputFile: String = ""
@Option(name: [.customLong("output")], help: "Output file path.")
var outputFilePath: String? = nil
@Flag(name: [.customLong("stable")], help: "Report only stable properties; timestamps and values will not be reported.")
var stable: Bool = false
@Option(name: [.customLong("zenz")], help: "gguf format model weight for zenz.")
var zenzWeightPath: String = ""
static let configuration = CommandConfiguration(commandName: "zenz_evaluate", abstract: "Evaluate quality of pure zenz's Conversion for input data.")
private func parseInputFile() throws -> [EvaluationInputItem] {
let url = URL(fileURLWithPath: self.inputFile)
let data = try Data(contentsOf: url)
return try JSONDecoder().decode([EvaluationInputItem].self, from: data)
}
private func greedyDecoding(query: String, leftContext: String?, zenz: Zenz, maxCount: Int) async -> String {
var leftContext = if let leftContext {
"\u{EE02}" + String(leftContext.suffix(40))
} else {
""
}
leftContext = "\u{EE00}\(query)\(leftContext)\u{EE01}"
return await zenz.pureGreedyDecoding(pureInput: leftContext, maxCount: maxCount)
}
mutating func run() async throws {
let inputItems = try parseInputFile()
let converter = await KanaKanjiConverter()
var executionTime: Double = 0
var resultItems: [EvaluateItem] = []
guard let zenz = await converter.getModel(modelURL: URL(string: self.zenzWeightPath)!) else {
print("Failed to initialize zenz model")
return
}
for item in inputItems {
let start = Date()
if item.user_dictionary != nil {
print("Warning: zenz_evaluate command does not suppport user dictionary. User Dictionary Contents are just ignored.")
}
//
let result = await self.greedyDecoding(query: item.query, leftContext: item.left_context, zenz: zenz, maxCount: item.answer.map(\.utf8.count).max()!)
print("Results:", result)
resultItems.append(
EvaluateItem(
query: item.query,
answers: item.answer,
left_context: item.left_context,
outputs: [
EvaluateItemOutput(text: result, score: 0.0)
]
)
)
executionTime += Date().timeIntervalSince(start)
await zenz.endSession()
}
var result = EvaluateResult(n_best: 1, execution_time: executionTime, items: resultItems)
if stable {
result.execution_time = 0
result.timestamp = 0
result.items.mutatingForeach {
$0.outputs.mutatingForeach {
$0.score = Double(Int($0.score))
}
}
}
let encoder = JSONEncoder()
encoder.outputFormatting = [.prettyPrinted, .sortedKeys]
let json = try encoder.encode(result)
if let outputFilePath {
try json.write(to: URL(fileURLWithPath: outputFilePath))
} else {
let string = String(data: json, encoding: .utf8)!
print(string)
}
}
}
}