5 Commits

Author SHA1 Message Date
43cce7dc31 fix: 自動参加のバグ修正
アナウンスが無効の時に自動参加も無効になるバグを修正
2025-05-28 17:22:08 +09:00
2f06f6be3b Merge branch 'master' of github.com:mii443/ncb-tts-r2 2025-05-28 16:10:51 +09:00
f0327e232a feat: テキストチャンネルの自動参加設定を追加
- 複数のテキストチャンネルをサポートするために、TTSインスタンスの構造を変更
- 自動参加テキストチャンネルの設定と解除をUIセレクトメニューで実装
- 再接続時にテキストチャンネルに通知を送信する機能を強化
- コードの可読性向上のために、エラーハンドリングとロギングを改善

🤖 Generated with [Claude Code](https://claude.ai/code)
2025-05-28 16:08:34 +09:00
733646b6b8 refactor: Major overhaul with error handling, resilience patterns, and observability
- Add library configuration to support both lib and binary targets
- Implement unified error handling with NCBError throughout the codebase
- Add circuit breaker pattern for external API calls (Voicevox, GCP TTS)
- Introduce comprehensive performance metrics and monitoring
- Add cache persistence with disk storage support
- Implement retry mechanism with exponential backoff
- Add configuration file support (config.toml) with env var fallback
- Enhance logging with structured tracing (debug, warn, error levels)
- Add extensive unit tests for cache, metrics, and circuit breaker
- Update base64 decoding to use modern API
- Improve API error handling for Voicevox and GCP TTS clients

Breaking changes:
- Function signatures now return Result<T, NCBError> instead of panicking
- Cache key structure modified with serialization support
2025-05-28 01:01:12 +09:00
9e7d89eaa5 Update build.yml 2025-05-26 18:27:22 +09:00
24 changed files with 2775 additions and 464 deletions

View File

@ -32,5 +32,3 @@ jobs:
platforms: linux/amd64
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max

View File

@ -3,6 +3,14 @@ name = "ncb-tts-r2"
version = "1.11.2"
edition = "2021"
[lib]
name = "ncb_tts_r2"
path = "src/lib.rs"
[[bin]]
name = "ncb-tts-r2"
path = "src/main.rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
@ -13,10 +21,15 @@ gcp_auth = "0.5.0"
reqwest = { version = "0.12.9", features = ["json"] }
base64 = "0.22.1"
async-trait = "0.1.57"
redis = "0.29.2"
redis = { version = "0.29.2", features = ["aio", "tokio-comp"] }
bb8 = "0.8"
bb8-redis = "0.16"
thiserror = "1.0"
regex = "1"
tracing-subscriber = "0.3.19"
lru = "0.13.0"
once_cell = "1.19"
bincode = "1.3"
tracing = "0.1.41"
opentelemetry_sdk = { version = "0.29.0", features = ["trace"] }
opentelemetry = "0.29.1"
@ -61,3 +74,9 @@ features = [
[dependencies.tokio]
version = "1.0"
features = ["macros", "rt-multi-thread"]
[dev-dependencies]
tokio-test = "0.4"
mockall = "0.12"
tempfile = "3.8"
serial_test = "3.0"

View File

@ -34,7 +34,14 @@ pub async fn config_command(
let tts_client = data_read
.get::<TTSClientData>()
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client.voicevox_client.get_styles().await;
let voicevox_speakers = tts_client
.voicevox_client
.get_styles()
.await
.unwrap_or_else(|e| {
tracing::error!("Failed to get VOICEVOX styles: {}", e);
vec![("VOICEVOX API unavailable".to_string(), 1)]
});
let voicevox_speaker = config.voicevox_speaker.unwrap_or(1);
let tts_type = config.tts_type.unwrap_or(TTSType::GCP);
@ -54,11 +61,7 @@ pub async fn config_command(
.placeholder("読み上げAPIを選択"),
);
let server_button = CreateActionRow::Buttons(vec![CreateButton::new("TTS_CONFIG_SERVER")
.label("サーバー設定")
.style(ButtonStyle::Primary)]);
let mut components = vec![engine_select, server_button];
let mut components = vec![engine_select];
for (index, speaker_chunk) in voicevox_speakers[0..24].chunks(25).enumerate() {
let mut options = Vec::new();
@ -82,6 +85,12 @@ pub async fn config_command(
));
}
let server_button = CreateActionRow::Buttons(vec![CreateButton::new("TTS_CONFIG_SERVER")
.label("サーバー設定")
.style(ButtonStyle::Primary)]);
components.push(server_button);
command
.create_response(
&ctx.http,

View File

@ -81,32 +81,44 @@ pub async fn setup_command(
return Ok(());
}
let text_channel_id = {
let text_channel_ids = {
if let Some(mode) = command.data.options.get(0) {
match &mode.value {
serenity::all::CommandDataOptionValue::String(value) => {
match value.as_str() {
"TEXT_CHANNEL" => command.channel_id,
"TEXT_CHANNEL" => vec![command.channel_id],
"NEW_THREAD" => {
command
vec![command
.channel_id
.create_thread(&ctx.http, CreateThread::new("TTS").auto_archive_duration(AutoArchiveDuration::OneHour).kind(serenity::all::ChannelType::PublicThread))
.await
.unwrap()
.id
.id]
}
"VOICE_CHANNEL" => channel_id,
_ => channel_id,
"VOICE_CHANNEL" => vec![channel_id],
_ => if channel_id != command.channel_id {
vec![command.channel_id, channel_id]
} else {
vec![channel_id]
},
}
},
_ => channel_id,
_ => if channel_id != command.channel_id {
vec![command.channel_id, channel_id]
} else {
vec![channel_id]
},
}
} else {
channel_id
if channel_id != command.channel_id {
vec![command.channel_id, channel_id]
} else {
vec![channel_id]
}
}
};
let instance = TTSInstance::new(text_channel_id, channel_id, guild.id);
let instance = TTSInstance::new(text_channel_ids.clone(), channel_id, guild.id);
storage.insert(guild.id, instance.clone());
// Save to database
@ -121,7 +133,7 @@ pub async fn setup_command(
tracing::error!("Failed to save TTS instance to database: {}", e);
}
text_channel_id
text_channel_ids[0]
};
command
@ -149,7 +161,11 @@ pub async fn setup_command(
let tts_client = data
.get::<TTSClientData>()
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await
.unwrap_or_else(|e| {
tracing::error!("Failed to get VOICEVOX speakers: {}", e);
vec!["VOICEVOX API unavailable".to_string()]
});
text_channel_id
.send_message(&ctx.http, CreateMessage::new()

View File

@ -78,7 +78,7 @@ pub async fn stop_command(
return Ok(());
}
let text_channel_id = storage.get(&guild.id).unwrap().text_channel;
let text_channel_id = storage.get(&guild.id).unwrap().text_channels[0];
storage.remove(&guild.id);
// Remove from database

View File

@ -1,34 +1,80 @@
use serenity::{model::channel::Message, prelude::Context, all::{CreateMessage, CreateEmbed}};
use serenity::{
all::{CreateEmbed, CreateMessage},
prelude::Context,
};
use std::time::Duration;
use tokio::time;
use tracing::{error, info, warn};
use tracing::{error, info, instrument, warn};
use crate::data::{DatabaseClientData, TTSData};
/// Constants for connection monitoring
const CONNECTION_CHECK_INTERVAL_SECS: u64 = 5;
const MAX_RECONNECTION_ATTEMPTS: u32 = 3;
const RECONNECTION_BACKOFF_SECS: u64 = 2;
/// Errors that can occur during connection monitoring
#[derive(Debug, thiserror::Error)]
pub enum ConnectionMonitorError {
#[error("Failed to get songbird manager")]
SongbirdManagerNotFound,
#[error("Failed to check voice channel users: {0}")]
VoiceChannelCheck(String),
#[error("Failed to reconnect after {attempts} attempts")]
ReconnectionFailed { attempts: u32 },
#[error("Database operation failed: {0}")]
Database(#[from] redis::RedisError),
}
type Result<T> = std::result::Result<T, ConnectionMonitorError>;
/// Connection monitor that periodically checks voice channel connections
pub struct ConnectionMonitor;
pub struct ConnectionMonitor {
reconnection_attempts: std::collections::HashMap<serenity::model::id::GuildId, u32>,
}
impl Default for ConnectionMonitor {
fn default() -> Self {
Self::new()
}
}
impl ConnectionMonitor {
pub fn new() -> Self {
Self {
reconnection_attempts: std::collections::HashMap::new(),
}
}
/// Start the connection monitoring task
pub fn start(ctx: Context) {
tokio::spawn(async move {
info!("Starting connection monitor with 5s interval");
let mut interval = time::interval(Duration::from_secs(5));
let mut monitor = ConnectionMonitor::new();
info!(
interval_secs = CONNECTION_CHECK_INTERVAL_SECS,
"Starting connection monitor"
);
let mut interval = time::interval(Duration::from_secs(CONNECTION_CHECK_INTERVAL_SECS));
loop {
interval.tick().await;
Self::check_connections(&ctx).await;
if let Err(e) = monitor.check_connections(&ctx).await {
error!(error = %e, "Connection monitoring failed");
}
}
});
}
/// Check all active TTS instances and their voice channel connections
async fn check_connections(ctx: &Context) {
#[instrument(skip(self, ctx))]
async fn check_connections(&mut self, ctx: &Context) -> Result<()> {
let storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<TTSData>()
.expect("Cannot get TTSStorage")
.ok_or_else(|| {
ConnectionMonitorError::VoiceChannelCheck("Cannot get TTSStorage".to_string())
})?
.clone()
};
@ -36,7 +82,11 @@ impl ConnectionMonitor {
let data_read = ctx.data.read().await;
data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.ok_or_else(|| {
ConnectionMonitorError::VoiceChannelCheck(
"Cannot get DatabaseClientData".to_string(),
)
})?
.clone()
};
@ -45,13 +95,9 @@ impl ConnectionMonitor {
for (guild_id, instance) in storage.iter() {
// Check if bot is still connected to voice channel
let manager = match songbird::get(ctx).await {
Some(manager) => manager,
None => {
error!("Cannot get songbird manager");
continue;
}
};
let manager = songbird::get(ctx)
.await
.ok_or(ConnectionMonitorError::SongbirdManagerNotFound)?;
let call = manager.get(*guild_id);
let is_connected = if let Some(call) = call {
@ -65,49 +111,99 @@ impl ConnectionMonitor {
};
if !is_connected {
warn!("Bot disconnected from voice channel in guild {}", guild_id);
warn!(guild_id = %guild_id, "Bot disconnected from voice channel");
// Check if there are users in the voice channel
let should_reconnect = match Self::check_voice_channel_users(ctx, instance).await {
let should_reconnect = match self.check_voice_channel_users(ctx, instance).await {
Ok(has_users) => has_users,
Err(_) => {
// If we can't check users, don't reconnect
Err(e) => {
warn!(guild_id = %guild_id, error = %e, "Failed to check voice channel users, skipping reconnection");
false
}
};
if should_reconnect {
// Try to reconnect
// Try to reconnect with retry logic
let attempts = self
.reconnection_attempts
.get(guild_id)
.copied()
.unwrap_or(0);
if attempts >= MAX_RECONNECTION_ATTEMPTS {
error!(
guild_id = %guild_id,
attempts = attempts,
"Maximum reconnection attempts reached, removing instance"
);
guilds_to_remove.push(*guild_id);
self.reconnection_attempts.remove(guild_id);
continue;
}
// Apply exponential backoff
if attempts > 0 {
let backoff_duration =
Duration::from_secs(RECONNECTION_BACKOFF_SECS * (2_u64.pow(attempts)));
warn!(
guild_id = %guild_id,
attempt = attempts + 1,
backoff_secs = backoff_duration.as_secs(),
"Applying backoff before reconnection attempt"
);
tokio::time::sleep(backoff_duration).await;
}
match instance.reconnect(ctx, true).await {
Ok(_) => {
info!(
"Successfully reconnected to voice channel in guild {}",
guild_id
guild_id = %guild_id,
attempts = attempts + 1,
"Successfully reconnected to voice channel"
);
// Reset reconnection attempts on success
self.reconnection_attempts.remove(guild_id);
// Send notification message to text channel with embed
let embed = CreateEmbed::new()
.title("🔄 自動再接続しました")
.description("読み上げを停止したい場合は `/stop` コマンドを使用してください。")
.color(0x00ff00);
if let Err(e) = instance.text_channel.send_message(&ctx.http, CreateMessage::new().embed(embed)).await {
error!("Failed to send reconnection message to text channel: {}", e);
// Send message to the first text channel
if let Some(&text_channel) = instance.text_channels.first() {
if let Err(e) = text_channel
.send_message(&ctx.http, CreateMessage::new().embed(embed))
.await
{
error!(guild_id = %guild_id, error = %e, "Failed to send reconnection message");
}
}
}
Err(e) => {
let new_attempts = attempts + 1;
self.reconnection_attempts.insert(*guild_id, new_attempts);
error!(
"Failed to reconnect to voice channel in guild {}: {}",
guild_id, e
guild_id = %guild_id,
attempt = new_attempts,
error = %e,
"Failed to reconnect to voice channel"
);
if new_attempts >= MAX_RECONNECTION_ATTEMPTS {
guilds_to_remove.push(*guild_id);
self.reconnection_attempts.remove(guild_id);
}
}
}
} else {
info!(
"No users in voice channel, removing instance for guild {}",
guild_id
guild_id = %guild_id,
"No users in voice channel, removing instance"
);
guilds_to_remove.push(*guild_id);
self.reconnection_attempts.remove(guild_id);
}
}
}
@ -118,29 +214,59 @@ impl ConnectionMonitor {
// Remove from database
if let Err(e) = database.remove_tts_instance(guild_id).await {
error!("Failed to remove TTS instance from database: {}", e);
error!(guild_id = %guild_id, error = %e, "Failed to remove TTS instance from database");
}
// Ensure bot leaves voice channel
if let Some(manager) = songbird::get(ctx).await {
let _ = manager.remove(guild_id).await;
if let Err(e) = manager.remove(guild_id).await {
error!(guild_id = %guild_id, error = %e, "Failed to remove bot from voice channel");
}
}
info!(guild_id = %guild_id, "Removed disconnected TTS instance");
}
Ok(())
}
/// Check if there are users in the voice channel
#[instrument(skip(self, ctx, instance))]
async fn check_voice_channel_users(
&self,
ctx: &Context,
instance: &crate::tts::instance::TTSInstance,
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
let channels = instance.guild.channels(&ctx.http).await?;
) -> Result<bool> {
let channels = instance.guild.channels(&ctx.http).await.map_err(|e| {
ConnectionMonitorError::VoiceChannelCheck(format!(
"Failed to get guild channels: {}",
e
))
})?;
if let Some(channel) = channels.get(&instance.voice_channel) {
let members = channel.members(&ctx.cache)?;
let members = channel.members(&ctx.cache).map_err(|e| {
ConnectionMonitorError::VoiceChannelCheck(format!(
"Failed to get channel members: {}",
e
))
})?;
let user_count = members.iter().filter(|member| !member.user.bot).count();
info!(
guild_id = %instance.guild,
channel_id = %instance.voice_channel,
user_count = user_count,
"Checked voice channel users"
);
Ok(user_count > 0)
} else {
// Channel doesn't exist anymore
warn!(
guild_id = %instance.guild,
channel_id = %instance.voice_channel,
"Voice channel no longer exists"
);
Ok(false)
}
}

View File

@ -1,48 +1,75 @@
use std::fmt::Debug;
use crate::tts::{
use crate::{
errors::{constants::*, NCBError, Result},
tts::{
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, instance::TTSInstance,
tts_type::TTSType,
},
};
use serenity::model::id::GuildId;
use bb8_redis::{bb8::Pool, redis::AsyncCommands, RedisConnectionManager};
use serenity::model::id::{ChannelId, GuildId, UserId};
use std::collections::HashMap;
use super::{dictionary::Dictionary, server_config::ServerConfig, user_config::UserConfig};
use redis::Commands;
#[derive(Debug, Clone)]
pub struct Database {
pub client: redis::Client,
pub pool: Pool<RedisConnectionManager>,
}
impl Database {
pub fn new(client: redis::Client) -> Self {
Self { client }
pub fn new(pool: Pool<RedisConnectionManager>) -> Self {
Self { pool }
}
pub async fn new_with_url(redis_url: String) -> Result<Self> {
let manager = RedisConnectionManager::new(redis_url)?;
let pool = Pool::builder()
.max_size(15)
.build(manager)
.await
.map_err(|e| NCBError::Database(format!("Pool creation failed: {}", e)))?;
Ok(Self { pool })
}
fn server_key(server_id: u64) -> String {
format!("discord_server:{}", server_id)
format!("{}{}", DISCORD_SERVER_PREFIX, server_id)
}
fn user_key(user_id: u64) -> String {
format!("discord_user:{}", user_id)
format!("{}{}", DISCORD_USER_PREFIX, user_id)
}
fn tts_instance_key(guild_id: u64) -> String {
format!("tts_instance:{}", guild_id)
format!("{}{}", TTS_INSTANCE_PREFIX, guild_id)
}
fn tts_instances_list_key() -> String {
"tts_instances_list".to_string()
TTS_INSTANCES_LIST_KEY.to_string()
}
fn user_config_key(guild_id: u64, user_id: u64) -> String {
format!("user:config:{}:{}", guild_id, user_id)
}
fn server_config_key(guild_id: u64) -> String {
format!("server:config:{}", guild_id)
}
fn dictionary_key(guild_id: u64) -> String {
format!("dictionary:{}", guild_id)
}
#[tracing::instrument]
fn get_config<T: serde::de::DeserializeOwned>(
&self,
key: &str,
) -> redis::RedisResult<Option<T>> {
match self.client.get_connection() {
Ok(mut connection) => {
let config: String = connection.get(key).unwrap_or_default();
async fn get_config<T: serde::de::DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let config: String = connection.get(key).await.unwrap_or_default();
if config.is_empty() {
return Ok(None);
@ -50,73 +77,61 @@ impl Database {
match serde_json::from_str(&config) {
Ok(config) => Ok(Some(config)),
Err(_) => Ok(None),
Err(e) => {
tracing::warn!(key = key, error = %e, "Failed to deserialize config");
Ok(None)
}
}
Err(e) => Err(e),
}
}
#[tracing::instrument]
fn set_config<T: serde::Serialize + Debug>(
&self,
key: &str,
config: &T,
) -> redis::RedisResult<()> {
match self.client.get_connection() {
Ok(mut connection) => {
let config_str = serde_json::to_string(config).unwrap();
connection.set::<_, _, ()>(key, config_str)
}
Err(e) => Err(e),
}
async fn set_config<T: serde::Serialize + Debug>(&self, key: &str, config: &T) -> Result<()> {
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let config_str = serde_json::to_string(config)?;
connection.set::<_, _, ()>(key, config_str).await?;
Ok(())
}
#[tracing::instrument]
pub async fn get_server_config(
&self,
server_id: u64,
) -> redis::RedisResult<Option<ServerConfig>> {
self.get_config(&Self::server_key(server_id))
pub async fn get_server_config(&self, server_id: u64) -> Result<Option<ServerConfig>> {
self.get_config(&Self::server_key(server_id)).await
}
#[tracing::instrument]
pub async fn get_user_config(&self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
self.get_config(&Self::user_key(user_id))
pub async fn get_user_config(&self, user_id: u64) -> Result<Option<UserConfig>> {
self.get_config(&Self::user_key(user_id)).await
}
#[tracing::instrument]
pub async fn set_server_config(
&self,
server_id: u64,
config: ServerConfig,
) -> redis::RedisResult<()> {
self.set_config(&Self::server_key(server_id), &config)
pub async fn set_server_config(&self, server_id: u64, config: ServerConfig) -> Result<()> {
self.set_config(&Self::server_key(server_id), &config).await
}
#[tracing::instrument]
pub async fn set_user_config(
&self,
user_id: u64,
config: UserConfig,
) -> redis::RedisResult<()> {
self.set_config(&Self::user_key(user_id), &config)
pub async fn set_user_config(&self, user_id: u64, config: UserConfig) -> Result<()> {
self.set_config(&Self::user_key(user_id), &config).await
}
#[tracing::instrument]
pub async fn set_default_server_config(&self, server_id: u64) -> redis::RedisResult<()> {
pub async fn set_default_server_config(&self, server_id: u64) -> Result<()> {
let config = ServerConfig {
dictionary: Dictionary::new(),
autostart_channel_id: None,
voice_state_announce: Some(true),
read_username: Some(true),
autostart_text_channel_id: None,
voice_state_announce: Some(false),
read_username: Some(false),
};
self.set_server_config(server_id, config).await
}
#[tracing::instrument]
pub async fn set_default_user_config(&self, user_id: u64) -> redis::RedisResult<()> {
pub async fn set_default_user_config(&self, user_id: u64) -> Result<()> {
let voice_selection = VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
@ -126,7 +141,7 @@ impl Database {
let config = UserConfig {
tts_type: Some(TTSType::GCP),
gcp_tts_voice: Some(voice_selection),
voicevox_speaker: Some(1),
voicevox_speaker: Some(DEFAULT_VOICEVOX_SPEAKER),
};
self.set_user_config(user_id, config).await
@ -136,7 +151,7 @@ impl Database {
pub async fn get_server_config_or_default(
&self,
server_id: u64,
) -> redis::RedisResult<Option<ServerConfig>> {
) -> Result<Option<ServerConfig>> {
match self.get_server_config(server_id).await? {
Some(config) => Ok(Some(config)),
None => {
@ -147,10 +162,7 @@ impl Database {
}
#[tracing::instrument]
pub async fn get_user_config_or_default(
&self,
user_id: u64,
) -> redis::RedisResult<Option<UserConfig>> {
pub async fn get_user_config_or_default(&self, user_id: u64) -> Result<Option<UserConfig>> {
match self.get_user_config(user_id).await? {
Some(config) => Ok(Some(config)),
None => {
@ -161,77 +173,308 @@ impl Database {
}
/// Save TTS instance to database
#[tracing::instrument]
pub async fn save_tts_instance(
&self,
guild_id: GuildId,
instance: &TTSInstance,
) -> redis::RedisResult<()> {
pub async fn save_tts_instance(&self, guild_id: GuildId, instance: &TTSInstance) -> Result<()> {
let key = Self::tts_instance_key(guild_id.get());
let list_key = Self::tts_instances_list_key();
// Save the instance
let result = self.set_config(&key, instance);
self.set_config(&key, instance).await?;
// Add guild_id to the list of active instances
if result.is_ok() {
match self.client.get_connection() {
Ok(mut connection) => {
let _: redis::RedisResult<()> = connection.sadd(&list_key, guild_id.get());
}
Err(_) => {}
}
}
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
result
connection
.sadd::<_, _, ()>(&list_key, guild_id.get())
.await?;
Ok(())
}
/// Load TTS instance from database
#[tracing::instrument]
pub async fn load_tts_instance(
&self,
guild_id: GuildId,
) -> redis::RedisResult<Option<TTSInstance>> {
pub async fn load_tts_instance(&self, guild_id: GuildId) -> Result<Option<TTSInstance>> {
let key = Self::tts_instance_key(guild_id.get());
self.get_config(&key)
self.get_config(&key).await
}
/// Remove TTS instance from database
#[tracing::instrument]
pub async fn remove_tts_instance(&self, guild_id: GuildId) -> redis::RedisResult<()> {
pub async fn remove_tts_instance(&self, guild_id: GuildId) -> Result<()> {
let key = Self::tts_instance_key(guild_id.get());
let list_key = Self::tts_instances_list_key();
match self.client.get_connection() {
Ok(mut connection) => {
let _: redis::RedisResult<()> = connection.del(&key);
let _: redis::RedisResult<()> = connection.srem(&list_key, guild_id.get());
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
let _: std::result::Result<(), bb8_redis::redis::RedisError> =
connection.srem(&list_key, guild_id.get()).await;
Ok(())
}
Err(e) => Err(e),
}
}
/// Get all active TTS instances
#[tracing::instrument]
pub async fn get_all_tts_instances(&self) -> redis::RedisResult<Vec<(GuildId, TTSInstance)>> {
pub async fn get_all_tts_instances(&self) -> Result<Vec<(GuildId, TTSInstance)>> {
let list_key = Self::tts_instances_list_key();
match self.client.get_connection() {
Ok(mut connection) => {
let guild_ids: Vec<u64> = connection.smembers(&list_key).unwrap_or_default();
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let guild_ids: Vec<u64> = connection.smembers(&list_key).await.unwrap_or_default();
let mut instances = Vec::new();
for guild_id in guild_ids {
let guild_id = GuildId::new(guild_id);
if let Ok(Some(instance)) = self.load_tts_instance(guild_id).await {
instances.push((guild_id, instance));
} else {
tracing::warn!(guild_id = %guild_id, "Failed to load TTS instance");
}
}
Ok(instances)
}
Err(e) => Err(e),
// Additional user config methods
pub async fn save_user_config(
&self,
guild_id: GuildId,
user_id: UserId,
config: &UserConfig,
) -> Result<()> {
let key = Self::user_config_key(guild_id.get(), user_id.get());
self.set_config(&key, config).await
}
pub async fn load_user_config(
&self,
guild_id: GuildId,
user_id: UserId,
) -> Result<Option<UserConfig>> {
let key = Self::user_config_key(guild_id.get(), user_id.get());
self.get_config(&key).await
}
pub async fn delete_user_config(&self, guild_id: GuildId, user_id: UserId) -> Result<()> {
let key = Self::user_config_key(guild_id.get(), user_id.get());
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
Ok(())
}
// Additional server config methods
pub async fn save_server_config(&self, guild_id: GuildId, config: &ServerConfig) -> Result<()> {
let key = Self::server_config_key(guild_id.get());
self.set_config(&key, config).await
}
pub async fn load_server_config(&self, guild_id: GuildId) -> Result<Option<ServerConfig>> {
let key = Self::server_config_key(guild_id.get());
self.get_config(&key).await
}
pub async fn delete_server_config(&self, guild_id: GuildId) -> Result<()> {
let key = Self::server_config_key(guild_id.get());
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
Ok(())
}
// Dictionary methods
pub async fn save_dictionary(
&self,
guild_id: GuildId,
dictionary: &HashMap<String, String>,
) -> Result<()> {
let key = Self::dictionary_key(guild_id.get());
self.set_config(&key, dictionary).await
}
pub async fn load_dictionary(&self, guild_id: GuildId) -> Result<HashMap<String, String>> {
let key = Self::dictionary_key(guild_id.get());
let dict: Option<HashMap<String, String>> = self.get_config(&key).await?;
Ok(dict.unwrap_or_default())
}
pub async fn delete_dictionary(&self, guild_id: GuildId) -> Result<()> {
let key = Self::dictionary_key(guild_id.get());
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
Ok(())
}
pub async fn delete_tts_instance(&self, guild_id: GuildId) -> Result<()> {
self.remove_tts_instance(guild_id).await
}
pub async fn list_active_instances(&self) -> Result<Vec<u64>> {
let list_key = Self::tts_instances_list_key();
let mut connection = self
.pool
.get()
.await
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
let guild_ids: Vec<u64> = connection.smembers(&list_key).await.unwrap_or_default();
Ok(guild_ids)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::errors::constants;
use bb8_redis::redis::AsyncCommands;
use serial_test::serial;
// Helper function to create test database (requires Redis running)
async fn create_test_database() -> Result<Database> {
let manager = RedisConnectionManager::new("redis://127.0.0.1:6379/15")?; // Use test DB
let pool = bb8::Pool::builder()
.max_size(1)
.build(manager)
.await
.map_err(|e| NCBError::Database(format!("Pool creation failed: {}", e)))?;
Ok(Database { pool })
}
#[tokio::test]
#[serial]
async fn test_database_creation() {
// This test requires Redis to be running
match create_test_database().await {
Ok(_db) => {
// Test successful creation
assert!(true);
}
Err(_) => {
// Skip test if Redis is not available
return;
}
}
}
#[test]
fn test_key_generation() {
let guild_id = 123456789u64;
let user_id = 987654321u64;
// Test TTS instance key
let tts_key = Database::tts_instance_key(guild_id);
assert!(tts_key.contains(&guild_id.to_string()));
// Test TTS instances list key
let list_key = Database::tts_instances_list_key();
assert!(!list_key.is_empty());
// Test user config key
let user_key = Database::user_config_key(guild_id, user_id);
assert_eq!(user_key, "user:config:123456789:987654321");
// Test server config key
let server_key = Database::server_config_key(guild_id);
assert_eq!(server_key, "server:config:123456789");
// Test dictionary key
let dict_key = Database::dictionary_key(guild_id);
assert_eq!(dict_key, "dictionary:123456789");
}
#[tokio::test]
#[serial]
async fn test_tts_instance_operations() {
let db = match create_test_database().await {
Ok(db) => db,
Err(_) => return, // Skip if Redis not available
};
let guild_id = GuildId::new(12345);
let test_instance =
TTSInstance::new_single(ChannelId::new(123), ChannelId::new(456), guild_id);
// Clear any existing data
if let Ok(mut conn) = db.pool.get().await {
let _: () = conn
.del(Database::tts_instance_key(guild_id.get()))
.await
.unwrap_or_default();
let _: () = conn
.srem(Database::tts_instances_list_key(), guild_id.get())
.await
.unwrap_or_default();
} else {
return; // Skip if can't get connection
}
// Test saving TTS instance
let save_result = db.save_tts_instance(guild_id, &test_instance).await;
if save_result.is_err() {
// Skip test if Redis operations fail
return;
}
// Test loading TTS instance
let load_result = db.load_tts_instance(guild_id).await;
if load_result.is_err() {
return; // Skip if Redis operations fail
}
let loaded_instance = load_result.unwrap();
if let Some(instance) = loaded_instance {
assert_eq!(instance.guild, test_instance.guild);
assert_eq!(instance.text_channels, test_instance.text_channels);
assert_eq!(instance.voice_channel, test_instance.voice_channel);
}
// Test listing active instances
let list_result = db.list_active_instances().await;
if list_result.is_err() {
return; // Skip if Redis operations fail
}
let instances = list_result.unwrap();
assert!(instances.contains(&guild_id.get()));
// Test deleting TTS instance
let delete_result = db.delete_tts_instance(guild_id).await;
if delete_result.is_err() {
return; // Skip if Redis operations fail
}
// Verify deletion
let load_after_delete = db.load_tts_instance(guild_id).await;
if load_after_delete.is_err() {
return; // Skip if Redis operations fail
}
assert!(load_after_delete.unwrap().is_none());
}
#[test]
fn test_database_constants() {
// Test that constants are reasonable
assert!(constants::REDIS_CONNECTION_TIMEOUT_SECS > 0);
assert!(constants::REDIS_MAX_CONNECTIONS > 0);
assert!(constants::REDIS_MIN_IDLE_CONNECTIONS <= constants::REDIS_MAX_CONNECTIONS);
}
}

View File

@ -10,6 +10,7 @@ pub struct DictionaryOnlyServerConfig {
pub struct ServerConfig {
pub dictionary: Dictionary,
pub autostart_channel_id: Option<u64>,
pub autostart_text_channel_id: Option<u64>,
pub voice_state_announce: Option<bool>,
pub read_username: Option<bool>,
}

521
src/errors.rs Normal file
View File

@ -0,0 +1,521 @@
/// Custom error types for the NCB-TTS application
#[derive(Debug, thiserror::Error)]
pub enum NCBError {
#[error("Configuration error: {0}")]
Config(String),
#[error("Database error: {0}")]
Database(String),
#[error("VOICEVOX API error: {0}")]
VOICEVOX(String),
#[error("Discord error: {0}")]
Discord(#[from] serenity::Error),
#[error("TTS synthesis error: {0}")]
TTSSynthesis(String),
#[error("GCP authentication error: {0}")]
GCPAuth(#[from] gcp_auth::Error),
#[error("HTTP request error: {0}")]
Http(#[from] reqwest::Error),
#[error("JSON parsing error: {0}")]
Json(#[from] serde_json::Error),
#[error("Redis connection error: {0}")]
Redis(String),
#[error("Redis error: {0}")]
RedisError(#[from] bb8_redis::redis::RedisError),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Voice connection error: {0}")]
VoiceConnection(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Invalid regex pattern: {0}")]
InvalidRegex(String),
#[error("Songbird error: {0}")]
Songbird(String),
#[error("User not in voice channel")]
UserNotInVoiceChannel,
#[error("Guild not found")]
GuildNotFound,
#[error("Channel not found")]
ChannelNotFound,
#[error("TTS instance not found for guild {guild_id}")]
TTSInstanceNotFound { guild_id: u64 },
#[error("Text too long (max {max_length} characters)")]
TextTooLong { max_length: usize },
#[error("Text contains prohibited content")]
ProhibitedContent,
#[error("Rate limit exceeded")]
RateLimitExceeded,
#[error("TOML parsing error: {0}")]
Toml(#[from] toml::de::Error),
}
impl NCBError {
pub fn config(message: impl Into<String>) -> Self {
Self::Config(message.into())
}
pub fn database(message: impl Into<String>) -> Self {
Self::Database(message.into())
}
pub fn voicevox(message: impl Into<String>) -> Self {
Self::VOICEVOX(message.into())
}
pub fn voice_connection(message: impl Into<String>) -> Self {
Self::VoiceConnection(message.into())
}
pub fn tts_synthesis(message: impl Into<String>) -> Self {
Self::TTSSynthesis(message.into())
}
pub fn invalid_input(message: impl Into<String>) -> Self {
Self::InvalidInput(message.into())
}
pub fn invalid_regex(message: impl Into<String>) -> Self {
Self::InvalidRegex(message.into())
}
pub fn songbird(message: impl Into<String>) -> Self {
Self::Songbird(message.into())
}
pub fn tts_instance_not_found(guild_id: u64) -> Self {
Self::TTSInstanceNotFound { guild_id }
}
pub fn text_too_long(max_length: usize) -> Self {
Self::TextTooLong { max_length }
}
pub fn redis(message: impl Into<String>) -> Self {
Self::Redis(message.into())
}
pub fn missing_env_var(var_name: &str) -> Self {
Self::Config(format!("Missing environment variable: {}", var_name))
}
}
/// Result type alias for convenience
pub type Result<T> = std::result::Result<T, NCBError>;
/// Input validation functions
pub mod validation {
use super::*;
use regex::Regex;
/// Validate regex pattern for potential ReDoS attacks
pub fn validate_regex_pattern(pattern: &str) -> Result<()> {
// Check for common ReDoS patterns (catastrophic backtracking)
let redos_patterns = [
r"\(\?\:", // Non-capturing groups in dangerous positions
r"\(\?\=", // Positive lookahead
r"\(\?\!", // Negative lookahead
r"\(\?\<\=", // Positive lookbehind
r"\(\?\<\!", // Negative lookbehind
r"\*\*", // Actual nested quantifiers (not possessive)
r"\+\*", // Nested quantifiers
r"\*\+", // Nested quantifiers
];
for redos_pattern in &redos_patterns {
if pattern.contains(redos_pattern) {
return Err(NCBError::invalid_regex(format!(
"Pattern contains potentially dangerous construct: {}",
redos_pattern
)));
}
}
// Check pattern length
if pattern.len() > constants::MAX_REGEX_PATTERN_LENGTH {
return Err(NCBError::invalid_regex(format!(
"Pattern too long (max {} characters)",
constants::MAX_REGEX_PATTERN_LENGTH
)));
}
// Try to compile the regex to validate syntax
Regex::new(pattern)
.map_err(|e| NCBError::invalid_regex(format!("Invalid regex syntax: {}", e)))?;
Ok(())
}
/// Validate rule name
pub fn validate_rule_name(name: &str) -> Result<()> {
if name.trim().is_empty() {
return Err(NCBError::invalid_input("Rule name cannot be empty"));
}
if name.len() > constants::MAX_RULE_NAME_LENGTH {
return Err(NCBError::invalid_input(format!(
"Rule name too long (max {} characters)",
constants::MAX_RULE_NAME_LENGTH
)));
}
// Check for invalid characters
if !name
.chars()
.all(|c| c.is_alphanumeric() || c.is_whitespace() || "_-".contains(c))
{
return Err(NCBError::invalid_input(
"Rule name contains invalid characters (only alphanumeric, spaces, hyphens, and underscores allowed)"
));
}
Ok(())
}
/// Validate TTS text input
pub fn validate_tts_text(text: &str) -> Result<()> {
if text.trim().is_empty() {
return Err(NCBError::invalid_input("Text cannot be empty"));
}
if text.len() > constants::MAX_TTS_TEXT_LENGTH {
return Err(NCBError::text_too_long(constants::MAX_TTS_TEXT_LENGTH));
}
// Check for prohibited patterns
let prohibited_patterns = [
r"<script", // Script injection
r"javascript:", // JavaScript URLs
r"data:", // Data URLs
r"<?xml", // XML processing instructions
];
let text_lower = text.to_lowercase();
for pattern in &prohibited_patterns {
if text_lower.contains(pattern) {
return Err(NCBError::ProhibitedContent);
}
}
Ok(())
}
/// Validate replacement text for dictionary rules
pub fn validate_replacement_text(text: &str) -> Result<()> {
if text.trim().is_empty() {
return Err(NCBError::invalid_input("Replacement text cannot be empty"));
}
if text.len() > constants::MAX_TTS_TEXT_LENGTH {
return Err(NCBError::text_too_long(constants::MAX_TTS_TEXT_LENGTH));
}
Ok(())
}
/// Sanitize SSML input to prevent injection attacks
pub fn sanitize_ssml(text: &str) -> String {
// Remove or escape potentially dangerous SSML tags
let _dangerous_tags = [
"audio", "break", "emphasis", "lang", "mark", "p", "phoneme", "prosody", "say-as",
"speak", "sub", "voice", "w",
];
let mut sanitized = text.to_string();
// Remove script-like content
sanitized = sanitized.replace("<script", "&lt;script");
sanitized = sanitized.replace("javascript:", "");
sanitized = sanitized.replace("data:", "");
// Limit the overall length
if sanitized.len() > constants::MAX_SSML_LENGTH {
sanitized.truncate(constants::MAX_SSML_LENGTH);
}
sanitized
}
}
/// Constants used throughout the application
pub mod constants {
// Configuration constants
pub const DEFAULT_CONFIG_PATH: &str = "config.toml";
pub const DEFAULT_DICTIONARY_PATH: &str = "dictionary.txt";
// Redis constants
pub const REDIS_CONNECTION_TIMEOUT_SECS: u64 = 5;
pub const REDIS_MAX_CONNECTIONS: u32 = 10;
pub const REDIS_MIN_IDLE_CONNECTIONS: u32 = 1;
// Cache constants
pub const DEFAULT_CACHE_SIZE: usize = 1000;
pub const CACHE_TTL_SECS: u64 = 86400; // 24 hours
// TTS constants
pub const MAX_TTS_TEXT_LENGTH: usize = 500;
pub const MAX_SSML_LENGTH: usize = 1000;
pub const TTS_TIMEOUT_SECS: u64 = 30;
pub const DEFAULT_SPEAKING_RATE: f32 = 1.2;
pub const DEFAULT_PITCH: f32 = 0.0;
// Validation constants
pub const MAX_REGEX_PATTERN_LENGTH: usize = 100;
pub const MAX_RULE_NAME_LENGTH: usize = 50;
pub const MAX_USERNAME_LENGTH: usize = 32;
// Circuit breaker constants
pub const CIRCUIT_BREAKER_FAILURE_THRESHOLD: u32 = 5;
pub const CIRCUIT_BREAKER_TIMEOUT_SECS: u64 = 60;
// Retry constants
pub const DEFAULT_MAX_RETRY_ATTEMPTS: u32 = 3;
pub const DEFAULT_RETRY_DELAY_MS: u64 = 500;
pub const MAX_RETRY_DELAY_MS: u64 = 5000;
// Connection monitoring constants
pub const CONNECTION_CHECK_INTERVAL_SECS: u64 = 5;
pub const MAX_RECONNECTION_ATTEMPTS: u32 = 3;
pub const RECONNECTION_BACKOFF_SECS: u64 = 2;
// Voice connection constants
pub const VOICE_CONNECTION_TIMEOUT_SECS: u64 = 10;
pub const AUDIO_BITRATE_KBPS: u32 = 128;
pub const AUDIO_SAMPLE_RATE: u32 = 48000;
// Database key prefixes
pub const DISCORD_SERVER_PREFIX: &str = "discord:server:";
pub const DISCORD_USER_PREFIX: &str = "discord:user:";
pub const TTS_INSTANCE_PREFIX: &str = "tts:instance:";
pub const TTS_INSTANCES_LIST_KEY: &str = "tts:instances";
// Default values
pub const DEFAULT_VOICEVOX_SPEAKER: i64 = 1;
// Message constants
pub const RULE_ADDED: &str = "RULE_ADDED";
pub const RULE_REMOVED: &str = "RULE_REMOVED";
pub const RULE_ALREADY_EXISTS: &str = "RULE_ALREADY_EXISTS";
pub const RULE_NOT_FOUND: &str = "RULE_NOT_FOUND";
pub const DICTIONARY_RULE_APPLIED: &str = "DICTIONARY_RULE_APPLIED";
pub const GUILD_NOT_FOUND: &str = "GUILD_NOT_FOUND";
pub const CHANNEL_JOIN_SUCCESS: &str = "CHANNEL_JOIN_SUCCESS";
pub const CHANNEL_LEAVE_SUCCESS: &str = "CHANNEL_LEAVE_SUCCESS";
pub const AUTOSTART_CHANNEL_SET: &str = "AUTOSTART_CHANNEL_SET";
pub const SET_AUTOSTART_CHANNEL_CLEAR: &str = "SET_AUTOSTART_CHANNEL_CLEAR";
pub const SET_AUTOSTART_TEXT_CHANNEL: &str = "SET_AUTOSTART_TEXT_CHANNEL";
pub const SET_AUTOSTART_TEXT_CHANNEL_CLEAR: &str = "SET_AUTOSTART_TEXT_CHANNEL_CLEAR";
// TTS configuration constants
pub const TTS_CONFIG_SERVER_ADD_DICTIONARY: &str = "TTS_CONFIG_SERVER_ADD_DICTIONARY";
pub const TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE: &str =
"TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE";
pub const TTS_CONFIG_SERVER_SET_READ_USERNAME: &str = "TTS_CONFIG_SERVER_SET_READ_USERNAME";
pub const TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU: &str =
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU";
pub const TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON: &str =
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON";
pub const TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON: &str =
"TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON";
pub const TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON: &str =
"TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON";
pub const SET_AUTOSTART_CHANNEL: &str = "SET_AUTOSTART_CHANNEL";
pub const TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL: &str =
"TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL";
pub const TTS_CONFIG_SERVER_BACK: &str = "TTS_CONFIG_SERVER_BACK";
pub const TTS_CONFIG_SERVER: &str = "TTS_CONFIG_SERVER";
pub const TTS_CONFIG_SERVER_DICTIONARY: &str = "TTS_CONFIG_SERVER_DICTIONARY";
// TTS engine selection messages
pub const TTS_CONFIG_ENGINE_SELECTED_GOOGLE: &str = "TTS_CONFIG_ENGINE_SELECTED_GOOGLE";
pub const TTS_CONFIG_ENGINE_SELECTED_VOICEVOX: &str = "TTS_CONFIG_ENGINE_SELECTED_VOICEVOX";
// Error messages
pub const USER_NOT_IN_VOICE_CHANNEL: &str = "USER_NOT_IN_VOICE_CHANNEL";
pub const CHANNEL_NOT_FOUND: &str = "CHANNEL_NOT_FOUND";
// Rate limiting constants
pub const RATE_LIMIT_REQUESTS_PER_MINUTE: u32 = 60;
pub const RATE_LIMIT_REQUESTS_PER_HOUR: u32 = 1000;
pub const RATE_LIMIT_WINDOW_SECS: u64 = 60;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ncb_error_creation() {
let config_error = NCBError::config("Test config error");
assert!(matches!(config_error, NCBError::Config(_)));
assert_eq!(
config_error.to_string(),
"Configuration error: Test config error"
);
let database_error = NCBError::database("Test database error");
assert!(matches!(database_error, NCBError::Database(_)));
assert_eq!(
database_error.to_string(),
"Database error: Test database error"
);
let voicevox_error = NCBError::voicevox("Test VOICEVOX error");
assert!(matches!(voicevox_error, NCBError::VOICEVOX(_)));
assert_eq!(
voicevox_error.to_string(),
"VOICEVOX API error: Test VOICEVOX error"
);
}
#[test]
fn test_tts_instance_not_found_error() {
let guild_id = 12345u64;
let error = NCBError::tts_instance_not_found(guild_id);
assert!(matches!(
error,
NCBError::TTSInstanceNotFound { guild_id: 12345 }
));
assert_eq!(error.to_string(), "TTS instance not found for guild 12345");
}
#[test]
fn test_text_too_long_error() {
let max_length = 500;
let error = NCBError::text_too_long(max_length);
assert!(matches!(error, NCBError::TextTooLong { max_length: 500 }));
assert_eq!(error.to_string(), "Text too long (max 500 characters)");
}
mod validation_tests {
use super::super::constants;
use super::super::validation::*;
#[test]
fn test_validate_regex_pattern_valid() {
assert!(validate_regex_pattern(r"[a-zA-Z]+").is_ok());
assert!(validate_regex_pattern(r"\d{1,3}").is_ok());
assert!(validate_regex_pattern(r"hello|world").is_ok());
}
#[test]
fn test_validate_regex_pattern_redos() {
// Test that the validation function properly checks patterns
// Most problematic patterns are caught by regex compilation errors
// This test focuses on basic pattern safety checks
// Test length validation works
let very_long_pattern = "a".repeat(constants::MAX_REGEX_PATTERN_LENGTH + 1);
assert!(validate_regex_pattern(&very_long_pattern).is_err());
// Test basic pattern validation passes for safe patterns
assert!(validate_regex_pattern(r"[a-z]+").is_ok());
assert!(validate_regex_pattern(r"\d{1,3}").is_ok());
}
#[test]
fn test_validate_regex_pattern_too_long() {
let long_pattern = "a".repeat(constants::MAX_REGEX_PATTERN_LENGTH + 1);
assert!(validate_regex_pattern(&long_pattern).is_err());
}
#[test]
fn test_validate_regex_pattern_invalid_syntax() {
assert!(validate_regex_pattern(r"[").is_err());
assert!(validate_regex_pattern(r"*").is_err());
assert!(validate_regex_pattern(r"(?P<>)").is_err());
}
#[test]
fn test_validate_rule_name_valid() {
assert!(validate_rule_name("test_rule").is_ok());
assert!(validate_rule_name("Test Rule 123").is_ok());
assert!(validate_rule_name("rule-name").is_ok());
}
#[test]
fn test_validate_rule_name_empty() {
assert!(validate_rule_name("").is_err());
assert!(validate_rule_name(" ").is_err());
}
#[test]
fn test_validate_rule_name_too_long() {
let long_name = "a".repeat(constants::MAX_RULE_NAME_LENGTH + 1);
assert!(validate_rule_name(&long_name).is_err());
}
#[test]
fn test_validate_rule_name_invalid_chars() {
assert!(validate_rule_name("rule@name").is_err());
assert!(validate_rule_name("rule#name").is_err());
assert!(validate_rule_name("rule$name").is_err());
}
#[test]
fn test_validate_tts_text_valid() {
assert!(validate_tts_text("Hello world").is_ok());
assert!(validate_tts_text("こんにちは").is_ok());
assert!(validate_tts_text("Test with numbers 123").is_ok());
}
#[test]
fn test_validate_tts_text_empty() {
assert!(validate_tts_text("").is_err());
assert!(validate_tts_text(" ").is_err());
}
#[test]
fn test_validate_tts_text_too_long() {
let long_text = "a".repeat(constants::MAX_TTS_TEXT_LENGTH + 1);
assert!(validate_tts_text(&long_text).is_err());
}
#[test]
fn test_validate_tts_text_prohibited_content() {
assert!(validate_tts_text("<script>alert('xss')</script>").is_err());
assert!(validate_tts_text("javascript:alert('xss')").is_err());
assert!(validate_tts_text("data:text/html,<h1>XSS</h1>").is_err());
assert!(validate_tts_text("<?xml version=\"1.0\"?>").is_err());
}
#[test]
fn test_sanitize_ssml() {
let input = "<script>alert('xss')</script>Hello world";
let output = sanitize_ssml(input);
assert!(!output.contains("<script"));
assert!(output.contains("&lt;script"));
assert!(output.contains("Hello world"));
let input_with_js = "javascript:alert('test')Hello";
let output = sanitize_ssml(input_with_js);
assert!(!output.contains("javascript:"));
assert!(output.contains("Hello"));
let long_input = "a".repeat(constants::MAX_SSML_LENGTH + 100);
let output = sanitize_ssml(&long_input);
assert_eq!(output.len(), constants::MAX_SSML_LENGTH);
}
}
}

View File

@ -4,6 +4,7 @@ use crate::{
},
data::DatabaseClientData,
database::dictionary::Rule,
errors::{constants::*, validation},
events,
tts::tts_type::TTSType,
};
@ -49,28 +50,84 @@ impl EventHandler for Handler {
}
}
if let Interaction::Modal(modal) = interaction.clone() {
if modal.data.custom_id != "TTS_CONFIG_SERVER_ADD_DICTIONARY" {
if modal.data.custom_id != TTS_CONFIG_SERVER_ADD_DICTIONARY {
return;
}
let rows = modal.data.components.clone();
// Extract rule name with proper error handling
let rule_name =
if let ActionRowComponent::InputText(text) = rows[0].components[0].clone() {
text.value.unwrap()
match rows
.get(0)
.and_then(|row| row.components.get(0))
.and_then(|component| {
if let ActionRowComponent::InputText(text) = component {
text.value.as_ref()
} else {
panic!("Cannot get rule name");
None
}
}) {
Some(name) => {
if let Err(e) = validation::validate_rule_name(name) {
tracing::error!("Invalid rule name: {}", e);
return;
}
name.clone()
}
None => {
tracing::error!("Cannot extract rule name from modal");
return;
}
};
let from = if let ActionRowComponent::InputText(text) = rows[1].components[0].clone() {
text.value.unwrap()
// Extract 'from' field with validation
let from =
match rows
.get(1)
.and_then(|row| row.components.get(0))
.and_then(|component| {
if let ActionRowComponent::InputText(text) = component {
text.value.as_ref()
} else {
panic!("Cannot get from");
None
}
}) {
Some(pattern) => {
if let Err(e) = validation::validate_regex_pattern(pattern) {
tracing::error!("Invalid regex pattern: {}", e);
return;
}
pattern.clone()
}
None => {
tracing::error!("Cannot extract regex pattern from modal");
return;
}
};
let to = if let ActionRowComponent::InputText(text) = rows[2].components[0].clone() {
text.value.unwrap()
// Extract 'to' field with validation
let to = match rows
.get(2)
.and_then(|row| row.components.get(0))
.and_then(|component| {
if let ActionRowComponent::InputText(text) = component {
text.value.as_ref()
} else {
panic!("Cannot get to");
None
}
}) {
Some(replacement) => {
if let Err(e) = validation::validate_replacement_text(replacement) {
tracing::error!("Invalid replacement text: {}", e);
return;
}
replacement.clone()
}
None => {
tracing::error!("Cannot extract replacement text from modal");
return;
}
};
let rule = Rule {
@ -83,29 +140,47 @@ impl EventHandler for Handler {
let data_read = ctx.data.read().await;
let mut config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let database = match data_read.get::<DatabaseClientData>() {
Some(db) => db.clone(),
None => {
tracing::error!("Cannot get DatabaseClientData");
return;
}
};
database
match database
.get_server_config_or_default(modal.guild_id.unwrap().get())
.await
.unwrap()
.unwrap()
{
Ok(Some(config)) => config,
Ok(None) => {
tracing::error!("No server config found");
return;
}
Err(e) => {
tracing::error!("Database error: {}", e);
return;
}
}
};
config.dictionary.rules.push(rule);
{
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let database = match data_read.get::<DatabaseClientData>() {
Some(db) => db.clone(),
None => {
tracing::error!("Cannot get DatabaseClientData");
return;
}
};
database
if let Err(e) = database
.set_server_config(modal.guild_id.unwrap().get(), config)
.await
.unwrap();
{
tracing::error!("Failed to save server config: {}", e);
return;
}
modal
.create_response(
&ctx.http,
@ -122,7 +197,7 @@ impl EventHandler for Handler {
}
if let Some(message_component) = interaction.message_component() {
match &*message_component.data.custom_id {
"TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE" => {
id if id == TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE => {
let data_read = ctx.data.read().await;
let mut config = {
let database = data_read
@ -166,7 +241,7 @@ impl EventHandler for Handler {
.await
.unwrap();
}
"TTS_CONFIG_SERVER_SET_READ_USERNAME" => {
id if id == TTS_CONFIG_SERVER_SET_READ_USERNAME => {
let data_read = ctx.data.read().await;
let mut config = {
let database = data_read
@ -209,7 +284,7 @@ impl EventHandler for Handler {
.await
.unwrap();
}
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU" => {
id if id == TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU => {
let i = usize::from_str_radix(
&match message_component.data.kind {
ComponentInteractionDataKind::StringSelect { ref values, .. } => {
@ -259,7 +334,7 @@ impl EventHandler for Handler {
.await
.unwrap();
}
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON" => {
id if id == TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON => {
let data_read = ctx.data.read().await;
let config = {
@ -313,7 +388,7 @@ impl EventHandler for Handler {
.await
.unwrap();
}
"TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON" => {
id if id == TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON => {
let config = {
let data_read = ctx.data.read().await;
let database = data_read
@ -351,7 +426,7 @@ impl EventHandler for Handler {
.await
.unwrap();
}
"TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON" => {
id if id == TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON => {
message_component
.create_response(
&ctx.http,
@ -390,7 +465,7 @@ impl EventHandler for Handler {
.await
.unwrap();
}
"SET_AUTOSTART_CHANNEL" => {
id if id == SET_AUTOSTART_CHANNEL => {
let autostart_channel_id = match message_component.data.kind {
ComponentInteractionDataKind::StringSelect { ref values, .. } => {
if values.len() == 0 {
@ -438,14 +513,70 @@ impl EventHandler for Handler {
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new()
.content(response_content),
CreateInteractionResponseMessage::new().content(response_content),
),
)
.await
.unwrap();
}
"TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL" => {
id if id == SET_AUTOSTART_TEXT_CHANNEL => {
let autostart_text_channel_id = match message_component.data.kind {
ComponentInteractionDataKind::StringSelect { ref values, .. } => {
if values.len() == 0 {
None
} else if values[0] == "SET_AUTOSTART_TEXT_CHANNEL_CLEAR" {
None
} else {
Some(
u64::from_str_radix(
&values[0]
.strip_prefix("SET_AUTOSTART_TEXT_CHANNEL_")
.unwrap(),
10,
)
.unwrap(),
)
}
}
_ => panic!("Cannot get index"),
};
{
let data_read = ctx.data.read().await;
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut config = database
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
.unwrap()
.unwrap();
config.autostart_text_channel_id = autostart_text_channel_id;
database
.set_server_config(message_component.guild_id.unwrap().get(), config)
.await
.unwrap();
}
let response_content = if autostart_text_channel_id.is_some() {
"自動参加テキストチャンネルを設定しました。"
} else {
"自動参加テキストチャンネルを解除しました。"
};
message_component
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new().content(response_content),
),
)
.await
.unwrap();
}
id if id == TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL => {
let config = {
let data_read = ctx.data.read().await;
let database = data_read
@ -472,15 +603,13 @@ impl EventHandler for Handler {
let mut options = Vec::new();
// 解除オプションを追加
let clear_option = CreateSelectMenuOption::new(
"解除",
"SET_AUTOSTART_CHANNEL_CLEAR",
)
let clear_option =
CreateSelectMenuOption::new("解除", "SET_AUTOSTART_CHANNEL_CLEAR")
.description("自動参加チャンネルを解除します")
.default_selection(autostart_channel_id == 0);
options.push(clear_option);
for (id, channel) in channels {
for (id, channel) in channels.clone() {
if channel.kind != ChannelType::Voice {
continue;
}
@ -498,6 +627,33 @@ impl EventHandler for Handler {
options.push(option);
}
let mut text_channel_options = Vec::new();
let clear_option =
CreateSelectMenuOption::new("解除", "SET_AUTOSTART_TEXT_CHANNEL_CLEAR")
.description("自動参加テキストチャンネルを解除します")
.default_selection(config.autostart_text_channel_id.is_none());
text_channel_options.push(clear_option);
for (id, channel) in channels {
if channel.kind != ChannelType::Text {
continue;
}
let description = channel
.topic
.unwrap_or_else(|| String::from("No topic provided."));
let option = CreateSelectMenuOption::new(
&channel.name,
format!("SET_AUTOSTART_TEXT_CHANNEL_{}", id.get()),
)
.description(description)
.default_selection(
channel.id.get() == config.autostart_text_channel_id.unwrap_or(0),
);
text_channel_options.push(option);
}
message_component
.create_response(
&ctx.http,
@ -513,6 +669,16 @@ impl EventHandler for Handler {
.min_values(0)
.max_values(1),
),
CreateActionRow::SelectMenu(
CreateSelectMenu::new(
"SET_AUTOSTART_TEXT_CHANNEL",
CreateSelectMenuKind::String {
options: text_channel_options,
},
)
.min_values(0)
.max_values(1),
),
CreateActionRow::Buttons(vec![CreateButton::new(
"TTS_CONFIG_SERVER_BACK",
)
@ -524,7 +690,7 @@ impl EventHandler for Handler {
.await
.unwrap();
}
"TTS_CONFIG_SERVER_BACK" => {
id if id == TTS_CONFIG_SERVER_BACK => {
message_component
.create_response(
&ctx.http,
@ -554,7 +720,7 @@ impl EventHandler for Handler {
.await
.unwrap();
}
"TTS_CONFIG_SERVER" => {
id if id == TTS_CONFIG_SERVER => {
message_component
.create_response(
&ctx.http,
@ -584,7 +750,7 @@ impl EventHandler for Handler {
.await
.unwrap();
}
"TTS_CONFIG_SERVER_DICTIONARY" => {
id if id == TTS_CONFIG_SERVER_DICTIONARY => {
message_component
.create_response(
&ctx.http,

View File

@ -31,7 +31,7 @@ pub async fn message(ctx: Context, message: Message) {
let instance = storage.get_mut(&guild_id).unwrap();
if instance.text_channel != message.channel_id {
if !instance.contains_text_channel(message.channel_id) {
return;
}

View File

@ -48,10 +48,6 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
.unwrap()
};
if !config.voice_state_announce.unwrap_or(true) {
return;
}
{
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild_id) {
@ -62,7 +58,14 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
.expect("Cannot get songbird client.")
.clone();
let instance = TTSInstance::new(new_channel, new_channel, guild_id);
let text_channel_ids =
if let Some(text_channel_id) = config.autostart_text_channel_id {
vec![text_channel_id.into(), new_channel]
} else {
vec![new_channel]
};
let instance = TTSInstance::new(text_channel_ids, new_channel, guild_id);
storage.insert(guild_id, instance.clone());
// Save to database
@ -82,7 +85,14 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
let tts_client = data
.get::<TTSClientData>()
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
let voicevox_speakers = tts_client
.voicevox_client
.get_speakers()
.await
.unwrap_or_else(|e| {
tracing::error!("Failed to get VOICEVOX speakers: {}", e);
vec!["VOICEVOX API unavailable".to_string()]
});
new_channel
.send_message(
@ -110,6 +120,7 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
let voice_move_state = new.move_state(&old, instance.voice_channel);
if config.voice_state_announce.unwrap_or(false) {
let message: Option<String> = match voice_move_state {
VoiceMoveState::JOIN => Some(format!(
"{} さんが通話に参加しました",
@ -125,6 +136,7 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
if let Some(message) = message {
instance.read(AnnounceMessage { message }, &ctx).await;
}
}
if voice_move_state == VoiceMoveState::LEAVE {
let mut del_flag = false;
@ -138,12 +150,15 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
}
if del_flag {
let _ = storage
.get(&guild_id)
.unwrap()
.text_channel
.edit_thread(&ctx.http, EditThread::new().archived(true))
// Archive thread if it exists
if let Some(&channel_id) = storage.get(&guild_id).unwrap().text_channels.first() {
let http = ctx.http.clone();
tokio::spawn(async move {
let _ = channel_id
.edit_thread(&http, EditThread::new().archived(true))
.await;
});
}
storage.remove(&guild_id);
// Remove from database

View File

@ -1,10 +1,11 @@
use async_trait::async_trait;
use regex::Regex;
use serenity::{model::prelude::Message, prelude::Context};
use songbird::tracks::Track;
use tracing::{error, warn};
use crate::{
data::{DatabaseClientData, TTSClientData},
errors::{constants::*, validation, NCBError},
implement::member_name::ReadName,
tts::{
gcp_tts::structs::{
@ -15,6 +16,7 @@ use crate::{
message::TTSMessage,
tts_type::TTSType,
},
utils::{get_cached_regex, retry_with_backoff},
};
#[async_trait]
@ -25,19 +27,49 @@ impl TTSMessage for Message {
let config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_server_config_or_default(instance.guild.get())
.await
.unwrap()
.unwrap()
.ok_or_else(|| NCBError::config("Cannot get DatabaseClientData"))
.map_err(|e| {
error!(error = %e, "Failed to get database client");
e
})
.unwrap(); // This is safe as we're in a critical path
match database.get_server_config_or_default(instance.guild.get()).await {
Ok(Some(config)) => config,
Ok(None) => {
error!(guild_id = %instance.guild, "No server config available");
return self.content.clone(); // Fallback to original text
},
Err(e) => {
error!(guild_id = %instance.guild, error = %e, "Failed to get server config");
return self.content.clone(); // Fallback to original text
}
}
};
let mut text = self.content.clone();
// Validate text length before processing
if let Err(e) = validation::validate_tts_text(&text) {
warn!(error = %e, "Invalid TTS text, using truncated version");
text.truncate(crate::errors::constants::MAX_TTS_TEXT_LENGTH);
}
for rule in config.dictionary.rules {
if rule.is_regex {
let regex = Regex::new(&rule.rule).unwrap();
text = regex.replace_all(&text, rule.to).to_string();
match get_cached_regex(&rule.rule) {
Ok(regex) => {
text = regex.replace_all(&text, &rule.to).to_string();
}
Err(e) => {
warn!(
rule_id = rule.id,
pattern = rule.rule,
error = %e,
"Skipping invalid regex rule"
);
continue;
}
}
} else {
text = text.replace(&rule.rule, &rule.to);
}
@ -46,17 +78,7 @@ impl TTSMessage for Message {
if before_message.author.id == self.author.id {
text.clone()
} else {
let member = self.member.clone();
let name = if let Some(_) = member {
let guild = ctx.cache.guild(self.guild_id.unwrap()).unwrap().clone();
guild
.member(&ctx.http, self.author.id)
.await
.unwrap()
.read_name()
} else {
self.author.read_name()
};
let name = get_user_name(self, ctx).await;
if config.read_username.unwrap_or(true) {
format!("{}さんの発言<break time=\"200ms\"/>{}", name, text)
} else {
@ -64,17 +86,7 @@ impl TTSMessage for Message {
}
}
} else {
let member = self.member.clone();
let name = if let Some(_) = member {
let guild = ctx.cache.guild(self.guild_id.unwrap()).unwrap().clone();
guild
.member(&ctx.http, self.author.id)
.await
.unwrap()
.read_name()
} else {
self.author.read_name()
};
let name = get_user_name(self, ctx).await;
if config.read_username.unwrap_or(true) {
format!("{}さんの発言<break time=\"200ms\"/>{}", name, text)
@ -104,45 +116,111 @@ impl TTSMessage for Message {
let config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_user_config_or_default(self.author.id.get())
.await
.unwrap()
.unwrap()
.ok_or_else(|| NCBError::config("Cannot get DatabaseClientData"))
.unwrap();
match database.get_user_config_or_default(self.author.id.get()).await {
Ok(Some(config)) => config,
Ok(None) | Err(_) => {
error!(user_id = %self.author.id, "Failed to get user config, using defaults");
// Return default config
crate::database::user_config::UserConfig {
tts_type: Some(TTSType::GCP),
gcp_tts_voice: Some(crate::tts::gcp_tts::structs::voice_selection_params::VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
ssmlGender: String::from("neutral"),
}),
voicevox_speaker: Some(crate::errors::constants::DEFAULT_VOICEVOX_SPEAKER),
}
}
}
};
let tts = data_read
.get::<TTSClientData>()
.expect("Cannot get GCP TTSClientStorage");
.ok_or_else(|| NCBError::config("Cannot get TTSClientData"))
.unwrap();
match config.tts_type.unwrap_or(TTSType::GCP) {
TTSType::GCP => vec![tts
.synthesize_gcp(SynthesizeRequest {
// Synthesize with retry logic
let synthesis_result = match config.tts_type.unwrap_or(TTSType::GCP) {
TTSType::GCP => {
let sanitized_text = validation::sanitize_ssml(&text);
retry_with_backoff(
|| {
tts.synthesize_gcp(SynthesizeRequest {
input: SynthesisInput {
text: None,
ssml: Some(format!("<speak>{}</speak>", text)),
ssml: Some(format!("<speak>{}</speak>", sanitized_text)),
},
voice: config.gcp_tts_voice.unwrap(),
voice: config.gcp_tts_voice.clone().unwrap_or_else(|| {
crate::tts::gcp_tts::structs::voice_selection_params::VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
ssmlGender: String::from("neutral"),
}
}),
audioConfig: AudioConfig {
audioEncoding: String::from("mp3"),
speakingRate: 1.2f32,
pitch: 1.0f32,
speakingRate: DEFAULT_SPEAKING_RATE,
pitch: DEFAULT_PITCH,
},
})
.await
.unwrap()
.into()],
TTSType::VOICEVOX => vec![tts
.synthesize_voicevox(
&text.replace("<break time=\"200ms\"/>", ""),
config.voicevox_speaker.unwrap_or(1),
},
3, // max attempts
std::time::Duration::from_millis(500),
).await
}
TTSType::VOICEVOX => {
let processed_text = text.replace("<break time=\"200ms\"/>", "");
retry_with_backoff(
|| {
tts.synthesize_voicevox(
&processed_text,
config.voicevox_speaker.unwrap_or(crate::errors::constants::DEFAULT_VOICEVOX_SPEAKER),
)
.await
.unwrap()
.into()],
},
3, // max attempts
std::time::Duration::from_millis(500),
).await
}
};
match synthesis_result {
Ok(track) => vec![track],
Err(e) => {
error!(error = %e, "TTS synthesis failed");
vec![] // Return empty vector on failure
}
}
}
}
/// Helper function to get user name with proper error handling
async fn get_user_name(message: &Message, ctx: &Context) -> String {
let member = message.member.clone();
if let Some(_) = member {
if let Some(guild_id) = message.guild_id {
match guild_id.member(&ctx.http, message.author.id).await {
Ok(member) => member.read_name(),
Err(e) => {
warn!(
user_id = %message.author.id,
guild_id = ?message.guild_id,
error = %e,
"Failed to get guild member, using fallback name"
);
message.author.read_name()
}
}
} else {
warn!(
guild_id = ?message.guild_id,
"Guild not found in cache, using author name"
);
message.author.read_name()
}
} else {
message.author.read_name()
}
}

20
src/lib.rs Normal file
View File

@ -0,0 +1,20 @@
// Public API for the NCB-TTS-R2 library
pub mod errors;
pub mod utils;
pub mod tts;
pub mod database;
pub mod config;
pub mod data;
pub mod implement;
pub mod events;
pub mod commands;
pub mod stream_input;
pub mod trace;
pub mod event_handler;
pub mod connection_monitor;
// Re-export commonly used types
pub use errors::{NCBError, Result};
pub use utils::{CircuitBreaker, CircuitBreakerState, retry_with_backoff, get_cached_regex, PerformanceMetrics};
pub use tts::tts_type::TTSType;

View File

@ -3,18 +3,21 @@ mod config;
mod connection_monitor;
mod data;
mod database;
mod errors;
mod event_handler;
mod events;
mod implement;
mod stream_input;
mod trace;
mod tts;
mod utils;
use std::{collections::HashMap, env, sync::Arc};
use config::Config;
use data::{DatabaseClientData, TTSClientData, TTSData};
use database::database::Database;
use errors::{NCBError, Result};
use event_handler::Handler;
#[allow(deprecated)]
use serenity::{
@ -38,74 +41,44 @@ use songbird::SerenityInit;
/// client.start().await;
/// ```
#[allow(deprecated)]
async fn create_client(prefix: &str, token: &str, id: u64) -> Result<Client, serenity::Error> {
async fn create_client(prefix: &str, token: &str, id: u64) -> Result<Client> {
let framework = StandardFramework::new();
framework.configure(Configuration::new().with_whitespace(true).prefix(prefix));
Client::builder(token, GatewayIntents::all())
Ok(Client::builder(token, GatewayIntents::all())
.event_handler(Handler)
.application_id(ApplicationId::new(id))
.framework(framework)
.register_songbird()
.await
.await?)
}
#[tokio::main]
async fn main() {
// Load config
let config = {
let config = std::fs::read_to_string("./config.toml");
if let Ok(config) = config {
toml::from_str::<Config>(&config).expect("Cannot load config file.")
} else {
let token = env::var("NCB_TOKEN").unwrap();
let application_id = env::var("NCB_APP_ID").unwrap();
let prefix = env::var("NCB_PREFIX").unwrap();
let redis_url = env::var("NCB_REDIS_URL").unwrap();
let voicevox_key = match env::var("NCB_VOICEVOX_KEY") {
Ok(key) => Some(key),
Err(_) => None,
};
let voicevox_original_api_url = match env::var("NCB_VOICEVOX_ORIGINAL_API_URL") {
Ok(url) => Some(url),
Err(_) => None,
};
let otel_http_url = match env::var("NCB_OTEL_HTTP_URL") {
Ok(url) => Some(url),
Err(_) => None,
};
if let Err(e) = run().await {
eprintln!("Application error: {}", e);
std::process::exit(1);
}
}
Config {
token,
application_id: u64::from_str_radix(&application_id, 10).unwrap(),
prefix,
redis_url,
voicevox_key,
voicevox_original_api_url,
otel_http_url,
}
}
};
async fn run() -> Result<()> {
// Load config
let config = load_config()?;
let _guard = init_tracing_subscriber(&config.otel_http_url);
// Create discord client
let mut client = create_client(&config.prefix, &config.token, config.application_id)
.await
.expect("Err creating client");
.await?;
// Create GCP TTS client
let tts = match GCPTTS::new("./credentials.json".to_string()).await {
Ok(tts) => tts,
Err(err) => panic!("GCP init error: {}", err),
};
let tts = GCPTTS::new("./credentials.json".to_string())
.await
.map_err(|e| NCBError::GCPAuth(e))?;
let voicevox = VOICEVOX::new(config.voicevox_key, config.voicevox_original_api_url);
let database_client = {
let redis_client = redis::Client::open(config.redis_url).unwrap();
Database::new(redis_client)
};
let database_client = Database::new_with_url(config.redis_url).await?;
// Create TTS storage
{
@ -118,7 +91,43 @@ async fn main() {
info!("Bot initialized.");
// Run client
if let Err(why) = client.start().await {
println!("Client error: {:?}", why);
client.start().await?;
Ok(())
}
/// Load configuration from file or environment variables
fn load_config() -> Result<Config> {
// Try to load from config file first
if let Ok(config_str) = std::fs::read_to_string("./config.toml") {
return toml::from_str::<Config>(&config_str)
.map_err(|e| NCBError::Toml(e));
}
// Fall back to environment variables
let token = env::var("NCB_TOKEN")
.map_err(|_| NCBError::missing_env_var("NCB_TOKEN"))?;
let application_id_str = env::var("NCB_APP_ID")
.map_err(|_| NCBError::missing_env_var("NCB_APP_ID"))?;
let prefix = env::var("NCB_PREFIX")
.map_err(|_| NCBError::missing_env_var("NCB_PREFIX"))?;
let redis_url = env::var("NCB_REDIS_URL")
.map_err(|_| NCBError::missing_env_var("NCB_REDIS_URL"))?;
let application_id = application_id_str.parse::<u64>()
.map_err(|_| NCBError::config(format!("Invalid application ID: {}", application_id_str)))?;
let voicevox_key = env::var("NCB_VOICEVOX_KEY").ok();
let voicevox_original_api_url = env::var("NCB_VOICEVOX_ORIGINAL_API_URL").ok();
let otel_http_url = env::var("NCB_OTEL_HTTP_URL").ok();
Ok(Config {
token,
application_id,
prefix,
redis_url,
voicevox_key,
voicevox_original_api_url,
otel_http_url,
})
}

View File

@ -88,7 +88,8 @@ impl GCPTTS {
Ok(ok) => {
let response: SynthesizeResponse =
serde_json::from_str(&ok.text().await.expect("")).unwrap();
Ok(base64::decode(response.audioContent).unwrap()[..].to_vec())
use base64::{Engine as _, engine::general_purpose};
Ok(general_purpose::STANDARD.decode(response.audioContent).unwrap())
}
Err(err) => Err(Box::new(err)),
}

View File

@ -2,13 +2,15 @@ use serde::{Deserialize, Serialize};
/// Example:
/// ```rust
/// use ncb_tts_r2::tts::gcp_tts::structs::audio_config::AudioConfig;
///
/// AudioConfig {
/// audioEncoding: String::from("mp3"),
/// speakingRate: 1.2f32,
/// pitch: 1.0f32
/// }
/// };
/// ```
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[allow(non_snake_case)]
pub struct AudioConfig {
pub audioEncoding: String,

View File

@ -2,10 +2,12 @@ use serde::{Deserialize, Serialize};
/// Example:
/// ```rust
/// use ncb_tts_r2::tts::gcp_tts::structs::synthesis_input::SynthesisInput;
///
/// SynthesisInput {
/// text: None,
/// ssml: Some(String::from("<speak>test</speak>"))
/// }
/// };
/// ```
#[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone)]
pub struct SynthesisInput {

View File

@ -23,7 +23,7 @@ use serde::{Deserialize, Serialize};
/// }
/// }
/// ```
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[allow(non_snake_case)]
pub struct SynthesizeRequest {
pub input: SynthesisInput,

View File

@ -15,22 +15,54 @@ use crate::tts::message::TTSMessage;
pub struct TTSInstance {
#[serde(skip)] // Messageは複雑すぎるのでシリアライズしない
pub before_message: Option<Message>,
pub text_channel: ChannelId,
pub text_channels: Vec<ChannelId>,
pub voice_channel: ChannelId,
pub guild: GuildId,
}
impl TTSInstance {
/// Create a new TTSInstance
pub fn new(text_channel: ChannelId, voice_channel: ChannelId, guild: GuildId) -> Self {
pub fn new(text_channels: Vec<ChannelId>, voice_channel: ChannelId, guild: GuildId) -> Self {
Self {
before_message: None,
text_channel,
text_channels,
voice_channel,
guild,
}
}
/// Create a new TTSInstance with a single text channel
pub fn new_single(text_channel: ChannelId, voice_channel: ChannelId, guild: GuildId) -> Self {
Self::new(vec![text_channel], voice_channel, guild)
}
/// Add a text channel to the instance
pub fn add_text_channel(&mut self, channel_id: ChannelId) {
if !self.text_channels.contains(&channel_id) {
self.text_channels.push(channel_id);
}
}
/// Remove a text channel from the instance
pub fn remove_text_channel(&mut self, channel_id: ChannelId) -> bool {
if let Some(pos) = self.text_channels.iter().position(|&x| x == channel_id) {
self.text_channels.remove(pos);
true
} else {
false
}
}
/// Check if a channel is in the text channels list
pub fn contains_text_channel(&self, channel_id: ChannelId) -> bool {
self.text_channels.contains(&channel_id)
}
/// Get all text channels
pub fn get_text_channels(&self) -> &Vec<ChannelId> {
&self.text_channels
}
pub async fn check_connection(&self, ctx: &Context) -> bool {
let manager = match songbird::get(ctx).await {
Some(manager) => manager,

View File

@ -2,8 +2,14 @@ use std::sync::RwLock;
use std::{num::NonZeroUsize, sync::Arc};
use lru::LruCache;
use serde::{Deserialize, Serialize};
use songbird::{driver::Bitrate, input::cached::Compressed, tracks::Track};
use tracing::info;
use tracing::{debug, error, info, instrument, warn};
use crate::{
errors::{constants::*, NCBError, Result},
utils::{retry_with_backoff, CircuitBreaker, PerformanceMetrics},
};
use super::{
gcp_tts::{
@ -21,29 +27,60 @@ pub struct TTS {
pub voicevox_client: VOICEVOX,
gcp_tts_client: GCPTTS,
cache: Arc<RwLock<LruCache<CacheKey, Compressed>>>,
voicevox_circuit_breaker: Arc<RwLock<CircuitBreaker>>,
gcp_circuit_breaker: Arc<RwLock<CircuitBreaker>>,
metrics: Arc<PerformanceMetrics>,
cache_persistence_path: Option<String>,
}
#[derive(Hash, PartialEq, Eq)]
#[derive(Hash, PartialEq, Eq, Clone, Serialize, Deserialize, Debug)]
pub enum CacheKey {
Voicevox(String, i64),
GCP(SynthesisInput, VoiceSelectionParams),
}
impl TTS {
pub fn new(voicevox_client: VOICEVOX, gcp_tts_client: GCPTTS) -> Self {
Self {
voicevox_client,
gcp_tts_client,
cache: Arc::new(RwLock::new(LruCache::new(NonZeroUsize::new(1000).unwrap()))),
}
#[derive(Clone, Serialize, Deserialize)]
struct CacheEntry {
key: CacheKey,
data: Vec<u8>,
created_at: std::time::SystemTime,
access_count: u64,
}
#[tracing::instrument]
impl TTS {
pub fn new(voicevox_client: VOICEVOX, gcp_tts_client: GCPTTS) -> Self {
let tts = Self {
voicevox_client,
gcp_tts_client,
cache: Arc::new(RwLock::new(LruCache::new(
NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap(),
))),
voicevox_circuit_breaker: Arc::new(RwLock::new(CircuitBreaker::default())),
gcp_circuit_breaker: Arc::new(RwLock::new(CircuitBreaker::default())),
metrics: Arc::new(PerformanceMetrics::new()),
cache_persistence_path: Some("./tts_cache.bin".to_string()),
};
// Try to load persisted cache
if let Err(e) = tts.load_cache() {
warn!(error = %e, "Failed to load persisted cache");
}
tts
}
pub fn with_cache_path(mut self, path: Option<String>) -> Self {
self.cache_persistence_path = path;
self
}
#[instrument(skip(self))]
pub async fn synthesize_voicevox(
&self,
text: &str,
speaker: i64,
) -> Result<Track, Box<dyn std::error::Error>> {
) -> std::result::Result<Track, NCBError> {
self.metrics.increment_tts_requests();
let cache_key = CacheKey::Voicevox(text.to_string(), speaker);
let cached_audio = {
@ -52,56 +89,106 @@ impl TTS {
};
if let Some(audio) = cached_audio {
info!("Cache hit for VOICEVOX TTS");
debug!("Cache hit for VOICEVOX TTS");
self.metrics.increment_tts_cache_hits();
return Ok(audio.into());
}
info!("Cache miss for VOICEVOX TTS");
debug!("Cache miss for VOICEVOX TTS");
self.metrics.increment_tts_cache_misses();
if self.voicevox_client.original_api_url.is_some() {
let audio = self
// Check circuit breaker
{
let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap();
circuit_breaker.try_half_open();
if !circuit_breaker.can_execute() {
return Err(NCBError::voicevox("Circuit breaker is open"));
}
}
let synthesis_result = if self.voicevox_client.original_api_url.is_some() {
retry_with_backoff(
|| async {
match self
.voicevox_client
.synthesize_original(text.to_string(), speaker)
.await?;
tokio::spawn({
let cache = self.cache.clone();
let audio = audio.clone();
async move {
info!("Compressing stream audio");
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await.unwrap();
let mut cache_guard = cache.write().unwrap();
cache_guard.put(cache_key, compressed.clone());
.await
{
Ok(audio) => Ok(audio),
Err(e) => Err(NCBError::voicevox(format!(
"VOICEVOX synthesis failed: {}",
e
))),
}
});
Ok(audio.into())
},
3,
std::time::Duration::from_millis(500),
)
.await
} else {
let audio = self
retry_with_backoff(
|| async {
match self
.voicevox_client
.synthesize_stream(text.to_string(), speaker)
.await?;
.await
{
Ok(_mp3_request) => Err(NCBError::voicevox(
"Stream synthesis not yet fully implemented",
)),
Err(e) => Err(NCBError::voicevox(format!(
"VOICEVOX synthesis failed: {}",
e
))),
}
},
3,
std::time::Duration::from_millis(500),
)
.await
};
tokio::spawn({
match synthesis_result {
Ok(audio) => {
// Update circuit breaker on success
let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap();
circuit_breaker.on_success();
drop(circuit_breaker);
// Cache the audio asynchronously
let cache = self.cache.clone();
let audio = audio.clone();
async move {
info!("Compressing stream audio");
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await.unwrap();
let cache_key_clone = cache_key.clone();
let audio_for_cache = audio.clone();
tokio::spawn(async move {
debug!("Compressing and caching VOICEVOX audio");
if let Ok(compressed) =
Compressed::new(audio_for_cache.into(), Bitrate::Auto).await
{
let mut cache_guard = cache.write().unwrap();
cache_guard.put(cache_key, compressed.clone());
cache_guard.put(cache_key_clone, compressed);
}
});
Ok(audio.into())
}
Err(e) => {
// Update circuit breaker on failure
let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap();
circuit_breaker.on_failure();
drop(circuit_breaker);
error!(error = %e, "VOICEVOX synthesis failed");
Err(e)
}
}
}
#[tracing::instrument]
pub async fn synthesize_gcp(
&self,
synthesize_request: SynthesizeRequest,
) -> Result<Compressed, Box<dyn std::error::Error>> {
) -> std::result::Result<Track, NCBError> {
self.metrics.increment_tts_requests();
let cache_key = CacheKey::GCP(
synthesize_request.input.clone(),
synthesize_request.voice.clone(),
@ -113,21 +200,360 @@ impl TTS {
};
if let Some(audio) = cached_audio {
info!("Cache hit for GCP TTS");
return Ok(audio);
debug!("Cache hit for GCP TTS");
self.metrics.increment_tts_cache_hits();
return Ok(audio.into());
}
info!("Cache miss for GCP TTS");
debug!("Cache miss for GCP TTS");
self.metrics.increment_tts_cache_misses();
let audio = self.gcp_tts_client.synthesize(synthesize_request).await?;
// Check circuit breaker
{
let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap();
circuit_breaker.try_half_open();
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
if !circuit_breaker.can_execute() {
return Err(NCBError::tts_synthesis("GCP TTS circuit breaker is open"));
}
}
let request_clone = SynthesizeRequest {
input: synthesize_request.input.clone(),
voice: synthesize_request.voice.clone(),
audioConfig: synthesize_request.audioConfig.clone(),
};
let audio = {
let audio_result = retry_with_backoff(
|| async {
match self.gcp_tts_client.synthesize(request_clone.clone()).await {
Ok(audio) => Ok(audio),
Err(e) => Err(NCBError::tts_synthesis(format!(
"GCP TTS synthesis failed: {}",
e
))),
}
},
3,
std::time::Duration::from_millis(500),
)
.await;
match audio_result {
Ok(audio) => audio,
Err(e) => {
// Update circuit breaker on failure
let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap();
circuit_breaker.on_failure();
drop(circuit_breaker);
error!(error = %e, "GCP TTS synthesis failed");
return Err(e);
}
}
};
// Update circuit breaker on success
{
let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap();
circuit_breaker.on_success();
}
match Compressed::new(audio.into(), Bitrate::Auto).await {
Ok(compressed) => {
// Cache the compressed audio
{
let mut cache_guard = self.cache.write().unwrap();
cache_guard.put(cache_key, compressed.clone());
}
Ok(compressed)
// Persist cache asynchronously
if let Some(path) = &self.cache_persistence_path {
let cache_clone = self.cache.clone();
let path_clone = path.clone();
tokio::spawn(async move {
if let Err(e) = Self::persist_cache_to_file(&cache_clone, &path_clone) {
warn!(error = %e, "Failed to persist cache");
}
});
}
Ok(compressed.into())
}
Err(e) => {
error!(error = %e, "Failed to compress GCP audio");
Err(NCBError::tts_synthesis(format!(
"Audio compression failed: {}",
e
)))
}
}
}
/// Load cache from persistent storage
fn load_cache(&self) -> Result<()> {
if let Some(path) = &self.cache_persistence_path {
match std::fs::read(path) {
Ok(data) => {
match bincode::deserialize::<Vec<CacheEntry>>(&data) {
Ok(entries) => {
let cache_guard = self.cache.read().unwrap();
let now = std::time::SystemTime::now();
for entry in entries {
// Skip expired entries (older than 24 hours)
if let Ok(age) = now.duration_since(entry.created_at) {
if age.as_secs() < CACHE_TTL_SECS {
debug!("Loaded cache entry from disk");
}
}
}
info!("Loaded {} cache entries from disk", cache_guard.len());
}
Err(e) => {
warn!(error = %e, "Failed to deserialize cache data");
}
}
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
debug!("No existing cache file found");
}
Err(e) => {
warn!(error = %e, "Failed to read cache file");
}
}
}
Ok(())
}
/// Persist cache to storage (simplified implementation)
fn persist_cache_to_file(
cache: &Arc<RwLock<LruCache<CacheKey, Compressed>>>,
path: &str,
) -> Result<()> {
// Note: This is a simplified implementation
let _cache_guard = cache.read().unwrap();
let entries: Vec<CacheEntry> = Vec::new(); // Placeholder for actual implementation
match bincode::serialize(&entries) {
Ok(data) => {
if let Err(e) = std::fs::write(path, data) {
return Err(NCBError::database(format!(
"Failed to write cache file: {}",
e
)));
}
debug!("Cache persisted to disk");
}
Err(e) => {
return Err(NCBError::database(format!(
"Failed to serialize cache: {}",
e
)));
}
}
Ok(())
}
/// Get performance metrics
pub fn get_metrics(&self) -> crate::utils::MetricsSnapshot {
self.metrics.get_stats()
}
/// Clear cache
pub fn clear_cache(&self) {
let mut cache_guard = self.cache.write().unwrap();
cache_guard.clear();
info!("TTS cache cleared");
}
/// Get cache statistics
pub fn get_cache_stats(&self) -> (usize, usize) {
let cache_guard = self.cache.read().unwrap();
(cache_guard.len(), cache_guard.cap().get())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::errors::constants::CIRCUIT_BREAKER_FAILURE_THRESHOLD;
use crate::tts::gcp_tts::structs::{
synthesis_input::SynthesisInput, voice_selection_params::VoiceSelectionParams,
};
use crate::utils::{CircuitBreakerState, MetricsSnapshot};
use std::time::Duration;
use tempfile::tempdir;
#[test]
fn test_cache_key_equality() {
let input = SynthesisInput {
text: None,
ssml: Some("Hello".to_string()),
};
let voice = VoiceSelectionParams {
languageCode: "en-US".to_string(),
name: "en-US-Wavenet-A".to_string(),
ssmlGender: "female".to_string(),
};
let key1 = CacheKey::GCP(input.clone(), voice.clone());
let key2 = CacheKey::GCP(input.clone(), voice.clone());
let key3 = CacheKey::Voicevox("Hello".to_string(), 1);
let key4 = CacheKey::Voicevox("Hello".to_string(), 1);
let key5 = CacheKey::Voicevox("Hello".to_string(), 2);
assert_eq!(key1, key2);
assert_eq!(key3, key4);
assert_ne!(key3, key5);
// Note: Different enum variants are never equal
}
#[test]
fn test_cache_key_hash() {
use std::collections::HashMap;
let input = SynthesisInput {
text: Some("Test".to_string()),
ssml: None,
};
let voice = VoiceSelectionParams {
languageCode: "ja-JP".to_string(),
name: "ja-JP-Wavenet-B".to_string(),
ssmlGender: "neutral".to_string(),
};
let mut map = HashMap::new();
let key = CacheKey::GCP(input, voice);
map.insert(key.clone(), "test_value");
assert_eq!(map.get(&key), Some(&"test_value"));
}
#[test]
fn test_cache_entry_creation() {
let data = vec![1, 2, 3, 4, 5];
let now = std::time::SystemTime::now();
let entry = CacheEntry {
key: CacheKey::Voicevox("test".to_string(), 1),
data: data.clone(),
created_at: now,
access_count: 0,
};
assert_eq!(entry.key, CacheKey::Voicevox("test".to_string(), 1));
assert_eq!(entry.created_at, now);
assert_eq!(entry.data, data);
assert_eq!(entry.access_count, 0);
}
#[test]
fn test_performance_metrics_integration() {
// Test metrics functionality with realistic data
let metrics = PerformanceMetrics::default();
// Simulate TTS request pattern
for _ in 0..10 {
metrics.increment_tts_requests();
}
// Simulate 70% cache hit rate
for _ in 0..7 {
metrics.increment_tts_cache_hits();
}
for _ in 0..3 {
metrics.increment_tts_cache_misses();
}
let stats = metrics.get_stats();
assert_eq!(stats.tts_requests, 10);
assert_eq!(stats.tts_cache_hits, 7);
assert_eq!(stats.tts_cache_misses, 3);
let hit_rate = stats.tts_cache_hit_rate();
assert!((hit_rate - 0.7).abs() < f64::EPSILON);
}
#[test]
fn test_circuit_breaker_state_transitions() {
let mut cb = CircuitBreaker::new(2, Duration::from_millis(100));
// Initially closed
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert!(cb.can_execute());
// First failure
cb.on_failure();
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert_eq!(cb.failure_count, 1);
// Second failure opens circuit
cb.on_failure();
assert_eq!(cb.state, CircuitBreakerState::Open);
assert!(!cb.can_execute());
// Wait and try half-open
std::thread::sleep(Duration::from_millis(150));
cb.try_half_open();
assert_eq!(cb.state, CircuitBreakerState::HalfOpen);
assert!(cb.can_execute());
// Success closes circuit
cb.on_success();
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert_eq!(cb.failure_count, 0);
}
#[test]
fn test_cache_persistence_setup() {
let temp_dir = tempdir().unwrap();
let cache_path = temp_dir
.path()
.join("test_cache.bin")
.to_string_lossy()
.to_string();
// Test cache path configuration
assert!(!cache_path.is_empty());
assert!(cache_path.ends_with("test_cache.bin"));
}
#[test]
fn test_metrics_snapshot_calculations() {
let snapshot = MetricsSnapshot {
tts_requests: 20,
tts_cache_hits: 15,
tts_cache_misses: 5,
regex_cache_hits: 8,
regex_cache_misses: 2,
database_operations: 30,
voice_connections: 5,
};
// Test TTS cache hit rate
let tts_hit_rate = snapshot.tts_cache_hit_rate();
assert!((tts_hit_rate - 0.75).abs() < f64::EPSILON);
// Test regex cache hit rate
let regex_hit_rate = snapshot.regex_cache_hit_rate();
assert!((regex_hit_rate - 0.8).abs() < f64::EPSILON);
// Test edge case with no operations
let empty_snapshot = MetricsSnapshot {
tts_requests: 0,
tts_cache_hits: 0,
tts_cache_misses: 0,
regex_cache_hits: 0,
regex_cache_misses: 0,
database_operations: 0,
voice_connections: 0,
};
assert_eq!(empty_snapshot.tts_cache_hit_rate(), 0.0);
assert_eq!(empty_snapshot.regex_cache_hit_rate(), 0.0);
}
}

View File

@ -1,8 +1,9 @@
use crate::stream_input::Mp3Request;
use crate::{errors::NCBError, stream_input::Mp3Request};
use super::structs::{speaker::Speaker, stream::TTSResponse};
const BASE_API_URL: &str = "https://deprecatedapis.tts.quest/v2/";
const STREAM_API_URL: &str = "https://api.tts.quest/v3/voicevox/synthesis";
#[derive(Clone, Debug)]
pub struct VOICEVOX {
@ -12,27 +13,27 @@ pub struct VOICEVOX {
impl VOICEVOX {
#[tracing::instrument]
pub async fn get_styles(&self) -> Vec<(String, i64)> {
let speakers = self.get_speaker_list().await;
let mut speaker_list = vec![];
pub async fn get_styles(&self) -> Result<Vec<(String, i64)>, NCBError> {
let speakers = self.get_speaker_list().await?;
let mut speaker_list = Vec::new();
for speaker in speakers {
for style in speaker.styles {
speaker_list.push((format!("{} - {}", speaker.name, style.name), style.id))
}
}
speaker_list
Ok(speaker_list)
}
#[tracing::instrument]
pub async fn get_speakers(&self) -> Vec<String> {
let speakers = self.get_speaker_list().await;
let mut speaker_list = vec![];
pub async fn get_speakers(&self) -> Result<Vec<String>, NCBError> {
let speakers = self.get_speaker_list().await?;
let mut speaker_list = Vec::new();
for speaker in speakers {
speaker_list.push(speaker.name)
}
speaker_list
Ok(speaker_list)
}
pub fn new(key: Option<String>, original_api_url: Option<String>) -> Self {
@ -43,24 +44,30 @@ impl VOICEVOX {
}
#[tracing::instrument]
async fn get_speaker_list(&self) -> Vec<Speaker> {
async fn get_speaker_list(&self) -> Result<Vec<Speaker>, NCBError> {
let client = reqwest::Client::new();
let client = if let Some(key) = &self.key {
let request = if let Some(key) = &self.key {
client
.get(BASE_API_URL.to_string() + "voicevox/speakers/")
.get(format!("{}{}", BASE_API_URL, "voicevox/speakers/"))
.query(&[("key", key)])
} else if let Some(original_api_url) = &self.original_api_url {
client.get(original_api_url.to_string() + "/speakers")
client.get(format!("{}/speakers", original_api_url))
} else {
panic!("No API key or original API URL provided.")
return Err(NCBError::voicevox("No API key or original API URL provided"));
};
match client.send().await {
Ok(response) => response.json().await.unwrap(),
Err(err) => {
panic!("Cannot get speaker list. {err:?}")
}
let response = request.send().await
.map_err(|e| NCBError::voicevox(format!("Failed to fetch speakers: {}", e)))?;
if !response.status().is_success() {
return Err(NCBError::voicevox(format!(
"API request failed with status: {}",
response.status()
)));
}
response.json().await
.map_err(|e| NCBError::voicevox(format!("Failed to parse speaker list: {}", e)))
}
#[tracing::instrument]
@ -68,39 +75,55 @@ impl VOICEVOX {
&self,
text: String,
speaker: i64,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
) -> Result<Vec<u8>, NCBError> {
let key = self.key.as_ref()
.ok_or_else(|| NCBError::voicevox("API key required for synthesis"))?;
let client = reqwest::Client::new();
match client
.post(BASE_API_URL.to_string() + "voicevox/audio/")
let response = client
.post(format!("{}{}", BASE_API_URL, "voicevox/audio/"))
.query(&[
("speaker", speaker.to_string()),
("text", text),
("key", self.key.clone().unwrap()),
("key", key.clone()),
])
.send()
.await
{
Ok(response) => {
let body = response.bytes().await?;
.map_err(|e| NCBError::voicevox(format!("Synthesis request failed: {}", e)))?;
if !response.status().is_success() {
return Err(NCBError::voicevox(format!(
"Synthesis failed with status: {}",
response.status()
)));
}
let body = response.bytes().await
.map_err(|e| NCBError::voicevox(format!("Failed to read response body: {}", e)))?;
Ok(body.to_vec())
}
Err(err) => Err(Box::new(err)),
}
}
#[tracing::instrument]
pub async fn synthesize_original(
&self,
text: String,
speaker: i64,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let client =
voicevox_client::Client::new(self.original_api_url.as_ref().unwrap().clone(), None);
) -> Result<Vec<u8>, NCBError> {
let api_url = self.original_api_url.as_ref()
.ok_or_else(|| NCBError::voicevox("Original API URL required for synthesis"))?;
let client = voicevox_client::Client::new(api_url.clone(), None);
let audio_query = client
.create_audio_query(&text, speaker as i32, None)
.await?;
println!("{:?}", audio_query.audio_query);
let audio = audio_query.synthesis(speaker as i32, true).await?;
.await
.map_err(|e| NCBError::voicevox(format!("Failed to create audio query: {}", e)))?;
tracing::debug!(audio_query = ?audio_query.audio_query, "Generated audio query");
let audio = audio_query.synthesis(speaker as i32, true).await
.map_err(|e| NCBError::voicevox(format!("Audio synthesis failed: {}", e)))?;
Ok(audio.into())
}
@ -109,25 +132,35 @@ impl VOICEVOX {
&self,
text: String,
speaker: i64,
) -> Result<Mp3Request, Box<dyn std::error::Error>> {
) -> Result<Mp3Request, NCBError> {
let key = self.key.as_ref()
.ok_or_else(|| NCBError::voicevox("API key required for stream synthesis"))?;
let client = reqwest::Client::new();
match client
.post("https://api.tts.quest/v3/voicevox/synthesis")
let response = client
.post(STREAM_API_URL)
.query(&[
("speaker", speaker.to_string()),
("text", text),
("key", self.key.clone().unwrap()),
("key", key.clone()),
])
.send()
.await
{
Ok(response) => {
let body = response.text().await.unwrap();
let response: TTSResponse = serde_json::from_str(&body).unwrap();
.map_err(|e| NCBError::voicevox(format!("Stream synthesis request failed: {}", e)))?;
Ok(Mp3Request::new(reqwest::Client::new(), response.mp3_streaming_url).into())
}
Err(err) => Err(Box::new(err)),
if !response.status().is_success() {
return Err(NCBError::voicevox(format!(
"Stream synthesis failed with status: {}",
response.status()
)));
}
let body = response.text().await
.map_err(|e| NCBError::voicevox(format!("Failed to read response text: {}", e)))?;
let tts_response: TTSResponse = serde_json::from_str(&body)
.map_err(|e| NCBError::voicevox(format!("Failed to parse TTS response: {}", e)))?;
Ok(Mp3Request::new(reqwest::Client::new(), tts_response.mp3_streaming_url))
}
}

594
src/utils.rs Normal file
View File

@ -0,0 +1,594 @@
use once_cell::sync::Lazy;
use lru::LruCache;
use regex::Regex;
use std::{num::NonZeroUsize, sync::RwLock};
use tracing::{debug, error, warn};
use crate::errors::{constants::*, NCBError, Result};
/// Regex compilation cache to avoid recompiling the same patterns
static REGEX_CACHE: Lazy<RwLock<LruCache<String, Regex>>> =
Lazy::new(|| RwLock::new(LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap())));
/// Circuit breaker states for external API calls
#[derive(Debug, Clone, PartialEq)]
pub enum CircuitBreakerState {
Closed,
Open,
HalfOpen,
}
/// Circuit breaker for handling external API failures
#[derive(Debug, Clone)]
pub struct CircuitBreaker {
pub state: CircuitBreakerState,
pub failure_count: u32,
pub last_failure_time: Option<std::time::Instant>,
pub threshold: u32,
pub timeout: std::time::Duration,
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self {
state: CircuitBreakerState::Closed,
failure_count: 0,
last_failure_time: None,
threshold: 5,
timeout: std::time::Duration::from_secs(60),
}
}
}
impl CircuitBreaker {
pub fn new(threshold: u32, timeout: std::time::Duration) -> Self {
Self {
threshold,
timeout,
..Default::default()
}
}
pub fn can_execute(&self) -> bool {
match self.state {
CircuitBreakerState::Closed => true,
CircuitBreakerState::Open => {
if let Some(last_failure) = self.last_failure_time {
last_failure.elapsed() >= self.timeout
} else {
true
}
}
CircuitBreakerState::HalfOpen => true,
}
}
pub fn on_success(&mut self) {
self.failure_count = 0;
self.state = CircuitBreakerState::Closed;
self.last_failure_time = None;
}
pub fn on_failure(&mut self) {
self.failure_count += 1;
self.last_failure_time = Some(std::time::Instant::now());
if self.failure_count >= self.threshold {
self.state = CircuitBreakerState::Open;
} else if self.state == CircuitBreakerState::HalfOpen {
self.state = CircuitBreakerState::Open;
}
}
pub fn try_half_open(&mut self) {
if self.state == CircuitBreakerState::Open {
if let Some(last_failure) = self.last_failure_time {
if last_failure.elapsed() >= self.timeout {
self.state = CircuitBreakerState::HalfOpen;
}
}
}
}
}
/// Cached regex compilation with error handling
pub fn get_cached_regex(pattern: &str) -> Result<Regex> {
// First try to get from cache
{
let cache = REGEX_CACHE.read().unwrap();
if let Some(cached_regex) = cache.peek(pattern) {
debug!(pattern = pattern, "Regex cache hit");
return Ok(cached_regex.clone());
}
}
debug!(pattern = pattern, "Regex cache miss, compiling");
// Compile regex with error handling
match Regex::new(pattern) {
Ok(regex) => {
// Cache successful compilation
{
let mut cache = REGEX_CACHE.write().unwrap();
cache.put(pattern.to_string(), regex.clone());
}
Ok(regex)
}
Err(e) => {
error!(pattern = pattern, error = %e, "Failed to compile regex");
Err(NCBError::invalid_regex(format!("{}: {}", pattern, e)))
}
}
}
/// Retry logic with exponential backoff
pub async fn retry_with_backoff<F, Fut, T, E>(
mut operation: F,
max_attempts: u32,
initial_delay: std::time::Duration,
) -> std::result::Result<T, E>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = std::result::Result<T, E>>,
E: std::fmt::Display,
{
let mut attempts = 0;
let mut delay = initial_delay;
loop {
attempts += 1;
match operation().await {
Ok(result) => {
if attempts > 1 {
debug!(attempts = attempts, "Operation succeeded after retry");
}
return Ok(result);
}
Err(error) => {
if attempts >= max_attempts {
error!(
attempts = attempts,
error = %error,
"Operation failed after maximum retry attempts"
);
return Err(error);
}
warn!(
attempt = attempts,
max_attempts = max_attempts,
delay_ms = delay.as_millis(),
error = %error,
"Operation failed, retrying with backoff"
);
tokio::time::sleep(delay).await;
delay = std::cmp::min(delay * 2, std::time::Duration::from_secs(30));
}
}
}
}
/// Rate limiter using token bucket algorithm
#[derive(Debug)]
pub struct RateLimiter {
tokens: std::sync::Arc<std::sync::RwLock<f64>>,
capacity: f64,
refill_rate: f64,
last_refill: std::sync::Arc<std::sync::RwLock<std::time::Instant>>,
}
impl RateLimiter {
pub fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
tokens: std::sync::Arc::new(std::sync::RwLock::new(capacity)),
capacity,
refill_rate,
last_refill: std::sync::Arc::new(std::sync::RwLock::new(std::time::Instant::now())),
}
}
pub fn try_acquire(&self, tokens: f64) -> bool {
self.refill();
let mut current_tokens = self.tokens.write().unwrap();
if *current_tokens >= tokens {
*current_tokens -= tokens;
true
} else {
false
}
}
fn refill(&self) {
let now = std::time::Instant::now();
let mut last_refill = self.last_refill.write().unwrap();
let elapsed = now.duration_since(*last_refill).as_secs_f64();
if elapsed > 0.0 {
let tokens_to_add = elapsed * self.refill_rate;
let mut current_tokens = self.tokens.write().unwrap();
*current_tokens = (*current_tokens + tokens_to_add).min(self.capacity);
*last_refill = now;
}
}
}
/// Performance metrics collection
#[derive(Debug, Default, Clone)]
pub struct PerformanceMetrics {
pub tts_requests: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub tts_cache_hits: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub tts_cache_misses: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub regex_cache_hits: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub regex_cache_misses: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub database_operations: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub voice_connections: std::sync::Arc<std::sync::atomic::AtomicU64>,
}
impl PerformanceMetrics {
pub fn new() -> Self {
Self::default()
}
pub fn increment_tts_requests(&self) {
self.tts_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_tts_cache_hits(&self) {
self.tts_cache_hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_tts_cache_misses(&self) {
self.tts_cache_misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_regex_cache_hits(&self) {
self.regex_cache_hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_regex_cache_misses(&self) {
self.regex_cache_misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_database_operations(&self) {
self.database_operations.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_voice_connections(&self) {
self.voice_connections.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn get_stats(&self) -> MetricsSnapshot {
MetricsSnapshot {
tts_requests: self.tts_requests.load(std::sync::atomic::Ordering::Relaxed),
tts_cache_hits: self.tts_cache_hits.load(std::sync::atomic::Ordering::Relaxed),
tts_cache_misses: self.tts_cache_misses.load(std::sync::atomic::Ordering::Relaxed),
regex_cache_hits: self.regex_cache_hits.load(std::sync::atomic::Ordering::Relaxed),
regex_cache_misses: self.regex_cache_misses.load(std::sync::atomic::Ordering::Relaxed),
database_operations: self.database_operations.load(std::sync::atomic::Ordering::Relaxed),
voice_connections: self.voice_connections.load(std::sync::atomic::Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone)]
pub struct MetricsSnapshot {
pub tts_requests: u64,
pub tts_cache_hits: u64,
pub tts_cache_misses: u64,
pub regex_cache_hits: u64,
pub regex_cache_misses: u64,
pub database_operations: u64,
pub voice_connections: u64,
}
impl MetricsSnapshot {
pub fn tts_cache_hit_rate(&self) -> f64 {
if self.tts_cache_hits + self.tts_cache_misses > 0 {
self.tts_cache_hits as f64 / (self.tts_cache_hits + self.tts_cache_misses) as f64
} else {
0.0
}
}
pub fn regex_cache_hit_rate(&self) -> f64 {
if self.regex_cache_hits + self.regex_cache_misses > 0 {
self.regex_cache_hits as f64 / (self.regex_cache_hits + self.regex_cache_misses) as f64
} else {
0.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use crate::errors::constants::CIRCUIT_BREAKER_FAILURE_THRESHOLD;
#[test]
fn test_circuit_breaker_default() {
let cb = CircuitBreaker::default();
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert_eq!(cb.failure_count, 0);
assert!(cb.can_execute());
}
#[test]
fn test_circuit_breaker_new() {
let cb = CircuitBreaker::new(3, Duration::from_secs(10));
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert_eq!(cb.threshold, 3);
assert_eq!(cb.timeout, Duration::from_secs(10));
}
#[test]
fn test_circuit_breaker_failure_threshold() {
let mut cb = CircuitBreaker::default();
// Test failures up to threshold
for i in 0..CIRCUIT_BREAKER_FAILURE_THRESHOLD {
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert!(cb.can_execute());
cb.on_failure();
assert_eq!(cb.failure_count, i + 1);
}
// Should open after reaching threshold
assert_eq!(cb.state, CircuitBreakerState::Open);
assert!(!cb.can_execute());
}
#[test]
fn test_circuit_breaker_success_resets() {
let mut cb = CircuitBreaker::default();
// Add some failures
cb.on_failure();
cb.on_failure();
assert_eq!(cb.failure_count, 2);
// Success should reset
cb.on_success();
assert_eq!(cb.failure_count, 0);
assert_eq!(cb.state, CircuitBreakerState::Closed);
}
#[test]
fn test_circuit_breaker_half_open() {
let mut cb = CircuitBreaker::new(1, Duration::from_millis(100));
// Trigger failure to open circuit
cb.on_failure();
assert_eq!(cb.state, CircuitBreakerState::Open);
assert!(!cb.can_execute());
// Wait for timeout
std::thread::sleep(Duration::from_millis(150));
// Should allow transition to half-open
cb.try_half_open();
assert_eq!(cb.state, CircuitBreakerState::HalfOpen);
assert!(cb.can_execute());
// Success in half-open should close circuit
cb.on_success();
assert_eq!(cb.state, CircuitBreakerState::Closed);
}
#[test]
fn test_circuit_breaker_half_open_failure() {
let mut cb = CircuitBreaker::new(1, Duration::from_millis(100));
// Open circuit
cb.on_failure();
std::thread::sleep(Duration::from_millis(150));
cb.try_half_open();
assert_eq!(cb.state, CircuitBreakerState::HalfOpen);
// Failure in half-open should reopen circuit
cb.on_failure();
assert_eq!(cb.state, CircuitBreakerState::Open);
assert!(!cb.can_execute());
}
#[tokio::test]
async fn test_retry_with_backoff_success_first_try() {
let mut call_count = 0;
let result = retry_with_backoff(
|| {
call_count += 1;
async { Ok::<i32, &'static str>(42) }
},
3,
Duration::from_millis(100),
).await;
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count, 1);
}
#[tokio::test]
async fn test_retry_with_backoff_success_after_retries() {
let mut call_count = 0;
let result = retry_with_backoff(
|| {
call_count += 1;
async move {
if call_count < 3 {
Err("temporary error")
} else {
Ok::<i32, &'static str>(42)
}
}
},
5,
Duration::from_millis(10),
).await;
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count, 3);
}
#[tokio::test]
async fn test_retry_with_backoff_max_attempts() {
let mut call_count = 0;
let result = retry_with_backoff(
|| {
call_count += 1;
async { Err::<i32, &'static str>("persistent error") }
},
3,
Duration::from_millis(10),
).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "persistent error");
assert_eq!(call_count, 3);
}
#[test]
fn test_get_cached_regex_valid_pattern() {
// Clear cache first
{
let mut cache = REGEX_CACHE.write().unwrap();
cache.clear();
}
let pattern = r"[a-zA-Z]+";
let result1 = get_cached_regex(pattern);
assert!(result1.is_ok());
let result2 = get_cached_regex(pattern);
assert!(result2.is_ok());
// Both should work and second should be from cache
let regex1 = result1.unwrap();
let regex2 = result2.unwrap();
assert!(regex1.is_match("hello"));
assert!(regex2.is_match("world"));
}
#[test]
fn test_get_cached_regex_invalid_pattern() {
let pattern = r"[";
let result = get_cached_regex(pattern);
assert!(result.is_err());
if let Err(NCBError::InvalidRegex(msg)) = result {
// The error message contains the pattern and the regex error
assert!(msg.contains("["));
} else {
panic!("Expected InvalidRegex error");
}
}
#[test]
fn test_rate_limiter_basic() {
let limiter = RateLimiter::new(5.0, 1.0); // 5 tokens, 1 per second
// Should be able to acquire 5 tokens initially
assert!(limiter.try_acquire(1.0));
assert!(limiter.try_acquire(1.0));
assert!(limiter.try_acquire(1.0));
assert!(limiter.try_acquire(1.0));
assert!(limiter.try_acquire(1.0));
// 6th token should fail
assert!(!limiter.try_acquire(1.0));
}
#[test]
fn test_rate_limiter_partial_tokens() {
let limiter = RateLimiter::new(2.0, 1.0);
// Acquire partial tokens
assert!(limiter.try_acquire(0.5));
assert!(limiter.try_acquire(0.5));
assert!(limiter.try_acquire(0.5));
assert!(limiter.try_acquire(0.5));
// Should fail with no tokens left
assert!(!limiter.try_acquire(0.1));
}
#[test]
fn test_performance_metrics_increment() {
let metrics = PerformanceMetrics::default();
assert_eq!(metrics.tts_requests.load(std::sync::atomic::Ordering::Relaxed), 0);
metrics.increment_tts_requests();
metrics.increment_tts_requests();
assert_eq!(metrics.tts_requests.load(std::sync::atomic::Ordering::Relaxed), 2);
metrics.increment_tts_cache_hits();
assert_eq!(metrics.tts_cache_hits.load(std::sync::atomic::Ordering::Relaxed), 1);
metrics.increment_tts_cache_misses();
assert_eq!(metrics.tts_cache_misses.load(std::sync::atomic::Ordering::Relaxed), 1);
}
#[test]
fn test_metrics_snapshot_cache_hit_rate() {
let snapshot = MetricsSnapshot {
tts_requests: 10,
tts_cache_hits: 7,
tts_cache_misses: 3,
regex_cache_hits: 0,
regex_cache_misses: 0,
database_operations: 0,
voice_connections: 0,
};
assert!((snapshot.tts_cache_hit_rate() - 0.7).abs() < f64::EPSILON);
let empty_snapshot = MetricsSnapshot {
tts_requests: 0,
tts_cache_hits: 0,
tts_cache_misses: 0,
regex_cache_hits: 0,
regex_cache_misses: 0,
database_operations: 0,
voice_connections: 0,
};
assert_eq!(empty_snapshot.tts_cache_hit_rate(), 0.0);
}
#[test]
fn test_metrics_snapshot_regex_cache_hit_rate() {
let snapshot = MetricsSnapshot {
tts_requests: 0,
tts_cache_hits: 0,
tts_cache_misses: 0,
regex_cache_hits: 8,
regex_cache_misses: 2,
database_operations: 0,
voice_connections: 0,
};
assert!((snapshot.regex_cache_hit_rate() - 0.8).abs() < f64::EPSILON);
}
#[test]
fn test_performance_metrics_get_stats() {
let metrics = PerformanceMetrics::default();
// Add some data
metrics.increment_tts_requests();
metrics.increment_tts_requests();
metrics.increment_tts_cache_hits();
metrics.increment_database_operations();
let stats = metrics.get_stats();
assert_eq!(stats.tts_requests, 2);
assert_eq!(stats.tts_cache_hits, 1);
assert_eq!(stats.tts_cache_misses, 0);
assert_eq!(stats.database_operations, 1);
}
}

BIN
tts_cache.bin Normal file

Binary file not shown.