mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-12-03 02:58:27 +00:00
feat: Zenzaiで学習機能を有効化する (#108)
* feat: add metadata for DicdataElement * feat: ignore learned word in zenz evaluation * feat: improve session command to support temporal memory
This commit is contained in:
@@ -15,6 +15,8 @@ extension Subcommands {
|
||||
var zenzWeightPath: String = ""
|
||||
@Flag(name: [.customLong("disable_prediction")], help: "Disable producing prediction candidates.")
|
||||
var disablePrediction = false
|
||||
@Flag(name: [.customLong("enable_memory")], help: "Enable memory.")
|
||||
var enableLearning = false
|
||||
@Flag(name: [.customLong("only_whole_conversion")], help: "Show only whole conversion (完全一致変換).")
|
||||
var onlyWholeConversion = false
|
||||
@Flag(name: [.customLong("report_score")], help: "Show internal score for the candidate.")
|
||||
@@ -27,39 +29,108 @@ extension Subcommands {
|
||||
|
||||
static var configuration = CommandConfiguration(commandName: "session", abstract: "Start session for incremental input.")
|
||||
|
||||
private func getTemporaryDirectory() -> URL? {
|
||||
let fileManager = FileManager.default
|
||||
let tempDirectoryURL = fileManager.temporaryDirectory.appendingPathComponent(UUID().uuidString)
|
||||
|
||||
do {
|
||||
try fileManager.createDirectory(at: tempDirectoryURL, withIntermediateDirectories: true, attributes: nil)
|
||||
print("Temporary directory created at \(tempDirectoryURL)")
|
||||
return tempDirectoryURL
|
||||
} catch {
|
||||
print("Error creating temporary directory: \(error)")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
@MainActor mutating func run() async {
|
||||
let memoryDirector = if self.enableLearning {
|
||||
if let dir = self.getTemporaryDirectory() {
|
||||
dir
|
||||
} else {
|
||||
fatalError("Could not get temporary directory.")
|
||||
}
|
||||
} else {
|
||||
URL(fileURLWithPath: "")
|
||||
}
|
||||
|
||||
let converter = KanaKanjiConverter()
|
||||
var composingText = ComposingText()
|
||||
let inputStyle: InputStyle = self.roman2kana ? .roman2kana : .direct
|
||||
var lastCandidates: [Candidate] = []
|
||||
var page = 0
|
||||
while true {
|
||||
print()
|
||||
print("\(bold: "== type :q to end session, type :d to delete character, type :c to stop composition, type any other text to input ==")")
|
||||
print("\(bold: "== Type :q to end session, type :d to delete character, type :c to stop composition. For other commands, type :h ==")")
|
||||
let input = readLine(strippingNewline: true) ?? ""
|
||||
switch input {
|
||||
case ":q": return
|
||||
case ":q":
|
||||
// 終了
|
||||
return
|
||||
case ":d":
|
||||
// 削除
|
||||
composingText.deleteBackwardFromCursorPosition(count: 1)
|
||||
case ":c":
|
||||
// クリア
|
||||
composingText.stopComposition()
|
||||
converter.stopComposition()
|
||||
print("composition is stopped")
|
||||
continue
|
||||
case ":n":
|
||||
// ページ送り
|
||||
page += 1
|
||||
for (i, candidate) in lastCandidates[self.displayTopN * page ..< self.displayTopN * (page + 1)].indexed() {
|
||||
if self.reportScore {
|
||||
print("\(bold: String(i)). \(candidate.text) \(bold: "score:") \(candidate.value)")
|
||||
} else {
|
||||
print("\(bold: String(i)). \(candidate.text)")
|
||||
}
|
||||
}
|
||||
continue
|
||||
case ":h":
|
||||
// ヘルプ
|
||||
print("""
|
||||
\(bold: "== anco session commands ==")
|
||||
\(bold: ":q") - quit session
|
||||
\(bold: ":c") - clear composition
|
||||
\(bold: ":d") - delete one character
|
||||
\(bold: ":n") - see more candidates
|
||||
\(bold: ":%d") - select candidate at that index (like :3 to select 3rd candidate)
|
||||
""")
|
||||
default:
|
||||
composingText.insertAtCursorPosition(input, inputStyle: inputStyle)
|
||||
if input.hasPrefix(":"), let index = Int(input.dropFirst()) {
|
||||
if !lastCandidates.indices.contains(index) {
|
||||
print("\(bold: "Error"): Index \(index) is not available for current context.")
|
||||
continue
|
||||
}
|
||||
let candidate = lastCandidates[index]
|
||||
print("Submit \(candidate.text)")
|
||||
converter.setCompletedData(candidate)
|
||||
converter.updateLearningData(candidate)
|
||||
composingText.prefixComplete(correspondingCount: candidate.correspondingCount)
|
||||
if composingText.isEmpty {
|
||||
composingText.stopComposition()
|
||||
converter.stopComposition()
|
||||
}
|
||||
} else {
|
||||
composingText.insertAtCursorPosition(input, inputStyle: inputStyle)
|
||||
}
|
||||
}
|
||||
print(composingText.convertTarget)
|
||||
let start = Date()
|
||||
let result = converter.requestCandidates(composingText, options: requestOptions())
|
||||
let result = converter.requestCandidates(composingText, options: requestOptions(memoryDirector: memoryDirector))
|
||||
let mainResults = result.mainResults.filter {
|
||||
!self.onlyWholeConversion || $0.data.reduce(into: "", {$0.append(contentsOf: $1.ruby)}) == input.toKatakana()
|
||||
}
|
||||
for candidate in mainResults.prefix(self.displayTopN) {
|
||||
for (i, candidate) in mainResults.prefix(self.displayTopN).indexed() {
|
||||
if self.reportScore {
|
||||
print("\(candidate.text) \(bold: "score:") \(candidate.value)")
|
||||
print("\(bold: String(i)). \(candidate.text) \(bold: "score:") \(candidate.value)")
|
||||
} else {
|
||||
print(candidate.text)
|
||||
print("\(bold: String(i)). \(candidate.text)")
|
||||
}
|
||||
}
|
||||
lastCandidates = mainResults
|
||||
page = 0
|
||||
if self.onlyWholeConversion {
|
||||
// entropyを示す
|
||||
let mean = mainResults.reduce(into: 0) { $0 += Double($1.value) } / Double(mainResults.count)
|
||||
@@ -74,7 +145,7 @@ extension Subcommands {
|
||||
}
|
||||
}
|
||||
|
||||
func requestOptions() -> ConvertRequestOptions {
|
||||
func requestOptions(memoryDirector: URL) -> ConvertRequestOptions {
|
||||
var option: ConvertRequestOptions = .withDefaultDictionary(
|
||||
N_best: self.onlyWholeConversion ? max(self.configNBest, self.displayTopN) : self.configNBest,
|
||||
requireJapanesePrediction: !self.onlyWholeConversion && !self.disablePrediction,
|
||||
@@ -85,10 +156,10 @@ extension Subcommands {
|
||||
englishCandidateInRoman2KanaInput: true,
|
||||
fullWidthRomanCandidate: false,
|
||||
halfWidthKanaCandidate: false,
|
||||
learningType: .nothing,
|
||||
learningType: enableLearning ? .inputAndOutput : .nothing,
|
||||
maxMemoryCount: 0,
|
||||
shouldResetMemory: false,
|
||||
memoryDirectoryURL: URL(fileURLWithPath: ""),
|
||||
memoryDirectoryURL: memoryDirector,
|
||||
sharedContainerURL: URL(fileURLWithPath: ""),
|
||||
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit),
|
||||
metadata: .init(versionString: "anco for debugging")
|
||||
|
||||
@@ -12,7 +12,7 @@ public struct DicdataElement: Equatable, Hashable, Sendable {
|
||||
static let BOSData = Self(word: "", ruby: "", cid: CIDData.BOS.cid, mid: MIDData.BOS.mid, value: 0, adjust: 0)
|
||||
static let EOSData = Self(word: "", ruby: "", cid: CIDData.EOS.cid, mid: MIDData.EOS.mid, value: 0, adjust: 0)
|
||||
|
||||
public init(word: String, ruby: String, lcid: Int, rcid: Int, mid: Int, value: PValue, adjust: PValue = .zero) {
|
||||
public init(word: String, ruby: String, lcid: Int, rcid: Int, mid: Int, value: PValue, adjust: PValue = .zero, metadata: DicdataElementMetadata = .empty) {
|
||||
self.word = word
|
||||
self.ruby = ruby
|
||||
self.lcid = lcid
|
||||
@@ -20,9 +20,10 @@ public struct DicdataElement: Equatable, Hashable, Sendable {
|
||||
self.mid = mid
|
||||
self.baseValue = value
|
||||
self.adjust = adjust
|
||||
self.metadata = metadata
|
||||
}
|
||||
|
||||
public init(word: String, ruby: String, cid: Int, mid: Int, value: PValue, adjust: PValue = .zero) {
|
||||
public init(word: String, ruby: String, cid: Int, mid: Int, value: PValue, adjust: PValue = .zero, metadata: DicdataElementMetadata = .empty) {
|
||||
self.word = word
|
||||
self.ruby = ruby
|
||||
self.lcid = cid
|
||||
@@ -30,9 +31,10 @@ public struct DicdataElement: Equatable, Hashable, Sendable {
|
||||
self.mid = mid
|
||||
self.baseValue = value
|
||||
self.adjust = adjust
|
||||
self.metadata = metadata
|
||||
}
|
||||
|
||||
public init(ruby: String, cid: Int, mid: Int, value: PValue, adjust: PValue = .zero) {
|
||||
public init(ruby: String, cid: Int, mid: Int, value: PValue, adjust: PValue = .zero, metadata: DicdataElementMetadata = .empty) {
|
||||
self.word = ruby
|
||||
self.ruby = ruby
|
||||
self.lcid = cid
|
||||
@@ -40,6 +42,7 @@ public struct DicdataElement: Equatable, Hashable, Sendable {
|
||||
self.mid = mid
|
||||
self.baseValue = value
|
||||
self.adjust = adjust
|
||||
self.metadata = metadata
|
||||
}
|
||||
|
||||
public consuming func adjustedData(_ adjustValue: PValue) -> Self {
|
||||
@@ -54,6 +57,7 @@ public struct DicdataElement: Equatable, Hashable, Sendable {
|
||||
public var mid: Int
|
||||
var baseValue: PValue
|
||||
public var adjust: PValue
|
||||
public var metadata: DicdataElementMetadata
|
||||
|
||||
public func value() -> PValue {
|
||||
min(.zero, self.baseValue + self.adjust)
|
||||
@@ -76,3 +80,14 @@ extension DicdataElement: CustomDebugStringConvertible {
|
||||
"(ruby: \(self.ruby), word: \(self.word), cid: (\(self.lcid), \(self.rcid)), mid: \(self.mid), value: \(self.baseValue)+\(self.adjust)=\(self.value()))"
|
||||
}
|
||||
}
|
||||
|
||||
public struct DicdataElementMetadata: OptionSet, Sendable, Hashable, Equatable {
|
||||
public let rawValue: UInt32
|
||||
public init(rawValue: UInt32) {
|
||||
self.rawValue = rawValue
|
||||
}
|
||||
|
||||
public static let empty: Self = []
|
||||
/// 学習データから得られた候補にはこのフラグを立てる
|
||||
public static let isLearned = DicdataElementMetadata(rawValue: 1 << 0) // 1
|
||||
}
|
||||
|
||||
@@ -213,6 +213,11 @@ public final class DicdataStore {
|
||||
for (key, value) in dict {
|
||||
data.append(contentsOf: LOUDS.getDataForLoudstxt3(identifier + "\(key)", indices: value.map {$0 & 2047}, option: self.requestOptions))
|
||||
}
|
||||
if identifier == "memory" {
|
||||
data.mutatingForeach {
|
||||
$0.metadata = .isLearned
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ import SwiftUtils
|
||||
return .error
|
||||
}
|
||||
for candidate in candidates {
|
||||
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate.text)
|
||||
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate)
|
||||
return result
|
||||
}
|
||||
return .error
|
||||
|
||||
@@ -127,13 +127,29 @@ class ZenzContext {
|
||||
case wholeResult(String)
|
||||
}
|
||||
|
||||
func evaluate_candidate(input: String, candidate: String) -> CandidateEvaluationResult {
|
||||
func getNextDicdataElement(for prefix: String, of candidate: Candidate) -> DicdataElement? {
|
||||
var curPrefix = ""
|
||||
for datum in candidate.data {
|
||||
if curPrefix == prefix {
|
||||
// prefixの直後の単語を返したい
|
||||
return datum
|
||||
} else if curPrefix.hasPrefix(prefix) {
|
||||
// FIXME: 積極的な一語化の方針と相性が良くないはず
|
||||
return nil
|
||||
}
|
||||
curPrefix.append(datum.word)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func evaluate_candidate(input: String, candidate: Candidate) -> CandidateEvaluationResult {
|
||||
print("Evaluate", candidate)
|
||||
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
|
||||
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
|
||||
let prompt = "\u{EE00}\(input)\u{EE01}"
|
||||
// Therefore, tokens = prompt_tokens + candidate_tokens is an appropriate operation.
|
||||
let prompt_tokens = self.tokenize(text: prompt, add_bos: true, add_eos: false)
|
||||
let candidate_tokens = self.tokenize(text: candidate, add_bos: false, add_eos: false)
|
||||
let candidate_tokens = self.tokenize(text: candidate.text, add_bos: false, add_eos: false)
|
||||
let tokens = prompt_tokens + candidate_tokens
|
||||
let startOffset = prompt_tokens.count - 1
|
||||
let pos_max = llama_kv_cache_seq_pos_max(self.context, 0)
|
||||
@@ -178,12 +194,19 @@ class ZenzContext {
|
||||
var cchars = tokens[..<i].reduce(into: []) {
|
||||
$0.append(contentsOf: token_to_piece(token: $1))
|
||||
}
|
||||
let acceptedPrefix = String(cString: cchars + [0]).dropFirst(prompt.count)
|
||||
// adding "\0"
|
||||
cchars += token_to_piece(token: max_token) + [0]
|
||||
let string = String(cString: cchars)
|
||||
// 要求するべき制約を記述する
|
||||
let prefixConstraint = String(string.dropFirst(prompt.count))
|
||||
return .fixRequired(prefixConstraint: prefixConstraint)
|
||||
|
||||
if let nextDicdataElement = getNextDicdataElement(for: String(acceptedPrefix), of: candidate), nextDicdataElement.metadata.contains(.isLearned) {
|
||||
// 学習による候補なので素通しする
|
||||
// pass
|
||||
} else {
|
||||
// 要求するべき制約を記述する
|
||||
let prefixConstraint = String(string.dropFirst(prompt.count))
|
||||
return .fixRequired(prefixConstraint: prefixConstraint)
|
||||
}
|
||||
}
|
||||
}
|
||||
score += log(max_exp) - log(exp_sum)
|
||||
|
||||
Reference in New Issue
Block a user