Merge pull request #221 from azooKey/refactor/lattice_api

refactor: 変換ロジックのリファクタリング(`Lattice`構造体の導入、変換ロジックの共通化)
This commit is contained in:
Miwa
2025-07-09 00:19:07 +09:00
committed by GitHub
8 changed files with 120 additions and 178 deletions

View File

@ -10,9 +10,6 @@ import Foundation
import SwiftUtils
extension Kana2Kanji {
/// Lattice
typealias Nodes = [[LatticeNode]]
/// ,
/// - Parameters:
/// - inputData:
@ -29,13 +26,13 @@ extension Kana2Kanji {
/// (3)(1)registerresultEOS
///
/// (4)
func kana2lattice_all(_ inputData: ComposingText, N_best: Int, needTypoCorrection: Bool) -> (result: LatticeNode, nodes: Nodes) {
func kana2lattice_all(_ inputData: ComposingText, N_best: Int, needTypoCorrection: Bool) -> (result: LatticeNode, lattice: Lattice) {
debug("新規に計算を行います。inputされた文字列は\(inputData.input.count)文字分の\(inputData.convertTarget)")
let count: Int = inputData.input.count
let result: LatticeNode = LatticeNode.EOSNode
let nodes: [[LatticeNode]] = (.zero ..< count).map {dicdataStore.getLOUDSDataInRange(inputData: inputData, from: $0, needTypoCorrection: needTypoCorrection)}
let lattice: Lattice = Lattice(nodes: (.zero ..< count).map {dicdataStore.getLOUDSDataInRange(inputData: inputData, from: $0, needTypoCorrection: needTypoCorrection)})
// inodes
for (i, nodeArray) in nodes.enumerated() {
for (i, nodeArray) in lattice.nodes.enumerated() {
// node
for node in nodeArray {
if node.prevs.isEmpty {
@ -57,40 +54,45 @@ extension Kana2Kanji {
let nextIndex: Int = node.inputRange.endIndex
// count
if nextIndex == count {
for index in node.prevs.indices {
let newnode: RegisteredNode = node.getRegisteredNode(index, value: node.values[index])
result.prevs.append(newnode)
}
self.updateResultNode(with: node, resultNode: result)
} else {
// nodenextnode
for nextnode in nodes[nextIndex] {
// node.registered.isEmpty
if self.dicdataStore.shouldBeRemoved(data: nextnode.data) {
continue
}
//
let ccValue: PValue = self.dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid)
// nodeprevnode
for (index, value) in node.values.enumerated() {
let newValue: PValue = ccValue + value
// index
let lastindex: Int = (nextnode.prevs.lastIndex(where: {$0.totalValue >= newValue}) ?? -1) + 1
if lastindex == N_best {
continue
}
let newnode: RegisteredNode = node.getRegisteredNode(index, value: newValue)
//
if nextnode.prevs.count >= N_best {
nextnode.prevs.removeLast()
}
// removeinsert (insertO(N))
nextnode.prevs.insert(newnode, at: lastindex)
}
}
self.updateNextNodes(with: node, nextNodes: lattice[inputIndex: nextIndex], nBest: N_best)
}
}
}
return (result: result, nodes: nodes)
return (result: result, lattice: lattice)
}
func updateResultNode(with node: LatticeNode, resultNode: LatticeNode) {
for index in node.prevs.indices {
let newnode: RegisteredNode = node.getRegisteredNode(index, value: node.values[index])
resultNode.prevs.append(newnode)
}
}
/// N-Best
func updateNextNodes(with node: LatticeNode, nextNodes: [LatticeNode], nBest: Int) {
for nextnode in nextNodes {
if self.dicdataStore.shouldBeRemoved(data: nextnode.data) {
continue
}
//
let ccValue: PValue = self.dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid)
// nodeprevnode
for (index, value) in node.values.enumerated() {
let newValue: PValue = ccValue + value
// index
let lastindex: Int = (nextnode.prevs.lastIndex(where: {$0.totalValue >= newValue}) ?? -1) + 1
if lastindex == nBest {
continue
}
let newnode: RegisteredNode = node.getRegisteredNode(index, value: newValue)
//
if nextnode.prevs.count >= nBest {
nextnode.prevs.removeLast()
}
// removeinsert (insertO(N))
nextnode.prevs.insert(newnode, at: lastindex)
}
}
}
}

View File

@ -18,13 +18,13 @@ extension Kana2Kanji {
/// (3)(1)registerresultEOS
///
/// (4)
func kana2lattice_all_with_prefix_constraint(_ inputData: ComposingText, N_best: Int, constraint: PrefixConstraint) -> (result: LatticeNode, nodes: Nodes) {
func kana2lattice_all_with_prefix_constraint(_ inputData: ComposingText, N_best: Int, constraint: PrefixConstraint) -> (result: LatticeNode, lattice: Lattice) {
debug("新規に計算を行います。inputされた文字列は\(inputData.input.count)文字分の\(inputData.convertTarget)。制約は\(constraint)")
let count: Int = inputData.input.count
let result: LatticeNode = LatticeNode.EOSNode
let nodes: [[LatticeNode]] = (.zero ..< count).map {dicdataStore.getLOUDSDataInRange(inputData: inputData, from: $0, needTypoCorrection: false)}
let lattice: Lattice = Lattice(nodes: (.zero ..< count).map {dicdataStore.getLOUDSDataInRange(inputData: inputData, from: $0, needTypoCorrection: false)})
// inodes
for (i, nodeArray) in nodes.enumerated() {
for (i, nodeArray) in lattice.nodes.enumerated() {
// node
for node in nodeArray {
if node.prevs.isEmpty {
@ -61,7 +61,7 @@ extension Kana2Kanji {
Array(($0.data.reduce(into: "") { $0.append(contentsOf: $1.word)} + node.data.word).utf8)
}
// nodenextnode
for nextnode in nodes[nextIndex] {
for nextnode in lattice[inputIndex: nextIndex] {
//
let ccValue: PValue = self.dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid)
// nodeprevnode
@ -97,7 +97,7 @@ extension Kana2Kanji {
}
}
}
return (result: result, nodes: nodes)
return (result: result, lattice: lattice)
}
}

View File

@ -15,13 +15,13 @@ extension Kana2Kanji {
/// (1)noderegisteredcompletedDataBOS
///
/// (2)
func kana2lattice_afterComplete(_ inputData: ComposingText, completedData: Candidate, N_best: Int, previousResult: (inputData: ComposingText, nodes: Nodes), needTypoCorrection: Bool) -> (result: LatticeNode, nodes: Nodes) {
func kana2lattice_afterComplete(_ inputData: ComposingText, completedData: Candidate, N_best: Int, previousResult: (inputData: ComposingText, lattice: Lattice), needTypoCorrection: Bool) -> (result: LatticeNode, lattice: Lattice) {
debug("確定直後の変換、前は:", previousResult.inputData, "後は:", inputData)
let count = inputData.input.count
// (1)
let start = RegisteredNode.fromLastCandidate(completedData)
let nodes: Nodes = previousResult.nodes.suffix(count)
for (i, nodeArray) in nodes.enumerated() {
let lattice = Lattice(nodes: previousResult.lattice.nodes.suffix(count))
for (i, nodeArray) in lattice.nodes.enumerated() {
if i == .zero {
for node in nodeArray {
node.prevs = [start]
@ -39,7 +39,7 @@ extension Kana2Kanji {
// (2)
let result = LatticeNode.EOSNode
for (i, nodeArray) in nodes.enumerated() {
for (i, nodeArray) in lattice.nodes.enumerated() {
for node in nodeArray {
if node.prevs.isEmpty {
continue
@ -58,41 +58,14 @@ extension Kana2Kanji {
}
//
let nextIndex = node.inputRange.endIndex
// count
if nextIndex != count {
for nextnode in nodes[nextIndex] {
if self.dicdataStore.shouldBeRemoved(data: nextnode.data) {
continue
}
//
let ccValue = self.dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid)
// nodeprevnode
for (index, value) in node.values.enumerated() {
let newValue = ccValue + value
// index
let lastindex = (nextnode.prevs.lastIndex(where: {$0.totalValue >= newValue}) ?? -1) + 1
if lastindex == N_best {
continue
}
let newnode = node.getRegisteredNode(index, value: newValue)
//
if nextnode.prevs.count >= N_best {
nextnode.prevs.removeLast()
}
// removeinsert (insertO(N))
nextnode.prevs.insert(newnode, at: lastindex)
}
}
// count
self.updateNextNodes(with: node, nextNodes: lattice[inputIndex: nextIndex], nBest: N_best)
} else {
for index in node.prevs.indices {
let newnode = node.getRegisteredNode(index, value: node.values[index])
result.prevs.append(newnode)
}
self.updateResultNode(with: node, resultNode: result)
}
}
}
return (result: result, nodes: nodes)
return (result: result, lattice: lattice)
}
}

View File

@ -24,35 +24,30 @@ extension Kana2Kanji {
///
/// (5)
func kana2lattice_changed(_ inputData: ComposingText, N_best: Int, counts: (deleted: Int, added: Int), previousResult: (inputData: ComposingText, nodes: Nodes), needTypoCorrection: Bool) -> (result: LatticeNode, nodes: Nodes) {
func kana2lattice_changed(_ inputData: ComposingText, N_best: Int, counts: (deleted: Int, added: Int), previousResult: (inputData: ComposingText, lattice: Lattice), needTypoCorrection: Bool) -> (result: LatticeNode, lattice: Lattice) {
// (0)
let count = inputData.input.count
let commonCount = previousResult.inputData.input.count - counts.deleted
debug("kana2lattice_changed", inputData, counts, previousResult.inputData, count, commonCount)
// (1)
var nodes = previousResult.nodes.prefix(commonCount).map {(nodes: [LatticeNode]) in
nodes.filter {$0.inputRange.endIndex <= commonCount}
}
while nodes.last?.isEmpty ?? false {
nodes.removeLast()
}
var lattice = previousResult.lattice.prefix(commonCount)
let terminalNodes: Nodes
let terminalNodes: Lattice
if counts.added == 0 {
terminalNodes = nodes.map {
terminalNodes = Lattice(nodes: lattice.nodes.map {
$0.filter {
$0.inputRange.endIndex == count
}
}
})
} else {
// (2)
let addedNodes: [[LatticeNode]] = (0..<count).map {(i: Int) in
let addedNodes: Lattice = Lattice(nodes: (0..<count).map {(i: Int) in
self.dicdataStore.getLOUDSDataInRange(inputData: inputData, from: i, toIndexRange: max(commonCount, i) ..< count, needTypoCorrection: needTypoCorrection)
}
})
// (3)
for nodeArray in nodes {
for nodeArray in lattice.nodes {
for node in nodeArray {
if node.prevs.isEmpty {
continue
@ -62,38 +57,10 @@ extension Kana2Kanji {
}
//
let nextIndex = node.inputRange.endIndex
for nextnode in addedNodes[nextIndex] {
if self.dicdataStore.shouldBeRemoved(data: nextnode.data) {
continue
}
//
let ccValue: PValue = self.dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid)
// nodeprevnode
for (index, value) in node.values.enumerated() {
let newValue: PValue = ccValue + value
// index
let lastindex: Int = (nextnode.prevs.lastIndex(where: {$0.totalValue >= newValue}) ?? -1) + 1
if lastindex == N_best {
continue
}
let newnode: RegisteredNode = node.getRegisteredNode(index, value: newValue)
//
if nextnode.prevs.count >= N_best {
nextnode.prevs.removeLast()
}
// removeinsert (insertO(N))
nextnode.prevs.insert(newnode, at: lastindex)
}
}
self.updateNextNodes(with: node, nextNodes: addedNodes[inputIndex: nextIndex], nBest: N_best)
}
}
for (index, nodeArray) in addedNodes.enumerated() where index < nodes.endIndex {
nodes[index].append(contentsOf: nodeArray)
}
for nodeArray in addedNodes.suffix(counts.added) {
nodes.append(nodeArray)
}
lattice.merge(addedNodes)
terminalNodes = addedNodes
}
@ -101,7 +68,7 @@ extension Kana2Kanji {
// terminalNodes
let result = LatticeNode.EOSNode
for (i, nodes) in terminalNodes.enumerated() {
for (i, nodes) in terminalNodes.nodes.enumerated() {
for node in nodes {
if node.prevs.isEmpty {
continue
@ -121,40 +88,13 @@ extension Kana2Kanji {
}
let nextIndex = node.inputRange.endIndex
if count == nextIndex {
//
for index in node.prevs.indices {
let newnode = node.getRegisteredNode(index, value: node.values[index])
result.prevs.append(newnode)
}
self.updateResultNode(with: node, resultNode: result)
} else {
for nextnode in terminalNodes[nextIndex] {
// node.registered.isEmpty
if self.dicdataStore.shouldBeRemoved(data: nextnode.data) {
continue
}
//
let ccValue = self.dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid)
// nodeprevnode
for (index, value) in node.values.enumerated() {
let newValue = ccValue + value
// index
let lastindex: Int = (nextnode.prevs.lastIndex(where: {$0.totalValue >= newValue}) ?? -1) + 1
if lastindex == N_best {
continue
}
let newnode: RegisteredNode = node.getRegisteredNode(index, value: newValue)
//
if nextnode.prevs.count >= N_best {
nextnode.prevs.removeLast()
}
// removeinsert (insertO(N))
nextnode.prevs.insert(newnode, at: lastindex)
}
}
self.updateNextNodes(with: node, nextNodes: terminalNodes[inputIndex: nextIndex], nBest: N_best)
}
}
}
return (result: result, nodes: nodes)
return (result: result, lattice: lattice)
}
}

View File

@ -24,33 +24,26 @@ extension Kana2Kanji {
///
/// (2)
func kana2lattice_no_change(N_best: Int, previousResult: (inputData: ComposingText, nodes: Nodes)) -> (result: LatticeNode, nodes: Nodes) {
func kana2lattice_no_change(N_best: Int, previousResult: (inputData: ComposingText, lattice: Lattice)) -> (result: LatticeNode, lattice: Lattice) {
debug("キャッシュから復元、元の文字は:", previousResult.inputData.convertTarget)
let count = previousResult.inputData.input.count
// (1)
let result = LatticeNode.EOSNode
for nodeArray in previousResult.nodes {
for node in nodeArray {
for nodeArray in previousResult.lattice.nodes {
for node in nodeArray where node.inputRange.endIndex == count {
if node.prevs.isEmpty {
continue
}
if self.dicdataStore.shouldBeRemoved(data: node.data) {
continue
}
let nextIndex = node.inputRange.endIndex
if nextIndex == count {
//
for (index, value) in node.values.enumerated() {
let newnode = node.getRegisteredNode(index, value: value)
result.prevs.append(newnode)
}
}
self.updateResultNode(with: node, resultNode: result)
}
}
// (2)
return (result: result, nodes: previousResult.nodes)
return (result: result, lattice: previousResult.lattice)
}
}

View File

@ -0,0 +1,34 @@
struct Lattice {
init(nodes: [[LatticeNode]] = []) {
self.nodes = nodes
}
private(set) var nodes: [[LatticeNode]]
func prefix(_ k: Int) -> Lattice {
var lattice = Lattice(nodes: self.nodes.prefix(k).map {(nodes: [LatticeNode]) in
nodes.filter {$0.inputRange.endIndex <= k}
})
while lattice.nodes.last?.isEmpty ?? false {
lattice.nodes.removeLast()
}
return lattice
}
mutating func merge(_ lattice: Lattice) {
for (index, nodeArray) in lattice.nodes.enumerated() where index < self.nodes.endIndex {
self.nodes[index].append(contentsOf: nodeArray)
}
if self.nodes.endIndex < lattice.nodes.endIndex {
for nodeArray in lattice.nodes[self.nodes.endIndex...] {
self.nodes.append(nodeArray)
}
}
}
subscript(inputIndex i: Int) -> [LatticeNode] {
get {
self.nodes[i]
}
}
}

View File

@ -61,11 +61,11 @@ extension Kana2Kanji {
requestRichCandidates: Bool,
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: EfficientNGram, personal: EfficientNGram)?,
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
) -> (result: LatticeNode, nodes: Nodes, cache: ZenzaiCache) {
) -> (result: LatticeNode, lattice: Lattice, cache: ZenzaiCache) {
var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? PrefixConstraint([])
debug("initial constraint", constraint)
let eosNode = LatticeNode.EOSNode
var nodes: Kana2Kanji.Nodes = []
var lattice: Lattice = Lattice(nodes: [])
var constructedCandidates: [(RegisteredNode, Candidate)] = []
var insertedCandidates: [(RegisteredNode, Candidate)] = []
defer {
@ -82,9 +82,9 @@ extension Kana2Kanji {
// N=3
self.kana2lattice_all_with_prefix_constraint(inputData, N_best: 3, constraint: constraint)
}
if nodes.isEmpty {
if lattice.nodes.isEmpty {
//
nodes = draftResult.nodes
lattice = draftResult.lattice
}
let candidates = draftResult.result.getCandidateData().map(self.processClauseCandidate)
constructedCandidates.append(contentsOf: zip(draftResult.result.prevs, candidates))
@ -100,7 +100,7 @@ extension Kana2Kanji {
debug("best was not found!")
// Empty
//
return (eosNode, nodes, ZenzaiCache(inputData, constraint: PrefixConstraint([]), satisfyingCandidate: nil))
return (eosNode, lattice, ZenzaiCache(inputData, constraint: PrefixConstraint([]), satisfyingCandidate: nil))
}
debug("Constrained draft modeling", -start.timeIntervalSinceNow)
@ -111,7 +111,7 @@ extension Kana2Kanji {
if inferenceLimit == 0 {
debug("inference limit! \(candidate.text) is used for excuse")
// When inference occurs more than maximum times, then just return result at this point
return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
return (eosNode, lattice, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
}
let reviewResult = zenz.candidateEvaluate(
convertTarget: inputData.convertTarget,
@ -159,9 +159,9 @@ extension Kana2Kanji {
}
}
if satisfied {
return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
return (eosNode, lattice, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: candidate))
} else {
return (eosNode, nodes, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: nil))
return (eosNode, lattice, ZenzaiCache(inputData, constraint: constraint, satisfyingCandidate: nil))
}
case .continue:
break reviewLoop

View File

@ -34,7 +34,7 @@ import EfficientNGram
//
private var previousInputData: ComposingText?
private var nodes: [[LatticeNode]] = []
private var lattice: Lattice = Lattice()
private var completedData: Candidate?
private var lastData: DicdataElement?
/// Zenzaizenz
@ -49,7 +49,7 @@ import EfficientNGram
self.zenzaiPersonalization = nil
self.zenzaiCache = nil
self.previousInputData = nil
self.nodes = []
self.lattice = .init()
self.completedData = nil
self.lastData = nil
}
@ -448,9 +448,9 @@ import EfficientNGram
///
/// - Note:
///
private func processResult(inputData: ComposingText, result: (result: LatticeNode, nodes: [[LatticeNode]]), options: ConvertRequestOptions) -> ConversionResult {
private func processResult(inputData: ComposingText, result: (result: LatticeNode, lattice: Lattice), options: ConvertRequestOptions) -> ConversionResult {
self.previousInputData = inputData
self.nodes = result.nodes
self.lattice = result.lattice
let clauseResult = result.result.getCandidateData()
if clauseResult.isEmpty {
let candidates = self.getUniqueCandidate(self.getAdditionalCandidate(inputData, options: options))
@ -538,7 +538,7 @@ import EfficientNGram
seenCandidate.formUnion(clause_candidates.map {$0.text})
//
let dicCandidates: [Candidate] = result.nodes[0]
let dicCandidates: [Candidate] = result.lattice.nodes[0]
.map {
Candidate(
text: $0.data.word,
@ -605,7 +605,7 @@ import EfficientNGram
/// - N_best:
/// - Returns:
///
private func convertToLattice(_ inputData: ComposingText, N_best: Int, zenzaiMode: ConvertRequestOptions.ZenzaiMode) -> (result: LatticeNode, nodes: [[LatticeNode]])? {
private func convertToLattice(_ inputData: ComposingText, N_best: Int, zenzaiMode: ConvertRequestOptions.ZenzaiMode) -> (result: LatticeNode, lattice: Lattice)? {
if inputData.convertTarget.isEmpty {
return nil
}
@ -642,7 +642,7 @@ import EfficientNGram
//
if previousInputData == inputData {
let result = converter.kana2lattice_no_change(N_best: N_best, previousResult: (inputData: previousInputData, nodes: nodes))
let result = converter.kana2lattice_no_change(N_best: N_best, previousResult: (inputData: previousInputData, lattice: self.lattice))
self.previousInputData = inputData
return result
}
@ -650,7 +650,7 @@ import EfficientNGram
//
if let completedData, previousInputData.inputHasSuffix(inputOf: inputData) {
debug("\(#function): 文節確定用の関数を呼びます、確定された文節は\(completedData)")
let result = converter.kana2lattice_afterComplete(inputData, completedData: completedData, N_best: N_best, previousResult: (inputData: previousInputData, nodes: nodes), needTypoCorrection: needTypoCorrection)
let result = converter.kana2lattice_afterComplete(inputData, completedData: completedData, N_best: N_best, previousResult: (inputData: previousInputData, lattice: self.lattice), needTypoCorrection: needTypoCorrection)
self.previousInputData = inputData
self.completedData = nil
return result
@ -662,7 +662,7 @@ import EfficientNGram
let diff = inputData.differenceSuffix(to: previousInputData)
debug("\(#function): 最後尾文字置換用の関数を呼びます、差分は\(diff)")
let result = converter.kana2lattice_changed(inputData, N_best: N_best, counts: (diff.deleted, diff.addedCount), previousResult: (inputData: previousInputData, nodes: nodes), needTypoCorrection: needTypoCorrection)
let result = converter.kana2lattice_changed(inputData, N_best: N_best, counts: (diff.deleted, diff.addedCount), previousResult: (inputData: previousInputData, lattice: self.lattice), needTypoCorrection: needTypoCorrection)
self.previousInputData = inputData
return result
}