mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-12-03 02:58:27 +00:00
feat: ngramの学習のresume処理を実装 (#156)
* feat: ngramの学習のresume処理を実装 * fix: mock signature for Android
This commit is contained in:
@@ -13,6 +13,9 @@ extension Subcommands.NGram {
|
||||
@Option(name: [.customShort("n")], help: "n-gram's n")
|
||||
var n: Int = 5
|
||||
|
||||
@Option(name: [.customLong("resume")], help: "Resume from these lm data")
|
||||
var resumeFilePattern: String?
|
||||
|
||||
static let configuration = CommandConfiguration(
|
||||
commandName: "train",
|
||||
abstract: "Train ngram and write the data"
|
||||
@@ -21,7 +24,13 @@ extension Subcommands.NGram {
|
||||
mutating func run() throws {
|
||||
let pattern = URL(fileURLWithPath: self.outputDirectory).path() + "lm_"
|
||||
print("Saving for \(pattern)")
|
||||
trainNGramFromFile(filePath: self.target, n: self.n, baseFilename: "lm", outputDir: self.outputDirectory)
|
||||
trainNGramFromFile(
|
||||
filePath: self.target,
|
||||
n: self.n,
|
||||
baseFilePattern: "lm",
|
||||
outputDir: self.outputDirectory,
|
||||
resumeFilePattern: self.resumeFilePattern
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,18 +8,27 @@ final class SwiftTrainer {
|
||||
let n: Int
|
||||
let tokenizer: ZenzTokenizer
|
||||
|
||||
/// Python の defaultdict(int) 相当
|
||||
private var c_abc = [[Int]: Int]()
|
||||
private var c_bc = [[Int]: Int]()
|
||||
private var u_abx = [[Int]: Int]()
|
||||
private var u_xbc = [[Int]: Int]()
|
||||
/// Python の defaultdict(set) 相当
|
||||
private var s_xbx = [[Int]: Set<Int>]()
|
||||
private var r_xbx = [[Int]: Int]()
|
||||
|
||||
init(n: Int, tokenizer: ZenzTokenizer) {
|
||||
self.n = n
|
||||
self.tokenizer = tokenizer
|
||||
}
|
||||
|
||||
init(baseFilePattern: String, n: Int, tokenizer: ZenzTokenizer) {
|
||||
self.tokenizer = tokenizer
|
||||
self.n = n
|
||||
self.c_abc = Self.loadDictionary(from: "\(baseFilePattern)_c_abc.marisa")
|
||||
self.c_bc = Self.loadDictionary(from: "\(baseFilePattern)_c_bc.marisa")
|
||||
self.u_abx = Self.loadDictionary(from: "\(baseFilePattern)_u_abx.marisa")
|
||||
self.u_xbc = Self.loadDictionary(from: "\(baseFilePattern)_u_xbc.marisa")
|
||||
self.r_xbx = Self.loadDictionary(from: "\(baseFilePattern)_r_xbx.marisa")
|
||||
}
|
||||
|
||||
/// 単一 n-gram (abc など) をカウント
|
||||
/// Python の count_ngram に対応
|
||||
private func countNGram(_ ngram: some BidirectionalCollection<Int>) {
|
||||
@@ -31,10 +40,10 @@ final class SwiftTrainer {
|
||||
let Bc = Array(ngram.dropFirst()) // bc
|
||||
// 中央部分 B, 末尾単語 c
|
||||
let B = Array(ngram.dropFirst().dropLast())
|
||||
let c = ngram.last!
|
||||
|
||||
// C(abc)
|
||||
c_abc[aBc, default: 0] += 1
|
||||
// C(bc)
|
||||
c_bc[Bc, default: 0] += 1
|
||||
|
||||
// 初回登場なら U(...) を更新
|
||||
if c_abc[aBc] == 1 {
|
||||
@@ -43,8 +52,11 @@ final class SwiftTrainer {
|
||||
// U(・bc)
|
||||
u_xbc[Bc, default: 0] += 1
|
||||
}
|
||||
// s_xbx[B] = s_xbx[B] ∪ {c}
|
||||
s_xbx[B, default: Set()].insert(c)
|
||||
|
||||
if c_bc[Bc] == 1 {
|
||||
// s_xbx[B] = s_xbx[B] ∪ {c}
|
||||
r_xbx[B, default: 0] += 1
|
||||
}
|
||||
}
|
||||
|
||||
/// 文から n-gram をカウント
|
||||
@@ -102,6 +114,58 @@ final class SwiftTrainer {
|
||||
return key + [Self.keyValueDelimiter] + Self.encodeValue(value: value)
|
||||
}
|
||||
|
||||
private static func loadDictionary(from path: String) -> [[Int]: Int] {
|
||||
let trie = Marisa()
|
||||
trie.load(path)
|
||||
// 空キーで predict 検索するとうまくいかないので、分割して検索する
|
||||
var dict = [[Int]: Int]()
|
||||
for i in Int8(0) ..< Int8.max {
|
||||
for encodedEntry in trie.search([i], .predictive) {
|
||||
if let (key, value) = Self.decodeEncodedEntry(encoded: encodedEntry) {
|
||||
dict[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
return dict
|
||||
}
|
||||
|
||||
/// エンコードされたエントリを [key, value] に復元
|
||||
private static func decodeEncodedEntry(encoded: [Int8]) -> ([Int], Int)? {
|
||||
guard let delimiterIndex = encoded.firstIndex(of: keyValueDelimiter) else {
|
||||
return nil
|
||||
}
|
||||
let keyEncoded = encoded[..<delimiterIndex]
|
||||
let valueEncoded = encoded[(delimiterIndex + 1)...]
|
||||
|
||||
// bulk get 用の delimiter は削除(存在しなければ無視)
|
||||
let filteredKeyEncoded = keyEncoded.filter { $0 != predictiveDelimiter }
|
||||
|
||||
// key は (v1, v2) ペアの繰り返しでエンコードしていた
|
||||
guard filteredKeyEncoded.count % 2 == 0 else {
|
||||
return nil
|
||||
}
|
||||
var key: [Int] = []
|
||||
var index = filteredKeyEncoded.startIndex
|
||||
while index < filteredKeyEncoded.endIndex {
|
||||
let token = decodeKey(
|
||||
v1: filteredKeyEncoded[index],
|
||||
v2: filteredKeyEncoded[filteredKeyEncoded.index(after: index)]
|
||||
)
|
||||
key.append(token)
|
||||
index = filteredKeyEncoded.index(index, offsetBy: 2)
|
||||
}
|
||||
|
||||
// value は常に5バイト
|
||||
guard valueEncoded.count == 5 else { return nil }
|
||||
let d = Int(Int8.max - 1)
|
||||
var value = 0
|
||||
for item in valueEncoded {
|
||||
value = value * d + (Int(item) - 1)
|
||||
}
|
||||
|
||||
return (key, value)
|
||||
}
|
||||
|
||||
/// 指定した [[Int]: Int] を Trie に登録して保存
|
||||
private func buildAndSaveTrie(from dict: [[Int]: Int], to path: String, forBulkGet: Bool = false) {
|
||||
let encode = forBulkGet ? encodeKeyValueForBulkGet : encodeKeyValue
|
||||
@@ -118,7 +182,7 @@ final class SwiftTrainer {
|
||||
|
||||
|
||||
/// 上記のカウント結果を marisa ファイルとして保存
|
||||
func saveToMarisaTrie(baseFilename: String, outputDir: String? = nil) {
|
||||
func saveToMarisaTrie(baseFilePattern: String, outputDir: String? = nil) {
|
||||
let fileManager = FileManager.default
|
||||
|
||||
// 出力フォルダの設定(デフォルト: ~/Library/Application Support/SwiftNGram/marisa/)
|
||||
@@ -144,21 +208,21 @@ final class SwiftTrainer {
|
||||
|
||||
// ファイルパスの生成(marisa ディレクトリ内に配置)
|
||||
let paths = [
|
||||
"\(baseFilename)_c_abc.marisa",
|
||||
"\(baseFilename)_u_abx.marisa",
|
||||
"\(baseFilename)_u_xbc.marisa",
|
||||
"\(baseFilename)_r_xbx.marisa",
|
||||
"\(baseFilePattern)_c_abc.marisa",
|
||||
"\(baseFilePattern)_c_bc.marisa",
|
||||
"\(baseFilePattern)_u_abx.marisa",
|
||||
"\(baseFilePattern)_u_xbc.marisa",
|
||||
"\(baseFilePattern)_r_xbx.marisa",
|
||||
].map { file in
|
||||
marisaDir.appendingPathComponent(file).path
|
||||
}
|
||||
|
||||
// 各 Trie ファイルを保存
|
||||
buildAndSaveTrie(from: c_abc, to: paths[0], forBulkGet: true)
|
||||
buildAndSaveTrie(from: u_abx, to: paths[1])
|
||||
buildAndSaveTrie(from: u_xbc, to: paths[2], forBulkGet: true)
|
||||
|
||||
let r_xbx: [[Int]: Int] = s_xbx.mapValues { $0.count }
|
||||
buildAndSaveTrie(from: r_xbx, to: paths[3])
|
||||
buildAndSaveTrie(from: c_bc, to: paths[1])
|
||||
buildAndSaveTrie(from: u_abx, to: paths[2])
|
||||
buildAndSaveTrie(from: u_xbc, to: paths[3], forBulkGet: true)
|
||||
buildAndSaveTrie(from: r_xbx, to: paths[4])
|
||||
|
||||
// **絶対パスでの出力**
|
||||
print("All saved files (absolute paths):")
|
||||
@@ -192,11 +256,16 @@ public func readLinesFromFile(filePath: String) -> [String]? {
|
||||
public func trainNGram(
|
||||
lines: [String],
|
||||
n: Int,
|
||||
baseFilename: String,
|
||||
outputDir: String? = nil
|
||||
baseFilePattern: String,
|
||||
outputDir: String? = nil,
|
||||
resumeFilePattern: String? = nil
|
||||
) {
|
||||
let tokenizer = ZenzTokenizer()
|
||||
let trainer = SwiftTrainer(n: n, tokenizer: tokenizer)
|
||||
let trainer = if let resumeFilePattern {
|
||||
SwiftTrainer(baseFilePattern: resumeFilePattern, n: n, tokenizer: tokenizer)
|
||||
} else {
|
||||
SwiftTrainer(n: n, tokenizer: tokenizer)
|
||||
}
|
||||
|
||||
for (i, line) in lines.enumerated() {
|
||||
if i % 100 == 0 {
|
||||
@@ -209,18 +278,24 @@ public func trainNGram(
|
||||
}
|
||||
|
||||
// Trie ファイルを保存(出力フォルダを渡す)
|
||||
trainer.saveToMarisaTrie(baseFilename: baseFilename, outputDir: outputDir)
|
||||
trainer.saveToMarisaTrie(baseFilePattern: baseFilePattern, outputDir: outputDir)
|
||||
}
|
||||
|
||||
/// 実行例: ファイルを読み込み、n-gram を学習して保存
|
||||
public func trainNGramFromFile(filePath: String, n: Int, baseFilename: String, outputDir: String? = nil) {
|
||||
public func trainNGramFromFile(
|
||||
filePath: String,
|
||||
n: Int,
|
||||
baseFilePattern: String,
|
||||
outputDir: String? = nil,
|
||||
resumeFilePattern: String? = nil
|
||||
) {
|
||||
guard let lines = readLinesFromFile(filePath: filePath) else {
|
||||
return
|
||||
}
|
||||
trainNGram(lines: lines, n: n, baseFilename: baseFilename, outputDir: outputDir)
|
||||
trainNGram(lines: lines, n: n, baseFilePattern: baseFilePattern, outputDir: outputDir, resumeFilePattern: resumeFilePattern)
|
||||
}
|
||||
#else
|
||||
public func trainNGramFromFile(filePath: String, n: Int, baseFilename: String, outputDir: String? = nil) {
|
||||
public func trainNGramFromFile(filePath: String, n: Int, baseFilePattern: String, outputDir: String? = nil, resumeFilePattern: String? = nil) {
|
||||
fatalError("[Error] trainNGramFromFile is unsupported.")
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user