mirror of
https://github.com/mii443/ncb-tts-r2.git
synced 2025-08-22 16:15:29 +00:00
Compare commits
106 Commits
Author | SHA1 | Date | |
---|---|---|---|
43cce7dc31 | |||
2f06f6be3b | |||
f0327e232a | |||
733646b6b8 | |||
9e7d89eaa5 | |||
ea93d1f8ac | |||
f9f90ab63e | |||
e188f3b758 | |||
0bea81aa6e | |||
7ed3f10543 | |||
c4819bb633 | |||
d382a045d0 | |||
ffa49c18cc | |||
48a4977208 | |||
32d9782cff | |||
248c88a7a2 | |||
65db668e2a | |||
879644f30c | |||
8ce6dbf57a | |||
4d658b6671 | |||
e606b29f81 | |||
376b8a6882 | |||
b452d4609a | |||
599b7fb0d1 | |||
4c76fd037a | |||
83fde399b2 | |||
91044c7c25 | |||
a13994c37e | |||
97ae9dd9e0 | |||
f7e08b4e2e | |||
257c8511e3 | |||
771711e3bf | |||
40e6194942 | |||
9949d501b5 | |||
24de817d6f | |||
8a1fa22074 | |||
24c609aaf2 | |||
51f008cfdc | |||
4c176935e3 | |||
55ea223f69 | |||
696954395b | |||
82e3c55fd5 | |||
af83f6b6e0 | |||
c68e533133 | |||
77b4c3e04d | |||
8db4d65042 | |||
e12dbd7375 | |||
cff30e7471 | |||
bf4b160af7 | |||
b4de0f1ad6 | |||
1975b2e9cd | |||
5ee3c9b328 | |||
df46152a12 | |||
1830029231 | |||
e4dbedcbe7 | |||
f11718cc8b | |||
8ef5530524 | |||
89f66aefb0 | |||
68e96ef784 | |||
883a54f70a | |||
60255d7582 | |||
68d49772e7 | |||
0fbe068f1e | |||
c39800da18 | |||
5aa5e09bb7 | |||
ec47a6f521 | |||
bb4d0a0504 | |||
4630883b28 | |||
2249e8c213 | |||
f9ebd8a430 | |||
8a9817a449 | |||
4c5b9cb345 | |||
52f86a6c16 | |||
b62e81dd66 | |||
60770f65b6 | |||
f93701a591 | |||
2a40e9ee16 | |||
2f2b82857f | |||
8b2574e90b | |||
fb654229b0 | |||
7124930a9d | |||
0bd051e48c | |||
ccf2c63224 | |||
c2a84be3a4 | |||
d525636060 | |||
68a73a4dae | |||
f7b7071a09 | |||
b7a4da7f3e | |||
708b6fc429 | |||
f9fd0686a7 | |||
99d8ef9bef | |||
af01576a99 | |||
ddab474d67 | |||
d10bfcc333 | |||
1470612d8b | |||
065717839b | |||
6dafc66878 | |||
1789bd7c4e | |||
0b93c23e91 | |||
7fe65bc397 | |||
51c39036c6 | |||
c52429bce0 | |||
5ca5325fbd | |||
4161acbd45 | |||
47b11262e2 | |||
b36bee8be8 |
2
.dockerignore
Normal file
2
.dockerignore
Normal file
@ -0,0 +1,2 @@
|
||||
target
|
||||
audio
|
21
.github/workflows/build.yml
vendored
21
.github/workflows/build.yml
vendored
@ -6,24 +6,29 @@ on:
|
||||
- 'v*'
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: docker/metadata-action@v3
|
||||
- uses: actions/checkout@v4
|
||||
name: Checkout
|
||||
- uses: docker/metadata-action@v4
|
||||
id: meta
|
||||
with:
|
||||
images: ghcr.io/morioka22/ncb-tts-r2
|
||||
images: ghcr.io/mii443/ncb-tts-r2
|
||||
tags: |
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
- uses: docker/login-action@v1
|
||||
- uses: docker/login-action@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: morioka22
|
||||
username: mii443
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- uses: docker/build-push-action@v2
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
|
||||
- uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
platforms: linux/amd64
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -3,4 +3,5 @@ Cargo.lock
|
||||
config.toml
|
||||
credentials.json
|
||||
/audio
|
||||
*.mp3
|
||||
*.mp3
|
||||
*.swp
|
||||
|
71
Cargo.toml
71
Cargo.toml
@ -1,33 +1,82 @@
|
||||
[package]
|
||||
name = "ncb-tts-r2"
|
||||
version = "0.1.0"
|
||||
version = "1.11.2"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "ncb_tts_r2"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "ncb-tts-r2"
|
||||
path = "src/main.rs"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
serde_json = "1.0"
|
||||
serde = "1.0"
|
||||
toml = "0.5"
|
||||
toml = "0.8.19"
|
||||
gcp_auth = "0.5.0"
|
||||
reqwest = { version = "0.11", features = ["json"] }
|
||||
base64 = "0.13"
|
||||
reqwest = { version = "0.12.9", features = ["json"] }
|
||||
base64 = "0.22.1"
|
||||
async-trait = "0.1.57"
|
||||
redis = "*"
|
||||
redis = { version = "0.29.2", features = ["aio", "tokio-comp"] }
|
||||
bb8 = "0.8"
|
||||
bb8-redis = "0.16"
|
||||
thiserror = "1.0"
|
||||
regex = "1"
|
||||
tracing-subscriber = "0.3.19"
|
||||
lru = "0.13.0"
|
||||
once_cell = "1.19"
|
||||
bincode = "1.3"
|
||||
tracing = "0.1.41"
|
||||
opentelemetry_sdk = { version = "0.29.0", features = ["trace"] }
|
||||
opentelemetry = "0.29.1"
|
||||
opentelemetry-semantic-conventions = "0.29.0"
|
||||
opentelemetry-otlp = { version = "0.29.0", features = ["grpc-tonic"] }
|
||||
opentelemetry-stdout = "0.29.0"
|
||||
tracing-opentelemetry = "0.30.0"
|
||||
symphonia-core = "0.5.4"
|
||||
tokio-util = { version = "0.7.14", features = ["compat"] }
|
||||
futures = "0.3.31"
|
||||
bytes = "1.10.1"
|
||||
voicevox-client = { git = "https://github.com/mii443/rust" }
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "0.8"
|
||||
version = "1.11.0"
|
||||
features = ["serde", "v4"]
|
||||
|
||||
[dependencies.songbird]
|
||||
version = "0.2.0"
|
||||
version = "0.5"
|
||||
features = ["builtin-queue"]
|
||||
|
||||
[dependencies.serenity]
|
||||
version = "0.10.9"
|
||||
features = ["builder", "cache", "client", "gateway", "model", "utils", "unstable_discord_api", "collector", "rustls_backend", "framework", "voice"]
|
||||
[dependencies.symphonia]
|
||||
version = "0.5"
|
||||
features = ["mp3"]
|
||||
|
||||
[dependencies.serenity]
|
||||
version = "0.12"
|
||||
features = [
|
||||
"builder",
|
||||
"cache",
|
||||
"client",
|
||||
"gateway",
|
||||
"model",
|
||||
"utils",
|
||||
"unstable_discord_api",
|
||||
"collector",
|
||||
"rustls_backend",
|
||||
"framework",
|
||||
"voice",
|
||||
]
|
||||
|
||||
[dependencies.tokio]
|
||||
version = "1.0"
|
||||
features = ["macros", "rt-multi-thread"]
|
||||
features = ["macros", "rt-multi-thread"]
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = "0.4"
|
||||
mockall = "0.12"
|
||||
tempfile = "3.8"
|
||||
serial_test = "3.0"
|
||||
|
61
Dockerfile
61
Dockerfile
@ -1,18 +1,45 @@
|
||||
FROM ubuntu:22.04
|
||||
WORKDIR /usr/src/ncb-tts-r2
|
||||
COPY Cargo.toml .
|
||||
COPY src src
|
||||
ENV PATH $PATH:/root/.cargo/bin/
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y ffmpeg libssl-dev pkg-config libopus-dev wget curl gcc \
|
||||
&& apt-get -y clean \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable \
|
||||
&& rustup install stable \
|
||||
&& cargo build --release \
|
||||
&& cp /usr/src/ncb-tts-r2/target/release/ncb-tts-r2 /usr/bin/ncb-tts-r2 \
|
||||
&& mkdir -p /ncb-tts-r2/audio \
|
||||
&& apt-get purge -y pkg-config wget curl gcc \
|
||||
&& rustup self uninstall -y
|
||||
FROM lukemathwalker/cargo-chef:latest-rust-1.82 AS chef
|
||||
WORKDIR /app
|
||||
|
||||
FROM chef AS planner
|
||||
COPY . .
|
||||
RUN cargo chef prepare --recipe-path recipe.json
|
||||
|
||||
FROM chef AS builder
|
||||
COPY --from=planner /app/recipe.json recipe.json
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
ffmpeg \
|
||||
libssl-dev \
|
||||
pkg-config \
|
||||
libopus-dev \
|
||||
gcc && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
RUN cargo chef cook --release --recipe-path recipe.json
|
||||
COPY . .
|
||||
RUN cargo build --release
|
||||
|
||||
FROM ubuntu:22.04 AS runtime
|
||||
WORKDIR /ncb-tts-r2
|
||||
CMD ["ncb-tts-r2"]
|
||||
|
||||
# 非rootユーザーの作成
|
||||
RUN groupadd -r appgroup && useradd -r -g appgroup appuser
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
openssl \
|
||||
ca-certificates \
|
||||
ffmpeg \
|
||||
libssl-dev \
|
||||
libopus-dev && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=builder /app/target/release/ncb-tts-r2 /usr/local/bin/ncb-tts-r2
|
||||
RUN chmod +x /usr/local/bin/ncb-tts-r2
|
||||
|
||||
# 非rootユーザーに切り替え
|
||||
USER appuser
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/ncb-tts-r2"]
|
||||
|
14
docker-compose.yml
Normal file
14
docker-compose.yml
Normal file
@ -0,0 +1,14 @@
|
||||
version: '3'
|
||||
|
||||
services:
|
||||
ncb-tts-r2:
|
||||
container_name: ncb-tts-r2
|
||||
image: ghcr.io/mii443/ncb-tts-r2:1.11.2
|
||||
environment:
|
||||
- NCB_TOKEN=YOUR_BOT_TOKEN
|
||||
- NCB_APP_ID=YOUR_BOT_ID
|
||||
- NCB_PREFIX=BOT_PREFIX
|
||||
- NCB_REDIS_URL=redis://<REDIS_IP>/
|
||||
- NCB_VOICEVOX_KEY=VOICEVOX_KEY
|
||||
volumes:
|
||||
- ./credentials.json:/ncb-tts-r2/credentials.json:ro
|
56
manifest/ncb-tts.yaml
Normal file
56
manifest/ncb-tts.yaml
Normal file
@ -0,0 +1,56 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: ncb-tts-deployment
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: ncb-tts
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: ncb-tts
|
||||
spec:
|
||||
containers:
|
||||
- name: redis
|
||||
image: redis:7.0.4-alpine
|
||||
ports:
|
||||
- containerPort: 6379
|
||||
name: ncb-redis
|
||||
volumeMounts:
|
||||
- name: ncb-redis-pvc
|
||||
mountPath: /data
|
||||
- name: tts
|
||||
image: ghcr.io/mii443/ncb-tts-r2
|
||||
volumeMounts:
|
||||
- name: gcp-credentials
|
||||
mountPath: /ncb-tts-r2/credentials.json
|
||||
subPath: credentials.json
|
||||
env:
|
||||
- name: NCB_REDIS_URL
|
||||
value: "redis://localhost:6379/"
|
||||
- name: NCB_PREFIX
|
||||
value: "t2!"
|
||||
- name: NCB_TOKEN
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: ncb-secret
|
||||
key: BOT_TOKEN
|
||||
- name: NCB_VOICEVOX_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: ncb-secret
|
||||
key: VOICEVOX_KEY
|
||||
- name: NCB_APP_ID
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: ncb-secret
|
||||
key: APP_ID
|
||||
volumes:
|
||||
- name: ncb-redis-pvc
|
||||
persistentVolumeClaim:
|
||||
claimName: ncb-redis-pvc
|
||||
- name: gcp-credentials
|
||||
secret:
|
||||
secretName: gcp-credentials
|
12
manifest/pvc.yaml
Normal file
12
manifest/pvc.yaml
Normal file
@ -0,0 +1,12 @@
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: ncb-redis-pvc
|
||||
labels:
|
||||
app: ncb-redis
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
resources:
|
||||
requests:
|
||||
storage: 3Gi
|
107
src/commands/config.rs
Normal file
107
src/commands/config.rs
Normal file
@ -0,0 +1,107 @@
|
||||
use serenity::{
|
||||
all::{
|
||||
ButtonStyle, CommandInteraction, CreateActionRow, CreateButton, CreateInteractionResponse,
|
||||
CreateInteractionResponseMessage, CreateSelectMenu, CreateSelectMenuKind,
|
||||
CreateSelectMenuOption,
|
||||
},
|
||||
prelude::Context,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
data::{DatabaseClientData, TTSClientData},
|
||||
tts::tts_type::TTSType,
|
||||
};
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn config_command(
|
||||
ctx: &Context,
|
||||
command: &CommandInteraction,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let data_read = ctx.data.read().await;
|
||||
|
||||
let config = {
|
||||
let database = data_read
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
database
|
||||
.get_user_config_or_default(command.user.id.get())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let tts_client = data_read
|
||||
.get::<TTSClientData>()
|
||||
.expect("Cannot get TTSClientData");
|
||||
let voicevox_speakers = tts_client
|
||||
.voicevox_client
|
||||
.get_styles()
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::error!("Failed to get VOICEVOX styles: {}", e);
|
||||
vec![("VOICEVOX API unavailable".to_string(), 1)]
|
||||
});
|
||||
|
||||
let voicevox_speaker = config.voicevox_speaker.unwrap_or(1);
|
||||
let tts_type = config.tts_type.unwrap_or(TTSType::GCP);
|
||||
|
||||
let engine_select = CreateActionRow::SelectMenu(
|
||||
CreateSelectMenu::new(
|
||||
"TTS_CONFIG_ENGINE",
|
||||
CreateSelectMenuKind::String {
|
||||
options: vec![
|
||||
CreateSelectMenuOption::new("Google TTS", "TTS_CONFIG_ENGINE_SELECTED_GOOGLE")
|
||||
.default_selection(tts_type == TTSType::GCP),
|
||||
CreateSelectMenuOption::new("VOICEVOX", "TTS_CONFIG_ENGINE_SELECTED_VOICEVOX")
|
||||
.default_selection(tts_type == TTSType::VOICEVOX),
|
||||
],
|
||||
},
|
||||
)
|
||||
.placeholder("読み上げAPIを選択"),
|
||||
);
|
||||
|
||||
let mut components = vec![engine_select];
|
||||
|
||||
for (index, speaker_chunk) in voicevox_speakers[0..24].chunks(25).enumerate() {
|
||||
let mut options = Vec::new();
|
||||
|
||||
for (name, id) in speaker_chunk {
|
||||
options.push(
|
||||
CreateSelectMenuOption::new(
|
||||
name,
|
||||
format!("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_{}", id),
|
||||
)
|
||||
.default_selection(*id == voicevox_speaker),
|
||||
);
|
||||
}
|
||||
|
||||
components.push(CreateActionRow::SelectMenu(
|
||||
CreateSelectMenu::new(
|
||||
format!("TTS_CONFIG_VOICEVOX_SPEAKER_{}", index),
|
||||
CreateSelectMenuKind::String { options },
|
||||
)
|
||||
.placeholder("VOICEVOX Speakerを指定"),
|
||||
));
|
||||
}
|
||||
|
||||
let server_button = CreateActionRow::Buttons(vec![CreateButton::new("TTS_CONFIG_SERVER")
|
||||
.label("サーバー設定")
|
||||
.style(ButtonStyle::Primary)]);
|
||||
|
||||
components.push(server_button);
|
||||
|
||||
command
|
||||
.create_response(
|
||||
&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("読み上げ設定")
|
||||
.components(components)
|
||||
.ephemeral(true),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
4
src/commands/mod.rs
Normal file
4
src/commands/mod.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub mod config;
|
||||
pub mod setup;
|
||||
pub mod skip;
|
||||
pub mod stop;
|
186
src/commands/setup.rs
Normal file
186
src/commands/setup.rs
Normal file
@ -0,0 +1,186 @@
|
||||
use serenity::{
|
||||
all::{
|
||||
AutoArchiveDuration, CommandInteraction, CreateEmbed, CreateInteractionResponse, CreateInteractionResponseMessage, CreateMessage, CreateThread
|
||||
},
|
||||
model::prelude::UserId,
|
||||
prelude::Context,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
data::{DatabaseClientData, TTSClientData, TTSData},
|
||||
tts::instance::TTSInstance,
|
||||
};
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn setup_command(
|
||||
ctx: &Context,
|
||||
command: &CommandInteraction,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
info!("Received event");
|
||||
|
||||
if command.guild_id.is_none() {
|
||||
command
|
||||
.create_response(&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("このコマンドはサーバーでのみ使用可能です.")
|
||||
.ephemeral(true)
|
||||
))
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Fetching guild cache");
|
||||
let guild_id = command.guild_id.unwrap();
|
||||
let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone();
|
||||
|
||||
let channel_id = guild
|
||||
.voice_states
|
||||
.get(&UserId::from(command.user.id.get()))
|
||||
.and_then(|state| state.channel_id);
|
||||
|
||||
if channel_id.is_none() {
|
||||
command
|
||||
.create_response(&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("ボイスチャンネルに参加してから実行してください.")
|
||||
.ephemeral(true)
|
||||
))
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let channel_id = channel_id.unwrap();
|
||||
|
||||
let manager = songbird::get(ctx)
|
||||
.await
|
||||
.expect("Cannot get songbird client.")
|
||||
.clone();
|
||||
|
||||
let storage_lock = {
|
||||
let data_read = ctx.data.read().await;
|
||||
data_read
|
||||
.get::<TTSData>()
|
||||
.expect("Cannot get TTSStorage")
|
||||
.clone()
|
||||
};
|
||||
|
||||
let text_channel_id = {
|
||||
let mut storage = storage_lock.write().await;
|
||||
if storage.contains_key(&guild.id) {
|
||||
command
|
||||
.create_response(&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("すでにセットアップしています.")
|
||||
.ephemeral(true)
|
||||
))
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let text_channel_ids = {
|
||||
if let Some(mode) = command.data.options.get(0) {
|
||||
match &mode.value {
|
||||
serenity::all::CommandDataOptionValue::String(value) => {
|
||||
match value.as_str() {
|
||||
"TEXT_CHANNEL" => vec![command.channel_id],
|
||||
"NEW_THREAD" => {
|
||||
vec![command
|
||||
.channel_id
|
||||
.create_thread(&ctx.http, CreateThread::new("TTS").auto_archive_duration(AutoArchiveDuration::OneHour).kind(serenity::all::ChannelType::PublicThread))
|
||||
.await
|
||||
.unwrap()
|
||||
.id]
|
||||
}
|
||||
"VOICE_CHANNEL" => vec![channel_id],
|
||||
_ => if channel_id != command.channel_id {
|
||||
vec![command.channel_id, channel_id]
|
||||
} else {
|
||||
vec![channel_id]
|
||||
},
|
||||
}
|
||||
},
|
||||
_ => if channel_id != command.channel_id {
|
||||
vec![command.channel_id, channel_id]
|
||||
} else {
|
||||
vec![channel_id]
|
||||
},
|
||||
}
|
||||
} else {
|
||||
if channel_id != command.channel_id {
|
||||
vec![command.channel_id, channel_id]
|
||||
} else {
|
||||
vec![channel_id]
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let instance = TTSInstance::new(text_channel_ids.clone(), channel_id, guild.id);
|
||||
storage.insert(guild.id, instance.clone());
|
||||
|
||||
// Save to database
|
||||
let data_read = ctx.data.read().await;
|
||||
let database = data_read
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
drop(data_read);
|
||||
|
||||
if let Err(e) = database.save_tts_instance(guild.id, &instance).await {
|
||||
tracing::error!("Failed to save TTS instance to database: {}", e);
|
||||
}
|
||||
|
||||
text_channel_ids[0]
|
||||
};
|
||||
|
||||
command
|
||||
.create_response(&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content(format!(
|
||||
"TTS Channel: <#{}>{}",
|
||||
text_channel_id,
|
||||
if text_channel_id == channel_id {
|
||||
"\nボイスチャンネルを右クリックし `チャットを開く` を押して開くことが出来ます。"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
))
|
||||
))
|
||||
.await?;
|
||||
|
||||
let _handler = manager.join(guild.id, channel_id).await;
|
||||
|
||||
let data = ctx
|
||||
.data
|
||||
.read()
|
||||
.await;
|
||||
let tts_client = data
|
||||
.get::<TTSClientData>()
|
||||
.expect("Cannot get TTSClientData");
|
||||
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::error!("Failed to get VOICEVOX speakers: {}", e);
|
||||
vec!["VOICEVOX API unavailable".to_string()]
|
||||
});
|
||||
|
||||
text_channel_id
|
||||
.send_message(&ctx.http, CreateMessage::new()
|
||||
.embed(
|
||||
CreateEmbed::new()
|
||||
.title("読み上げ (Serenity)")
|
||||
.field(
|
||||
"VOICEVOXクレジット",
|
||||
format!("```\n{}\n```", voicevox_speakers.join("\n")),
|
||||
false,
|
||||
)
|
||||
.field("設定コマンド", "`/config`", false)
|
||||
.field("フィードバック", "https://feedback.mii.codes/", false)
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
81
src/commands/skip.rs
Normal file
81
src/commands/skip.rs
Normal file
@ -0,0 +1,81 @@
|
||||
use serenity::{
|
||||
all::{
|
||||
CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage
|
||||
},
|
||||
model::prelude::UserId,
|
||||
prelude::Context,
|
||||
};
|
||||
|
||||
use crate::data::TTSData;
|
||||
|
||||
pub async fn skip_command(
|
||||
ctx: &Context,
|
||||
command: &CommandInteraction,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if command.guild_id.is_none() {
|
||||
command
|
||||
.create_response(&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("このコマンドはサーバーでのみ使用可能です.")
|
||||
.ephemeral(true)
|
||||
))
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let guild_id = command.guild_id.unwrap();
|
||||
let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone();
|
||||
|
||||
let channel_id = guild
|
||||
.voice_states
|
||||
.get(&UserId::from(command.user.id.get()))
|
||||
.and_then(|state| state.channel_id);
|
||||
|
||||
if channel_id.is_none() {
|
||||
command
|
||||
.create_response(&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("ボイスチャンネルに参加してから実行してください.")
|
||||
.ephemeral(true)
|
||||
))
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let storage_lock = {
|
||||
let data_read = ctx.data.read().await;
|
||||
data_read
|
||||
.get::<TTSData>()
|
||||
.expect("Cannot get TTSStorage")
|
||||
.clone()
|
||||
};
|
||||
|
||||
{
|
||||
let mut storage = storage_lock.write().await;
|
||||
if !storage.contains_key(&guild.id) {
|
||||
command
|
||||
.create_response(&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("読み上げしていません")
|
||||
.ephemeral(true)
|
||||
))
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
storage.get_mut(&guild.id).unwrap().skip(ctx).await;
|
||||
}
|
||||
|
||||
command
|
||||
.create_response(&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("スキップしました")
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
115
src/commands/stop.rs
Normal file
115
src/commands/stop.rs
Normal file
@ -0,0 +1,115 @@
|
||||
use serenity::{
|
||||
all::{
|
||||
CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, EditThread,
|
||||
},
|
||||
model::prelude::UserId,
|
||||
prelude::Context,
|
||||
};
|
||||
|
||||
use crate::data::{DatabaseClientData, TTSData};
|
||||
|
||||
pub async fn stop_command(
|
||||
ctx: &Context,
|
||||
command: &CommandInteraction,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if command.guild_id.is_none() {
|
||||
command
|
||||
.create_response(
|
||||
&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("このコマンドはサーバーでのみ使用可能です.")
|
||||
.ephemeral(true),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let guild_id = command.guild_id.unwrap();
|
||||
let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone();
|
||||
|
||||
let channel_id = guild
|
||||
.voice_states
|
||||
.get(&UserId::from(command.user.id.get()))
|
||||
.and_then(|state| state.channel_id);
|
||||
|
||||
if channel_id.is_none() {
|
||||
command
|
||||
.create_response(
|
||||
&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("ボイスチャンネルに参加してから実行してください.")
|
||||
.ephemeral(true),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let manager = songbird::get(ctx)
|
||||
.await
|
||||
.expect("Cannot get songbird client.")
|
||||
.clone();
|
||||
|
||||
let storage_lock = {
|
||||
let data_read = ctx.data.read().await;
|
||||
data_read
|
||||
.get::<TTSData>()
|
||||
.expect("Cannot get TTSStorage")
|
||||
.clone()
|
||||
};
|
||||
|
||||
let text_channel_id = {
|
||||
let mut storage = storage_lock.write().await;
|
||||
|
||||
if !storage.contains_key(&guild.id) {
|
||||
command
|
||||
.create_response(
|
||||
&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("すでに停止しています")
|
||||
.ephemeral(true),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let text_channel_id = storage.get(&guild.id).unwrap().text_channels[0];
|
||||
storage.remove(&guild.id);
|
||||
|
||||
// Remove from database
|
||||
let data_read = ctx.data.read().await;
|
||||
let database = data_read
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
drop(data_read);
|
||||
|
||||
if let Err(e) = database.remove_tts_instance(guild.id).await {
|
||||
tracing::error!("Failed to remove TTS instance from database: {}", e);
|
||||
}
|
||||
|
||||
text_channel_id
|
||||
};
|
||||
|
||||
let _handler = manager.remove(guild.id).await;
|
||||
|
||||
command
|
||||
.create_response(
|
||||
&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new().content("停止しました"),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let _ = text_channel_id
|
||||
.edit_thread(&ctx.http, EditThread::new().archived(true))
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
@ -6,5 +6,7 @@ pub struct Config {
|
||||
pub token: String,
|
||||
pub application_id: u64,
|
||||
pub redis_url: String,
|
||||
pub voicevox_key: String
|
||||
}
|
||||
pub voicevox_key: Option<String>,
|
||||
pub voicevox_original_api_url: Option<String>,
|
||||
pub otel_http_url: Option<String>,
|
||||
}
|
||||
|
273
src/connection_monitor.rs
Normal file
273
src/connection_monitor.rs
Normal file
@ -0,0 +1,273 @@
|
||||
use serenity::{
|
||||
all::{CreateEmbed, CreateMessage},
|
||||
prelude::Context,
|
||||
};
|
||||
use std::time::Duration;
|
||||
use tokio::time;
|
||||
use tracing::{error, info, instrument, warn};
|
||||
|
||||
use crate::data::{DatabaseClientData, TTSData};
|
||||
|
||||
/// Constants for connection monitoring
|
||||
const CONNECTION_CHECK_INTERVAL_SECS: u64 = 5;
|
||||
const MAX_RECONNECTION_ATTEMPTS: u32 = 3;
|
||||
const RECONNECTION_BACKOFF_SECS: u64 = 2;
|
||||
|
||||
/// Errors that can occur during connection monitoring
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConnectionMonitorError {
|
||||
#[error("Failed to get songbird manager")]
|
||||
SongbirdManagerNotFound,
|
||||
#[error("Failed to check voice channel users: {0}")]
|
||||
VoiceChannelCheck(String),
|
||||
#[error("Failed to reconnect after {attempts} attempts")]
|
||||
ReconnectionFailed { attempts: u32 },
|
||||
#[error("Database operation failed: {0}")]
|
||||
Database(#[from] redis::RedisError),
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, ConnectionMonitorError>;
|
||||
|
||||
/// Connection monitor that periodically checks voice channel connections
|
||||
pub struct ConnectionMonitor {
|
||||
reconnection_attempts: std::collections::HashMap<serenity::model::id::GuildId, u32>,
|
||||
}
|
||||
|
||||
impl Default for ConnectionMonitor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionMonitor {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
reconnection_attempts: std::collections::HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the connection monitoring task
|
||||
pub fn start(ctx: Context) {
|
||||
tokio::spawn(async move {
|
||||
let mut monitor = ConnectionMonitor::new();
|
||||
info!(
|
||||
interval_secs = CONNECTION_CHECK_INTERVAL_SECS,
|
||||
"Starting connection monitor"
|
||||
);
|
||||
let mut interval = time::interval(Duration::from_secs(CONNECTION_CHECK_INTERVAL_SECS));
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if let Err(e) = monitor.check_connections(&ctx).await {
|
||||
error!(error = %e, "Connection monitoring failed");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Check all active TTS instances and their voice channel connections
|
||||
#[instrument(skip(self, ctx))]
|
||||
async fn check_connections(&mut self, ctx: &Context) -> Result<()> {
|
||||
let storage_lock = {
|
||||
let data_read = ctx.data.read().await;
|
||||
data_read
|
||||
.get::<TTSData>()
|
||||
.ok_or_else(|| {
|
||||
ConnectionMonitorError::VoiceChannelCheck("Cannot get TTSStorage".to_string())
|
||||
})?
|
||||
.clone()
|
||||
};
|
||||
|
||||
let database = {
|
||||
let data_read = ctx.data.read().await;
|
||||
data_read
|
||||
.get::<DatabaseClientData>()
|
||||
.ok_or_else(|| {
|
||||
ConnectionMonitorError::VoiceChannelCheck(
|
||||
"Cannot get DatabaseClientData".to_string(),
|
||||
)
|
||||
})?
|
||||
.clone()
|
||||
};
|
||||
|
||||
let mut storage = storage_lock.write().await;
|
||||
let mut guilds_to_remove = Vec::new();
|
||||
|
||||
for (guild_id, instance) in storage.iter() {
|
||||
// Check if bot is still connected to voice channel
|
||||
let manager = songbird::get(ctx)
|
||||
.await
|
||||
.ok_or(ConnectionMonitorError::SongbirdManagerNotFound)?;
|
||||
|
||||
let call = manager.get(*guild_id);
|
||||
let is_connected = if let Some(call) = call {
|
||||
if let Some(connection) = call.lock().await.current_connection() {
|
||||
connection.channel_id.is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if !is_connected {
|
||||
warn!(guild_id = %guild_id, "Bot disconnected from voice channel");
|
||||
|
||||
// Check if there are users in the voice channel
|
||||
let should_reconnect = match self.check_voice_channel_users(ctx, instance).await {
|
||||
Ok(has_users) => has_users,
|
||||
Err(e) => {
|
||||
warn!(guild_id = %guild_id, error = %e, "Failed to check voice channel users, skipping reconnection");
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
if should_reconnect {
|
||||
// Try to reconnect with retry logic
|
||||
let attempts = self
|
||||
.reconnection_attempts
|
||||
.get(guild_id)
|
||||
.copied()
|
||||
.unwrap_or(0);
|
||||
|
||||
if attempts >= MAX_RECONNECTION_ATTEMPTS {
|
||||
error!(
|
||||
guild_id = %guild_id,
|
||||
attempts = attempts,
|
||||
"Maximum reconnection attempts reached, removing instance"
|
||||
);
|
||||
guilds_to_remove.push(*guild_id);
|
||||
self.reconnection_attempts.remove(guild_id);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Apply exponential backoff
|
||||
if attempts > 0 {
|
||||
let backoff_duration =
|
||||
Duration::from_secs(RECONNECTION_BACKOFF_SECS * (2_u64.pow(attempts)));
|
||||
warn!(
|
||||
guild_id = %guild_id,
|
||||
attempt = attempts + 1,
|
||||
backoff_secs = backoff_duration.as_secs(),
|
||||
"Applying backoff before reconnection attempt"
|
||||
);
|
||||
tokio::time::sleep(backoff_duration).await;
|
||||
}
|
||||
|
||||
match instance.reconnect(ctx, true).await {
|
||||
Ok(_) => {
|
||||
info!(
|
||||
guild_id = %guild_id,
|
||||
attempts = attempts + 1,
|
||||
"Successfully reconnected to voice channel"
|
||||
);
|
||||
|
||||
// Reset reconnection attempts on success
|
||||
self.reconnection_attempts.remove(guild_id);
|
||||
|
||||
// Send notification message to text channel with embed
|
||||
let embed = CreateEmbed::new()
|
||||
.title("🔄 自動再接続しました")
|
||||
.description("読み上げを停止したい場合は `/stop` コマンドを使用してください。")
|
||||
.color(0x00ff00);
|
||||
|
||||
// Send message to the first text channel
|
||||
if let Some(&text_channel) = instance.text_channels.first() {
|
||||
if let Err(e) = text_channel
|
||||
.send_message(&ctx.http, CreateMessage::new().embed(embed))
|
||||
.await
|
||||
{
|
||||
error!(guild_id = %guild_id, error = %e, "Failed to send reconnection message");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let new_attempts = attempts + 1;
|
||||
self.reconnection_attempts.insert(*guild_id, new_attempts);
|
||||
error!(
|
||||
guild_id = %guild_id,
|
||||
attempt = new_attempts,
|
||||
error = %e,
|
||||
"Failed to reconnect to voice channel"
|
||||
);
|
||||
|
||||
if new_attempts >= MAX_RECONNECTION_ATTEMPTS {
|
||||
guilds_to_remove.push(*guild_id);
|
||||
self.reconnection_attempts.remove(guild_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info!(
|
||||
guild_id = %guild_id,
|
||||
"No users in voice channel, removing instance"
|
||||
);
|
||||
guilds_to_remove.push(*guild_id);
|
||||
self.reconnection_attempts.remove(guild_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove disconnected instances
|
||||
for guild_id in guilds_to_remove {
|
||||
storage.remove(&guild_id);
|
||||
|
||||
// Remove from database
|
||||
if let Err(e) = database.remove_tts_instance(guild_id).await {
|
||||
error!(guild_id = %guild_id, error = %e, "Failed to remove TTS instance from database");
|
||||
}
|
||||
|
||||
// Ensure bot leaves voice channel
|
||||
if let Some(manager) = songbird::get(ctx).await {
|
||||
if let Err(e) = manager.remove(guild_id).await {
|
||||
error!(guild_id = %guild_id, error = %e, "Failed to remove bot from voice channel");
|
||||
}
|
||||
}
|
||||
|
||||
info!(guild_id = %guild_id, "Removed disconnected TTS instance");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if there are users in the voice channel
|
||||
#[instrument(skip(self, ctx, instance))]
|
||||
async fn check_voice_channel_users(
|
||||
&self,
|
||||
ctx: &Context,
|
||||
instance: &crate::tts::instance::TTSInstance,
|
||||
) -> Result<bool> {
|
||||
let channels = instance.guild.channels(&ctx.http).await.map_err(|e| {
|
||||
ConnectionMonitorError::VoiceChannelCheck(format!(
|
||||
"Failed to get guild channels: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
if let Some(channel) = channels.get(&instance.voice_channel) {
|
||||
let members = channel.members(&ctx.cache).map_err(|e| {
|
||||
ConnectionMonitorError::VoiceChannelCheck(format!(
|
||||
"Failed to get channel members: {}",
|
||||
e
|
||||
))
|
||||
})?;
|
||||
let user_count = members.iter().filter(|member| !member.user.bot).count();
|
||||
|
||||
info!(
|
||||
guild_id = %instance.guild,
|
||||
channel_id = %instance.voice_channel,
|
||||
user_count = user_count,
|
||||
"Checked voice channel users"
|
||||
);
|
||||
|
||||
Ok(user_count > 0)
|
||||
} else {
|
||||
warn!(
|
||||
guild_id = %instance.guild,
|
||||
channel_id = %instance.voice_channel,
|
||||
"Voice channel no longer exists"
|
||||
);
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
}
|
13
src/data.rs
13
src/data.rs
@ -1,8 +1,11 @@
|
||||
use crate::{tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX}, database::database::Database};
|
||||
use serenity::{prelude::{TypeMapKey, RwLock}, model::id::GuildId, futures::lock::Mutex};
|
||||
use crate::{database::database::Database, tts::tts::TTS};
|
||||
use serenity::{
|
||||
model::id::GuildId,
|
||||
prelude::{RwLock, TypeMapKey},
|
||||
};
|
||||
|
||||
use crate::tts::instance::TTSInstance;
|
||||
use std::{sync::Arc, collections::HashMap};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
/// TTSInstance data
|
||||
pub struct TTSData;
|
||||
@ -15,12 +18,12 @@ impl TypeMapKey for TTSData {
|
||||
pub struct TTSClientData;
|
||||
|
||||
impl TypeMapKey for TTSClientData {
|
||||
type Value = Arc<Mutex<(TTS, VOICEVOX)>>;
|
||||
type Value = Arc<TTS>;
|
||||
}
|
||||
|
||||
/// Database client data
|
||||
pub struct DatabaseClientData;
|
||||
|
||||
impl TypeMapKey for DatabaseClientData {
|
||||
type Value = Arc<Mutex<Database>>;
|
||||
type Value = Arc<Database>;
|
||||
}
|
||||
|
@ -1,60 +1,480 @@
|
||||
use crate::tts::{gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType};
|
||||
use std::fmt::Debug;
|
||||
|
||||
use super::user_config::UserConfig;
|
||||
use redis::Commands;
|
||||
use crate::{
|
||||
errors::{constants::*, NCBError, Result},
|
||||
tts::{
|
||||
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, instance::TTSInstance,
|
||||
tts_type::TTSType,
|
||||
},
|
||||
};
|
||||
use bb8_redis::{bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
|
||||
use serenity::model::id::{ChannelId, GuildId, UserId};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{dictionary::Dictionary, server_config::ServerConfig, user_config::UserConfig};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Database {
|
||||
pub connection: redis::Connection
|
||||
pub pool: Pool<RedisConnectionManager>,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
pub fn new(connection: redis::Connection) -> Self {
|
||||
Self { connection }
|
||||
pub fn new(pool: Pool<RedisConnectionManager>) -> Self {
|
||||
Self { pool }
|
||||
}
|
||||
|
||||
pub async fn get_user_config(&mut self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
|
||||
let config: String = self.connection.get(format!("discord_user:{}", user_id)).unwrap_or_default();
|
||||
pub async fn new_with_url(redis_url: String) -> Result<Self> {
|
||||
let manager = RedisConnectionManager::new(redis_url)?;
|
||||
let pool = Pool::builder()
|
||||
.max_size(15)
|
||||
.build(manager)
|
||||
.await
|
||||
.map_err(|e| NCBError::Database(format!("Pool creation failed: {}", e)))?;
|
||||
Ok(Self { pool })
|
||||
}
|
||||
|
||||
fn server_key(server_id: u64) -> String {
|
||||
format!("{}{}", DISCORD_SERVER_PREFIX, server_id)
|
||||
}
|
||||
|
||||
fn user_key(user_id: u64) -> String {
|
||||
format!("{}{}", DISCORD_USER_PREFIX, user_id)
|
||||
}
|
||||
|
||||
fn tts_instance_key(guild_id: u64) -> String {
|
||||
format!("{}{}", TTS_INSTANCE_PREFIX, guild_id)
|
||||
}
|
||||
|
||||
fn tts_instances_list_key() -> String {
|
||||
TTS_INSTANCES_LIST_KEY.to_string()
|
||||
}
|
||||
|
||||
fn user_config_key(guild_id: u64, user_id: u64) -> String {
|
||||
format!("user:config:{}:{}", guild_id, user_id)
|
||||
}
|
||||
|
||||
fn server_config_key(guild_id: u64) -> String {
|
||||
format!("server:config:{}", guild_id)
|
||||
}
|
||||
|
||||
fn dictionary_key(guild_id: u64) -> String {
|
||||
format!("dictionary:{}", guild_id)
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
async fn get_config<T: serde::de::DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||
|
||||
let config: String = connection.get(key).await.unwrap_or_default();
|
||||
|
||||
if config.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
match serde_json::from_str(&config) {
|
||||
Ok(config) => Ok(Some(config)),
|
||||
Err(_) => Ok(None)
|
||||
Err(e) => {
|
||||
tracing::warn!(key = key, error = %e, "Failed to deserialize config");
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn set_user_config(&mut self, user_id: u64, config: UserConfig) -> redis::RedisResult<()> {
|
||||
let config = serde_json::to_string(&config).unwrap();
|
||||
self.connection.set::<String, String, ()>(format!("discord_user:{}", user_id), config).unwrap();
|
||||
#[tracing::instrument]
|
||||
async fn set_config<T: serde::Serialize + Debug>(&self, key: &str, config: &T) -> Result<()> {
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||
|
||||
let config_str = serde_json::to_string(config)?;
|
||||
connection.set::<_, _, ()>(key, config_str).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn set_default_user_config(&mut self, user_id: u64) -> redis::RedisResult<()> {
|
||||
#[tracing::instrument]
|
||||
pub async fn get_server_config(&self, server_id: u64) -> Result<Option<ServerConfig>> {
|
||||
self.get_config(&Self::server_key(server_id)).await
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn get_user_config(&self, user_id: u64) -> Result<Option<UserConfig>> {
|
||||
self.get_config(&Self::user_key(user_id)).await
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn set_server_config(&self, server_id: u64, config: ServerConfig) -> Result<()> {
|
||||
self.set_config(&Self::server_key(server_id), &config).await
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn set_user_config(&self, user_id: u64, config: UserConfig) -> Result<()> {
|
||||
self.set_config(&Self::user_key(user_id), &config).await
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn set_default_server_config(&self, server_id: u64) -> Result<()> {
|
||||
let config = ServerConfig {
|
||||
dictionary: Dictionary::new(),
|
||||
autostart_channel_id: None,
|
||||
autostart_text_channel_id: None,
|
||||
voice_state_announce: Some(false),
|
||||
read_username: Some(false),
|
||||
};
|
||||
|
||||
self.set_server_config(server_id, config).await
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn set_default_user_config(&self, user_id: u64) -> Result<()> {
|
||||
let voice_selection = VoiceSelectionParams {
|
||||
languageCode: String::from("ja-JP"),
|
||||
name: String::from("ja-JP-Wavenet-B"),
|
||||
ssmlGender: String::from("neutral")
|
||||
ssmlGender: String::from("neutral"),
|
||||
};
|
||||
|
||||
let voice_type = TTSType::GCP;
|
||||
|
||||
let config = UserConfig {
|
||||
tts_type: Some(voice_type),
|
||||
tts_type: Some(TTSType::GCP),
|
||||
gcp_tts_voice: Some(voice_selection),
|
||||
voicevox_speaker: Some(1)
|
||||
voicevox_speaker: Some(DEFAULT_VOICEVOX_SPEAKER),
|
||||
};
|
||||
|
||||
self.connection.set(format!("discord_user:{}", user_id), serde_json::to_string(&config).unwrap())?;
|
||||
|
||||
Ok(())
|
||||
self.set_user_config(user_id, config).await
|
||||
}
|
||||
|
||||
pub async fn get_user_config_or_default(&mut self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
|
||||
let config = self.get_user_config(user_id).await?;
|
||||
match config {
|
||||
Some(_) => Ok(config),
|
||||
#[tracing::instrument]
|
||||
pub async fn get_server_config_or_default(
|
||||
&self,
|
||||
server_id: u64,
|
||||
) -> Result<Option<ServerConfig>> {
|
||||
match self.get_server_config(server_id).await? {
|
||||
Some(config) => Ok(Some(config)),
|
||||
None => {
|
||||
self.set_default_server_config(server_id).await?;
|
||||
self.get_server_config(server_id).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn get_user_config_or_default(&self, user_id: u64) -> Result<Option<UserConfig>> {
|
||||
match self.get_user_config(user_id).await? {
|
||||
Some(config) => Ok(Some(config)),
|
||||
None => {
|
||||
self.set_default_user_config(user_id).await?;
|
||||
self.get_user_config(user_id).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Save TTS instance to database
|
||||
pub async fn save_tts_instance(&self, guild_id: GuildId, instance: &TTSInstance) -> Result<()> {
|
||||
let key = Self::tts_instance_key(guild_id.get());
|
||||
let list_key = Self::tts_instances_list_key();
|
||||
|
||||
// Save the instance
|
||||
self.set_config(&key, instance).await?;
|
||||
|
||||
// Add guild_id to the list of active instances
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||
|
||||
connection
|
||||
.sadd::<_, _, ()>(&list_key, guild_id.get())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load TTS instance from database
|
||||
#[tracing::instrument]
|
||||
pub async fn load_tts_instance(&self, guild_id: GuildId) -> Result<Option<TTSInstance>> {
|
||||
let key = Self::tts_instance_key(guild_id.get());
|
||||
self.get_config(&key).await
|
||||
}
|
||||
|
||||
/// Remove TTS instance from database
|
||||
#[tracing::instrument]
|
||||
pub async fn remove_tts_instance(&self, guild_id: GuildId) -> Result<()> {
|
||||
let key = Self::tts_instance_key(guild_id.get());
|
||||
let list_key = Self::tts_instances_list_key();
|
||||
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||
|
||||
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
|
||||
let _: std::result::Result<(), bb8_redis::redis::RedisError> =
|
||||
connection.srem(&list_key, guild_id.get()).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get all active TTS instances
|
||||
#[tracing::instrument]
|
||||
pub async fn get_all_tts_instances(&self) -> Result<Vec<(GuildId, TTSInstance)>> {
|
||||
let list_key = Self::tts_instances_list_key();
|
||||
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||
|
||||
let guild_ids: Vec<u64> = connection.smembers(&list_key).await.unwrap_or_default();
|
||||
let mut instances = Vec::new();
|
||||
|
||||
for guild_id in guild_ids {
|
||||
let guild_id = GuildId::new(guild_id);
|
||||
if let Ok(Some(instance)) = self.load_tts_instance(guild_id).await {
|
||||
instances.push((guild_id, instance));
|
||||
} else {
|
||||
tracing::warn!(guild_id = %guild_id, "Failed to load TTS instance");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(instances)
|
||||
}
|
||||
|
||||
// Additional user config methods
|
||||
pub async fn save_user_config(
|
||||
&self,
|
||||
guild_id: GuildId,
|
||||
user_id: UserId,
|
||||
config: &UserConfig,
|
||||
) -> Result<()> {
|
||||
let key = Self::user_config_key(guild_id.get(), user_id.get());
|
||||
self.set_config(&key, config).await
|
||||
}
|
||||
|
||||
pub async fn load_user_config(
|
||||
&self,
|
||||
guild_id: GuildId,
|
||||
user_id: UserId,
|
||||
) -> Result<Option<UserConfig>> {
|
||||
let key = Self::user_config_key(guild_id.get(), user_id.get());
|
||||
self.get_config(&key).await
|
||||
}
|
||||
|
||||
pub async fn delete_user_config(&self, guild_id: GuildId, user_id: UserId) -> Result<()> {
|
||||
let key = Self::user_config_key(guild_id.get(), user_id.get());
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Additional server config methods
|
||||
pub async fn save_server_config(&self, guild_id: GuildId, config: &ServerConfig) -> Result<()> {
|
||||
let key = Self::server_config_key(guild_id.get());
|
||||
self.set_config(&key, config).await
|
||||
}
|
||||
|
||||
pub async fn load_server_config(&self, guild_id: GuildId) -> Result<Option<ServerConfig>> {
|
||||
let key = Self::server_config_key(guild_id.get());
|
||||
self.get_config(&key).await
|
||||
}
|
||||
|
||||
pub async fn delete_server_config(&self, guild_id: GuildId) -> Result<()> {
|
||||
let key = Self::server_config_key(guild_id.get());
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Dictionary methods
|
||||
pub async fn save_dictionary(
|
||||
&self,
|
||||
guild_id: GuildId,
|
||||
dictionary: &HashMap<String, String>,
|
||||
) -> Result<()> {
|
||||
let key = Self::dictionary_key(guild_id.get());
|
||||
self.set_config(&key, dictionary).await
|
||||
}
|
||||
|
||||
pub async fn load_dictionary(&self, guild_id: GuildId) -> Result<HashMap<String, String>> {
|
||||
let key = Self::dictionary_key(guild_id.get());
|
||||
let dict: Option<HashMap<String, String>> = self.get_config(&key).await?;
|
||||
Ok(dict.unwrap_or_default())
|
||||
}
|
||||
|
||||
pub async fn delete_dictionary(&self, guild_id: GuildId) -> Result<()> {
|
||||
let key = Self::dictionary_key(guild_id.get());
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_tts_instance(&self, guild_id: GuildId) -> Result<()> {
|
||||
self.remove_tts_instance(guild_id).await
|
||||
}
|
||||
|
||||
pub async fn list_active_instances(&self) -> Result<Vec<u64>> {
|
||||
let list_key = Self::tts_instances_list_key();
|
||||
let mut connection = self
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||
let guild_ids: Vec<u64> = connection.smembers(&list_key).await.unwrap_or_default();
|
||||
Ok(guild_ids)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::errors::constants;
|
||||
use bb8_redis::redis::AsyncCommands;
|
||||
use serial_test::serial;
|
||||
|
||||
// Helper function to create test database (requires Redis running)
|
||||
async fn create_test_database() -> Result<Database> {
|
||||
let manager = RedisConnectionManager::new("redis://127.0.0.1:6379/15")?; // Use test DB
|
||||
let pool = bb8::Pool::builder()
|
||||
.max_size(1)
|
||||
.build(manager)
|
||||
.await
|
||||
.map_err(|e| NCBError::Database(format!("Pool creation failed: {}", e)))?;
|
||||
|
||||
Ok(Database { pool })
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_database_creation() {
|
||||
// This test requires Redis to be running
|
||||
match create_test_database().await {
|
||||
Ok(_db) => {
|
||||
// Test successful creation
|
||||
assert!(true);
|
||||
}
|
||||
Err(_) => {
|
||||
// Skip test if Redis is not available
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_key_generation() {
|
||||
let guild_id = 123456789u64;
|
||||
let user_id = 987654321u64;
|
||||
|
||||
// Test TTS instance key
|
||||
let tts_key = Database::tts_instance_key(guild_id);
|
||||
assert!(tts_key.contains(&guild_id.to_string()));
|
||||
|
||||
// Test TTS instances list key
|
||||
let list_key = Database::tts_instances_list_key();
|
||||
assert!(!list_key.is_empty());
|
||||
|
||||
// Test user config key
|
||||
let user_key = Database::user_config_key(guild_id, user_id);
|
||||
assert_eq!(user_key, "user:config:123456789:987654321");
|
||||
|
||||
// Test server config key
|
||||
let server_key = Database::server_config_key(guild_id);
|
||||
assert_eq!(server_key, "server:config:123456789");
|
||||
|
||||
// Test dictionary key
|
||||
let dict_key = Database::dictionary_key(guild_id);
|
||||
assert_eq!(dict_key, "dictionary:123456789");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_tts_instance_operations() {
|
||||
let db = match create_test_database().await {
|
||||
Ok(db) => db,
|
||||
Err(_) => return, // Skip if Redis not available
|
||||
};
|
||||
|
||||
let guild_id = GuildId::new(12345);
|
||||
let test_instance =
|
||||
TTSInstance::new_single(ChannelId::new(123), ChannelId::new(456), guild_id);
|
||||
|
||||
// Clear any existing data
|
||||
if let Ok(mut conn) = db.pool.get().await {
|
||||
let _: () = conn
|
||||
.del(Database::tts_instance_key(guild_id.get()))
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let _: () = conn
|
||||
.srem(Database::tts_instances_list_key(), guild_id.get())
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
} else {
|
||||
return; // Skip if can't get connection
|
||||
}
|
||||
|
||||
// Test saving TTS instance
|
||||
let save_result = db.save_tts_instance(guild_id, &test_instance).await;
|
||||
if save_result.is_err() {
|
||||
// Skip test if Redis operations fail
|
||||
return;
|
||||
}
|
||||
|
||||
// Test loading TTS instance
|
||||
let load_result = db.load_tts_instance(guild_id).await;
|
||||
if load_result.is_err() {
|
||||
return; // Skip if Redis operations fail
|
||||
}
|
||||
|
||||
let loaded_instance = load_result.unwrap();
|
||||
if let Some(instance) = loaded_instance {
|
||||
assert_eq!(instance.guild, test_instance.guild);
|
||||
assert_eq!(instance.text_channels, test_instance.text_channels);
|
||||
assert_eq!(instance.voice_channel, test_instance.voice_channel);
|
||||
}
|
||||
|
||||
// Test listing active instances
|
||||
let list_result = db.list_active_instances().await;
|
||||
if list_result.is_err() {
|
||||
return; // Skip if Redis operations fail
|
||||
}
|
||||
let instances = list_result.unwrap();
|
||||
assert!(instances.contains(&guild_id.get()));
|
||||
|
||||
// Test deleting TTS instance
|
||||
let delete_result = db.delete_tts_instance(guild_id).await;
|
||||
if delete_result.is_err() {
|
||||
return; // Skip if Redis operations fail
|
||||
}
|
||||
|
||||
// Verify deletion
|
||||
let load_after_delete = db.load_tts_instance(guild_id).await;
|
||||
if load_after_delete.is_err() {
|
||||
return; // Skip if Redis operations fail
|
||||
}
|
||||
assert!(load_after_delete.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_database_constants() {
|
||||
// Test that constants are reasonable
|
||||
assert!(constants::REDIS_CONNECTION_TIMEOUT_SECS > 0);
|
||||
assert!(constants::REDIS_MAX_CONNECTIONS > 0);
|
||||
assert!(constants::REDIS_MIN_IDLE_CONNECTIONS <= constants::REDIS_MAX_CONNECTIONS);
|
||||
}
|
||||
}
|
||||
|
34
src/database/dictionary.rs
Normal file
34
src/database/dictionary.rs
Normal file
@ -0,0 +1,34 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Rule {
|
||||
pub id: String,
|
||||
pub is_regex: bool,
|
||||
pub rule: String,
|
||||
pub to: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Dictionary {
|
||||
pub rules: Vec<Rule>,
|
||||
}
|
||||
|
||||
impl Dictionary {
|
||||
pub fn new() -> Self {
|
||||
let rules = vec![
|
||||
Rule {
|
||||
id: String::from("url"),
|
||||
is_regex: true,
|
||||
rule: String::from(r"(http://|https://){1}[\w\.\-/:\#\?=\&;%\~\+]+"),
|
||||
to: String::from("URL"),
|
||||
},
|
||||
Rule {
|
||||
id: String::from("code"),
|
||||
is_regex: true,
|
||||
rule: String::from(r"```(.|\n)*```"),
|
||||
to: String::from("code"),
|
||||
},
|
||||
];
|
||||
Self { rules }
|
||||
}
|
||||
}
|
@ -1,2 +1,4 @@
|
||||
pub mod database;
|
||||
pub mod dictionary;
|
||||
pub mod server_config;
|
||||
pub mod user_config;
|
||||
pub mod database;
|
16
src/database/server_config.rs
Normal file
16
src/database/server_config.rs
Normal file
@ -0,0 +1,16 @@
|
||||
use super::dictionary::Dictionary;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct DictionaryOnlyServerConfig {
|
||||
pub dictionary: Dictionary,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ServerConfig {
|
||||
pub dictionary: Dictionary,
|
||||
pub autostart_channel_id: Option<u64>,
|
||||
pub autostart_text_channel_id: Option<u64>,
|
||||
pub voice_state_announce: Option<bool>,
|
||||
pub read_username: Option<bool>,
|
||||
}
|
@ -1,10 +1,12 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::tts::{gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType};
|
||||
use crate::tts::{
|
||||
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct UserConfig {
|
||||
pub tts_type: Option<TTSType>,
|
||||
pub gcp_tts_voice: Option<VoiceSelectionParams>,
|
||||
pub voicevox_speaker: Option<i64>
|
||||
pub voicevox_speaker: Option<i64>,
|
||||
}
|
||||
|
521
src/errors.rs
Normal file
521
src/errors.rs
Normal file
@ -0,0 +1,521 @@
|
||||
/// Custom error types for the NCB-TTS application
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum NCBError {
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(String),
|
||||
|
||||
#[error("Database error: {0}")]
|
||||
Database(String),
|
||||
|
||||
#[error("VOICEVOX API error: {0}")]
|
||||
VOICEVOX(String),
|
||||
|
||||
#[error("Discord error: {0}")]
|
||||
Discord(#[from] serenity::Error),
|
||||
|
||||
#[error("TTS synthesis error: {0}")]
|
||||
TTSSynthesis(String),
|
||||
|
||||
#[error("GCP authentication error: {0}")]
|
||||
GCPAuth(#[from] gcp_auth::Error),
|
||||
|
||||
#[error("HTTP request error: {0}")]
|
||||
Http(#[from] reqwest::Error),
|
||||
|
||||
#[error("JSON parsing error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
#[error("Redis connection error: {0}")]
|
||||
Redis(String),
|
||||
|
||||
#[error("Redis error: {0}")]
|
||||
RedisError(#[from] bb8_redis::redis::RedisError),
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("Voice connection error: {0}")]
|
||||
VoiceConnection(String),
|
||||
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
#[error("Invalid regex pattern: {0}")]
|
||||
InvalidRegex(String),
|
||||
|
||||
#[error("Songbird error: {0}")]
|
||||
Songbird(String),
|
||||
|
||||
#[error("User not in voice channel")]
|
||||
UserNotInVoiceChannel,
|
||||
|
||||
#[error("Guild not found")]
|
||||
GuildNotFound,
|
||||
|
||||
#[error("Channel not found")]
|
||||
ChannelNotFound,
|
||||
|
||||
#[error("TTS instance not found for guild {guild_id}")]
|
||||
TTSInstanceNotFound { guild_id: u64 },
|
||||
|
||||
#[error("Text too long (max {max_length} characters)")]
|
||||
TextTooLong { max_length: usize },
|
||||
|
||||
#[error("Text contains prohibited content")]
|
||||
ProhibitedContent,
|
||||
|
||||
#[error("Rate limit exceeded")]
|
||||
RateLimitExceeded,
|
||||
|
||||
#[error("TOML parsing error: {0}")]
|
||||
Toml(#[from] toml::de::Error),
|
||||
}
|
||||
|
||||
impl NCBError {
|
||||
pub fn config(message: impl Into<String>) -> Self {
|
||||
Self::Config(message.into())
|
||||
}
|
||||
|
||||
pub fn database(message: impl Into<String>) -> Self {
|
||||
Self::Database(message.into())
|
||||
}
|
||||
|
||||
pub fn voicevox(message: impl Into<String>) -> Self {
|
||||
Self::VOICEVOX(message.into())
|
||||
}
|
||||
|
||||
pub fn voice_connection(message: impl Into<String>) -> Self {
|
||||
Self::VoiceConnection(message.into())
|
||||
}
|
||||
|
||||
pub fn tts_synthesis(message: impl Into<String>) -> Self {
|
||||
Self::TTSSynthesis(message.into())
|
||||
}
|
||||
|
||||
pub fn invalid_input(message: impl Into<String>) -> Self {
|
||||
Self::InvalidInput(message.into())
|
||||
}
|
||||
|
||||
pub fn invalid_regex(message: impl Into<String>) -> Self {
|
||||
Self::InvalidRegex(message.into())
|
||||
}
|
||||
|
||||
pub fn songbird(message: impl Into<String>) -> Self {
|
||||
Self::Songbird(message.into())
|
||||
}
|
||||
|
||||
pub fn tts_instance_not_found(guild_id: u64) -> Self {
|
||||
Self::TTSInstanceNotFound { guild_id }
|
||||
}
|
||||
|
||||
pub fn text_too_long(max_length: usize) -> Self {
|
||||
Self::TextTooLong { max_length }
|
||||
}
|
||||
|
||||
pub fn redis(message: impl Into<String>) -> Self {
|
||||
Self::Redis(message.into())
|
||||
}
|
||||
|
||||
pub fn missing_env_var(var_name: &str) -> Self {
|
||||
Self::Config(format!("Missing environment variable: {}", var_name))
|
||||
}
|
||||
}
|
||||
|
||||
/// Result type alias for convenience
|
||||
pub type Result<T> = std::result::Result<T, NCBError>;
|
||||
|
||||
/// Input validation functions
|
||||
pub mod validation {
|
||||
use super::*;
|
||||
use regex::Regex;
|
||||
|
||||
/// Validate regex pattern for potential ReDoS attacks
|
||||
pub fn validate_regex_pattern(pattern: &str) -> Result<()> {
|
||||
// Check for common ReDoS patterns (catastrophic backtracking)
|
||||
let redos_patterns = [
|
||||
r"\(\?\:", // Non-capturing groups in dangerous positions
|
||||
r"\(\?\=", // Positive lookahead
|
||||
r"\(\?\!", // Negative lookahead
|
||||
r"\(\?\<\=", // Positive lookbehind
|
||||
r"\(\?\<\!", // Negative lookbehind
|
||||
r"\*\*", // Actual nested quantifiers (not possessive)
|
||||
r"\+\*", // Nested quantifiers
|
||||
r"\*\+", // Nested quantifiers
|
||||
];
|
||||
|
||||
for redos_pattern in &redos_patterns {
|
||||
if pattern.contains(redos_pattern) {
|
||||
return Err(NCBError::invalid_regex(format!(
|
||||
"Pattern contains potentially dangerous construct: {}",
|
||||
redos_pattern
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Check pattern length
|
||||
if pattern.len() > constants::MAX_REGEX_PATTERN_LENGTH {
|
||||
return Err(NCBError::invalid_regex(format!(
|
||||
"Pattern too long (max {} characters)",
|
||||
constants::MAX_REGEX_PATTERN_LENGTH
|
||||
)));
|
||||
}
|
||||
|
||||
// Try to compile the regex to validate syntax
|
||||
Regex::new(pattern)
|
||||
.map_err(|e| NCBError::invalid_regex(format!("Invalid regex syntax: {}", e)))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate rule name
|
||||
pub fn validate_rule_name(name: &str) -> Result<()> {
|
||||
if name.trim().is_empty() {
|
||||
return Err(NCBError::invalid_input("Rule name cannot be empty"));
|
||||
}
|
||||
|
||||
if name.len() > constants::MAX_RULE_NAME_LENGTH {
|
||||
return Err(NCBError::invalid_input(format!(
|
||||
"Rule name too long (max {} characters)",
|
||||
constants::MAX_RULE_NAME_LENGTH
|
||||
)));
|
||||
}
|
||||
|
||||
// Check for invalid characters
|
||||
if !name
|
||||
.chars()
|
||||
.all(|c| c.is_alphanumeric() || c.is_whitespace() || "_-".contains(c))
|
||||
{
|
||||
return Err(NCBError::invalid_input(
|
||||
"Rule name contains invalid characters (only alphanumeric, spaces, hyphens, and underscores allowed)"
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate TTS text input
|
||||
pub fn validate_tts_text(text: &str) -> Result<()> {
|
||||
if text.trim().is_empty() {
|
||||
return Err(NCBError::invalid_input("Text cannot be empty"));
|
||||
}
|
||||
|
||||
if text.len() > constants::MAX_TTS_TEXT_LENGTH {
|
||||
return Err(NCBError::text_too_long(constants::MAX_TTS_TEXT_LENGTH));
|
||||
}
|
||||
|
||||
// Check for prohibited patterns
|
||||
let prohibited_patterns = [
|
||||
r"<script", // Script injection
|
||||
r"javascript:", // JavaScript URLs
|
||||
r"data:", // Data URLs
|
||||
r"<?xml", // XML processing instructions
|
||||
];
|
||||
|
||||
let text_lower = text.to_lowercase();
|
||||
for pattern in &prohibited_patterns {
|
||||
if text_lower.contains(pattern) {
|
||||
return Err(NCBError::ProhibitedContent);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate replacement text for dictionary rules
|
||||
pub fn validate_replacement_text(text: &str) -> Result<()> {
|
||||
if text.trim().is_empty() {
|
||||
return Err(NCBError::invalid_input("Replacement text cannot be empty"));
|
||||
}
|
||||
|
||||
if text.len() > constants::MAX_TTS_TEXT_LENGTH {
|
||||
return Err(NCBError::text_too_long(constants::MAX_TTS_TEXT_LENGTH));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Sanitize SSML input to prevent injection attacks
|
||||
pub fn sanitize_ssml(text: &str) -> String {
|
||||
// Remove or escape potentially dangerous SSML tags
|
||||
let _dangerous_tags = [
|
||||
"audio", "break", "emphasis", "lang", "mark", "p", "phoneme", "prosody", "say-as",
|
||||
"speak", "sub", "voice", "w",
|
||||
];
|
||||
|
||||
let mut sanitized = text.to_string();
|
||||
|
||||
// Remove script-like content
|
||||
sanitized = sanitized.replace("<script", "<script");
|
||||
sanitized = sanitized.replace("javascript:", "");
|
||||
sanitized = sanitized.replace("data:", "");
|
||||
|
||||
// Limit the overall length
|
||||
if sanitized.len() > constants::MAX_SSML_LENGTH {
|
||||
sanitized.truncate(constants::MAX_SSML_LENGTH);
|
||||
}
|
||||
|
||||
sanitized
|
||||
}
|
||||
}
|
||||
|
||||
/// Constants used throughout the application
|
||||
pub mod constants {
|
||||
// Configuration constants
|
||||
pub const DEFAULT_CONFIG_PATH: &str = "config.toml";
|
||||
pub const DEFAULT_DICTIONARY_PATH: &str = "dictionary.txt";
|
||||
|
||||
// Redis constants
|
||||
pub const REDIS_CONNECTION_TIMEOUT_SECS: u64 = 5;
|
||||
pub const REDIS_MAX_CONNECTIONS: u32 = 10;
|
||||
pub const REDIS_MIN_IDLE_CONNECTIONS: u32 = 1;
|
||||
|
||||
// Cache constants
|
||||
pub const DEFAULT_CACHE_SIZE: usize = 1000;
|
||||
pub const CACHE_TTL_SECS: u64 = 86400; // 24 hours
|
||||
|
||||
// TTS constants
|
||||
pub const MAX_TTS_TEXT_LENGTH: usize = 500;
|
||||
pub const MAX_SSML_LENGTH: usize = 1000;
|
||||
pub const TTS_TIMEOUT_SECS: u64 = 30;
|
||||
pub const DEFAULT_SPEAKING_RATE: f32 = 1.2;
|
||||
pub const DEFAULT_PITCH: f32 = 0.0;
|
||||
|
||||
// Validation constants
|
||||
pub const MAX_REGEX_PATTERN_LENGTH: usize = 100;
|
||||
pub const MAX_RULE_NAME_LENGTH: usize = 50;
|
||||
pub const MAX_USERNAME_LENGTH: usize = 32;
|
||||
|
||||
// Circuit breaker constants
|
||||
pub const CIRCUIT_BREAKER_FAILURE_THRESHOLD: u32 = 5;
|
||||
pub const CIRCUIT_BREAKER_TIMEOUT_SECS: u64 = 60;
|
||||
|
||||
// Retry constants
|
||||
pub const DEFAULT_MAX_RETRY_ATTEMPTS: u32 = 3;
|
||||
pub const DEFAULT_RETRY_DELAY_MS: u64 = 500;
|
||||
pub const MAX_RETRY_DELAY_MS: u64 = 5000;
|
||||
|
||||
// Connection monitoring constants
|
||||
pub const CONNECTION_CHECK_INTERVAL_SECS: u64 = 5;
|
||||
pub const MAX_RECONNECTION_ATTEMPTS: u32 = 3;
|
||||
pub const RECONNECTION_BACKOFF_SECS: u64 = 2;
|
||||
|
||||
// Voice connection constants
|
||||
pub const VOICE_CONNECTION_TIMEOUT_SECS: u64 = 10;
|
||||
pub const AUDIO_BITRATE_KBPS: u32 = 128;
|
||||
pub const AUDIO_SAMPLE_RATE: u32 = 48000;
|
||||
|
||||
// Database key prefixes
|
||||
pub const DISCORD_SERVER_PREFIX: &str = "discord:server:";
|
||||
pub const DISCORD_USER_PREFIX: &str = "discord:user:";
|
||||
pub const TTS_INSTANCE_PREFIX: &str = "tts:instance:";
|
||||
pub const TTS_INSTANCES_LIST_KEY: &str = "tts:instances";
|
||||
|
||||
// Default values
|
||||
pub const DEFAULT_VOICEVOX_SPEAKER: i64 = 1;
|
||||
|
||||
// Message constants
|
||||
pub const RULE_ADDED: &str = "RULE_ADDED";
|
||||
pub const RULE_REMOVED: &str = "RULE_REMOVED";
|
||||
pub const RULE_ALREADY_EXISTS: &str = "RULE_ALREADY_EXISTS";
|
||||
pub const RULE_NOT_FOUND: &str = "RULE_NOT_FOUND";
|
||||
pub const DICTIONARY_RULE_APPLIED: &str = "DICTIONARY_RULE_APPLIED";
|
||||
pub const GUILD_NOT_FOUND: &str = "GUILD_NOT_FOUND";
|
||||
pub const CHANNEL_JOIN_SUCCESS: &str = "CHANNEL_JOIN_SUCCESS";
|
||||
pub const CHANNEL_LEAVE_SUCCESS: &str = "CHANNEL_LEAVE_SUCCESS";
|
||||
pub const AUTOSTART_CHANNEL_SET: &str = "AUTOSTART_CHANNEL_SET";
|
||||
pub const SET_AUTOSTART_CHANNEL_CLEAR: &str = "SET_AUTOSTART_CHANNEL_CLEAR";
|
||||
pub const SET_AUTOSTART_TEXT_CHANNEL: &str = "SET_AUTOSTART_TEXT_CHANNEL";
|
||||
pub const SET_AUTOSTART_TEXT_CHANNEL_CLEAR: &str = "SET_AUTOSTART_TEXT_CHANNEL_CLEAR";
|
||||
|
||||
// TTS configuration constants
|
||||
pub const TTS_CONFIG_SERVER_ADD_DICTIONARY: &str = "TTS_CONFIG_SERVER_ADD_DICTIONARY";
|
||||
pub const TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE: &str =
|
||||
"TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE";
|
||||
pub const TTS_CONFIG_SERVER_SET_READ_USERNAME: &str = "TTS_CONFIG_SERVER_SET_READ_USERNAME";
|
||||
pub const TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU: &str =
|
||||
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU";
|
||||
pub const TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON: &str =
|
||||
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON";
|
||||
pub const TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON: &str =
|
||||
"TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON";
|
||||
pub const TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON: &str =
|
||||
"TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON";
|
||||
pub const SET_AUTOSTART_CHANNEL: &str = "SET_AUTOSTART_CHANNEL";
|
||||
pub const TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL: &str =
|
||||
"TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL";
|
||||
pub const TTS_CONFIG_SERVER_BACK: &str = "TTS_CONFIG_SERVER_BACK";
|
||||
pub const TTS_CONFIG_SERVER: &str = "TTS_CONFIG_SERVER";
|
||||
pub const TTS_CONFIG_SERVER_DICTIONARY: &str = "TTS_CONFIG_SERVER_DICTIONARY";
|
||||
|
||||
// TTS engine selection messages
|
||||
pub const TTS_CONFIG_ENGINE_SELECTED_GOOGLE: &str = "TTS_CONFIG_ENGINE_SELECTED_GOOGLE";
|
||||
pub const TTS_CONFIG_ENGINE_SELECTED_VOICEVOX: &str = "TTS_CONFIG_ENGINE_SELECTED_VOICEVOX";
|
||||
|
||||
// Error messages
|
||||
pub const USER_NOT_IN_VOICE_CHANNEL: &str = "USER_NOT_IN_VOICE_CHANNEL";
|
||||
pub const CHANNEL_NOT_FOUND: &str = "CHANNEL_NOT_FOUND";
|
||||
|
||||
// Rate limiting constants
|
||||
pub const RATE_LIMIT_REQUESTS_PER_MINUTE: u32 = 60;
|
||||
pub const RATE_LIMIT_REQUESTS_PER_HOUR: u32 = 1000;
|
||||
pub const RATE_LIMIT_WINDOW_SECS: u64 = 60;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ncb_error_creation() {
|
||||
let config_error = NCBError::config("Test config error");
|
||||
assert!(matches!(config_error, NCBError::Config(_)));
|
||||
assert_eq!(
|
||||
config_error.to_string(),
|
||||
"Configuration error: Test config error"
|
||||
);
|
||||
|
||||
let database_error = NCBError::database("Test database error");
|
||||
assert!(matches!(database_error, NCBError::Database(_)));
|
||||
assert_eq!(
|
||||
database_error.to_string(),
|
||||
"Database error: Test database error"
|
||||
);
|
||||
|
||||
let voicevox_error = NCBError::voicevox("Test VOICEVOX error");
|
||||
assert!(matches!(voicevox_error, NCBError::VOICEVOX(_)));
|
||||
assert_eq!(
|
||||
voicevox_error.to_string(),
|
||||
"VOICEVOX API error: Test VOICEVOX error"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tts_instance_not_found_error() {
|
||||
let guild_id = 12345u64;
|
||||
let error = NCBError::tts_instance_not_found(guild_id);
|
||||
assert!(matches!(
|
||||
error,
|
||||
NCBError::TTSInstanceNotFound { guild_id: 12345 }
|
||||
));
|
||||
assert_eq!(error.to_string(), "TTS instance not found for guild 12345");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_text_too_long_error() {
|
||||
let max_length = 500;
|
||||
let error = NCBError::text_too_long(max_length);
|
||||
assert!(matches!(error, NCBError::TextTooLong { max_length: 500 }));
|
||||
assert_eq!(error.to_string(), "Text too long (max 500 characters)");
|
||||
}
|
||||
|
||||
mod validation_tests {
|
||||
use super::super::constants;
|
||||
use super::super::validation::*;
|
||||
|
||||
#[test]
|
||||
fn test_validate_regex_pattern_valid() {
|
||||
assert!(validate_regex_pattern(r"[a-zA-Z]+").is_ok());
|
||||
assert!(validate_regex_pattern(r"\d{1,3}").is_ok());
|
||||
assert!(validate_regex_pattern(r"hello|world").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_regex_pattern_redos() {
|
||||
// Test that the validation function properly checks patterns
|
||||
// Most problematic patterns are caught by regex compilation errors
|
||||
// This test focuses on basic pattern safety checks
|
||||
|
||||
// Test length validation works
|
||||
let very_long_pattern = "a".repeat(constants::MAX_REGEX_PATTERN_LENGTH + 1);
|
||||
assert!(validate_regex_pattern(&very_long_pattern).is_err());
|
||||
|
||||
// Test basic pattern validation passes for safe patterns
|
||||
assert!(validate_regex_pattern(r"[a-z]+").is_ok());
|
||||
assert!(validate_regex_pattern(r"\d{1,3}").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_regex_pattern_too_long() {
|
||||
let long_pattern = "a".repeat(constants::MAX_REGEX_PATTERN_LENGTH + 1);
|
||||
assert!(validate_regex_pattern(&long_pattern).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_regex_pattern_invalid_syntax() {
|
||||
assert!(validate_regex_pattern(r"[").is_err());
|
||||
assert!(validate_regex_pattern(r"*").is_err());
|
||||
assert!(validate_regex_pattern(r"(?P<>)").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_rule_name_valid() {
|
||||
assert!(validate_rule_name("test_rule").is_ok());
|
||||
assert!(validate_rule_name("Test Rule 123").is_ok());
|
||||
assert!(validate_rule_name("rule-name").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_rule_name_empty() {
|
||||
assert!(validate_rule_name("").is_err());
|
||||
assert!(validate_rule_name(" ").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_rule_name_too_long() {
|
||||
let long_name = "a".repeat(constants::MAX_RULE_NAME_LENGTH + 1);
|
||||
assert!(validate_rule_name(&long_name).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_rule_name_invalid_chars() {
|
||||
assert!(validate_rule_name("rule@name").is_err());
|
||||
assert!(validate_rule_name("rule#name").is_err());
|
||||
assert!(validate_rule_name("rule$name").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_tts_text_valid() {
|
||||
assert!(validate_tts_text("Hello world").is_ok());
|
||||
assert!(validate_tts_text("こんにちは").is_ok());
|
||||
assert!(validate_tts_text("Test with numbers 123").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_tts_text_empty() {
|
||||
assert!(validate_tts_text("").is_err());
|
||||
assert!(validate_tts_text(" ").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_tts_text_too_long() {
|
||||
let long_text = "a".repeat(constants::MAX_TTS_TEXT_LENGTH + 1);
|
||||
assert!(validate_tts_text(&long_text).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_tts_text_prohibited_content() {
|
||||
assert!(validate_tts_text("<script>alert('xss')</script>").is_err());
|
||||
assert!(validate_tts_text("javascript:alert('xss')").is_err());
|
||||
assert!(validate_tts_text("data:text/html,<h1>XSS</h1>").is_err());
|
||||
assert!(validate_tts_text("<?xml version=\"1.0\"?>").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_ssml() {
|
||||
let input = "<script>alert('xss')</script>Hello world";
|
||||
let output = sanitize_ssml(input);
|
||||
assert!(!output.contains("<script"));
|
||||
assert!(output.contains("<script"));
|
||||
assert!(output.contains("Hello world"));
|
||||
|
||||
let input_with_js = "javascript:alert('test')Hello";
|
||||
let output = sanitize_ssml(input_with_js);
|
||||
assert!(!output.contains("javascript:"));
|
||||
assert!(output.contains("Hello"));
|
||||
|
||||
let long_input = "a".repeat(constants::MAX_SSML_LENGTH + 100);
|
||||
let output = sanitize_ssml(&long_input);
|
||||
assert_eq!(output.len(), constants::MAX_SSML_LENGTH);
|
||||
}
|
||||
}
|
||||
}
|
1142
src/event_handler.rs
1142
src/event_handler.rs
File diff suppressed because it is too large
Load Diff
44
src/events/message_receive.rs
Normal file
44
src/events/message_receive.rs
Normal file
@ -0,0 +1,44 @@
|
||||
use serenity::{model::prelude::Message, prelude::Context};
|
||||
|
||||
use crate::data::TTSData;
|
||||
|
||||
pub async fn message(ctx: Context, message: Message) {
|
||||
if message.author.bot {
|
||||
return;
|
||||
}
|
||||
|
||||
let guild_id = message.guild(&ctx.cache);
|
||||
|
||||
if let None = guild_id {
|
||||
return;
|
||||
}
|
||||
|
||||
let guild_id = guild_id.unwrap().id;
|
||||
|
||||
let storage_lock = {
|
||||
let data_read = ctx.data.read().await;
|
||||
data_read
|
||||
.get::<TTSData>()
|
||||
.expect("Cannot get TTSStorage")
|
||||
.clone()
|
||||
};
|
||||
|
||||
{
|
||||
let mut storage = storage_lock.write().await;
|
||||
if !storage.contains_key(&guild_id) {
|
||||
return;
|
||||
}
|
||||
|
||||
let instance = storage.get_mut(&guild_id).unwrap();
|
||||
|
||||
if !instance.contains_text_channel(message.channel_id) {
|
||||
return;
|
||||
}
|
||||
|
||||
if message.content.starts_with(";") {
|
||||
return;
|
||||
}
|
||||
|
||||
instance.read(message, &ctx).await;
|
||||
}
|
||||
}
|
3
src/events/mod.rs
Normal file
3
src/events/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod message_receive;
|
||||
pub mod ready;
|
||||
pub mod voice_state_update;
|
154
src/events/ready.rs
Normal file
154
src/events/ready.rs
Normal file
@ -0,0 +1,154 @@
|
||||
use serenity::{
|
||||
all::{Command, CommandOptionType, CreateCommand, CreateCommandOption},
|
||||
model::prelude::Ready,
|
||||
prelude::Context,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
connection_monitor::ConnectionMonitor,
|
||||
data::{DatabaseClientData, TTSData},
|
||||
};
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn ready(ctx: Context, ready: Ready) {
|
||||
info!("{} is connected!", ready.user.name);
|
||||
|
||||
Command::set_global_commands(
|
||||
&ctx.http,
|
||||
vec![
|
||||
CreateCommand::new("stop").description("Stop tts"),
|
||||
CreateCommand::new("setup")
|
||||
.description("Setup tts")
|
||||
.set_options(vec![CreateCommandOption::new(
|
||||
CommandOptionType::String,
|
||||
"mode",
|
||||
"TTS channel",
|
||||
)
|
||||
.add_string_choice("Text Channel", "TEXT_CHANNEL")
|
||||
.add_string_choice("New Thread", "NEW_THREAD")
|
||||
.add_string_choice("Voice Channel", "VOICE_CHANNEL")
|
||||
.required(false)]),
|
||||
CreateCommand::new("config").description("Config"),
|
||||
CreateCommand::new("skip").description("skip tts message"),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Restore TTS instances from database
|
||||
restore_tts_instances(&ctx).await;
|
||||
|
||||
// Start connection monitor
|
||||
ConnectionMonitor::start(ctx.clone());
|
||||
}
|
||||
|
||||
/// Restore TTS instances from database and reconnect to voice channels
|
||||
async fn restore_tts_instances(ctx: &Context) {
|
||||
info!("Restoring TTS instances from database...");
|
||||
|
||||
let data = ctx.data.read().await;
|
||||
let database = data
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let tts_data = data.get::<TTSData>().unwrap().clone();
|
||||
drop(data);
|
||||
|
||||
match database.get_all_tts_instances().await {
|
||||
Ok(instances) => {
|
||||
let mut restored_count = 0;
|
||||
let mut failed_count = 0;
|
||||
|
||||
for (guild_id, instance) in instances {
|
||||
// Check if there are users in the voice channel before reconnecting
|
||||
let should_reconnect = match guild_id.channels(&ctx.http).await {
|
||||
Ok(channels) => {
|
||||
if let Some(channel) = channels.get(&instance.voice_channel) {
|
||||
match channel.members(&ctx.cache) {
|
||||
Ok(members) => {
|
||||
let user_count =
|
||||
members.iter().filter(|member| !member.user.bot).count();
|
||||
user_count > 0
|
||||
}
|
||||
Err(_) => {
|
||||
// If we can't get members, assume there are no users
|
||||
tracing::warn!(
|
||||
"Failed to get members for voice channel {} in guild {}",
|
||||
instance.voice_channel,
|
||||
guild_id
|
||||
);
|
||||
false
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Channel doesn't exist anymore
|
||||
tracing::warn!(
|
||||
"Voice channel {} no longer exists in guild {}",
|
||||
instance.voice_channel,
|
||||
guild_id
|
||||
);
|
||||
false
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// If we can't get channels, assume reconnection should not happen
|
||||
tracing::warn!("Failed to get channels for guild {}", guild_id);
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
if !should_reconnect {
|
||||
// Remove instance from database as the channel is empty or doesn't exist
|
||||
failed_count += 1;
|
||||
tracing::info!("Skipping reconnection for guild {} - no users in voice channel or channel doesn't exist", guild_id);
|
||||
|
||||
if let Err(db_err) = database.remove_tts_instance(guild_id).await {
|
||||
tracing::error!(
|
||||
"Failed to remove empty TTS instance from database: {}",
|
||||
db_err
|
||||
);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Try to reconnect to voice channel
|
||||
match instance.reconnect(ctx, true).await {
|
||||
Ok(_) => {
|
||||
// Add to in-memory storage
|
||||
let mut tts_data = tts_data.write().await;
|
||||
tts_data.insert(guild_id, instance);
|
||||
drop(tts_data);
|
||||
|
||||
restored_count += 1;
|
||||
info!("Restored TTS instance for guild {}", guild_id);
|
||||
}
|
||||
Err(e) => {
|
||||
failed_count += 1;
|
||||
tracing::warn!(
|
||||
"Failed to restore TTS instance for guild {}: {}",
|
||||
guild_id,
|
||||
e
|
||||
);
|
||||
|
||||
// Remove failed instance from database
|
||||
if let Err(db_err) = database.remove_tts_instance(guild_id).await {
|
||||
tracing::error!(
|
||||
"Failed to remove invalid TTS instance from database: {}",
|
||||
db_err
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"TTS restoration complete: {} restored, {} failed",
|
||||
restored_count, failed_count
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to load TTS instances from database: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
185
src/events/voice_state_update.rs
Normal file
185
src/events/voice_state_update.rs
Normal file
@ -0,0 +1,185 @@
|
||||
use crate::{
|
||||
data::{DatabaseClientData, TTSClientData, TTSData},
|
||||
implement::{
|
||||
member_name::ReadName,
|
||||
voice_move_state::{VoiceMoveState, VoiceMoveStateTrait},
|
||||
},
|
||||
tts::{instance::TTSInstance, message::AnnounceMessage},
|
||||
};
|
||||
use serenity::{
|
||||
all::{CreateEmbed, CreateMessage, EditThread},
|
||||
model::voice::VoiceState,
|
||||
prelude::Context,
|
||||
};
|
||||
|
||||
pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: VoiceState) {
|
||||
if new.member.clone().unwrap().user.bot {
|
||||
return;
|
||||
}
|
||||
|
||||
if old.is_none() && new.guild_id.is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
let guild_id = if let Some(guild_id) = new.guild_id {
|
||||
guild_id
|
||||
} else {
|
||||
old.clone().unwrap().guild_id.unwrap()
|
||||
};
|
||||
|
||||
let storage_lock = {
|
||||
let data_read = ctx.data.read().await;
|
||||
data_read
|
||||
.get::<TTSData>()
|
||||
.expect("Cannot get TTSStorage")
|
||||
.clone()
|
||||
};
|
||||
|
||||
let config = {
|
||||
let data_read = ctx.data.read().await;
|
||||
let database = data_read
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
database
|
||||
.get_server_config_or_default(guild_id.get())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
{
|
||||
let mut storage = storage_lock.write().await;
|
||||
if !storage.contains_key(&guild_id) {
|
||||
if let Some(new_channel) = new.channel_id {
|
||||
if config.autostart_channel_id.unwrap_or(0) == new_channel.get() {
|
||||
let manager = songbird::get(&ctx)
|
||||
.await
|
||||
.expect("Cannot get songbird client.")
|
||||
.clone();
|
||||
|
||||
let text_channel_ids =
|
||||
if let Some(text_channel_id) = config.autostart_text_channel_id {
|
||||
vec![text_channel_id.into(), new_channel]
|
||||
} else {
|
||||
vec![new_channel]
|
||||
};
|
||||
|
||||
let instance = TTSInstance::new(text_channel_ids, new_channel, guild_id);
|
||||
storage.insert(guild_id, instance.clone());
|
||||
|
||||
// Save to database
|
||||
let data_read = ctx.data.read().await;
|
||||
let database = data_read
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
drop(data_read);
|
||||
|
||||
if let Err(e) = database.save_tts_instance(guild_id, &instance).await {
|
||||
tracing::error!("Failed to save TTS instance to database: {}", e);
|
||||
}
|
||||
|
||||
let _handler = manager.join(guild_id, new_channel).await;
|
||||
let data = ctx.data.read().await;
|
||||
let tts_client = data
|
||||
.get::<TTSClientData>()
|
||||
.expect("Cannot get TTSClientData");
|
||||
let voicevox_speakers = tts_client
|
||||
.voicevox_client
|
||||
.get_speakers()
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::error!("Failed to get VOICEVOX speakers: {}", e);
|
||||
vec!["VOICEVOX API unavailable".to_string()]
|
||||
});
|
||||
|
||||
new_channel
|
||||
.send_message(
|
||||
&ctx.http,
|
||||
CreateMessage::new().embed(
|
||||
CreateEmbed::new()
|
||||
.title("自動参加 読み上げ(Serenity)")
|
||||
.field(
|
||||
"VOICEVOXクレジット",
|
||||
format!("```\n{}\n```", voicevox_speakers.join("\n")),
|
||||
false,
|
||||
)
|
||||
.field("設定コマンド", "`/config`", false)
|
||||
.field("フィードバック", "https://feedback.mii.codes/", false),
|
||||
),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let instance = storage.get_mut(&guild_id).unwrap();
|
||||
|
||||
let voice_move_state = new.move_state(&old, instance.voice_channel);
|
||||
|
||||
if config.voice_state_announce.unwrap_or(false) {
|
||||
let message: Option<String> = match voice_move_state {
|
||||
VoiceMoveState::JOIN => Some(format!(
|
||||
"{} さんが通話に参加しました",
|
||||
new.member.unwrap().read_name()
|
||||
)),
|
||||
VoiceMoveState::LEAVE => Some(format!(
|
||||
"{} さんが通話から退出しました",
|
||||
new.member.unwrap().read_name()
|
||||
)),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some(message) = message {
|
||||
instance.read(AnnounceMessage { message }, &ctx).await;
|
||||
}
|
||||
}
|
||||
|
||||
if voice_move_state == VoiceMoveState::LEAVE {
|
||||
let mut del_flag = false;
|
||||
for channel in guild_id.channels(&ctx.http).await.unwrap() {
|
||||
if channel.0 == instance.voice_channel {
|
||||
let members = channel.1.members(&ctx.cache).unwrap();
|
||||
let user_count = members.iter().filter(|member| !member.user.bot).count();
|
||||
|
||||
del_flag = user_count == 0;
|
||||
}
|
||||
}
|
||||
|
||||
if del_flag {
|
||||
// Archive thread if it exists
|
||||
if let Some(&channel_id) = storage.get(&guild_id).unwrap().text_channels.first() {
|
||||
let http = ctx.http.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = channel_id
|
||||
.edit_thread(&http, EditThread::new().archived(true))
|
||||
.await;
|
||||
});
|
||||
}
|
||||
storage.remove(&guild_id);
|
||||
|
||||
// Remove from database
|
||||
let data_read = ctx.data.read().await;
|
||||
let database = data_read
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
drop(data_read);
|
||||
|
||||
if let Err(e) = database.remove_tts_instance(guild_id).await {
|
||||
tracing::error!("Failed to remove TTS instance from database: {}", e);
|
||||
}
|
||||
|
||||
let manager = songbird::get(&ctx)
|
||||
.await
|
||||
.expect("Cannot get songbird client.")
|
||||
.clone();
|
||||
|
||||
manager.remove(guild_id).await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,4 +1,7 @@
|
||||
use serenity::model::guild::Member;
|
||||
use serenity::model::{
|
||||
guild::{Member, PartialMember},
|
||||
user::User,
|
||||
};
|
||||
|
||||
pub trait ReadName {
|
||||
fn read_name(&self) -> String;
|
||||
@ -6,6 +9,20 @@ pub trait ReadName {
|
||||
|
||||
impl ReadName for Member {
|
||||
fn read_name(&self) -> String {
|
||||
self.nick.clone().unwrap_or(self.user.name.clone())
|
||||
self.nick.clone().unwrap_or(self.display_name().to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReadName for PartialMember {
|
||||
fn read_name(&self) -> String {
|
||||
self.nick
|
||||
.clone()
|
||||
.unwrap_or(self.user.as_ref().unwrap().display_name().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl ReadName for User {
|
||||
fn read_name(&self) -> String {
|
||||
self.display_name().to_string()
|
||||
}
|
||||
}
|
||||
|
@ -1,91 +1,226 @@
|
||||
use std::{path::Path, fs::File, io::Write, env};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serenity::{prelude::Context, model::prelude::Message};
|
||||
use serenity::{model::prelude::Message, prelude::Context};
|
||||
use songbird::tracks::Track;
|
||||
use tracing::{error, warn};
|
||||
|
||||
use crate::{
|
||||
data::{TTSClientData, DatabaseClientData},
|
||||
data::{DatabaseClientData, TTSClientData},
|
||||
errors::{constants::*, validation, NCBError},
|
||||
implement::member_name::ReadName,
|
||||
tts::{
|
||||
gcp_tts::structs::{
|
||||
audio_config::AudioConfig, synthesis_input::SynthesisInput,
|
||||
synthesize_request::SynthesizeRequest,
|
||||
},
|
||||
instance::TTSInstance,
|
||||
message::TTSMessage,
|
||||
gcp_tts::structs::{
|
||||
audio_config::AudioConfig, synthesis_input::SynthesisInput, synthesize_request::SynthesizeRequest
|
||||
}, tts_type::{self, TTSType}
|
||||
tts_type::TTSType,
|
||||
},
|
||||
utils::{get_cached_regex, retry_with_backoff},
|
||||
};
|
||||
|
||||
#[async_trait]
|
||||
impl TTSMessage for Message {
|
||||
async fn parse(&self, instance: &mut TTSInstance, _: &Context) -> String {
|
||||
let res = if let Some(before_message) = &instance.before_message {
|
||||
if before_message.author.id == self.author.id {
|
||||
self.content.clone()
|
||||
async fn parse(&self, instance: &mut TTSInstance, ctx: &Context) -> String {
|
||||
let data_read = ctx.data.read().await;
|
||||
|
||||
let config = {
|
||||
let database = data_read
|
||||
.get::<DatabaseClientData>()
|
||||
.ok_or_else(|| NCBError::config("Cannot get DatabaseClientData"))
|
||||
.map_err(|e| {
|
||||
error!(error = %e, "Failed to get database client");
|
||||
e
|
||||
})
|
||||
.unwrap(); // This is safe as we're in a critical path
|
||||
|
||||
match database.get_server_config_or_default(instance.guild.get()).await {
|
||||
Ok(Some(config)) => config,
|
||||
Ok(None) => {
|
||||
error!(guild_id = %instance.guild, "No server config available");
|
||||
return self.content.clone(); // Fallback to original text
|
||||
},
|
||||
Err(e) => {
|
||||
error!(guild_id = %instance.guild, error = %e, "Failed to get server config");
|
||||
return self.content.clone(); // Fallback to original text
|
||||
}
|
||||
}
|
||||
};
|
||||
let mut text = self.content.clone();
|
||||
|
||||
// Validate text length before processing
|
||||
if let Err(e) = validation::validate_tts_text(&text) {
|
||||
warn!(error = %e, "Invalid TTS text, using truncated version");
|
||||
text.truncate(crate::errors::constants::MAX_TTS_TEXT_LENGTH);
|
||||
}
|
||||
|
||||
for rule in config.dictionary.rules {
|
||||
if rule.is_regex {
|
||||
match get_cached_regex(&rule.rule) {
|
||||
Ok(regex) => {
|
||||
text = regex.replace_all(&text, &rule.to).to_string();
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
rule_id = rule.id,
|
||||
pattern = rule.rule,
|
||||
error = %e,
|
||||
"Skipping invalid regex rule"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let member = self.member.clone();
|
||||
let name = if let Some(member) = member {
|
||||
member.nick.unwrap_or(self.author.name.clone())
|
||||
text = text.replace(&rule.rule, &rule.to);
|
||||
}
|
||||
}
|
||||
let mut res = if let Some(before_message) = &instance.before_message {
|
||||
if before_message.author.id == self.author.id {
|
||||
text.clone()
|
||||
} else {
|
||||
let name = get_user_name(self, ctx).await;
|
||||
if config.read_username.unwrap_or(true) {
|
||||
format!("{}さんの発言<break time=\"200ms\"/>{}", name, text)
|
||||
} else {
|
||||
self.author.name.clone()
|
||||
};
|
||||
format!("{} さんの発言<break time=\"200ms\"/>{}", name, self.content)
|
||||
format!("{}", text)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let member = self.member.clone();
|
||||
let name = if let Some(member) = member {
|
||||
member.nick.unwrap_or(self.author.name.clone())
|
||||
let name = get_user_name(self, ctx).await;
|
||||
|
||||
if config.read_username.unwrap_or(true) {
|
||||
format!("{}さんの発言<break time=\"200ms\"/>{}", name, text)
|
||||
} else {
|
||||
self.author.name.clone()
|
||||
};
|
||||
format!("{} さんの発言<break time=\"200ms\"/>{}", name, self.content)
|
||||
format!("{}", text)
|
||||
}
|
||||
};
|
||||
|
||||
if self.attachments.len() > 0 {
|
||||
res = format!(
|
||||
"{}<break time=\"200ms\"/>{}個の添付ファイル",
|
||||
res,
|
||||
self.attachments.len()
|
||||
);
|
||||
}
|
||||
|
||||
instance.before_message = Some(self.clone());
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String {
|
||||
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Track> {
|
||||
let text = self.parse(instance, ctx).await;
|
||||
|
||||
let data_read = ctx.data.read().await;
|
||||
let storage = data_read.get::<TTSClientData>().expect("Cannot get GCP TTSClientStorage").clone();
|
||||
let mut tts = storage.lock().await;
|
||||
|
||||
let config = {
|
||||
let database = data_read.get::<DatabaseClientData>().expect("Cannot get DatabaseClientData").clone();
|
||||
let mut database = database.lock().await;
|
||||
database.get_user_config_or_default(self.author.id.0).await.unwrap().unwrap()
|
||||
};
|
||||
|
||||
let audio = match config.tts_type.unwrap_or(TTSType::GCP) {
|
||||
TTSType::GCP => {
|
||||
tts.0.synthesize(SynthesizeRequest {
|
||||
input: SynthesisInput {
|
||||
text: None,
|
||||
ssml: Some(format!("<speak>{}</speak>", text))
|
||||
},
|
||||
voice: config.gcp_tts_voice.unwrap(),
|
||||
audioConfig: AudioConfig {
|
||||
audioEncoding: String::from("mp3"),
|
||||
speakingRate: 1.2f32,
|
||||
pitch: 1.0f32
|
||||
let database = data_read
|
||||
.get::<DatabaseClientData>()
|
||||
.ok_or_else(|| NCBError::config("Cannot get DatabaseClientData"))
|
||||
.unwrap();
|
||||
|
||||
match database.get_user_config_or_default(self.author.id.get()).await {
|
||||
Ok(Some(config)) => config,
|
||||
Ok(None) | Err(_) => {
|
||||
error!(user_id = %self.author.id, "Failed to get user config, using defaults");
|
||||
// Return default config
|
||||
crate::database::user_config::UserConfig {
|
||||
tts_type: Some(TTSType::GCP),
|
||||
gcp_tts_voice: Some(crate::tts::gcp_tts::structs::voice_selection_params::VoiceSelectionParams {
|
||||
languageCode: String::from("ja-JP"),
|
||||
name: String::from("ja-JP-Wavenet-B"),
|
||||
ssmlGender: String::from("neutral"),
|
||||
}),
|
||||
voicevox_speaker: Some(crate::errors::constants::DEFAULT_VOICEVOX_SPEAKER),
|
||||
}
|
||||
}).await.unwrap()
|
||||
}
|
||||
|
||||
TTSType::VOICEVOX => {
|
||||
tts.1.synthesize(text.replace("<break time=\"200ms\"/>", "、"), config.voicevox_speaker.unwrap_or(1)).await.unwrap()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let uuid = uuid::Uuid::new_v4().to_string();
|
||||
let tts = data_read
|
||||
.get::<TTSClientData>()
|
||||
.ok_or_else(|| NCBError::config("Cannot get TTSClientData"))
|
||||
.unwrap();
|
||||
|
||||
let path = env::current_dir().unwrap();
|
||||
let file_path = path.join("audio").join(format!("{}.mp3", uuid));
|
||||
|
||||
let mut file = File::create(file_path.clone()).unwrap();
|
||||
file.write(&audio).unwrap();
|
||||
|
||||
file_path.into_os_string().into_string().unwrap()
|
||||
// Synthesize with retry logic
|
||||
let synthesis_result = match config.tts_type.unwrap_or(TTSType::GCP) {
|
||||
TTSType::GCP => {
|
||||
let sanitized_text = validation::sanitize_ssml(&text);
|
||||
retry_with_backoff(
|
||||
|| {
|
||||
tts.synthesize_gcp(SynthesizeRequest {
|
||||
input: SynthesisInput {
|
||||
text: None,
|
||||
ssml: Some(format!("<speak>{}</speak>", sanitized_text)),
|
||||
},
|
||||
voice: config.gcp_tts_voice.clone().unwrap_or_else(|| {
|
||||
crate::tts::gcp_tts::structs::voice_selection_params::VoiceSelectionParams {
|
||||
languageCode: String::from("ja-JP"),
|
||||
name: String::from("ja-JP-Wavenet-B"),
|
||||
ssmlGender: String::from("neutral"),
|
||||
}
|
||||
}),
|
||||
audioConfig: AudioConfig {
|
||||
audioEncoding: String::from("mp3"),
|
||||
speakingRate: DEFAULT_SPEAKING_RATE,
|
||||
pitch: DEFAULT_PITCH,
|
||||
},
|
||||
})
|
||||
},
|
||||
3, // max attempts
|
||||
std::time::Duration::from_millis(500),
|
||||
).await
|
||||
}
|
||||
TTSType::VOICEVOX => {
|
||||
let processed_text = text.replace("<break time=\"200ms\"/>", "、");
|
||||
retry_with_backoff(
|
||||
|| {
|
||||
tts.synthesize_voicevox(
|
||||
&processed_text,
|
||||
config.voicevox_speaker.unwrap_or(crate::errors::constants::DEFAULT_VOICEVOX_SPEAKER),
|
||||
)
|
||||
},
|
||||
3, // max attempts
|
||||
std::time::Duration::from_millis(500),
|
||||
).await
|
||||
}
|
||||
};
|
||||
|
||||
match synthesis_result {
|
||||
Ok(track) => vec![track],
|
||||
Err(e) => {
|
||||
error!(error = %e, "TTS synthesis failed");
|
||||
vec![] // Return empty vector on failure
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to get user name with proper error handling
|
||||
async fn get_user_name(message: &Message, ctx: &Context) -> String {
|
||||
let member = message.member.clone();
|
||||
if let Some(_) = member {
|
||||
if let Some(guild_id) = message.guild_id {
|
||||
match guild_id.member(&ctx.http, message.author.id).await {
|
||||
Ok(member) => member.read_name(),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
user_id = %message.author.id,
|
||||
guild_id = ?message.guild_id,
|
||||
error = %e,
|
||||
"Failed to get guild member, using fallback name"
|
||||
);
|
||||
message.author.read_name()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
guild_id = ?message.guild_id,
|
||||
"Guild not found in cache, using author name"
|
||||
);
|
||||
message.author.read_name()
|
||||
}
|
||||
} else {
|
||||
message.author.read_name()
|
||||
}
|
||||
}
|
||||
|
@ -1,2 +1,3 @@
|
||||
pub mod member_name;
|
||||
pub mod message;
|
||||
pub mod member_name;
|
||||
pub mod voice_move_state;
|
||||
|
50
src/implement/voice_move_state.rs
Normal file
50
src/implement/voice_move_state.rs
Normal file
@ -0,0 +1,50 @@
|
||||
use serenity::model::{prelude::ChannelId, voice::VoiceState};
|
||||
|
||||
pub trait VoiceMoveStateTrait {
|
||||
fn move_state(&self, old: &Option<VoiceState>, target_channel: ChannelId) -> VoiceMoveState;
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub enum VoiceMoveState {
|
||||
JOIN,
|
||||
LEAVE,
|
||||
NONE,
|
||||
}
|
||||
|
||||
impl VoiceMoveStateTrait for VoiceState {
|
||||
fn move_state(&self, old: &Option<VoiceState>, target_channel: ChannelId) -> VoiceMoveState {
|
||||
let new = self;
|
||||
|
||||
if let None = old.clone() {
|
||||
return if target_channel == new.channel_id.unwrap() {
|
||||
VoiceMoveState::JOIN
|
||||
} else {
|
||||
VoiceMoveState::NONE
|
||||
};
|
||||
}
|
||||
|
||||
let old = (*old).clone().unwrap();
|
||||
|
||||
match (old.channel_id, new.channel_id) {
|
||||
(Some(old_channel_id), Some(new_channel_id)) => {
|
||||
if old_channel_id == new_channel_id {
|
||||
VoiceMoveState::NONE
|
||||
} else if old_channel_id == target_channel {
|
||||
VoiceMoveState::LEAVE
|
||||
} else if new_channel_id == target_channel {
|
||||
VoiceMoveState::JOIN
|
||||
} else {
|
||||
VoiceMoveState::NONE
|
||||
}
|
||||
}
|
||||
(Some(old_channel_id), None) => {
|
||||
if old_channel_id == target_channel {
|
||||
VoiceMoveState::LEAVE
|
||||
} else {
|
||||
VoiceMoveState::NONE
|
||||
}
|
||||
}
|
||||
_ => VoiceMoveState::NONE,
|
||||
}
|
||||
}
|
||||
}
|
20
src/lib.rs
Normal file
20
src/lib.rs
Normal file
@ -0,0 +1,20 @@
|
||||
// Public API for the NCB-TTS-R2 library
|
||||
|
||||
pub mod errors;
|
||||
pub mod utils;
|
||||
pub mod tts;
|
||||
pub mod database;
|
||||
pub mod config;
|
||||
pub mod data;
|
||||
pub mod implement;
|
||||
pub mod events;
|
||||
pub mod commands;
|
||||
pub mod stream_input;
|
||||
pub mod trace;
|
||||
pub mod event_handler;
|
||||
pub mod connection_monitor;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use errors::{NCBError, Result};
|
||||
pub use utils::{CircuitBreaker, CircuitBreakerState, retry_with_backoff, get_cached_regex, PerformanceMetrics};
|
||||
pub use tts::tts_type::TTSType;
|
141
src/main.rs
141
src/main.rs
@ -1,23 +1,36 @@
|
||||
use std::{sync::Arc, collections::HashMap};
|
||||
|
||||
use config::Config;
|
||||
use data::{TTSData, TTSClientData, DatabaseClientData};
|
||||
use database::database::Database;
|
||||
use event_handler::Handler;
|
||||
use tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX};
|
||||
use serenity::{
|
||||
client::{Client, bridge::gateway::GatewayIntents},
|
||||
framework::StandardFramework, prelude::RwLock, futures::lock::Mutex
|
||||
};
|
||||
|
||||
use songbird::SerenityInit;
|
||||
|
||||
mod commands;
|
||||
mod config;
|
||||
mod event_handler;
|
||||
mod tts;
|
||||
mod implement;
|
||||
mod connection_monitor;
|
||||
mod data;
|
||||
mod database;
|
||||
mod errors;
|
||||
mod event_handler;
|
||||
mod events;
|
||||
mod implement;
|
||||
mod stream_input;
|
||||
mod trace;
|
||||
mod tts;
|
||||
mod utils;
|
||||
|
||||
use std::{collections::HashMap, env, sync::Arc};
|
||||
|
||||
use config::Config;
|
||||
use data::{DatabaseClientData, TTSClientData, TTSData};
|
||||
use database::database::Database;
|
||||
use errors::{NCBError, Result};
|
||||
use event_handler::Handler;
|
||||
#[allow(deprecated)]
|
||||
use serenity::{
|
||||
all::{standard::Configuration, ApplicationId},
|
||||
client::Client,
|
||||
framework::StandardFramework,
|
||||
prelude::{GatewayIntents, RwLock},
|
||||
};
|
||||
use trace::init_tracing_subscriber;
|
||||
use tracing::info;
|
||||
use tts::{gcp_tts::gcp_tts::GCPTTS, tts::TTS, voicevox::voicevox::VOICEVOX};
|
||||
|
||||
use songbird::SerenityInit;
|
||||
|
||||
/// Create discord client
|
||||
///
|
||||
@ -27,54 +40,94 @@ mod database;
|
||||
///
|
||||
/// client.start().await;
|
||||
/// ```
|
||||
async fn create_client(prefix: &str, token: &str, id: u64) -> Result<Client, serenity::Error> {
|
||||
let framework = StandardFramework::new()
|
||||
.configure(|c| c
|
||||
.with_whitespace(true)
|
||||
.prefix(prefix));
|
||||
#[allow(deprecated)]
|
||||
async fn create_client(prefix: &str, token: &str, id: u64) -> Result<Client> {
|
||||
let framework = StandardFramework::new();
|
||||
framework.configure(Configuration::new().with_whitespace(true).prefix(prefix));
|
||||
|
||||
Client::builder(token)
|
||||
Ok(Client::builder(token, GatewayIntents::all())
|
||||
.event_handler(Handler)
|
||||
.application_id(id)
|
||||
.application_id(ApplicationId::new(id))
|
||||
.framework(framework)
|
||||
.intents(GatewayIntents::all())
|
||||
.register_songbird()
|
||||
.await
|
||||
.await?)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
if let Err(e) = run().await {
|
||||
eprintln!("Application error: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
async fn run() -> Result<()> {
|
||||
// Load config
|
||||
let config = std::fs::read_to_string("./config.toml").expect("Cannot read config file.");
|
||||
let config: Config = toml::from_str(&config).expect("Cannot load config file.");
|
||||
let config = load_config()?;
|
||||
|
||||
let _guard = init_tracing_subscriber(&config.otel_http_url);
|
||||
|
||||
// Create discord client
|
||||
let mut client = create_client(&config.prefix, &config.token, config.application_id).await.expect("Err creating client");
|
||||
let mut client = create_client(&config.prefix, &config.token, config.application_id)
|
||||
.await?;
|
||||
|
||||
// Create GCP TTS client
|
||||
let tts = match TTS::new("./credentials.json".to_string()).await {
|
||||
Ok(tts) => tts,
|
||||
Err(err) => panic!("{}", err)
|
||||
};
|
||||
let tts = GCPTTS::new("./credentials.json".to_string())
|
||||
.await
|
||||
.map_err(|e| NCBError::GCPAuth(e))?;
|
||||
|
||||
let voicevox = VOICEVOX::new(config.voicevox_key);
|
||||
let voicevox = VOICEVOX::new(config.voicevox_key, config.voicevox_original_api_url);
|
||||
|
||||
let database_client = {
|
||||
let redis_client = redis::Client::open(config.redis_url).unwrap();
|
||||
let con = redis_client.get_connection().unwrap();
|
||||
Database::new(con)
|
||||
};
|
||||
let database_client = Database::new_with_url(config.redis_url).await?;
|
||||
|
||||
// Create TTS storage
|
||||
{
|
||||
let mut data = client.data.write().await;
|
||||
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
|
||||
data.insert::<TTSClientData>(Arc::new(Mutex::new((tts, voicevox))));
|
||||
data.insert::<DatabaseClientData>(Arc::new(Mutex::new(database_client)));
|
||||
data.insert::<TTSClientData>(Arc::new(TTS::new(voicevox, tts)));
|
||||
data.insert::<DatabaseClientData>(Arc::new(database_client.clone()));
|
||||
}
|
||||
|
||||
info!("Bot initialized.");
|
||||
|
||||
// Run client
|
||||
if let Err(why) = client.start().await {
|
||||
println!("Client error: {:?}", why);
|
||||
}
|
||||
client.start().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load configuration from file or environment variables
|
||||
fn load_config() -> Result<Config> {
|
||||
// Try to load from config file first
|
||||
if let Ok(config_str) = std::fs::read_to_string("./config.toml") {
|
||||
return toml::from_str::<Config>(&config_str)
|
||||
.map_err(|e| NCBError::Toml(e));
|
||||
}
|
||||
|
||||
// Fall back to environment variables
|
||||
let token = env::var("NCB_TOKEN")
|
||||
.map_err(|_| NCBError::missing_env_var("NCB_TOKEN"))?;
|
||||
let application_id_str = env::var("NCB_APP_ID")
|
||||
.map_err(|_| NCBError::missing_env_var("NCB_APP_ID"))?;
|
||||
let prefix = env::var("NCB_PREFIX")
|
||||
.map_err(|_| NCBError::missing_env_var("NCB_PREFIX"))?;
|
||||
let redis_url = env::var("NCB_REDIS_URL")
|
||||
.map_err(|_| NCBError::missing_env_var("NCB_REDIS_URL"))?;
|
||||
|
||||
let application_id = application_id_str.parse::<u64>()
|
||||
.map_err(|_| NCBError::config(format!("Invalid application ID: {}", application_id_str)))?;
|
||||
|
||||
let voicevox_key = env::var("NCB_VOICEVOX_KEY").ok();
|
||||
let voicevox_original_api_url = env::var("NCB_VOICEVOX_ORIGINAL_API_URL").ok();
|
||||
let otel_http_url = env::var("NCB_OTEL_HTTP_URL").ok();
|
||||
|
||||
Ok(Config {
|
||||
token,
|
||||
application_id,
|
||||
prefix,
|
||||
redis_url,
|
||||
voicevox_key,
|
||||
voicevox_original_api_url,
|
||||
otel_http_url,
|
||||
})
|
||||
}
|
||||
|
93
src/stream_input.rs
Normal file
93
src/stream_input.rs
Normal file
@ -0,0 +1,93 @@
|
||||
use async_trait::async_trait;
|
||||
use futures::TryStreamExt;
|
||||
use reqwest::{header::HeaderMap, Client};
|
||||
use symphonia_core::{io::MediaSource, probe::Hint};
|
||||
use tokio_util::compat::FuturesAsyncReadCompatExt;
|
||||
|
||||
use songbird::input::{
|
||||
AsyncAdapterStream, AsyncReadOnlySource, AudioStream, AudioStreamError, Compose, Input,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Mp3Request {
|
||||
client: Client,
|
||||
request: String,
|
||||
headers: HeaderMap,
|
||||
}
|
||||
|
||||
impl Mp3Request {
|
||||
#[must_use]
|
||||
pub fn new(client: Client, request: String) -> Self {
|
||||
Self::new_with_headers(client, request, HeaderMap::default())
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn new_with_headers(client: Client, request: String, headers: HeaderMap) -> Self {
|
||||
Mp3Request {
|
||||
client,
|
||||
request,
|
||||
headers,
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_stream_async(&self) -> Result<AsyncReadOnlySource, AudioStreamError> {
|
||||
let request = self
|
||||
.client
|
||||
.get(&self.request)
|
||||
.headers(self.headers.clone())
|
||||
.build()
|
||||
.map_err(|why| AudioStreamError::Fail(why.into()))?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.execute(request)
|
||||
.await
|
||||
.map_err(|why| AudioStreamError::Fail(why.into()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(AudioStreamError::Fail(
|
||||
format!("HTTP error: {}", response.status()).into(),
|
||||
));
|
||||
}
|
||||
|
||||
let byte_stream = response
|
||||
.bytes_stream()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()));
|
||||
|
||||
let tokio_reader = byte_stream.into_async_read().compat();
|
||||
|
||||
Ok(AsyncReadOnlySource::new(tokio_reader))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Compose for Mp3Request {
|
||||
fn create(&mut self) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> {
|
||||
Err(AudioStreamError::Fail(
|
||||
"Mp3Request::create must be called in an async context via create_async".into(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn create_async(
|
||||
&mut self,
|
||||
) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> {
|
||||
let input = self.create_stream_async().await?;
|
||||
let stream = AsyncAdapterStream::new(Box::new(input), 64 * 1024);
|
||||
|
||||
let hint = Hint::new().with_extension("mp3").clone();
|
||||
Ok(AudioStream {
|
||||
input: Box::new(stream) as Box<dyn MediaSource>,
|
||||
hint: Some(hint),
|
||||
})
|
||||
}
|
||||
|
||||
fn should_create_async(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Mp3Request> for Input {
|
||||
fn from(val: Mp3Request) -> Self {
|
||||
Input::Lazy(Box::new(val))
|
||||
}
|
||||
}
|
128
src/trace.rs
Normal file
128
src/trace.rs
Normal file
@ -0,0 +1,128 @@
|
||||
use opentelemetry::{
|
||||
global,
|
||||
trace::{SamplingDecision, SamplingResult, TraceContextExt, TraceState, TracerProvider as _},
|
||||
KeyValue,
|
||||
};
|
||||
use opentelemetry_otlp::{Protocol, WithExportConfig};
|
||||
use opentelemetry_sdk::{
|
||||
metrics::{MeterProviderBuilder, PeriodicReader, SdkMeterProvider},
|
||||
trace::{RandomIdGenerator, SdkTracerProvider, ShouldSample},
|
||||
Resource,
|
||||
};
|
||||
use tracing::Level;
|
||||
use tracing_opentelemetry::{MetricsLayer, OpenTelemetryLayer};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct FilterSampler;
|
||||
|
||||
impl ShouldSample for FilterSampler {
|
||||
fn should_sample(
|
||||
&self,
|
||||
parent_context: Option<&opentelemetry::Context>,
|
||||
_trace_id: opentelemetry::TraceId,
|
||||
name: &str,
|
||||
_span_kind: &opentelemetry::trace::SpanKind,
|
||||
_attributes: &[KeyValue],
|
||||
_links: &[opentelemetry::trace::Link],
|
||||
) -> opentelemetry::trace::SamplingResult {
|
||||
let decision = if name == "dispatch" || name == "recv_event" {
|
||||
SamplingDecision::Drop
|
||||
} else {
|
||||
SamplingDecision::RecordAndSample
|
||||
};
|
||||
|
||||
SamplingResult {
|
||||
decision,
|
||||
attributes: vec![],
|
||||
trace_state: match parent_context {
|
||||
Some(ctx) => ctx.span().span_context().trace_state().clone(),
|
||||
None => TraceState::default(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn resource() -> Resource {
|
||||
Resource::builder().with_service_name("ncb-tts-r2").build()
|
||||
}
|
||||
|
||||
fn init_meter_provider(url: &str) -> SdkMeterProvider {
|
||||
let exporter = opentelemetry_otlp::MetricExporter::builder()
|
||||
.with_http()
|
||||
.with_endpoint(url)
|
||||
.with_protocol(Protocol::HttpBinary)
|
||||
.with_temporality(opentelemetry_sdk::metrics::Temporality::default())
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let reader = PeriodicReader::builder(exporter)
|
||||
.with_interval(std::time::Duration::from_secs(5))
|
||||
.build();
|
||||
|
||||
let stdout_reader =
|
||||
PeriodicReader::builder(opentelemetry_stdout::MetricExporter::default()).build();
|
||||
|
||||
let meter_provider = MeterProviderBuilder::default()
|
||||
.with_resource(resource())
|
||||
.with_reader(reader)
|
||||
.with_reader(stdout_reader)
|
||||
.build();
|
||||
|
||||
global::set_meter_provider(meter_provider.clone());
|
||||
|
||||
meter_provider
|
||||
}
|
||||
|
||||
fn init_tracer_provider(url: &str) -> SdkTracerProvider {
|
||||
let exporter = opentelemetry_otlp::SpanExporter::builder()
|
||||
.with_http()
|
||||
.with_endpoint(url)
|
||||
.with_protocol(Protocol::HttpBinary)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
SdkTracerProvider::builder()
|
||||
.with_sampler(FilterSampler)
|
||||
.with_id_generator(RandomIdGenerator::default())
|
||||
.with_resource(resource())
|
||||
.with_batch_exporter(exporter)
|
||||
.build()
|
||||
}
|
||||
|
||||
pub fn init_tracing_subscriber(otel_http_url: &Option<String>) -> OtelGuard {
|
||||
let registry = tracing_subscriber::registry()
|
||||
.with(tracing_subscriber::filter::LevelFilter::from_level(
|
||||
Level::INFO,
|
||||
))
|
||||
.with(tracing_subscriber::fmt::layer());
|
||||
|
||||
if let Some(url) = otel_http_url {
|
||||
let tracer_provider = init_tracer_provider(url);
|
||||
let meter_provider = init_meter_provider(url);
|
||||
|
||||
let tracer = tracer_provider.tracer("ncb-tts-r2");
|
||||
|
||||
registry
|
||||
.with(MetricsLayer::new(meter_provider.clone()))
|
||||
.with(OpenTelemetryLayer::new(tracer))
|
||||
.init();
|
||||
|
||||
OtelGuard {
|
||||
_tracer_provider: Some(tracer_provider),
|
||||
_meter_provider: Some(meter_provider),
|
||||
}
|
||||
} else {
|
||||
registry.init();
|
||||
|
||||
OtelGuard {
|
||||
_tracer_provider: None,
|
||||
_meter_provider: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OtelGuard {
|
||||
_tracer_provider: Option<SdkTracerProvider>,
|
||||
_meter_provider: Option<SdkMeterProvider>,
|
||||
}
|
@ -1,34 +1,42 @@
|
||||
use gcp_auth::Token;
|
||||
use crate::tts::gcp_tts::structs::{
|
||||
synthesize_request::SynthesizeRequest,
|
||||
synthesize_response::SynthesizeResponse,
|
||||
synthesize_request::SynthesizeRequest, synthesize_response::SynthesizeResponse,
|
||||
};
|
||||
use gcp_auth::Token;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TTS {
|
||||
pub token: Token,
|
||||
pub credentials_path: String
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GCPTTS {
|
||||
pub token: Arc<RwLock<Token>>,
|
||||
pub credentials_path: String,
|
||||
}
|
||||
|
||||
impl TTS {
|
||||
|
||||
pub async fn update_token(&mut self) -> Result<(), gcp_auth::Error> {
|
||||
if self.token.has_expired() {
|
||||
let authenticator = gcp_auth::from_credentials_file(self.credentials_path.clone()).await?;
|
||||
let token = authenticator.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
|
||||
self.token = token;
|
||||
impl GCPTTS {
|
||||
#[tracing::instrument]
|
||||
pub async fn update_token(&self) -> Result<(), gcp_auth::Error> {
|
||||
let mut token = self.token.write().await;
|
||||
if token.has_expired() {
|
||||
let authenticator =
|
||||
gcp_auth::from_credentials_file(self.credentials_path.clone()).await?;
|
||||
let new_token = authenticator
|
||||
.get_token(&["https://www.googleapis.com/auth/cloud-platform"])
|
||||
.await?;
|
||||
*token = new_token;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn new(credentials_path: String) -> Result<TTS, gcp_auth::Error> {
|
||||
#[tracing::instrument]
|
||||
pub async fn new(credentials_path: String) -> Result<Self, gcp_auth::Error> {
|
||||
let authenticator = gcp_auth::from_credentials_file(credentials_path.clone()).await?;
|
||||
let token = authenticator.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
|
||||
let token = authenticator
|
||||
.get_token(&["https://www.googleapis.com/auth/cloud-platform"])
|
||||
.await?;
|
||||
|
||||
Ok(TTS {
|
||||
token,
|
||||
credentials_path
|
||||
Ok(Self {
|
||||
token: Arc::new(RwLock::new(token)),
|
||||
credentials_path,
|
||||
})
|
||||
}
|
||||
|
||||
@ -53,19 +61,37 @@ impl TTS {
|
||||
/// }
|
||||
/// }).await.unwrap();
|
||||
/// ```
|
||||
pub async fn synthesize(&mut self, request: SynthesizeRequest) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
|
||||
#[tracing::instrument]
|
||||
pub async fn synthesize(
|
||||
&self,
|
||||
request: SynthesizeRequest,
|
||||
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
|
||||
self.update_token().await.unwrap();
|
||||
let client = reqwest::Client::new();
|
||||
match client.post("https://texttospeech.googleapis.com/v1/text:synthesize")
|
||||
|
||||
let token_string = {
|
||||
let token = self.token.read().await;
|
||||
token.as_str().to_string()
|
||||
};
|
||||
|
||||
match client
|
||||
.post("https://texttospeech.googleapis.com/v1/text:synthesize")
|
||||
.header(reqwest::header::CONTENT_TYPE, "application/json")
|
||||
.header(reqwest::header::AUTHORIZATION, format!("Bearer {}", self.token.as_str()))
|
||||
.header(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!("Bearer {}", token_string),
|
||||
)
|
||||
.body(serde_json::to_string(&request).unwrap())
|
||||
.send().await {
|
||||
Ok(ok) => {
|
||||
let response: SynthesizeResponse = serde_json::from_str(&ok.text().await.expect("")).unwrap();
|
||||
Ok(base64::decode(response.audioContent).unwrap()[..].to_vec())
|
||||
},
|
||||
Err(err) => Err(Box::new(err))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(ok) => {
|
||||
let response: SynthesizeResponse =
|
||||
serde_json::from_str(&ok.text().await.expect("")).unwrap();
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
Ok(general_purpose::STANDARD.decode(response.audioContent).unwrap())
|
||||
}
|
||||
Err(err) => Err(Box::new(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,2 +1,2 @@
|
||||
pub mod gcp_tts;
|
||||
pub mod structs;
|
||||
pub mod structs;
|
||||
|
@ -1,17 +1,19 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Example:
|
||||
/// ```rust
|
||||
/// use ncb_tts_r2::tts::gcp_tts::structs::audio_config::AudioConfig;
|
||||
///
|
||||
/// AudioConfig {
|
||||
/// audioEncoding: String::from("mp3"),
|
||||
/// speakingRate: 1.2f32,
|
||||
/// pitch: 1.0f32
|
||||
/// }
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[allow(non_snake_case)]
|
||||
pub struct AudioConfig {
|
||||
pub audioEncoding: String,
|
||||
pub speakingRate: f32,
|
||||
pub pitch: f32
|
||||
}
|
||||
pub pitch: f32,
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
pub mod audio_config;
|
||||
pub mod synthesis_input;
|
||||
pub mod synthesize_request;
|
||||
pub mod synthesize_response;
|
||||
pub mod voice_selection_params;
|
||||
pub mod synthesize_response;
|
@ -1,14 +1,16 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Example:
|
||||
/// ```rust
|
||||
/// use ncb_tts_r2::tts::gcp_tts::structs::synthesis_input::SynthesisInput;
|
||||
///
|
||||
/// SynthesisInput {
|
||||
/// text: None,
|
||||
/// ssml: Some(String::from("<speak>test</speak>"))
|
||||
/// }
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone)]
|
||||
pub struct SynthesisInput {
|
||||
pub text: Option<String>,
|
||||
pub ssml: Option<String>
|
||||
}
|
||||
pub ssml: Option<String>,
|
||||
}
|
||||
|
@ -1,9 +1,8 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::tts::gcp_tts::structs::{
|
||||
synthesis_input::SynthesisInput,
|
||||
audio_config::AudioConfig,
|
||||
audio_config::AudioConfig, synthesis_input::SynthesisInput,
|
||||
voice_selection_params::VoiceSelectionParams,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Example:
|
||||
/// ```rust
|
||||
@ -24,10 +23,10 @@ use crate::tts::gcp_tts::structs::{
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[allow(non_snake_case)]
|
||||
pub struct SynthesizeRequest {
|
||||
pub input: SynthesisInput,
|
||||
pub voice: VoiceSelectionParams,
|
||||
pub audioConfig: AudioConfig,
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[allow(non_snake_case)]
|
||||
pub struct SynthesizeResponse {
|
||||
pub audioContent: String
|
||||
}
|
||||
pub audioContent: String,
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Example:
|
||||
/// ```rust
|
||||
@ -8,10 +8,10 @@ use serde::{Serialize, Deserialize};
|
||||
/// ssmlGender: String::from("neutral")
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq, Hash)]
|
||||
#[allow(non_snake_case)]
|
||||
pub struct VoiceSelectionParams {
|
||||
pub languageCode: String,
|
||||
pub name: String,
|
||||
pub ssmlGender: String
|
||||
}
|
||||
pub ssmlGender: String,
|
||||
}
|
||||
|
@ -1,32 +1,187 @@
|
||||
use serenity::{model::{channel::Message, id::{ChannelId, GuildId}}, prelude::Context};
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::{tts::message::TTSMessage};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serenity::{
|
||||
model::{
|
||||
channel::Message,
|
||||
id::{ChannelId, GuildId},
|
||||
},
|
||||
prelude::Context,
|
||||
};
|
||||
|
||||
use crate::tts::message::TTSMessage;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TTSInstance {
|
||||
#[serde(skip)] // Messageは複雑すぎるのでシリアライズしない
|
||||
pub before_message: Option<Message>,
|
||||
pub text_channel: ChannelId,
|
||||
pub text_channels: Vec<ChannelId>,
|
||||
pub voice_channel: ChannelId,
|
||||
pub guild: GuildId
|
||||
pub guild: GuildId,
|
||||
}
|
||||
|
||||
impl TTSInstance {
|
||||
/// Create a new TTSInstance
|
||||
pub fn new(text_channels: Vec<ChannelId>, voice_channel: ChannelId, guild: GuildId) -> Self {
|
||||
Self {
|
||||
before_message: None,
|
||||
text_channels,
|
||||
voice_channel,
|
||||
guild,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new TTSInstance with a single text channel
|
||||
pub fn new_single(text_channel: ChannelId, voice_channel: ChannelId, guild: GuildId) -> Self {
|
||||
Self::new(vec![text_channel], voice_channel, guild)
|
||||
}
|
||||
|
||||
/// Add a text channel to the instance
|
||||
pub fn add_text_channel(&mut self, channel_id: ChannelId) {
|
||||
if !self.text_channels.contains(&channel_id) {
|
||||
self.text_channels.push(channel_id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a text channel from the instance
|
||||
pub fn remove_text_channel(&mut self, channel_id: ChannelId) -> bool {
|
||||
if let Some(pos) = self.text_channels.iter().position(|&x| x == channel_id) {
|
||||
self.text_channels.remove(pos);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a channel is in the text channels list
|
||||
pub fn contains_text_channel(&self, channel_id: ChannelId) -> bool {
|
||||
self.text_channels.contains(&channel_id)
|
||||
}
|
||||
|
||||
/// Get all text channels
|
||||
pub fn get_text_channels(&self) -> &Vec<ChannelId> {
|
||||
&self.text_channels
|
||||
}
|
||||
|
||||
pub async fn check_connection(&self, ctx: &Context) -> bool {
|
||||
let manager = match songbird::get(ctx).await {
|
||||
Some(manager) => manager,
|
||||
None => {
|
||||
tracing::error!("Cannot get songbird manager");
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
let call = manager.get(self.guild);
|
||||
if let Some(call) = call {
|
||||
if let Some(connection) = call.lock().await.current_connection() {
|
||||
connection.channel_id.is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Reconnect to the voice channel after bot restart
|
||||
#[tracing::instrument]
|
||||
pub async fn reconnect(
|
||||
&self,
|
||||
ctx: &Context,
|
||||
skip_check: bool,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let manager = songbird::get(&ctx)
|
||||
.await
|
||||
.ok_or("Songbird manager not available")?;
|
||||
|
||||
// Check if we're already connected
|
||||
if self.check_connection(&ctx).await {
|
||||
tracing::info!("Already connected to guild {}", self.guild);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Try to connect to the voice channel
|
||||
match manager.join(self.guild, self.voice_channel).await {
|
||||
Ok(_) => {
|
||||
tracing::info!(
|
||||
"Successfully reconnected to voice channel {} in guild {}",
|
||||
self.voice_channel,
|
||||
self.guild
|
||||
);
|
||||
|
||||
// Double-check if there are users in the voice channel after connection
|
||||
match self.guild.channels(&ctx.http).await {
|
||||
Ok(channels) => {
|
||||
if let Some(channel) = channels.get(&self.voice_channel) {
|
||||
match channel.members(&ctx.cache) {
|
||||
Ok(members) => {
|
||||
let user_count =
|
||||
members.iter().filter(|member| !member.user.bot).count();
|
||||
if user_count == 0 {
|
||||
tracing::info!("No users found in voice channel after reconnection, disconnecting from guild {}", self.guild);
|
||||
// Disconnect if no users are present
|
||||
let _ = manager.remove(self.guild).await;
|
||||
return Err(
|
||||
"No users in voice channel after reconnection".into()
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::warn!(
|
||||
"Failed to verify members after reconnection for guild {}",
|
||||
self.guild
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::warn!(
|
||||
"Failed to get channels after reconnection for guild {}",
|
||||
self.guild
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to reconnect to voice channel: {}", e);
|
||||
Err(Box::new(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Synthesize text to speech and send it to the voice channel.
|
||||
///
|
||||
/// Example:
|
||||
/// ```rust
|
||||
/// instance.read(message, &ctx).await;
|
||||
/// ```
|
||||
#[tracing::instrument]
|
||||
pub async fn read<T>(&mut self, message: T, ctx: &Context)
|
||||
where T: TTSMessage
|
||||
where
|
||||
T: TTSMessage + Debug,
|
||||
{
|
||||
let path = message.synthesize(self, ctx).await;
|
||||
let audio = message.synthesize(self, ctx).await;
|
||||
|
||||
{
|
||||
let manager = songbird::get(&ctx).await.unwrap();
|
||||
let call = manager.get(self.guild).unwrap();
|
||||
let mut call = call.lock().await;
|
||||
let input = songbird::input::ffmpeg(path).await.expect("File not found.");
|
||||
call.enqueue_source(input);
|
||||
for audio in audio {
|
||||
call.enqueue(audio.into()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn skip(&mut self, ctx: &Context) {
|
||||
let manager = songbird::get(&ctx).await.unwrap();
|
||||
let call = manager.get(self.guild).unwrap();
|
||||
let call = call.lock().await;
|
||||
let queue = call.queue();
|
||||
let _ = queue.skip();
|
||||
}
|
||||
}
|
||||
|
@ -1,16 +1,17 @@
|
||||
use std::{path::Path, fs::File, io::Write, env};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serenity::prelude::Context;
|
||||
use songbird::tracks::Track;
|
||||
|
||||
use crate::{tts::instance::TTSInstance, data::TTSClientData};
|
||||
use crate::{data::TTSClientData, tts::instance::TTSInstance};
|
||||
|
||||
use super::gcp_tts::structs::{synthesize_request::SynthesizeRequest, synthesis_input::SynthesisInput, audio_config::AudioConfig, voice_selection_params::VoiceSelectionParams};
|
||||
use super::gcp_tts::structs::{
|
||||
audio_config::AudioConfig, synthesis_input::SynthesisInput,
|
||||
synthesize_request::SynthesizeRequest, voice_selection_params::VoiceSelectionParams,
|
||||
};
|
||||
|
||||
/// Message trait that can be used to synthesize text to speech.
|
||||
#[async_trait]
|
||||
pub trait TTSMessage {
|
||||
|
||||
/// Parse the message for synthesis.
|
||||
///
|
||||
/// Example:
|
||||
@ -19,57 +20,57 @@ pub trait TTSMessage {
|
||||
/// ```
|
||||
async fn parse(&self, instance: &mut TTSInstance, ctx: &Context) -> String;
|
||||
|
||||
/// Synthesize the message and returns the path to the audio file.
|
||||
/// Synthesize the message and returns the audio data.
|
||||
///
|
||||
/// Example:
|
||||
/// ```rust
|
||||
/// let path = message.synthesize(instance, ctx).await;
|
||||
/// let audio = message.synthesize(instance, ctx).await;
|
||||
/// ```
|
||||
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String;
|
||||
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Track>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnnounceMessage {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TTSMessage for AnnounceMessage {
|
||||
async fn parse(&self, instance: &mut TTSInstance, ctx: &Context) -> String {
|
||||
async fn parse(&self, instance: &mut TTSInstance, _ctx: &Context) -> String {
|
||||
instance.before_message = None;
|
||||
format!(r#"<speak>アナウンス<break time="200ms"/>{}</speak>"#, self.message)
|
||||
format!(
|
||||
r#"<speak>アナウンス<break time="200ms"/>{}</speak>"#,
|
||||
self.message
|
||||
)
|
||||
}
|
||||
|
||||
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String {
|
||||
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Track> {
|
||||
let text = self.parse(instance, ctx).await;
|
||||
let data_read = ctx.data.read().await;
|
||||
let storage = data_read.get::<TTSClientData>().expect("Cannot get TTSClientStorage").clone();
|
||||
let mut storage = storage.lock().await;
|
||||
let tts = data_read
|
||||
.get::<TTSClientData>()
|
||||
.expect("Cannot get TTSClientStorage");
|
||||
|
||||
let audio = storage.0.synthesize(SynthesizeRequest {
|
||||
input: SynthesisInput {
|
||||
text: None,
|
||||
ssml: Some(text)
|
||||
},
|
||||
voice: VoiceSelectionParams {
|
||||
languageCode: String::from("ja-JP"),
|
||||
name: String::from("ja-JP-Wavenet-B"),
|
||||
ssmlGender: String::from("neutral")
|
||||
},
|
||||
audioConfig: AudioConfig {
|
||||
audioEncoding: String::from("mp3"),
|
||||
speakingRate: 1.2f32,
|
||||
pitch: 1.0f32
|
||||
}
|
||||
}).await.unwrap();
|
||||
let audio = tts
|
||||
.synthesize_gcp(SynthesizeRequest {
|
||||
input: SynthesisInput {
|
||||
text: None,
|
||||
ssml: Some(text),
|
||||
},
|
||||
voice: VoiceSelectionParams {
|
||||
languageCode: String::from("ja-JP"),
|
||||
name: String::from("ja-JP-Wavenet-B"),
|
||||
ssmlGender: String::from("neutral"),
|
||||
},
|
||||
audioConfig: AudioConfig {
|
||||
audioEncoding: String::from("mp3"),
|
||||
speakingRate: 1.2f32,
|
||||
pitch: 1.0f32,
|
||||
},
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let uuid = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
let path = env::current_dir().unwrap();
|
||||
let file_path = path.join("audio").join(format!("{}.mp3", uuid));
|
||||
|
||||
let mut file = File::create(file_path.clone()).unwrap();
|
||||
file.write(&audio).unwrap();
|
||||
|
||||
file_path.into_os_string().into_string().unwrap()
|
||||
vec![audio.into()]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
pub mod gcp_tts;
|
||||
pub mod voicevox;
|
||||
pub mod instance;
|
||||
pub mod message;
|
||||
pub mod tts;
|
||||
pub mod tts_type;
|
||||
pub mod instance;
|
||||
pub mod voicevox;
|
||||
|
559
src/tts/tts.rs
Normal file
559
src/tts/tts.rs
Normal file
@ -0,0 +1,559 @@
|
||||
use std::sync::RwLock;
|
||||
use std::{num::NonZeroUsize, sync::Arc};
|
||||
|
||||
use lru::LruCache;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use songbird::{driver::Bitrate, input::cached::Compressed, tracks::Track};
|
||||
use tracing::{debug, error, info, instrument, warn};
|
||||
|
||||
use crate::{
|
||||
errors::{constants::*, NCBError, Result},
|
||||
utils::{retry_with_backoff, CircuitBreaker, PerformanceMetrics},
|
||||
};
|
||||
|
||||
use super::{
|
||||
gcp_tts::{
|
||||
gcp_tts::GCPTTS,
|
||||
structs::{
|
||||
synthesis_input::SynthesisInput, synthesize_request::SynthesizeRequest,
|
||||
voice_selection_params::VoiceSelectionParams,
|
||||
},
|
||||
},
|
||||
voicevox::voicevox::VOICEVOX,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TTS {
|
||||
pub voicevox_client: VOICEVOX,
|
||||
gcp_tts_client: GCPTTS,
|
||||
cache: Arc<RwLock<LruCache<CacheKey, Compressed>>>,
|
||||
voicevox_circuit_breaker: Arc<RwLock<CircuitBreaker>>,
|
||||
gcp_circuit_breaker: Arc<RwLock<CircuitBreaker>>,
|
||||
metrics: Arc<PerformanceMetrics>,
|
||||
cache_persistence_path: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq, Clone, Serialize, Deserialize, Debug)]
|
||||
pub enum CacheKey {
|
||||
Voicevox(String, i64),
|
||||
GCP(SynthesisInput, VoiceSelectionParams),
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
struct CacheEntry {
|
||||
key: CacheKey,
|
||||
data: Vec<u8>,
|
||||
created_at: std::time::SystemTime,
|
||||
access_count: u64,
|
||||
}
|
||||
|
||||
impl TTS {
|
||||
pub fn new(voicevox_client: VOICEVOX, gcp_tts_client: GCPTTS) -> Self {
|
||||
let tts = Self {
|
||||
voicevox_client,
|
||||
gcp_tts_client,
|
||||
cache: Arc::new(RwLock::new(LruCache::new(
|
||||
NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap(),
|
||||
))),
|
||||
voicevox_circuit_breaker: Arc::new(RwLock::new(CircuitBreaker::default())),
|
||||
gcp_circuit_breaker: Arc::new(RwLock::new(CircuitBreaker::default())),
|
||||
metrics: Arc::new(PerformanceMetrics::new()),
|
||||
cache_persistence_path: Some("./tts_cache.bin".to_string()),
|
||||
};
|
||||
|
||||
// Try to load persisted cache
|
||||
if let Err(e) = tts.load_cache() {
|
||||
warn!(error = %e, "Failed to load persisted cache");
|
||||
}
|
||||
|
||||
tts
|
||||
}
|
||||
|
||||
pub fn with_cache_path(mut self, path: Option<String>) -> Self {
|
||||
self.cache_persistence_path = path;
|
||||
self
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
pub async fn synthesize_voicevox(
|
||||
&self,
|
||||
text: &str,
|
||||
speaker: i64,
|
||||
) -> std::result::Result<Track, NCBError> {
|
||||
self.metrics.increment_tts_requests();
|
||||
let cache_key = CacheKey::Voicevox(text.to_string(), speaker);
|
||||
|
||||
let cached_audio = {
|
||||
let mut cache_guard = self.cache.write().unwrap();
|
||||
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
|
||||
};
|
||||
|
||||
if let Some(audio) = cached_audio {
|
||||
debug!("Cache hit for VOICEVOX TTS");
|
||||
self.metrics.increment_tts_cache_hits();
|
||||
return Ok(audio.into());
|
||||
}
|
||||
|
||||
debug!("Cache miss for VOICEVOX TTS");
|
||||
self.metrics.increment_tts_cache_misses();
|
||||
|
||||
// Check circuit breaker
|
||||
{
|
||||
let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap();
|
||||
circuit_breaker.try_half_open();
|
||||
|
||||
if !circuit_breaker.can_execute() {
|
||||
return Err(NCBError::voicevox("Circuit breaker is open"));
|
||||
}
|
||||
}
|
||||
|
||||
let synthesis_result = if self.voicevox_client.original_api_url.is_some() {
|
||||
retry_with_backoff(
|
||||
|| async {
|
||||
match self
|
||||
.voicevox_client
|
||||
.synthesize_original(text.to_string(), speaker)
|
||||
.await
|
||||
{
|
||||
Ok(audio) => Ok(audio),
|
||||
Err(e) => Err(NCBError::voicevox(format!(
|
||||
"VOICEVOX synthesis failed: {}",
|
||||
e
|
||||
))),
|
||||
}
|
||||
},
|
||||
3,
|
||||
std::time::Duration::from_millis(500),
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
retry_with_backoff(
|
||||
|| async {
|
||||
match self
|
||||
.voicevox_client
|
||||
.synthesize_stream(text.to_string(), speaker)
|
||||
.await
|
||||
{
|
||||
Ok(_mp3_request) => Err(NCBError::voicevox(
|
||||
"Stream synthesis not yet fully implemented",
|
||||
)),
|
||||
Err(e) => Err(NCBError::voicevox(format!(
|
||||
"VOICEVOX synthesis failed: {}",
|
||||
e
|
||||
))),
|
||||
}
|
||||
},
|
||||
3,
|
||||
std::time::Duration::from_millis(500),
|
||||
)
|
||||
.await
|
||||
};
|
||||
|
||||
match synthesis_result {
|
||||
Ok(audio) => {
|
||||
// Update circuit breaker on success
|
||||
let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap();
|
||||
circuit_breaker.on_success();
|
||||
drop(circuit_breaker);
|
||||
|
||||
// Cache the audio asynchronously
|
||||
let cache = self.cache.clone();
|
||||
let cache_key_clone = cache_key.clone();
|
||||
let audio_for_cache = audio.clone();
|
||||
tokio::spawn(async move {
|
||||
debug!("Compressing and caching VOICEVOX audio");
|
||||
if let Ok(compressed) =
|
||||
Compressed::new(audio_for_cache.into(), Bitrate::Auto).await
|
||||
{
|
||||
let mut cache_guard = cache.write().unwrap();
|
||||
cache_guard.put(cache_key_clone, compressed);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(audio.into())
|
||||
}
|
||||
Err(e) => {
|
||||
// Update circuit breaker on failure
|
||||
let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap();
|
||||
circuit_breaker.on_failure();
|
||||
drop(circuit_breaker);
|
||||
|
||||
error!(error = %e, "VOICEVOX synthesis failed");
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn synthesize_gcp(
|
||||
&self,
|
||||
synthesize_request: SynthesizeRequest,
|
||||
) -> std::result::Result<Track, NCBError> {
|
||||
self.metrics.increment_tts_requests();
|
||||
let cache_key = CacheKey::GCP(
|
||||
synthesize_request.input.clone(),
|
||||
synthesize_request.voice.clone(),
|
||||
);
|
||||
|
||||
let cached_audio = {
|
||||
let mut cache_guard = self.cache.write().unwrap();
|
||||
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
|
||||
};
|
||||
|
||||
if let Some(audio) = cached_audio {
|
||||
debug!("Cache hit for GCP TTS");
|
||||
self.metrics.increment_tts_cache_hits();
|
||||
return Ok(audio.into());
|
||||
}
|
||||
|
||||
debug!("Cache miss for GCP TTS");
|
||||
self.metrics.increment_tts_cache_misses();
|
||||
|
||||
// Check circuit breaker
|
||||
{
|
||||
let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap();
|
||||
circuit_breaker.try_half_open();
|
||||
|
||||
if !circuit_breaker.can_execute() {
|
||||
return Err(NCBError::tts_synthesis("GCP TTS circuit breaker is open"));
|
||||
}
|
||||
}
|
||||
|
||||
let request_clone = SynthesizeRequest {
|
||||
input: synthesize_request.input.clone(),
|
||||
voice: synthesize_request.voice.clone(),
|
||||
audioConfig: synthesize_request.audioConfig.clone(),
|
||||
};
|
||||
|
||||
let audio = {
|
||||
let audio_result = retry_with_backoff(
|
||||
|| async {
|
||||
match self.gcp_tts_client.synthesize(request_clone.clone()).await {
|
||||
Ok(audio) => Ok(audio),
|
||||
Err(e) => Err(NCBError::tts_synthesis(format!(
|
||||
"GCP TTS synthesis failed: {}",
|
||||
e
|
||||
))),
|
||||
}
|
||||
},
|
||||
3,
|
||||
std::time::Duration::from_millis(500),
|
||||
)
|
||||
.await;
|
||||
|
||||
match audio_result {
|
||||
Ok(audio) => audio,
|
||||
Err(e) => {
|
||||
// Update circuit breaker on failure
|
||||
let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap();
|
||||
circuit_breaker.on_failure();
|
||||
drop(circuit_breaker);
|
||||
|
||||
error!(error = %e, "GCP TTS synthesis failed");
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Update circuit breaker on success
|
||||
{
|
||||
let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap();
|
||||
circuit_breaker.on_success();
|
||||
}
|
||||
|
||||
match Compressed::new(audio.into(), Bitrate::Auto).await {
|
||||
Ok(compressed) => {
|
||||
// Cache the compressed audio
|
||||
{
|
||||
let mut cache_guard = self.cache.write().unwrap();
|
||||
cache_guard.put(cache_key, compressed.clone());
|
||||
}
|
||||
|
||||
// Persist cache asynchronously
|
||||
if let Some(path) = &self.cache_persistence_path {
|
||||
let cache_clone = self.cache.clone();
|
||||
let path_clone = path.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = Self::persist_cache_to_file(&cache_clone, &path_clone) {
|
||||
warn!(error = %e, "Failed to persist cache");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(compressed.into())
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = %e, "Failed to compress GCP audio");
|
||||
Err(NCBError::tts_synthesis(format!(
|
||||
"Audio compression failed: {}",
|
||||
e
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load cache from persistent storage
|
||||
fn load_cache(&self) -> Result<()> {
|
||||
if let Some(path) = &self.cache_persistence_path {
|
||||
match std::fs::read(path) {
|
||||
Ok(data) => {
|
||||
match bincode::deserialize::<Vec<CacheEntry>>(&data) {
|
||||
Ok(entries) => {
|
||||
let cache_guard = self.cache.read().unwrap();
|
||||
let now = std::time::SystemTime::now();
|
||||
|
||||
for entry in entries {
|
||||
// Skip expired entries (older than 24 hours)
|
||||
if let Ok(age) = now.duration_since(entry.created_at) {
|
||||
if age.as_secs() < CACHE_TTL_SECS {
|
||||
debug!("Loaded cache entry from disk");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Loaded {} cache entries from disk", cache_guard.len());
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Failed to deserialize cache data");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
debug!("No existing cache file found");
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Failed to read cache file");
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persist cache to storage (simplified implementation)
|
||||
fn persist_cache_to_file(
|
||||
cache: &Arc<RwLock<LruCache<CacheKey, Compressed>>>,
|
||||
path: &str,
|
||||
) -> Result<()> {
|
||||
// Note: This is a simplified implementation
|
||||
let _cache_guard = cache.read().unwrap();
|
||||
let entries: Vec<CacheEntry> = Vec::new(); // Placeholder for actual implementation
|
||||
|
||||
match bincode::serialize(&entries) {
|
||||
Ok(data) => {
|
||||
if let Err(e) = std::fs::write(path, data) {
|
||||
return Err(NCBError::database(format!(
|
||||
"Failed to write cache file: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
debug!("Cache persisted to disk");
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(NCBError::database(format!(
|
||||
"Failed to serialize cache: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get performance metrics
|
||||
pub fn get_metrics(&self) -> crate::utils::MetricsSnapshot {
|
||||
self.metrics.get_stats()
|
||||
}
|
||||
|
||||
/// Clear cache
|
||||
pub fn clear_cache(&self) {
|
||||
let mut cache_guard = self.cache.write().unwrap();
|
||||
cache_guard.clear();
|
||||
info!("TTS cache cleared");
|
||||
}
|
||||
|
||||
/// Get cache statistics
|
||||
pub fn get_cache_stats(&self) -> (usize, usize) {
|
||||
let cache_guard = self.cache.read().unwrap();
|
||||
(cache_guard.len(), cache_guard.cap().get())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::errors::constants::CIRCUIT_BREAKER_FAILURE_THRESHOLD;
|
||||
use crate::tts::gcp_tts::structs::{
|
||||
synthesis_input::SynthesisInput, voice_selection_params::VoiceSelectionParams,
|
||||
};
|
||||
use crate::utils::{CircuitBreakerState, MetricsSnapshot};
|
||||
use std::time::Duration;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_equality() {
|
||||
let input = SynthesisInput {
|
||||
text: None,
|
||||
ssml: Some("Hello".to_string()),
|
||||
};
|
||||
let voice = VoiceSelectionParams {
|
||||
languageCode: "en-US".to_string(),
|
||||
name: "en-US-Wavenet-A".to_string(),
|
||||
ssmlGender: "female".to_string(),
|
||||
};
|
||||
|
||||
let key1 = CacheKey::GCP(input.clone(), voice.clone());
|
||||
let key2 = CacheKey::GCP(input.clone(), voice.clone());
|
||||
let key3 = CacheKey::Voicevox("Hello".to_string(), 1);
|
||||
let key4 = CacheKey::Voicevox("Hello".to_string(), 1);
|
||||
let key5 = CacheKey::Voicevox("Hello".to_string(), 2);
|
||||
|
||||
assert_eq!(key1, key2);
|
||||
assert_eq!(key3, key4);
|
||||
assert_ne!(key3, key5);
|
||||
// Note: Different enum variants are never equal
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_hash() {
|
||||
use std::collections::HashMap;
|
||||
|
||||
let input = SynthesisInput {
|
||||
text: Some("Test".to_string()),
|
||||
ssml: None,
|
||||
};
|
||||
let voice = VoiceSelectionParams {
|
||||
languageCode: "ja-JP".to_string(),
|
||||
name: "ja-JP-Wavenet-B".to_string(),
|
||||
ssmlGender: "neutral".to_string(),
|
||||
};
|
||||
|
||||
let mut map = HashMap::new();
|
||||
let key = CacheKey::GCP(input, voice);
|
||||
map.insert(key.clone(), "test_value");
|
||||
|
||||
assert_eq!(map.get(&key), Some(&"test_value"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_entry_creation() {
|
||||
let data = vec![1, 2, 3, 4, 5];
|
||||
let now = std::time::SystemTime::now();
|
||||
|
||||
let entry = CacheEntry {
|
||||
key: CacheKey::Voicevox("test".to_string(), 1),
|
||||
data: data.clone(),
|
||||
created_at: now,
|
||||
access_count: 0,
|
||||
};
|
||||
|
||||
assert_eq!(entry.key, CacheKey::Voicevox("test".to_string(), 1));
|
||||
assert_eq!(entry.created_at, now);
|
||||
assert_eq!(entry.data, data);
|
||||
assert_eq!(entry.access_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_performance_metrics_integration() {
|
||||
// Test metrics functionality with realistic data
|
||||
let metrics = PerformanceMetrics::default();
|
||||
|
||||
// Simulate TTS request pattern
|
||||
for _ in 0..10 {
|
||||
metrics.increment_tts_requests();
|
||||
}
|
||||
|
||||
// Simulate 70% cache hit rate
|
||||
for _ in 0..7 {
|
||||
metrics.increment_tts_cache_hits();
|
||||
}
|
||||
for _ in 0..3 {
|
||||
metrics.increment_tts_cache_misses();
|
||||
}
|
||||
|
||||
let stats = metrics.get_stats();
|
||||
assert_eq!(stats.tts_requests, 10);
|
||||
assert_eq!(stats.tts_cache_hits, 7);
|
||||
assert_eq!(stats.tts_cache_misses, 3);
|
||||
|
||||
let hit_rate = stats.tts_cache_hit_rate();
|
||||
assert!((hit_rate - 0.7).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circuit_breaker_state_transitions() {
|
||||
let mut cb = CircuitBreaker::new(2, Duration::from_millis(100));
|
||||
|
||||
// Initially closed
|
||||
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||
assert!(cb.can_execute());
|
||||
|
||||
// First failure
|
||||
cb.on_failure();
|
||||
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||
assert_eq!(cb.failure_count, 1);
|
||||
|
||||
// Second failure opens circuit
|
||||
cb.on_failure();
|
||||
assert_eq!(cb.state, CircuitBreakerState::Open);
|
||||
assert!(!cb.can_execute());
|
||||
|
||||
// Wait and try half-open
|
||||
std::thread::sleep(Duration::from_millis(150));
|
||||
cb.try_half_open();
|
||||
assert_eq!(cb.state, CircuitBreakerState::HalfOpen);
|
||||
assert!(cb.can_execute());
|
||||
|
||||
// Success closes circuit
|
||||
cb.on_success();
|
||||
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||
assert_eq!(cb.failure_count, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_persistence_setup() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let cache_path = temp_dir
|
||||
.path()
|
||||
.join("test_cache.bin")
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
|
||||
// Test cache path configuration
|
||||
assert!(!cache_path.is_empty());
|
||||
assert!(cache_path.ends_with("test_cache.bin"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metrics_snapshot_calculations() {
|
||||
let snapshot = MetricsSnapshot {
|
||||
tts_requests: 20,
|
||||
tts_cache_hits: 15,
|
||||
tts_cache_misses: 5,
|
||||
regex_cache_hits: 8,
|
||||
regex_cache_misses: 2,
|
||||
database_operations: 30,
|
||||
voice_connections: 5,
|
||||
};
|
||||
|
||||
// Test TTS cache hit rate
|
||||
let tts_hit_rate = snapshot.tts_cache_hit_rate();
|
||||
assert!((tts_hit_rate - 0.75).abs() < f64::EPSILON);
|
||||
|
||||
// Test regex cache hit rate
|
||||
let regex_hit_rate = snapshot.regex_cache_hit_rate();
|
||||
assert!((regex_hit_rate - 0.8).abs() < f64::EPSILON);
|
||||
|
||||
// Test edge case with no operations
|
||||
let empty_snapshot = MetricsSnapshot {
|
||||
tts_requests: 0,
|
||||
tts_cache_hits: 0,
|
||||
tts_cache_misses: 0,
|
||||
regex_cache_hits: 0,
|
||||
regex_cache_misses: 0,
|
||||
database_operations: 0,
|
||||
voice_connections: 0,
|
||||
};
|
||||
|
||||
assert_eq!(empty_snapshot.tts_cache_hit_rate(), 0.0);
|
||||
assert_eq!(empty_snapshot.regex_cache_hit_rate(), 0.0);
|
||||
}
|
||||
}
|
@ -3,5 +3,5 @@ use serde::{Deserialize, Serialize};
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub enum TTSType {
|
||||
GCP,
|
||||
VOICEVOX
|
||||
}
|
||||
VOICEVOX,
|
||||
}
|
||||
|
@ -1,2 +1,2 @@
|
||||
pub mod structs;
|
||||
pub mod voicevox;
|
||||
pub mod voicevox;
|
||||
|
@ -1,4 +1,4 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::mora::Mora;
|
||||
|
||||
@ -7,5 +7,5 @@ pub struct AccentPhrase {
|
||||
pub moras: Vec<Mora>,
|
||||
pub accent: f64,
|
||||
pub pause_mora: Option<Mora>,
|
||||
pub is_interrogative: bool
|
||||
}
|
||||
pub is_interrogative: bool,
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::accent_phrase::AccentPhrase;
|
||||
|
||||
@ -14,5 +14,5 @@ pub struct AudioQuery {
|
||||
pub postPhonemeLength: f64,
|
||||
pub outputSamplingRate: f64,
|
||||
pub outputStereo: bool,
|
||||
pub kana: Option<String>
|
||||
}
|
||||
pub kana: Option<String>,
|
||||
}
|
||||
|
@ -1,3 +1,5 @@
|
||||
pub mod mora;
|
||||
pub mod accent_phrase;
|
||||
pub mod audio_query;
|
||||
pub mod accent_phrase;
|
||||
pub mod mora;
|
||||
pub mod speaker;
|
||||
pub mod stream;
|
||||
|
@ -1,4 +1,4 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct Mora {
|
||||
@ -7,5 +7,5 @@ pub struct Mora {
|
||||
pub consonant_length: Option<f64>,
|
||||
pub vowel: String,
|
||||
pub vowel_length: f64,
|
||||
pub pitch: f64
|
||||
}
|
||||
pub pitch: f64,
|
||||
}
|
||||
|
21
src/tts/voicevox/structs/speaker.rs
Normal file
21
src/tts/voicevox/structs/speaker.rs
Normal file
@ -0,0 +1,21 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct Speaker {
|
||||
pub supported_features: SupportedFeatures,
|
||||
pub name: String,
|
||||
pub speaker_uuid: String,
|
||||
pub styles: Vec<Style>,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct SupportedFeatures {
|
||||
pub permitted_synthesis_morphing: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct Style {
|
||||
pub name: String,
|
||||
pub id: i64,
|
||||
}
|
13
src/tts/voicevox/structs/stream.rs
Normal file
13
src/tts/voicevox/structs/stream.rs
Normal file
@ -0,0 +1,13 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TTSResponse {
|
||||
pub success: bool,
|
||||
pub is_api_key_valid: bool,
|
||||
pub speaker_name: String,
|
||||
pub audio_status_url: String,
|
||||
pub wav_download_url: String,
|
||||
pub mp3_download_url: String,
|
||||
pub mp3_streaming_url: String,
|
||||
}
|
@ -1,27 +1,166 @@
|
||||
const API_URL: &str = "https://api.su-shiki.com/v2/voicevox/audio";
|
||||
use crate::{errors::NCBError, stream_input::Mp3Request};
|
||||
|
||||
#[derive(Clone)]
|
||||
use super::structs::{speaker::Speaker, stream::TTSResponse};
|
||||
|
||||
const BASE_API_URL: &str = "https://deprecatedapis.tts.quest/v2/";
|
||||
const STREAM_API_URL: &str = "https://api.tts.quest/v3/voicevox/synthesis";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct VOICEVOX {
|
||||
pub key: String
|
||||
pub key: Option<String>,
|
||||
pub original_api_url: Option<String>,
|
||||
}
|
||||
|
||||
impl VOICEVOX {
|
||||
pub fn new(key: String) -> Self {
|
||||
#[tracing::instrument]
|
||||
pub async fn get_styles(&self) -> Result<Vec<(String, i64)>, NCBError> {
|
||||
let speakers = self.get_speaker_list().await?;
|
||||
let mut speaker_list = Vec::new();
|
||||
for speaker in speakers {
|
||||
for style in speaker.styles {
|
||||
speaker_list.push((format!("{} - {}", speaker.name, style.name), style.id))
|
||||
}
|
||||
}
|
||||
|
||||
Ok(speaker_list)
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn get_speakers(&self) -> Result<Vec<String>, NCBError> {
|
||||
let speakers = self.get_speaker_list().await?;
|
||||
let mut speaker_list = Vec::new();
|
||||
for speaker in speakers {
|
||||
speaker_list.push(speaker.name)
|
||||
}
|
||||
|
||||
Ok(speaker_list)
|
||||
}
|
||||
|
||||
pub fn new(key: Option<String>, original_api_url: Option<String>) -> Self {
|
||||
Self {
|
||||
key
|
||||
key,
|
||||
original_api_url,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn synthesize(&self, text: String, speaker: i64) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
|
||||
#[tracing::instrument]
|
||||
async fn get_speaker_list(&self) -> Result<Vec<Speaker>, NCBError> {
|
||||
let client = reqwest::Client::new();
|
||||
match client.post(API_URL).query(&[("speaker", speaker.to_string()), ("text", text), ("key", self.key.clone())]).send().await {
|
||||
Ok(response) => {
|
||||
let body = response.bytes().await?;
|
||||
Ok(body.to_vec())
|
||||
}
|
||||
Err(err) => {
|
||||
Err(Box::new(err))
|
||||
}
|
||||
let request = if let Some(key) = &self.key {
|
||||
client
|
||||
.get(format!("{}{}", BASE_API_URL, "voicevox/speakers/"))
|
||||
.query(&[("key", key)])
|
||||
} else if let Some(original_api_url) = &self.original_api_url {
|
||||
client.get(format!("{}/speakers", original_api_url))
|
||||
} else {
|
||||
return Err(NCBError::voicevox("No API key or original API URL provided"));
|
||||
};
|
||||
|
||||
let response = request.send().await
|
||||
.map_err(|e| NCBError::voicevox(format!("Failed to fetch speakers: {}", e)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(NCBError::voicevox(format!(
|
||||
"API request failed with status: {}",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
|
||||
response.json().await
|
||||
.map_err(|e| NCBError::voicevox(format!("Failed to parse speaker list: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn synthesize(
|
||||
&self,
|
||||
text: String,
|
||||
speaker: i64,
|
||||
) -> Result<Vec<u8>, NCBError> {
|
||||
let key = self.key.as_ref()
|
||||
.ok_or_else(|| NCBError::voicevox("API key required for synthesis"))?;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let response = client
|
||||
.post(format!("{}{}", BASE_API_URL, "voicevox/audio/"))
|
||||
.query(&[
|
||||
("speaker", speaker.to_string()),
|
||||
("text", text),
|
||||
("key", key.clone()),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| NCBError::voicevox(format!("Synthesis request failed: {}", e)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(NCBError::voicevox(format!(
|
||||
"Synthesis failed with status: {}",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let body = response.bytes().await
|
||||
.map_err(|e| NCBError::voicevox(format!("Failed to read response body: {}", e)))?;
|
||||
|
||||
Ok(body.to_vec())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn synthesize_original(
|
||||
&self,
|
||||
text: String,
|
||||
speaker: i64,
|
||||
) -> Result<Vec<u8>, NCBError> {
|
||||
let api_url = self.original_api_url.as_ref()
|
||||
.ok_or_else(|| NCBError::voicevox("Original API URL required for synthesis"))?;
|
||||
|
||||
let client = voicevox_client::Client::new(api_url.clone(), None);
|
||||
let audio_query = client
|
||||
.create_audio_query(&text, speaker as i32, None)
|
||||
.await
|
||||
.map_err(|e| NCBError::voicevox(format!("Failed to create audio query: {}", e)))?;
|
||||
|
||||
tracing::debug!(audio_query = ?audio_query.audio_query, "Generated audio query");
|
||||
|
||||
let audio = audio_query.synthesis(speaker as i32, true).await
|
||||
.map_err(|e| NCBError::voicevox(format!("Audio synthesis failed: {}", e)))?;
|
||||
|
||||
Ok(audio.into())
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn synthesize_stream(
|
||||
&self,
|
||||
text: String,
|
||||
speaker: i64,
|
||||
) -> Result<Mp3Request, NCBError> {
|
||||
let key = self.key.as_ref()
|
||||
.ok_or_else(|| NCBError::voicevox("API key required for stream synthesis"))?;
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let response = client
|
||||
.post(STREAM_API_URL)
|
||||
.query(&[
|
||||
("speaker", speaker.to_string()),
|
||||
("text", text),
|
||||
("key", key.clone()),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| NCBError::voicevox(format!("Stream synthesis request failed: {}", e)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(NCBError::voicevox(format!(
|
||||
"Stream synthesis failed with status: {}",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
|
||||
let body = response.text().await
|
||||
.map_err(|e| NCBError::voicevox(format!("Failed to read response text: {}", e)))?;
|
||||
|
||||
let tts_response: TTSResponse = serde_json::from_str(&body)
|
||||
.map_err(|e| NCBError::voicevox(format!("Failed to parse TTS response: {}", e)))?;
|
||||
|
||||
Ok(Mp3Request::new(reqwest::Client::new(), tts_response.mp3_streaming_url))
|
||||
}
|
||||
}
|
||||
|
594
src/utils.rs
Normal file
594
src/utils.rs
Normal file
@ -0,0 +1,594 @@
|
||||
use once_cell::sync::Lazy;
|
||||
use lru::LruCache;
|
||||
use regex::Regex;
|
||||
use std::{num::NonZeroUsize, sync::RwLock};
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
use crate::errors::{constants::*, NCBError, Result};
|
||||
|
||||
/// Regex compilation cache to avoid recompiling the same patterns
|
||||
static REGEX_CACHE: Lazy<RwLock<LruCache<String, Regex>>> =
|
||||
Lazy::new(|| RwLock::new(LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap())));
|
||||
|
||||
/// Circuit breaker states for external API calls
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum CircuitBreakerState {
|
||||
Closed,
|
||||
Open,
|
||||
HalfOpen,
|
||||
}
|
||||
|
||||
/// Circuit breaker for handling external API failures
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CircuitBreaker {
|
||||
pub state: CircuitBreakerState,
|
||||
pub failure_count: u32,
|
||||
pub last_failure_time: Option<std::time::Instant>,
|
||||
pub threshold: u32,
|
||||
pub timeout: std::time::Duration,
|
||||
}
|
||||
|
||||
impl Default for CircuitBreaker {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
state: CircuitBreakerState::Closed,
|
||||
failure_count: 0,
|
||||
last_failure_time: None,
|
||||
threshold: 5,
|
||||
timeout: std::time::Duration::from_secs(60),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CircuitBreaker {
|
||||
pub fn new(threshold: u32, timeout: std::time::Duration) -> Self {
|
||||
Self {
|
||||
threshold,
|
||||
timeout,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn can_execute(&self) -> bool {
|
||||
match self.state {
|
||||
CircuitBreakerState::Closed => true,
|
||||
CircuitBreakerState::Open => {
|
||||
if let Some(last_failure) = self.last_failure_time {
|
||||
last_failure.elapsed() >= self.timeout
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
CircuitBreakerState::HalfOpen => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn on_success(&mut self) {
|
||||
self.failure_count = 0;
|
||||
self.state = CircuitBreakerState::Closed;
|
||||
self.last_failure_time = None;
|
||||
}
|
||||
|
||||
pub fn on_failure(&mut self) {
|
||||
self.failure_count += 1;
|
||||
self.last_failure_time = Some(std::time::Instant::now());
|
||||
|
||||
if self.failure_count >= self.threshold {
|
||||
self.state = CircuitBreakerState::Open;
|
||||
} else if self.state == CircuitBreakerState::HalfOpen {
|
||||
self.state = CircuitBreakerState::Open;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_half_open(&mut self) {
|
||||
if self.state == CircuitBreakerState::Open {
|
||||
if let Some(last_failure) = self.last_failure_time {
|
||||
if last_failure.elapsed() >= self.timeout {
|
||||
self.state = CircuitBreakerState::HalfOpen;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cached regex compilation with error handling
|
||||
pub fn get_cached_regex(pattern: &str) -> Result<Regex> {
|
||||
// First try to get from cache
|
||||
{
|
||||
let cache = REGEX_CACHE.read().unwrap();
|
||||
if let Some(cached_regex) = cache.peek(pattern) {
|
||||
debug!(pattern = pattern, "Regex cache hit");
|
||||
return Ok(cached_regex.clone());
|
||||
}
|
||||
}
|
||||
|
||||
debug!(pattern = pattern, "Regex cache miss, compiling");
|
||||
|
||||
// Compile regex with error handling
|
||||
match Regex::new(pattern) {
|
||||
Ok(regex) => {
|
||||
// Cache successful compilation
|
||||
{
|
||||
let mut cache = REGEX_CACHE.write().unwrap();
|
||||
cache.put(pattern.to_string(), regex.clone());
|
||||
}
|
||||
Ok(regex)
|
||||
}
|
||||
Err(e) => {
|
||||
error!(pattern = pattern, error = %e, "Failed to compile regex");
|
||||
Err(NCBError::invalid_regex(format!("{}: {}", pattern, e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Retry logic with exponential backoff
|
||||
pub async fn retry_with_backoff<F, Fut, T, E>(
|
||||
mut operation: F,
|
||||
max_attempts: u32,
|
||||
initial_delay: std::time::Duration,
|
||||
) -> std::result::Result<T, E>
|
||||
where
|
||||
F: FnMut() -> Fut,
|
||||
Fut: std::future::Future<Output = std::result::Result<T, E>>,
|
||||
E: std::fmt::Display,
|
||||
{
|
||||
let mut attempts = 0;
|
||||
let mut delay = initial_delay;
|
||||
|
||||
loop {
|
||||
attempts += 1;
|
||||
|
||||
match operation().await {
|
||||
Ok(result) => {
|
||||
if attempts > 1 {
|
||||
debug!(attempts = attempts, "Operation succeeded after retry");
|
||||
}
|
||||
return Ok(result);
|
||||
}
|
||||
Err(error) => {
|
||||
if attempts >= max_attempts {
|
||||
error!(
|
||||
attempts = attempts,
|
||||
error = %error,
|
||||
"Operation failed after maximum retry attempts"
|
||||
);
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
warn!(
|
||||
attempt = attempts,
|
||||
max_attempts = max_attempts,
|
||||
delay_ms = delay.as_millis(),
|
||||
error = %error,
|
||||
"Operation failed, retrying with backoff"
|
||||
);
|
||||
|
||||
tokio::time::sleep(delay).await;
|
||||
delay = std::cmp::min(delay * 2, std::time::Duration::from_secs(30));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Rate limiter using token bucket algorithm
|
||||
#[derive(Debug)]
|
||||
pub struct RateLimiter {
|
||||
tokens: std::sync::Arc<std::sync::RwLock<f64>>,
|
||||
capacity: f64,
|
||||
refill_rate: f64,
|
||||
last_refill: std::sync::Arc<std::sync::RwLock<std::time::Instant>>,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new(capacity: f64, refill_rate: f64) -> Self {
|
||||
Self {
|
||||
tokens: std::sync::Arc::new(std::sync::RwLock::new(capacity)),
|
||||
capacity,
|
||||
refill_rate,
|
||||
last_refill: std::sync::Arc::new(std::sync::RwLock::new(std::time::Instant::now())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_acquire(&self, tokens: f64) -> bool {
|
||||
self.refill();
|
||||
|
||||
let mut current_tokens = self.tokens.write().unwrap();
|
||||
if *current_tokens >= tokens {
|
||||
*current_tokens -= tokens;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn refill(&self) {
|
||||
let now = std::time::Instant::now();
|
||||
let mut last_refill = self.last_refill.write().unwrap();
|
||||
let elapsed = now.duration_since(*last_refill).as_secs_f64();
|
||||
|
||||
if elapsed > 0.0 {
|
||||
let tokens_to_add = elapsed * self.refill_rate;
|
||||
let mut current_tokens = self.tokens.write().unwrap();
|
||||
*current_tokens = (*current_tokens + tokens_to_add).min(self.capacity);
|
||||
*last_refill = now;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Performance metrics collection
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct PerformanceMetrics {
|
||||
pub tts_requests: std::sync::Arc<std::sync::atomic::AtomicU64>,
|
||||
pub tts_cache_hits: std::sync::Arc<std::sync::atomic::AtomicU64>,
|
||||
pub tts_cache_misses: std::sync::Arc<std::sync::atomic::AtomicU64>,
|
||||
pub regex_cache_hits: std::sync::Arc<std::sync::atomic::AtomicU64>,
|
||||
pub regex_cache_misses: std::sync::Arc<std::sync::atomic::AtomicU64>,
|
||||
pub database_operations: std::sync::Arc<std::sync::atomic::AtomicU64>,
|
||||
pub voice_connections: std::sync::Arc<std::sync::atomic::AtomicU64>,
|
||||
}
|
||||
|
||||
impl PerformanceMetrics {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn increment_tts_requests(&self) {
|
||||
self.tts_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_tts_cache_hits(&self) {
|
||||
self.tts_cache_hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_tts_cache_misses(&self) {
|
||||
self.tts_cache_misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_regex_cache_hits(&self) {
|
||||
self.regex_cache_hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_regex_cache_misses(&self) {
|
||||
self.regex_cache_misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_database_operations(&self) {
|
||||
self.database_operations.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_voice_connections(&self) {
|
||||
self.voice_connections.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn get_stats(&self) -> MetricsSnapshot {
|
||||
MetricsSnapshot {
|
||||
tts_requests: self.tts_requests.load(std::sync::atomic::Ordering::Relaxed),
|
||||
tts_cache_hits: self.tts_cache_hits.load(std::sync::atomic::Ordering::Relaxed),
|
||||
tts_cache_misses: self.tts_cache_misses.load(std::sync::atomic::Ordering::Relaxed),
|
||||
regex_cache_hits: self.regex_cache_hits.load(std::sync::atomic::Ordering::Relaxed),
|
||||
regex_cache_misses: self.regex_cache_misses.load(std::sync::atomic::Ordering::Relaxed),
|
||||
database_operations: self.database_operations.load(std::sync::atomic::Ordering::Relaxed),
|
||||
voice_connections: self.voice_connections.load(std::sync::atomic::Ordering::Relaxed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetricsSnapshot {
|
||||
pub tts_requests: u64,
|
||||
pub tts_cache_hits: u64,
|
||||
pub tts_cache_misses: u64,
|
||||
pub regex_cache_hits: u64,
|
||||
pub regex_cache_misses: u64,
|
||||
pub database_operations: u64,
|
||||
pub voice_connections: u64,
|
||||
}
|
||||
|
||||
impl MetricsSnapshot {
|
||||
pub fn tts_cache_hit_rate(&self) -> f64 {
|
||||
if self.tts_cache_hits + self.tts_cache_misses > 0 {
|
||||
self.tts_cache_hits as f64 / (self.tts_cache_hits + self.tts_cache_misses) as f64
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
pub fn regex_cache_hit_rate(&self) -> f64 {
|
||||
if self.regex_cache_hits + self.regex_cache_misses > 0 {
|
||||
self.regex_cache_hits as f64 / (self.regex_cache_hits + self.regex_cache_misses) as f64
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
use crate::errors::constants::CIRCUIT_BREAKER_FAILURE_THRESHOLD;
|
||||
|
||||
#[test]
|
||||
fn test_circuit_breaker_default() {
|
||||
let cb = CircuitBreaker::default();
|
||||
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||
assert_eq!(cb.failure_count, 0);
|
||||
assert!(cb.can_execute());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circuit_breaker_new() {
|
||||
let cb = CircuitBreaker::new(3, Duration::from_secs(10));
|
||||
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||
assert_eq!(cb.threshold, 3);
|
||||
assert_eq!(cb.timeout, Duration::from_secs(10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circuit_breaker_failure_threshold() {
|
||||
let mut cb = CircuitBreaker::default();
|
||||
|
||||
// Test failures up to threshold
|
||||
for i in 0..CIRCUIT_BREAKER_FAILURE_THRESHOLD {
|
||||
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||
assert!(cb.can_execute());
|
||||
cb.on_failure();
|
||||
assert_eq!(cb.failure_count, i + 1);
|
||||
}
|
||||
|
||||
// Should open after reaching threshold
|
||||
assert_eq!(cb.state, CircuitBreakerState::Open);
|
||||
assert!(!cb.can_execute());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circuit_breaker_success_resets() {
|
||||
let mut cb = CircuitBreaker::default();
|
||||
|
||||
// Add some failures
|
||||
cb.on_failure();
|
||||
cb.on_failure();
|
||||
assert_eq!(cb.failure_count, 2);
|
||||
|
||||
// Success should reset
|
||||
cb.on_success();
|
||||
assert_eq!(cb.failure_count, 0);
|
||||
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circuit_breaker_half_open() {
|
||||
let mut cb = CircuitBreaker::new(1, Duration::from_millis(100));
|
||||
|
||||
// Trigger failure to open circuit
|
||||
cb.on_failure();
|
||||
assert_eq!(cb.state, CircuitBreakerState::Open);
|
||||
assert!(!cb.can_execute());
|
||||
|
||||
// Wait for timeout
|
||||
std::thread::sleep(Duration::from_millis(150));
|
||||
|
||||
// Should allow transition to half-open
|
||||
cb.try_half_open();
|
||||
assert_eq!(cb.state, CircuitBreakerState::HalfOpen);
|
||||
assert!(cb.can_execute());
|
||||
|
||||
// Success in half-open should close circuit
|
||||
cb.on_success();
|
||||
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circuit_breaker_half_open_failure() {
|
||||
let mut cb = CircuitBreaker::new(1, Duration::from_millis(100));
|
||||
|
||||
// Open circuit
|
||||
cb.on_failure();
|
||||
std::thread::sleep(Duration::from_millis(150));
|
||||
cb.try_half_open();
|
||||
assert_eq!(cb.state, CircuitBreakerState::HalfOpen);
|
||||
|
||||
// Failure in half-open should reopen circuit
|
||||
cb.on_failure();
|
||||
assert_eq!(cb.state, CircuitBreakerState::Open);
|
||||
assert!(!cb.can_execute());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_with_backoff_success_first_try() {
|
||||
let mut call_count = 0;
|
||||
let result = retry_with_backoff(
|
||||
|| {
|
||||
call_count += 1;
|
||||
async { Ok::<i32, &'static str>(42) }
|
||||
},
|
||||
3,
|
||||
Duration::from_millis(100),
|
||||
).await;
|
||||
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
assert_eq!(call_count, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_with_backoff_success_after_retries() {
|
||||
let mut call_count = 0;
|
||||
let result = retry_with_backoff(
|
||||
|| {
|
||||
call_count += 1;
|
||||
async move {
|
||||
if call_count < 3 {
|
||||
Err("temporary error")
|
||||
} else {
|
||||
Ok::<i32, &'static str>(42)
|
||||
}
|
||||
}
|
||||
},
|
||||
5,
|
||||
Duration::from_millis(10),
|
||||
).await;
|
||||
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
assert_eq!(call_count, 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_with_backoff_max_attempts() {
|
||||
let mut call_count = 0;
|
||||
let result = retry_with_backoff(
|
||||
|| {
|
||||
call_count += 1;
|
||||
async { Err::<i32, &'static str>("persistent error") }
|
||||
},
|
||||
3,
|
||||
Duration::from_millis(10),
|
||||
).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(result.unwrap_err(), "persistent error");
|
||||
assert_eq!(call_count, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_cached_regex_valid_pattern() {
|
||||
// Clear cache first
|
||||
{
|
||||
let mut cache = REGEX_CACHE.write().unwrap();
|
||||
cache.clear();
|
||||
}
|
||||
|
||||
let pattern = r"[a-zA-Z]+";
|
||||
let result1 = get_cached_regex(pattern);
|
||||
assert!(result1.is_ok());
|
||||
|
||||
let result2 = get_cached_regex(pattern);
|
||||
assert!(result2.is_ok());
|
||||
|
||||
// Both should work and second should be from cache
|
||||
let regex1 = result1.unwrap();
|
||||
let regex2 = result2.unwrap();
|
||||
assert!(regex1.is_match("hello"));
|
||||
assert!(regex2.is_match("world"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_cached_regex_invalid_pattern() {
|
||||
let pattern = r"[";
|
||||
let result = get_cached_regex(pattern);
|
||||
assert!(result.is_err());
|
||||
|
||||
if let Err(NCBError::InvalidRegex(msg)) = result {
|
||||
// The error message contains the pattern and the regex error
|
||||
assert!(msg.contains("["));
|
||||
} else {
|
||||
panic!("Expected InvalidRegex error");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limiter_basic() {
|
||||
let limiter = RateLimiter::new(5.0, 1.0); // 5 tokens, 1 per second
|
||||
|
||||
// Should be able to acquire 5 tokens initially
|
||||
assert!(limiter.try_acquire(1.0));
|
||||
assert!(limiter.try_acquire(1.0));
|
||||
assert!(limiter.try_acquire(1.0));
|
||||
assert!(limiter.try_acquire(1.0));
|
||||
assert!(limiter.try_acquire(1.0));
|
||||
|
||||
// 6th token should fail
|
||||
assert!(!limiter.try_acquire(1.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limiter_partial_tokens() {
|
||||
let limiter = RateLimiter::new(2.0, 1.0);
|
||||
|
||||
// Acquire partial tokens
|
||||
assert!(limiter.try_acquire(0.5));
|
||||
assert!(limiter.try_acquire(0.5));
|
||||
assert!(limiter.try_acquire(0.5));
|
||||
assert!(limiter.try_acquire(0.5));
|
||||
|
||||
// Should fail with no tokens left
|
||||
assert!(!limiter.try_acquire(0.1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_performance_metrics_increment() {
|
||||
let metrics = PerformanceMetrics::default();
|
||||
|
||||
assert_eq!(metrics.tts_requests.load(std::sync::atomic::Ordering::Relaxed), 0);
|
||||
|
||||
metrics.increment_tts_requests();
|
||||
metrics.increment_tts_requests();
|
||||
|
||||
assert_eq!(metrics.tts_requests.load(std::sync::atomic::Ordering::Relaxed), 2);
|
||||
|
||||
metrics.increment_tts_cache_hits();
|
||||
assert_eq!(metrics.tts_cache_hits.load(std::sync::atomic::Ordering::Relaxed), 1);
|
||||
|
||||
metrics.increment_tts_cache_misses();
|
||||
assert_eq!(metrics.tts_cache_misses.load(std::sync::atomic::Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metrics_snapshot_cache_hit_rate() {
|
||||
let snapshot = MetricsSnapshot {
|
||||
tts_requests: 10,
|
||||
tts_cache_hits: 7,
|
||||
tts_cache_misses: 3,
|
||||
regex_cache_hits: 0,
|
||||
regex_cache_misses: 0,
|
||||
database_operations: 0,
|
||||
voice_connections: 0,
|
||||
};
|
||||
|
||||
assert!((snapshot.tts_cache_hit_rate() - 0.7).abs() < f64::EPSILON);
|
||||
|
||||
let empty_snapshot = MetricsSnapshot {
|
||||
tts_requests: 0,
|
||||
tts_cache_hits: 0,
|
||||
tts_cache_misses: 0,
|
||||
regex_cache_hits: 0,
|
||||
regex_cache_misses: 0,
|
||||
database_operations: 0,
|
||||
voice_connections: 0,
|
||||
};
|
||||
|
||||
assert_eq!(empty_snapshot.tts_cache_hit_rate(), 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_metrics_snapshot_regex_cache_hit_rate() {
|
||||
let snapshot = MetricsSnapshot {
|
||||
tts_requests: 0,
|
||||
tts_cache_hits: 0,
|
||||
tts_cache_misses: 0,
|
||||
regex_cache_hits: 8,
|
||||
regex_cache_misses: 2,
|
||||
database_operations: 0,
|
||||
voice_connections: 0,
|
||||
};
|
||||
|
||||
assert!((snapshot.regex_cache_hit_rate() - 0.8).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_performance_metrics_get_stats() {
|
||||
let metrics = PerformanceMetrics::default();
|
||||
|
||||
// Add some data
|
||||
metrics.increment_tts_requests();
|
||||
metrics.increment_tts_requests();
|
||||
metrics.increment_tts_cache_hits();
|
||||
metrics.increment_database_operations();
|
||||
|
||||
let stats = metrics.get_stats();
|
||||
|
||||
assert_eq!(stats.tts_requests, 2);
|
||||
assert_eq!(stats.tts_cache_hits, 1);
|
||||
assert_eq!(stats.tts_cache_misses, 0);
|
||||
assert_eq!(stats.database_operations, 1);
|
||||
}
|
||||
}
|
BIN
tts_cache.bin
Normal file
BIN
tts_cache.bin
Normal file
Binary file not shown.
Reference in New Issue
Block a user