mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-08-22 15:05:26 +00:00
[Fix] 制約の適用を調整した (#96)
* zenz-v1側は制約をコードポイントレベルで喋るので、grapheme clusterに非依存の処理に書き換えた * EOSを考慮するように変更
This commit is contained in:
@ -18,10 +18,11 @@ extension Kana2Kanji {
|
||||
/// (3)(1)のregisterされた結果をresultノードに追加していく。この際EOSとの連接計算を行っておく。
|
||||
///
|
||||
/// (4)ノードをアップデートした上で返却する。
|
||||
func kana2lattice_all_with_prefix_constraint(_ inputData: ComposingText, N_best: Int, constraint: String) -> (result: LatticeNode, nodes: Nodes) {
|
||||
func kana2lattice_all_with_prefix_constraint(_ inputData: ComposingText, N_best: Int, constraint: PrefixConstraint) -> (result: LatticeNode, nodes: Nodes) {
|
||||
debug("新規に計算を行います。inputされた文字列は\(inputData.input.count)文字分の\(inputData.convertTarget)。制約は\(constraint)")
|
||||
let count: Int = inputData.input.count
|
||||
let result: LatticeNode = LatticeNode.EOSNode
|
||||
let utf16Constraint = Array(constraint.constraint.utf16)
|
||||
let nodes: [[LatticeNode]] = (.zero ..< count).map {dicdataStore.getFrozenLOUDSDataInRange(inputData: inputData, from: $0)}
|
||||
// 「i文字目から始まるnodes」に対して
|
||||
for (i, nodeArray) in nodes.enumerated() {
|
||||
@ -49,13 +50,14 @@ extension Kana2Kanji {
|
||||
for index in node.prevs.indices {
|
||||
let newnode: RegisteredNode = node.getRegisteredNode(index, value: node.values[index])
|
||||
let text = newnode.getCandidateData().data.reduce(into: "") { $0.append(contentsOf: $1.word)} + node.data.word
|
||||
if text.hasPrefix(constraint) {
|
||||
let condition = (!constraint.hasEOS && text.utf16.hasPrefix(utf16Constraint)) || (constraint.hasEOS && text == constraint.constraint)
|
||||
if condition {
|
||||
result.prevs.append(newnode)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let candidates = node.getCandidateData().map {
|
||||
$0.data.reduce(into: "") { $0.append(contentsOf: $1.word)} + node.data.word
|
||||
let candidates: [[String.UTF16View.Element]] = node.getCandidateData().map {
|
||||
Array(($0.data.reduce(into: "") { $0.append(contentsOf: $1.word)} + node.data.word).utf16)
|
||||
}
|
||||
// nodeの繋がる次にあり得る全てのnextnodeに対して
|
||||
for nextnode in nodes[nextIndex] {
|
||||
@ -72,8 +74,9 @@ extension Kana2Kanji {
|
||||
// 制約 AB 単語 ABC (OK)
|
||||
// 制約 AB 単語 A (OK)
|
||||
// 制約 AB 単語 AC (NG)
|
||||
let text = candidates[index] + nextnode.data.word
|
||||
if !text.hasPrefix(constraint) && !constraint.hasPrefix(text) {
|
||||
let text = candidates[index] + nextnode.data.word.utf16
|
||||
let condition = (!constraint.hasEOS && (text.hasPrefix(utf16Constraint) || utf16Constraint.hasPrefix(text))) || (constraint.hasEOS && text.count < utf16Constraint.count && utf16Constraint.hasPrefix(text))
|
||||
guard condition else {
|
||||
continue
|
||||
}
|
||||
let newValue: PValue = ccValue + value
|
||||
|
@ -3,17 +3,17 @@ import SwiftUtils
|
||||
|
||||
extension Kana2Kanji {
|
||||
struct ZenzaiCache: Sendable {
|
||||
init(_ inputData: ComposingText, constraint: String, satisfyingCandidate: Candidate?) {
|
||||
init(_ inputData: ComposingText, constraint: PrefixConstraint, satisfyingCandidate: Candidate?) {
|
||||
self.inputData = inputData
|
||||
self.prefixConstraint = constraint
|
||||
self.satisfyingCandidate = satisfyingCandidate
|
||||
}
|
||||
|
||||
private var prefixConstraint: String
|
||||
private var prefixConstraint: PrefixConstraint
|
||||
private var satisfyingCandidate: Candidate?
|
||||
private var inputData: ComposingText
|
||||
|
||||
func getNewConstraint(for newInputData: ComposingText) -> String {
|
||||
func getNewConstraint(for newInputData: ComposingText) -> PrefixConstraint {
|
||||
if let satisfyingCandidate {
|
||||
var current = newInputData.convertTarget.toKatakana()[...]
|
||||
var constraint = ""
|
||||
@ -23,18 +23,28 @@ extension Kana2Kanji {
|
||||
current = current.dropFirst(item.ruby.count)
|
||||
}
|
||||
}
|
||||
return constraint
|
||||
return PrefixConstraint(constraint)
|
||||
} else if newInputData.convertTarget.hasPrefix(inputData.convertTarget) {
|
||||
return self.prefixConstraint
|
||||
} else {
|
||||
return ""
|
||||
return PrefixConstraint("")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PrefixConstraint: Equatable {
|
||||
init(_ constraint: String, hasEOS: Bool = false) {
|
||||
self.constraint = constraint
|
||||
self.hasEOS = hasEOS
|
||||
}
|
||||
|
||||
var constraint: String
|
||||
var hasEOS: Bool
|
||||
}
|
||||
|
||||
/// zenzaiシステムによる完全変換。
|
||||
@MainActor func all_zenzai(_ inputData: ComposingText, zenz: Zenz, zenzaiCache: ZenzaiCache?, inferenceLimit: Int) -> (result: LatticeNode, nodes: Nodes, cache: ZenzaiCache) {
|
||||
var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? ""
|
||||
var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? PrefixConstraint("")
|
||||
print("initial constraint", constraint)
|
||||
let eosNode = LatticeNode.EOSNode
|
||||
var nodes: Kana2Kanji.Nodes = []
|
||||
@ -60,7 +70,7 @@ extension Kana2Kanji {
|
||||
print("best was not found!")
|
||||
// Emptyの場合
|
||||
// 制約が満たせない場合は無視する
|
||||
return (eosNode, nodes, ZenzaiCache(inputData, constraint: "", satisfyingCandidate: nil))
|
||||
return (eosNode, nodes, ZenzaiCache(inputData, constraint: PrefixConstraint(""), satisfyingCandidate: nil))
|
||||
}
|
||||
print("Constrained draft modeling", -start.timeIntervalSinceNow)
|
||||
reviewLoop: while true {
|
||||
@ -97,7 +107,7 @@ extension Kana2Kanji {
|
||||
}
|
||||
|
||||
private enum NextAction {
|
||||
case `return`(constraint: String, satisfied: Bool)
|
||||
case `return`(constraint: PrefixConstraint, satisfied: Bool)
|
||||
case `continue`
|
||||
case `retry`(candidateIndex: Int)
|
||||
}
|
||||
@ -106,7 +116,7 @@ extension Kana2Kanji {
|
||||
candidateIndex: Int,
|
||||
candidates: [Candidate],
|
||||
reviewResult: consuming ZenzContext.CandidateEvaluationResult,
|
||||
constraint: inout String
|
||||
constraint: inout PrefixConstraint
|
||||
) -> NextAction {
|
||||
switch reviewResult {
|
||||
case .error:
|
||||
@ -119,13 +129,13 @@ extension Kana2Kanji {
|
||||
return .return(constraint: constraint, satisfied: true)
|
||||
case .fixRequired(let prefixConstraint):
|
||||
// 同じ制約が2回連続で出てきたら諦める
|
||||
if constraint == prefixConstraint {
|
||||
if constraint.constraint == prefixConstraint {
|
||||
print("same constraint:", prefixConstraint)
|
||||
return .return(constraint: "", satisfied: false)
|
||||
return .return(constraint: PrefixConstraint(""), satisfied: false)
|
||||
}
|
||||
// 制約が得られたので、更新する
|
||||
print("update constraint:", prefixConstraint)
|
||||
constraint = prefixConstraint
|
||||
constraint = PrefixConstraint(prefixConstraint)
|
||||
// もし制約を満たす候補があるならそれを使って再レビューチャレンジを戦うことで、推論を減らせる
|
||||
for i in candidates.indices where i != candidateIndex {
|
||||
if candidates[i].text.hasPrefix(prefixConstraint) {
|
||||
@ -134,6 +144,24 @@ extension Kana2Kanji {
|
||||
}
|
||||
}
|
||||
return .continue
|
||||
case .wholeResult(let wholeConstraint):
|
||||
let newConstraint = PrefixConstraint(wholeConstraint, hasEOS: true)
|
||||
// 同じ制約が2回連続で出てきたら諦める
|
||||
if constraint == newConstraint {
|
||||
print("same constraint:", constraint)
|
||||
return .return(constraint: PrefixConstraint(""), satisfied: false)
|
||||
}
|
||||
// 制約が得られたので、更新する
|
||||
print("update whole constraint:", wholeConstraint)
|
||||
constraint = PrefixConstraint(wholeConstraint, hasEOS: true)
|
||||
// もし制約を満たす候補があるならそれを使って再レビューチャレンジを戦うことで、推論を減らせる
|
||||
for i in candidates.indices where i != candidateIndex {
|
||||
if candidates[i].text == wholeConstraint {
|
||||
print("found \(candidates[i].text) as another retry")
|
||||
return .retry(candidateIndex: i)
|
||||
}
|
||||
}
|
||||
return .continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -124,6 +124,7 @@ class ZenzContext {
|
||||
case error
|
||||
case pass(score: Float)
|
||||
case fixRequired(prefixConstraint: String)
|
||||
case wholeResult(String)
|
||||
}
|
||||
|
||||
func evaluate_candidate(input: String, candidate: String) -> CandidateEvaluationResult {
|
||||
@ -163,15 +164,27 @@ class ZenzContext {
|
||||
}
|
||||
// ここで最も良い候補であったかをチェックする
|
||||
if max_token != token_id {
|
||||
var cchars = tokens[..<i].reduce(into: []) {
|
||||
$0.append(contentsOf: token_to_piece(token: $1))
|
||||
if max_token == llama_token_eos(model) {
|
||||
var cchars = tokens[..<i].reduce(into: []) {
|
||||
$0.append(contentsOf: token_to_piece(token: $1))
|
||||
}
|
||||
// adding "\0"
|
||||
cchars.append(0)
|
||||
let string = String(cString: cchars)
|
||||
// 要求するべき制約を記述する
|
||||
let wholeResult = String(string.dropFirst(prompt.count))
|
||||
return .wholeResult(wholeResult)
|
||||
} else {
|
||||
var cchars = tokens[..<i].reduce(into: []) {
|
||||
$0.append(contentsOf: token_to_piece(token: $1))
|
||||
}
|
||||
// adding "\0"
|
||||
cchars += token_to_piece(token: max_token) + [0]
|
||||
let string = String(cString: cchars)
|
||||
// 要求するべき制約を記述する
|
||||
let prefixConstraint = String(string.dropFirst(prompt.count))
|
||||
return .fixRequired(prefixConstraint: prefixConstraint)
|
||||
}
|
||||
// adding "\0"
|
||||
cchars += token_to_piece(token: max_token) + [0]
|
||||
let string = String(cString: cchars)
|
||||
// 要求するべき制約を記述する
|
||||
let prefixConstraint = String(string.dropFirst(prompt.count))
|
||||
return .fixRequired(prefixConstraint: prefixConstraint)
|
||||
}
|
||||
score += log(max_exp) - log(exp_sum)
|
||||
}
|
||||
|
@ -86,6 +86,20 @@ public extension Collection {
|
||||
}
|
||||
|
||||
public extension Collection where Self.Element: Equatable {
|
||||
/// Returns a Bool value indicating whether the collection has the given prefix.
|
||||
/// - Parameters:
|
||||
/// - prefix: A collection to search for at the start of this collection.
|
||||
/// - Returns: A Bool value indicating whether the collection has the given prefix.
|
||||
@inlinable func hasPrefix(_ prefix: some Collection<Element>) -> Bool {
|
||||
if self.count < prefix.count {
|
||||
return false
|
||||
}
|
||||
for (u, v) in zip(self, prefix) where u != v {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
/// Returns a Bool value indicating whether the collection has the given suffix.
|
||||
/// - Parameters:
|
||||
/// - suffix: A collection to search for at the end of this collection.
|
||||
|
Reference in New Issue
Block a user