refactor: N-Best計算部分を切り出して共通化した

This commit is contained in:
Miwa / Ensan
2025-07-08 23:28:47 +09:00
parent a0f2ff71b6
commit 31a8ece9fc
3 changed files with 31 additions and 96 deletions

View File

@ -59,35 +59,37 @@ extension Kana2Kanji {
result.prevs.append(newnode)
}
} else {
// nodenextnode
for nextnode in nodes[nextIndex] {
// node.registered.isEmpty
if self.dicdataStore.shouldBeRemoved(data: nextnode.data) {
continue
}
//
let ccValue: PValue = self.dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid)
// nodeprevnode
for (index, value) in node.values.enumerated() {
let newValue: PValue = ccValue + value
// index
let lastindex: Int = (nextnode.prevs.lastIndex(where: {$0.totalValue >= newValue}) ?? -1) + 1
if lastindex == N_best {
continue
}
let newnode: RegisteredNode = node.getRegisteredNode(index, value: newValue)
//
if nextnode.prevs.count >= N_best {
nextnode.prevs.removeLast()
}
// removeinsert (insertO(N))
nextnode.prevs.insert(newnode, at: lastindex)
}
}
self.updateNextNodes(with: node, nextNodes: nodes[nextIndex], nBest: N_best)
}
}
}
return (result: result, lattice: Lattice(nodes: nodes))
}
/// N-Best
func updateNextNodes(with node: LatticeNode, nextNodes: [LatticeNode], nBest: Int) {
for nextnode in nextNodes {
if self.dicdataStore.shouldBeRemoved(data: nextnode.data) {
continue
}
//
let ccValue: PValue = self.dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid)
// nodeprevnode
for (index, value) in node.values.enumerated() {
let newValue: PValue = ccValue + value
// index
let lastindex: Int = (nextnode.prevs.lastIndex(where: {$0.totalValue >= newValue}) ?? -1) + 1
if lastindex == nBest {
continue
}
let newnode: RegisteredNode = node.getRegisteredNode(index, value: newValue)
//
if nextnode.prevs.count >= nBest {
nextnode.prevs.removeLast()
}
// removeinsert (insertO(N))
nextnode.prevs.insert(newnode, at: lastindex)
}
}
}
}