mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-12-03 11:08:33 +00:00
Merge branch 'develop' into feat/zenz-v3
This commit is contained in:
@@ -5,7 +5,14 @@ import ArgumentParser
|
||||
public struct Anco: AsyncParsableCommand {
|
||||
public static let configuration = CommandConfiguration(
|
||||
abstract: "Anco is A(zooKey) Kana-Ka(n)ji (co)nverter",
|
||||
subcommands: [Subcommands.Run.self, Subcommands.Dict.self, Subcommands.Evaluate.self, Subcommands.Session.self, Subcommands.ExperimentalPredict.self],
|
||||
subcommands: [
|
||||
Subcommands.Run.self,
|
||||
Subcommands.Dict.self,
|
||||
Subcommands.Evaluate.self,
|
||||
Subcommands.ZenzEvaluate.self,
|
||||
Subcommands.Session.self,
|
||||
Subcommands.ExperimentalPredict.self
|
||||
],
|
||||
defaultSubcommand: Subcommands.Run.self
|
||||
)
|
||||
|
||||
|
||||
@@ -116,6 +116,7 @@ extension Subcommands.Dict {
|
||||
dictionaryResourceURL: URL(fileURLWithPath: self.dictionaryDirectory),
|
||||
memoryDirectoryURL: URL(fileURLWithPath: self.dictionaryDirectory),
|
||||
sharedContainerURL: URL(fileURLWithPath: self.dictionaryDirectory),
|
||||
textReplacer: .empty,
|
||||
metadata: .init(versionString: "anco for debugging")
|
||||
)
|
||||
}
|
||||
|
||||
@@ -17,22 +17,24 @@ extension Subcommands {
|
||||
var zenzWeightPath: String = ""
|
||||
@Option(name: [.customLong("config_zenzai_inference_limit")], help: "inference limit for zenzai.")
|
||||
var configZenzaiInferenceLimit: Int = .max
|
||||
@Flag(name: [.customLong("config_zenzai_ignore_left_context")], help: "ignore left_context")
|
||||
var configZenzaiIgnoreLeftContext: Bool = false
|
||||
|
||||
static let configuration = CommandConfiguration(commandName: "evaluate", abstract: "Evaluate quality of Conversion for input data.")
|
||||
|
||||
private func parseInputFile() throws -> [InputItem] {
|
||||
private func parseInputFile() throws -> [EvaluationInputItem] {
|
||||
let url = URL(fileURLWithPath: self.inputFile)
|
||||
let data = try Data(contentsOf: url)
|
||||
return try JSONDecoder().decode([InputItem].self, from: data)
|
||||
return try JSONDecoder().decode([EvaluationInputItem].self, from: data)
|
||||
}
|
||||
|
||||
mutating func run() async throws {
|
||||
let inputItems = try parseInputFile()
|
||||
let requestOptions = requestOptions()
|
||||
let converter = await KanaKanjiConverter()
|
||||
let start = Date()
|
||||
var executionTime: Double = 0
|
||||
var resultItems: [EvaluateItem] = []
|
||||
for item in inputItems {
|
||||
let start = Date()
|
||||
// セットアップ
|
||||
await converter.sendToDicdataStore(.importDynamicUserDict(
|
||||
(item.user_dictionary ?? []).map {
|
||||
@@ -42,7 +44,7 @@ extension Subcommands {
|
||||
// 変換
|
||||
var composingText = ComposingText()
|
||||
composingText.insertAtCursorPosition(item.query, inputStyle: .direct)
|
||||
|
||||
let requestOptions = self.requestOptions(leftSideContext: item.left_context)
|
||||
let result = await converter.requestCandidates(composingText, options: requestOptions)
|
||||
let mainResults = result.mainResults.filter {
|
||||
$0.data.reduce(into: "", {$0.append(contentsOf: $1.ruby)}) == item.query.toKatakana()
|
||||
@@ -51,16 +53,17 @@ extension Subcommands {
|
||||
EvaluateItem(
|
||||
query: item.query,
|
||||
answers: item.answer,
|
||||
left_context: item.left_context,
|
||||
outputs: mainResults.prefix(self.configNBest).map {
|
||||
EvaluateItemOutput(text: $0.text, score: Double($0.value))
|
||||
}
|
||||
)
|
||||
)
|
||||
executionTime += Date().timeIntervalSince(start)
|
||||
// Explictly reset state
|
||||
await converter.stopComposition()
|
||||
}
|
||||
let end = Date()
|
||||
var result = EvaluateResult(n_best: self.configNBest, execution_time: end.timeIntervalSince(start), items: resultItems)
|
||||
var result = EvaluateResult(n_best: self.configNBest, execution_time: executionTime, items: resultItems)
|
||||
if stable {
|
||||
result.execution_time = 0
|
||||
result.timestamp = 0
|
||||
@@ -83,7 +86,7 @@ extension Subcommands {
|
||||
}
|
||||
}
|
||||
|
||||
func requestOptions() -> ConvertRequestOptions {
|
||||
func requestOptions(leftSideContext: String?) -> ConvertRequestOptions {
|
||||
var option: ConvertRequestOptions = .withDefaultDictionary(
|
||||
N_best: self.configNBest,
|
||||
requireJapanesePrediction: false,
|
||||
@@ -99,7 +102,7 @@ extension Subcommands {
|
||||
shouldResetMemory: false,
|
||||
memoryDirectoryURL: URL(fileURLWithPath: ""),
|
||||
sharedContainerURL: URL(fileURLWithPath: ""),
|
||||
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit),
|
||||
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit, versionDependentMode: .v2(.init(leftSideContext: self.configZenzaiIgnoreLeftContext ? nil : leftSideContext))),
|
||||
metadata: .init(versionString: "anco for debugging")
|
||||
)
|
||||
option.requestQuery = .完全一致
|
||||
@@ -107,7 +110,7 @@ extension Subcommands {
|
||||
}
|
||||
}
|
||||
|
||||
private struct InputItem: Codable {
|
||||
struct EvaluationInputItem: Codable {
|
||||
/// 入力クエリ
|
||||
var query: String
|
||||
|
||||
@@ -117,11 +120,13 @@ extension Subcommands {
|
||||
/// タグ
|
||||
var tag: [String] = []
|
||||
|
||||
/// 左文脈
|
||||
var left_context: String? = nil
|
||||
|
||||
/// ユーザ辞書
|
||||
var user_dictionary: [InputUserDictionaryItem]? = nil
|
||||
}
|
||||
|
||||
private struct InputUserDictionaryItem: Codable {
|
||||
struct InputUserDictionaryItem: Codable {
|
||||
/// 漢字
|
||||
var word: String
|
||||
/// 読み
|
||||
@@ -129,6 +134,7 @@ extension Subcommands {
|
||||
/// ヒント
|
||||
var hint: String? = nil
|
||||
}
|
||||
}
|
||||
|
||||
struct EvaluateResult: Codable {
|
||||
internal init(n_best: Int, timestamp: TimeInterval = Date().timeIntervalSince1970, execution_time: TimeInterval, items: [Subcommands.EvaluateItem]) {
|
||||
@@ -166,9 +172,10 @@ extension Subcommands {
|
||||
}
|
||||
|
||||
struct EvaluateItem: Codable {
|
||||
init(query: String, answers: [String], outputs: [Subcommands.EvaluateItemOutput]) {
|
||||
init(query: String, answers: [String], left_context: String?, outputs: [Subcommands.EvaluateItemOutput]) {
|
||||
self.query = query
|
||||
self.answers = answers
|
||||
self.left_context = left_context ?? ""
|
||||
self.outputs = outputs
|
||||
do {
|
||||
// entropyを示す
|
||||
@@ -195,6 +202,9 @@ extension Subcommands {
|
||||
/// 出力
|
||||
var outputs: [EvaluateItemOutput]
|
||||
|
||||
/// 文脈
|
||||
var left_context: String
|
||||
|
||||
/// エントロピー
|
||||
var entropy: Double
|
||||
|
||||
|
||||
89
Sources/CliTool/Subcommands/ZenzEvaluateCommand.swift
Normal file
89
Sources/CliTool/Subcommands/ZenzEvaluateCommand.swift
Normal file
@@ -0,0 +1,89 @@
|
||||
import KanaKanjiConverterModuleWithDefaultDictionary
|
||||
import ArgumentParser
|
||||
import Foundation
|
||||
|
||||
extension Subcommands {
|
||||
struct ZenzEvaluate: AsyncParsableCommand {
|
||||
@Argument(help: "query, answer, tagを備えたjsonファイルへのパス")
|
||||
var inputFile: String = ""
|
||||
|
||||
@Option(name: [.customLong("output")], help: "Output file path.")
|
||||
var outputFilePath: String? = nil
|
||||
@Flag(name: [.customLong("stable")], help: "Report only stable properties; timestamps and values will not be reported.")
|
||||
var stable: Bool = false
|
||||
@Option(name: [.customLong("zenz")], help: "gguf format model weight for zenz.")
|
||||
var zenzWeightPath: String = ""
|
||||
|
||||
static let configuration = CommandConfiguration(commandName: "zenz_evaluate", abstract: "Evaluate quality of pure zenz's Conversion for input data.")
|
||||
|
||||
private func parseInputFile() throws -> [EvaluationInputItem] {
|
||||
let url = URL(fileURLWithPath: self.inputFile)
|
||||
let data = try Data(contentsOf: url)
|
||||
return try JSONDecoder().decode([EvaluationInputItem].self, from: data)
|
||||
}
|
||||
|
||||
private func greedyDecoding(query: String, leftContext: String?, zenz: Zenz, maxCount: Int) async -> String {
|
||||
var leftContext = if let leftContext {
|
||||
"\u{EE02}" + String(leftContext.suffix(40))
|
||||
} else {
|
||||
""
|
||||
}
|
||||
leftContext = "\u{EE00}\(query)\(leftContext)\u{EE01}"
|
||||
return await zenz.pureGreedyDecoding(pureInput: leftContext, maxCount: maxCount)
|
||||
}
|
||||
|
||||
mutating func run() async throws {
|
||||
let inputItems = try parseInputFile()
|
||||
let converter = await KanaKanjiConverter()
|
||||
var executionTime: Double = 0
|
||||
var resultItems: [EvaluateItem] = []
|
||||
|
||||
guard let zenz = await converter.getModel(modelURL: URL(string: self.zenzWeightPath)!) else {
|
||||
print("Failed to initialize zenz model")
|
||||
return
|
||||
}
|
||||
|
||||
for item in inputItems {
|
||||
let start = Date()
|
||||
if item.user_dictionary != nil {
|
||||
print("Warning: zenz_evaluate command does not suppport user dictionary. User Dictionary Contents are just ignored.")
|
||||
}
|
||||
// 変換
|
||||
let result = await self.greedyDecoding(query: item.query, leftContext: item.left_context, zenz: zenz, maxCount: item.answer.map(\.utf8.count).max()!)
|
||||
print("Results:", result)
|
||||
resultItems.append(
|
||||
EvaluateItem(
|
||||
query: item.query,
|
||||
answers: item.answer,
|
||||
left_context: item.left_context,
|
||||
outputs: [
|
||||
EvaluateItemOutput(text: result, score: 0.0)
|
||||
]
|
||||
)
|
||||
)
|
||||
executionTime += Date().timeIntervalSince(start)
|
||||
await zenz.endSession()
|
||||
}
|
||||
var result = EvaluateResult(n_best: 1, execution_time: executionTime, items: resultItems)
|
||||
if stable {
|
||||
result.execution_time = 0
|
||||
result.timestamp = 0
|
||||
result.items.mutatingForeach {
|
||||
$0.outputs.mutatingForeach {
|
||||
$0.score = Double(Int($0.score))
|
||||
}
|
||||
}
|
||||
}
|
||||
let encoder = JSONEncoder()
|
||||
encoder.outputFormatting = [.prettyPrinted, .sortedKeys]
|
||||
let json = try encoder.encode(result)
|
||||
|
||||
if let outputFilePath {
|
||||
try json.write(to: URL(fileURLWithPath: outputFilePath))
|
||||
} else {
|
||||
let string = String(data: json, encoding: .utf8)!
|
||||
print(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -29,7 +29,7 @@ public struct ConvertRequestOptions: Sendable {
|
||||
/// - sharedContainerURL: ユーザ辞書など、キーボード外で書き込んだ設定データの保存されているディレクトリを指定します。
|
||||
/// - textReplacer: 予測変換のための置換機を指定します。
|
||||
/// - metadata: メタデータを指定します。詳しくは`ConvertRequestOptions.Metadata`を参照してください。
|
||||
public init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, textReplacer: TextReplacer = TextReplacer(), zenzaiMode: ZenzaiMode = .off, metadata: ConvertRequestOptions.Metadata?) {
|
||||
public init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, textReplacer: TextReplacer, zenzaiMode: ZenzaiMode = .off, metadata: ConvertRequestOptions.Metadata?) {
|
||||
self.N_best = N_best
|
||||
self.requireJapanesePrediction = requireJapanesePrediction
|
||||
self.requireEnglishPrediction = requireEnglishPrediction
|
||||
@@ -50,25 +50,49 @@ public struct ConvertRequestOptions: Sendable {
|
||||
self.dictionaryResourceURL = dictionaryResourceURL
|
||||
}
|
||||
|
||||
package init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, textReplacer: TextReplacer = TextReplacer(), zenzaiMode: ZenzaiMode = .off, metadata: ConvertRequestOptions.Metadata?, requestQuery: RequestQuery) {
|
||||
self.N_best = N_best
|
||||
self.requireJapanesePrediction = requireJapanesePrediction
|
||||
self.requireEnglishPrediction = requireEnglishPrediction
|
||||
self.keyboardLanguage = keyboardLanguage
|
||||
self.typographyLetterCandidate = typographyLetterCandidate
|
||||
self.unicodeCandidate = unicodeCandidate
|
||||
self.englishCandidateInRoman2KanaInput = englishCandidateInRoman2KanaInput
|
||||
self.fullWidthRomanCandidate = fullWidthRomanCandidate
|
||||
self.halfWidthKanaCandidate = halfWidthKanaCandidate
|
||||
self.learningType = learningType
|
||||
self.maxMemoryCount = maxMemoryCount
|
||||
self.shouldResetMemory = shouldResetMemory
|
||||
self.memoryDirectoryURL = memoryDirectoryURL
|
||||
self.sharedContainerURL = sharedContainerURL
|
||||
self.metadata = metadata
|
||||
self.textReplacer = textReplacer
|
||||
self.zenzaiMode = zenzaiMode
|
||||
self.dictionaryResourceURL = dictionaryResourceURL
|
||||
/// 変換リクエストに必要な設定データ
|
||||
///
|
||||
/// - parameters:
|
||||
/// - N_best: 変換候補の数。上位`N`件までの言語モデル上の妥当性を保証します。大きくすると計算量が増加します。
|
||||
/// - requireJapanesePrediction: 日本語の予測変換候補の必要性。`false`にすると、日本語の予測変換候補を出力しなくなります。
|
||||
/// - requireEnglishPrediction: 英語の予測変換候補の必要性。`false`にすると、英語の予測変換候補を出力しなくなります。ローマ字入力を用いた日本語入力では`false`にした方が良いでしょう。
|
||||
/// - keyboardLanguage: キーボードの言語を指定します。
|
||||
/// - typographyLetterCandidate: `true`の場合、「おしゃれなフォント」での英数字変換候補が出力に含まれるようになります。詳しくは`KanaKanjiConverter.typographicalCandidates(_:)`を参照してください。
|
||||
/// - unicodeCandidate: `true`の場合、`U+xxxx`のような入力に対してUnicodeの変換候補が出力に含まれるようになります。詳しくは`KanaKanjiConverter.unicodeCandidates(_:)`を参照してください。`
|
||||
/// - englishCandidateInRoman2KanaInput: `true`の場合、日本語ローマ字入力時に英語変換候補を出力します。`false`の場合、ローマ字入力時に英語変換候補を出力しません。
|
||||
/// - fullWidthRomanCandidate: `true`の場合、全角英数字の変換候補が出力に含まれるようになります。
|
||||
/// - halfWidthKanaCandidate: `true`の場合、半角カナの変換候補が出力に含まれるようになります。
|
||||
/// - learningType: 学習モードを指定します。詳しくは`LearningType`を参照してください。
|
||||
/// - maxMemoryCount: 学習が有効な場合に保持するデータの最大数を指定します。`0`の場合`learningType`を`nothing`に指定する方が適切です。
|
||||
/// - shouldResetMemory: `true`の場合、変換を開始する前に学習データをリセットします。
|
||||
/// - dictionaryResourceURL: 内蔵辞書データの読み出し先を指定します。
|
||||
/// - memoryDirectoryURL: 学習データの保存先を指定します。書き込み可能なディレクトリを指定してください。
|
||||
/// - sharedContainerURL: ユーザ辞書など、キーボード外で書き込んだ設定データの保存されているディレクトリを指定します。
|
||||
/// - textReplacer: 予測変換のための置換機を指定します。
|
||||
/// - metadata: メタデータを指定します。詳しくは`ConvertRequestOptions.Metadata`を参照してください。
|
||||
@available(*, deprecated, message: "it be removed in AzooKeyKanaKanjiConverter v1.0")
|
||||
public init(N_best: Int = 10, requireJapanesePrediction: Bool, requireEnglishPrediction: Bool, keyboardLanguage: KeyboardLanguage, typographyLetterCandidate: Bool = false, unicodeCandidate: Bool = true, englishCandidateInRoman2KanaInput: Bool = false, fullWidthRomanCandidate: Bool = false, halfWidthKanaCandidate: Bool = false, learningType: LearningType, maxMemoryCount: Int = 65536, shouldResetMemory: Bool = false, dictionaryResourceURL: URL, memoryDirectoryURL: URL, sharedContainerURL: URL, zenzaiMode: ZenzaiMode = .off, metadata: ConvertRequestOptions.Metadata?) {
|
||||
self.init(
|
||||
N_best: N_best,
|
||||
requireJapanesePrediction: requireJapanesePrediction,
|
||||
requireEnglishPrediction: requireEnglishPrediction,
|
||||
keyboardLanguage: keyboardLanguage,
|
||||
typographyLetterCandidate: typographyLetterCandidate,
|
||||
unicodeCandidate: unicodeCandidate,
|
||||
englishCandidateInRoman2KanaInput: englishCandidateInRoman2KanaInput,
|
||||
fullWidthRomanCandidate: fullWidthRomanCandidate,
|
||||
halfWidthKanaCandidate: halfWidthKanaCandidate,
|
||||
learningType: learningType,
|
||||
maxMemoryCount: maxMemoryCount,
|
||||
shouldResetMemory: shouldResetMemory,
|
||||
dictionaryResourceURL: dictionaryResourceURL,
|
||||
memoryDirectoryURL: memoryDirectoryURL,
|
||||
sharedContainerURL: sharedContainerURL,
|
||||
// MARK: using deprecated initializer here
|
||||
textReplacer: TextReplacer(),
|
||||
zenzaiMode: zenzaiMode,
|
||||
metadata: metadata
|
||||
)
|
||||
}
|
||||
|
||||
public var N_best: Int
|
||||
@@ -117,6 +141,7 @@ public struct ConvertRequestOptions: Sendable {
|
||||
memoryDirectoryURL: (try? FileManager.default.url(for: .libraryDirectory, in: .userDomainMask, appropriateFor: nil, create: false)) ?? Bundle.main.bundleURL,
|
||||
// dummy data, won't work
|
||||
sharedContainerURL: Bundle.main.bundleURL,
|
||||
textReplacer: .empty,
|
||||
metadata: nil
|
||||
)
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ import SwiftUtils
|
||||
self.lastData = nil
|
||||
}
|
||||
|
||||
private func getModel(modelURL: URL) -> Zenz? {
|
||||
package func getModel(modelURL: URL) -> Zenz? {
|
||||
if let model = self.zenz, model.resourceURL == modelURL {
|
||||
self.zenzStatus = "load \(modelURL.absoluteString)"
|
||||
return model
|
||||
|
||||
@@ -130,7 +130,7 @@ extension ComposingText {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
while var (convertTargetElements, lastElement, count) = stack.popLast() {
|
||||
while case .some((var convertTargetElements, let lastElement, let count)) = stack.popLast() {
|
||||
if rightIndexRange.contains(count + left - 1) {
|
||||
if let convertTarget = ComposingText.getConvertTargetIfRightSideIsValid(lastElement: lastElement, of: self.input, to: count + left, convertTargetElements: convertTargetElements)?.map({$0.toKatakana()}) {
|
||||
stringToInfo.append((convertTarget, (count + left - 1)))
|
||||
|
||||
@@ -61,6 +61,17 @@ public struct TextReplacer: Sendable {
|
||||
}
|
||||
}
|
||||
|
||||
/// 動作しない`TextReplacer`を構築するためのイニシャライザ
|
||||
/// - parameters:
|
||||
/// - isEmpty: 入力
|
||||
private init(isEmpty: Bool) {
|
||||
assert(isEmpty)
|
||||
}
|
||||
|
||||
public static var empty: Self {
|
||||
Self(isEmpty: true)
|
||||
}
|
||||
|
||||
@available(*, deprecated, renamed: "init(emojiDataProvider:)", message: "it be removed in AzooKeyKanaKanjiConverter v1.0")
|
||||
public init() {
|
||||
self.init {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import Foundation
|
||||
import SwiftUtils
|
||||
|
||||
@MainActor final class Zenz {
|
||||
@MainActor package final class Zenz {
|
||||
package var resourceURL: URL
|
||||
private var zenzContext: ZenzContext?
|
||||
init(resourceURL: URL) throws {
|
||||
@@ -27,9 +27,7 @@ import SwiftUtils
|
||||
}
|
||||
}
|
||||
|
||||
func startSession() {}
|
||||
|
||||
func endSession() {
|
||||
package func endSession() {
|
||||
try? self.zenzContext?.reset_context()
|
||||
}
|
||||
|
||||
@@ -51,4 +49,8 @@ import SwiftUtils
|
||||
let result = zenzContext.predict_next_character(leftSideContext: leftSideContext, count: count)
|
||||
return result
|
||||
}
|
||||
|
||||
package func pureGreedyDecoding(pureInput: String, maxCount: Int = .max) -> String {
|
||||
return self.zenzContext?.pure_greedy_decoding(leftSideContext: pureInput, maxCount: maxCount) ?? ""
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ enum ZenzError: LocalizedError {
|
||||
}
|
||||
}
|
||||
|
||||
class ZenzContext {
|
||||
final class ZenzContext {
|
||||
private var model: OpaquePointer
|
||||
private var context: OpaquePointer
|
||||
private var prevInput: [llama_token] = []
|
||||
@@ -197,6 +197,42 @@ class ZenzContext {
|
||||
}
|
||||
}
|
||||
|
||||
/// ピュアな貪欲法による生成を行って返す
|
||||
func pure_greedy_decoding(leftSideContext: String, maxCount: Int = .max) -> String {
|
||||
var prompt_tokens = self.tokenize(text: leftSideContext, add_bos: false)
|
||||
let initial_count = prompt_tokens.count
|
||||
let eos_token = llama_token_eos(model)
|
||||
while prompt_tokens.count - initial_count < maxCount {
|
||||
let startOffset = prompt_tokens.count - 1
|
||||
guard let logits = self.get_logits(tokens: prompt_tokens, logits_start_index: startOffset) else {
|
||||
print("logits unavailable")
|
||||
return ""
|
||||
}
|
||||
let n_vocab = llama_n_vocab(model)
|
||||
let startIndex = (prompt_tokens.count - 1 - startOffset) * Int(n_vocab)
|
||||
let endIndex = (prompt_tokens.count - startOffset) * Int(n_vocab)
|
||||
// Min-Heapを使用してn-bestを計算
|
||||
var max_token: llama_token = -1
|
||||
var max_value: Float = Float.infinity * -1
|
||||
for index in startIndex..<endIndex {
|
||||
let token = llama_token(index - startIndex)
|
||||
if max_value < logits[index] {
|
||||
max_token = token
|
||||
max_value = logits[index]
|
||||
}
|
||||
}
|
||||
if max_token == eos_token {
|
||||
break
|
||||
} else {
|
||||
prompt_tokens.append(max_token)
|
||||
}
|
||||
}
|
||||
|
||||
// Heapからソートして結果を取り出す
|
||||
let cchars: [CChar] = prompt_tokens.dropFirst(initial_count).flatMap(self.token_to_piece) + [0]
|
||||
return String(cString: cchars)
|
||||
}
|
||||
|
||||
func predict_next_character(leftSideContext: String, count: Int) -> [(character: Character, value: Float)] {
|
||||
struct NextCharacterCandidate: Comparable {
|
||||
static func < (lhs: NextCharacterCandidate, rhs: NextCharacterCandidate) -> Bool {
|
||||
@@ -250,7 +286,6 @@ class ZenzContext {
|
||||
}
|
||||
|
||||
func evaluate_candidate(input: String, candidate: Candidate, requestRichCandidates: Bool, versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode) -> CandidateEvaluationResult {
|
||||
print("Evaluate", candidate)
|
||||
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
|
||||
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
|
||||
var userDictionaryPrompt: String = ""
|
||||
@@ -289,7 +324,6 @@ class ZenzContext {
|
||||
}
|
||||
}
|
||||
// 左文脈を取得
|
||||
// プロフィールがある場合はこれを条件に追加
|
||||
let leftSideContext: String = switch versionDependentConfig {
|
||||
case .v1: ""
|
||||
case .v2(let mode):
|
||||
|
||||
Reference in New Issue
Block a user