[Fix] 制約の適用を調整した (#96)

* zenz-v1側は制約をコードポイントレベルで喋るので、grapheme clusterに非依存の処理に書き換えた

* EOSを考慮するように変更
This commit is contained in:
Miwa
2024-05-19 19:35:28 +09:00
committed by GitHub
parent dfef9631a9
commit c8ca5b54c0
4 changed files with 84 additions and 26 deletions

View File

@ -18,10 +18,11 @@ extension Kana2Kanji {
/// (3)(1)registerresultEOS
///
/// (4)
func kana2lattice_all_with_prefix_constraint(_ inputData: ComposingText, N_best: Int, constraint: String) -> (result: LatticeNode, nodes: Nodes) {
func kana2lattice_all_with_prefix_constraint(_ inputData: ComposingText, N_best: Int, constraint: PrefixConstraint) -> (result: LatticeNode, nodes: Nodes) {
debug("新規に計算を行います。inputされた文字列は\(inputData.input.count)文字分の\(inputData.convertTarget)。制約は\(constraint)")
let count: Int = inputData.input.count
let result: LatticeNode = LatticeNode.EOSNode
let utf16Constraint = Array(constraint.constraint.utf16)
let nodes: [[LatticeNode]] = (.zero ..< count).map {dicdataStore.getFrozenLOUDSDataInRange(inputData: inputData, from: $0)}
// inodes
for (i, nodeArray) in nodes.enumerated() {
@ -49,13 +50,14 @@ extension Kana2Kanji {
for index in node.prevs.indices {
let newnode: RegisteredNode = node.getRegisteredNode(index, value: node.values[index])
let text = newnode.getCandidateData().data.reduce(into: "") { $0.append(contentsOf: $1.word)} + node.data.word
if text.hasPrefix(constraint) {
let condition = (!constraint.hasEOS && text.utf16.hasPrefix(utf16Constraint)) || (constraint.hasEOS && text == constraint.constraint)
if condition {
result.prevs.append(newnode)
}
}
} else {
let candidates = node.getCandidateData().map {
$0.data.reduce(into: "") { $0.append(contentsOf: $1.word)} + node.data.word
let candidates: [[String.UTF16View.Element]] = node.getCandidateData().map {
Array(($0.data.reduce(into: "") { $0.append(contentsOf: $1.word)} + node.data.word).utf16)
}
// nodenextnode
for nextnode in nodes[nextIndex] {
@ -72,8 +74,9 @@ extension Kana2Kanji {
// AB ABC (OK)
// AB A (OK)
// AB AC (NG)
let text = candidates[index] + nextnode.data.word
if !text.hasPrefix(constraint) && !constraint.hasPrefix(text) {
let text = candidates[index] + nextnode.data.word.utf16
let condition = (!constraint.hasEOS && (text.hasPrefix(utf16Constraint) || utf16Constraint.hasPrefix(text))) || (constraint.hasEOS && text.count < utf16Constraint.count && utf16Constraint.hasPrefix(text))
guard condition else {
continue
}
let newValue: PValue = ccValue + value

View File

@ -3,17 +3,17 @@ import SwiftUtils
extension Kana2Kanji {
struct ZenzaiCache: Sendable {
init(_ inputData: ComposingText, constraint: String, satisfyingCandidate: Candidate?) {
init(_ inputData: ComposingText, constraint: PrefixConstraint, satisfyingCandidate: Candidate?) {
self.inputData = inputData
self.prefixConstraint = constraint
self.satisfyingCandidate = satisfyingCandidate
}
private var prefixConstraint: String
private var prefixConstraint: PrefixConstraint
private var satisfyingCandidate: Candidate?
private var inputData: ComposingText
func getNewConstraint(for newInputData: ComposingText) -> String {
func getNewConstraint(for newInputData: ComposingText) -> PrefixConstraint {
if let satisfyingCandidate {
var current = newInputData.convertTarget.toKatakana()[...]
var constraint = ""
@ -23,18 +23,28 @@ extension Kana2Kanji {
current = current.dropFirst(item.ruby.count)
}
}
return constraint
return PrefixConstraint(constraint)
} else if newInputData.convertTarget.hasPrefix(inputData.convertTarget) {
return self.prefixConstraint
} else {
return ""
return PrefixConstraint("")
}
}
}
struct PrefixConstraint: Equatable {
init(_ constraint: String, hasEOS: Bool = false) {
self.constraint = constraint
self.hasEOS = hasEOS
}
var constraint: String
var hasEOS: Bool
}
/// zenzai
@MainActor func all_zenzai(_ inputData: ComposingText, zenz: Zenz, zenzaiCache: ZenzaiCache?, inferenceLimit: Int) -> (result: LatticeNode, nodes: Nodes, cache: ZenzaiCache) {
var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? ""
var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? PrefixConstraint("")
print("initial constraint", constraint)
let eosNode = LatticeNode.EOSNode
var nodes: Kana2Kanji.Nodes = []
@ -60,7 +70,7 @@ extension Kana2Kanji {
print("best was not found!")
// Empty
//
return (eosNode, nodes, ZenzaiCache(inputData, constraint: "", satisfyingCandidate: nil))
return (eosNode, nodes, ZenzaiCache(inputData, constraint: PrefixConstraint(""), satisfyingCandidate: nil))
}
print("Constrained draft modeling", -start.timeIntervalSinceNow)
reviewLoop: while true {
@ -97,7 +107,7 @@ extension Kana2Kanji {
}
private enum NextAction {
case `return`(constraint: String, satisfied: Bool)
case `return`(constraint: PrefixConstraint, satisfied: Bool)
case `continue`
case `retry`(candidateIndex: Int)
}
@ -106,7 +116,7 @@ extension Kana2Kanji {
candidateIndex: Int,
candidates: [Candidate],
reviewResult: consuming ZenzContext.CandidateEvaluationResult,
constraint: inout String
constraint: inout PrefixConstraint
) -> NextAction {
switch reviewResult {
case .error:
@ -119,13 +129,13 @@ extension Kana2Kanji {
return .return(constraint: constraint, satisfied: true)
case .fixRequired(let prefixConstraint):
// 2
if constraint == prefixConstraint {
if constraint.constraint == prefixConstraint {
print("same constraint:", prefixConstraint)
return .return(constraint: "", satisfied: false)
return .return(constraint: PrefixConstraint(""), satisfied: false)
}
//
print("update constraint:", prefixConstraint)
constraint = prefixConstraint
constraint = PrefixConstraint(prefixConstraint)
// 使
for i in candidates.indices where i != candidateIndex {
if candidates[i].text.hasPrefix(prefixConstraint) {
@ -134,6 +144,24 @@ extension Kana2Kanji {
}
}
return .continue
case .wholeResult(let wholeConstraint):
let newConstraint = PrefixConstraint(wholeConstraint, hasEOS: true)
// 2
if constraint == newConstraint {
print("same constraint:", constraint)
return .return(constraint: PrefixConstraint(""), satisfied: false)
}
//
print("update whole constraint:", wholeConstraint)
constraint = PrefixConstraint(wholeConstraint, hasEOS: true)
// 使
for i in candidates.indices where i != candidateIndex {
if candidates[i].text == wholeConstraint {
print("found \(candidates[i].text) as another retry")
return .retry(candidateIndex: i)
}
}
return .continue
}
}
}

View File

@ -124,6 +124,7 @@ class ZenzContext {
case error
case pass(score: Float)
case fixRequired(prefixConstraint: String)
case wholeResult(String)
}
func evaluate_candidate(input: String, candidate: String) -> CandidateEvaluationResult {
@ -163,15 +164,27 @@ class ZenzContext {
}
//
if max_token != token_id {
var cchars = tokens[..<i].reduce(into: []) {
$0.append(contentsOf: token_to_piece(token: $1))
if max_token == llama_token_eos(model) {
var cchars = tokens[..<i].reduce(into: []) {
$0.append(contentsOf: token_to_piece(token: $1))
}
// adding "\0"
cchars.append(0)
let string = String(cString: cchars)
//
let wholeResult = String(string.dropFirst(prompt.count))
return .wholeResult(wholeResult)
} else {
var cchars = tokens[..<i].reduce(into: []) {
$0.append(contentsOf: token_to_piece(token: $1))
}
// adding "\0"
cchars += token_to_piece(token: max_token) + [0]
let string = String(cString: cchars)
//
let prefixConstraint = String(string.dropFirst(prompt.count))
return .fixRequired(prefixConstraint: prefixConstraint)
}
// adding "\0"
cchars += token_to_piece(token: max_token) + [0]
let string = String(cString: cchars)
//
let prefixConstraint = String(string.dropFirst(prompt.count))
return .fixRequired(prefixConstraint: prefixConstraint)
}
score += log(max_exp) - log(exp_sum)
}

View File

@ -86,6 +86,20 @@ public extension Collection {
}
public extension Collection where Self.Element: Equatable {
/// Returns a Bool value indicating whether the collection has the given prefix.
/// - Parameters:
/// - prefix: A collection to search for at the start of this collection.
/// - Returns: A Bool value indicating whether the collection has the given prefix.
@inlinable func hasPrefix(_ prefix: some Collection<Element>) -> Bool {
if self.count < prefix.count {
return false
}
for (u, v) in zip(self, prefix) where u != v {
return false
}
return true
}
/// Returns a Bool value indicating whether the collection has the given suffix.
/// - Parameters:
/// - suffix: A collection to search for at the end of this collection.