mirror of
https://github.com/mii443/AzooKeyKanaKanjiConverter.git
synced 2025-08-22 15:05:26 +00:00
feat: SwiftNGramWithMarisaTrieへの依存を除去し、内部の実装をターゲットとして追加 (#153)
This commit is contained in:
4
.github/workflows/swift.yml
vendored
4
.github/workflows/swift.yml
vendored
@ -51,8 +51,8 @@ jobs:
|
|||||||
os: [windows-latest]
|
os: [windows-latest]
|
||||||
swift-version:
|
swift-version:
|
||||||
[{
|
[{
|
||||||
branch: "swift-6.0.3-release",
|
branch: "swift-6.0.2-release",
|
||||||
tag: "6.0.3-RELEASE"
|
tag: "6.0.2-RELEASE"
|
||||||
}]
|
}]
|
||||||
steps:
|
steps:
|
||||||
- uses: compnerd/gha-setup-swift@main
|
- uses: compnerd/gha-setup-swift@main
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
import PackageDescription
|
import PackageDescription
|
||||||
import Foundation
|
import Foundation
|
||||||
|
|
||||||
let swiftSettings: [SwiftSetting] = [
|
var swiftSettings: [SwiftSetting] = [
|
||||||
.enableUpcomingFeature("BareSlashRegexLiterals"),
|
.enableUpcomingFeature("BareSlashRegexLiterals"),
|
||||||
.enableUpcomingFeature("ConciseMagicFile"),
|
.enableUpcomingFeature("ConciseMagicFile"),
|
||||||
.enableUpcomingFeature("ExistentialAny"),
|
.enableUpcomingFeature("ExistentialAny"),
|
||||||
@ -13,7 +13,6 @@ let swiftSettings: [SwiftSetting] = [
|
|||||||
.enableUpcomingFeature("StrictConcurrency"),
|
.enableUpcomingFeature("StrictConcurrency"),
|
||||||
.enableUpcomingFeature("DisableOutwardActorInference"),
|
.enableUpcomingFeature("DisableOutwardActorInference"),
|
||||||
.enableUpcomingFeature("ImportObjcForwardDeclarations"),
|
.enableUpcomingFeature("ImportObjcForwardDeclarations"),
|
||||||
.interoperabilityMode(.Cxx),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
var dependencies: [Package.Dependency] = [
|
var dependencies: [Package.Dependency] = [
|
||||||
@ -22,9 +21,23 @@ var dependencies: [Package.Dependency] = [
|
|||||||
.package(url: "https://github.com/apple/swift-algorithms", from: "1.0.0"),
|
.package(url: "https://github.com/apple/swift-algorithms", from: "1.0.0"),
|
||||||
.package(url: "https://github.com/apple/swift-collections", from: "1.0.0"),
|
.package(url: "https://github.com/apple/swift-collections", from: "1.0.0"),
|
||||||
.package(url: "https://github.com/apple/swift-argument-parser", .upToNextMajor(from: "1.0.0")),
|
.package(url: "https://github.com/apple/swift-argument-parser", .upToNextMajor(from: "1.0.0")),
|
||||||
.package(url: "https://github.com/nyanko3141592/SwiftNGramWiithMarisaTrie", branch: "0171169c1eb8f7bacceb8236ce7d324050d893c2")
|
.package(url: "https://github.com/ensan-hcl/swift-tokenizers", branch: "feat/minimum")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
var efficientNGramDependencies: [Target.Dependency] = [.product(name: "Transformers", package: "swift-tokenizers")]
|
||||||
|
#if (!os(Linux) || !canImport(Android)) && !os(Windows)
|
||||||
|
// Android環境・Windows環境ではSwiftyMarisaが利用できないため、除外する。
|
||||||
|
// したがって、EfficientNGramの動作はサポートしない。
|
||||||
|
if let envValue = ProcessInfo.processInfo.environment["LLAMA_MOCK"], envValue == "1" {
|
||||||
|
// LLAMA_MOCK=1の場合もサポートしない
|
||||||
|
} else {
|
||||||
|
dependencies.append(.package(url: "https://github.com/ensan-hcl/SwiftyMarisa", branch: "6e145aef5583aac96dd7ff8f9fbb9944d893128e"))
|
||||||
|
efficientNGramDependencies.append("SwiftyMarisa")
|
||||||
|
swiftSettings.append(.interoperabilityMode(.Cxx))
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
var targets: [Target] = [
|
var targets: [Target] = [
|
||||||
// Targets are the basic building blocks of a package. A target can define a module or a test suite.
|
// Targets are the basic building blocks of a package. A target can define a module or a test suite.
|
||||||
// Targets can depend on other targets in this package, and on products in packages this package depends on.
|
// Targets can depend on other targets in this package, and on products in packages this package depends on.
|
||||||
@ -36,6 +49,12 @@ var targets: [Target] = [
|
|||||||
resources: [],
|
resources: [],
|
||||||
swiftSettings: swiftSettings
|
swiftSettings: swiftSettings
|
||||||
),
|
),
|
||||||
|
.target(
|
||||||
|
name: "EfficientNGram",
|
||||||
|
dependencies: efficientNGramDependencies,
|
||||||
|
resources: [.copy("tokenizer")],
|
||||||
|
swiftSettings: swiftSettings
|
||||||
|
),
|
||||||
.target(
|
.target(
|
||||||
name: "KanaKanjiConverterModuleWithDefaultDictionary",
|
name: "KanaKanjiConverterModuleWithDefaultDictionary",
|
||||||
dependencies: [
|
dependencies: [
|
||||||
@ -69,6 +88,12 @@ var targets: [Target] = [
|
|||||||
resources: [],
|
resources: [],
|
||||||
swiftSettings: swiftSettings
|
swiftSettings: swiftSettings
|
||||||
),
|
),
|
||||||
|
.testTarget(
|
||||||
|
name: "EfficientNGramTests",
|
||||||
|
dependencies: ["EfficientNGram"],
|
||||||
|
resources: [],
|
||||||
|
swiftSettings: swiftSettings
|
||||||
|
),
|
||||||
.testTarget(
|
.testTarget(
|
||||||
name: "KanaKanjiConverterModuleTests",
|
name: "KanaKanjiConverterModuleTests",
|
||||||
dependencies: ["KanaKanjiConverterModule"],
|
dependencies: ["KanaKanjiConverterModule"],
|
||||||
@ -131,8 +156,8 @@ targets.append(contentsOf: [
|
|||||||
dependencies: [
|
dependencies: [
|
||||||
"SwiftUtils",
|
"SwiftUtils",
|
||||||
"llama.cpp",
|
"llama.cpp",
|
||||||
|
"EfficientNGram",
|
||||||
.product(name: "Collections", package: "swift-collections"),
|
.product(name: "Collections", package: "swift-collections"),
|
||||||
.product(name: "SwiftNGram", package: "SwiftNGramWiithMarisaTrie")
|
|
||||||
],
|
],
|
||||||
swiftSettings: swiftSettings
|
swiftSettings: swiftSettings
|
||||||
)
|
)
|
||||||
@ -146,8 +171,8 @@ if let envValue = ProcessInfo.processInfo.environment["LLAMA_MOCK"], envValue ==
|
|||||||
dependencies: [
|
dependencies: [
|
||||||
"SwiftUtils",
|
"SwiftUtils",
|
||||||
"llama-mock",
|
"llama-mock",
|
||||||
|
"EfficientNGram",
|
||||||
.product(name: "Collections", package: "swift-collections"),
|
.product(name: "Collections", package: "swift-collections"),
|
||||||
.product(name: "SwiftNGram", package: "SwiftNGramWiithMarisaTrie")
|
|
||||||
],
|
],
|
||||||
swiftSettings: swiftSettings
|
swiftSettings: swiftSettings
|
||||||
)
|
)
|
||||||
@ -162,9 +187,9 @@ if let envValue = ProcessInfo.processInfo.environment["LLAMA_MOCK"], envValue ==
|
|||||||
name: "KanaKanjiConverterModule",
|
name: "KanaKanjiConverterModule",
|
||||||
dependencies: [
|
dependencies: [
|
||||||
"SwiftUtils",
|
"SwiftUtils",
|
||||||
|
"EfficientNGram",
|
||||||
.product(name: "llama", package: "llama.cpp"),
|
.product(name: "llama", package: "llama.cpp"),
|
||||||
.product(name: "Collections", package: "swift-collections"),
|
.product(name: "Collections", package: "swift-collections"),
|
||||||
.product(name: "SwiftNGram", package: "SwiftNGramWiithMarisaTrie")
|
|
||||||
],
|
],
|
||||||
swiftSettings: swiftSettings
|
swiftSettings: swiftSettings
|
||||||
)
|
)
|
||||||
|
@ -11,7 +11,8 @@ public struct Anco: AsyncParsableCommand {
|
|||||||
Subcommands.Evaluate.self,
|
Subcommands.Evaluate.self,
|
||||||
Subcommands.ZenzEvaluate.self,
|
Subcommands.ZenzEvaluate.self,
|
||||||
Subcommands.Session.self,
|
Subcommands.Session.self,
|
||||||
Subcommands.ExperimentalPredict.self
|
Subcommands.ExperimentalPredict.self,
|
||||||
|
Subcommands.NGram.self
|
||||||
],
|
],
|
||||||
defaultSubcommand: Subcommands.Run.self
|
defaultSubcommand: Subcommands.Run.self
|
||||||
)
|
)
|
||||||
|
@ -0,0 +1,62 @@
|
|||||||
|
import Foundation
|
||||||
|
import EfficientNGram
|
||||||
|
import ArgumentParser
|
||||||
|
|
||||||
|
extension Subcommands.NGram {
|
||||||
|
struct Inference: ParsableCommand {
|
||||||
|
@Argument(help: "学習済みのLM")
|
||||||
|
var lmPattern: String = ""
|
||||||
|
|
||||||
|
@Option(name: [.customLong("another_lm")], help: "Another lm for flavored decoding")
|
||||||
|
var anotherLMPattern: String?
|
||||||
|
|
||||||
|
@Option(name: [.customLong("alpha")], help: "alpha for flavored decoding")
|
||||||
|
var alpha: Double = 0.5
|
||||||
|
|
||||||
|
@Option(name: [.customLong("prompt"), .customShort("p")], help: "The prompt for inference.")
|
||||||
|
var prompt: String = "これは"
|
||||||
|
|
||||||
|
@Option(name: [.customShort("n")], help: "n-gram's n")
|
||||||
|
var n: Int = 5
|
||||||
|
|
||||||
|
@Option(name: [.customLong("length"), .customShort("l")], help: "token length for generation")
|
||||||
|
var length: Int = 100
|
||||||
|
|
||||||
|
static let configuration = CommandConfiguration(
|
||||||
|
commandName: "inference",
|
||||||
|
abstract: "Inference using ngram"
|
||||||
|
)
|
||||||
|
|
||||||
|
private func measureExecutionTime(block: () -> String) -> (String, Double) {
|
||||||
|
let start = DispatchTime.now()
|
||||||
|
let result = block()
|
||||||
|
let end = DispatchTime.now()
|
||||||
|
let nanoTime = end.uptimeNanoseconds - start.uptimeNanoseconds
|
||||||
|
let milliTime = Double(nanoTime) / 1_000_000 // ミリ秒単位
|
||||||
|
return (result, milliTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
mutating func run() throws {
|
||||||
|
print("Loading LM base: \(self.lmPattern)")
|
||||||
|
let tokenizer = ZenzTokenizer()
|
||||||
|
let lmBase = EfficientNGram(baseFilename: self.lmPattern, n: self.n, d: 0.75, tokenizer: tokenizer)
|
||||||
|
let lmPerson = if let anotherLMPattern {
|
||||||
|
EfficientNGram(baseFilename: anotherLMPattern, n: self.n, d: 0.75, tokenizer: tokenizer)
|
||||||
|
} else {
|
||||||
|
lmBase
|
||||||
|
}
|
||||||
|
let (generatedText, elapsedTime) = measureExecutionTime {
|
||||||
|
generateText(
|
||||||
|
inputText: self.prompt,
|
||||||
|
mixAlpha: self.alpha,
|
||||||
|
lmBase: lmBase,
|
||||||
|
lmPerson: lmPerson,
|
||||||
|
tokenizer: tokenizer,
|
||||||
|
maxCount: self.length
|
||||||
|
)
|
||||||
|
}
|
||||||
|
print("\(bold: "Generated"): \(generatedText)")
|
||||||
|
print("\(bold: "Execution Time"): \(elapsedTime) ms")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
12
Sources/CliTool/Subcommands/NGramCommands/NGramCommand.swift
Normal file
12
Sources/CliTool/Subcommands/NGramCommands/NGramCommand.swift
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import Foundation
|
||||||
|
import ArgumentParser
|
||||||
|
|
||||||
|
extension Subcommands {
|
||||||
|
struct NGram: ParsableCommand {
|
||||||
|
static let configuration = CommandConfiguration(
|
||||||
|
commandName: "ngram",
|
||||||
|
abstract: "Use EfficientNGram Implementation",
|
||||||
|
subcommands: [Self.Train.self, Self.Inference.self]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
27
Sources/CliTool/Subcommands/NGramCommands/TrainCommand.swift
Normal file
27
Sources/CliTool/Subcommands/NGramCommands/TrainCommand.swift
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import Foundation
|
||||||
|
import EfficientNGram
|
||||||
|
import ArgumentParser
|
||||||
|
|
||||||
|
extension Subcommands.NGram {
|
||||||
|
struct Train: ParsableCommand {
|
||||||
|
@Argument(help: "学習テキストデータのfilename")
|
||||||
|
var target: String = ""
|
||||||
|
|
||||||
|
@Option(name: [.customLong("output_dir"), .customShort("o")], help: "The directory for output lm data.")
|
||||||
|
var outputDirectory: String = "./"
|
||||||
|
|
||||||
|
@Option(name: [.customShort("n")], help: "n-gram's n")
|
||||||
|
var n: Int = 5
|
||||||
|
|
||||||
|
static let configuration = CommandConfiguration(
|
||||||
|
commandName: "train",
|
||||||
|
abstract: "Train ngram and write the data"
|
||||||
|
)
|
||||||
|
|
||||||
|
mutating func run() throws {
|
||||||
|
let pattern = URL(fileURLWithPath: self.outputDirectory).path() + "lm_"
|
||||||
|
print("Saving for \(pattern)")
|
||||||
|
trainNGramFromFile(filePath: self.target, n: self.n, baseFilename: "lm", outputDir: self.outputDirectory)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
230
Sources/EfficientNGram/Inference.swift
Normal file
230
Sources/EfficientNGram/Inference.swift
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
import Foundation
|
||||||
|
#if canImport(SwiftyMarisa)
|
||||||
|
import SwiftyMarisa
|
||||||
|
|
||||||
|
/// Base64 でエンコードされた Key-Value をデコードする関数
|
||||||
|
private func decodeKeyValue(_ suffix: some Collection<Int8>) -> UInt32? {
|
||||||
|
// 最初の5個が値をエンコードしている
|
||||||
|
let d = Int(Int8.max - 1)
|
||||||
|
var value = 0
|
||||||
|
for item in suffix.prefix(5) {
|
||||||
|
value *= d
|
||||||
|
value += Int(item) - 1
|
||||||
|
}
|
||||||
|
return UInt32(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Kneser-Ney 言語モデル
|
||||||
|
public struct EfficientNGram {
|
||||||
|
public let n: Int
|
||||||
|
public let d: Double
|
||||||
|
|
||||||
|
// Tries
|
||||||
|
let c_abc: Marisa
|
||||||
|
let u_abx: Marisa
|
||||||
|
let u_xbc: Marisa
|
||||||
|
let r_xbx: Marisa
|
||||||
|
|
||||||
|
private var tokenizer: ZenzTokenizer
|
||||||
|
|
||||||
|
public init(baseFilename: String, n: Int, d: Double, tokenizer: ZenzTokenizer) {
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
// ストアドプロパティを一度に全て初期化(“仮”の値で OK)
|
||||||
|
self.n = n
|
||||||
|
self.d = d
|
||||||
|
|
||||||
|
self.c_abc = Marisa()
|
||||||
|
self.u_abx = Marisa()
|
||||||
|
self.u_xbc = Marisa()
|
||||||
|
self.r_xbx = Marisa()
|
||||||
|
|
||||||
|
c_abc.load("\(baseFilename)_c_abc.marisa")
|
||||||
|
u_abx.load("\(baseFilename)_u_abx.marisa")
|
||||||
|
u_xbc.load("\(baseFilename)_u_xbc.marisa")
|
||||||
|
r_xbx.load("\(baseFilename)_r_xbx.marisa")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trie から Key に対応する Value を取得する関数
|
||||||
|
private func getValue(from trie: Marisa, key: [Int]) -> UInt32? {
|
||||||
|
let int8s = SwiftTrainer.encodeKey(key: key) + [SwiftTrainer.keyValueDelimiter] // delimiter ( as it is negative, it must not appear in key part)
|
||||||
|
let results = trie.search(int8s, .predictive)
|
||||||
|
for result in results {
|
||||||
|
if let decoded = decodeKeyValue(result.dropFirst(int8s.count)) {
|
||||||
|
return decoded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 「prefix + 次の1文字」を扱うケースでbulk処理で高速化する
|
||||||
|
private func bulkGetValueWithSum(from trie: Marisa, prefix: [Int]) -> (values: [UInt32], sum: UInt32) {
|
||||||
|
let int8s = SwiftTrainer.encodeKey(key: prefix) + [SwiftTrainer.predictiveDelimiter] // 予測用のdelimiter
|
||||||
|
let results = trie.search(int8s, .predictive)
|
||||||
|
var dict = [UInt32](repeating: 0, count: self.tokenizer.vocabSize)
|
||||||
|
var sum: UInt32 = 0
|
||||||
|
for result in results {
|
||||||
|
var suffix = result.dropFirst(int8s.count)
|
||||||
|
let v1 = suffix.removeFirst()
|
||||||
|
let v2 = suffix.removeFirst()
|
||||||
|
// delimiterを除去
|
||||||
|
if suffix.first != SwiftTrainer.keyValueDelimiter {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
suffix.removeFirst()
|
||||||
|
if let decoded = decodeKeyValue(suffix) {
|
||||||
|
let word = SwiftTrainer.decodeKey(v1: v1, v2: v2)
|
||||||
|
dict[word] = decoded
|
||||||
|
sum += decoded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return (dict, sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Kneser-Ney Smoothingを入れたNgram LMの実装
|
||||||
|
func predict(
|
||||||
|
nextWord: Int,
|
||||||
|
c_abx_ab: UInt32,
|
||||||
|
u_abx_ab: UInt32,
|
||||||
|
c_abc_abc: UInt32,
|
||||||
|
plf_items: [(
|
||||||
|
u_xbc_abc: [UInt32],
|
||||||
|
u_xbx_ab: UInt32,
|
||||||
|
r_xbx_ab: UInt32
|
||||||
|
)]
|
||||||
|
) -> Double {
|
||||||
|
// ngram = [a, b, c]
|
||||||
|
// abc = "a|b|c"
|
||||||
|
// ab = "a|b"
|
||||||
|
let alpha, gamma: Double
|
||||||
|
if c_abx_ab != 0 {
|
||||||
|
alpha = max(0, Double(c_abc_abc) - self.d) / Double(c_abx_ab)
|
||||||
|
gamma = self.d * Double(u_abx_ab) / Double(c_abx_ab)
|
||||||
|
} else {
|
||||||
|
alpha = 0
|
||||||
|
gamma = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// predict_lowerの処理
|
||||||
|
var plf = 0.0
|
||||||
|
var coef = 1.0
|
||||||
|
for (u_xbc_abc, u_xbx_ab, r_xbx_ab) in plf_items {
|
||||||
|
let alpha, gamma: Double
|
||||||
|
if u_xbx_ab > 0 {
|
||||||
|
alpha = max(0, Double(u_xbc_abc[nextWord]) - self.d) / Double(u_xbx_ab)
|
||||||
|
gamma = self.d * Double(r_xbx_ab) / Double(u_xbx_ab)
|
||||||
|
} else {
|
||||||
|
alpha = 0
|
||||||
|
gamma = 1
|
||||||
|
}
|
||||||
|
plf += alpha * coef
|
||||||
|
coef *= gamma
|
||||||
|
}
|
||||||
|
plf += coef / Double(self.tokenizer.vocabSize)
|
||||||
|
|
||||||
|
let prob = alpha + gamma * plf
|
||||||
|
return prob
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Kneser-Ney の確率を求める
|
||||||
|
public func bulkPredict(_ ngram: some BidirectionalCollection<Int>) -> [Double] {
|
||||||
|
// abがn-1個の要素を持つように調整する
|
||||||
|
let ab = if ngram.count > self.n - 1 {
|
||||||
|
Array(ngram.suffix(self.n - 1))
|
||||||
|
} else if ngram.count == self.n - 1 {
|
||||||
|
Array(ngram)
|
||||||
|
} else {
|
||||||
|
Array(repeating: self.tokenizer.startTokenID, count: self.n - 1 - ngram.count) + Array(ngram)
|
||||||
|
}
|
||||||
|
let u_abx_ab = self.getValue(from: u_abx, key: ab) ?? 0
|
||||||
|
let (c_abc_abc, c_abx_ab) = self.bulkGetValueWithSum(from: self.c_abc, prefix: ab)
|
||||||
|
var plf_items: [(u_xbc_abc: [UInt32], u_xbx_ab: UInt32, r_xbx_ab: UInt32)] = []
|
||||||
|
for i in 1 ..< self.n - 1 {
|
||||||
|
let ab = Array(ab.dropFirst(i))
|
||||||
|
let r_xbx_ab = self.getValue(from: self.r_xbx, key: ab) ?? 0
|
||||||
|
let (u_xbc_abc, u_xbx_ab) = self.bulkGetValueWithSum(from: self.u_xbc, prefix: ab)
|
||||||
|
plf_items.append((u_xbc_abc: u_xbc_abc, u_xbx_ab: u_xbx_ab, r_xbx_ab: r_xbx_ab))
|
||||||
|
}
|
||||||
|
// 全候補を探索
|
||||||
|
var results = [Double]()
|
||||||
|
results.reserveCapacity(tokenizer.vocabSize)
|
||||||
|
for w in 0 ..< tokenizer.vocabSize {
|
||||||
|
results.append(self.predict(nextWord: w, c_abx_ab: c_abx_ab, u_abx_ab: u_abx_ab, c_abc_abc: c_abc_abc[w], plf_items: plf_items))
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// テキスト生成
|
||||||
|
package func generateText(
|
||||||
|
inputText: String,
|
||||||
|
mixAlpha: Double,
|
||||||
|
lmBase: EfficientNGram,
|
||||||
|
lmPerson: EfficientNGram,
|
||||||
|
tokenizer: ZenzTokenizer,
|
||||||
|
maxCount: Int = 100
|
||||||
|
) -> String
|
||||||
|
{
|
||||||
|
// もともとの文字列を配列化
|
||||||
|
var tokens = tokenizer.encode(text: inputText)
|
||||||
|
// suffix を事前に取り出す
|
||||||
|
var suffix = tokens.suffix(lmBase.n - 1)
|
||||||
|
|
||||||
|
while tokens.count < maxCount {
|
||||||
|
var maxProb = -Double.infinity
|
||||||
|
var nextWord = -1
|
||||||
|
|
||||||
|
// 全候補を探索
|
||||||
|
let pBases = lmBase.bulkPredict(suffix)
|
||||||
|
let pPersons = lmPerson.bulkPredict(suffix)
|
||||||
|
for (w, (pBase, pPerson)) in zip(pBases, pPersons).enumerated() {
|
||||||
|
// どちらかが 0 ならスキップ
|
||||||
|
if pBase == 0.0 || pPerson == 0.0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
let mixLogProb = (1 - mixAlpha) * log2(pBase) + mixAlpha * log2(pPerson)
|
||||||
|
|
||||||
|
if mixLogProb > maxProb {
|
||||||
|
maxProb = mixLogProb
|
||||||
|
nextWord = w
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 候補なし or EOS なら生成終了
|
||||||
|
if nextWord == -1 || nextWord == tokenizer.endTokenID {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1文字単位のモデルなら append で文字列に追加
|
||||||
|
// (単語単位なら間にスペースを入れる等の工夫が必要)
|
||||||
|
tokens.append(nextWord)
|
||||||
|
|
||||||
|
// suffix を更新
|
||||||
|
suffix.append(nextWord)
|
||||||
|
if suffix.count > (lmBase.n - 1) {
|
||||||
|
suffix.removeFirst()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokenizer.decode(tokens: tokens)
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
/// Mock Implementation
|
||||||
|
public struct EfficientNGram {
|
||||||
|
public init(baseFilename: String, n: Int, d: Double, tokenizer: ZenzTokenizer) {}
|
||||||
|
public func bulkPredict(_ ngram: some BidirectionalCollection<Int>) -> [Double] {
|
||||||
|
// FIXME: avoid hard-coding
|
||||||
|
return [Double].init(repeating: 1/6000, count: 6000)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
package func generateText(
|
||||||
|
inputText: String,
|
||||||
|
mixAlpha: Double,
|
||||||
|
lmBase: EfficientNGram,
|
||||||
|
lmPerson: EfficientNGram,
|
||||||
|
tokenizer: ZenzTokenizer,
|
||||||
|
maxCount: Int = 100
|
||||||
|
) -> String {
|
||||||
|
return "[Error] Unsupported"
|
||||||
|
}
|
||||||
|
#endif
|
31
Sources/EfficientNGram/Tokenizer.swift
Normal file
31
Sources/EfficientNGram/Tokenizer.swift
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import Tokenizers
|
||||||
|
import Hub
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
public struct ZenzTokenizer {
|
||||||
|
private let tokenizer: any Tokenizer
|
||||||
|
public init() {
|
||||||
|
let modelFolder = Bundle.module.resourceURL!.appendingPathComponent("tokenizer", isDirectory: true)
|
||||||
|
let hubApi = HubApi.shared
|
||||||
|
let tokenizerConfig = try! hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer_config.json"))
|
||||||
|
let tokenizerData = try! hubApi.configuration(fileURL: modelFolder.appending(path: "tokenizer.json"))
|
||||||
|
let tokenizer = try! AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
}
|
||||||
|
func encode(text: String) -> [Int] {
|
||||||
|
return self.tokenizer.encode(text: text)
|
||||||
|
}
|
||||||
|
func decode(tokens: [Int]) -> String {
|
||||||
|
return self.tokenizer.decode(tokens: tokens)
|
||||||
|
}
|
||||||
|
var startTokenID: Int {
|
||||||
|
self.tokenizer.bosTokenId!
|
||||||
|
}
|
||||||
|
var endTokenID: Int {
|
||||||
|
self.tokenizer.eosTokenId!
|
||||||
|
}
|
||||||
|
var vocabSize: Int {
|
||||||
|
// FIXME
|
||||||
|
6000
|
||||||
|
}
|
||||||
|
}
|
226
Sources/EfficientNGram/Trainer.swift
Normal file
226
Sources/EfficientNGram/Trainer.swift
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
import Foundation
|
||||||
|
#if canImport(SwiftyMarisa)
|
||||||
|
import SwiftyMarisa
|
||||||
|
|
||||||
|
final class SwiftTrainer {
|
||||||
|
static let keyValueDelimiter: Int8 = Int8.min
|
||||||
|
static let predictiveDelimiter: Int8 = Int8.min + 1
|
||||||
|
let n: Int
|
||||||
|
let tokenizer: ZenzTokenizer
|
||||||
|
|
||||||
|
/// Python の defaultdict(int) 相当
|
||||||
|
private var c_abc = [[Int]: Int]()
|
||||||
|
private var u_abx = [[Int]: Int]()
|
||||||
|
private var u_xbc = [[Int]: Int]()
|
||||||
|
/// Python の defaultdict(set) 相当
|
||||||
|
private var s_xbx = [[Int]: Set<Int>]()
|
||||||
|
|
||||||
|
init(n: Int, tokenizer: ZenzTokenizer) {
|
||||||
|
self.n = n
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 単一 n-gram (abc など) をカウント
|
||||||
|
/// Python の count_ngram に対応
|
||||||
|
private func countNGram(_ ngram: some BidirectionalCollection<Int>) {
|
||||||
|
// n-gram は最低 2 token 必要 (式的に aB, Bc, B, c のような分割を行う)
|
||||||
|
guard ngram.count >= 2 else { return }
|
||||||
|
|
||||||
|
let aBc = Array(ngram) // abc
|
||||||
|
let aB = Array(ngram.dropLast()) // ab
|
||||||
|
let Bc = Array(ngram.dropFirst()) // bc
|
||||||
|
// 中央部分 B, 末尾単語 c
|
||||||
|
let B = Array(ngram.dropFirst().dropLast())
|
||||||
|
let c = ngram.last!
|
||||||
|
|
||||||
|
// C(abc)
|
||||||
|
c_abc[aBc, default: 0] += 1
|
||||||
|
|
||||||
|
// 初回登場なら U(...) を更新
|
||||||
|
if c_abc[aBc] == 1 {
|
||||||
|
// U(ab・)
|
||||||
|
u_abx[aB, default: 0] += 1
|
||||||
|
// U(・bc)
|
||||||
|
u_xbc[Bc, default: 0] += 1
|
||||||
|
}
|
||||||
|
// s_xbx[B] = s_xbx[B] ∪ {c}
|
||||||
|
s_xbx[B, default: Set()].insert(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 文から n-gram をカウント
|
||||||
|
/// Python の count_sent_ngram に対応
|
||||||
|
private func countSentNGram(n: Int, sent: [Int]) {
|
||||||
|
// 先頭に (n-1) 個の <s>、末尾に </s> を追加
|
||||||
|
let padded = Array(repeating: self.tokenizer.startTokenID, count: n - 1) + sent + [self.tokenizer.endTokenID]
|
||||||
|
// スライディングウィンドウで n 個ずつ
|
||||||
|
for i in 0..<(padded.count - n + 1) {
|
||||||
|
countNGram(padded[i..<i+n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 文全体をカウント (2-gram~N-gram までをまとめて処理)
|
||||||
|
/// Python の count_sent に対応
|
||||||
|
func countSent(_ sentence: String) {
|
||||||
|
let tokens = self.tokenizer.encode(text: sentence)
|
||||||
|
for k in 2...n {
|
||||||
|
countSentNGram(n: k, sent: tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static func encodeKey(key: [Int]) -> [Int8] {
|
||||||
|
var int8s: [Int8] = []
|
||||||
|
int8s.reserveCapacity(key.count * 2 + 1)
|
||||||
|
for token in key {
|
||||||
|
let (q, r) = token.quotientAndRemainder(dividingBy: Int(Int8.max - 1))
|
||||||
|
int8s.append(Int8(q + 1))
|
||||||
|
int8s.append(Int8(r + 1))
|
||||||
|
}
|
||||||
|
return int8s
|
||||||
|
}
|
||||||
|
static func encodeValue(value: Int) -> [Int8] {
|
||||||
|
let div = Int(Int8.max - 1)
|
||||||
|
let (q1, r1) = value.quotientAndRemainder(dividingBy: div) // value = q1 * div + r1
|
||||||
|
let (q2, r2) = q1.quotientAndRemainder(dividingBy: div) // value = (q2 * div + r2) * div + r1 = q2 d² + r2 d + r1
|
||||||
|
let (q3, r3) = q2.quotientAndRemainder(dividingBy: div) // value = q3 d³ + r3 d² + r2 d + r1
|
||||||
|
let (q4, r4) = q3.quotientAndRemainder(dividingBy: div) // value = q4 d⁴ + r4 d³ + r3 d² + r2 d + r1
|
||||||
|
return [Int8(q4 + 1), Int8(r4 + 1), Int8(r3 + 1), Int8(r2 + 1), Int8(r1 + 1)]
|
||||||
|
}
|
||||||
|
|
||||||
|
static func decodeKey(v1: Int8, v2: Int8) -> Int {
|
||||||
|
return Int(v1-1) * Int(Int8.max-1) + Int(v2-1)
|
||||||
|
}
|
||||||
|
/// 文字列 + 4バイト整数を Base64 にエンコードした文字列を作る
|
||||||
|
/// Python の encode_key_value(key, value) 相当
|
||||||
|
private func encodeKeyValue(key: [Int], value: Int) -> [Int8] {
|
||||||
|
let key = Self.encodeKey(key: key)
|
||||||
|
return key + [Self.keyValueDelimiter] + Self.encodeValue(value: value)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func encodeKeyValueForBulkGet(key: [Int], value: Int) -> [Int8] {
|
||||||
|
var key = Self.encodeKey(key: key)
|
||||||
|
key.insert(Self.predictiveDelimiter, at: key.count - 2) // 1トークンはInt8が2つで表せる。最後のトークンの直前にデリミタ`Int8.min + 1`を入れ、これを用いて予測検索をする
|
||||||
|
return key + [Self.keyValueDelimiter] + Self.encodeValue(value: value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 指定した [[Int]: Int] を Trie に登録して保存
|
||||||
|
private func buildAndSaveTrie(from dict: [[Int]: Int], to path: String, forBulkGet: Bool = false) {
|
||||||
|
let encode = forBulkGet ? encodeKeyValueForBulkGet : encodeKeyValue
|
||||||
|
let encodedStrings: [[Int8]] = dict.map(encode)
|
||||||
|
let trie = Marisa()
|
||||||
|
trie.build { builder in
|
||||||
|
for entry in encodedStrings {
|
||||||
|
builder(entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
trie.save(path)
|
||||||
|
print("Saved \(path): \(encodedStrings.count) entries")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// 上記のカウント結果を marisa ファイルとして保存
|
||||||
|
func saveToMarisaTrie(baseFilename: String, outputDir: String? = nil) {
|
||||||
|
let fileManager = FileManager.default
|
||||||
|
|
||||||
|
// 出力フォルダの設定(デフォルト: ~/Library/Application Support/SwiftNGram/marisa/)
|
||||||
|
let marisaDir: URL
|
||||||
|
if let outputDir = outputDir {
|
||||||
|
marisaDir = URL(fileURLWithPath: outputDir)
|
||||||
|
} else {
|
||||||
|
let libraryDir = fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first!
|
||||||
|
marisaDir = libraryDir.appendingPathComponent("SwiftNGram/marisa", isDirectory: true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// フォルダがない場合は作成
|
||||||
|
do {
|
||||||
|
try fileManager.createDirectory(
|
||||||
|
at: marisaDir,
|
||||||
|
withIntermediateDirectories: true, // 中間ディレクトリも作成
|
||||||
|
attributes: nil
|
||||||
|
)
|
||||||
|
} catch {
|
||||||
|
print("ディレクトリ作成エラー: \(error)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ファイルパスの生成(marisa ディレクトリ内に配置)
|
||||||
|
let paths = [
|
||||||
|
"\(baseFilename)_c_abc.marisa",
|
||||||
|
"\(baseFilename)_u_abx.marisa",
|
||||||
|
"\(baseFilename)_u_xbc.marisa",
|
||||||
|
"\(baseFilename)_r_xbx.marisa",
|
||||||
|
].map { file in
|
||||||
|
marisaDir.appendingPathComponent(file).path
|
||||||
|
}
|
||||||
|
|
||||||
|
// 各 Trie ファイルを保存
|
||||||
|
buildAndSaveTrie(from: c_abc, to: paths[0], forBulkGet: true)
|
||||||
|
buildAndSaveTrie(from: u_abx, to: paths[1])
|
||||||
|
buildAndSaveTrie(from: u_xbc, to: paths[2], forBulkGet: true)
|
||||||
|
|
||||||
|
let r_xbx: [[Int]: Int] = s_xbx.mapValues { $0.count }
|
||||||
|
buildAndSaveTrie(from: r_xbx, to: paths[3])
|
||||||
|
|
||||||
|
// **絶対パスでの出力**
|
||||||
|
print("All saved files (absolute paths):")
|
||||||
|
for path in paths {
|
||||||
|
print(path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ファイルを読み込み、行ごとの文字列配列を返す関数
|
||||||
|
public func readLinesFromFile(filePath: String) -> [String]? {
|
||||||
|
guard let fileHandle = FileHandle(forReadingAtPath: filePath) else {
|
||||||
|
print("[Error] ファイルを開けませんでした: \(filePath)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer {
|
||||||
|
try? fileHandle.close()
|
||||||
|
}
|
||||||
|
// UTF-8 でデータを読み込む
|
||||||
|
let data = fileHandle.readDataToEndOfFile()
|
||||||
|
guard let text = String(data: data, encoding: .utf8) else {
|
||||||
|
print("[Error] UTF-8 で読み込めませんでした: \(filePath)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// 改行で分割し、空行を除去
|
||||||
|
return text.components(separatedBy: .newlines).filter { !$0.isEmpty }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 文章の配列から n-gram を学習し、Marisa-Trie を保存する関数
|
||||||
|
public func trainNGram(
|
||||||
|
lines: [String],
|
||||||
|
n: Int,
|
||||||
|
baseFilename: String,
|
||||||
|
outputDir: String? = nil
|
||||||
|
) {
|
||||||
|
let tokenizer = ZenzTokenizer()
|
||||||
|
let trainer = SwiftTrainer(n: n, tokenizer: tokenizer)
|
||||||
|
|
||||||
|
for (i, line) in lines.enumerated() {
|
||||||
|
if i % 100 == 0 {
|
||||||
|
print(i, "/", lines.count)
|
||||||
|
}
|
||||||
|
let trimmed = line.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||||
|
if !trimmed.isEmpty {
|
||||||
|
trainer.countSent(trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trie ファイルを保存(出力フォルダを渡す)
|
||||||
|
trainer.saveToMarisaTrie(baseFilename: baseFilename, outputDir: outputDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 実行例: ファイルを読み込み、n-gram を学習して保存
|
||||||
|
public func trainNGramFromFile(filePath: String, n: Int, baseFilename: String, outputDir: String? = nil) {
|
||||||
|
guard let lines = readLinesFromFile(filePath: filePath) else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
trainNGram(lines: lines, n: n, baseFilename: baseFilename, outputDir: outputDir)
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
public func trainNGramFromFile(filePath: String, n: Int, baseFilename: String, outputDir: String? = nil) {
|
||||||
|
fatalError("[Error] trainNGramFromFile is unsupported.")
|
||||||
|
}
|
||||||
|
#endif
|
3
Sources/EfficientNGram/tokenizer/README.md
Normal file
3
Sources/EfficientNGram/tokenizer/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# tokenizer
|
||||||
|
|
||||||
|
This tokenizer data is from [ku-nlp/gpt2-small-japanese-char](https://huggingface.co/ku-nlp/gpt2-small-japanese-char), following CC BY-SA 4.0 License.
|
34
Sources/EfficientNGram/tokenizer/config.json
Normal file
34
Sources/EfficientNGram/tokenizer/config.json
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
{
|
||||||
|
"_name_or_path": "ku-nlp/gpt2-small-japanese-char",
|
||||||
|
"activation_function": "gelu_new",
|
||||||
|
"architectures": [
|
||||||
|
"GPT2LMHeadModel"
|
||||||
|
],
|
||||||
|
"attn_pdrop": 0.1,
|
||||||
|
"bos_token_id": 1,
|
||||||
|
"embd_pdrop": 0.1,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"layer_norm_epsilon": 1e-05,
|
||||||
|
"model_type": "gpt2",
|
||||||
|
"n_embd": 768,
|
||||||
|
"n_ctx": 1024,
|
||||||
|
"n_head": 12,
|
||||||
|
"n_inner": null,
|
||||||
|
"n_layer": 12,
|
||||||
|
"n_positions": 1024,
|
||||||
|
"pad_token_id": 1,
|
||||||
|
"reorder_and_upcast_attn": false,
|
||||||
|
"resid_pdrop": 0.1,
|
||||||
|
"scale_attn_by_inverse_layer_idx": false,
|
||||||
|
"scale_attn_weights": true,
|
||||||
|
"summary_activation": null,
|
||||||
|
"summary_first_dropout": 0.1,
|
||||||
|
"summary_proj_to_labels": true,
|
||||||
|
"summary_type": "cls_index",
|
||||||
|
"summary_use_proj": true,
|
||||||
|
"torch_dtype": "float32",
|
||||||
|
"transformers_version": "4.30.0",
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 6000
|
||||||
|
}
|
5765
Sources/EfficientNGram/tokenizer/merges.txt
Normal file
5765
Sources/EfficientNGram/tokenizer/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
30
Sources/EfficientNGram/tokenizer/special_tokens_map.json
Normal file
30
Sources/EfficientNGram/tokenizer/special_tokens_map.json
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
{
|
||||||
|
"bos_token": {
|
||||||
|
"content": "<s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"eos_token": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"pad_token": {
|
||||||
|
"content": "[PAD]",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"unk_token": {
|
||||||
|
"content": "[UNK]",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
}
|
||||||
|
}
|
29132
Sources/EfficientNGram/tokenizer/tokenizer.json
Normal file
29132
Sources/EfficientNGram/tokenizer/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
47
Sources/EfficientNGram/tokenizer/tokenizer_config.json
Normal file
47
Sources/EfficientNGram/tokenizer/tokenizer_config.json
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
{
|
||||||
|
"add_bos_token": false,
|
||||||
|
"add_prefix_space": false,
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"0": {
|
||||||
|
"content": "[UNK]",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"1": {
|
||||||
|
"content": "[PAD]",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"content": "<s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"3": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": true,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"clean_up_tokenization_spaces": true,
|
||||||
|
"eos_token": "</s>",
|
||||||
|
"errors": "replace",
|
||||||
|
"extra_special_tokens": {},
|
||||||
|
"model_max_length": 1000000000000000019884624838656,
|
||||||
|
"pad_token": "[PAD]",
|
||||||
|
"tokenizer_class": "GPT2Tokenizer",
|
||||||
|
"unk_token": "[UNK]"
|
||||||
|
}
|
1
Sources/EfficientNGram/tokenizer/vocab.json
Normal file
1
Sources/EfficientNGram/tokenizer/vocab.json
Normal file
File diff suppressed because one or more lines are too long
@ -8,7 +8,7 @@
|
|||||||
|
|
||||||
import Foundation
|
import Foundation
|
||||||
import SwiftUtils
|
import SwiftUtils
|
||||||
import SwiftNGram
|
import EfficientNGram
|
||||||
|
|
||||||
/// かな漢字変換の管理を受け持つクラス
|
/// かな漢字変換の管理を受け持つクラス
|
||||||
@MainActor public final class KanaKanjiConverter {
|
@MainActor public final class KanaKanjiConverter {
|
||||||
@ -29,7 +29,7 @@ import SwiftNGram
|
|||||||
/// Zenzaiのためのzenzモデル
|
/// Zenzaiのためのzenzモデル
|
||||||
private var zenz: Zenz? = nil
|
private var zenz: Zenz? = nil
|
||||||
private var zenzaiCache: Kana2Kanji.ZenzaiCache? = nil
|
private var zenzaiCache: Kana2Kanji.ZenzaiCache? = nil
|
||||||
private var zenzaiPersonalization: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)?
|
private var zenzaiPersonalization: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: EfficientNGram, personal: EfficientNGram)?
|
||||||
public private(set) var zenzStatus: String = ""
|
public private(set) var zenzStatus: String = ""
|
||||||
|
|
||||||
/// リセットする関数
|
/// リセットする関数
|
||||||
@ -43,7 +43,7 @@ import SwiftNGram
|
|||||||
self.lastData = nil
|
self.lastData = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
private func getZenzaiPersonalization(mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode?) -> (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)? {
|
private func getZenzaiPersonalization(mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode?) -> (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: EfficientNGram, personal: EfficientNGram)? {
|
||||||
guard let mode else {
|
guard let mode else {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -51,8 +51,8 @@ import SwiftNGram
|
|||||||
return zenzaiPersonalization
|
return zenzaiPersonalization
|
||||||
}
|
}
|
||||||
let tokenizer = ZenzTokenizer()
|
let tokenizer = ZenzTokenizer()
|
||||||
let baseModel = LM(baseFilename: mode.baseNgramLanguageModel, n: mode.n, d: mode.d, tokenizer: tokenizer)
|
let baseModel = EfficientNGram(baseFilename: mode.baseNgramLanguageModel, n: mode.n, d: mode.d, tokenizer: tokenizer)
|
||||||
let personalModel = LM(baseFilename: mode.personalNgramLanguageModel, n: mode.n, d: mode.d, tokenizer: tokenizer)
|
let personalModel = EfficientNGram(baseFilename: mode.personalNgramLanguageModel, n: mode.n, d: mode.d, tokenizer: tokenizer)
|
||||||
self.zenzaiPersonalization = (mode, baseModel, personalModel)
|
self.zenzaiPersonalization = (mode, baseModel, personalModel)
|
||||||
return (mode, baseModel, personalModel)
|
return (mode, baseModel, personalModel)
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import Foundation
|
import Foundation
|
||||||
import SwiftUtils
|
import SwiftUtils
|
||||||
import SwiftNGram
|
import EfficientNGram
|
||||||
|
|
||||||
extension Kana2Kanji {
|
extension Kana2Kanji {
|
||||||
struct ZenzaiCache: Sendable {
|
struct ZenzaiCache: Sendable {
|
||||||
@ -58,7 +58,7 @@ extension Kana2Kanji {
|
|||||||
zenzaiCache: ZenzaiCache?,
|
zenzaiCache: ZenzaiCache?,
|
||||||
inferenceLimit: Int,
|
inferenceLimit: Int,
|
||||||
requestRichCandidates: Bool,
|
requestRichCandidates: Bool,
|
||||||
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)?,
|
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: EfficientNGram, personal: EfficientNGram)?,
|
||||||
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
|
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
|
||||||
) -> (result: LatticeNode, nodes: Nodes, cache: ZenzaiCache) {
|
) -> (result: LatticeNode, nodes: Nodes, cache: ZenzaiCache) {
|
||||||
var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? PrefixConstraint([])
|
var constraint = zenzaiCache?.getNewConstraint(for: inputData) ?? PrefixConstraint([])
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import Foundation
|
import Foundation
|
||||||
import SwiftUtils
|
import SwiftUtils
|
||||||
import SwiftNGram
|
import EfficientNGram
|
||||||
|
|
||||||
@MainActor package final class Zenz {
|
@MainActor package final class Zenz {
|
||||||
package var resourceURL: URL
|
package var resourceURL: URL
|
||||||
@ -36,7 +36,7 @@ import SwiftNGram
|
|||||||
convertTarget: String,
|
convertTarget: String,
|
||||||
candidates: [Candidate],
|
candidates: [Candidate],
|
||||||
requestRichCandidates: Bool,
|
requestRichCandidates: Bool,
|
||||||
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)?,
|
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: EfficientNGram, personal: EfficientNGram)?,
|
||||||
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
|
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
|
||||||
) -> ZenzContext.CandidateEvaluationResult {
|
) -> ZenzContext.CandidateEvaluationResult {
|
||||||
guard let zenzContext else {
|
guard let zenzContext else {
|
||||||
|
@ -7,7 +7,7 @@ import SwiftUtils
|
|||||||
import HeapModule
|
import HeapModule
|
||||||
import Algorithms
|
import Algorithms
|
||||||
import Foundation
|
import Foundation
|
||||||
import SwiftNGram
|
import EfficientNGram
|
||||||
|
|
||||||
struct FixedSizeHeap<Element: Comparable> {
|
struct FixedSizeHeap<Element: Comparable> {
|
||||||
private var size: Int
|
private var size: Int
|
||||||
@ -290,7 +290,7 @@ final class ZenzContext {
|
|||||||
input: String,
|
input: String,
|
||||||
candidate: Candidate,
|
candidate: Candidate,
|
||||||
requestRichCandidates: Bool,
|
requestRichCandidates: Bool,
|
||||||
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: LM, personal: LM)?,
|
personalizationMode: (mode: ConvertRequestOptions.ZenzaiMode.PersonalizationMode, base: EfficientNGram, personal: EfficientNGram)?,
|
||||||
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
|
versionDependentConfig: ConvertRequestOptions.ZenzaiVersionDependentMode
|
||||||
) -> CandidateEvaluationResult {
|
) -> CandidateEvaluationResult {
|
||||||
print("Evaluate", candidate)
|
print("Evaluate", candidate)
|
||||||
|
14
Tests/EfficientNGramTests/EfficientNGramTests.swift
Normal file
14
Tests/EfficientNGramTests/EfficientNGramTests.swift
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import XCTest
|
||||||
|
@testable import EfficientNGram
|
||||||
|
import Tokenizers
|
||||||
|
|
||||||
|
class SwiftNGramTests: XCTestCase {
|
||||||
|
#if canImport(SwiftyMarisa)
|
||||||
|
func testTokenizers() throws {
|
||||||
|
let tokenizer = ZenzTokenizer()
|
||||||
|
let inputIds = tokenizer.encode(text: "これは日本語です")
|
||||||
|
XCTAssertEqual(inputIds, [268, 262, 253, 304, 358, 698, 246, 255])
|
||||||
|
XCTAssertEqual(tokenizer.decode(tokens: inputIds), "これは日本語です")
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
Reference in New Issue
Block a user