feat: Zenzaiで学習機能を有効化する (#108)

* feat: add metadata for DicdataElement

* feat: ignore learned word in zenz evaluation

* feat: improve session command to support temporal memory
This commit is contained in:
Miwa
2024-06-27 00:32:44 +09:00
committed by GitHub
parent fbf09a76eb
commit f58f1603d4
5 changed files with 133 additions and 19 deletions

View File

@@ -15,6 +15,8 @@ extension Subcommands {
var zenzWeightPath: String = ""
@Flag(name: [.customLong("disable_prediction")], help: "Disable producing prediction candidates.")
var disablePrediction = false
@Flag(name: [.customLong("enable_memory")], help: "Enable memory.")
var enableLearning = false
@Flag(name: [.customLong("only_whole_conversion")], help: "Show only whole conversion (完全一致変換).")
var onlyWholeConversion = false
@Flag(name: [.customLong("report_score")], help: "Show internal score for the candidate.")
@@ -27,39 +29,108 @@ extension Subcommands {
static var configuration = CommandConfiguration(commandName: "session", abstract: "Start session for incremental input.")
private func getTemporaryDirectory() -> URL? {
let fileManager = FileManager.default
let tempDirectoryURL = fileManager.temporaryDirectory.appendingPathComponent(UUID().uuidString)
do {
try fileManager.createDirectory(at: tempDirectoryURL, withIntermediateDirectories: true, attributes: nil)
print("Temporary directory created at \(tempDirectoryURL)")
return tempDirectoryURL
} catch {
print("Error creating temporary directory: \(error)")
return nil
}
}
@MainActor mutating func run() async {
let memoryDirector = if self.enableLearning {
if let dir = self.getTemporaryDirectory() {
dir
} else {
fatalError("Could not get temporary directory.")
}
} else {
URL(fileURLWithPath: "")
}
let converter = KanaKanjiConverter()
var composingText = ComposingText()
let inputStyle: InputStyle = self.roman2kana ? .roman2kana : .direct
var lastCandidates: [Candidate] = []
var page = 0
while true {
print()
print("\(bold: "== type :q to end session, type :d to delete character, type :c to stop composition, type any other text to input ==")")
print("\(bold: "== Type :q to end session, type :d to delete character, type :c to stop composition. For other commands, type :h ==")")
let input = readLine(strippingNewline: true) ?? ""
switch input {
case ":q": return
case ":q":
//
return
case ":d":
//
composingText.deleteBackwardFromCursorPosition(count: 1)
case ":c":
//
composingText.stopComposition()
converter.stopComposition()
print("composition is stopped")
continue
case ":n":
//
page += 1
for (i, candidate) in lastCandidates[self.displayTopN * page ..< self.displayTopN * (page + 1)].indexed() {
if self.reportScore {
print("\(bold: String(i)). \(candidate.text) \(bold: "score:") \(candidate.value)")
} else {
print("\(bold: String(i)). \(candidate.text)")
}
}
continue
case ":h":
//
print("""
\(bold: "== anco session commands ==")
\(bold: ":q") - quit session
\(bold: ":c") - clear composition
\(bold: ":d") - delete one character
\(bold: ":n") - see more candidates
\(bold: ":%d") - select candidate at that index (like :3 to select 3rd candidate)
""")
default:
composingText.insertAtCursorPosition(input, inputStyle: inputStyle)
if input.hasPrefix(":"), let index = Int(input.dropFirst()) {
if !lastCandidates.indices.contains(index) {
print("\(bold: "Error"): Index \(index) is not available for current context.")
continue
}
let candidate = lastCandidates[index]
print("Submit \(candidate.text)")
converter.setCompletedData(candidate)
converter.updateLearningData(candidate)
composingText.prefixComplete(correspondingCount: candidate.correspondingCount)
if composingText.isEmpty {
composingText.stopComposition()
converter.stopComposition()
}
} else {
composingText.insertAtCursorPosition(input, inputStyle: inputStyle)
}
}
print(composingText.convertTarget)
let start = Date()
let result = converter.requestCandidates(composingText, options: requestOptions())
let result = converter.requestCandidates(composingText, options: requestOptions(memoryDirector: memoryDirector))
let mainResults = result.mainResults.filter {
!self.onlyWholeConversion || $0.data.reduce(into: "", {$0.append(contentsOf: $1.ruby)}) == input.toKatakana()
}
for candidate in mainResults.prefix(self.displayTopN) {
for (i, candidate) in mainResults.prefix(self.displayTopN).indexed() {
if self.reportScore {
print("\(candidate.text) \(bold: "score:") \(candidate.value)")
print("\(bold: String(i)). \(candidate.text) \(bold: "score:") \(candidate.value)")
} else {
print(candidate.text)
print("\(bold: String(i)). \(candidate.text)")
}
}
lastCandidates = mainResults
page = 0
if self.onlyWholeConversion {
// entropy
let mean = mainResults.reduce(into: 0) { $0 += Double($1.value) } / Double(mainResults.count)
@@ -74,7 +145,7 @@ extension Subcommands {
}
}
func requestOptions() -> ConvertRequestOptions {
func requestOptions(memoryDirector: URL) -> ConvertRequestOptions {
var option: ConvertRequestOptions = .withDefaultDictionary(
N_best: self.onlyWholeConversion ? max(self.configNBest, self.displayTopN) : self.configNBest,
requireJapanesePrediction: !self.onlyWholeConversion && !self.disablePrediction,
@@ -85,10 +156,10 @@ extension Subcommands {
englishCandidateInRoman2KanaInput: true,
fullWidthRomanCandidate: false,
halfWidthKanaCandidate: false,
learningType: .nothing,
learningType: enableLearning ? .inputAndOutput : .nothing,
maxMemoryCount: 0,
shouldResetMemory: false,
memoryDirectoryURL: URL(fileURLWithPath: ""),
memoryDirectoryURL: memoryDirector,
sharedContainerURL: URL(fileURLWithPath: ""),
zenzaiMode: self.zenzWeightPath.isEmpty ? .off : .on(weight: URL(string: self.zenzWeightPath)!, inferenceLimit: self.configZenzaiInferenceLimit),
metadata: .init(versionString: "anco for debugging")

View File

@@ -12,7 +12,7 @@ public struct DicdataElement: Equatable, Hashable, Sendable {
static let BOSData = Self(word: "", ruby: "", cid: CIDData.BOS.cid, mid: MIDData.BOS.mid, value: 0, adjust: 0)
static let EOSData = Self(word: "", ruby: "", cid: CIDData.EOS.cid, mid: MIDData.EOS.mid, value: 0, adjust: 0)
public init(word: String, ruby: String, lcid: Int, rcid: Int, mid: Int, value: PValue, adjust: PValue = .zero) {
public init(word: String, ruby: String, lcid: Int, rcid: Int, mid: Int, value: PValue, adjust: PValue = .zero, metadata: DicdataElementMetadata = .empty) {
self.word = word
self.ruby = ruby
self.lcid = lcid
@@ -20,9 +20,10 @@ public struct DicdataElement: Equatable, Hashable, Sendable {
self.mid = mid
self.baseValue = value
self.adjust = adjust
self.metadata = metadata
}
public init(word: String, ruby: String, cid: Int, mid: Int, value: PValue, adjust: PValue = .zero) {
public init(word: String, ruby: String, cid: Int, mid: Int, value: PValue, adjust: PValue = .zero, metadata: DicdataElementMetadata = .empty) {
self.word = word
self.ruby = ruby
self.lcid = cid
@@ -30,9 +31,10 @@ public struct DicdataElement: Equatable, Hashable, Sendable {
self.mid = mid
self.baseValue = value
self.adjust = adjust
self.metadata = metadata
}
public init(ruby: String, cid: Int, mid: Int, value: PValue, adjust: PValue = .zero) {
public init(ruby: String, cid: Int, mid: Int, value: PValue, adjust: PValue = .zero, metadata: DicdataElementMetadata = .empty) {
self.word = ruby
self.ruby = ruby
self.lcid = cid
@@ -40,6 +42,7 @@ public struct DicdataElement: Equatable, Hashable, Sendable {
self.mid = mid
self.baseValue = value
self.adjust = adjust
self.metadata = metadata
}
public consuming func adjustedData(_ adjustValue: PValue) -> Self {
@@ -54,6 +57,7 @@ public struct DicdataElement: Equatable, Hashable, Sendable {
public var mid: Int
var baseValue: PValue
public var adjust: PValue
public var metadata: DicdataElementMetadata
public func value() -> PValue {
min(.zero, self.baseValue + self.adjust)
@@ -76,3 +80,14 @@ extension DicdataElement: CustomDebugStringConvertible {
"(ruby: \(self.ruby), word: \(self.word), cid: (\(self.lcid), \(self.rcid)), mid: \(self.mid), value: \(self.baseValue)+\(self.adjust)=\(self.value()))"
}
}
public struct DicdataElementMetadata: OptionSet, Sendable, Hashable, Equatable {
public let rawValue: UInt32
public init(rawValue: UInt32) {
self.rawValue = rawValue
}
public static let empty: Self = []
///
public static let isLearned = DicdataElementMetadata(rawValue: 1 << 0) // 1
}

View File

@@ -213,6 +213,11 @@ public final class DicdataStore {
for (key, value) in dict {
data.append(contentsOf: LOUDS.getDataForLoudstxt3(identifier + "\(key)", indices: value.map {$0 & 2047}, option: self.requestOptions))
}
if identifier == "memory" {
data.mutatingForeach {
$0.metadata = .isLearned
}
}
return data
}

View File

@@ -35,7 +35,7 @@ import SwiftUtils
return .error
}
for candidate in candidates {
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate.text)
let result = zenzContext.evaluate_candidate(input: convertTarget.toKatakana(), candidate: candidate)
return result
}
return .error

View File

@@ -127,13 +127,29 @@ class ZenzContext {
case wholeResult(String)
}
func evaluate_candidate(input: String, candidate: String) -> CandidateEvaluationResult {
func getNextDicdataElement(for prefix: String, of candidate: Candidate) -> DicdataElement? {
var curPrefix = ""
for datum in candidate.data {
if curPrefix == prefix {
// prefix
return datum
} else if curPrefix.hasPrefix(prefix) {
// FIXME:
return nil
}
curPrefix.append(datum.word)
}
return nil
}
func evaluate_candidate(input: String, candidate: Candidate) -> CandidateEvaluationResult {
print("Evaluate", candidate)
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
let prompt = "\u{EE00}\(input)\u{EE01}"
// Therefore, tokens = prompt_tokens + candidate_tokens is an appropriate operation.
let prompt_tokens = self.tokenize(text: prompt, add_bos: true, add_eos: false)
let candidate_tokens = self.tokenize(text: candidate, add_bos: false, add_eos: false)
let candidate_tokens = self.tokenize(text: candidate.text, add_bos: false, add_eos: false)
let tokens = prompt_tokens + candidate_tokens
let startOffset = prompt_tokens.count - 1
let pos_max = llama_kv_cache_seq_pos_max(self.context, 0)
@@ -178,12 +194,19 @@ class ZenzContext {
var cchars = tokens[..<i].reduce(into: []) {
$0.append(contentsOf: token_to_piece(token: $1))
}
let acceptedPrefix = String(cString: cchars + [0]).dropFirst(prompt.count)
// adding "\0"
cchars += token_to_piece(token: max_token) + [0]
let string = String(cString: cchars)
//
let prefixConstraint = String(string.dropFirst(prompt.count))
return .fixRequired(prefixConstraint: prefixConstraint)
if let nextDicdataElement = getNextDicdataElement(for: String(acceptedPrefix), of: candidate), nextDicdataElement.metadata.contains(.isLearned) {
//
// pass
} else {
//
let prefixConstraint = String(string.dropFirst(prompt.count))
return .fixRequired(prefixConstraint: prefixConstraint)
}
}
}
score += log(max_exp) - log(exp_sum)