Files
AzooKeyKanaKanjiConverter/Sources/EfficientNGram/Trainer.swift
2025-07-25 01:33:17 +09:00

302 lines
11 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import Foundation
#if canImport(SwiftyMarisa) && Zenzai
import SwiftyMarisa
final class SwiftTrainer {
static let keyValueDelimiter: Int8 = Int8.min
static let predictiveDelimiter: Int8 = Int8.min + 1
let n: Int
let tokenizer: ZenzTokenizer
private var c_abc = [[Int]: Int]()
private var c_bc = [[Int]: Int]()
private var u_abx = [[Int]: Int]()
private var u_xbc = [[Int]: Int]()
private var r_xbx = [[Int]: Int]()
init(n: Int, tokenizer: ZenzTokenizer) {
self.n = n
self.tokenizer = tokenizer
}
init(baseFilePattern: String, n: Int, tokenizer: ZenzTokenizer) {
self.tokenizer = tokenizer
self.n = n
self.c_abc = Self.loadDictionary(from: "\(baseFilePattern)_c_abc.marisa")
self.c_bc = Self.loadDictionary(from: "\(baseFilePattern)_c_bc.marisa")
self.u_abx = Self.loadDictionary(from: "\(baseFilePattern)_u_abx.marisa")
self.u_xbc = Self.loadDictionary(from: "\(baseFilePattern)_u_xbc.marisa")
self.r_xbx = Self.loadDictionary(from: "\(baseFilePattern)_r_xbx.marisa")
}
/// n-gram (abc )
/// Python count_ngram
private func countNGram(_ ngram: some BidirectionalCollection<Int>) {
// n-gram 2 token ( aB, Bc, B, c )
guard ngram.count >= 2 else { return }
let aBc = Array(ngram) // abc
let aB = Array(ngram.dropLast()) // ab
let Bc = Array(ngram.dropFirst()) // bc
// B, c
let B = Array(ngram.dropFirst().dropLast())
// C(abc)
c_abc[aBc, default: 0] += 1
// C(bc)
c_bc[Bc, default: 0] += 1
// U(...)
if c_abc[aBc] == 1 {
// U(ab)
u_abx[aB, default: 0] += 1
// U(bc)
u_xbc[Bc, default: 0] += 1
}
if c_bc[Bc] == 1 {
// s_xbx[B] = s_xbx[B] {c}
r_xbx[B, default: 0] += 1
}
}
/// n-gram
/// Python count_sent_ngram
private func countSentNGram(n: Int, sent: [Int]) {
// (n-1) <s> </s>
let padded = Array(repeating: self.tokenizer.startTokenID, count: n - 1) + sent + [self.tokenizer.endTokenID]
// n
for i in 0..<(padded.count - n + 1) {
countNGram(padded[i..<i+n])
}
}
/// (2-gramN-gram )
/// Python count_sent
func countSent(_ sentence: String) {
let tokens = self.tokenizer.encode(text: sentence)
for k in 2...n {
countSentNGram(n: k, sent: tokens)
}
}
static func encodeKey(key: [Int]) -> [Int8] {
var int8s: [Int8] = []
int8s.reserveCapacity(key.count * 2 + 1)
for token in key {
let (q, r) = token.quotientAndRemainder(dividingBy: Int(Int8.max - 1))
int8s.append(Int8(q + 1))
int8s.append(Int8(r + 1))
}
return int8s
}
static func encodeValue(value: Int) -> [Int8] {
let div = Int(Int8.max - 1)
let (q1, r1) = value.quotientAndRemainder(dividingBy: div) // value = q1 * div + r1
let (q2, r2) = q1.quotientAndRemainder(dividingBy: div) // value = (q2 * div + r2) * div + r1 = q2 d² + r2 d + r1
let (q3, r3) = q2.quotientAndRemainder(dividingBy: div) // value = q3 d³ + r3 d² + r2 d + r1
let (q4, r4) = q3.quotientAndRemainder(dividingBy: div) // value = q4 d + r4 d³ + r3 d² + r2 d + r1
return [Int8(q4 + 1), Int8(r4 + 1), Int8(r3 + 1), Int8(r2 + 1), Int8(r1 + 1)]
}
static func decodeKey(v1: Int8, v2: Int8) -> Int {
return Int(v1-1) * Int(Int8.max-1) + Int(v2-1)
}
/// + 4 Base64
/// Python encode_key_value(key, value)
private func encodeKeyValue(key: [Int], value: Int) -> [Int8] {
let key = Self.encodeKey(key: key)
return key + [Self.keyValueDelimiter] + Self.encodeValue(value: value)
}
private func encodeKeyValueForBulkGet(key: [Int], value: Int) -> [Int8] {
var key = Self.encodeKey(key: key)
key.insert(Self.predictiveDelimiter, at: key.count - 2) // 1Int82`Int8.min + 1`
return key + [Self.keyValueDelimiter] + Self.encodeValue(value: value)
}
private static func loadDictionary(from path: String) -> [[Int]: Int] {
let trie = Marisa()
trie.load(path)
// predict
var dict = [[Int]: Int]()
for i in Int8(0) ..< Int8.max {
for encodedEntry in trie.search([i], .predictive) {
if let (key, value) = Self.decodeEncodedEntry(encoded: encodedEntry) {
dict[key] = value
}
}
}
return dict
}
/// [key, value]
private static func decodeEncodedEntry(encoded: [Int8]) -> ([Int], Int)? {
guard let delimiterIndex = encoded.firstIndex(of: keyValueDelimiter) else {
return nil
}
let keyEncoded = encoded[..<delimiterIndex]
let valueEncoded = encoded[(delimiterIndex + 1)...]
// bulk get delimiter
let filteredKeyEncoded = keyEncoded.filter { $0 != predictiveDelimiter }
// key (v1, v2)
guard filteredKeyEncoded.count % 2 == 0 else {
return nil
}
var key: [Int] = []
var index = filteredKeyEncoded.startIndex
while index < filteredKeyEncoded.endIndex {
let token = decodeKey(
v1: filteredKeyEncoded[index],
v2: filteredKeyEncoded[filteredKeyEncoded.index(after: index)]
)
key.append(token)
index = filteredKeyEncoded.index(index, offsetBy: 2)
}
// value 5
guard valueEncoded.count == 5 else { return nil }
let d = Int(Int8.max - 1)
var value = 0
for item in valueEncoded {
value = value * d + (Int(item) - 1)
}
return (key, value)
}
/// [[Int]: Int] Trie
private func buildAndSaveTrie(from dict: [[Int]: Int], to path: String, forBulkGet: Bool = false) {
let encode = forBulkGet ? encodeKeyValueForBulkGet : encodeKeyValue
let encodedStrings: [[Int8]] = dict.map(encode)
let trie = Marisa()
trie.build { builder in
for entry in encodedStrings {
builder(entry)
}
}
trie.save(path)
}
/// marisa
func saveToMarisaTrie(baseFilePattern: String, outputDir: String? = nil) {
let fileManager = FileManager.default
// : ~/Library/Application Support/SwiftNGram/marisa/
let marisaDir: URL
if let outputDir = outputDir {
marisaDir = URL(fileURLWithPath: outputDir)
} else {
let libraryDir = fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first!
marisaDir = libraryDir.appendingPathComponent("SwiftNGram/marisa", isDirectory: true)
}
//
do {
try fileManager.createDirectory(
at: marisaDir,
withIntermediateDirectories: true, //
attributes: nil
)
} catch {
return
}
// marisa
let paths = [
"\(baseFilePattern)_c_abc.marisa",
"\(baseFilePattern)_c_bc.marisa",
"\(baseFilePattern)_u_abx.marisa",
"\(baseFilePattern)_u_xbc.marisa",
"\(baseFilePattern)_r_xbx.marisa",
].map { file in
marisaDir.appendingPathComponent(file).path
}
// Trie
buildAndSaveTrie(from: c_abc, to: paths[0], forBulkGet: true)
buildAndSaveTrie(from: c_bc, to: paths[1])
buildAndSaveTrie(from: u_abx, to: paths[2])
buildAndSaveTrie(from: u_xbc, to: paths[3], forBulkGet: true)
buildAndSaveTrie(from: r_xbx, to: paths[4])
// ****
for path in paths {
}
}
}
///
public func readLinesFromFile(filePath: String) -> [String]? {
guard let fileHandle = FileHandle(forReadingAtPath: filePath) else {
return nil
}
defer {
try? fileHandle.close()
}
// UTF-8
let data = fileHandle.readDataToEndOfFile()
guard let text = String(data: data, encoding: .utf8) else {
return nil
}
//
return text.components(separatedBy: .newlines).filter { !$0.isEmpty }
}
/// n-gram Marisa-Trie
public func trainNGram(
lines: [String],
n: Int,
baseFilePattern: String,
outputDir: String? = nil,
resumeFilePattern: String? = nil
) {
let tokenizer = ZenzTokenizer()
let trainer = if let resumeFilePattern {
SwiftTrainer(baseFilePattern: resumeFilePattern, n: n, tokenizer: tokenizer)
} else {
SwiftTrainer(n: n, tokenizer: tokenizer)
}
for (i, line) in lines.enumerated() {
if i % 100 == 0 {
}
let trimmed = line.trimmingCharacters(in: .whitespacesAndNewlines)
if !trimmed.isEmpty {
trainer.countSent(trimmed)
}
}
// Trie
trainer.saveToMarisaTrie(baseFilePattern: baseFilePattern, outputDir: outputDir)
}
/// : n-gram
public func trainNGramFromFile(
filePath: String,
n: Int,
baseFilePattern: String,
outputDir: String? = nil,
resumeFilePattern: String? = nil
) {
guard let lines = readLinesFromFile(filePath: filePath) else {
return
}
trainNGram(lines: lines, n: n, baseFilePattern: baseFilePattern, outputDir: outputDir, resumeFilePattern: resumeFilePattern)
}
#else
public func trainNGramFromFile(filePath: String, n: Int, baseFilePattern: String, outputDir: String? = nil, resumeFilePattern: String? = nil) {
fatalError("[Error] trainNGramFromFile is unsupported.")
}
#endif