Files
AzooKeyKanaKanjiConverter/Sources/EfficientNGram/Trainer.swift

227 lines
8.5 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import Foundation
#if canImport(SwiftyMarisa)
import SwiftyMarisa
final class SwiftTrainer {
static let keyValueDelimiter: Int8 = Int8.min
static let predictiveDelimiter: Int8 = Int8.min + 1
let n: Int
let tokenizer: ZenzTokenizer
/// Python defaultdict(int)
private var c_abc = [[Int]: Int]()
private var u_abx = [[Int]: Int]()
private var u_xbc = [[Int]: Int]()
/// Python defaultdict(set)
private var s_xbx = [[Int]: Set<Int>]()
init(n: Int, tokenizer: ZenzTokenizer) {
self.n = n
self.tokenizer = tokenizer
}
/// n-gram (abc )
/// Python count_ngram
private func countNGram(_ ngram: some BidirectionalCollection<Int>) {
// n-gram 2 token ( aB, Bc, B, c )
guard ngram.count >= 2 else { return }
let aBc = Array(ngram) // abc
let aB = Array(ngram.dropLast()) // ab
let Bc = Array(ngram.dropFirst()) // bc
// B, c
let B = Array(ngram.dropFirst().dropLast())
let c = ngram.last!
// C(abc)
c_abc[aBc, default: 0] += 1
// U(...)
if c_abc[aBc] == 1 {
// U(ab)
u_abx[aB, default: 0] += 1
// U(bc)
u_xbc[Bc, default: 0] += 1
}
// s_xbx[B] = s_xbx[B] {c}
s_xbx[B, default: Set()].insert(c)
}
/// n-gram
/// Python count_sent_ngram
private func countSentNGram(n: Int, sent: [Int]) {
// (n-1) <s> </s>
let padded = Array(repeating: self.tokenizer.startTokenID, count: n - 1) + sent + [self.tokenizer.endTokenID]
// n
for i in 0..<(padded.count - n + 1) {
countNGram(padded[i..<i+n])
}
}
/// (2-gramN-gram )
/// Python count_sent
func countSent(_ sentence: String) {
let tokens = self.tokenizer.encode(text: sentence)
for k in 2...n {
countSentNGram(n: k, sent: tokens)
}
}
static func encodeKey(key: [Int]) -> [Int8] {
var int8s: [Int8] = []
int8s.reserveCapacity(key.count * 2 + 1)
for token in key {
let (q, r) = token.quotientAndRemainder(dividingBy: Int(Int8.max - 1))
int8s.append(Int8(q + 1))
int8s.append(Int8(r + 1))
}
return int8s
}
static func encodeValue(value: Int) -> [Int8] {
let div = Int(Int8.max - 1)
let (q1, r1) = value.quotientAndRemainder(dividingBy: div) // value = q1 * div + r1
let (q2, r2) = q1.quotientAndRemainder(dividingBy: div) // value = (q2 * div + r2) * div + r1 = q2 d² + r2 d + r1
let (q3, r3) = q2.quotientAndRemainder(dividingBy: div) // value = q3 d³ + r3 d² + r2 d + r1
let (q4, r4) = q3.quotientAndRemainder(dividingBy: div) // value = q4 d + r4 d³ + r3 d² + r2 d + r1
return [Int8(q4 + 1), Int8(r4 + 1), Int8(r3 + 1), Int8(r2 + 1), Int8(r1 + 1)]
}
static func decodeKey(v1: Int8, v2: Int8) -> Int {
return Int(v1-1) * Int(Int8.max-1) + Int(v2-1)
}
/// + 4 Base64
/// Python encode_key_value(key, value)
private func encodeKeyValue(key: [Int], value: Int) -> [Int8] {
let key = Self.encodeKey(key: key)
return key + [Self.keyValueDelimiter] + Self.encodeValue(value: value)
}
private func encodeKeyValueForBulkGet(key: [Int], value: Int) -> [Int8] {
var key = Self.encodeKey(key: key)
key.insert(Self.predictiveDelimiter, at: key.count - 2) // 1Int82`Int8.min + 1`
return key + [Self.keyValueDelimiter] + Self.encodeValue(value: value)
}
/// [[Int]: Int] Trie
private func buildAndSaveTrie(from dict: [[Int]: Int], to path: String, forBulkGet: Bool = false) {
let encode = forBulkGet ? encodeKeyValueForBulkGet : encodeKeyValue
let encodedStrings: [[Int8]] = dict.map(encode)
let trie = Marisa()
trie.build { builder in
for entry in encodedStrings {
builder(entry)
}
}
trie.save(path)
print("Saved \(path): \(encodedStrings.count) entries")
}
/// marisa
func saveToMarisaTrie(baseFilename: String, outputDir: String? = nil) {
let fileManager = FileManager.default
// : ~/Library/Application Support/SwiftNGram/marisa/
let marisaDir: URL
if let outputDir = outputDir {
marisaDir = URL(fileURLWithPath: outputDir)
} else {
let libraryDir = fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first!
marisaDir = libraryDir.appendingPathComponent("SwiftNGram/marisa", isDirectory: true)
}
//
do {
try fileManager.createDirectory(
at: marisaDir,
withIntermediateDirectories: true, //
attributes: nil
)
} catch {
print("ディレクトリ作成エラー: \(error)")
return
}
// marisa
let paths = [
"\(baseFilename)_c_abc.marisa",
"\(baseFilename)_u_abx.marisa",
"\(baseFilename)_u_xbc.marisa",
"\(baseFilename)_r_xbx.marisa",
].map { file in
marisaDir.appendingPathComponent(file).path
}
// Trie
buildAndSaveTrie(from: c_abc, to: paths[0], forBulkGet: true)
buildAndSaveTrie(from: u_abx, to: paths[1])
buildAndSaveTrie(from: u_xbc, to: paths[2], forBulkGet: true)
let r_xbx: [[Int]: Int] = s_xbx.mapValues { $0.count }
buildAndSaveTrie(from: r_xbx, to: paths[3])
// ****
print("All saved files (absolute paths):")
for path in paths {
print(path)
}
}
}
///
public func readLinesFromFile(filePath: String) -> [String]? {
guard let fileHandle = FileHandle(forReadingAtPath: filePath) else {
print("[Error] ファイルを開けませんでした: \(filePath)")
return nil
}
defer {
try? fileHandle.close()
}
// UTF-8
let data = fileHandle.readDataToEndOfFile()
guard let text = String(data: data, encoding: .utf8) else {
print("[Error] UTF-8 で読み込めませんでした: \(filePath)")
return nil
}
//
return text.components(separatedBy: .newlines).filter { !$0.isEmpty }
}
/// n-gram Marisa-Trie
public func trainNGram(
lines: [String],
n: Int,
baseFilename: String,
outputDir: String? = nil
) {
let tokenizer = ZenzTokenizer()
let trainer = SwiftTrainer(n: n, tokenizer: tokenizer)
for (i, line) in lines.enumerated() {
if i % 100 == 0 {
print(i, "/", lines.count)
}
let trimmed = line.trimmingCharacters(in: .whitespacesAndNewlines)
if !trimmed.isEmpty {
trainer.countSent(trimmed)
}
}
// Trie
trainer.saveToMarisaTrie(baseFilename: baseFilename, outputDir: outputDir)
}
/// : n-gram
public func trainNGramFromFile(filePath: String, n: Int, baseFilename: String, outputDir: String? = nil) {
guard let lines = readLinesFromFile(filePath: filePath) else {
return
}
trainNGram(lines: lines, n: n, baseFilename: baseFilename, outputDir: outputDir)
}
#else
public func trainNGramFromFile(filePath: String, n: Int, baseFilename: String, outputDir: String? = nil) {
fatalError("[Error] trainNGramFromFile is unsupported.")
}
#endif