[Experimental] ConvertGraphの差分ベース構築を実装 (#77)

* 差分ビルドを実装

* テストケース追加
This commit is contained in:
Miwa / Ensan
2024-04-07 23:51:58 +09:00
committed by GitHub
parent 5713548f4d
commit 7a3f1d3c3c
4 changed files with 301 additions and 24 deletions

View File

@ -11,6 +11,7 @@ import Foundation
struct ConvertGraph {
struct Node {
var value: Character
var latticeNodes: [LatticeNode]
var inputElementsRange: InputGraphRange
var correction: CorrectGraph.Correction = .none
@ -18,7 +19,7 @@ struct ConvertGraph {
var nodes: [Node] = [
// root node
Node(latticeNodes: [], inputElementsRange: .endIndex(0))
Node(value: "\0", latticeNodes: [], inputElementsRange: .endIndex(0))
]
/// NextIndex
@ -28,7 +29,12 @@ struct ConvertGraph {
init(input: LookupGraph, nodeIndex2LatticeNode: [Int: [LatticeNode]]) {
let nodes = input.nodes.enumerated().map { (index, node) in
Node(latticeNodes: nodeIndex2LatticeNode[index, default: []], inputElementsRange: node.inputElementsRange, correction: node.correction)
Node(
value: node.character,
latticeNodes: nodeIndex2LatticeNode[index, default: []],
inputElementsRange: node.inputElementsRange,
correction: node.correction
)
}
self.nodes = nodes
self.allowedPrevIndex = input.allowedPrevIndex
@ -41,8 +47,8 @@ extension ConvertGraph {
final class LatticeNode: CustomStringConvertible {
///
public let data: DicdataElement
/// ConvertGraphindex
var nextConvertNodeIndices: IndexSet = []
/// DicdataElementConvertGraphindex
var endNodeIndex: Int
/// `N_best`
var prevs: [RegisteredNode] = []
/// `prevs`
@ -51,13 +57,13 @@ extension ConvertGraph {
/// `EOS`
static var EOSNode: LatticeNode {
LatticeNode(data: DicdataElement.EOSData, nextConvertNodeIndices: [], inputElementsRange: .unknown)
LatticeNode(data: DicdataElement.EOSData, endNodeIndex: 0, inputElementsRange: .unknown)
}
init(data: DicdataElement, nextConvertNodeIndices: IndexSet, inputElementsRange: InputGraphRange, prevs: [RegisteredNode] = []) {
init(data: DicdataElement, endNodeIndex: Int, inputElementsRange: InputGraphRange, prevs: [RegisteredNode] = []) {
self.data = data
self.values = [data.value()]
self.nextConvertNodeIndices = nextConvertNodeIndices
self.values = []
self.endNodeIndex = endNodeIndex
self.inputElementsRange = inputElementsRange
self.prevs = prevs
}
@ -165,13 +171,13 @@ extension ConvertGraph {
node.values = node.prevs.map {$0.totalValue + wValue}
}
//
if node.nextConvertNodeIndices.isEmpty || result.inputElementsRange.startIndex == node.inputElementsRange.endIndex {
if self.allowedNextIndex[node.endNodeIndex, default: []].isEmpty || result.inputElementsRange.startIndex == node.inputElementsRange.endIndex {
for index in node.prevs.indices {
let newnode: RegisteredNode = node.getRegisteredNode(index, value: node.values[index])
result.prevs.append(newnode)
}
} else {
for nextIndex in node.nextConvertNodeIndices {
for nextIndex in self.allowedNextIndex[node.endNodeIndex, default: []] {
// nodenextnode
for nextnode in self.nodes[nextIndex].latticeNodes {
// node.registered.isEmpty
@ -203,4 +209,175 @@ extension ConvertGraph {
}
return result
}
mutating func convertAllDifferential(cacheConvertGraph: ConvertGraph, option: borrowing ConvertRequestOptions, dicdataStore: DicdataStore, lookupGraphMatchInfo: [Int: Int]) -> LatticeNode {
print(lookupGraphMatchInfo)
//
typealias MatchSearchItem = (
curNodeIndex: Int,
cacheNodeIndex: Int
)
// BOS
var stack: [MatchSearchItem] = [(0, 0)]
var curNodeToCacheNode: [Int: Int] = [:]
do {
var processedIndices: IndexSet = []
while let item = stack.popLast() {
if processedIndices.contains(item.curNodeIndex) {
continue
}
let prevIndices = self.allowedPrevIndex[item.curNodeIndex, default: []]
if prevIndices.allSatisfy(processedIndices.contains) {
if prevIndices.allSatisfy(curNodeToCacheNode.keys.contains) {
//
curNodeToCacheNode[item.curNodeIndex] = item.cacheNodeIndex
//
for nextNodeIndex in self.allowedNextIndex[item.curNodeIndex, default: []] {
let nextNode = self.nodes[nextNodeIndex]
if let cacheNodeIndex = cacheConvertGraph.allowedNextIndex[item.cacheNodeIndex, default: []].first(where: {
cacheConvertGraph.nodes[$0].value == nextNode.value
}) {
stack.append((nextNodeIndex, cacheNodeIndex))
}
}
} else {
//
// nextNode
processedIndices.formUnion(self.allowedNextIndex[item.curNodeIndex, default: []])
}
//
processedIndices.insert(item.curNodeIndex)
} else {
// prevNode
let restIndices = prevIndices.subtracting(IndexSet(processedIndices))
let firstIndex = stack.firstIndex(where: { restIndices.contains($0.curNodeIndex) }) ?? 0
stack.insert(item, at: firstIndex)
}
}
}
let lookupGraphCacheNodeToCurNode = Dictionary(lookupGraphMatchInfo.map {(k, v) in (v, k)}, uniquingKeysWith: { (k1, _) in k1 })
struct HashablePair<T1: Hashable, T2: Hashable>: Hashable {
init(_ first: T1, _ second: T2) {
self.first = first
self.second = second
}
var first: T1
var second: T2
}
// 使self
print("curNodeToCacheNode", curNodeToCacheNode)
for (curNodeIndex, cacheNodeIndex) in curNodeToCacheNode {
self.nodes[curNodeIndex].latticeNodes.removeAll {
lookupGraphMatchInfo.keys.contains($0.endNodeIndex)
}
cacheConvertGraph.nodes[cacheNodeIndex].latticeNodes.forEach {
if let e = lookupGraphCacheNodeToCurNode[$0.endNodeIndex] {
$0.endNodeIndex = e
self.nodes[curNodeIndex].latticeNodes.append($0)
}
}
}
//
let result: LatticeNode = LatticeNode.EOSNode
result.inputElementsRange = .init(startIndex: self.nodes.compactMap {$0.inputElementsRange.endIndex}.max(), endIndex: nil)
var processStack = Array(self.nodes.enumerated().reversed())
var processedIndices: IndexSet = [0] // root
var invalidIndices: IndexSet = []
while let (i, graphNode) = processStack.popLast() {
//
guard !processedIndices.contains(i), !invalidIndices.contains(i) else {
continue
}
// prevNode
let prevIndices = self.allowedPrevIndex[i, default: []]
guard !prevIndices.isEmpty else {
//
invalidIndices.insert(i)
continue
}
var unprocessedPrevs: Set<Int> = []
for prevIndex in prevIndices {
if !processedIndices.contains(prevIndex) && !invalidIndices.contains(prevIndex) {
unprocessedPrevs.insert(prevIndex)
}
}
// prevNodestack
guard unprocessedPrevs.isEmpty else {
// prevNode
let firstIndex = processStack.firstIndex(where: { unprocessedPrevs.contains($0.offset) }) ?? 0
processStack.insert((i, graphNode), at: firstIndex)
continue
}
print(i, graphNode.inputElementsRange)
processedIndices.insert(i)
let isMatchedGraphNode = curNodeToCacheNode.keys.contains(i)
//
for node in graphNode.latticeNodes {
if node.prevs.isEmpty {
continue
}
if dicdataStore.shouldBeRemoved(data: node.data) {
continue
}
let isMatched = isMatchedGraphNode && lookupGraphMatchInfo.keys.contains(node.endNodeIndex)
if !isMatched {
//
//
let wValue: PValue = node.data.value()
node.values = if i == 0 {
// values
node.prevs.map {$0.totalValue + wValue + dicdataStore.getCCValue($0.data.rcid, node.data.lcid)}
} else {
// values
node.prevs.map {$0.totalValue + wValue}
}
}
//
if self.allowedNextIndex[node.endNodeIndex, default: []].isEmpty || result.inputElementsRange.startIndex == node.inputElementsRange.endIndex {
for index in node.prevs.indices {
let newnode: RegisteredNode = node.getRegisteredNode(index, value: node.values[index])
result.prevs.append(newnode)
}
} else {
for nextIndex in self.allowedNextIndex[node.endNodeIndex, default: []] {
//
let nextMatchable = curNodeToCacheNode.keys.contains(nextIndex)
// nodenextnode
for nextnode in self.nodes[nextIndex].latticeNodes {
if nextMatchable && lookupGraphMatchInfo.keys.contains(nextnode.endNodeIndex) {
continue
}
// node.registered.isEmpty
if dicdataStore.shouldBeRemoved(data: nextnode.data) {
continue
}
//
let ccValue: PValue = dicdataStore.getCCValue(node.data.rcid, nextnode.data.lcid)
// nodeprevnode
for (index, value) in node.values.enumerated() {
let newValue: PValue = ccValue + value
// index
let lastindex: Int = (nextnode.prevs.lastIndex(where: {$0.totalValue >= newValue}) ?? -1) + 1
if lastindex == option.N_best {
continue
}
let newnode: RegisteredNode = node.getRegisteredNode(index, value: newValue)
//
if nextnode.prevs.count >= option.N_best {
nextnode.prevs.removeLast()
}
// removeinsert (insertO(N))
nextnode.prevs.insert(newnode, at: lastindex)
}
}
}
}
}
}
return result
}
}

View File

@ -85,7 +85,12 @@ struct LookupGraph {
return (indexSet, loudsNodeIndex2GraphNodeEndIndices)
}
mutating func differentialByfixSearch(in louds: LOUDS, cacheLookupGraph: LookupGraph, graphNodeIndex: (start: Int, cache: Int)) -> (IndexSet, [Int: [Int]]) {
mutating func differentialByfixSearch(
in louds: LOUDS,
cacheLookupGraph: LookupGraph,
graphNodeIndex: (start: Int, cache: Int),
lookupGraphMatch: inout [Int: Int]
) -> (IndexSet, [Int: [Int]]) {
guard var graphNodeEndIndexToLoudsNodeIndex = cacheLookupGraph.loudsNodeIndex[graphNodeIndex.cache] else {
return self.byfixNodeIndices(in: louds, startGraphNodeIndex: graphNodeIndex.start)
}
@ -110,6 +115,8 @@ struct LookupGraph {
stack.append(contentsOf: self.nextIndexWithMatch(cNodeIndex, cacheNodeIndex: cCacheNodeIndex, cacheGraph: cacheLookupGraph).map {
($0.0, $0.1, loudsNodeIndex)
})
//
lookupGraphMatch[cNodeIndex] = cCacheNodeIndex
}
//
else if let loudsNodeIndex = louds.searchCharNodeIndex(from: cLastLoudsNodeIndex, char: cNode.charId) {
@ -168,11 +175,11 @@ extension DicdataStore {
)
if graphNode.inputElementsRange.startIndex == 0 {
latticeNodes.append(contentsOf: dicdata.map {
.init(data: $0, nextConvertNodeIndices: lookupGraph.allowedNextIndex[endNodeIndex, default: []], inputElementsRange: inputElementsRange, prevs: [.BOSNode()])
.init(data: $0, endNodeIndex: endNodeIndex, inputElementsRange: inputElementsRange, prevs: [.BOSNode()])
})
} else {
latticeNodes.append(contentsOf: dicdata.map {
.init(data: $0, nextConvertNodeIndices: lookupGraph.allowedNextIndex[endNodeIndex, default: []], inputElementsRange: inputElementsRange)
.init(data: $0, endNodeIndex: endNodeIndex, inputElementsRange: inputElementsRange)
})
}
}
@ -186,7 +193,15 @@ extension DicdataStore {
return (lookupGraph, ConvertGraph(input: lookupGraph, nodeIndex2LatticeNode: graphNodeIndex2LatticeNodes))
}
func buildConvertGraphDifferential(inputGraph: consuming InputGraph, cacheLookupGraph: LookupGraph, option: ConvertRequestOptions) -> (LookupGraph, ConvertGraph) {
func buildConvertGraphDifferential(
inputGraph: consuming InputGraph,
cacheLookupGraph: LookupGraph,
option: ConvertRequestOptions
) -> (
lookupGraph: LookupGraph,
convertGraph: ConvertGraph,
lookupGraphMatch: [Int: Int]
) {
var lookupGraph = LookupGraph.build(input: consume inputGraph, character2CharId: { self.character2charId($0.toKatakana()) })
typealias StackItem = (
currentLookupGraphNodeIndex: Int,
@ -197,6 +212,8 @@ extension DicdataStore {
var stack: [StackItem] = lookupGraph.nextIndexWithMatch(0, cacheNodeIndex: 0, cacheGraph: cacheLookupGraph)
var graphNodeIndex2LatticeNodes: [Int: [ConvertGraph.LatticeNode]] = [:]
var processedIndexSet = IndexSet()
var lookupGraphMatch: [Int: Int] = [:]
while let (graphNodeIndex, cacheGraphNodeIndex) = stack.popLast() {
//
guard !processedIndexSet.contains(graphNodeIndex) else {
@ -210,7 +227,7 @@ extension DicdataStore {
/// * loudsNodeIndices: loudsloudstxt
/// * loudsNodeIndex2GraphNodeEndIndices: loudsNodeIndexgraphNodeIndex
let (indexSet, loudsNodeIndex2GraphNodeEndIndices) = if let cacheGraphNodeIndex {
lookupGraph.differentialByfixSearch(in: louds, cacheLookupGraph: cacheLookupGraph, graphNodeIndex: (graphNodeIndex, cacheGraphNodeIndex))
lookupGraph.differentialByfixSearch(in: louds, cacheLookupGraph: cacheLookupGraph, graphNodeIndex: (graphNodeIndex, cacheGraphNodeIndex), lookupGraphMatch: &lookupGraphMatch)
} else {
lookupGraph.byfixNodeIndices(in: louds, startGraphNodeIndex: graphNodeIndex)
}
@ -226,11 +243,11 @@ extension DicdataStore {
)
if graphNode.inputElementsRange.startIndex == 0 {
latticeNodes.append(contentsOf: dicdata.map {
.init(data: $0, nextConvertNodeIndices: lookupGraph.allowedNextIndex[endNodeIndex, default: []], inputElementsRange: inputElementsRange, prevs: [.BOSNode()])
.init(data: $0, endNodeIndex: endNodeIndex, inputElementsRange: inputElementsRange, prevs: [.BOSNode()])
})
} else {
latticeNodes.append(contentsOf: dicdata.map {
.init(data: $0, nextConvertNodeIndices: lookupGraph.allowedNextIndex[endNodeIndex, default: []], inputElementsRange: inputElementsRange)
.init(data: $0, endNodeIndex: endNodeIndex, inputElementsRange: inputElementsRange)
})
}
}
@ -245,7 +262,7 @@ extension DicdataStore {
stack.append(contentsOf: lookupGraph.allowedNextIndex[graphNodeIndex, default: []].map {($0, nil)})
}
}
return (lookupGraph, ConvertGraph(input: lookupGraph, nodeIndex2LatticeNode: graphNodeIndex2LatticeNodes))
return (lookupGraph, ConvertGraph(input: lookupGraph, nodeIndex2LatticeNode: graphNodeIndex2LatticeNodes), lookupGraphMatch)
}
func getDicdataFromLoudstxt3(identifier: String, indices: some Sequence<Int>, option: ConvertRequestOptions) -> [(loudsNodeIndex: Int, dicdata: [DicdataElement])] {

View File

@ -236,7 +236,8 @@ final class LookupGraphTests: XCTestCase {
var lookupGraph2 = LookupGraph.build(input: inputGraph2, character2CharId: values.character2CharId)
let startNodeIndex2 = lookupGraph2.allowedNextIndex[0, default: IndexSet()].first(where: { lookupGraph2.nodes[$0].character == "" })
XCTAssertNotNil(startNodeIndex2)
let (loudsNodeIndices2, _) = lookupGraph2.differentialByfixSearch(in: louds, cacheLookupGraph: lookupGraph1, graphNodeIndex: (startNodeIndex2 ?? 0, startNodeIndex1 ?? 0))
var matchInfo: [Int: Int] = [:]
let (loudsNodeIndices2, _) = lookupGraph2.differentialByfixSearch(in: louds, cacheLookupGraph: lookupGraph1, graphNodeIndex: (startNodeIndex2 ?? 0, startNodeIndex1 ?? 0), lookupGraphMatch: &matchInfo)
let dicdataWithIndex = values.dicdataStore.getDicdataFromLoudstxt3(identifier: "", indices: loudsNodeIndices2, option: requestOptions())
let dicdata = dicdataWithIndex.flatMapSet { $0.dicdata }
//
@ -284,7 +285,8 @@ final class LookupGraphTests: XCTestCase {
var lookupGraph2 = LookupGraph.build(input: inputGraph2, character2CharId: values.character2CharId)
let startNodeIndex2 = lookupGraph2.allowedNextIndex[0, default: IndexSet()].first(where: { lookupGraph2.nodes[$0].character == "" })
XCTAssertNotNil(startNodeIndex2)
let (loudsNodeIndices2, _) = lookupGraph2.differentialByfixSearch(in: louds, cacheLookupGraph: lookupGraph1, graphNodeIndex: (startNodeIndex2 ?? 0, startNodeIndex1 ?? 0))
var matchInfo: [Int: Int] = [:]
let (loudsNodeIndices2, _) = lookupGraph2.differentialByfixSearch(in: louds, cacheLookupGraph: lookupGraph1, graphNodeIndex: (startNodeIndex2 ?? 0, startNodeIndex1 ?? 0), lookupGraphMatch: &matchInfo)
let dicdataWithIndex = values.dicdataStore.getDicdataFromLoudstxt3(identifier: "", indices: loudsNodeIndices2, option: requestOptions())
let dicdata = dicdataWithIndex.flatMapSet { $0.dicdata }
//

View File

@ -47,10 +47,9 @@ extension Kana2Kanji {
let inputGraph = InputGraph.build(input: previousResult.correctGraph)
// convertGraph
print(#file, "lookup", previousResult.inputGraph)
let (lookupGraph, convertGraph) = self.dicdataStore.buildConvertGraphDifferential(inputGraph: inputGraph, cacheLookupGraph: previousResult.lookupGraph, option: option)
var (lookupGraph, convertGraph, matchInfo) = self.dicdataStore.buildConvertGraphDifferential(inputGraph: inputGraph, cacheLookupGraph: previousResult.lookupGraph, option: option)
print(#file, "convert")
// TODO:
let result = convertGraph.convertAll(option: option, dicdataStore: self.dicdataStore)
let result = convertGraph.convertAllDifferential(cacheConvertGraph: previousResult.convertGraph, option: option, dicdataStore: self.dicdataStore, lookupGraphMatchInfo: matchInfo)
return Result(endNode: result, correctGraph: previousResult.correctGraph, inputGraph: inputGraph, lookupGraph: lookupGraph, convertGraph: convertGraph)
}
}
@ -87,7 +86,7 @@ final class ExperimentalConversionTests: XCTestCase {
convertGraph.nodes.first {
$0.latticeNodes.contains(where: {$0.data.word == ""})
}?.latticeNodes.mapSet {$0.data.ruby}
.symmetricDifference(["", "タイ", "タイカ", "タイガ", "タイカク", "タイガク"]),
.symmetricDifference(["", "タイ", "タイカ", "タイガ", "タイカク", "タイガク"]),
[]
)
}
@ -256,4 +255,86 @@ final class ExperimentalConversionTests: XCTestCase {
)
XCTAssertTrue(thirdResult.endNode.joinedPrevs().contains("大国")) //
}
func testConversion_incremental_intai() throws {
let dicdataStore = DicdataStore(requestOptions: requestOptions())
let kana2kanji = Kana2Kanji(dicdataStore: dicdataStore)
var c = ComposingText()
c.insertAtCursorPosition("i", inputStyle: .roman2kana)
let firstResult = kana2kanji._experimental_all(c, option: requestOptions())
XCTAssertTrue(firstResult.endNode.joinedPrevs().contains("")) //
c.insertAtCursorPosition("n", inputStyle: .roman2kana)
let secondResult = kana2kanji._experimental_additional(
composingText: c,
additionalInputsStartIndex: 1,
previousResult: firstResult,
option: requestOptions()
)
print(secondResult.endNode.joinedPrevs())
c.insertAtCursorPosition("t", inputStyle: .roman2kana)
let thirdResult = kana2kanji._experimental_additional(
composingText: c,
additionalInputsStartIndex: 2,
previousResult: secondResult,
option: requestOptions()
)
print(thirdResult.endNode.joinedPrevs())
c.insertAtCursorPosition("a", inputStyle: .roman2kana)
let forthResult = kana2kanji._experimental_additional(
composingText: c,
additionalInputsStartIndex: 3,
previousResult: thirdResult,
option: requestOptions()
)
XCTAssertTrue(forthResult.endNode.joinedPrevs().contains("インタ")) //
c.insertAtCursorPosition("i", inputStyle: .roman2kana)
let fifthResult = kana2kanji._experimental_additional(
composingText: c,
additionalInputsStartIndex: 4,
previousResult: forthResult,
option: requestOptions()
)
XCTAssertTrue(fifthResult.endNode.joinedPrevs().contains("引退")) //
}
func testConversion_incremental_intsi() throws {
let dicdataStore = DicdataStore(requestOptions: requestOptions())
let kana2kanji = Kana2Kanji(dicdataStore: dicdataStore)
var c = ComposingText()
c.insertAtCursorPosition("i", inputStyle: .roman2kana)
let firstResult = kana2kanji._experimental_all(c, option: requestOptions())
XCTAssertTrue(firstResult.endNode.joinedPrevs().contains("")) //
c.insertAtCursorPosition("n", inputStyle: .roman2kana)
let secondResult = kana2kanji._experimental_additional(
composingText: c,
additionalInputsStartIndex: 1,
previousResult: firstResult,
option: requestOptions()
)
// XCTAssertTrue(secondResult.endNode.joinedPrevs().contains("n")) // in
c.insertAtCursorPosition("t", inputStyle: .roman2kana)
let thirdResult = kana2kanji._experimental_additional(
composingText: c,
additionalInputsStartIndex: 2,
previousResult: secondResult,
option: requestOptions()
)
// XCTAssertTrue(thirdResult.endNode.joinedPrevs().contains("t")) // int
c.insertAtCursorPosition("s", inputStyle: .roman2kana)
let forthResult = kana2kanji._experimental_additional(
composingText: c,
additionalInputsStartIndex: 3,
previousResult: thirdResult,
option: requestOptions()
)
XCTAssertTrue(forthResult.endNode.joinedPrevs().contains("インタ")) //
c.insertAtCursorPosition("i", inputStyle: .roman2kana)
let fifthResult = kana2kanji._experimental_additional(
composingText: c,
additionalInputsStartIndex: 4,
previousResult: forthResult,
option: requestOptions()
)
XCTAssertTrue(fifthResult.endNode.joinedPrevs().contains("引退")) //
}
}