feat: ngramの学習のresume処理を実装 (#156)

* feat: ngramの学習のresume処理を実装

* fix: mock signature for Android
This commit is contained in:
Miwa
2025-02-13 16:29:25 +09:00
committed by GitHub
parent 0e8e3e7ba3
commit 24f506b5b5
2 changed files with 109 additions and 25 deletions

View File

@@ -13,6 +13,9 @@ extension Subcommands.NGram {
@Option(name: [.customShort("n")], help: "n-gram's n")
var n: Int = 5
@Option(name: [.customLong("resume")], help: "Resume from these lm data")
var resumeFilePattern: String?
static let configuration = CommandConfiguration(
commandName: "train",
abstract: "Train ngram and write the data"
@@ -21,7 +24,13 @@ extension Subcommands.NGram {
mutating func run() throws {
let pattern = URL(fileURLWithPath: self.outputDirectory).path() + "lm_"
print("Saving for \(pattern)")
trainNGramFromFile(filePath: self.target, n: self.n, baseFilename: "lm", outputDir: self.outputDirectory)
trainNGramFromFile(
filePath: self.target,
n: self.n,
baseFilePattern: "lm",
outputDir: self.outputDirectory,
resumeFilePattern: self.resumeFilePattern
)
}
}
}

View File

@@ -8,18 +8,27 @@ final class SwiftTrainer {
let n: Int
let tokenizer: ZenzTokenizer
/// Python defaultdict(int)
private var c_abc = [[Int]: Int]()
private var c_bc = [[Int]: Int]()
private var u_abx = [[Int]: Int]()
private var u_xbc = [[Int]: Int]()
/// Python defaultdict(set)
private var s_xbx = [[Int]: Set<Int>]()
private var r_xbx = [[Int]: Int]()
init(n: Int, tokenizer: ZenzTokenizer) {
self.n = n
self.tokenizer = tokenizer
}
init(baseFilePattern: String, n: Int, tokenizer: ZenzTokenizer) {
self.tokenizer = tokenizer
self.n = n
self.c_abc = Self.loadDictionary(from: "\(baseFilePattern)_c_abc.marisa")
self.c_bc = Self.loadDictionary(from: "\(baseFilePattern)_c_bc.marisa")
self.u_abx = Self.loadDictionary(from: "\(baseFilePattern)_u_abx.marisa")
self.u_xbc = Self.loadDictionary(from: "\(baseFilePattern)_u_xbc.marisa")
self.r_xbx = Self.loadDictionary(from: "\(baseFilePattern)_r_xbx.marisa")
}
/// n-gram (abc )
/// Python count_ngram
private func countNGram(_ ngram: some BidirectionalCollection<Int>) {
@@ -31,10 +40,10 @@ final class SwiftTrainer {
let Bc = Array(ngram.dropFirst()) // bc
// B, c
let B = Array(ngram.dropFirst().dropLast())
let c = ngram.last!
// C(abc)
c_abc[aBc, default: 0] += 1
// C(bc)
c_bc[Bc, default: 0] += 1
// U(...)
if c_abc[aBc] == 1 {
@@ -43,8 +52,11 @@ final class SwiftTrainer {
// U(bc)
u_xbc[Bc, default: 0] += 1
}
// s_xbx[B] = s_xbx[B] {c}
s_xbx[B, default: Set()].insert(c)
if c_bc[Bc] == 1 {
// s_xbx[B] = s_xbx[B] {c}
r_xbx[B, default: 0] += 1
}
}
/// n-gram
@@ -102,6 +114,58 @@ final class SwiftTrainer {
return key + [Self.keyValueDelimiter] + Self.encodeValue(value: value)
}
private static func loadDictionary(from path: String) -> [[Int]: Int] {
let trie = Marisa()
trie.load(path)
// predict
var dict = [[Int]: Int]()
for i in Int8(0) ..< Int8.max {
for encodedEntry in trie.search([i], .predictive) {
if let (key, value) = Self.decodeEncodedEntry(encoded: encodedEntry) {
dict[key] = value
}
}
}
return dict
}
/// [key, value]
private static func decodeEncodedEntry(encoded: [Int8]) -> ([Int], Int)? {
guard let delimiterIndex = encoded.firstIndex(of: keyValueDelimiter) else {
return nil
}
let keyEncoded = encoded[..<delimiterIndex]
let valueEncoded = encoded[(delimiterIndex + 1)...]
// bulk get delimiter
let filteredKeyEncoded = keyEncoded.filter { $0 != predictiveDelimiter }
// key (v1, v2)
guard filteredKeyEncoded.count % 2 == 0 else {
return nil
}
var key: [Int] = []
var index = filteredKeyEncoded.startIndex
while index < filteredKeyEncoded.endIndex {
let token = decodeKey(
v1: filteredKeyEncoded[index],
v2: filteredKeyEncoded[filteredKeyEncoded.index(after: index)]
)
key.append(token)
index = filteredKeyEncoded.index(index, offsetBy: 2)
}
// value 5
guard valueEncoded.count == 5 else { return nil }
let d = Int(Int8.max - 1)
var value = 0
for item in valueEncoded {
value = value * d + (Int(item) - 1)
}
return (key, value)
}
/// [[Int]: Int] Trie
private func buildAndSaveTrie(from dict: [[Int]: Int], to path: String, forBulkGet: Bool = false) {
let encode = forBulkGet ? encodeKeyValueForBulkGet : encodeKeyValue
@@ -118,7 +182,7 @@ final class SwiftTrainer {
/// marisa
func saveToMarisaTrie(baseFilename: String, outputDir: String? = nil) {
func saveToMarisaTrie(baseFilePattern: String, outputDir: String? = nil) {
let fileManager = FileManager.default
// : ~/Library/Application Support/SwiftNGram/marisa/
@@ -144,21 +208,21 @@ final class SwiftTrainer {
// marisa
let paths = [
"\(baseFilename)_c_abc.marisa",
"\(baseFilename)_u_abx.marisa",
"\(baseFilename)_u_xbc.marisa",
"\(baseFilename)_r_xbx.marisa",
"\(baseFilePattern)_c_abc.marisa",
"\(baseFilePattern)_c_bc.marisa",
"\(baseFilePattern)_u_abx.marisa",
"\(baseFilePattern)_u_xbc.marisa",
"\(baseFilePattern)_r_xbx.marisa",
].map { file in
marisaDir.appendingPathComponent(file).path
}
// Trie
buildAndSaveTrie(from: c_abc, to: paths[0], forBulkGet: true)
buildAndSaveTrie(from: u_abx, to: paths[1])
buildAndSaveTrie(from: u_xbc, to: paths[2], forBulkGet: true)
let r_xbx: [[Int]: Int] = s_xbx.mapValues { $0.count }
buildAndSaveTrie(from: r_xbx, to: paths[3])
buildAndSaveTrie(from: c_bc, to: paths[1])
buildAndSaveTrie(from: u_abx, to: paths[2])
buildAndSaveTrie(from: u_xbc, to: paths[3], forBulkGet: true)
buildAndSaveTrie(from: r_xbx, to: paths[4])
// ****
print("All saved files (absolute paths):")
@@ -192,11 +256,16 @@ public func readLinesFromFile(filePath: String) -> [String]? {
public func trainNGram(
lines: [String],
n: Int,
baseFilename: String,
outputDir: String? = nil
baseFilePattern: String,
outputDir: String? = nil,
resumeFilePattern: String? = nil
) {
let tokenizer = ZenzTokenizer()
let trainer = SwiftTrainer(n: n, tokenizer: tokenizer)
let trainer = if let resumeFilePattern {
SwiftTrainer(baseFilePattern: resumeFilePattern, n: n, tokenizer: tokenizer)
} else {
SwiftTrainer(n: n, tokenizer: tokenizer)
}
for (i, line) in lines.enumerated() {
if i % 100 == 0 {
@@ -209,18 +278,24 @@ public func trainNGram(
}
// Trie
trainer.saveToMarisaTrie(baseFilename: baseFilename, outputDir: outputDir)
trainer.saveToMarisaTrie(baseFilePattern: baseFilePattern, outputDir: outputDir)
}
/// : n-gram
public func trainNGramFromFile(filePath: String, n: Int, baseFilename: String, outputDir: String? = nil) {
public func trainNGramFromFile(
filePath: String,
n: Int,
baseFilePattern: String,
outputDir: String? = nil,
resumeFilePattern: String? = nil
) {
guard let lines = readLinesFromFile(filePath: filePath) else {
return
}
trainNGram(lines: lines, n: n, baseFilename: baseFilename, outputDir: outputDir)
trainNGram(lines: lines, n: n, baseFilePattern: baseFilePattern, outputDir: outputDir, resumeFilePattern: resumeFilePattern)
}
#else
public func trainNGramFromFile(filePath: String, n: Int, baseFilename: String, outputDir: String? = nil) {
public func trainNGramFromFile(filePath: String, n: Int, baseFilePattern: String, outputDir: String? = nil, resumeFilePattern: String? = nil) {
fatalError("[Error] trainNGramFromFile is unsupported.")
}
#endif