106 Commits

Author SHA1 Message Date
43cce7dc31 fix: 自動参加のバグ修正
アナウンスが無効の時に自動参加も無効になるバグを修正
2025-05-28 17:22:08 +09:00
2f06f6be3b Merge branch 'master' of github.com:mii443/ncb-tts-r2 2025-05-28 16:10:51 +09:00
f0327e232a feat: テキストチャンネルの自動参加設定を追加
- 複数のテキストチャンネルをサポートするために、TTSインスタンスの構造を変更
- 自動参加テキストチャンネルの設定と解除をUIセレクトメニューで実装
- 再接続時にテキストチャンネルに通知を送信する機能を強化
- コードの可読性向上のために、エラーハンドリングとロギングを改善

🤖 Generated with [Claude Code](https://claude.ai/code)
2025-05-28 16:08:34 +09:00
733646b6b8 refactor: Major overhaul with error handling, resilience patterns, and observability
- Add library configuration to support both lib and binary targets
- Implement unified error handling with NCBError throughout the codebase
- Add circuit breaker pattern for external API calls (Voicevox, GCP TTS)
- Introduce comprehensive performance metrics and monitoring
- Add cache persistence with disk storage support
- Implement retry mechanism with exponential backoff
- Add configuration file support (config.toml) with env var fallback
- Enhance logging with structured tracing (debug, warn, error levels)
- Add extensive unit tests for cache, metrics, and circuit breaker
- Update base64 decoding to use modern API
- Improve API error handling for Voicevox and GCP TTS clients

Breaking changes:
- Function signatures now return Result<T, NCBError> instead of panicking
- Cache key structure modified with serialization support
2025-05-28 01:01:12 +09:00
9e7d89eaa5 Update build.yml 2025-05-26 18:27:22 +09:00
ea93d1f8ac Update build.yml 2025-05-26 18:17:27 +09:00
f9f90ab63e Update build.yml 2025-05-26 18:08:12 +09:00
e188f3b758 Update build.yml 2025-05-26 17:53:50 +09:00
0bea81aa6e feat: add connection monitor with auto-reconnect notifications
Implements automatic voice channel connection monitoring with user-friendly notifications:
- Monitor connections every 5 seconds to detect disconnections
- Auto-reconnect when users are present in voice channel
- Send embed notifications to text channel upon reconnection
- Include /stop command information in reconnection messages

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-05-26 17:39:27 +09:00
7ed3f10543 feat: add clear option for autostart channel setting
Allow users to remove/clear the autostart channel setting through the UI select menu.
Added "解除" option and dynamic response messages for both setting and clearing operations.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-05-26 16:35:50 +09:00
c4819bb633 check users before restoring instance 2025-05-26 00:02:00 +09:00
d382a045d0 feat: restoring instance 2025-05-25 23:42:01 +09:00
ffa49c18cc remove arm64 support 2025-05-25 22:51:01 +09:00
48a4977208 update gh action 2025-05-25 22:50:11 +09:00
32d9782cff Merge pull request #2 from mii443/stable
Stable
2025-05-25 22:28:21 +09:00
248c88a7a2 update t v1.11.2 2025-05-25 22:27:42 +09:00
65db668e2a update Docker things 2025-05-25 00:10:08 +09:00
879644f30c fix gcp version 2025-05-25 00:02:44 +09:00
8ce6dbf57a Merge branch 'stable' of github.com:mii443/ncb-tts-r2 into stable 2025-05-24 23:56:06 +09:00
4d658b6671 WIP: ローカルの変更を保存 2025-05-24 23:49:30 +09:00
e606b29f81 add server configs 2025-05-24 23:44:59 +09:00
376b8a6882 Merge pull request #1 from mii443/stable
Stable
2025-05-24 21:54:25 +09:00
b452d4609a Merge branch 'master' into stable 2025-05-24 21:54:11 +09:00
599b7fb0d1 v1.10.1 2025-04-27 17:17:44 +09:00
4c76fd037a v1.9.2 2025-04-27 17:17:06 +09:00
83fde399b2 support original voicevox api 2025-04-27 17:16:17 +09:00
91044c7c25 fix user count 2025-04-24 15:32:43 +00:00
a13994c37e fix user count 2025-04-12 18:06:40 +09:00
97ae9dd9e0 optimize database lock 2025-04-11 18:07:46 +09:00
f7e08b4e2e add mp3 streaming for voicevox 2025-04-11 15:10:01 +09:00
257c8511e3 change type 2025-04-11 14:03:13 +09:00
771711e3bf update 2025-04-11 10:43:21 +09:00
40e6194942 supress warn 2025-04-10 15:06:37 +09:00
9949d501b5 fix warn, optimize 2025-04-10 14:57:09 +09:00
24de817d6f add otel http url config 2025-04-10 14:17:30 +09:00
8a1fa22074 add tracing opentelemetry 2025-04-10 13:40:29 +09:00
24c609aaf2 update 2025-04-05 17:53:51 +09:00
51f008cfdc extend cache size 2025-04-05 17:53:16 +09:00
4c176935e3 reduce Mutex lock 2025-04-04 22:14:23 +09:00
55ea223f69 gcp synthesize without mutable 2025-04-04 21:59:53 +09:00
696954395b support vec audio 2025-04-04 21:46:32 +09:00
82e3c55fd5 fix warnings 2025-04-04 21:11:46 +09:00
af83f6b6e0 change log level, remove unused import 2025-04-04 21:07:48 +09:00
c68e533133 implement compressed local audio cache 2025-04-04 20:45:20 +09:00
77b4c3e04d update workflow 2025-04-03 03:05:59 +09:00
8db4d65042 change versions 2025-04-03 03:04:35 +09:00
e12dbd7375 fix voice move state bug 2025-04-03 03:03:06 +09:00
cff30e7471 fix reading name 2025-04-03 02:58:02 +09:00
bf4b160af7 implement tts without disk I/O 2025-04-03 02:43:26 +09:00
b4de0f1ad6 fix warnings 2025-04-03 02:38:40 +09:00
1975b2e9cd fix instance.rs bug 2025-04-03 02:32:59 +09:00
5ee3c9b328 fix errors 2025-04-03 02:04:17 +09:00
df46152a12 update crates, fix event_handler, main 2025-03-30 23:34:25 +09:00
1830029231 add ARM build 2025-02-21 16:17:16 +09:00
e4dbedcbe7 Update docker-compose.yml 2025-02-21 16:15:23 +09:00
f11718cc8b Update docker-compose.yml 2025-02-21 16:14:47 +09:00
8ef5530524 update serenity to poise 2024-11-19 10:15:52 +00:00
89f66aefb0 update dependencies 2024-11-19 10:00:15 +00:00
mii
68e96ef784 change base image 2024-01-16 15:01:45 +00:00
mii
883a54f70a update rust version 2024-01-16 14:24:46 +00:00
mii
60255d7582 fix action row 2024-01-16 14:13:19 +00:00
mii
68d49772e7 update version to 1.7.0 2023-10-03 15:15:58 +00:00
mii
0fbe068f1e fix action row 2023-10-03 15:13:55 +00:00
mii
c39800da18 auto load voiecvox speaker list 2023-10-03 14:54:34 +00:00
mii
5aa5e09bb7 downgrade base image 2023-09-30 12:33:54 +00:00
mii
ec47a6f521 fix github workflow 2023-09-30 11:15:09 +00:00
mii
bb4d0a0504 fix github workflow 2023-09-30 11:05:05 +00:00
mii
4630883b28 Update version 2023-09-30 10:55:57 +00:00
mii
2249e8c213 Add autostart feature 2023-09-30 10:53:57 +00:00
mii
f9ebd8a430 wip autostart settings 2023-09-28 10:53:41 +00:00
mii
8a9817a449 update version to 1.5.1 2023-04-13 04:44:09 +00:00
mii
4c5b9cb345 update songbird 2023-04-13 04:43:20 +00:00
mii
52f86a6c16 add ignore 2023-04-06 01:46:40 +00:00
mii
b62e81dd66 add nurserobo_type_t 2023-01-15 07:50:59 +00:00
mii
60770f65b6 add skip command 2022-12-08 08:25:00 +00:00
mii
f93701a591 add dictionary, server config 2022-12-07 08:21:43 +00:00
mii
2a40e9ee16 change regex 2022-12-07 03:36:43 +00:00
mii
2f2b82857f update docker build 2022-12-07 03:24:49 +00:00
mii
8b2574e90b fix code block regex 2022-12-06 15:55:17 +00:00
mii
fb654229b0 fix command register 2022-12-06 15:35:19 +00:00
mii
7124930a9d fix code block regex 2022-12-06 14:46:03 +00:00
mii
0bd051e48c add feedback url 2022-12-06 14:39:14 +00:00
mii
ccf2c63224 change setup message 2022-12-06 14:29:50 +00:00
mii
c2a84be3a4 TTS Channel mode feature 2022-12-06 14:04:14 +00:00
mii
d525636060 add kubernetes manifest, add code block regex 2022-12-06 13:01:32 +00:00
mii
68a73a4dae support auto archive thread 2022-11-19 02:51:09 +00:00
mii
f7b7071a09 update docker-compose.yml image version 2022-11-18 09:26:19 +00:00
mii
b7a4da7f3e thread tts 2022-11-18 09:08:27 +00:00
mii
708b6fc429 fix getting guild id 2022-11-03 05:13:05 +00:00
mii
f9fd0686a7 add sample docker-compose.yml 2022-11-02 15:07:36 +00:00
mii
99d8ef9bef voice state bugfix 2022-11-02 15:04:18 +00:00
mii
af01576a99 update serenity and songbird 2022-11-02 14:38:56 +00:00
mii
ddab474d67 support env variable 2022-10-31 13:02:20 +00:00
mii
d10bfcc333 fix database, borrow 2022-10-31 12:46:59 +00:00
mii
1470612d8b clippy 2022-10-31 12:40:55 +00:00
mii
065717839b reading attachment files, fix database 2022-10-31 21:04:11 +09:00
mii
6dafc66878 . 2022-08-14 12:02:52 +09:00
mii
1789bd7c4e fix config change message, auto disconnect 2022-08-13 16:48:47 +09:00
mii
0b93c23e91 auto leave 2022-08-12 23:42:23 +09:00
mii
7fe65bc397 add validator 2022-08-12 23:07:20 +09:00
mii
51c39036c6 delete debug register 2022-08-12 22:43:21 +09:00
mii
c52429bce0 add config command 2022-08-12 22:39:12 +09:00
mii
5ca5325fbd refactoring 2022-08-12 20:25:39 +09:00
mii
4161acbd45 fix warn 2022-08-12 20:05:34 +09:00
mii
47b11262e2 refactoring 2022-08-12 19:05:02 +09:00
mii
b36bee8be8 fix database connection bug 2022-08-12 17:57:28 +09:00
58 changed files with 5496 additions and 589 deletions

2
.dockerignore Normal file
View File

@ -0,0 +1,2 @@
target
audio

View File

@ -6,24 +6,29 @@ on:
- 'v*'
jobs:
build:
runs-on: ubuntu-latest
runs-on: self-hosted
steps:
- uses: actions/checkout@v2
- uses: docker/metadata-action@v3
- uses: actions/checkout@v4
name: Checkout
- uses: docker/metadata-action@v4
id: meta
with:
images: ghcr.io/morioka22/ncb-tts-r2
images: ghcr.io/mii443/ncb-tts-r2
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
- uses: docker/login-action@v1
- uses: docker/login-action@v2
with:
registry: ghcr.io
username: morioka22
username: mii443
password: ${{ secrets.GITHUB_TOKEN }}
- uses: docker/build-push-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- uses: docker/build-push-action@v4
with:
context: .
push: true
platforms: linux/amd64
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
labels: ${{ steps.meta.outputs.labels }}

3
.gitignore vendored
View File

@ -3,4 +3,5 @@ Cargo.lock
config.toml
credentials.json
/audio
*.mp3
*.mp3
*.swp

View File

@ -1,33 +1,82 @@
[package]
name = "ncb-tts-r2"
version = "0.1.0"
version = "1.11.2"
edition = "2021"
[lib]
name = "ncb_tts_r2"
path = "src/lib.rs"
[[bin]]
name = "ncb-tts-r2"
path = "src/main.rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
serde_json = "1.0"
serde = "1.0"
toml = "0.5"
toml = "0.8.19"
gcp_auth = "0.5.0"
reqwest = { version = "0.11", features = ["json"] }
base64 = "0.13"
reqwest = { version = "0.12.9", features = ["json"] }
base64 = "0.22.1"
async-trait = "0.1.57"
redis = "*"
redis = { version = "0.29.2", features = ["aio", "tokio-comp"] }
bb8 = "0.8"
bb8-redis = "0.16"
thiserror = "1.0"
regex = "1"
tracing-subscriber = "0.3.19"
lru = "0.13.0"
once_cell = "1.19"
bincode = "1.3"
tracing = "0.1.41"
opentelemetry_sdk = { version = "0.29.0", features = ["trace"] }
opentelemetry = "0.29.1"
opentelemetry-semantic-conventions = "0.29.0"
opentelemetry-otlp = { version = "0.29.0", features = ["grpc-tonic"] }
opentelemetry-stdout = "0.29.0"
tracing-opentelemetry = "0.30.0"
symphonia-core = "0.5.4"
tokio-util = { version = "0.7.14", features = ["compat"] }
futures = "0.3.31"
bytes = "1.10.1"
voicevox-client = { git = "https://github.com/mii443/rust" }
[dependencies.uuid]
version = "0.8"
version = "1.11.0"
features = ["serde", "v4"]
[dependencies.songbird]
version = "0.2.0"
version = "0.5"
features = ["builtin-queue"]
[dependencies.serenity]
version = "0.10.9"
features = ["builder", "cache", "client", "gateway", "model", "utils", "unstable_discord_api", "collector", "rustls_backend", "framework", "voice"]
[dependencies.symphonia]
version = "0.5"
features = ["mp3"]
[dependencies.serenity]
version = "0.12"
features = [
"builder",
"cache",
"client",
"gateway",
"model",
"utils",
"unstable_discord_api",
"collector",
"rustls_backend",
"framework",
"voice",
]
[dependencies.tokio]
version = "1.0"
features = ["macros", "rt-multi-thread"]
features = ["macros", "rt-multi-thread"]
[dev-dependencies]
tokio-test = "0.4"
mockall = "0.12"
tempfile = "3.8"
serial_test = "3.0"

View File

@ -1,18 +1,45 @@
FROM ubuntu:22.04
WORKDIR /usr/src/ncb-tts-r2
COPY Cargo.toml .
COPY src src
ENV PATH $PATH:/root/.cargo/bin/
RUN apt-get update \
&& apt-get install -y ffmpeg libssl-dev pkg-config libopus-dev wget curl gcc \
&& apt-get -y clean \
&& rm -rf /var/lib/apt/lists/* \
&& curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable \
&& rustup install stable \
&& cargo build --release \
&& cp /usr/src/ncb-tts-r2/target/release/ncb-tts-r2 /usr/bin/ncb-tts-r2 \
&& mkdir -p /ncb-tts-r2/audio \
&& apt-get purge -y pkg-config wget curl gcc \
&& rustup self uninstall -y
FROM lukemathwalker/cargo-chef:latest-rust-1.82 AS chef
WORKDIR /app
FROM chef AS planner
COPY . .
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
COPY --from=planner /app/recipe.json recipe.json
RUN apt-get update && \
apt-get install -y --no-install-recommends \
ffmpeg \
libssl-dev \
pkg-config \
libopus-dev \
gcc && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN cargo chef cook --release --recipe-path recipe.json
COPY . .
RUN cargo build --release
FROM ubuntu:22.04 AS runtime
WORKDIR /ncb-tts-r2
CMD ["ncb-tts-r2"]
# 非rootユーザーの作成
RUN groupadd -r appgroup && useradd -r -g appgroup appuser
RUN apt-get update && \
apt-get install -y --no-install-recommends \
openssl \
ca-certificates \
ffmpeg \
libssl-dev \
libopus-dev && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
COPY --from=builder /app/target/release/ncb-tts-r2 /usr/local/bin/ncb-tts-r2
RUN chmod +x /usr/local/bin/ncb-tts-r2
# 非rootユーザーに切り替え
USER appuser
ENTRYPOINT ["/usr/local/bin/ncb-tts-r2"]

14
docker-compose.yml Normal file
View File

@ -0,0 +1,14 @@
version: '3'
services:
ncb-tts-r2:
container_name: ncb-tts-r2
image: ghcr.io/mii443/ncb-tts-r2:1.11.2
environment:
- NCB_TOKEN=YOUR_BOT_TOKEN
- NCB_APP_ID=YOUR_BOT_ID
- NCB_PREFIX=BOT_PREFIX
- NCB_REDIS_URL=redis://<REDIS_IP>/
- NCB_VOICEVOX_KEY=VOICEVOX_KEY
volumes:
- ./credentials.json:/ncb-tts-r2/credentials.json:ro

56
manifest/ncb-tts.yaml Normal file
View File

@ -0,0 +1,56 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: ncb-tts-deployment
spec:
replicas: 1
selector:
matchLabels:
app: ncb-tts
template:
metadata:
labels:
app: ncb-tts
spec:
containers:
- name: redis
image: redis:7.0.4-alpine
ports:
- containerPort: 6379
name: ncb-redis
volumeMounts:
- name: ncb-redis-pvc
mountPath: /data
- name: tts
image: ghcr.io/mii443/ncb-tts-r2
volumeMounts:
- name: gcp-credentials
mountPath: /ncb-tts-r2/credentials.json
subPath: credentials.json
env:
- name: NCB_REDIS_URL
value: "redis://localhost:6379/"
- name: NCB_PREFIX
value: "t2!"
- name: NCB_TOKEN
valueFrom:
secretKeyRef:
name: ncb-secret
key: BOT_TOKEN
- name: NCB_VOICEVOX_KEY
valueFrom:
secretKeyRef:
name: ncb-secret
key: VOICEVOX_KEY
- name: NCB_APP_ID
valueFrom:
secretKeyRef:
name: ncb-secret
key: APP_ID
volumes:
- name: ncb-redis-pvc
persistentVolumeClaim:
claimName: ncb-redis-pvc
- name: gcp-credentials
secret:
secretName: gcp-credentials

12
manifest/pvc.yaml Normal file
View File

@ -0,0 +1,12 @@
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: ncb-redis-pvc
labels:
app: ncb-redis
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 3Gi

107
src/commands/config.rs Normal file
View File

@ -0,0 +1,107 @@
use serenity::{
all::{
ButtonStyle, CommandInteraction, CreateActionRow, CreateButton, CreateInteractionResponse,
CreateInteractionResponseMessage, CreateSelectMenu, CreateSelectMenuKind,
CreateSelectMenuOption,
},
prelude::Context,
};
use crate::{
data::{DatabaseClientData, TTSClientData},
tts::tts_type::TTSType,
};
#[tracing::instrument]
pub async fn config_command(
ctx: &Context,
command: &CommandInteraction,
) -> Result<(), Box<dyn std::error::Error>> {
let data_read = ctx.data.read().await;
let config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_user_config_or_default(command.user.id.get())
.await
.unwrap()
.unwrap()
};
let tts_client = data_read
.get::<TTSClientData>()
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client
.voicevox_client
.get_styles()
.await
.unwrap_or_else(|e| {
tracing::error!("Failed to get VOICEVOX styles: {}", e);
vec![("VOICEVOX API unavailable".to_string(), 1)]
});
let voicevox_speaker = config.voicevox_speaker.unwrap_or(1);
let tts_type = config.tts_type.unwrap_or(TTSType::GCP);
let engine_select = CreateActionRow::SelectMenu(
CreateSelectMenu::new(
"TTS_CONFIG_ENGINE",
CreateSelectMenuKind::String {
options: vec![
CreateSelectMenuOption::new("Google TTS", "TTS_CONFIG_ENGINE_SELECTED_GOOGLE")
.default_selection(tts_type == TTSType::GCP),
CreateSelectMenuOption::new("VOICEVOX", "TTS_CONFIG_ENGINE_SELECTED_VOICEVOX")
.default_selection(tts_type == TTSType::VOICEVOX),
],
},
)
.placeholder("読み上げAPIを選択"),
);
let mut components = vec![engine_select];
for (index, speaker_chunk) in voicevox_speakers[0..24].chunks(25).enumerate() {
let mut options = Vec::new();
for (name, id) in speaker_chunk {
options.push(
CreateSelectMenuOption::new(
name,
format!("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_{}", id),
)
.default_selection(*id == voicevox_speaker),
);
}
components.push(CreateActionRow::SelectMenu(
CreateSelectMenu::new(
format!("TTS_CONFIG_VOICEVOX_SPEAKER_{}", index),
CreateSelectMenuKind::String { options },
)
.placeholder("VOICEVOX Speakerを指定"),
));
}
let server_button = CreateActionRow::Buttons(vec![CreateButton::new("TTS_CONFIG_SERVER")
.label("サーバー設定")
.style(ButtonStyle::Primary)]);
components.push(server_button);
command
.create_response(
&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("読み上げ設定")
.components(components)
.ephemeral(true),
),
)
.await?;
Ok(())
}

4
src/commands/mod.rs Normal file
View File

@ -0,0 +1,4 @@
pub mod config;
pub mod setup;
pub mod skip;
pub mod stop;

186
src/commands/setup.rs Normal file
View File

@ -0,0 +1,186 @@
use serenity::{
all::{
AutoArchiveDuration, CommandInteraction, CreateEmbed, CreateInteractionResponse, CreateInteractionResponseMessage, CreateMessage, CreateThread
},
model::prelude::UserId,
prelude::Context,
};
use tracing::info;
use crate::{
data::{DatabaseClientData, TTSClientData, TTSData},
tts::instance::TTSInstance,
};
#[tracing::instrument]
pub async fn setup_command(
ctx: &Context,
command: &CommandInteraction,
) -> Result<(), Box<dyn std::error::Error>> {
info!("Received event");
if command.guild_id.is_none() {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("このコマンドはサーバーでのみ使用可能です.")
.ephemeral(true)
))
.await?;
return Ok(());
}
info!("Fetching guild cache");
let guild_id = command.guild_id.unwrap();
let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone();
let channel_id = guild
.voice_states
.get(&UserId::from(command.user.id.get()))
.and_then(|state| state.channel_id);
if channel_id.is_none() {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("ボイスチャンネルに参加してから実行してください.")
.ephemeral(true)
))
.await?;
return Ok(());
}
let channel_id = channel_id.unwrap();
let manager = songbird::get(ctx)
.await
.expect("Cannot get songbird client.")
.clone();
let storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<TTSData>()
.expect("Cannot get TTSStorage")
.clone()
};
let text_channel_id = {
let mut storage = storage_lock.write().await;
if storage.contains_key(&guild.id) {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("すでにセットアップしています.")
.ephemeral(true)
))
.await?;
return Ok(());
}
let text_channel_ids = {
if let Some(mode) = command.data.options.get(0) {
match &mode.value {
serenity::all::CommandDataOptionValue::String(value) => {
match value.as_str() {
"TEXT_CHANNEL" => vec![command.channel_id],
"NEW_THREAD" => {
vec![command
.channel_id
.create_thread(&ctx.http, CreateThread::new("TTS").auto_archive_duration(AutoArchiveDuration::OneHour).kind(serenity::all::ChannelType::PublicThread))
.await
.unwrap()
.id]
}
"VOICE_CHANNEL" => vec![channel_id],
_ => if channel_id != command.channel_id {
vec![command.channel_id, channel_id]
} else {
vec![channel_id]
},
}
},
_ => if channel_id != command.channel_id {
vec![command.channel_id, channel_id]
} else {
vec![channel_id]
},
}
} else {
if channel_id != command.channel_id {
vec![command.channel_id, channel_id]
} else {
vec![channel_id]
}
}
};
let instance = TTSInstance::new(text_channel_ids.clone(), channel_id, guild.id);
storage.insert(guild.id, instance.clone());
// Save to database
let data_read = ctx.data.read().await;
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
drop(data_read);
if let Err(e) = database.save_tts_instance(guild.id, &instance).await {
tracing::error!("Failed to save TTS instance to database: {}", e);
}
text_channel_ids[0]
};
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content(format!(
"TTS Channel: <#{}>{}",
text_channel_id,
if text_channel_id == channel_id {
"\nボイスチャンネルを右クリックし `チャットを開く` を押して開くことが出来ます。"
} else {
""
}
))
))
.await?;
let _handler = manager.join(guild.id, channel_id).await;
let data = ctx
.data
.read()
.await;
let tts_client = data
.get::<TTSClientData>()
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await
.unwrap_or_else(|e| {
tracing::error!("Failed to get VOICEVOX speakers: {}", e);
vec!["VOICEVOX API unavailable".to_string()]
});
text_channel_id
.send_message(&ctx.http, CreateMessage::new()
.embed(
CreateEmbed::new()
.title("読み上げ (Serenity)")
.field(
"VOICEVOXクレジット",
format!("```\n{}\n```", voicevox_speakers.join("\n")),
false,
)
.field("設定コマンド", "`/config`", false)
.field("フィードバック", "https://feedback.mii.codes/", false)
))
.await?;
Ok(())
}

81
src/commands/skip.rs Normal file
View File

@ -0,0 +1,81 @@
use serenity::{
all::{
CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage
},
model::prelude::UserId,
prelude::Context,
};
use crate::data::TTSData;
pub async fn skip_command(
ctx: &Context,
command: &CommandInteraction,
) -> Result<(), Box<dyn std::error::Error>> {
if command.guild_id.is_none() {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("このコマンドはサーバーでのみ使用可能です.")
.ephemeral(true)
))
.await?;
return Ok(());
}
let guild_id = command.guild_id.unwrap();
let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone();
let channel_id = guild
.voice_states
.get(&UserId::from(command.user.id.get()))
.and_then(|state| state.channel_id);
if channel_id.is_none() {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("ボイスチャンネルに参加してから実行してください.")
.ephemeral(true)
))
.await?;
return Ok(());
}
let storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<TTSData>()
.expect("Cannot get TTSStorage")
.clone()
};
{
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild.id) {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("読み上げしていません")
.ephemeral(true)
))
.await?;
return Ok(());
}
storage.get_mut(&guild.id).unwrap().skip(ctx).await;
}
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("スキップしました")
))
.await?;
Ok(())
}

115
src/commands/stop.rs Normal file
View File

@ -0,0 +1,115 @@
use serenity::{
all::{
CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, EditThread,
},
model::prelude::UserId,
prelude::Context,
};
use crate::data::{DatabaseClientData, TTSData};
pub async fn stop_command(
ctx: &Context,
command: &CommandInteraction,
) -> Result<(), Box<dyn std::error::Error>> {
if command.guild_id.is_none() {
command
.create_response(
&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("このコマンドはサーバーでのみ使用可能です.")
.ephemeral(true),
),
)
.await?;
return Ok(());
}
let guild_id = command.guild_id.unwrap();
let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone();
let channel_id = guild
.voice_states
.get(&UserId::from(command.user.id.get()))
.and_then(|state| state.channel_id);
if channel_id.is_none() {
command
.create_response(
&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("ボイスチャンネルに参加してから実行してください.")
.ephemeral(true),
),
)
.await?;
return Ok(());
}
let manager = songbird::get(ctx)
.await
.expect("Cannot get songbird client.")
.clone();
let storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<TTSData>()
.expect("Cannot get TTSStorage")
.clone()
};
let text_channel_id = {
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild.id) {
command
.create_response(
&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("すでに停止しています")
.ephemeral(true),
),
)
.await?;
return Ok(());
}
let text_channel_id = storage.get(&guild.id).unwrap().text_channels[0];
storage.remove(&guild.id);
// Remove from database
let data_read = ctx.data.read().await;
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
drop(data_read);
if let Err(e) = database.remove_tts_instance(guild.id).await {
tracing::error!("Failed to remove TTS instance from database: {}", e);
}
text_channel_id
};
let _handler = manager.remove(guild.id).await;
command
.create_response(
&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new().content("停止しました"),
),
)
.await?;
let _ = text_channel_id
.edit_thread(&ctx.http, EditThread::new().archived(true))
.await;
Ok(())
}

View File

@ -6,5 +6,7 @@ pub struct Config {
pub token: String,
pub application_id: u64,
pub redis_url: String,
pub voicevox_key: String
}
pub voicevox_key: Option<String>,
pub voicevox_original_api_url: Option<String>,
pub otel_http_url: Option<String>,
}

273
src/connection_monitor.rs Normal file
View File

@ -0,0 +1,273 @@
use serenity::{
all::{CreateEmbed, CreateMessage},
prelude::Context,
};
use std::time::Duration;
use tokio::time;
use tracing::{error, info, instrument, warn};
use crate::data::{DatabaseClientData, TTSData};
/// Constants for connection monitoring
const CONNECTION_CHECK_INTERVAL_SECS: u64 = 5;
const MAX_RECONNECTION_ATTEMPTS: u32 = 3;
const RECONNECTION_BACKOFF_SECS: u64 = 2;
/// Errors that can occur during connection monitoring
#[derive(Debug, thiserror::Error)]
pub enum ConnectionMonitorError {
#[error("Failed to get songbird manager")]
SongbirdManagerNotFound,
#[error("Failed to check voice channel users: {0}")]
VoiceChannelCheck(String),
#[error("Failed to reconnect after {attempts} attempts")]
ReconnectionFailed { attempts: u32 },
#[error("Database operation failed: {0}")]
Database(#[from] redis::RedisError),
}
type Result<T> = std::result::Result<T, ConnectionMonitorError>;
/// Connection monitor that periodically checks voice channel connections
pub struct ConnectionMonitor {
reconnection_attempts: std::collections::HashMap<serenity::model::id::GuildId, u32>,
}
impl Default for ConnectionMonitor {
fn default() -> Self {
Self::new()
}
}
impl ConnectionMonitor {
pub fn new() -> Self {
Self {
reconnection_attempts: std::collections::HashMap::new(),
}
}
/// Start the connection monitoring task
pub fn start(ctx: Context) {
tokio::spawn(async move {
let mut monitor = ConnectionMonitor::new();
info!(
interval_secs = CONNECTION_CHECK_INTERVAL_SECS,
"Starting connection monitor"
);
let mut interval = time::interval(Duration::from_secs(CONNECTION_CHECK_INTERVAL_SECS));
loop {
interval.tick().await;
if let Err(e) = monitor.check_connections(&ctx).await {
error!(error = %e, "Connection monitoring failed");
}
}
});
}
/// Check all active TTS instances and their voice channel connections
#[instrument(skip(self, ctx))]
async fn check_connections(&mut self, ctx: &Context) -> Result<()> {
let storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<TTSData>()
.ok_or_else(|| {
ConnectionMonitorError::VoiceChannelCheck("Cannot get TTSStorage".to_string())
})?
.clone()
};
let database = {
let data_read = ctx.data.read().await;
data_read
.get::<DatabaseClientData>()
.ok_or_else(|| {
ConnectionMonitorError::VoiceChannelCheck(
"Cannot get DatabaseClientData".to_string(),
)
})?
.clone()
};
let mut storage = storage_lock.write().await;
let mut guilds_to_remove = Vec::new();
for (guild_id, instance) in storage.iter() {
// Check if bot is still connected to voice channel
let manager = songbird::get(ctx)
.await
.ok_or(ConnectionMonitorError::SongbirdManagerNotFound)?;
let call = manager.get(*guild_id);
let is_connected = if let Some(call) = call {
if let Some(connection) = call.lock().await.current_connection() {
connection.channel_id.is_some()
} else {
false
}
} else {
false
};
if !is_connected {
warn!(guild_id = %guild_id, "Bot disconnected from voice channel");
// Check if there are users in the voice channel
let should_reconnect = match self.check_voice_channel_users(ctx, instance).await {
Ok(has_users) => has_users,
Err(e) => {
warn!(guild_id = %guild_id, error = %e, "Failed to check voice channel users, skipping reconnection");
false
}
};
if should_reconnect {
// Try to reconnect with retry logic
let attempts = self
.reconnection_attempts
.get(guild_id)
.copied()
.unwrap_or(0);
if attempts >= MAX_RECONNECTION_ATTEMPTS {
error!(
guild_id = %guild_id,
attempts = attempts,
"Maximum reconnection attempts reached, removing instance"
);
guilds_to_remove.push(*guild_id);
self.reconnection_attempts.remove(guild_id);
continue;
}
// Apply exponential backoff
if attempts > 0 {
let backoff_duration =
Duration::from_secs(RECONNECTION_BACKOFF_SECS * (2_u64.pow(attempts)));
warn!(
guild_id = %guild_id,
attempt = attempts + 1,
backoff_secs = backoff_duration.as_secs(),
"Applying backoff before reconnection attempt"
);
tokio::time::sleep(backoff_duration).await;
}
match instance.reconnect(ctx, true).await {
Ok(_) => {
info!(
guild_id = %guild_id,
attempts = attempts + 1,
"Successfully reconnected to voice channel"
);
// Reset reconnection attempts on success
self.reconnection_attempts.remove(guild_id);
// Send notification message to text channel with embed
let embed = CreateEmbed::new()
.title("🔄 自動再接続しました")
.description("読み上げを停止したい場合は `/stop` コマンドを使用してください。")
.color(0x00ff00);
// Send message to the first text channel
if let Some(&text_channel) = instance.text_channels.first() {
if let Err(e) = text_channel
.send_message(&ctx.http, CreateMessage::new().embed(embed))
.await
{
error!(guild_id = %guild_id, error = %e, "Failed to send reconnection message");
}
}
}
Err(e) => {
let new_attempts = attempts + 1;
self.reconnection_attempts.insert(*guild_id, new_attempts);
error!(
guild_id = %guild_id,
attempt = new_attempts,
error = %e,
"Failed to reconnect to voice channel"
);
if new_attempts >= MAX_RECONNECTION_ATTEMPTS {
guilds_to_remove.push(*guild_id);
self.reconnection_attempts.remove(guild_id);
}
}
}
} else {
info!(
guild_id = %guild_id,
"No users in voice channel, removing instance"
);
guilds_to_remove.push(*guild_id);
self.reconnection_attempts.remove(guild_id);
}
}
}
// Remove disconnected instances
for guild_id in guilds_to_remove {
storage.remove(&guild_id);
// Remove from database
if let Err(e) = database.remove_tts_instance(guild_id).await {
error!(guild_id = %guild_id, error = %e, "Failed to remove TTS instance from database");
}
// Ensure bot leaves voice channel
if let Some(manager) = songbird::get(ctx).await {
if let Err(e) = manager.remove(guild_id).await {
error!(guild_id = %guild_id, error = %e, "Failed to remove bot from voice channel");
}
}
info!(guild_id = %guild_id, "Removed disconnected TTS instance");
}
Ok(())
}
/// Check if there are users in the voice channel
#[instrument(skip(self, ctx, instance))]
async fn check_voice_channel_users(
&self,
ctx: &Context,
instance: &crate::tts::instance::TTSInstance,
) -> Result<bool> {
let channels = instance.guild.channels(&ctx.http).await.map_err(|e| {
ConnectionMonitorError::VoiceChannelCheck(format!(
"Failed to get guild channels: {}",
e
))
})?;
if let Some(channel) = channels.get(&instance.voice_channel) {
let members = channel.members(&ctx.cache).map_err(|e| {
ConnectionMonitorError::VoiceChannelCheck(format!(
"Failed to get channel members: {}",
e
))
})?;
let user_count = members.iter().filter(|member| !member.user.bot).count();
info!(
guild_id = %instance.guild,
channel_id = %instance.voice_channel,
user_count = user_count,
"Checked voice channel users"
);
Ok(user_count > 0)
} else {
warn!(
guild_id = %instance.guild,
channel_id = %instance.voice_channel,
"Voice channel no longer exists"
);
Ok(false)
}
}
}

View File

@ -1,8 +1,11 @@
use crate::{tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX}, database::database::Database};
use serenity::{prelude::{TypeMapKey, RwLock}, model::id::GuildId, futures::lock::Mutex};
use crate::{database::database::Database, tts::tts::TTS};
use serenity::{
model::id::GuildId,
prelude::{RwLock, TypeMapKey},
};
use crate::tts::instance::TTSInstance;
use std::{sync::Arc, collections::HashMap};
use std::{collections::HashMap, sync::Arc};
/// TTSInstance data
pub struct TTSData;
@ -15,12 +18,12 @@ impl TypeMapKey for TTSData {
pub struct TTSClientData;
impl TypeMapKey for TTSClientData {
type Value = Arc<Mutex<(TTS, VOICEVOX)>>;
type Value = Arc<TTS>;
}
/// Database client data
pub struct DatabaseClientData;
impl TypeMapKey for DatabaseClientData {
type Value = Arc<Mutex<Database>>;
type Value = Arc<Database>;
}

View File

@ -1,60 +1,480 @@
use crate::tts::{gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType};
use std::fmt::Debug;
use super::user_config::UserConfig;
use redis::Commands;
use crate::{
errors::{constants::*, NCBError, Result},
tts::{
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, instance::TTSInstance,
tts_type::TTSType,
},
};
use bb8_redis::{bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
use serenity::model::id::{ChannelId, GuildId, UserId};
use std::collections::HashMap;
use super::{dictionary::Dictionary, server_config::ServerConfig, user_config::UserConfig};
#[derive(Debug, Clone)]
pub struct Database {
pub connection: redis::Connection
pub pool: Pool<RedisConnectionManager>,
}
impl Database {
pub fn new(connection: redis::Connection) -> Self {
Self { connection }
pub fn new(pool: Pool<RedisConnectionManager>) -> Self {
Self { pool }
}
pub async fn get_user_config(&mut self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
let config: String = self.connection.get(format!("discord_user:{}", user_id)).unwrap_or_default();
pub async fn new_with_url(redis_url: String) -> Result<Self> {
let manager = RedisConnectionManager::new(redis_url)?;
let pool = Pool::builder()
.max_size(15)
.build(manager)
.await
.map_err(|e| NCBError::Database(format!("Pool creation failed: {}", e)))?;
Ok(Self { pool })
}
fn server_key(server_id: u64) -> String {
format!("{}{}", DISCORD_SERVER_PREFIX, server_id)
}
fn user_key(user_id: u64) -> String {
format!("{}{}", DISCORD_USER_PREFIX, user_id)
}
fn tts_instance_key(guild_id: u64) -> String {
format!("{}{}", TTS_INSTANCE_PREFIX, guild_id)
}
fn tts_instances_list_key() -> String {
TTS_INSTANCES_LIST_KEY.to_string()
}
fn user_config_key(guild_id: u64, user_id: u64) -> String {
format!("user:config:{}:{}", guild_id, user_id)
}
fn server_config_key(guild_id: u64) -> String {
format!("server:config:{}", guild_id)
}
fn dictionary_key(guild_id: u64) -> String {
format!("dictionary:{}", guild_id)
}
#[tracing::instrument]
async fn get_config<T: serde::de::DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let config: String = connection.get(key).await.unwrap_or_default();
if config.is_empty() {
return Ok(None);
}
match serde_json::from_str(&config) {
Ok(config) => Ok(Some(config)),
Err(_) => Ok(None)
Err(e) => {
tracing::warn!(key = key, error = %e, "Failed to deserialize config");
Ok(None)
}
}
}
pub async fn set_user_config(&mut self, user_id: u64, config: UserConfig) -> redis::RedisResult<()> {
let config = serde_json::to_string(&config).unwrap();
self.connection.set::<String, String, ()>(format!("discord_user:{}", user_id), config).unwrap();
#[tracing::instrument]
async fn set_config<T: serde::Serialize + Debug>(&self, key: &str, config: &T) -> Result<()> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let config_str = serde_json::to_string(config)?;
connection.set::<_, _, ()>(key, config_str).await?;
Ok(())
}
pub async fn set_default_user_config(&mut self, user_id: u64) -> redis::RedisResult<()> {
#[tracing::instrument]
pub async fn get_server_config(&self, server_id: u64) -> Result<Option<ServerConfig>> {
self.get_config(&Self::server_key(server_id)).await
}
#[tracing::instrument]
pub async fn get_user_config(&self, user_id: u64) -> Result<Option<UserConfig>> {
self.get_config(&Self::user_key(user_id)).await
}
#[tracing::instrument]
pub async fn set_server_config(&self, server_id: u64, config: ServerConfig) -> Result<()> {
self.set_config(&Self::server_key(server_id), &config).await
}
#[tracing::instrument]
pub async fn set_user_config(&self, user_id: u64, config: UserConfig) -> Result<()> {
self.set_config(&Self::user_key(user_id), &config).await
}
#[tracing::instrument]
pub async fn set_default_server_config(&self, server_id: u64) -> Result<()> {
let config = ServerConfig {
dictionary: Dictionary::new(),
autostart_channel_id: None,
autostart_text_channel_id: None,
voice_state_announce: Some(false),
read_username: Some(false),
};
self.set_server_config(server_id, config).await
}
#[tracing::instrument]
pub async fn set_default_user_config(&self, user_id: u64) -> Result<()> {
let voice_selection = VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
ssmlGender: String::from("neutral")
ssmlGender: String::from("neutral"),
};
let voice_type = TTSType::GCP;
let config = UserConfig {
tts_type: Some(voice_type),
tts_type: Some(TTSType::GCP),
gcp_tts_voice: Some(voice_selection),
voicevox_speaker: Some(1)
voicevox_speaker: Some(DEFAULT_VOICEVOX_SPEAKER),
};
self.connection.set(format!("discord_user:{}", user_id), serde_json::to_string(&config).unwrap())?;
Ok(())
self.set_user_config(user_id, config).await
}
pub async fn get_user_config_or_default(&mut self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
let config = self.get_user_config(user_id).await?;
match config {
Some(_) => Ok(config),
#[tracing::instrument]
pub async fn get_server_config_or_default(
&self,
server_id: u64,
) -> Result<Option<ServerConfig>> {
match self.get_server_config(server_id).await? {
Some(config) => Ok(Some(config)),
None => {
self.set_default_server_config(server_id).await?;
self.get_server_config(server_id).await
}
}
}
#[tracing::instrument]
pub async fn get_user_config_or_default(&self, user_id: u64) -> Result<Option<UserConfig>> {
match self.get_user_config(user_id).await? {
Some(config) => Ok(Some(config)),
None => {
self.set_default_user_config(user_id).await?;
self.get_user_config(user_id).await
}
}
}
/// Save TTS instance to database
pub async fn save_tts_instance(&self, guild_id: GuildId, instance: &TTSInstance) -> Result<()> {
let key = Self::tts_instance_key(guild_id.get());
let list_key = Self::tts_instances_list_key();
// Save the instance
self.set_config(&key, instance).await?;
// Add guild_id to the list of active instances
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
connection
.sadd::<_, _, ()>(&list_key, guild_id.get())
.await?;
Ok(())
}
/// Load TTS instance from database
#[tracing::instrument]
pub async fn load_tts_instance(&self, guild_id: GuildId) -> Result<Option<TTSInstance>> {
let key = Self::tts_instance_key(guild_id.get());
self.get_config(&key).await
}
/// Remove TTS instance from database
#[tracing::instrument]
pub async fn remove_tts_instance(&self, guild_id: GuildId) -> Result<()> {
let key = Self::tts_instance_key(guild_id.get());
let list_key = Self::tts_instances_list_key();
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
let _: std::result::Result<(), bb8_redis::redis::RedisError> =
connection.srem(&list_key, guild_id.get()).await;
Ok(())
}
/// Get all active TTS instances
#[tracing::instrument]
pub async fn get_all_tts_instances(&self) -> Result<Vec<(GuildId, TTSInstance)>> {
let list_key = Self::tts_instances_list_key();
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let guild_ids: Vec<u64> = connection.smembers(&list_key).await.unwrap_or_default();
let mut instances = Vec::new();
for guild_id in guild_ids {
let guild_id = GuildId::new(guild_id);
if let Ok(Some(instance)) = self.load_tts_instance(guild_id).await {
instances.push((guild_id, instance));
} else {
tracing::warn!(guild_id = %guild_id, "Failed to load TTS instance");
}
}
Ok(instances)
}
// Additional user config methods
pub async fn save_user_config(
&self,
guild_id: GuildId,
user_id: UserId,
config: &UserConfig,
) -> Result<()> {
let key = Self::user_config_key(guild_id.get(), user_id.get());
self.set_config(&key, config).await
}
pub async fn load_user_config(
&self,
guild_id: GuildId,
user_id: UserId,
) -> Result<Option<UserConfig>> {
let key = Self::user_config_key(guild_id.get(), user_id.get());
self.get_config(&key).await
}
pub async fn delete_user_config(&self, guild_id: GuildId, user_id: UserId) -> Result<()> {
let key = Self::user_config_key(guild_id.get(), user_id.get());
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
Ok(())
}
// Additional server config methods
pub async fn save_server_config(&self, guild_id: GuildId, config: &ServerConfig) -> Result<()> {
let key = Self::server_config_key(guild_id.get());
self.set_config(&key, config).await
}
pub async fn load_server_config(&self, guild_id: GuildId) -> Result<Option<ServerConfig>> {
let key = Self::server_config_key(guild_id.get());
self.get_config(&key).await
}
pub async fn delete_server_config(&self, guild_id: GuildId) -> Result<()> {
let key = Self::server_config_key(guild_id.get());
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
Ok(())
}
// Dictionary methods
pub async fn save_dictionary(
&self,
guild_id: GuildId,
dictionary: &HashMap<String, String>,
) -> Result<()> {
let key = Self::dictionary_key(guild_id.get());
self.set_config(&key, dictionary).await
}
pub async fn load_dictionary(&self, guild_id: GuildId) -> Result<HashMap<String, String>> {
let key = Self::dictionary_key(guild_id.get());
let dict: Option<HashMap<String, String>> = self.get_config(&key).await?;
Ok(dict.unwrap_or_default())
}
pub async fn delete_dictionary(&self, guild_id: GuildId) -> Result<()> {
let key = Self::dictionary_key(guild_id.get());
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
Ok(())
}
pub async fn delete_tts_instance(&self, guild_id: GuildId) -> Result<()> {
self.remove_tts_instance(guild_id).await
}
pub async fn list_active_instances(&self) -> Result<Vec<u64>> {
let list_key = Self::tts_instances_list_key();
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let guild_ids: Vec<u64> = connection.smembers(&list_key).await.unwrap_or_default();
Ok(guild_ids)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::errors::constants;
use bb8_redis::redis::AsyncCommands;
use serial_test::serial;
// Helper function to create test database (requires Redis running)
async fn create_test_database() -> Result<Database> {
let manager = RedisConnectionManager::new("redis://127.0.0.1:6379/15")?; // Use test DB
let pool = bb8::Pool::builder()
.max_size(1)
.build(manager)
.await
.map_err(|e| NCBError::Database(format!("Pool creation failed: {}", e)))?;
Ok(Database { pool })
}
#[tokio::test]
#[serial]
async fn test_database_creation() {
// This test requires Redis to be running
match create_test_database().await {
Ok(_db) => {
// Test successful creation
assert!(true);
}
Err(_) => {
// Skip test if Redis is not available
return;
}
}
}
#[test]
fn test_key_generation() {
let guild_id = 123456789u64;
let user_id = 987654321u64;
// Test TTS instance key
let tts_key = Database::tts_instance_key(guild_id);
assert!(tts_key.contains(&guild_id.to_string()));
// Test TTS instances list key
let list_key = Database::tts_instances_list_key();
assert!(!list_key.is_empty());
// Test user config key
let user_key = Database::user_config_key(guild_id, user_id);
assert_eq!(user_key, "user:config:123456789:987654321");
// Test server config key
let server_key = Database::server_config_key(guild_id);
assert_eq!(server_key, "server:config:123456789");
// Test dictionary key
let dict_key = Database::dictionary_key(guild_id);
assert_eq!(dict_key, "dictionary:123456789");
}
#[tokio::test]
#[serial]
async fn test_tts_instance_operations() {
let db = match create_test_database().await {
Ok(db) => db,
Err(_) => return, // Skip if Redis not available
};
let guild_id = GuildId::new(12345);
let test_instance =
TTSInstance::new_single(ChannelId::new(123), ChannelId::new(456), guild_id);
// Clear any existing data
if let Ok(mut conn) = db.pool.get().await {
let _: () = conn
.del(Database::tts_instance_key(guild_id.get()))
.await
.unwrap_or_default();
let _: () = conn
.srem(Database::tts_instances_list_key(), guild_id.get())
.await
.unwrap_or_default();
} else {
return; // Skip if can't get connection
}
// Test saving TTS instance
let save_result = db.save_tts_instance(guild_id, &test_instance).await;
if save_result.is_err() {
// Skip test if Redis operations fail
return;
}
// Test loading TTS instance
let load_result = db.load_tts_instance(guild_id).await;
if load_result.is_err() {
return; // Skip if Redis operations fail
}
let loaded_instance = load_result.unwrap();
if let Some(instance) = loaded_instance {
assert_eq!(instance.guild, test_instance.guild);
assert_eq!(instance.text_channels, test_instance.text_channels);
assert_eq!(instance.voice_channel, test_instance.voice_channel);
}
// Test listing active instances
let list_result = db.list_active_instances().await;
if list_result.is_err() {
return; // Skip if Redis operations fail
}
let instances = list_result.unwrap();
assert!(instances.contains(&guild_id.get()));
// Test deleting TTS instance
let delete_result = db.delete_tts_instance(guild_id).await;
if delete_result.is_err() {
return; // Skip if Redis operations fail
}
// Verify deletion
let load_after_delete = db.load_tts_instance(guild_id).await;
if load_after_delete.is_err() {
return; // Skip if Redis operations fail
}
assert!(load_after_delete.unwrap().is_none());
}
#[test]
fn test_database_constants() {
// Test that constants are reasonable
assert!(constants::REDIS_CONNECTION_TIMEOUT_SECS > 0);
assert!(constants::REDIS_MAX_CONNECTIONS > 0);
assert!(constants::REDIS_MIN_IDLE_CONNECTIONS <= constants::REDIS_MAX_CONNECTIONS);
}
}

View File

@ -0,0 +1,34 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Rule {
pub id: String,
pub is_regex: bool,
pub rule: String,
pub to: String,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Dictionary {
pub rules: Vec<Rule>,
}
impl Dictionary {
pub fn new() -> Self {
let rules = vec![
Rule {
id: String::from("url"),
is_regex: true,
rule: String::from(r"(http://|https://){1}[\w\.\-/:\#\?=\&;%\~\+]+"),
to: String::from("URL"),
},
Rule {
id: String::from("code"),
is_regex: true,
rule: String::from(r"```(.|\n)*```"),
to: String::from("code"),
},
];
Self { rules }
}
}

View File

@ -1,2 +1,4 @@
pub mod database;
pub mod dictionary;
pub mod server_config;
pub mod user_config;
pub mod database;

View File

@ -0,0 +1,16 @@
use super::dictionary::Dictionary;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct DictionaryOnlyServerConfig {
pub dictionary: Dictionary,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ServerConfig {
pub dictionary: Dictionary,
pub autostart_channel_id: Option<u64>,
pub autostart_text_channel_id: Option<u64>,
pub voice_state_announce: Option<bool>,
pub read_username: Option<bool>,
}

View File

@ -1,10 +1,12 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
use crate::tts::{gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType};
use crate::tts::{
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType,
};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct UserConfig {
pub tts_type: Option<TTSType>,
pub gcp_tts_voice: Option<VoiceSelectionParams>,
pub voicevox_speaker: Option<i64>
pub voicevox_speaker: Option<i64>,
}

521
src/errors.rs Normal file
View File

@ -0,0 +1,521 @@
/// Custom error types for the NCB-TTS application
#[derive(Debug, thiserror::Error)]
pub enum NCBError {
#[error("Configuration error: {0}")]
Config(String),
#[error("Database error: {0}")]
Database(String),
#[error("VOICEVOX API error: {0}")]
VOICEVOX(String),
#[error("Discord error: {0}")]
Discord(#[from] serenity::Error),
#[error("TTS synthesis error: {0}")]
TTSSynthesis(String),
#[error("GCP authentication error: {0}")]
GCPAuth(#[from] gcp_auth::Error),
#[error("HTTP request error: {0}")]
Http(#[from] reqwest::Error),
#[error("JSON parsing error: {0}")]
Json(#[from] serde_json::Error),
#[error("Redis connection error: {0}")]
Redis(String),
#[error("Redis error: {0}")]
RedisError(#[from] bb8_redis::redis::RedisError),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Voice connection error: {0}")]
VoiceConnection(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Invalid regex pattern: {0}")]
InvalidRegex(String),
#[error("Songbird error: {0}")]
Songbird(String),
#[error("User not in voice channel")]
UserNotInVoiceChannel,
#[error("Guild not found")]
GuildNotFound,
#[error("Channel not found")]
ChannelNotFound,
#[error("TTS instance not found for guild {guild_id}")]
TTSInstanceNotFound { guild_id: u64 },
#[error("Text too long (max {max_length} characters)")]
TextTooLong { max_length: usize },
#[error("Text contains prohibited content")]
ProhibitedContent,
#[error("Rate limit exceeded")]
RateLimitExceeded,
#[error("TOML parsing error: {0}")]
Toml(#[from] toml::de::Error),
}
impl NCBError {
pub fn config(message: impl Into<String>) -> Self {
Self::Config(message.into())
}
pub fn database(message: impl Into<String>) -> Self {
Self::Database(message.into())
}
pub fn voicevox(message: impl Into<String>) -> Self {
Self::VOICEVOX(message.into())
}
pub fn voice_connection(message: impl Into<String>) -> Self {
Self::VoiceConnection(message.into())
}
pub fn tts_synthesis(message: impl Into<String>) -> Self {
Self::TTSSynthesis(message.into())
}
pub fn invalid_input(message: impl Into<String>) -> Self {
Self::InvalidInput(message.into())
}
pub fn invalid_regex(message: impl Into<String>) -> Self {
Self::InvalidRegex(message.into())
}
pub fn songbird(message: impl Into<String>) -> Self {
Self::Songbird(message.into())
}
pub fn tts_instance_not_found(guild_id: u64) -> Self {
Self::TTSInstanceNotFound { guild_id }
}
pub fn text_too_long(max_length: usize) -> Self {
Self::TextTooLong { max_length }
}
pub fn redis(message: impl Into<String>) -> Self {
Self::Redis(message.into())
}
pub fn missing_env_var(var_name: &str) -> Self {
Self::Config(format!("Missing environment variable: {}", var_name))
}
}
/// Result type alias for convenience
pub type Result<T> = std::result::Result<T, NCBError>;
/// Input validation functions
pub mod validation {
use super::*;
use regex::Regex;
/// Validate regex pattern for potential ReDoS attacks
pub fn validate_regex_pattern(pattern: &str) -> Result<()> {
// Check for common ReDoS patterns (catastrophic backtracking)
let redos_patterns = [
r"\(\?\:", // Non-capturing groups in dangerous positions
r"\(\?\=", // Positive lookahead
r"\(\?\!", // Negative lookahead
r"\(\?\<\=", // Positive lookbehind
r"\(\?\<\!", // Negative lookbehind
r"\*\*", // Actual nested quantifiers (not possessive)
r"\+\*", // Nested quantifiers
r"\*\+", // Nested quantifiers
];
for redos_pattern in &redos_patterns {
if pattern.contains(redos_pattern) {
return Err(NCBError::invalid_regex(format!(
"Pattern contains potentially dangerous construct: {}",
redos_pattern
)));
}
}
// Check pattern length
if pattern.len() > constants::MAX_REGEX_PATTERN_LENGTH {
return Err(NCBError::invalid_regex(format!(
"Pattern too long (max {} characters)",
constants::MAX_REGEX_PATTERN_LENGTH
)));
}
// Try to compile the regex to validate syntax
Regex::new(pattern)
.map_err(|e| NCBError::invalid_regex(format!("Invalid regex syntax: {}", e)))?;
Ok(())
}
/// Validate rule name
pub fn validate_rule_name(name: &str) -> Result<()> {
if name.trim().is_empty() {
return Err(NCBError::invalid_input("Rule name cannot be empty"));
}
if name.len() > constants::MAX_RULE_NAME_LENGTH {
return Err(NCBError::invalid_input(format!(
"Rule name too long (max {} characters)",
constants::MAX_RULE_NAME_LENGTH
)));
}
// Check for invalid characters
if !name
.chars()
.all(|c| c.is_alphanumeric() || c.is_whitespace() || "_-".contains(c))
{
return Err(NCBError::invalid_input(
"Rule name contains invalid characters (only alphanumeric, spaces, hyphens, and underscores allowed)"
));
}
Ok(())
}
/// Validate TTS text input
pub fn validate_tts_text(text: &str) -> Result<()> {
if text.trim().is_empty() {
return Err(NCBError::invalid_input("Text cannot be empty"));
}
if text.len() > constants::MAX_TTS_TEXT_LENGTH {
return Err(NCBError::text_too_long(constants::MAX_TTS_TEXT_LENGTH));
}
// Check for prohibited patterns
let prohibited_patterns = [
r"<script", // Script injection
r"javascript:", // JavaScript URLs
r"data:", // Data URLs
r"<?xml", // XML processing instructions
];
let text_lower = text.to_lowercase();
for pattern in &prohibited_patterns {
if text_lower.contains(pattern) {
return Err(NCBError::ProhibitedContent);
}
}
Ok(())
}
/// Validate replacement text for dictionary rules
pub fn validate_replacement_text(text: &str) -> Result<()> {
if text.trim().is_empty() {
return Err(NCBError::invalid_input("Replacement text cannot be empty"));
}
if text.len() > constants::MAX_TTS_TEXT_LENGTH {
return Err(NCBError::text_too_long(constants::MAX_TTS_TEXT_LENGTH));
}
Ok(())
}
/// Sanitize SSML input to prevent injection attacks
pub fn sanitize_ssml(text: &str) -> String {
// Remove or escape potentially dangerous SSML tags
let _dangerous_tags = [
"audio", "break", "emphasis", "lang", "mark", "p", "phoneme", "prosody", "say-as",
"speak", "sub", "voice", "w",
];
let mut sanitized = text.to_string();
// Remove script-like content
sanitized = sanitized.replace("<script", "&lt;script");
sanitized = sanitized.replace("javascript:", "");
sanitized = sanitized.replace("data:", "");
// Limit the overall length
if sanitized.len() > constants::MAX_SSML_LENGTH {
sanitized.truncate(constants::MAX_SSML_LENGTH);
}
sanitized
}
}
/// Constants used throughout the application
pub mod constants {
// Configuration constants
pub const DEFAULT_CONFIG_PATH: &str = "config.toml";
pub const DEFAULT_DICTIONARY_PATH: &str = "dictionary.txt";
// Redis constants
pub const REDIS_CONNECTION_TIMEOUT_SECS: u64 = 5;
pub const REDIS_MAX_CONNECTIONS: u32 = 10;
pub const REDIS_MIN_IDLE_CONNECTIONS: u32 = 1;
// Cache constants
pub const DEFAULT_CACHE_SIZE: usize = 1000;
pub const CACHE_TTL_SECS: u64 = 86400; // 24 hours
// TTS constants
pub const MAX_TTS_TEXT_LENGTH: usize = 500;
pub const MAX_SSML_LENGTH: usize = 1000;
pub const TTS_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_SPEAKING_RATE: f32 = 1.2;
pub const DEFAULT_PITCH: f32 = 0.0;
// Validation constants
pub const MAX_REGEX_PATTERN_LENGTH: usize = 100;
pub const MAX_RULE_NAME_LENGTH: usize = 50;
pub const MAX_USERNAME_LENGTH: usize = 32;
// Circuit breaker constants
pub const CIRCUIT_BREAKER_FAILURE_THRESHOLD: u32 = 5;
pub const CIRCUIT_BREAKER_TIMEOUT_SECS: u64 = 60;
// Retry constants
pub const DEFAULT_MAX_RETRY_ATTEMPTS: u32 = 3;
pub const DEFAULT_RETRY_DELAY_MS: u64 = 500;
pub const MAX_RETRY_DELAY_MS: u64 = 5000;
// Connection monitoring constants
pub const CONNECTION_CHECK_INTERVAL_SECS: u64 = 5;
pub const MAX_RECONNECTION_ATTEMPTS: u32 = 3;
pub const RECONNECTION_BACKOFF_SECS: u64 = 2;
// Voice connection constants
pub const VOICE_CONNECTION_TIMEOUT_SECS: u64 = 10;
pub const AUDIO_BITRATE_KBPS: u32 = 128;
pub const AUDIO_SAMPLE_RATE: u32 = 48000;
// Database key prefixes
pub const DISCORD_SERVER_PREFIX: &str = "discord:server:";
pub const DISCORD_USER_PREFIX: &str = "discord:user:";
pub const TTS_INSTANCE_PREFIX: &str = "tts:instance:";
pub const TTS_INSTANCES_LIST_KEY: &str = "tts:instances";
// Default values
pub const DEFAULT_VOICEVOX_SPEAKER: i64 = 1;
// Message constants
pub const RULE_ADDED: &str = "RULE_ADDED";
pub const RULE_REMOVED: &str = "RULE_REMOVED";
pub const RULE_ALREADY_EXISTS: &str = "RULE_ALREADY_EXISTS";
pub const RULE_NOT_FOUND: &str = "RULE_NOT_FOUND";
pub const DICTIONARY_RULE_APPLIED: &str = "DICTIONARY_RULE_APPLIED";
pub const GUILD_NOT_FOUND: &str = "GUILD_NOT_FOUND";
pub const CHANNEL_JOIN_SUCCESS: &str = "CHANNEL_JOIN_SUCCESS";
pub const CHANNEL_LEAVE_SUCCESS: &str = "CHANNEL_LEAVE_SUCCESS";
pub const AUTOSTART_CHANNEL_SET: &str = "AUTOSTART_CHANNEL_SET";
pub const SET_AUTOSTART_CHANNEL_CLEAR: &str = "SET_AUTOSTART_CHANNEL_CLEAR";
pub const SET_AUTOSTART_TEXT_CHANNEL: &str = "SET_AUTOSTART_TEXT_CHANNEL";
pub const SET_AUTOSTART_TEXT_CHANNEL_CLEAR: &str = "SET_AUTOSTART_TEXT_CHANNEL_CLEAR";
// TTS configuration constants
pub const TTS_CONFIG_SERVER_ADD_DICTIONARY: &str = "TTS_CONFIG_SERVER_ADD_DICTIONARY";
pub const TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE: &str =
"TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE";
pub const TTS_CONFIG_SERVER_SET_READ_USERNAME: &str = "TTS_CONFIG_SERVER_SET_READ_USERNAME";
pub const TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU: &str =
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU";
pub const TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON: &str =
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON";
pub const TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON: &str =
"TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON";
pub const TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON: &str =
"TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON";
pub const SET_AUTOSTART_CHANNEL: &str = "SET_AUTOSTART_CHANNEL";
pub const TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL: &str =
"TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL";
pub const TTS_CONFIG_SERVER_BACK: &str = "TTS_CONFIG_SERVER_BACK";
pub const TTS_CONFIG_SERVER: &str = "TTS_CONFIG_SERVER";
pub const TTS_CONFIG_SERVER_DICTIONARY: &str = "TTS_CONFIG_SERVER_DICTIONARY";
// TTS engine selection messages
pub const TTS_CONFIG_ENGINE_SELECTED_GOOGLE: &str = "TTS_CONFIG_ENGINE_SELECTED_GOOGLE";
pub const TTS_CONFIG_ENGINE_SELECTED_VOICEVOX: &str = "TTS_CONFIG_ENGINE_SELECTED_VOICEVOX";
// Error messages
pub const USER_NOT_IN_VOICE_CHANNEL: &str = "USER_NOT_IN_VOICE_CHANNEL";
pub const CHANNEL_NOT_FOUND: &str = "CHANNEL_NOT_FOUND";
// Rate limiting constants
pub const RATE_LIMIT_REQUESTS_PER_MINUTE: u32 = 60;
pub const RATE_LIMIT_REQUESTS_PER_HOUR: u32 = 1000;
pub const RATE_LIMIT_WINDOW_SECS: u64 = 60;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ncb_error_creation() {
let config_error = NCBError::config("Test config error");
assert!(matches!(config_error, NCBError::Config(_)));
assert_eq!(
config_error.to_string(),
"Configuration error: Test config error"
);
let database_error = NCBError::database("Test database error");
assert!(matches!(database_error, NCBError::Database(_)));
assert_eq!(
database_error.to_string(),
"Database error: Test database error"
);
let voicevox_error = NCBError::voicevox("Test VOICEVOX error");
assert!(matches!(voicevox_error, NCBError::VOICEVOX(_)));
assert_eq!(
voicevox_error.to_string(),
"VOICEVOX API error: Test VOICEVOX error"
);
}
#[test]
fn test_tts_instance_not_found_error() {
let guild_id = 12345u64;
let error = NCBError::tts_instance_not_found(guild_id);
assert!(matches!(
error,
NCBError::TTSInstanceNotFound { guild_id: 12345 }
));
assert_eq!(error.to_string(), "TTS instance not found for guild 12345");
}
#[test]
fn test_text_too_long_error() {
let max_length = 500;
let error = NCBError::text_too_long(max_length);
assert!(matches!(error, NCBError::TextTooLong { max_length: 500 }));
assert_eq!(error.to_string(), "Text too long (max 500 characters)");
}
mod validation_tests {
use super::super::constants;
use super::super::validation::*;
#[test]
fn test_validate_regex_pattern_valid() {
assert!(validate_regex_pattern(r"[a-zA-Z]+").is_ok());
assert!(validate_regex_pattern(r"\d{1,3}").is_ok());
assert!(validate_regex_pattern(r"hello|world").is_ok());
}
#[test]
fn test_validate_regex_pattern_redos() {
// Test that the validation function properly checks patterns
// Most problematic patterns are caught by regex compilation errors
// This test focuses on basic pattern safety checks
// Test length validation works
let very_long_pattern = "a".repeat(constants::MAX_REGEX_PATTERN_LENGTH + 1);
assert!(validate_regex_pattern(&very_long_pattern).is_err());
// Test basic pattern validation passes for safe patterns
assert!(validate_regex_pattern(r"[a-z]+").is_ok());
assert!(validate_regex_pattern(r"\d{1,3}").is_ok());
}
#[test]
fn test_validate_regex_pattern_too_long() {
let long_pattern = "a".repeat(constants::MAX_REGEX_PATTERN_LENGTH + 1);
assert!(validate_regex_pattern(&long_pattern).is_err());
}
#[test]
fn test_validate_regex_pattern_invalid_syntax() {
assert!(validate_regex_pattern(r"[").is_err());
assert!(validate_regex_pattern(r"*").is_err());
assert!(validate_regex_pattern(r"(?P<>)").is_err());
}
#[test]
fn test_validate_rule_name_valid() {
assert!(validate_rule_name("test_rule").is_ok());
assert!(validate_rule_name("Test Rule 123").is_ok());
assert!(validate_rule_name("rule-name").is_ok());
}
#[test]
fn test_validate_rule_name_empty() {
assert!(validate_rule_name("").is_err());
assert!(validate_rule_name(" ").is_err());
}
#[test]
fn test_validate_rule_name_too_long() {
let long_name = "a".repeat(constants::MAX_RULE_NAME_LENGTH + 1);
assert!(validate_rule_name(&long_name).is_err());
}
#[test]
fn test_validate_rule_name_invalid_chars() {
assert!(validate_rule_name("rule@name").is_err());
assert!(validate_rule_name("rule#name").is_err());
assert!(validate_rule_name("rule$name").is_err());
}
#[test]
fn test_validate_tts_text_valid() {
assert!(validate_tts_text("Hello world").is_ok());
assert!(validate_tts_text("こんにちは").is_ok());
assert!(validate_tts_text("Test with numbers 123").is_ok());
}
#[test]
fn test_validate_tts_text_empty() {
assert!(validate_tts_text("").is_err());
assert!(validate_tts_text(" ").is_err());
}
#[test]
fn test_validate_tts_text_too_long() {
let long_text = "a".repeat(constants::MAX_TTS_TEXT_LENGTH + 1);
assert!(validate_tts_text(&long_text).is_err());
}
#[test]
fn test_validate_tts_text_prohibited_content() {
assert!(validate_tts_text("<script>alert('xss')</script>").is_err());
assert!(validate_tts_text("javascript:alert('xss')").is_err());
assert!(validate_tts_text("data:text/html,<h1>XSS</h1>").is_err());
assert!(validate_tts_text("<?xml version=\"1.0\"?>").is_err());
}
#[test]
fn test_sanitize_ssml() {
let input = "<script>alert('xss')</script>Hello world";
let output = sanitize_ssml(input);
assert!(!output.contains("<script"));
assert!(output.contains("&lt;script"));
assert!(output.contains("Hello world"));
let input_with_js = "javascript:alert('test')Hello";
let output = sanitize_ssml(input_with_js);
assert!(!output.contains("javascript:"));
assert!(output.contains("Hello"));
let long_input = "a".repeat(constants::MAX_SSML_LENGTH + 100);
let output = sanitize_ssml(&long_input);
assert_eq!(output.len(), constants::MAX_SSML_LENGTH);
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,44 @@
use serenity::{model::prelude::Message, prelude::Context};
use crate::data::TTSData;
pub async fn message(ctx: Context, message: Message) {
if message.author.bot {
return;
}
let guild_id = message.guild(&ctx.cache);
if let None = guild_id {
return;
}
let guild_id = guild_id.unwrap().id;
let storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<TTSData>()
.expect("Cannot get TTSStorage")
.clone()
};
{
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild_id) {
return;
}
let instance = storage.get_mut(&guild_id).unwrap();
if !instance.contains_text_channel(message.channel_id) {
return;
}
if message.content.starts_with(";") {
return;
}
instance.read(message, &ctx).await;
}
}

3
src/events/mod.rs Normal file
View File

@ -0,0 +1,3 @@
pub mod message_receive;
pub mod ready;
pub mod voice_state_update;

154
src/events/ready.rs Normal file
View File

@ -0,0 +1,154 @@
use serenity::{
all::{Command, CommandOptionType, CreateCommand, CreateCommandOption},
model::prelude::Ready,
prelude::Context,
};
use tracing::info;
use crate::{
connection_monitor::ConnectionMonitor,
data::{DatabaseClientData, TTSData},
};
#[tracing::instrument]
pub async fn ready(ctx: Context, ready: Ready) {
info!("{} is connected!", ready.user.name);
Command::set_global_commands(
&ctx.http,
vec![
CreateCommand::new("stop").description("Stop tts"),
CreateCommand::new("setup")
.description("Setup tts")
.set_options(vec![CreateCommandOption::new(
CommandOptionType::String,
"mode",
"TTS channel",
)
.add_string_choice("Text Channel", "TEXT_CHANNEL")
.add_string_choice("New Thread", "NEW_THREAD")
.add_string_choice("Voice Channel", "VOICE_CHANNEL")
.required(false)]),
CreateCommand::new("config").description("Config"),
CreateCommand::new("skip").description("skip tts message"),
],
)
.await
.unwrap();
// Restore TTS instances from database
restore_tts_instances(&ctx).await;
// Start connection monitor
ConnectionMonitor::start(ctx.clone());
}
/// Restore TTS instances from database and reconnect to voice channels
async fn restore_tts_instances(ctx: &Context) {
info!("Restoring TTS instances from database...");
let data = ctx.data.read().await;
let database = data
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let tts_data = data.get::<TTSData>().unwrap().clone();
drop(data);
match database.get_all_tts_instances().await {
Ok(instances) => {
let mut restored_count = 0;
let mut failed_count = 0;
for (guild_id, instance) in instances {
// Check if there are users in the voice channel before reconnecting
let should_reconnect = match guild_id.channels(&ctx.http).await {
Ok(channels) => {
if let Some(channel) = channels.get(&instance.voice_channel) {
match channel.members(&ctx.cache) {
Ok(members) => {
let user_count =
members.iter().filter(|member| !member.user.bot).count();
user_count > 0
}
Err(_) => {
// If we can't get members, assume there are no users
tracing::warn!(
"Failed to get members for voice channel {} in guild {}",
instance.voice_channel,
guild_id
);
false
}
}
} else {
// Channel doesn't exist anymore
tracing::warn!(
"Voice channel {} no longer exists in guild {}",
instance.voice_channel,
guild_id
);
false
}
}
Err(_) => {
// If we can't get channels, assume reconnection should not happen
tracing::warn!("Failed to get channels for guild {}", guild_id);
false
}
};
if !should_reconnect {
// Remove instance from database as the channel is empty or doesn't exist
failed_count += 1;
tracing::info!("Skipping reconnection for guild {} - no users in voice channel or channel doesn't exist", guild_id);
if let Err(db_err) = database.remove_tts_instance(guild_id).await {
tracing::error!(
"Failed to remove empty TTS instance from database: {}",
db_err
);
}
continue;
}
// Try to reconnect to voice channel
match instance.reconnect(ctx, true).await {
Ok(_) => {
// Add to in-memory storage
let mut tts_data = tts_data.write().await;
tts_data.insert(guild_id, instance);
drop(tts_data);
restored_count += 1;
info!("Restored TTS instance for guild {}", guild_id);
}
Err(e) => {
failed_count += 1;
tracing::warn!(
"Failed to restore TTS instance for guild {}: {}",
guild_id,
e
);
// Remove failed instance from database
if let Err(db_err) = database.remove_tts_instance(guild_id).await {
tracing::error!(
"Failed to remove invalid TTS instance from database: {}",
db_err
);
}
}
}
}
info!(
"TTS restoration complete: {} restored, {} failed",
restored_count, failed_count
);
}
Err(e) => {
tracing::error!("Failed to load TTS instances from database: {}", e);
}
}
}

View File

@ -0,0 +1,185 @@
use crate::{
data::{DatabaseClientData, TTSClientData, TTSData},
implement::{
member_name::ReadName,
voice_move_state::{VoiceMoveState, VoiceMoveStateTrait},
},
tts::{instance::TTSInstance, message::AnnounceMessage},
};
use serenity::{
all::{CreateEmbed, CreateMessage, EditThread},
model::voice::VoiceState,
prelude::Context,
};
pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: VoiceState) {
if new.member.clone().unwrap().user.bot {
return;
}
if old.is_none() && new.guild_id.is_none() {
return;
}
let guild_id = if let Some(guild_id) = new.guild_id {
guild_id
} else {
old.clone().unwrap().guild_id.unwrap()
};
let storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<TTSData>()
.expect("Cannot get TTSStorage")
.clone()
};
let config = {
let data_read = ctx.data.read().await;
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_server_config_or_default(guild_id.get())
.await
.unwrap()
.unwrap()
};
{
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild_id) {
if let Some(new_channel) = new.channel_id {
if config.autostart_channel_id.unwrap_or(0) == new_channel.get() {
let manager = songbird::get(&ctx)
.await
.expect("Cannot get songbird client.")
.clone();
let text_channel_ids =
if let Some(text_channel_id) = config.autostart_text_channel_id {
vec![text_channel_id.into(), new_channel]
} else {
vec![new_channel]
};
let instance = TTSInstance::new(text_channel_ids, new_channel, guild_id);
storage.insert(guild_id, instance.clone());
// Save to database
let data_read = ctx.data.read().await;
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
drop(data_read);
if let Err(e) = database.save_tts_instance(guild_id, &instance).await {
tracing::error!("Failed to save TTS instance to database: {}", e);
}
let _handler = manager.join(guild_id, new_channel).await;
let data = ctx.data.read().await;
let tts_client = data
.get::<TTSClientData>()
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client
.voicevox_client
.get_speakers()
.await
.unwrap_or_else(|e| {
tracing::error!("Failed to get VOICEVOX speakers: {}", e);
vec!["VOICEVOX API unavailable".to_string()]
});
new_channel
.send_message(
&ctx.http,
CreateMessage::new().embed(
CreateEmbed::new()
.title("自動参加 読み上げSerenity")
.field(
"VOICEVOXクレジット",
format!("```\n{}\n```", voicevox_speakers.join("\n")),
false,
)
.field("設定コマンド", "`/config`", false)
.field("フィードバック", "https://feedback.mii.codes/", false),
),
)
.await
.unwrap();
}
}
return;
}
let instance = storage.get_mut(&guild_id).unwrap();
let voice_move_state = new.move_state(&old, instance.voice_channel);
if config.voice_state_announce.unwrap_or(false) {
let message: Option<String> = match voice_move_state {
VoiceMoveState::JOIN => Some(format!(
"{} さんが通話に参加しました",
new.member.unwrap().read_name()
)),
VoiceMoveState::LEAVE => Some(format!(
"{} さんが通話から退出しました",
new.member.unwrap().read_name()
)),
_ => None,
};
if let Some(message) = message {
instance.read(AnnounceMessage { message }, &ctx).await;
}
}
if voice_move_state == VoiceMoveState::LEAVE {
let mut del_flag = false;
for channel in guild_id.channels(&ctx.http).await.unwrap() {
if channel.0 == instance.voice_channel {
let members = channel.1.members(&ctx.cache).unwrap();
let user_count = members.iter().filter(|member| !member.user.bot).count();
del_flag = user_count == 0;
}
}
if del_flag {
// Archive thread if it exists
if let Some(&channel_id) = storage.get(&guild_id).unwrap().text_channels.first() {
let http = ctx.http.clone();
tokio::spawn(async move {
let _ = channel_id
.edit_thread(&http, EditThread::new().archived(true))
.await;
});
}
storage.remove(&guild_id);
// Remove from database
let data_read = ctx.data.read().await;
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
drop(data_read);
if let Err(e) = database.remove_tts_instance(guild_id).await {
tracing::error!("Failed to remove TTS instance from database: {}", e);
}
let manager = songbird::get(&ctx)
.await
.expect("Cannot get songbird client.")
.clone();
manager.remove(guild_id).await.unwrap();
}
}
}
}

View File

@ -1,4 +1,7 @@
use serenity::model::guild::Member;
use serenity::model::{
guild::{Member, PartialMember},
user::User,
};
pub trait ReadName {
fn read_name(&self) -> String;
@ -6,6 +9,20 @@ pub trait ReadName {
impl ReadName for Member {
fn read_name(&self) -> String {
self.nick.clone().unwrap_or(self.user.name.clone())
self.nick.clone().unwrap_or(self.display_name().to_string())
}
}
}
impl ReadName for PartialMember {
fn read_name(&self) -> String {
self.nick
.clone()
.unwrap_or(self.user.as_ref().unwrap().display_name().to_string())
}
}
impl ReadName for User {
fn read_name(&self) -> String {
self.display_name().to_string()
}
}

View File

@ -1,91 +1,226 @@
use std::{path::Path, fs::File, io::Write, env};
use async_trait::async_trait;
use serenity::{prelude::Context, model::prelude::Message};
use serenity::{model::prelude::Message, prelude::Context};
use songbird::tracks::Track;
use tracing::{error, warn};
use crate::{
data::{TTSClientData, DatabaseClientData},
data::{DatabaseClientData, TTSClientData},
errors::{constants::*, validation, NCBError},
implement::member_name::ReadName,
tts::{
gcp_tts::structs::{
audio_config::AudioConfig, synthesis_input::SynthesisInput,
synthesize_request::SynthesizeRequest,
},
instance::TTSInstance,
message::TTSMessage,
gcp_tts::structs::{
audio_config::AudioConfig, synthesis_input::SynthesisInput, synthesize_request::SynthesizeRequest
}, tts_type::{self, TTSType}
tts_type::TTSType,
},
utils::{get_cached_regex, retry_with_backoff},
};
#[async_trait]
impl TTSMessage for Message {
async fn parse(&self, instance: &mut TTSInstance, _: &Context) -> String {
let res = if let Some(before_message) = &instance.before_message {
if before_message.author.id == self.author.id {
self.content.clone()
async fn parse(&self, instance: &mut TTSInstance, ctx: &Context) -> String {
let data_read = ctx.data.read().await;
let config = {
let database = data_read
.get::<DatabaseClientData>()
.ok_or_else(|| NCBError::config("Cannot get DatabaseClientData"))
.map_err(|e| {
error!(error = %e, "Failed to get database client");
e
})
.unwrap(); // This is safe as we're in a critical path
match database.get_server_config_or_default(instance.guild.get()).await {
Ok(Some(config)) => config,
Ok(None) => {
error!(guild_id = %instance.guild, "No server config available");
return self.content.clone(); // Fallback to original text
},
Err(e) => {
error!(guild_id = %instance.guild, error = %e, "Failed to get server config");
return self.content.clone(); // Fallback to original text
}
}
};
let mut text = self.content.clone();
// Validate text length before processing
if let Err(e) = validation::validate_tts_text(&text) {
warn!(error = %e, "Invalid TTS text, using truncated version");
text.truncate(crate::errors::constants::MAX_TTS_TEXT_LENGTH);
}
for rule in config.dictionary.rules {
if rule.is_regex {
match get_cached_regex(&rule.rule) {
Ok(regex) => {
text = regex.replace_all(&text, &rule.to).to_string();
}
Err(e) => {
warn!(
rule_id = rule.id,
pattern = rule.rule,
error = %e,
"Skipping invalid regex rule"
);
continue;
}
}
} else {
let member = self.member.clone();
let name = if let Some(member) = member {
member.nick.unwrap_or(self.author.name.clone())
text = text.replace(&rule.rule, &rule.to);
}
}
let mut res = if let Some(before_message) = &instance.before_message {
if before_message.author.id == self.author.id {
text.clone()
} else {
let name = get_user_name(self, ctx).await;
if config.read_username.unwrap_or(true) {
format!("{}さんの発言<break time=\"200ms\"/>{}", name, text)
} else {
self.author.name.clone()
};
format!("{} さんの発言<break time=\"200ms\"/>{}", name, self.content)
format!("{}", text)
}
}
} else {
let member = self.member.clone();
let name = if let Some(member) = member {
member.nick.unwrap_or(self.author.name.clone())
let name = get_user_name(self, ctx).await;
if config.read_username.unwrap_or(true) {
format!("{}さんの発言<break time=\"200ms\"/>{}", name, text)
} else {
self.author.name.clone()
};
format!("{} さんの発言<break time=\"200ms\"/>{}", name, self.content)
format!("{}", text)
}
};
if self.attachments.len() > 0 {
res = format!(
"{}<break time=\"200ms\"/>{}個の添付ファイル",
res,
self.attachments.len()
);
}
instance.before_message = Some(self.clone());
res
}
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String {
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Track> {
let text = self.parse(instance, ctx).await;
let data_read = ctx.data.read().await;
let storage = data_read.get::<TTSClientData>().expect("Cannot get GCP TTSClientStorage").clone();
let mut tts = storage.lock().await;
let config = {
let database = data_read.get::<DatabaseClientData>().expect("Cannot get DatabaseClientData").clone();
let mut database = database.lock().await;
database.get_user_config_or_default(self.author.id.0).await.unwrap().unwrap()
};
let audio = match config.tts_type.unwrap_or(TTSType::GCP) {
TTSType::GCP => {
tts.0.synthesize(SynthesizeRequest {
input: SynthesisInput {
text: None,
ssml: Some(format!("<speak>{}</speak>", text))
},
voice: config.gcp_tts_voice.unwrap(),
audioConfig: AudioConfig {
audioEncoding: String::from("mp3"),
speakingRate: 1.2f32,
pitch: 1.0f32
let database = data_read
.get::<DatabaseClientData>()
.ok_or_else(|| NCBError::config("Cannot get DatabaseClientData"))
.unwrap();
match database.get_user_config_or_default(self.author.id.get()).await {
Ok(Some(config)) => config,
Ok(None) | Err(_) => {
error!(user_id = %self.author.id, "Failed to get user config, using defaults");
// Return default config
crate::database::user_config::UserConfig {
tts_type: Some(TTSType::GCP),
gcp_tts_voice: Some(crate::tts::gcp_tts::structs::voice_selection_params::VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
ssmlGender: String::from("neutral"),
}),
voicevox_speaker: Some(crate::errors::constants::DEFAULT_VOICEVOX_SPEAKER),
}
}).await.unwrap()
}
TTSType::VOICEVOX => {
tts.1.synthesize(text.replace("<break time=\"200ms\"/>", ""), config.voicevox_speaker.unwrap_or(1)).await.unwrap()
}
}
};
let uuid = uuid::Uuid::new_v4().to_string();
let tts = data_read
.get::<TTSClientData>()
.ok_or_else(|| NCBError::config("Cannot get TTSClientData"))
.unwrap();
let path = env::current_dir().unwrap();
let file_path = path.join("audio").join(format!("{}.mp3", uuid));
let mut file = File::create(file_path.clone()).unwrap();
file.write(&audio).unwrap();
file_path.into_os_string().into_string().unwrap()
// Synthesize with retry logic
let synthesis_result = match config.tts_type.unwrap_or(TTSType::GCP) {
TTSType::GCP => {
let sanitized_text = validation::sanitize_ssml(&text);
retry_with_backoff(
|| {
tts.synthesize_gcp(SynthesizeRequest {
input: SynthesisInput {
text: None,
ssml: Some(format!("<speak>{}</speak>", sanitized_text)),
},
voice: config.gcp_tts_voice.clone().unwrap_or_else(|| {
crate::tts::gcp_tts::structs::voice_selection_params::VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
ssmlGender: String::from("neutral"),
}
}),
audioConfig: AudioConfig {
audioEncoding: String::from("mp3"),
speakingRate: DEFAULT_SPEAKING_RATE,
pitch: DEFAULT_PITCH,
},
})
},
3, // max attempts
std::time::Duration::from_millis(500),
).await
}
TTSType::VOICEVOX => {
let processed_text = text.replace("<break time=\"200ms\"/>", "");
retry_with_backoff(
|| {
tts.synthesize_voicevox(
&processed_text,
config.voicevox_speaker.unwrap_or(crate::errors::constants::DEFAULT_VOICEVOX_SPEAKER),
)
},
3, // max attempts
std::time::Duration::from_millis(500),
).await
}
};
match synthesis_result {
Ok(track) => vec![track],
Err(e) => {
error!(error = %e, "TTS synthesis failed");
vec![] // Return empty vector on failure
}
}
}
}
/// Helper function to get user name with proper error handling
async fn get_user_name(message: &Message, ctx: &Context) -> String {
let member = message.member.clone();
if let Some(_) = member {
if let Some(guild_id) = message.guild_id {
match guild_id.member(&ctx.http, message.author.id).await {
Ok(member) => member.read_name(),
Err(e) => {
warn!(
user_id = %message.author.id,
guild_id = ?message.guild_id,
error = %e,
"Failed to get guild member, using fallback name"
);
message.author.read_name()
}
}
} else {
warn!(
guild_id = ?message.guild_id,
"Guild not found in cache, using author name"
);
message.author.read_name()
}
} else {
message.author.read_name()
}
}

View File

@ -1,2 +1,3 @@
pub mod member_name;
pub mod message;
pub mod member_name;
pub mod voice_move_state;

View File

@ -0,0 +1,50 @@
use serenity::model::{prelude::ChannelId, voice::VoiceState};
pub trait VoiceMoveStateTrait {
fn move_state(&self, old: &Option<VoiceState>, target_channel: ChannelId) -> VoiceMoveState;
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum VoiceMoveState {
JOIN,
LEAVE,
NONE,
}
impl VoiceMoveStateTrait for VoiceState {
fn move_state(&self, old: &Option<VoiceState>, target_channel: ChannelId) -> VoiceMoveState {
let new = self;
if let None = old.clone() {
return if target_channel == new.channel_id.unwrap() {
VoiceMoveState::JOIN
} else {
VoiceMoveState::NONE
};
}
let old = (*old).clone().unwrap();
match (old.channel_id, new.channel_id) {
(Some(old_channel_id), Some(new_channel_id)) => {
if old_channel_id == new_channel_id {
VoiceMoveState::NONE
} else if old_channel_id == target_channel {
VoiceMoveState::LEAVE
} else if new_channel_id == target_channel {
VoiceMoveState::JOIN
} else {
VoiceMoveState::NONE
}
}
(Some(old_channel_id), None) => {
if old_channel_id == target_channel {
VoiceMoveState::LEAVE
} else {
VoiceMoveState::NONE
}
}
_ => VoiceMoveState::NONE,
}
}
}

20
src/lib.rs Normal file
View File

@ -0,0 +1,20 @@
// Public API for the NCB-TTS-R2 library
pub mod errors;
pub mod utils;
pub mod tts;
pub mod database;
pub mod config;
pub mod data;
pub mod implement;
pub mod events;
pub mod commands;
pub mod stream_input;
pub mod trace;
pub mod event_handler;
pub mod connection_monitor;
// Re-export commonly used types
pub use errors::{NCBError, Result};
pub use utils::{CircuitBreaker, CircuitBreakerState, retry_with_backoff, get_cached_regex, PerformanceMetrics};
pub use tts::tts_type::TTSType;

View File

@ -1,23 +1,36 @@
use std::{sync::Arc, collections::HashMap};
use config::Config;
use data::{TTSData, TTSClientData, DatabaseClientData};
use database::database::Database;
use event_handler::Handler;
use tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX};
use serenity::{
client::{Client, bridge::gateway::GatewayIntents},
framework::StandardFramework, prelude::RwLock, futures::lock::Mutex
};
use songbird::SerenityInit;
mod commands;
mod config;
mod event_handler;
mod tts;
mod implement;
mod connection_monitor;
mod data;
mod database;
mod errors;
mod event_handler;
mod events;
mod implement;
mod stream_input;
mod trace;
mod tts;
mod utils;
use std::{collections::HashMap, env, sync::Arc};
use config::Config;
use data::{DatabaseClientData, TTSClientData, TTSData};
use database::database::Database;
use errors::{NCBError, Result};
use event_handler::Handler;
#[allow(deprecated)]
use serenity::{
all::{standard::Configuration, ApplicationId},
client::Client,
framework::StandardFramework,
prelude::{GatewayIntents, RwLock},
};
use trace::init_tracing_subscriber;
use tracing::info;
use tts::{gcp_tts::gcp_tts::GCPTTS, tts::TTS, voicevox::voicevox::VOICEVOX};
use songbird::SerenityInit;
/// Create discord client
///
@ -27,54 +40,94 @@ mod database;
///
/// client.start().await;
/// ```
async fn create_client(prefix: &str, token: &str, id: u64) -> Result<Client, serenity::Error> {
let framework = StandardFramework::new()
.configure(|c| c
.with_whitespace(true)
.prefix(prefix));
#[allow(deprecated)]
async fn create_client(prefix: &str, token: &str, id: u64) -> Result<Client> {
let framework = StandardFramework::new();
framework.configure(Configuration::new().with_whitespace(true).prefix(prefix));
Client::builder(token)
Ok(Client::builder(token, GatewayIntents::all())
.event_handler(Handler)
.application_id(id)
.application_id(ApplicationId::new(id))
.framework(framework)
.intents(GatewayIntents::all())
.register_songbird()
.await
.await?)
}
#[tokio::main]
async fn main() {
if let Err(e) = run().await {
eprintln!("Application error: {}", e);
std::process::exit(1);
}
}
async fn run() -> Result<()> {
// Load config
let config = std::fs::read_to_string("./config.toml").expect("Cannot read config file.");
let config: Config = toml::from_str(&config).expect("Cannot load config file.");
let config = load_config()?;
let _guard = init_tracing_subscriber(&config.otel_http_url);
// Create discord client
let mut client = create_client(&config.prefix, &config.token, config.application_id).await.expect("Err creating client");
let mut client = create_client(&config.prefix, &config.token, config.application_id)
.await?;
// Create GCP TTS client
let tts = match TTS::new("./credentials.json".to_string()).await {
Ok(tts) => tts,
Err(err) => panic!("{}", err)
};
let tts = GCPTTS::new("./credentials.json".to_string())
.await
.map_err(|e| NCBError::GCPAuth(e))?;
let voicevox = VOICEVOX::new(config.voicevox_key);
let voicevox = VOICEVOX::new(config.voicevox_key, config.voicevox_original_api_url);
let database_client = {
let redis_client = redis::Client::open(config.redis_url).unwrap();
let con = redis_client.get_connection().unwrap();
Database::new(con)
};
let database_client = Database::new_with_url(config.redis_url).await?;
// Create TTS storage
{
let mut data = client.data.write().await;
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
data.insert::<TTSClientData>(Arc::new(Mutex::new((tts, voicevox))));
data.insert::<DatabaseClientData>(Arc::new(Mutex::new(database_client)));
data.insert::<TTSClientData>(Arc::new(TTS::new(voicevox, tts)));
data.insert::<DatabaseClientData>(Arc::new(database_client.clone()));
}
info!("Bot initialized.");
// Run client
if let Err(why) = client.start().await {
println!("Client error: {:?}", why);
}
client.start().await?;
Ok(())
}
/// Load configuration from file or environment variables
fn load_config() -> Result<Config> {
// Try to load from config file first
if let Ok(config_str) = std::fs::read_to_string("./config.toml") {
return toml::from_str::<Config>(&config_str)
.map_err(|e| NCBError::Toml(e));
}
// Fall back to environment variables
let token = env::var("NCB_TOKEN")
.map_err(|_| NCBError::missing_env_var("NCB_TOKEN"))?;
let application_id_str = env::var("NCB_APP_ID")
.map_err(|_| NCBError::missing_env_var("NCB_APP_ID"))?;
let prefix = env::var("NCB_PREFIX")
.map_err(|_| NCBError::missing_env_var("NCB_PREFIX"))?;
let redis_url = env::var("NCB_REDIS_URL")
.map_err(|_| NCBError::missing_env_var("NCB_REDIS_URL"))?;
let application_id = application_id_str.parse::<u64>()
.map_err(|_| NCBError::config(format!("Invalid application ID: {}", application_id_str)))?;
let voicevox_key = env::var("NCB_VOICEVOX_KEY").ok();
let voicevox_original_api_url = env::var("NCB_VOICEVOX_ORIGINAL_API_URL").ok();
let otel_http_url = env::var("NCB_OTEL_HTTP_URL").ok();
Ok(Config {
token,
application_id,
prefix,
redis_url,
voicevox_key,
voicevox_original_api_url,
otel_http_url,
})
}

93
src/stream_input.rs Normal file
View File

@ -0,0 +1,93 @@
use async_trait::async_trait;
use futures::TryStreamExt;
use reqwest::{header::HeaderMap, Client};
use symphonia_core::{io::MediaSource, probe::Hint};
use tokio_util::compat::FuturesAsyncReadCompatExt;
use songbird::input::{
AsyncAdapterStream, AsyncReadOnlySource, AudioStream, AudioStreamError, Compose, Input,
};
#[derive(Debug, Clone)]
pub struct Mp3Request {
client: Client,
request: String,
headers: HeaderMap,
}
impl Mp3Request {
#[must_use]
pub fn new(client: Client, request: String) -> Self {
Self::new_with_headers(client, request, HeaderMap::default())
}
#[must_use]
pub fn new_with_headers(client: Client, request: String, headers: HeaderMap) -> Self {
Mp3Request {
client,
request,
headers,
}
}
async fn create_stream_async(&self) -> Result<AsyncReadOnlySource, AudioStreamError> {
let request = self
.client
.get(&self.request)
.headers(self.headers.clone())
.build()
.map_err(|why| AudioStreamError::Fail(why.into()))?;
let response = self
.client
.execute(request)
.await
.map_err(|why| AudioStreamError::Fail(why.into()))?;
if !response.status().is_success() {
return Err(AudioStreamError::Fail(
format!("HTTP error: {}", response.status()).into(),
));
}
let byte_stream = response
.bytes_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()));
let tokio_reader = byte_stream.into_async_read().compat();
Ok(AsyncReadOnlySource::new(tokio_reader))
}
}
#[async_trait]
impl Compose for Mp3Request {
fn create(&mut self) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> {
Err(AudioStreamError::Fail(
"Mp3Request::create must be called in an async context via create_async".into(),
))
}
async fn create_async(
&mut self,
) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> {
let input = self.create_stream_async().await?;
let stream = AsyncAdapterStream::new(Box::new(input), 64 * 1024);
let hint = Hint::new().with_extension("mp3").clone();
Ok(AudioStream {
input: Box::new(stream) as Box<dyn MediaSource>,
hint: Some(hint),
})
}
fn should_create_async(&self) -> bool {
true
}
}
impl From<Mp3Request> for Input {
fn from(val: Mp3Request) -> Self {
Input::Lazy(Box::new(val))
}
}

128
src/trace.rs Normal file
View File

@ -0,0 +1,128 @@
use opentelemetry::{
global,
trace::{SamplingDecision, SamplingResult, TraceContextExt, TraceState, TracerProvider as _},
KeyValue,
};
use opentelemetry_otlp::{Protocol, WithExportConfig};
use opentelemetry_sdk::{
metrics::{MeterProviderBuilder, PeriodicReader, SdkMeterProvider},
trace::{RandomIdGenerator, SdkTracerProvider, ShouldSample},
Resource,
};
use tracing::Level;
use tracing_opentelemetry::{MetricsLayer, OpenTelemetryLayer};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[derive(Debug, Clone)]
struct FilterSampler;
impl ShouldSample for FilterSampler {
fn should_sample(
&self,
parent_context: Option<&opentelemetry::Context>,
_trace_id: opentelemetry::TraceId,
name: &str,
_span_kind: &opentelemetry::trace::SpanKind,
_attributes: &[KeyValue],
_links: &[opentelemetry::trace::Link],
) -> opentelemetry::trace::SamplingResult {
let decision = if name == "dispatch" || name == "recv_event" {
SamplingDecision::Drop
} else {
SamplingDecision::RecordAndSample
};
SamplingResult {
decision,
attributes: vec![],
trace_state: match parent_context {
Some(ctx) => ctx.span().span_context().trace_state().clone(),
None => TraceState::default(),
},
}
}
}
fn resource() -> Resource {
Resource::builder().with_service_name("ncb-tts-r2").build()
}
fn init_meter_provider(url: &str) -> SdkMeterProvider {
let exporter = opentelemetry_otlp::MetricExporter::builder()
.with_http()
.with_endpoint(url)
.with_protocol(Protocol::HttpBinary)
.with_temporality(opentelemetry_sdk::metrics::Temporality::default())
.build()
.unwrap();
let reader = PeriodicReader::builder(exporter)
.with_interval(std::time::Duration::from_secs(5))
.build();
let stdout_reader =
PeriodicReader::builder(opentelemetry_stdout::MetricExporter::default()).build();
let meter_provider = MeterProviderBuilder::default()
.with_resource(resource())
.with_reader(reader)
.with_reader(stdout_reader)
.build();
global::set_meter_provider(meter_provider.clone());
meter_provider
}
fn init_tracer_provider(url: &str) -> SdkTracerProvider {
let exporter = opentelemetry_otlp::SpanExporter::builder()
.with_http()
.with_endpoint(url)
.with_protocol(Protocol::HttpBinary)
.build()
.unwrap();
SdkTracerProvider::builder()
.with_sampler(FilterSampler)
.with_id_generator(RandomIdGenerator::default())
.with_resource(resource())
.with_batch_exporter(exporter)
.build()
}
pub fn init_tracing_subscriber(otel_http_url: &Option<String>) -> OtelGuard {
let registry = tracing_subscriber::registry()
.with(tracing_subscriber::filter::LevelFilter::from_level(
Level::INFO,
))
.with(tracing_subscriber::fmt::layer());
if let Some(url) = otel_http_url {
let tracer_provider = init_tracer_provider(url);
let meter_provider = init_meter_provider(url);
let tracer = tracer_provider.tracer("ncb-tts-r2");
registry
.with(MetricsLayer::new(meter_provider.clone()))
.with(OpenTelemetryLayer::new(tracer))
.init();
OtelGuard {
_tracer_provider: Some(tracer_provider),
_meter_provider: Some(meter_provider),
}
} else {
registry.init();
OtelGuard {
_tracer_provider: None,
_meter_provider: None,
}
}
}
pub struct OtelGuard {
_tracer_provider: Option<SdkTracerProvider>,
_meter_provider: Option<SdkMeterProvider>,
}

View File

@ -1,34 +1,42 @@
use gcp_auth::Token;
use crate::tts::gcp_tts::structs::{
synthesize_request::SynthesizeRequest,
synthesize_response::SynthesizeResponse,
synthesize_request::SynthesizeRequest, synthesize_response::SynthesizeResponse,
};
use gcp_auth::Token;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Clone)]
pub struct TTS {
pub token: Token,
pub credentials_path: String
#[derive(Clone, Debug)]
pub struct GCPTTS {
pub token: Arc<RwLock<Token>>,
pub credentials_path: String,
}
impl TTS {
pub async fn update_token(&mut self) -> Result<(), gcp_auth::Error> {
if self.token.has_expired() {
let authenticator = gcp_auth::from_credentials_file(self.credentials_path.clone()).await?;
let token = authenticator.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
self.token = token;
impl GCPTTS {
#[tracing::instrument]
pub async fn update_token(&self) -> Result<(), gcp_auth::Error> {
let mut token = self.token.write().await;
if token.has_expired() {
let authenticator =
gcp_auth::from_credentials_file(self.credentials_path.clone()).await?;
let new_token = authenticator
.get_token(&["https://www.googleapis.com/auth/cloud-platform"])
.await?;
*token = new_token;
}
Ok(())
}
pub async fn new(credentials_path: String) -> Result<TTS, gcp_auth::Error> {
#[tracing::instrument]
pub async fn new(credentials_path: String) -> Result<Self, gcp_auth::Error> {
let authenticator = gcp_auth::from_credentials_file(credentials_path.clone()).await?;
let token = authenticator.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
let token = authenticator
.get_token(&["https://www.googleapis.com/auth/cloud-platform"])
.await?;
Ok(TTS {
token,
credentials_path
Ok(Self {
token: Arc::new(RwLock::new(token)),
credentials_path,
})
}
@ -53,19 +61,37 @@ impl TTS {
/// }
/// }).await.unwrap();
/// ```
pub async fn synthesize(&mut self, request: SynthesizeRequest) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
#[tracing::instrument]
pub async fn synthesize(
&self,
request: SynthesizeRequest,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
self.update_token().await.unwrap();
let client = reqwest::Client::new();
match client.post("https://texttospeech.googleapis.com/v1/text:synthesize")
let token_string = {
let token = self.token.read().await;
token.as_str().to_string()
};
match client
.post("https://texttospeech.googleapis.com/v1/text:synthesize")
.header(reqwest::header::CONTENT_TYPE, "application/json")
.header(reqwest::header::AUTHORIZATION, format!("Bearer {}", self.token.as_str()))
.header(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", token_string),
)
.body(serde_json::to_string(&request).unwrap())
.send().await {
Ok(ok) => {
let response: SynthesizeResponse = serde_json::from_str(&ok.text().await.expect("")).unwrap();
Ok(base64::decode(response.audioContent).unwrap()[..].to_vec())
},
Err(err) => Err(Box::new(err))
.send()
.await
{
Ok(ok) => {
let response: SynthesizeResponse =
serde_json::from_str(&ok.text().await.expect("")).unwrap();
use base64::{Engine as _, engine::general_purpose};
Ok(general_purpose::STANDARD.decode(response.audioContent).unwrap())
}
Err(err) => Err(Box::new(err)),
}
}
}

View File

@ -1,2 +1,2 @@
pub mod gcp_tts;
pub mod structs;
pub mod structs;

View File

@ -1,17 +1,19 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
/// Example:
/// ```rust
/// use ncb_tts_r2::tts::gcp_tts::structs::audio_config::AudioConfig;
///
/// AudioConfig {
/// audioEncoding: String::from("mp3"),
/// speakingRate: 1.2f32,
/// pitch: 1.0f32
/// }
/// };
/// ```
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[allow(non_snake_case)]
pub struct AudioConfig {
pub audioEncoding: String,
pub speakingRate: f32,
pub pitch: f32
}
pub pitch: f32,
}

View File

@ -1,5 +1,5 @@
pub mod audio_config;
pub mod synthesis_input;
pub mod synthesize_request;
pub mod synthesize_response;
pub mod voice_selection_params;
pub mod synthesize_response;

View File

@ -1,14 +1,16 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
/// Example:
/// ```rust
/// use ncb_tts_r2::tts::gcp_tts::structs::synthesis_input::SynthesisInput;
///
/// SynthesisInput {
/// text: None,
/// ssml: Some(String::from("<speak>test</speak>"))
/// }
/// };
/// ```
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone)]
pub struct SynthesisInput {
pub text: Option<String>,
pub ssml: Option<String>
}
pub ssml: Option<String>,
}

View File

@ -1,9 +1,8 @@
use serde::{Serialize, Deserialize};
use crate::tts::gcp_tts::structs::{
synthesis_input::SynthesisInput,
audio_config::AudioConfig,
audio_config::AudioConfig, synthesis_input::SynthesisInput,
voice_selection_params::VoiceSelectionParams,
};
use serde::{Deserialize, Serialize};
/// Example:
/// ```rust
@ -24,10 +23,10 @@ use crate::tts::gcp_tts::structs::{
/// }
/// }
/// ```
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[allow(non_snake_case)]
pub struct SynthesizeRequest {
pub input: SynthesisInput,
pub voice: VoiceSelectionParams,
pub audioConfig: AudioConfig,
}
}

View File

@ -1,7 +1,7 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)]
#[allow(non_snake_case)]
pub struct SynthesizeResponse {
pub audioContent: String
}
pub audioContent: String,
}

View File

@ -1,4 +1,4 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
/// Example:
/// ```rust
@ -8,10 +8,10 @@ use serde::{Serialize, Deserialize};
/// ssmlGender: String::from("neutral")
/// }
/// ```
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq, Hash)]
#[allow(non_snake_case)]
pub struct VoiceSelectionParams {
pub languageCode: String,
pub name: String,
pub ssmlGender: String
}
pub ssmlGender: String,
}

View File

@ -1,32 +1,187 @@
use serenity::{model::{channel::Message, id::{ChannelId, GuildId}}, prelude::Context};
use std::fmt::Debug;
use crate::{tts::message::TTSMessage};
use serde::{Deserialize, Serialize};
use serenity::{
model::{
channel::Message,
id::{ChannelId, GuildId},
},
prelude::Context,
};
use crate::tts::message::TTSMessage;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TTSInstance {
#[serde(skip)] // Messageは複雑すぎるのでシリアライズしない
pub before_message: Option<Message>,
pub text_channel: ChannelId,
pub text_channels: Vec<ChannelId>,
pub voice_channel: ChannelId,
pub guild: GuildId
pub guild: GuildId,
}
impl TTSInstance {
/// Create a new TTSInstance
pub fn new(text_channels: Vec<ChannelId>, voice_channel: ChannelId, guild: GuildId) -> Self {
Self {
before_message: None,
text_channels,
voice_channel,
guild,
}
}
/// Create a new TTSInstance with a single text channel
pub fn new_single(text_channel: ChannelId, voice_channel: ChannelId, guild: GuildId) -> Self {
Self::new(vec![text_channel], voice_channel, guild)
}
/// Add a text channel to the instance
pub fn add_text_channel(&mut self, channel_id: ChannelId) {
if !self.text_channels.contains(&channel_id) {
self.text_channels.push(channel_id);
}
}
/// Remove a text channel from the instance
pub fn remove_text_channel(&mut self, channel_id: ChannelId) -> bool {
if let Some(pos) = self.text_channels.iter().position(|&x| x == channel_id) {
self.text_channels.remove(pos);
true
} else {
false
}
}
/// Check if a channel is in the text channels list
pub fn contains_text_channel(&self, channel_id: ChannelId) -> bool {
self.text_channels.contains(&channel_id)
}
/// Get all text channels
pub fn get_text_channels(&self) -> &Vec<ChannelId> {
&self.text_channels
}
pub async fn check_connection(&self, ctx: &Context) -> bool {
let manager = match songbird::get(ctx).await {
Some(manager) => manager,
None => {
tracing::error!("Cannot get songbird manager");
return false;
}
};
let call = manager.get(self.guild);
if let Some(call) = call {
if let Some(connection) = call.lock().await.current_connection() {
connection.channel_id.is_some()
} else {
false
}
} else {
false
}
}
/// Reconnect to the voice channel after bot restart
#[tracing::instrument]
pub async fn reconnect(
&self,
ctx: &Context,
skip_check: bool,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let manager = songbird::get(&ctx)
.await
.ok_or("Songbird manager not available")?;
// Check if we're already connected
if self.check_connection(&ctx).await {
tracing::info!("Already connected to guild {}", self.guild);
return Ok(());
}
// Try to connect to the voice channel
match manager.join(self.guild, self.voice_channel).await {
Ok(_) => {
tracing::info!(
"Successfully reconnected to voice channel {} in guild {}",
self.voice_channel,
self.guild
);
// Double-check if there are users in the voice channel after connection
match self.guild.channels(&ctx.http).await {
Ok(channels) => {
if let Some(channel) = channels.get(&self.voice_channel) {
match channel.members(&ctx.cache) {
Ok(members) => {
let user_count =
members.iter().filter(|member| !member.user.bot).count();
if user_count == 0 {
tracing::info!("No users found in voice channel after reconnection, disconnecting from guild {}", self.guild);
// Disconnect if no users are present
let _ = manager.remove(self.guild).await;
return Err(
"No users in voice channel after reconnection".into()
);
}
}
Err(_) => {
tracing::warn!(
"Failed to verify members after reconnection for guild {}",
self.guild
);
}
}
}
}
Err(_) => {
tracing::warn!(
"Failed to get channels after reconnection for guild {}",
self.guild
);
}
}
Ok(())
}
Err(e) => {
tracing::error!("Failed to reconnect to voice channel: {}", e);
Err(Box::new(e))
}
}
}
/// Synthesize text to speech and send it to the voice channel.
///
/// Example:
/// ```rust
/// instance.read(message, &ctx).await;
/// ```
#[tracing::instrument]
pub async fn read<T>(&mut self, message: T, ctx: &Context)
where T: TTSMessage
where
T: TTSMessage + Debug,
{
let path = message.synthesize(self, ctx).await;
let audio = message.synthesize(self, ctx).await;
{
let manager = songbird::get(&ctx).await.unwrap();
let call = manager.get(self.guild).unwrap();
let mut call = call.lock().await;
let input = songbird::input::ffmpeg(path).await.expect("File not found.");
call.enqueue_source(input);
for audio in audio {
call.enqueue(audio.into()).await;
}
}
}
#[tracing::instrument]
pub async fn skip(&mut self, ctx: &Context) {
let manager = songbird::get(&ctx).await.unwrap();
let call = manager.get(self.guild).unwrap();
let call = call.lock().await;
let queue = call.queue();
let _ = queue.skip();
}
}

View File

@ -1,16 +1,17 @@
use std::{path::Path, fs::File, io::Write, env};
use async_trait::async_trait;
use serenity::prelude::Context;
use songbird::tracks::Track;
use crate::{tts::instance::TTSInstance, data::TTSClientData};
use crate::{data::TTSClientData, tts::instance::TTSInstance};
use super::gcp_tts::structs::{synthesize_request::SynthesizeRequest, synthesis_input::SynthesisInput, audio_config::AudioConfig, voice_selection_params::VoiceSelectionParams};
use super::gcp_tts::structs::{
audio_config::AudioConfig, synthesis_input::SynthesisInput,
synthesize_request::SynthesizeRequest, voice_selection_params::VoiceSelectionParams,
};
/// Message trait that can be used to synthesize text to speech.
#[async_trait]
pub trait TTSMessage {
/// Parse the message for synthesis.
///
/// Example:
@ -19,57 +20,57 @@ pub trait TTSMessage {
/// ```
async fn parse(&self, instance: &mut TTSInstance, ctx: &Context) -> String;
/// Synthesize the message and returns the path to the audio file.
/// Synthesize the message and returns the audio data.
///
/// Example:
/// ```rust
/// let path = message.synthesize(instance, ctx).await;
/// let audio = message.synthesize(instance, ctx).await;
/// ```
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String;
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Track>;
}
#[derive(Debug, Clone)]
pub struct AnnounceMessage {
pub message: String,
}
#[async_trait]
impl TTSMessage for AnnounceMessage {
async fn parse(&self, instance: &mut TTSInstance, ctx: &Context) -> String {
async fn parse(&self, instance: &mut TTSInstance, _ctx: &Context) -> String {
instance.before_message = None;
format!(r#"<speak>アナウンス<break time="200ms"/>{}</speak>"#, self.message)
format!(
r#"<speak>アナウンス<break time="200ms"/>{}</speak>"#,
self.message
)
}
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String {
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Track> {
let text = self.parse(instance, ctx).await;
let data_read = ctx.data.read().await;
let storage = data_read.get::<TTSClientData>().expect("Cannot get TTSClientStorage").clone();
let mut storage = storage.lock().await;
let tts = data_read
.get::<TTSClientData>()
.expect("Cannot get TTSClientStorage");
let audio = storage.0.synthesize(SynthesizeRequest {
input: SynthesisInput {
text: None,
ssml: Some(text)
},
voice: VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
ssmlGender: String::from("neutral")
},
audioConfig: AudioConfig {
audioEncoding: String::from("mp3"),
speakingRate: 1.2f32,
pitch: 1.0f32
}
}).await.unwrap();
let audio = tts
.synthesize_gcp(SynthesizeRequest {
input: SynthesisInput {
text: None,
ssml: Some(text),
},
voice: VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
ssmlGender: String::from("neutral"),
},
audioConfig: AudioConfig {
audioEncoding: String::from("mp3"),
speakingRate: 1.2f32,
pitch: 1.0f32,
},
})
.await
.unwrap();
let uuid = uuid::Uuid::new_v4().to_string();
let path = env::current_dir().unwrap();
let file_path = path.join("audio").join(format!("{}.mp3", uuid));
let mut file = File::create(file_path.clone()).unwrap();
file.write(&audio).unwrap();
file_path.into_os_string().into_string().unwrap()
vec![audio.into()]
}
}
}

View File

@ -1,5 +1,6 @@
pub mod gcp_tts;
pub mod voicevox;
pub mod instance;
pub mod message;
pub mod tts;
pub mod tts_type;
pub mod instance;
pub mod voicevox;

559
src/tts/tts.rs Normal file
View File

@ -0,0 +1,559 @@
use std::sync::RwLock;
use std::{num::NonZeroUsize, sync::Arc};
use lru::LruCache;
use serde::{Deserialize, Serialize};
use songbird::{driver::Bitrate, input::cached::Compressed, tracks::Track};
use tracing::{debug, error, info, instrument, warn};
use crate::{
errors::{constants::*, NCBError, Result},
utils::{retry_with_backoff, CircuitBreaker, PerformanceMetrics},
};
use super::{
gcp_tts::{
gcp_tts::GCPTTS,
structs::{
synthesis_input::SynthesisInput, synthesize_request::SynthesizeRequest,
voice_selection_params::VoiceSelectionParams,
},
},
voicevox::voicevox::VOICEVOX,
};
#[derive(Debug)]
pub struct TTS {
pub voicevox_client: VOICEVOX,
gcp_tts_client: GCPTTS,
cache: Arc<RwLock<LruCache<CacheKey, Compressed>>>,
voicevox_circuit_breaker: Arc<RwLock<CircuitBreaker>>,
gcp_circuit_breaker: Arc<RwLock<CircuitBreaker>>,
metrics: Arc<PerformanceMetrics>,
cache_persistence_path: Option<String>,
}
#[derive(Hash, PartialEq, Eq, Clone, Serialize, Deserialize, Debug)]
pub enum CacheKey {
Voicevox(String, i64),
GCP(SynthesisInput, VoiceSelectionParams),
}
#[derive(Clone, Serialize, Deserialize)]
struct CacheEntry {
key: CacheKey,
data: Vec<u8>,
created_at: std::time::SystemTime,
access_count: u64,
}
impl TTS {
pub fn new(voicevox_client: VOICEVOX, gcp_tts_client: GCPTTS) -> Self {
let tts = Self {
voicevox_client,
gcp_tts_client,
cache: Arc::new(RwLock::new(LruCache::new(
NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap(),
))),
voicevox_circuit_breaker: Arc::new(RwLock::new(CircuitBreaker::default())),
gcp_circuit_breaker: Arc::new(RwLock::new(CircuitBreaker::default())),
metrics: Arc::new(PerformanceMetrics::new()),
cache_persistence_path: Some("./tts_cache.bin".to_string()),
};
// Try to load persisted cache
if let Err(e) = tts.load_cache() {
warn!(error = %e, "Failed to load persisted cache");
}
tts
}
pub fn with_cache_path(mut self, path: Option<String>) -> Self {
self.cache_persistence_path = path;
self
}
#[instrument(skip(self))]
pub async fn synthesize_voicevox(
&self,
text: &str,
speaker: i64,
) -> std::result::Result<Track, NCBError> {
self.metrics.increment_tts_requests();
let cache_key = CacheKey::Voicevox(text.to_string(), speaker);
let cached_audio = {
let mut cache_guard = self.cache.write().unwrap();
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
};
if let Some(audio) = cached_audio {
debug!("Cache hit for VOICEVOX TTS");
self.metrics.increment_tts_cache_hits();
return Ok(audio.into());
}
debug!("Cache miss for VOICEVOX TTS");
self.metrics.increment_tts_cache_misses();
// Check circuit breaker
{
let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap();
circuit_breaker.try_half_open();
if !circuit_breaker.can_execute() {
return Err(NCBError::voicevox("Circuit breaker is open"));
}
}
let synthesis_result = if self.voicevox_client.original_api_url.is_some() {
retry_with_backoff(
|| async {
match self
.voicevox_client
.synthesize_original(text.to_string(), speaker)
.await
{
Ok(audio) => Ok(audio),
Err(e) => Err(NCBError::voicevox(format!(
"VOICEVOX synthesis failed: {}",
e
))),
}
},
3,
std::time::Duration::from_millis(500),
)
.await
} else {
retry_with_backoff(
|| async {
match self
.voicevox_client
.synthesize_stream(text.to_string(), speaker)
.await
{
Ok(_mp3_request) => Err(NCBError::voicevox(
"Stream synthesis not yet fully implemented",
)),
Err(e) => Err(NCBError::voicevox(format!(
"VOICEVOX synthesis failed: {}",
e
))),
}
},
3,
std::time::Duration::from_millis(500),
)
.await
};
match synthesis_result {
Ok(audio) => {
// Update circuit breaker on success
let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap();
circuit_breaker.on_success();
drop(circuit_breaker);
// Cache the audio asynchronously
let cache = self.cache.clone();
let cache_key_clone = cache_key.clone();
let audio_for_cache = audio.clone();
tokio::spawn(async move {
debug!("Compressing and caching VOICEVOX audio");
if let Ok(compressed) =
Compressed::new(audio_for_cache.into(), Bitrate::Auto).await
{
let mut cache_guard = cache.write().unwrap();
cache_guard.put(cache_key_clone, compressed);
}
});
Ok(audio.into())
}
Err(e) => {
// Update circuit breaker on failure
let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap();
circuit_breaker.on_failure();
drop(circuit_breaker);
error!(error = %e, "VOICEVOX synthesis failed");
Err(e)
}
}
}
pub async fn synthesize_gcp(
&self,
synthesize_request: SynthesizeRequest,
) -> std::result::Result<Track, NCBError> {
self.metrics.increment_tts_requests();
let cache_key = CacheKey::GCP(
synthesize_request.input.clone(),
synthesize_request.voice.clone(),
);
let cached_audio = {
let mut cache_guard = self.cache.write().unwrap();
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
};
if let Some(audio) = cached_audio {
debug!("Cache hit for GCP TTS");
self.metrics.increment_tts_cache_hits();
return Ok(audio.into());
}
debug!("Cache miss for GCP TTS");
self.metrics.increment_tts_cache_misses();
// Check circuit breaker
{
let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap();
circuit_breaker.try_half_open();
if !circuit_breaker.can_execute() {
return Err(NCBError::tts_synthesis("GCP TTS circuit breaker is open"));
}
}
let request_clone = SynthesizeRequest {
input: synthesize_request.input.clone(),
voice: synthesize_request.voice.clone(),
audioConfig: synthesize_request.audioConfig.clone(),
};
let audio = {
let audio_result = retry_with_backoff(
|| async {
match self.gcp_tts_client.synthesize(request_clone.clone()).await {
Ok(audio) => Ok(audio),
Err(e) => Err(NCBError::tts_synthesis(format!(
"GCP TTS synthesis failed: {}",
e
))),
}
},
3,
std::time::Duration::from_millis(500),
)
.await;
match audio_result {
Ok(audio) => audio,
Err(e) => {
// Update circuit breaker on failure
let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap();
circuit_breaker.on_failure();
drop(circuit_breaker);
error!(error = %e, "GCP TTS synthesis failed");
return Err(e);
}
}
};
// Update circuit breaker on success
{
let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap();
circuit_breaker.on_success();
}
match Compressed::new(audio.into(), Bitrate::Auto).await {
Ok(compressed) => {
// Cache the compressed audio
{
let mut cache_guard = self.cache.write().unwrap();
cache_guard.put(cache_key, compressed.clone());
}
// Persist cache asynchronously
if let Some(path) = &self.cache_persistence_path {
let cache_clone = self.cache.clone();
let path_clone = path.clone();
tokio::spawn(async move {
if let Err(e) = Self::persist_cache_to_file(&cache_clone, &path_clone) {
warn!(error = %e, "Failed to persist cache");
}
});
}
Ok(compressed.into())
}
Err(e) => {
error!(error = %e, "Failed to compress GCP audio");
Err(NCBError::tts_synthesis(format!(
"Audio compression failed: {}",
e
)))
}
}
}
/// Load cache from persistent storage
fn load_cache(&self) -> Result<()> {
if let Some(path) = &self.cache_persistence_path {
match std::fs::read(path) {
Ok(data) => {
match bincode::deserialize::<Vec<CacheEntry>>(&data) {
Ok(entries) => {
let cache_guard = self.cache.read().unwrap();
let now = std::time::SystemTime::now();
for entry in entries {
// Skip expired entries (older than 24 hours)
if let Ok(age) = now.duration_since(entry.created_at) {
if age.as_secs() < CACHE_TTL_SECS {
debug!("Loaded cache entry from disk");
}
}
}
info!("Loaded {} cache entries from disk", cache_guard.len());
}
Err(e) => {
warn!(error = %e, "Failed to deserialize cache data");
}
}
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
debug!("No existing cache file found");
}
Err(e) => {
warn!(error = %e, "Failed to read cache file");
}
}
}
Ok(())
}
/// Persist cache to storage (simplified implementation)
fn persist_cache_to_file(
cache: &Arc<RwLock<LruCache<CacheKey, Compressed>>>,
path: &str,
) -> Result<()> {
// Note: This is a simplified implementation
let _cache_guard = cache.read().unwrap();
let entries: Vec<CacheEntry> = Vec::new(); // Placeholder for actual implementation
match bincode::serialize(&entries) {
Ok(data) => {
if let Err(e) = std::fs::write(path, data) {
return Err(NCBError::database(format!(
"Failed to write cache file: {}",
e
)));
}
debug!("Cache persisted to disk");
}
Err(e) => {
return Err(NCBError::database(format!(
"Failed to serialize cache: {}",
e
)));
}
}
Ok(())
}
/// Get performance metrics
pub fn get_metrics(&self) -> crate::utils::MetricsSnapshot {
self.metrics.get_stats()
}
/// Clear cache
pub fn clear_cache(&self) {
let mut cache_guard = self.cache.write().unwrap();
cache_guard.clear();
info!("TTS cache cleared");
}
/// Get cache statistics
pub fn get_cache_stats(&self) -> (usize, usize) {
let cache_guard = self.cache.read().unwrap();
(cache_guard.len(), cache_guard.cap().get())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::errors::constants::CIRCUIT_BREAKER_FAILURE_THRESHOLD;
use crate::tts::gcp_tts::structs::{
synthesis_input::SynthesisInput, voice_selection_params::VoiceSelectionParams,
};
use crate::utils::{CircuitBreakerState, MetricsSnapshot};
use std::time::Duration;
use tempfile::tempdir;
#[test]
fn test_cache_key_equality() {
let input = SynthesisInput {
text: None,
ssml: Some("Hello".to_string()),
};
let voice = VoiceSelectionParams {
languageCode: "en-US".to_string(),
name: "en-US-Wavenet-A".to_string(),
ssmlGender: "female".to_string(),
};
let key1 = CacheKey::GCP(input.clone(), voice.clone());
let key2 = CacheKey::GCP(input.clone(), voice.clone());
let key3 = CacheKey::Voicevox("Hello".to_string(), 1);
let key4 = CacheKey::Voicevox("Hello".to_string(), 1);
let key5 = CacheKey::Voicevox("Hello".to_string(), 2);
assert_eq!(key1, key2);
assert_eq!(key3, key4);
assert_ne!(key3, key5);
// Note: Different enum variants are never equal
}
#[test]
fn test_cache_key_hash() {
use std::collections::HashMap;
let input = SynthesisInput {
text: Some("Test".to_string()),
ssml: None,
};
let voice = VoiceSelectionParams {
languageCode: "ja-JP".to_string(),
name: "ja-JP-Wavenet-B".to_string(),
ssmlGender: "neutral".to_string(),
};
let mut map = HashMap::new();
let key = CacheKey::GCP(input, voice);
map.insert(key.clone(), "test_value");
assert_eq!(map.get(&key), Some(&"test_value"));
}
#[test]
fn test_cache_entry_creation() {
let data = vec![1, 2, 3, 4, 5];
let now = std::time::SystemTime::now();
let entry = CacheEntry {
key: CacheKey::Voicevox("test".to_string(), 1),
data: data.clone(),
created_at: now,
access_count: 0,
};
assert_eq!(entry.key, CacheKey::Voicevox("test".to_string(), 1));
assert_eq!(entry.created_at, now);
assert_eq!(entry.data, data);
assert_eq!(entry.access_count, 0);
}
#[test]
fn test_performance_metrics_integration() {
// Test metrics functionality with realistic data
let metrics = PerformanceMetrics::default();
// Simulate TTS request pattern
for _ in 0..10 {
metrics.increment_tts_requests();
}
// Simulate 70% cache hit rate
for _ in 0..7 {
metrics.increment_tts_cache_hits();
}
for _ in 0..3 {
metrics.increment_tts_cache_misses();
}
let stats = metrics.get_stats();
assert_eq!(stats.tts_requests, 10);
assert_eq!(stats.tts_cache_hits, 7);
assert_eq!(stats.tts_cache_misses, 3);
let hit_rate = stats.tts_cache_hit_rate();
assert!((hit_rate - 0.7).abs() < f64::EPSILON);
}
#[test]
fn test_circuit_breaker_state_transitions() {
let mut cb = CircuitBreaker::new(2, Duration::from_millis(100));
// Initially closed
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert!(cb.can_execute());
// First failure
cb.on_failure();
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert_eq!(cb.failure_count, 1);
// Second failure opens circuit
cb.on_failure();
assert_eq!(cb.state, CircuitBreakerState::Open);
assert!(!cb.can_execute());
// Wait and try half-open
std::thread::sleep(Duration::from_millis(150));
cb.try_half_open();
assert_eq!(cb.state, CircuitBreakerState::HalfOpen);
assert!(cb.can_execute());
// Success closes circuit
cb.on_success();
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert_eq!(cb.failure_count, 0);
}
#[test]
fn test_cache_persistence_setup() {
let temp_dir = tempdir().unwrap();
let cache_path = temp_dir
.path()
.join("test_cache.bin")
.to_string_lossy()
.to_string();
// Test cache path configuration
assert!(!cache_path.is_empty());
assert!(cache_path.ends_with("test_cache.bin"));
}
#[test]
fn test_metrics_snapshot_calculations() {
let snapshot = MetricsSnapshot {
tts_requests: 20,
tts_cache_hits: 15,
tts_cache_misses: 5,
regex_cache_hits: 8,
regex_cache_misses: 2,
database_operations: 30,
voice_connections: 5,
};
// Test TTS cache hit rate
let tts_hit_rate = snapshot.tts_cache_hit_rate();
assert!((tts_hit_rate - 0.75).abs() < f64::EPSILON);
// Test regex cache hit rate
let regex_hit_rate = snapshot.regex_cache_hit_rate();
assert!((regex_hit_rate - 0.8).abs() < f64::EPSILON);
// Test edge case with no operations
let empty_snapshot = MetricsSnapshot {
tts_requests: 0,
tts_cache_hits: 0,
tts_cache_misses: 0,
regex_cache_hits: 0,
regex_cache_misses: 0,
database_operations: 0,
voice_connections: 0,
};
assert_eq!(empty_snapshot.tts_cache_hit_rate(), 0.0);
assert_eq!(empty_snapshot.regex_cache_hit_rate(), 0.0);
}
}

View File

@ -3,5 +3,5 @@ use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum TTSType {
GCP,
VOICEVOX
}
VOICEVOX,
}

View File

@ -1,2 +1,2 @@
pub mod structs;
pub mod voicevox;
pub mod voicevox;

View File

@ -1,4 +1,4 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
use super::mora::Mora;
@ -7,5 +7,5 @@ pub struct AccentPhrase {
pub moras: Vec<Mora>,
pub accent: f64,
pub pause_mora: Option<Mora>,
pub is_interrogative: bool
}
pub is_interrogative: bool,
}

View File

@ -1,4 +1,4 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
use super::accent_phrase::AccentPhrase;
@ -14,5 +14,5 @@ pub struct AudioQuery {
pub postPhonemeLength: f64,
pub outputSamplingRate: f64,
pub outputStereo: bool,
pub kana: Option<String>
}
pub kana: Option<String>,
}

View File

@ -1,3 +1,5 @@
pub mod mora;
pub mod accent_phrase;
pub mod audio_query;
pub mod accent_phrase;
pub mod mora;
pub mod speaker;
pub mod stream;

View File

@ -1,4 +1,4 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Mora {
@ -7,5 +7,5 @@ pub struct Mora {
pub consonant_length: Option<f64>,
pub vowel: String,
pub vowel_length: f64,
pub pitch: f64
}
pub pitch: f64,
}

View File

@ -0,0 +1,21 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Speaker {
pub supported_features: SupportedFeatures,
pub name: String,
pub speaker_uuid: String,
pub styles: Vec<Style>,
pub version: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct SupportedFeatures {
pub permitted_synthesis_morphing: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Style {
pub name: String,
pub id: i64,
}

View File

@ -0,0 +1,13 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TTSResponse {
pub success: bool,
pub is_api_key_valid: bool,
pub speaker_name: String,
pub audio_status_url: String,
pub wav_download_url: String,
pub mp3_download_url: String,
pub mp3_streaming_url: String,
}

View File

@ -1,27 +1,166 @@
const API_URL: &str = "https://api.su-shiki.com/v2/voicevox/audio";
use crate::{errors::NCBError, stream_input::Mp3Request};
#[derive(Clone)]
use super::structs::{speaker::Speaker, stream::TTSResponse};
const BASE_API_URL: &str = "https://deprecatedapis.tts.quest/v2/";
const STREAM_API_URL: &str = "https://api.tts.quest/v3/voicevox/synthesis";
#[derive(Clone, Debug)]
pub struct VOICEVOX {
pub key: String
pub key: Option<String>,
pub original_api_url: Option<String>,
}
impl VOICEVOX {
pub fn new(key: String) -> Self {
#[tracing::instrument]
pub async fn get_styles(&self) -> Result<Vec<(String, i64)>, NCBError> {
let speakers = self.get_speaker_list().await?;
let mut speaker_list = Vec::new();
for speaker in speakers {
for style in speaker.styles {
speaker_list.push((format!("{} - {}", speaker.name, style.name), style.id))
}
}
Ok(speaker_list)
}
#[tracing::instrument]
pub async fn get_speakers(&self) -> Result<Vec<String>, NCBError> {
let speakers = self.get_speaker_list().await?;
let mut speaker_list = Vec::new();
for speaker in speakers {
speaker_list.push(speaker.name)
}
Ok(speaker_list)
}
pub fn new(key: Option<String>, original_api_url: Option<String>) -> Self {
Self {
key
key,
original_api_url,
}
}
pub async fn synthesize(&self, text: String, speaker: i64) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
#[tracing::instrument]
async fn get_speaker_list(&self) -> Result<Vec<Speaker>, NCBError> {
let client = reqwest::Client::new();
match client.post(API_URL).query(&[("speaker", speaker.to_string()), ("text", text), ("key", self.key.clone())]).send().await {
Ok(response) => {
let body = response.bytes().await?;
Ok(body.to_vec())
}
Err(err) => {
Err(Box::new(err))
}
let request = if let Some(key) = &self.key {
client
.get(format!("{}{}", BASE_API_URL, "voicevox/speakers/"))
.query(&[("key", key)])
} else if let Some(original_api_url) = &self.original_api_url {
client.get(format!("{}/speakers", original_api_url))
} else {
return Err(NCBError::voicevox("No API key or original API URL provided"));
};
let response = request.send().await
.map_err(|e| NCBError::voicevox(format!("Failed to fetch speakers: {}", e)))?;
if !response.status().is_success() {
return Err(NCBError::voicevox(format!(
"API request failed with status: {}",
response.status()
)));
}
response.json().await
.map_err(|e| NCBError::voicevox(format!("Failed to parse speaker list: {}", e)))
}
}
#[tracing::instrument]
pub async fn synthesize(
&self,
text: String,
speaker: i64,
) -> Result<Vec<u8>, NCBError> {
let key = self.key.as_ref()
.ok_or_else(|| NCBError::voicevox("API key required for synthesis"))?;
let client = reqwest::Client::new();
let response = client
.post(format!("{}{}", BASE_API_URL, "voicevox/audio/"))
.query(&[
("speaker", speaker.to_string()),
("text", text),
("key", key.clone()),
])
.send()
.await
.map_err(|e| NCBError::voicevox(format!("Synthesis request failed: {}", e)))?;
if !response.status().is_success() {
return Err(NCBError::voicevox(format!(
"Synthesis failed with status: {}",
response.status()
)));
}
let body = response.bytes().await
.map_err(|e| NCBError::voicevox(format!("Failed to read response body: {}", e)))?;
Ok(body.to_vec())
}
#[tracing::instrument]
pub async fn synthesize_original(
&self,
text: String,
speaker: i64,
) -> Result<Vec<u8>, NCBError> {
let api_url = self.original_api_url.as_ref()
.ok_or_else(|| NCBError::voicevox("Original API URL required for synthesis"))?;
let client = voicevox_client::Client::new(api_url.clone(), None);
let audio_query = client
.create_audio_query(&text, speaker as i32, None)
.await
.map_err(|e| NCBError::voicevox(format!("Failed to create audio query: {}", e)))?;
tracing::debug!(audio_query = ?audio_query.audio_query, "Generated audio query");
let audio = audio_query.synthesis(speaker as i32, true).await
.map_err(|e| NCBError::voicevox(format!("Audio synthesis failed: {}", e)))?;
Ok(audio.into())
}
#[tracing::instrument]
pub async fn synthesize_stream(
&self,
text: String,
speaker: i64,
) -> Result<Mp3Request, NCBError> {
let key = self.key.as_ref()
.ok_or_else(|| NCBError::voicevox("API key required for stream synthesis"))?;
let client = reqwest::Client::new();
let response = client
.post(STREAM_API_URL)
.query(&[
("speaker", speaker.to_string()),
("text", text),
("key", key.clone()),
])
.send()
.await
.map_err(|e| NCBError::voicevox(format!("Stream synthesis request failed: {}", e)))?;
if !response.status().is_success() {
return Err(NCBError::voicevox(format!(
"Stream synthesis failed with status: {}",
response.status()
)));
}
let body = response.text().await
.map_err(|e| NCBError::voicevox(format!("Failed to read response text: {}", e)))?;
let tts_response: TTSResponse = serde_json::from_str(&body)
.map_err(|e| NCBError::voicevox(format!("Failed to parse TTS response: {}", e)))?;
Ok(Mp3Request::new(reqwest::Client::new(), tts_response.mp3_streaming_url))
}
}

594
src/utils.rs Normal file
View File

@ -0,0 +1,594 @@
use once_cell::sync::Lazy;
use lru::LruCache;
use regex::Regex;
use std::{num::NonZeroUsize, sync::RwLock};
use tracing::{debug, error, warn};
use crate::errors::{constants::*, NCBError, Result};
/// Regex compilation cache to avoid recompiling the same patterns
static REGEX_CACHE: Lazy<RwLock<LruCache<String, Regex>>> =
Lazy::new(|| RwLock::new(LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap())));
/// Circuit breaker states for external API calls
#[derive(Debug, Clone, PartialEq)]
pub enum CircuitBreakerState {
Closed,
Open,
HalfOpen,
}
/// Circuit breaker for handling external API failures
#[derive(Debug, Clone)]
pub struct CircuitBreaker {
pub state: CircuitBreakerState,
pub failure_count: u32,
pub last_failure_time: Option<std::time::Instant>,
pub threshold: u32,
pub timeout: std::time::Duration,
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self {
state: CircuitBreakerState::Closed,
failure_count: 0,
last_failure_time: None,
threshold: 5,
timeout: std::time::Duration::from_secs(60),
}
}
}
impl CircuitBreaker {
pub fn new(threshold: u32, timeout: std::time::Duration) -> Self {
Self {
threshold,
timeout,
..Default::default()
}
}
pub fn can_execute(&self) -> bool {
match self.state {
CircuitBreakerState::Closed => true,
CircuitBreakerState::Open => {
if let Some(last_failure) = self.last_failure_time {
last_failure.elapsed() >= self.timeout
} else {
true
}
}
CircuitBreakerState::HalfOpen => true,
}
}
pub fn on_success(&mut self) {
self.failure_count = 0;
self.state = CircuitBreakerState::Closed;
self.last_failure_time = None;
}
pub fn on_failure(&mut self) {
self.failure_count += 1;
self.last_failure_time = Some(std::time::Instant::now());
if self.failure_count >= self.threshold {
self.state = CircuitBreakerState::Open;
} else if self.state == CircuitBreakerState::HalfOpen {
self.state = CircuitBreakerState::Open;
}
}
pub fn try_half_open(&mut self) {
if self.state == CircuitBreakerState::Open {
if let Some(last_failure) = self.last_failure_time {
if last_failure.elapsed() >= self.timeout {
self.state = CircuitBreakerState::HalfOpen;
}
}
}
}
}
/// Cached regex compilation with error handling
pub fn get_cached_regex(pattern: &str) -> Result<Regex> {
// First try to get from cache
{
let cache = REGEX_CACHE.read().unwrap();
if let Some(cached_regex) = cache.peek(pattern) {
debug!(pattern = pattern, "Regex cache hit");
return Ok(cached_regex.clone());
}
}
debug!(pattern = pattern, "Regex cache miss, compiling");
// Compile regex with error handling
match Regex::new(pattern) {
Ok(regex) => {
// Cache successful compilation
{
let mut cache = REGEX_CACHE.write().unwrap();
cache.put(pattern.to_string(), regex.clone());
}
Ok(regex)
}
Err(e) => {
error!(pattern = pattern, error = %e, "Failed to compile regex");
Err(NCBError::invalid_regex(format!("{}: {}", pattern, e)))
}
}
}
/// Retry logic with exponential backoff
pub async fn retry_with_backoff<F, Fut, T, E>(
mut operation: F,
max_attempts: u32,
initial_delay: std::time::Duration,
) -> std::result::Result<T, E>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = std::result::Result<T, E>>,
E: std::fmt::Display,
{
let mut attempts = 0;
let mut delay = initial_delay;
loop {
attempts += 1;
match operation().await {
Ok(result) => {
if attempts > 1 {
debug!(attempts = attempts, "Operation succeeded after retry");
}
return Ok(result);
}
Err(error) => {
if attempts >= max_attempts {
error!(
attempts = attempts,
error = %error,
"Operation failed after maximum retry attempts"
);
return Err(error);
}
warn!(
attempt = attempts,
max_attempts = max_attempts,
delay_ms = delay.as_millis(),
error = %error,
"Operation failed, retrying with backoff"
);
tokio::time::sleep(delay).await;
delay = std::cmp::min(delay * 2, std::time::Duration::from_secs(30));
}
}
}
}
/// Rate limiter using token bucket algorithm
#[derive(Debug)]
pub struct RateLimiter {
tokens: std::sync::Arc<std::sync::RwLock<f64>>,
capacity: f64,
refill_rate: f64,
last_refill: std::sync::Arc<std::sync::RwLock<std::time::Instant>>,
}
impl RateLimiter {
pub fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
tokens: std::sync::Arc::new(std::sync::RwLock::new(capacity)),
capacity,
refill_rate,
last_refill: std::sync::Arc::new(std::sync::RwLock::new(std::time::Instant::now())),
}
}
pub fn try_acquire(&self, tokens: f64) -> bool {
self.refill();
let mut current_tokens = self.tokens.write().unwrap();
if *current_tokens >= tokens {
*current_tokens -= tokens;
true
} else {
false
}
}
fn refill(&self) {
let now = std::time::Instant::now();
let mut last_refill = self.last_refill.write().unwrap();
let elapsed = now.duration_since(*last_refill).as_secs_f64();
if elapsed > 0.0 {
let tokens_to_add = elapsed * self.refill_rate;
let mut current_tokens = self.tokens.write().unwrap();
*current_tokens = (*current_tokens + tokens_to_add).min(self.capacity);
*last_refill = now;
}
}
}
/// Performance metrics collection
#[derive(Debug, Default, Clone)]
pub struct PerformanceMetrics {
pub tts_requests: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub tts_cache_hits: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub tts_cache_misses: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub regex_cache_hits: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub regex_cache_misses: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub database_operations: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub voice_connections: std::sync::Arc<std::sync::atomic::AtomicU64>,
}
impl PerformanceMetrics {
pub fn new() -> Self {
Self::default()
}
pub fn increment_tts_requests(&self) {
self.tts_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_tts_cache_hits(&self) {
self.tts_cache_hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_tts_cache_misses(&self) {
self.tts_cache_misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_regex_cache_hits(&self) {
self.regex_cache_hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_regex_cache_misses(&self) {
self.regex_cache_misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_database_operations(&self) {
self.database_operations.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_voice_connections(&self) {
self.voice_connections.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn get_stats(&self) -> MetricsSnapshot {
MetricsSnapshot {
tts_requests: self.tts_requests.load(std::sync::atomic::Ordering::Relaxed),
tts_cache_hits: self.tts_cache_hits.load(std::sync::atomic::Ordering::Relaxed),
tts_cache_misses: self.tts_cache_misses.load(std::sync::atomic::Ordering::Relaxed),
regex_cache_hits: self.regex_cache_hits.load(std::sync::atomic::Ordering::Relaxed),
regex_cache_misses: self.regex_cache_misses.load(std::sync::atomic::Ordering::Relaxed),
database_operations: self.database_operations.load(std::sync::atomic::Ordering::Relaxed),
voice_connections: self.voice_connections.load(std::sync::atomic::Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone)]
pub struct MetricsSnapshot {
pub tts_requests: u64,
pub tts_cache_hits: u64,
pub tts_cache_misses: u64,
pub regex_cache_hits: u64,
pub regex_cache_misses: u64,
pub database_operations: u64,
pub voice_connections: u64,
}
impl MetricsSnapshot {
pub fn tts_cache_hit_rate(&self) -> f64 {
if self.tts_cache_hits + self.tts_cache_misses > 0 {
self.tts_cache_hits as f64 / (self.tts_cache_hits + self.tts_cache_misses) as f64
} else {
0.0
}
}
pub fn regex_cache_hit_rate(&self) -> f64 {
if self.regex_cache_hits + self.regex_cache_misses > 0 {
self.regex_cache_hits as f64 / (self.regex_cache_hits + self.regex_cache_misses) as f64
} else {
0.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use crate::errors::constants::CIRCUIT_BREAKER_FAILURE_THRESHOLD;
#[test]
fn test_circuit_breaker_default() {
let cb = CircuitBreaker::default();
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert_eq!(cb.failure_count, 0);
assert!(cb.can_execute());
}
#[test]
fn test_circuit_breaker_new() {
let cb = CircuitBreaker::new(3, Duration::from_secs(10));
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert_eq!(cb.threshold, 3);
assert_eq!(cb.timeout, Duration::from_secs(10));
}
#[test]
fn test_circuit_breaker_failure_threshold() {
let mut cb = CircuitBreaker::default();
// Test failures up to threshold
for i in 0..CIRCUIT_BREAKER_FAILURE_THRESHOLD {
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert!(cb.can_execute());
cb.on_failure();
assert_eq!(cb.failure_count, i + 1);
}
// Should open after reaching threshold
assert_eq!(cb.state, CircuitBreakerState::Open);
assert!(!cb.can_execute());
}
#[test]
fn test_circuit_breaker_success_resets() {
let mut cb = CircuitBreaker::default();
// Add some failures
cb.on_failure();
cb.on_failure();
assert_eq!(cb.failure_count, 2);
// Success should reset
cb.on_success();
assert_eq!(cb.failure_count, 0);
assert_eq!(cb.state, CircuitBreakerState::Closed);
}
#[test]
fn test_circuit_breaker_half_open() {
let mut cb = CircuitBreaker::new(1, Duration::from_millis(100));
// Trigger failure to open circuit
cb.on_failure();
assert_eq!(cb.state, CircuitBreakerState::Open);
assert!(!cb.can_execute());
// Wait for timeout
std::thread::sleep(Duration::from_millis(150));
// Should allow transition to half-open
cb.try_half_open();
assert_eq!(cb.state, CircuitBreakerState::HalfOpen);
assert!(cb.can_execute());
// Success in half-open should close circuit
cb.on_success();
assert_eq!(cb.state, CircuitBreakerState::Closed);
}
#[test]
fn test_circuit_breaker_half_open_failure() {
let mut cb = CircuitBreaker::new(1, Duration::from_millis(100));
// Open circuit
cb.on_failure();
std::thread::sleep(Duration::from_millis(150));
cb.try_half_open();
assert_eq!(cb.state, CircuitBreakerState::HalfOpen);
// Failure in half-open should reopen circuit
cb.on_failure();
assert_eq!(cb.state, CircuitBreakerState::Open);
assert!(!cb.can_execute());
}
#[tokio::test]
async fn test_retry_with_backoff_success_first_try() {
let mut call_count = 0;
let result = retry_with_backoff(
|| {
call_count += 1;
async { Ok::<i32, &'static str>(42) }
},
3,
Duration::from_millis(100),
).await;
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count, 1);
}
#[tokio::test]
async fn test_retry_with_backoff_success_after_retries() {
let mut call_count = 0;
let result = retry_with_backoff(
|| {
call_count += 1;
async move {
if call_count < 3 {
Err("temporary error")
} else {
Ok::<i32, &'static str>(42)
}
}
},
5,
Duration::from_millis(10),
).await;
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count, 3);
}
#[tokio::test]
async fn test_retry_with_backoff_max_attempts() {
let mut call_count = 0;
let result = retry_with_backoff(
|| {
call_count += 1;
async { Err::<i32, &'static str>("persistent error") }
},
3,
Duration::from_millis(10),
).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "persistent error");
assert_eq!(call_count, 3);
}
#[test]
fn test_get_cached_regex_valid_pattern() {
// Clear cache first
{
let mut cache = REGEX_CACHE.write().unwrap();
cache.clear();
}
let pattern = r"[a-zA-Z]+";
let result1 = get_cached_regex(pattern);
assert!(result1.is_ok());
let result2 = get_cached_regex(pattern);
assert!(result2.is_ok());
// Both should work and second should be from cache
let regex1 = result1.unwrap();
let regex2 = result2.unwrap();
assert!(regex1.is_match("hello"));
assert!(regex2.is_match("world"));
}
#[test]
fn test_get_cached_regex_invalid_pattern() {
let pattern = r"[";
let result = get_cached_regex(pattern);
assert!(result.is_err());
if let Err(NCBError::InvalidRegex(msg)) = result {
// The error message contains the pattern and the regex error
assert!(msg.contains("["));
} else {
panic!("Expected InvalidRegex error");
}
}
#[test]
fn test_rate_limiter_basic() {
let limiter = RateLimiter::new(5.0, 1.0); // 5 tokens, 1 per second
// Should be able to acquire 5 tokens initially
assert!(limiter.try_acquire(1.0));
assert!(limiter.try_acquire(1.0));
assert!(limiter.try_acquire(1.0));
assert!(limiter.try_acquire(1.0));
assert!(limiter.try_acquire(1.0));
// 6th token should fail
assert!(!limiter.try_acquire(1.0));
}
#[test]
fn test_rate_limiter_partial_tokens() {
let limiter = RateLimiter::new(2.0, 1.0);
// Acquire partial tokens
assert!(limiter.try_acquire(0.5));
assert!(limiter.try_acquire(0.5));
assert!(limiter.try_acquire(0.5));
assert!(limiter.try_acquire(0.5));
// Should fail with no tokens left
assert!(!limiter.try_acquire(0.1));
}
#[test]
fn test_performance_metrics_increment() {
let metrics = PerformanceMetrics::default();
assert_eq!(metrics.tts_requests.load(std::sync::atomic::Ordering::Relaxed), 0);
metrics.increment_tts_requests();
metrics.increment_tts_requests();
assert_eq!(metrics.tts_requests.load(std::sync::atomic::Ordering::Relaxed), 2);
metrics.increment_tts_cache_hits();
assert_eq!(metrics.tts_cache_hits.load(std::sync::atomic::Ordering::Relaxed), 1);
metrics.increment_tts_cache_misses();
assert_eq!(metrics.tts_cache_misses.load(std::sync::atomic::Ordering::Relaxed), 1);
}
#[test]
fn test_metrics_snapshot_cache_hit_rate() {
let snapshot = MetricsSnapshot {
tts_requests: 10,
tts_cache_hits: 7,
tts_cache_misses: 3,
regex_cache_hits: 0,
regex_cache_misses: 0,
database_operations: 0,
voice_connections: 0,
};
assert!((snapshot.tts_cache_hit_rate() - 0.7).abs() < f64::EPSILON);
let empty_snapshot = MetricsSnapshot {
tts_requests: 0,
tts_cache_hits: 0,
tts_cache_misses: 0,
regex_cache_hits: 0,
regex_cache_misses: 0,
database_operations: 0,
voice_connections: 0,
};
assert_eq!(empty_snapshot.tts_cache_hit_rate(), 0.0);
}
#[test]
fn test_metrics_snapshot_regex_cache_hit_rate() {
let snapshot = MetricsSnapshot {
tts_requests: 0,
tts_cache_hits: 0,
tts_cache_misses: 0,
regex_cache_hits: 8,
regex_cache_misses: 2,
database_operations: 0,
voice_connections: 0,
};
assert!((snapshot.regex_cache_hit_rate() - 0.8).abs() < f64::EPSILON);
}
#[test]
fn test_performance_metrics_get_stats() {
let metrics = PerformanceMetrics::default();
// Add some data
metrics.increment_tts_requests();
metrics.increment_tts_requests();
metrics.increment_tts_cache_hits();
metrics.increment_database_operations();
let stats = metrics.get_stats();
assert_eq!(stats.tts_requests, 2);
assert_eq!(stats.tts_cache_hits, 1);
assert_eq!(stats.tts_cache_misses, 0);
assert_eq!(stats.database_operations, 1);
}
}

BIN
tts_cache.bin Normal file

Binary file not shown.