refactor: N-Best計算部分を切り出して共通化した

This commit is contained in:
Miwa / Ensan
2025-07-08 23:28:47 +09:00
parent a0f2ff71b6
commit 31a8ece9fc
3 changed files with 31 additions and 96 deletions

View File

@@ -59,35 +59,37 @@ extension Kana2Kanji {
result.prevs.append(newnode)
}
} else {
// nodenextnode
for nextnode in nodes[nextIndex] {
// node.registered.isEmpty
if self.dicdataStore.shouldBeRemoved(data: nextnode.data) {
continue
}
//
let ccValue: PValue = self.dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid)
// nodeprevnode
for (index, value) in node.values.enumerated() {
let newValue: PValue = ccValue + value
// index
let lastindex: Int = (nextnode.prevs.lastIndex(where: {$0.totalValue >= newValue}) ?? -1) + 1
if lastindex == N_best {
continue
}
let newnode: RegisteredNode = node.getRegisteredNode(index, value: newValue)
//
if nextnode.prevs.count >= N_best {
nextnode.prevs.removeLast()
}
// removeinsert (insertO(N))
nextnode.prevs.insert(newnode, at: lastindex)
}
}
self.updateNextNodes(with: node, nextNodes: nodes[nextIndex], nBest: N_best)
}
}
}
return (result: result, lattice: Lattice(nodes: nodes))
}
/// N-Best
func updateNextNodes(with node: LatticeNode, nextNodes: [LatticeNode], nBest: Int) {
for nextnode in nextNodes {
if self.dicdataStore.shouldBeRemoved(data: nextnode.data) {
continue
}
//
let ccValue: PValue = self.dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid)
// nodeprevnode
for (index, value) in node.values.enumerated() {
let newValue: PValue = ccValue + value
// index
let lastindex: Int = (nextnode.prevs.lastIndex(where: {$0.totalValue >= newValue}) ?? -1) + 1
if lastindex == nBest {
continue
}
let newnode: RegisteredNode = node.getRegisteredNode(index, value: newValue)
//
if nextnode.prevs.count >= nBest {
nextnode.prevs.removeLast()
}
// removeinsert (insertO(N))
nextnode.prevs.insert(newnode, at: lastindex)
}
}
}
}

View File

@@ -60,31 +60,9 @@ extension Kana2Kanji {
let nextIndex = node.inputRange.endIndex
// count
if nextIndex != count {
for nextnode in lattice.nodes[nextIndex] {
if self.dicdataStore.shouldBeRemoved(data: nextnode.data) {
continue
}
//
let ccValue = self.dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid)
// nodeprevnode
for (index, value) in node.values.enumerated() {
let newValue = ccValue + value
// index
let lastindex = (nextnode.prevs.lastIndex(where: {$0.totalValue >= newValue}) ?? -1) + 1
if lastindex == N_best {
continue
}
let newnode = node.getRegisteredNode(index, value: newValue)
//
if nextnode.prevs.count >= N_best {
nextnode.prevs.removeLast()
}
// removeinsert (insertO(N))
nextnode.prevs.insert(newnode, at: lastindex)
}
}
// count
self.updateNextNodes(with: node, nextNodes: lattice.nodes[nextIndex], nBest: N_best)
} else {
// count
for index in node.prevs.indices {
let newnode = node.getRegisteredNode(index, value: node.values[index])
result.prevs.append(newnode)

View File

@@ -57,29 +57,7 @@ extension Kana2Kanji {
}
//
let nextIndex = node.inputRange.endIndex
for nextnode in addedNodes.nodes[nextIndex] {
if self.dicdataStore.shouldBeRemoved(data: nextnode.data) {
continue
}
//
let ccValue: PValue = self.dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid)
// nodeprevnode
for (index, value) in node.values.enumerated() {
let newValue: PValue = ccValue + value
// index
let lastindex: Int = (nextnode.prevs.lastIndex(where: {$0.totalValue >= newValue}) ?? -1) + 1
if lastindex == N_best {
continue
}
let newnode: RegisteredNode = node.getRegisteredNode(index, value: newValue)
//
if nextnode.prevs.count >= N_best {
nextnode.prevs.removeLast()
}
// removeinsert (insertO(N))
nextnode.prevs.insert(newnode, at: lastindex)
}
}
self.updateNextNodes(with: node, nextNodes: addedNodes.nodes[nextIndex], nBest: N_best)
}
}
lattice.merge(addedNodes)
@@ -116,30 +94,7 @@ extension Kana2Kanji {
result.prevs.append(newnode)
}
} else {
for nextnode in terminalNodes.nodes[nextIndex] {
// node.registered.isEmpty
if self.dicdataStore.shouldBeRemoved(data: nextnode.data) {
continue
}
//
let ccValue = self.dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid)
// nodeprevnode
for (index, value) in node.values.enumerated() {
let newValue = ccValue + value
// index
let lastindex: Int = (nextnode.prevs.lastIndex(where: {$0.totalValue >= newValue}) ?? -1) + 1
if lastindex == N_best {
continue
}
let newnode: RegisteredNode = node.getRegisteredNode(index, value: newValue)
//
if nextnode.prevs.count >= N_best {
nextnode.prevs.removeLast()
}
// removeinsert (insertO(N))
nextnode.prevs.insert(newnode, at: lastindex)
}
}
self.updateNextNodes(with: node, nextNodes: terminalNodes.nodes[nextIndex], nBest: N_best)
}
}
}