diff --git a/Cargo.toml b/Cargo.toml index dae0969..a98aff6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,14 @@ name = "ncb-tts-r2" version = "1.11.2" edition = "2021" +[lib] +name = "ncb_tts_r2" +path = "src/lib.rs" + +[[bin]] +name = "ncb-tts-r2" +path = "src/main.rs" + # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] @@ -13,10 +21,15 @@ gcp_auth = "0.5.0" reqwest = { version = "0.12.9", features = ["json"] } base64 = "0.22.1" async-trait = "0.1.57" -redis = "0.29.2" +redis = { version = "0.29.2", features = ["aio", "tokio-comp"] } +bb8 = "0.8" +bb8-redis = "0.16" +thiserror = "1.0" regex = "1" tracing-subscriber = "0.3.19" lru = "0.13.0" +once_cell = "1.19" +bincode = "1.3" tracing = "0.1.41" opentelemetry_sdk = { version = "0.29.0", features = ["trace"] } opentelemetry = "0.29.1" @@ -61,3 +74,9 @@ features = [ [dependencies.tokio] version = "1.0" features = ["macros", "rt-multi-thread"] + +[dev-dependencies] +tokio-test = "0.4" +mockall = "0.12" +tempfile = "3.8" +serial_test = "3.0" diff --git a/src/commands/config.rs b/src/commands/config.rs index 7c88f92..a1964f7 100644 --- a/src/commands/config.rs +++ b/src/commands/config.rs @@ -34,7 +34,11 @@ pub async fn config_command( let tts_client = data_read .get::() .expect("Cannot get TTSClientData"); - let voicevox_speakers = tts_client.voicevox_client.get_styles().await; + let voicevox_speakers = tts_client.voicevox_client.get_styles().await + .unwrap_or_else(|e| { + tracing::error!("Failed to get VOICEVOX styles: {}", e); + vec![("VOICEVOX API unavailable".to_string(), 1)] + }); let voicevox_speaker = config.voicevox_speaker.unwrap_or(1); let tts_type = config.tts_type.unwrap_or(TTSType::GCP); diff --git a/src/commands/setup.rs b/src/commands/setup.rs index fb62434..eb30765 100644 --- a/src/commands/setup.rs +++ b/src/commands/setup.rs @@ -149,7 +149,11 @@ pub async fn setup_command( let tts_client = data .get::() .expect("Cannot get TTSClientData"); - let voicevox_speakers = tts_client.voicevox_client.get_speakers().await; + let voicevox_speakers = tts_client.voicevox_client.get_speakers().await + .unwrap_or_else(|e| { + tracing::error!("Failed to get VOICEVOX speakers: {}", e); + vec!["VOICEVOX API unavailable".to_string()] + }); text_channel_id .send_message(&ctx.http, CreateMessage::new() diff --git a/src/connection_monitor.rs b/src/connection_monitor.rs index df25eef..c35bbec 100644 --- a/src/connection_monitor.rs +++ b/src/connection_monitor.rs @@ -1,34 +1,75 @@ -use serenity::{model::channel::Message, prelude::Context, all::{CreateMessage, CreateEmbed}}; +use serenity::{prelude::Context, all::{CreateMessage, CreateEmbed}}; use std::time::Duration; use tokio::time; -use tracing::{error, info, warn}; +use tracing::{error, info, warn, instrument}; use crate::data::{DatabaseClientData, TTSData}; +/// Constants for connection monitoring +const CONNECTION_CHECK_INTERVAL_SECS: u64 = 5; +const MAX_RECONNECTION_ATTEMPTS: u32 = 3; +const RECONNECTION_BACKOFF_SECS: u64 = 2; + +/// Errors that can occur during connection monitoring +#[derive(Debug, thiserror::Error)] +pub enum ConnectionMonitorError { + #[error("Failed to get songbird manager")] + SongbirdManagerNotFound, + #[error("Failed to check voice channel users: {0}")] + VoiceChannelCheck(String), + #[error("Failed to reconnect after {attempts} attempts")] + ReconnectionFailed { attempts: u32 }, + #[error("Database operation failed: {0}")] + Database(#[from] redis::RedisError), +} + +type Result = std::result::Result; + /// Connection monitor that periodically checks voice channel connections -pub struct ConnectionMonitor; +pub struct ConnectionMonitor { + reconnection_attempts: std::collections::HashMap, +} + +impl Default for ConnectionMonitor { + fn default() -> Self { + Self::new() + } +} impl ConnectionMonitor { + pub fn new() -> Self { + Self { + reconnection_attempts: std::collections::HashMap::new(), + } + } + /// Start the connection monitoring task pub fn start(ctx: Context) { tokio::spawn(async move { - info!("Starting connection monitor with 5s interval"); - let mut interval = time::interval(Duration::from_secs(5)); + let mut monitor = ConnectionMonitor::new(); + info!( + interval_secs = CONNECTION_CHECK_INTERVAL_SECS, + "Starting connection monitor" + ); + let mut interval = time::interval(Duration::from_secs(CONNECTION_CHECK_INTERVAL_SECS)); loop { interval.tick().await; - Self::check_connections(&ctx).await; + if let Err(e) = monitor.check_connections(&ctx).await { + error!(error = %e, "Connection monitoring failed"); + } } }); } /// Check all active TTS instances and their voice channel connections - async fn check_connections(ctx: &Context) { + #[instrument(skip(self, ctx))] + async fn check_connections(&mut self, ctx: &Context) -> Result<()> { let storage_lock = { let data_read = ctx.data.read().await; data_read .get::() - .expect("Cannot get TTSStorage") + .ok_or_else(|| ConnectionMonitorError::VoiceChannelCheck("Cannot get TTSStorage".to_string()))? .clone() }; @@ -36,7 +77,7 @@ impl ConnectionMonitor { let data_read = ctx.data.read().await; data_read .get::() - .expect("Cannot get DatabaseClientData") + .ok_or_else(|| ConnectionMonitorError::VoiceChannelCheck("Cannot get DatabaseClientData".to_string()))? .clone() }; @@ -45,13 +86,8 @@ impl ConnectionMonitor { for (guild_id, instance) in storage.iter() { // Check if bot is still connected to voice channel - let manager = match songbird::get(ctx).await { - Some(manager) => manager, - None => { - error!("Cannot get songbird manager"); - continue; - } - }; + let manager = songbird::get(ctx).await + .ok_or(ConnectionMonitorError::SongbirdManagerNotFound)?; let call = manager.get(*guild_id); let is_connected = if let Some(call) = call { @@ -65,49 +101,87 @@ impl ConnectionMonitor { }; if !is_connected { - warn!("Bot disconnected from voice channel in guild {}", guild_id); + warn!(guild_id = %guild_id, "Bot disconnected from voice channel"); // Check if there are users in the voice channel - let should_reconnect = match Self::check_voice_channel_users(ctx, instance).await { + let should_reconnect = match self.check_voice_channel_users(ctx, instance).await { Ok(has_users) => has_users, - Err(_) => { - // If we can't check users, don't reconnect + Err(e) => { + warn!(guild_id = %guild_id, error = %e, "Failed to check voice channel users, skipping reconnection"); false } }; if should_reconnect { - // Try to reconnect + // Try to reconnect with retry logic + let attempts = self.reconnection_attempts.get(guild_id).copied().unwrap_or(0); + + if attempts >= MAX_RECONNECTION_ATTEMPTS { + error!( + guild_id = %guild_id, + attempts = attempts, + "Maximum reconnection attempts reached, removing instance" + ); + guilds_to_remove.push(*guild_id); + self.reconnection_attempts.remove(guild_id); + continue; + } + + // Apply exponential backoff + if attempts > 0 { + let backoff_duration = Duration::from_secs(RECONNECTION_BACKOFF_SECS * (2_u64.pow(attempts))); + warn!( + guild_id = %guild_id, + attempt = attempts + 1, + backoff_secs = backoff_duration.as_secs(), + "Applying backoff before reconnection attempt" + ); + tokio::time::sleep(backoff_duration).await; + } + match instance.reconnect(ctx, true).await { Ok(_) => { info!( - "Successfully reconnected to voice channel in guild {}", - guild_id + guild_id = %guild_id, + attempts = attempts + 1, + "Successfully reconnected to voice channel" ); + // Reset reconnection attempts on success + self.reconnection_attempts.remove(guild_id); + // Send notification message to text channel with embed let embed = CreateEmbed::new() .title("๐Ÿ”„ ่‡ชๅ‹•ๅ†ๆŽฅ็ถšใ—ใพใ—ใŸ") .description("่ชญใฟไธŠใ’ใ‚’ๅœๆญขใ—ใŸใ„ๅ ดๅˆใฏ `/stop` ใ‚ณใƒžใƒณใƒ‰ใ‚’ไฝฟ็”จใ—ใฆใใ ใ•ใ„ใ€‚") .color(0x00ff00); if let Err(e) = instance.text_channel.send_message(&ctx.http, CreateMessage::new().embed(embed)).await { - error!("Failed to send reconnection message to text channel: {}", e); + error!(guild_id = %guild_id, error = %e, "Failed to send reconnection message"); } } Err(e) => { + let new_attempts = attempts + 1; + self.reconnection_attempts.insert(*guild_id, new_attempts); error!( - "Failed to reconnect to voice channel in guild {}: {}", - guild_id, e + guild_id = %guild_id, + attempt = new_attempts, + error = %e, + "Failed to reconnect to voice channel" ); - guilds_to_remove.push(*guild_id); + + if new_attempts >= MAX_RECONNECTION_ATTEMPTS { + guilds_to_remove.push(*guild_id); + self.reconnection_attempts.remove(guild_id); + } } } } else { info!( - "No users in voice channel, removing instance for guild {}", - guild_id + guild_id = %guild_id, + "No users in voice channel, removing instance" ); guilds_to_remove.push(*guild_id); + self.reconnection_attempts.remove(guild_id); } } } @@ -118,29 +192,51 @@ impl ConnectionMonitor { // Remove from database if let Err(e) = database.remove_tts_instance(guild_id).await { - error!("Failed to remove TTS instance from database: {}", e); + error!(guild_id = %guild_id, error = %e, "Failed to remove TTS instance from database"); } // Ensure bot leaves voice channel if let Some(manager) = songbird::get(ctx).await { - let _ = manager.remove(guild_id).await; + if let Err(e) = manager.remove(guild_id).await { + error!(guild_id = %guild_id, error = %e, "Failed to remove bot from voice channel"); + } } + + info!(guild_id = %guild_id, "Removed disconnected TTS instance"); } + + Ok(()) } /// Check if there are users in the voice channel + #[instrument(skip(self, ctx, instance))] async fn check_voice_channel_users( + &self, ctx: &Context, instance: &crate::tts::instance::TTSInstance, - ) -> Result> { - let channels = instance.guild.channels(&ctx.http).await?; + ) -> Result { + let channels = instance.guild.channels(&ctx.http).await + .map_err(|e| ConnectionMonitorError::VoiceChannelCheck(format!("Failed to get guild channels: {}", e)))?; if let Some(channel) = channels.get(&instance.voice_channel) { - let members = channel.members(&ctx.cache)?; + let members = channel.members(&ctx.cache) + .map_err(|e| ConnectionMonitorError::VoiceChannelCheck(format!("Failed to get channel members: {}", e)))?; let user_count = members.iter().filter(|member| !member.user.bot).count(); + + info!( + guild_id = %instance.guild, + channel_id = %instance.voice_channel, + user_count = user_count, + "Checked voice channel users" + ); + Ok(user_count > 0) } else { - // Channel doesn't exist anymore + warn!( + guild_id = %instance.guild, + channel_id = %instance.voice_channel, + "Voice channel no longer exists" + ); Ok(false) } } diff --git a/src/database/database.rs b/src/database/database.rs index 6b4e4e7..1eacc40 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -1,88 +1,114 @@ use std::fmt::Debug; -use crate::tts::{ - gcp_tts::structs::voice_selection_params::VoiceSelectionParams, instance::TTSInstance, - tts_type::TTSType, +use bb8_redis::{bb8::Pool, RedisConnectionManager, redis::AsyncCommands}; +use crate::{ + errors::{NCBError, Result, constants::*}, + tts::{ + gcp_tts::structs::voice_selection_params::VoiceSelectionParams, instance::TTSInstance, + tts_type::TTSType, + }, }; -use serenity::model::id::GuildId; +use serenity::model::id::{GuildId, UserId, ChannelId}; +use std::collections::HashMap; use super::{dictionary::Dictionary, server_config::ServerConfig, user_config::UserConfig}; -use redis::Commands; #[derive(Debug, Clone)] pub struct Database { - pub client: redis::Client, + pub pool: Pool, } impl Database { - pub fn new(client: redis::Client) -> Self { - Self { client } + pub fn new(pool: Pool) -> Self { + Self { pool } + } + + pub async fn new_with_url(redis_url: String) -> Result { + let manager = RedisConnectionManager::new(redis_url)?; + let pool = Pool::builder() + .max_size(15) + .build(manager) + .await + .map_err(|e| NCBError::Database(format!("Pool creation failed: {}", e)))?; + Ok(Self { pool }) } fn server_key(server_id: u64) -> String { - format!("discord_server:{}", server_id) + format!("{}{}", DISCORD_SERVER_PREFIX, server_id) } fn user_key(user_id: u64) -> String { - format!("discord_user:{}", user_id) + format!("{}{}", DISCORD_USER_PREFIX, user_id) } fn tts_instance_key(guild_id: u64) -> String { - format!("tts_instance:{}", guild_id) + format!("{}{}", TTS_INSTANCE_PREFIX, guild_id) } fn tts_instances_list_key() -> String { - "tts_instances_list".to_string() + TTS_INSTANCES_LIST_KEY.to_string() + } + + fn user_config_key(guild_id: u64, user_id: u64) -> String { + format!("user:config:{}:{}", guild_id, user_id) + } + + fn server_config_key(guild_id: u64) -> String { + format!("server:config:{}", guild_id) + } + + fn dictionary_key(guild_id: u64) -> String { + format!("dictionary:{}", guild_id) } #[tracing::instrument] - fn get_config( + async fn get_config( &self, key: &str, - ) -> redis::RedisResult> { - match self.client.get_connection() { - Ok(mut connection) => { - let config: String = connection.get(key).unwrap_or_default(); + ) -> Result> { + let mut connection = self.pool.get().await + .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; + + let config: String = connection.get(key).await.unwrap_or_default(); - if config.is_empty() { - return Ok(None); - } + if config.is_empty() { + return Ok(None); + } - match serde_json::from_str(&config) { - Ok(config) => Ok(Some(config)), - Err(_) => Ok(None), - } + match serde_json::from_str(&config) { + Ok(config) => Ok(Some(config)), + Err(e) => { + tracing::warn!(key = key, error = %e, "Failed to deserialize config"); + Ok(None) } - Err(e) => Err(e), } } #[tracing::instrument] - fn set_config( + async fn set_config( &self, key: &str, config: &T, - ) -> redis::RedisResult<()> { - match self.client.get_connection() { - Ok(mut connection) => { - let config_str = serde_json::to_string(config).unwrap(); - connection.set::<_, _, ()>(key, config_str) - } - Err(e) => Err(e), - } + ) -> Result<()> { + let mut connection = self.pool.get().await + .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; + + let config_str = serde_json::to_string(config)?; + connection.set::<_, _, ()>(key, config_str).await?; + Ok(()) } #[tracing::instrument] pub async fn get_server_config( &self, server_id: u64, - ) -> redis::RedisResult> { - self.get_config(&Self::server_key(server_id)) + ) -> Result> { + self.get_config(&Self::server_key(server_id)).await } #[tracing::instrument] - pub async fn get_user_config(&self, user_id: u64) -> redis::RedisResult> { - self.get_config(&Self::user_key(user_id)) + pub async fn get_user_config(&self, user_id: u64) -> Result> { + self.get_config(&Self::user_key(user_id)).await } #[tracing::instrument] @@ -90,8 +116,8 @@ impl Database { &self, server_id: u64, config: ServerConfig, - ) -> redis::RedisResult<()> { - self.set_config(&Self::server_key(server_id), &config) + ) -> Result<()> { + self.set_config(&Self::server_key(server_id), &config).await } #[tracing::instrument] @@ -99,12 +125,12 @@ impl Database { &self, user_id: u64, config: UserConfig, - ) -> redis::RedisResult<()> { - self.set_config(&Self::user_key(user_id), &config) + ) -> Result<()> { + self.set_config(&Self::user_key(user_id), &config).await } #[tracing::instrument] - pub async fn set_default_server_config(&self, server_id: u64) -> redis::RedisResult<()> { + pub async fn set_default_server_config(&self, server_id: u64) -> Result<()> { let config = ServerConfig { dictionary: Dictionary::new(), autostart_channel_id: None, @@ -116,7 +142,7 @@ impl Database { } #[tracing::instrument] - pub async fn set_default_user_config(&self, user_id: u64) -> redis::RedisResult<()> { + pub async fn set_default_user_config(&self, user_id: u64) -> Result<()> { let voice_selection = VoiceSelectionParams { languageCode: String::from("ja-JP"), name: String::from("ja-JP-Wavenet-B"), @@ -126,7 +152,7 @@ impl Database { let config = UserConfig { tts_type: Some(TTSType::GCP), gcp_tts_voice: Some(voice_selection), - voicevox_speaker: Some(1), + voicevox_speaker: Some(DEFAULT_VOICEVOX_SPEAKER), }; self.set_user_config(user_id, config).await @@ -136,7 +162,7 @@ impl Database { pub async fn get_server_config_or_default( &self, server_id: u64, - ) -> redis::RedisResult> { + ) -> Result> { match self.get_server_config(server_id).await? { Some(config) => Ok(Some(config)), None => { @@ -150,7 +176,7 @@ impl Database { pub async fn get_user_config_or_default( &self, user_id: u64, - ) -> redis::RedisResult> { + ) -> Result> { match self.get_user_config(user_id).await? { Some(config) => Ok(Some(config)), None => { @@ -161,29 +187,23 @@ impl Database { } /// Save TTS instance to database - #[tracing::instrument] pub async fn save_tts_instance( &self, guild_id: GuildId, instance: &TTSInstance, - ) -> redis::RedisResult<()> { + ) -> Result<()> { let key = Self::tts_instance_key(guild_id.get()); let list_key = Self::tts_instances_list_key(); // Save the instance - let result = self.set_config(&key, instance); + self.set_config(&key, instance).await?; // Add guild_id to the list of active instances - if result.is_ok() { - match self.client.get_connection() { - Ok(mut connection) => { - let _: redis::RedisResult<()> = connection.sadd(&list_key, guild_id.get()); - } - Err(_) => {} - } - } - - result + let mut connection = self.pool.get().await + .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; + + connection.sadd::<_, _, ()>(&list_key, guild_id.get()).await?; + Ok(()) } /// Load TTS instance from database @@ -191,47 +211,278 @@ impl Database { pub async fn load_tts_instance( &self, guild_id: GuildId, - ) -> redis::RedisResult> { + ) -> Result> { let key = Self::tts_instance_key(guild_id.get()); - self.get_config(&key) + self.get_config(&key).await } /// Remove TTS instance from database #[tracing::instrument] - pub async fn remove_tts_instance(&self, guild_id: GuildId) -> redis::RedisResult<()> { + pub async fn remove_tts_instance(&self, guild_id: GuildId) -> Result<()> { let key = Self::tts_instance_key(guild_id.get()); let list_key = Self::tts_instances_list_key(); - match self.client.get_connection() { - Ok(mut connection) => { - let _: redis::RedisResult<()> = connection.del(&key); - let _: redis::RedisResult<()> = connection.srem(&list_key, guild_id.get()); - Ok(()) - } - Err(e) => Err(e), - } + let mut connection = self.pool.get().await + .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; + + let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await; + let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.srem(&list_key, guild_id.get()).await; + + Ok(()) } /// Get all active TTS instances #[tracing::instrument] - pub async fn get_all_tts_instances(&self) -> redis::RedisResult> { + pub async fn get_all_tts_instances(&self) -> Result> { let list_key = Self::tts_instances_list_key(); - match self.client.get_connection() { - Ok(mut connection) => { - let guild_ids: Vec = connection.smembers(&list_key).unwrap_or_default(); - let mut instances = Vec::new(); + let mut connection = self.pool.get().await + .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; + + let guild_ids: Vec = connection.smembers(&list_key).await.unwrap_or_default(); + let mut instances = Vec::new(); - for guild_id in guild_ids { - let guild_id = GuildId::new(guild_id); - if let Ok(Some(instance)) = self.load_tts_instance(guild_id).await { - instances.push((guild_id, instance)); - } - } - - Ok(instances) + for guild_id in guild_ids { + let guild_id = GuildId::new(guild_id); + if let Ok(Some(instance)) = self.load_tts_instance(guild_id).await { + instances.push((guild_id, instance)); + } else { + tracing::warn!(guild_id = %guild_id, "Failed to load TTS instance"); } - Err(e) => Err(e), } + + Ok(instances) + } + + // Additional user config methods + pub async fn save_user_config( + &self, + guild_id: GuildId, + user_id: UserId, + config: &UserConfig, + ) -> Result<()> { + let key = Self::user_config_key(guild_id.get(), user_id.get()); + self.set_config(&key, config).await + } + + pub async fn load_user_config( + &self, + guild_id: GuildId, + user_id: UserId, + ) -> Result> { + let key = Self::user_config_key(guild_id.get(), user_id.get()); + self.get_config(&key).await + } + + pub async fn delete_user_config( + &self, + guild_id: GuildId, + user_id: UserId, + ) -> Result<()> { + let key = Self::user_config_key(guild_id.get(), user_id.get()); + let mut connection = self.pool.get().await + .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; + let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await; + Ok(()) + } + + // Additional server config methods + pub async fn save_server_config( + &self, + guild_id: GuildId, + config: &ServerConfig, + ) -> Result<()> { + let key = Self::server_config_key(guild_id.get()); + self.set_config(&key, config).await + } + + pub async fn load_server_config( + &self, + guild_id: GuildId, + ) -> Result> { + let key = Self::server_config_key(guild_id.get()); + self.get_config(&key).await + } + + pub async fn delete_server_config(&self, guild_id: GuildId) -> Result<()> { + let key = Self::server_config_key(guild_id.get()); + let mut connection = self.pool.get().await + .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; + let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await; + Ok(()) + } + + // Dictionary methods + pub async fn save_dictionary( + &self, + guild_id: GuildId, + dictionary: &HashMap, + ) -> Result<()> { + let key = Self::dictionary_key(guild_id.get()); + self.set_config(&key, dictionary).await + } + + pub async fn load_dictionary( + &self, + guild_id: GuildId, + ) -> Result> { + let key = Self::dictionary_key(guild_id.get()); + let dict: Option> = self.get_config(&key).await?; + Ok(dict.unwrap_or_default()) + } + + pub async fn delete_dictionary(&self, guild_id: GuildId) -> Result<()> { + let key = Self::dictionary_key(guild_id.get()); + let mut connection = self.pool.get().await + .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; + let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await; + Ok(()) + } + + pub async fn delete_tts_instance(&self, guild_id: GuildId) -> Result<()> { + self.remove_tts_instance(guild_id).await + } + + pub async fn list_active_instances(&self) -> Result> { + let list_key = Self::tts_instances_list_key(); + let mut connection = self.pool.get().await + .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; + let guild_ids: Vec = connection.smembers(&list_key).await.unwrap_or_default(); + Ok(guild_ids) } } + +#[cfg(test)] +mod tests { + use super::*; + use bb8_redis::redis::AsyncCommands; + use serial_test::serial; + use crate::errors::constants; + + // Helper function to create test database (requires Redis running) + async fn create_test_database() -> Result { + let manager = RedisConnectionManager::new("redis://127.0.0.1:6379/15")?; // Use test DB + let pool = bb8::Pool::builder() + .max_size(1) + .build(manager) + .await + .map_err(|e| NCBError::Database(format!("Pool creation failed: {}", e)))?; + + Ok(Database { pool }) + } + + #[tokio::test] + #[serial] + async fn test_database_creation() { + // This test requires Redis to be running + match create_test_database().await { + Ok(_db) => { + // Test successful creation + assert!(true); + } + Err(_) => { + // Skip test if Redis is not available + return; + } + } + } + + #[test] + fn test_key_generation() { + let guild_id = 123456789u64; + let user_id = 987654321u64; + + // Test TTS instance key + let tts_key = Database::tts_instance_key(guild_id); + assert!(tts_key.contains(&guild_id.to_string())); + + // Test TTS instances list key + let list_key = Database::tts_instances_list_key(); + assert!(!list_key.is_empty()); + + // Test user config key + let user_key = Database::user_config_key(guild_id, user_id); + assert_eq!(user_key, "user:config:123456789:987654321"); + + // Test server config key + let server_key = Database::server_config_key(guild_id); + assert_eq!(server_key, "server:config:123456789"); + + // Test dictionary key + let dict_key = Database::dictionary_key(guild_id); + assert_eq!(dict_key, "dictionary:123456789"); + } + + #[tokio::test] + #[serial] + async fn test_tts_instance_operations() { + let db = match create_test_database().await { + Ok(db) => db, + Err(_) => return, // Skip if Redis not available + }; + + let guild_id = GuildId::new(12345); + let test_instance = TTSInstance::new( + ChannelId::new(123), + ChannelId::new(456), + guild_id + ); + + // Clear any existing data + if let Ok(mut conn) = db.pool.get().await { + let _: () = conn.del(Database::tts_instance_key(guild_id.get())).await.unwrap_or_default(); + let _: () = conn.srem(Database::tts_instances_list_key(), guild_id.get()).await.unwrap_or_default(); + } else { + return; // Skip if can't get connection + } + + // Test saving TTS instance + let save_result = db.save_tts_instance(guild_id, &test_instance).await; + if save_result.is_err() { + // Skip test if Redis operations fail + return; + } + + // Test loading TTS instance + let load_result = db.load_tts_instance(guild_id).await; + if load_result.is_err() { + return; // Skip if Redis operations fail + } + + let loaded_instance = load_result.unwrap(); + if let Some(instance) = loaded_instance { + assert_eq!(instance.guild, test_instance.guild); + assert_eq!(instance.text_channel, test_instance.text_channel); + assert_eq!(instance.voice_channel, test_instance.voice_channel); + } + + // Test listing active instances + let list_result = db.list_active_instances().await; + if list_result.is_err() { + return; // Skip if Redis operations fail + } + let instances = list_result.unwrap(); + assert!(instances.contains(&guild_id.get())); + + // Test deleting TTS instance + let delete_result = db.delete_tts_instance(guild_id).await; + if delete_result.is_err() { + return; // Skip if Redis operations fail + } + + // Verify deletion + let load_after_delete = db.load_tts_instance(guild_id).await; + if load_after_delete.is_err() { + return; // Skip if Redis operations fail + } + assert!(load_after_delete.unwrap().is_none()); + } + + #[test] + fn test_database_constants() { + // Test that constants are reasonable + assert!(constants::REDIS_CONNECTION_TIMEOUT_SECS > 0); + assert!(constants::REDIS_MAX_CONNECTIONS > 0); + assert!(constants::REDIS_MIN_IDLE_CONNECTIONS <= constants::REDIS_MAX_CONNECTIONS); + } +} \ No newline at end of file diff --git a/src/errors.rs b/src/errors.rs new file mode 100644 index 0000000..8c50865 --- /dev/null +++ b/src/errors.rs @@ -0,0 +1,519 @@ +/// Custom error types for the NCB-TTS application +#[derive(Debug, thiserror::Error)] +pub enum NCBError { + #[error("Configuration error: {0}")] + Config(String), + + #[error("Database error: {0}")] + Database(String), + + #[error("VOICEVOX API error: {0}")] + VOICEVOX(String), + + #[error("Discord error: {0}")] + Discord(#[from] serenity::Error), + + #[error("TTS synthesis error: {0}")] + TTSSynthesis(String), + + #[error("GCP authentication error: {0}")] + GCPAuth(#[from] gcp_auth::Error), + + #[error("HTTP request error: {0}")] + Http(#[from] reqwest::Error), + + #[error("JSON parsing error: {0}")] + Json(#[from] serde_json::Error), + + #[error("Redis connection error: {0}")] + Redis(String), + + #[error("Redis error: {0}")] + RedisError(#[from] bb8_redis::redis::RedisError), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("Voice connection error: {0}")] + VoiceConnection(String), + + #[error("Invalid input: {0}")] + InvalidInput(String), + + #[error("Invalid regex pattern: {0}")] + InvalidRegex(String), + + #[error("Songbird error: {0}")] + Songbird(String), + + #[error("User not in voice channel")] + UserNotInVoiceChannel, + + #[error("Guild not found")] + GuildNotFound, + + #[error("Channel not found")] + ChannelNotFound, + + #[error("TTS instance not found for guild {guild_id}")] + TTSInstanceNotFound { guild_id: u64 }, + + #[error("Text too long (max {max_length} characters)")] + TextTooLong { max_length: usize }, + + #[error("Text contains prohibited content")] + ProhibitedContent, + + #[error("Rate limit exceeded")] + RateLimitExceeded, + + #[error("TOML parsing error: {0}")] + Toml(#[from] toml::de::Error), +} + +impl NCBError { + pub fn config(message: impl Into) -> Self { + Self::Config(message.into()) + } + + pub fn database(message: impl Into) -> Self { + Self::Database(message.into()) + } + + pub fn voicevox(message: impl Into) -> Self { + Self::VOICEVOX(message.into()) + } + + pub fn voice_connection(message: impl Into) -> Self { + Self::VoiceConnection(message.into()) + } + + pub fn tts_synthesis(message: impl Into) -> Self { + Self::TTSSynthesis(message.into()) + } + + pub fn invalid_input(message: impl Into) -> Self { + Self::InvalidInput(message.into()) + } + + pub fn invalid_regex(message: impl Into) -> Self { + Self::InvalidRegex(message.into()) + } + + pub fn songbird(message: impl Into) -> Self { + Self::Songbird(message.into()) + } + + pub fn tts_instance_not_found(guild_id: u64) -> Self { + Self::TTSInstanceNotFound { guild_id } + } + + pub fn text_too_long(max_length: usize) -> Self { + Self::TextTooLong { max_length } + } + + pub fn redis(message: impl Into) -> Self { + Self::Redis(message.into()) + } + + pub fn missing_env_var(var_name: &str) -> Self { + Self::Config(format!("Missing environment variable: {}", var_name)) + } +} + +/// Result type alias for convenience +pub type Result = std::result::Result; + +/// Input validation functions +pub mod validation { + use super::*; + use regex::Regex; + + /// Validate regex pattern for potential ReDoS attacks + pub fn validate_regex_pattern(pattern: &str) -> Result<()> { + // Check for common ReDoS patterns (catastrophic backtracking) + let redos_patterns = [ + r"\(\?\:", // Non-capturing groups in dangerous positions + r"\(\?\=", // Positive lookahead + r"\(\?\!", // Negative lookahead + r"\(\?\<\=", // Positive lookbehind + r"\(\?\<\!", // Negative lookbehind + r"\*\*", // Actual nested quantifiers (not possessive) + r"\+\*", // Nested quantifiers + r"\*\+", // Nested quantifiers + ]; + + for redos_pattern in &redos_patterns { + if pattern.contains(redos_pattern) { + return Err(NCBError::invalid_regex(format!( + "Pattern contains potentially dangerous construct: {}", + redos_pattern + ))); + } + } + + // Check pattern length + if pattern.len() > constants::MAX_REGEX_PATTERN_LENGTH { + return Err(NCBError::invalid_regex(format!( + "Pattern too long (max {} characters)", + constants::MAX_REGEX_PATTERN_LENGTH + ))); + } + + // Try to compile the regex to validate syntax + Regex::new(pattern) + .map_err(|e| NCBError::invalid_regex(format!("Invalid regex syntax: {}", e)))?; + + Ok(()) + } + + /// Validate rule name + pub fn validate_rule_name(name: &str) -> Result<()> { + if name.trim().is_empty() { + return Err(NCBError::invalid_input("Rule name cannot be empty")); + } + + if name.len() > constants::MAX_RULE_NAME_LENGTH { + return Err(NCBError::invalid_input(format!( + "Rule name too long (max {} characters)", + constants::MAX_RULE_NAME_LENGTH + ))); + } + + // Check for invalid characters + if !name + .chars() + .all(|c| c.is_alphanumeric() || c.is_whitespace() || "_-".contains(c)) + { + return Err(NCBError::invalid_input( + "Rule name contains invalid characters (only alphanumeric, spaces, hyphens, and underscores allowed)" + )); + } + + Ok(()) + } + + /// Validate TTS text input + pub fn validate_tts_text(text: &str) -> Result<()> { + if text.trim().is_empty() { + return Err(NCBError::invalid_input("Text cannot be empty")); + } + + if text.len() > constants::MAX_TTS_TEXT_LENGTH { + return Err(NCBError::text_too_long(constants::MAX_TTS_TEXT_LENGTH)); + } + + // Check for prohibited patterns + let prohibited_patterns = [ + r" Result<()> { + if text.trim().is_empty() { + return Err(NCBError::invalid_input("Replacement text cannot be empty")); + } + + if text.len() > constants::MAX_TTS_TEXT_LENGTH { + return Err(NCBError::text_too_long(constants::MAX_TTS_TEXT_LENGTH)); + } + + Ok(()) + } + + /// Sanitize SSML input to prevent injection attacks + pub fn sanitize_ssml(text: &str) -> String { + // Remove or escape potentially dangerous SSML tags + let _dangerous_tags = [ + "audio", "break", "emphasis", "lang", "mark", "p", "phoneme", "prosody", "say-as", + "speak", "sub", "voice", "w", + ]; + + let mut sanitized = text.to_string(); + + // Remove script-like content + sanitized = sanitized.replace(" constants::MAX_SSML_LENGTH { + sanitized.truncate(constants::MAX_SSML_LENGTH); + } + + sanitized + } +} + +/// Constants used throughout the application +pub mod constants { + // Configuration constants + pub const DEFAULT_CONFIG_PATH: &str = "config.toml"; + pub const DEFAULT_DICTIONARY_PATH: &str = "dictionary.txt"; + + // Redis constants + pub const REDIS_CONNECTION_TIMEOUT_SECS: u64 = 5; + pub const REDIS_MAX_CONNECTIONS: u32 = 10; + pub const REDIS_MIN_IDLE_CONNECTIONS: u32 = 1; + + // Cache constants + pub const DEFAULT_CACHE_SIZE: usize = 1000; + pub const CACHE_TTL_SECS: u64 = 86400; // 24 hours + + // TTS constants + pub const MAX_TTS_TEXT_LENGTH: usize = 500; + pub const MAX_SSML_LENGTH: usize = 1000; + pub const TTS_TIMEOUT_SECS: u64 = 30; + pub const DEFAULT_SPEAKING_RATE: f32 = 1.2; + pub const DEFAULT_PITCH: f32 = 0.0; + + // Validation constants + pub const MAX_REGEX_PATTERN_LENGTH: usize = 100; + pub const MAX_RULE_NAME_LENGTH: usize = 50; + pub const MAX_USERNAME_LENGTH: usize = 32; + + // Circuit breaker constants + pub const CIRCUIT_BREAKER_FAILURE_THRESHOLD: u32 = 5; + pub const CIRCUIT_BREAKER_TIMEOUT_SECS: u64 = 60; + + // Retry constants + pub const DEFAULT_MAX_RETRY_ATTEMPTS: u32 = 3; + pub const DEFAULT_RETRY_DELAY_MS: u64 = 500; + pub const MAX_RETRY_DELAY_MS: u64 = 5000; + + // Connection monitoring constants + pub const CONNECTION_CHECK_INTERVAL_SECS: u64 = 5; + pub const MAX_RECONNECTION_ATTEMPTS: u32 = 3; + pub const RECONNECTION_BACKOFF_SECS: u64 = 2; + + // Voice connection constants + pub const VOICE_CONNECTION_TIMEOUT_SECS: u64 = 10; + pub const AUDIO_BITRATE_KBPS: u32 = 128; + pub const AUDIO_SAMPLE_RATE: u32 = 48000; + + // Database key prefixes + pub const DISCORD_SERVER_PREFIX: &str = "discord:server:"; + pub const DISCORD_USER_PREFIX: &str = "discord:user:"; + pub const TTS_INSTANCE_PREFIX: &str = "tts:instance:"; + pub const TTS_INSTANCES_LIST_KEY: &str = "tts:instances"; + + // Default values + pub const DEFAULT_VOICEVOX_SPEAKER: i64 = 1; + + // Message constants + pub const RULE_ADDED: &str = "RULE_ADDED"; + pub const RULE_REMOVED: &str = "RULE_REMOVED"; + pub const RULE_ALREADY_EXISTS: &str = "RULE_ALREADY_EXISTS"; + pub const RULE_NOT_FOUND: &str = "RULE_NOT_FOUND"; + pub const DICTIONARY_RULE_APPLIED: &str = "DICTIONARY_RULE_APPLIED"; + pub const GUILD_NOT_FOUND: &str = "GUILD_NOT_FOUND"; + pub const CHANNEL_JOIN_SUCCESS: &str = "CHANNEL_JOIN_SUCCESS"; + pub const CHANNEL_LEAVE_SUCCESS: &str = "CHANNEL_LEAVE_SUCCESS"; + pub const AUTOSTART_CHANNEL_SET: &str = "AUTOSTART_CHANNEL_SET"; + pub const SET_AUTOSTART_CHANNEL_CLEAR: &str = "SET_AUTOSTART_CHANNEL_CLEAR"; + + // TTS configuration constants + pub const TTS_CONFIG_SERVER_ADD_DICTIONARY: &str = "TTS_CONFIG_SERVER_ADD_DICTIONARY"; + pub const TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE: &str = + "TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE"; + pub const TTS_CONFIG_SERVER_SET_READ_USERNAME: &str = "TTS_CONFIG_SERVER_SET_READ_USERNAME"; + pub const TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU: &str = + "TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU"; + pub const TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON: &str = + "TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON"; + pub const TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON: &str = + "TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON"; + pub const TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON: &str = + "TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON"; + pub const SET_AUTOSTART_CHANNEL: &str = "SET_AUTOSTART_CHANNEL"; + pub const TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL: &str = + "TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL"; + pub const TTS_CONFIG_SERVER_BACK: &str = "TTS_CONFIG_SERVER_BACK"; + pub const TTS_CONFIG_SERVER: &str = "TTS_CONFIG_SERVER"; + pub const TTS_CONFIG_SERVER_DICTIONARY: &str = "TTS_CONFIG_SERVER_DICTIONARY"; + + // TTS engine selection messages + pub const TTS_CONFIG_ENGINE_SELECTED_GOOGLE: &str = "TTS_CONFIG_ENGINE_SELECTED_GOOGLE"; + pub const TTS_CONFIG_ENGINE_SELECTED_VOICEVOX: &str = "TTS_CONFIG_ENGINE_SELECTED_VOICEVOX"; + + // Error messages + pub const USER_NOT_IN_VOICE_CHANNEL: &str = "USER_NOT_IN_VOICE_CHANNEL"; + pub const CHANNEL_NOT_FOUND: &str = "CHANNEL_NOT_FOUND"; + + // Rate limiting constants + pub const RATE_LIMIT_REQUESTS_PER_MINUTE: u32 = 60; + pub const RATE_LIMIT_REQUESTS_PER_HOUR: u32 = 1000; + pub const RATE_LIMIT_WINDOW_SECS: u64 = 60; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ncb_error_creation() { + let config_error = NCBError::config("Test config error"); + assert!(matches!(config_error, NCBError::Config(_))); + assert_eq!( + config_error.to_string(), + "Configuration error: Test config error" + ); + + let database_error = NCBError::database("Test database error"); + assert!(matches!(database_error, NCBError::Database(_))); + assert_eq!( + database_error.to_string(), + "Database error: Test database error" + ); + + let voicevox_error = NCBError::voicevox("Test VOICEVOX error"); + assert!(matches!(voicevox_error, NCBError::VOICEVOX(_))); + assert_eq!( + voicevox_error.to_string(), + "VOICEVOX API error: Test VOICEVOX error" + ); + } + + #[test] + fn test_tts_instance_not_found_error() { + let guild_id = 12345u64; + let error = NCBError::tts_instance_not_found(guild_id); + assert!(matches!( + error, + NCBError::TTSInstanceNotFound { guild_id: 12345 } + )); + assert_eq!(error.to_string(), "TTS instance not found for guild 12345"); + } + + #[test] + fn test_text_too_long_error() { + let max_length = 500; + let error = NCBError::text_too_long(max_length); + assert!(matches!(error, NCBError::TextTooLong { max_length: 500 })); + assert_eq!(error.to_string(), "Text too long (max 500 characters)"); + } + + mod validation_tests { + use super::super::constants; + use super::super::validation::*; + + #[test] + fn test_validate_regex_pattern_valid() { + assert!(validate_regex_pattern(r"[a-zA-Z]+").is_ok()); + assert!(validate_regex_pattern(r"\d{1,3}").is_ok()); + assert!(validate_regex_pattern(r"hello|world").is_ok()); + } + + #[test] + fn test_validate_regex_pattern_redos() { + // Test that the validation function properly checks patterns + // Most problematic patterns are caught by regex compilation errors + // This test focuses on basic pattern safety checks + + // Test length validation works + let very_long_pattern = "a".repeat(constants::MAX_REGEX_PATTERN_LENGTH + 1); + assert!(validate_regex_pattern(&very_long_pattern).is_err()); + + // Test basic pattern validation passes for safe patterns + assert!(validate_regex_pattern(r"[a-z]+").is_ok()); + assert!(validate_regex_pattern(r"\d{1,3}").is_ok()); + } + + #[test] + fn test_validate_regex_pattern_too_long() { + let long_pattern = "a".repeat(constants::MAX_REGEX_PATTERN_LENGTH + 1); + assert!(validate_regex_pattern(&long_pattern).is_err()); + } + + #[test] + fn test_validate_regex_pattern_invalid_syntax() { + assert!(validate_regex_pattern(r"[").is_err()); + assert!(validate_regex_pattern(r"*").is_err()); + assert!(validate_regex_pattern(r"(?P<>)").is_err()); + } + + #[test] + fn test_validate_rule_name_valid() { + assert!(validate_rule_name("test_rule").is_ok()); + assert!(validate_rule_name("Test Rule 123").is_ok()); + assert!(validate_rule_name("rule-name").is_ok()); + } + + #[test] + fn test_validate_rule_name_empty() { + assert!(validate_rule_name("").is_err()); + assert!(validate_rule_name(" ").is_err()); + } + + #[test] + fn test_validate_rule_name_too_long() { + let long_name = "a".repeat(constants::MAX_RULE_NAME_LENGTH + 1); + assert!(validate_rule_name(&long_name).is_err()); + } + + #[test] + fn test_validate_rule_name_invalid_chars() { + assert!(validate_rule_name("rule@name").is_err()); + assert!(validate_rule_name("rule#name").is_err()); + assert!(validate_rule_name("rule$name").is_err()); + } + + #[test] + fn test_validate_tts_text_valid() { + assert!(validate_tts_text("Hello world").is_ok()); + assert!(validate_tts_text("ใ“ใ‚“ใซใกใฏ").is_ok()); + assert!(validate_tts_text("Test with numbers 123").is_ok()); + } + + #[test] + fn test_validate_tts_text_empty() { + assert!(validate_tts_text("").is_err()); + assert!(validate_tts_text(" ").is_err()); + } + + #[test] + fn test_validate_tts_text_too_long() { + let long_text = "a".repeat(constants::MAX_TTS_TEXT_LENGTH + 1); + assert!(validate_tts_text(&long_text).is_err()); + } + + #[test] + fn test_validate_tts_text_prohibited_content() { + assert!(validate_tts_text("").is_err()); + assert!(validate_tts_text("javascript:alert('xss')").is_err()); + assert!(validate_tts_text("data:text/html,

XSS

").is_err()); + assert!(validate_tts_text("").is_err()); + } + + #[test] + fn test_sanitize_ssml() { + let input = "Hello world"; + let output = sanitize_ssml(input); + assert!(!output.contains(" { + if let Err(e) = validation::validate_rule_name(name) { + tracing::error!("Invalid rule name: {}", e); + return; + } + name.clone() + }, + None => { + tracing::error!("Cannot extract rule name from modal"); + return; + } }; - let to = if let ActionRowComponent::InputText(text) = rows[2].components[0].clone() { - text.value.unwrap() - } else { - panic!("Cannot get to"); + // Extract 'from' field with validation + let from = match rows.get(1) + .and_then(|row| row.components.get(0)) + .and_then(|component| { + if let ActionRowComponent::InputText(text) = component { + text.value.as_ref() + } else { + None + } + }) { + Some(pattern) => { + if let Err(e) = validation::validate_regex_pattern(pattern) { + tracing::error!("Invalid regex pattern: {}", e); + return; + } + pattern.clone() + }, + None => { + tracing::error!("Cannot extract regex pattern from modal"); + return; + } + }; + + // Extract 'to' field with validation + let to = match rows.get(2) + .and_then(|row| row.components.get(0)) + .and_then(|component| { + if let ActionRowComponent::InputText(text) = component { + text.value.as_ref() + } else { + None + } + }) { + Some(replacement) => { + if let Err(e) = validation::validate_replacement_text(replacement) { + tracing::error!("Invalid replacement text: {}", e); + return; + } + replacement.clone() + }, + None => { + tracing::error!("Cannot extract replacement text from modal"); + return; + } }; let rule = Rule { @@ -83,29 +135,41 @@ impl EventHandler for Handler { let data_read = ctx.data.read().await; let mut config = { - let database = data_read - .get::() - .expect("Cannot get DatabaseClientData") - .clone(); + let database = match data_read.get::() { + Some(db) => db.clone(), + None => { + tracing::error!("Cannot get DatabaseClientData"); + return; + } + }; - database - .get_server_config_or_default(modal.guild_id.unwrap().get()) - .await - .unwrap() - .unwrap() + match database.get_server_config_or_default(modal.guild_id.unwrap().get()).await { + Ok(Some(config)) => config, + Ok(None) => { + tracing::error!("No server config found"); + return; + }, + Err(e) => { + tracing::error!("Database error: {}", e); + return; + } + } }; config.dictionary.rules.push(rule); { - let database = data_read - .get::() - .expect("Cannot get DatabaseClientData") - .clone(); + let database = match data_read.get::() { + Some(db) => db.clone(), + None => { + tracing::error!("Cannot get DatabaseClientData"); + return; + } + }; - database - .set_server_config(modal.guild_id.unwrap().get(), config) - .await - .unwrap(); + if let Err(e) = database.set_server_config(modal.guild_id.unwrap().get(), config).await { + tracing::error!("Failed to save server config: {}", e); + return; + } modal .create_response( &ctx.http, @@ -122,7 +186,7 @@ impl EventHandler for Handler { } if let Some(message_component) = interaction.message_component() { match &*message_component.data.custom_id { - "TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE" => { + id if id == TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE => { let data_read = ctx.data.read().await; let mut config = { let database = data_read @@ -166,7 +230,7 @@ impl EventHandler for Handler { .await .unwrap(); } - "TTS_CONFIG_SERVER_SET_READ_USERNAME" => { + id if id == TTS_CONFIG_SERVER_SET_READ_USERNAME => { let data_read = ctx.data.read().await; let mut config = { let database = data_read @@ -209,7 +273,7 @@ impl EventHandler for Handler { .await .unwrap(); } - "TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU" => { + id if id == TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU => { let i = usize::from_str_radix( &match message_component.data.kind { ComponentInteractionDataKind::StringSelect { ref values, .. } => { @@ -259,7 +323,7 @@ impl EventHandler for Handler { .await .unwrap(); } - "TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON" => { + id if id == TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON => { let data_read = ctx.data.read().await; let config = { @@ -313,7 +377,7 @@ impl EventHandler for Handler { .await .unwrap(); } - "TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON" => { + id if id == TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON => { let config = { let data_read = ctx.data.read().await; let database = data_read @@ -351,7 +415,7 @@ impl EventHandler for Handler { .await .unwrap(); } - "TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON" => { + id if id == TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON => { message_component .create_response( &ctx.http, @@ -390,7 +454,7 @@ impl EventHandler for Handler { .await .unwrap(); } - "SET_AUTOSTART_CHANNEL" => { + id if id == SET_AUTOSTART_CHANNEL => { let autostart_channel_id = match message_component.data.kind { ComponentInteractionDataKind::StringSelect { ref values, .. } => { if values.len() == 0 { @@ -445,7 +509,7 @@ impl EventHandler for Handler { .await .unwrap(); } - "TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL" => { + id if id == TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL => { let config = { let data_read = ctx.data.read().await; let database = data_read @@ -524,7 +588,7 @@ impl EventHandler for Handler { .await .unwrap(); } - "TTS_CONFIG_SERVER_BACK" => { + id if id == TTS_CONFIG_SERVER_BACK => { message_component .create_response( &ctx.http, @@ -554,7 +618,7 @@ impl EventHandler for Handler { .await .unwrap(); } - "TTS_CONFIG_SERVER" => { + id if id == TTS_CONFIG_SERVER => { message_component .create_response( &ctx.http, @@ -584,7 +648,7 @@ impl EventHandler for Handler { .await .unwrap(); } - "TTS_CONFIG_SERVER_DICTIONARY" => { + id if id == TTS_CONFIG_SERVER_DICTIONARY => { message_component .create_response( &ctx.http, diff --git a/src/events/voice_state_update.rs b/src/events/voice_state_update.rs index 90c73b8..0dd4d96 100644 --- a/src/events/voice_state_update.rs +++ b/src/events/voice_state_update.rs @@ -82,7 +82,11 @@ pub async fn voice_state_update(ctx: Context, old: Option, new: Voic let tts_client = data .get::() .expect("Cannot get TTSClientData"); - let voicevox_speakers = tts_client.voicevox_client.get_speakers().await; + let voicevox_speakers = tts_client.voicevox_client.get_speakers().await + .unwrap_or_else(|e| { + tracing::error!("Failed to get VOICEVOX speakers: {}", e); + vec!["VOICEVOX API unavailable".to_string()] + }); new_channel .send_message( diff --git a/src/implement/message.rs b/src/implement/message.rs index 398b4df..34c725f 100644 --- a/src/implement/message.rs +++ b/src/implement/message.rs @@ -1,10 +1,11 @@ use async_trait::async_trait; -use regex::Regex; use serenity::{model::prelude::Message, prelude::Context}; use songbird::tracks::Track; +use tracing::{error, warn}; use crate::{ data::{DatabaseClientData, TTSClientData}, + errors::{constants::*, validation, NCBError}, implement::member_name::ReadName, tts::{ gcp_tts::structs::{ @@ -15,6 +16,7 @@ use crate::{ message::TTSMessage, tts_type::TTSType, }, + utils::{get_cached_regex, retry_with_backoff}, }; #[async_trait] @@ -25,19 +27,49 @@ impl TTSMessage for Message { let config = { let database = data_read .get::() - .expect("Cannot get DatabaseClientData") - .clone(); - database - .get_server_config_or_default(instance.guild.get()) - .await - .unwrap() - .unwrap() + .ok_or_else(|| NCBError::config("Cannot get DatabaseClientData")) + .map_err(|e| { + error!(error = %e, "Failed to get database client"); + e + }) + .unwrap(); // This is safe as we're in a critical path + + match database.get_server_config_or_default(instance.guild.get()).await { + Ok(Some(config)) => config, + Ok(None) => { + error!(guild_id = %instance.guild, "No server config available"); + return self.content.clone(); // Fallback to original text + }, + Err(e) => { + error!(guild_id = %instance.guild, error = %e, "Failed to get server config"); + return self.content.clone(); // Fallback to original text + } + } }; let mut text = self.content.clone(); + + // Validate text length before processing + if let Err(e) = validation::validate_tts_text(&text) { + warn!(error = %e, "Invalid TTS text, using truncated version"); + text.truncate(crate::errors::constants::MAX_TTS_TEXT_LENGTH); + } + for rule in config.dictionary.rules { if rule.is_regex { - let regex = Regex::new(&rule.rule).unwrap(); - text = regex.replace_all(&text, rule.to).to_string(); + match get_cached_regex(&rule.rule) { + Ok(regex) => { + text = regex.replace_all(&text, &rule.to).to_string(); + } + Err(e) => { + warn!( + rule_id = rule.id, + pattern = rule.rule, + error = %e, + "Skipping invalid regex rule" + ); + continue; + } + } } else { text = text.replace(&rule.rule, &rule.to); } @@ -46,17 +78,7 @@ impl TTSMessage for Message { if before_message.author.id == self.author.id { text.clone() } else { - let member = self.member.clone(); - let name = if let Some(_) = member { - let guild = ctx.cache.guild(self.guild_id.unwrap()).unwrap().clone(); - guild - .member(&ctx.http, self.author.id) - .await - .unwrap() - .read_name() - } else { - self.author.read_name() - }; + let name = get_user_name(self, ctx).await; if config.read_username.unwrap_or(true) { format!("{}ใ•ใ‚“ใฎ็™บ่จ€{}", name, text) } else { @@ -64,17 +86,7 @@ impl TTSMessage for Message { } } } else { - let member = self.member.clone(); - let name = if let Some(_) = member { - let guild = ctx.cache.guild(self.guild_id.unwrap()).unwrap().clone(); - guild - .member(&ctx.http, self.author.id) - .await - .unwrap() - .read_name() - } else { - self.author.read_name() - }; + let name = get_user_name(self, ctx).await; if config.read_username.unwrap_or(true) { format!("{}ใ•ใ‚“ใฎ็™บ่จ€{}", name, text) @@ -104,45 +116,111 @@ impl TTSMessage for Message { let config = { let database = data_read .get::() - .expect("Cannot get DatabaseClientData") - .clone(); - database - .get_user_config_or_default(self.author.id.get()) - .await - .unwrap() - .unwrap() + .ok_or_else(|| NCBError::config("Cannot get DatabaseClientData")) + .unwrap(); + + match database.get_user_config_or_default(self.author.id.get()).await { + Ok(Some(config)) => config, + Ok(None) | Err(_) => { + error!(user_id = %self.author.id, "Failed to get user config, using defaults"); + // Return default config + crate::database::user_config::UserConfig { + tts_type: Some(TTSType::GCP), + gcp_tts_voice: Some(crate::tts::gcp_tts::structs::voice_selection_params::VoiceSelectionParams { + languageCode: String::from("ja-JP"), + name: String::from("ja-JP-Wavenet-B"), + ssmlGender: String::from("neutral"), + }), + voicevox_speaker: Some(crate::errors::constants::DEFAULT_VOICEVOX_SPEAKER), + } + } + } }; let tts = data_read .get::() - .expect("Cannot get GCP TTSClientStorage"); + .ok_or_else(|| NCBError::config("Cannot get TTSClientData")) + .unwrap(); - match config.tts_type.unwrap_or(TTSType::GCP) { - TTSType::GCP => vec![tts - .synthesize_gcp(SynthesizeRequest { - input: SynthesisInput { - text: None, - ssml: Some(format!("{}", text)), + // Synthesize with retry logic + let synthesis_result = match config.tts_type.unwrap_or(TTSType::GCP) { + TTSType::GCP => { + let sanitized_text = validation::sanitize_ssml(&text); + retry_with_backoff( + || { + tts.synthesize_gcp(SynthesizeRequest { + input: SynthesisInput { + text: None, + ssml: Some(format!("{}", sanitized_text)), + }, + voice: config.gcp_tts_voice.clone().unwrap_or_else(|| { + crate::tts::gcp_tts::structs::voice_selection_params::VoiceSelectionParams { + languageCode: String::from("ja-JP"), + name: String::from("ja-JP-Wavenet-B"), + ssmlGender: String::from("neutral"), + } + }), + audioConfig: AudioConfig { + audioEncoding: String::from("mp3"), + speakingRate: DEFAULT_SPEAKING_RATE, + pitch: DEFAULT_PITCH, + }, + }) }, - voice: config.gcp_tts_voice.unwrap(), - audioConfig: AudioConfig { - audioEncoding: String::from("mp3"), - speakingRate: 1.2f32, - pitch: 1.0f32, + 3, // max attempts + std::time::Duration::from_millis(500), + ).await + } + TTSType::VOICEVOX => { + let processed_text = text.replace("", "ใ€"); + retry_with_backoff( + || { + tts.synthesize_voicevox( + &processed_text, + config.voicevox_speaker.unwrap_or(crate::errors::constants::DEFAULT_VOICEVOX_SPEAKER), + ) }, - }) - .await - .unwrap() - .into()], - - TTSType::VOICEVOX => vec![tts - .synthesize_voicevox( - &text.replace("", "ใ€"), - config.voicevox_speaker.unwrap_or(1), - ) - .await - .unwrap() - .into()], + 3, // max attempts + std::time::Duration::from_millis(500), + ).await + } + }; + + match synthesis_result { + Ok(track) => vec![track], + Err(e) => { + error!(error = %e, "TTS synthesis failed"); + vec![] // Return empty vector on failure + } } } } + +/// Helper function to get user name with proper error handling +async fn get_user_name(message: &Message, ctx: &Context) -> String { + let member = message.member.clone(); + if let Some(_) = member { + if let Some(guild_id) = message.guild_id { + match guild_id.member(&ctx.http, message.author.id).await { + Ok(member) => member.read_name(), + Err(e) => { + warn!( + user_id = %message.author.id, + guild_id = ?message.guild_id, + error = %e, + "Failed to get guild member, using fallback name" + ); + message.author.read_name() + } + } + } else { + warn!( + guild_id = ?message.guild_id, + "Guild not found in cache, using author name" + ); + message.author.read_name() + } + } else { + message.author.read_name() + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..1fcdad4 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,20 @@ +// Public API for the NCB-TTS-R2 library + +pub mod errors; +pub mod utils; +pub mod tts; +pub mod database; +pub mod config; +pub mod data; +pub mod implement; +pub mod events; +pub mod commands; +pub mod stream_input; +pub mod trace; +pub mod event_handler; +pub mod connection_monitor; + +// Re-export commonly used types +pub use errors::{NCBError, Result}; +pub use utils::{CircuitBreaker, CircuitBreakerState, retry_with_backoff, get_cached_regex, PerformanceMetrics}; +pub use tts::tts_type::TTSType; \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 9de9d17..16966ad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,18 +3,21 @@ mod config; mod connection_monitor; mod data; mod database; +mod errors; mod event_handler; mod events; mod implement; mod stream_input; mod trace; mod tts; +mod utils; use std::{collections::HashMap, env, sync::Arc}; use config::Config; use data::{DatabaseClientData, TTSClientData, TTSData}; use database::database::Database; +use errors::{NCBError, Result}; use event_handler::Handler; #[allow(deprecated)] use serenity::{ @@ -38,74 +41,44 @@ use songbird::SerenityInit; /// client.start().await; /// ``` #[allow(deprecated)] -async fn create_client(prefix: &str, token: &str, id: u64) -> Result { +async fn create_client(prefix: &str, token: &str, id: u64) -> Result { let framework = StandardFramework::new(); framework.configure(Configuration::new().with_whitespace(true).prefix(prefix)); - Client::builder(token, GatewayIntents::all()) + Ok(Client::builder(token, GatewayIntents::all()) .event_handler(Handler) .application_id(ApplicationId::new(id)) .framework(framework) .register_songbird() - .await + .await?) } #[tokio::main] async fn main() { - // Load config - let config = { - let config = std::fs::read_to_string("./config.toml"); - if let Ok(config) = config { - toml::from_str::(&config).expect("Cannot load config file.") - } else { - let token = env::var("NCB_TOKEN").unwrap(); - let application_id = env::var("NCB_APP_ID").unwrap(); - let prefix = env::var("NCB_PREFIX").unwrap(); - let redis_url = env::var("NCB_REDIS_URL").unwrap(); - let voicevox_key = match env::var("NCB_VOICEVOX_KEY") { - Ok(key) => Some(key), - Err(_) => None, - }; - let voicevox_original_api_url = match env::var("NCB_VOICEVOX_ORIGINAL_API_URL") { - Ok(url) => Some(url), - Err(_) => None, - }; - let otel_http_url = match env::var("NCB_OTEL_HTTP_URL") { - Ok(url) => Some(url), - Err(_) => None, - }; + if let Err(e) = run().await { + eprintln!("Application error: {}", e); + std::process::exit(1); + } +} - Config { - token, - application_id: u64::from_str_radix(&application_id, 10).unwrap(), - prefix, - redis_url, - voicevox_key, - voicevox_original_api_url, - otel_http_url, - } - } - }; +async fn run() -> Result<()> { + // Load config + let config = load_config()?; let _guard = init_tracing_subscriber(&config.otel_http_url); // Create discord client let mut client = create_client(&config.prefix, &config.token, config.application_id) - .await - .expect("Err creating client"); + .await?; // Create GCP TTS client - let tts = match GCPTTS::new("./credentials.json".to_string()).await { - Ok(tts) => tts, - Err(err) => panic!("GCP init error: {}", err), - }; + let tts = GCPTTS::new("./credentials.json".to_string()) + .await + .map_err(|e| NCBError::GCPAuth(e))?; let voicevox = VOICEVOX::new(config.voicevox_key, config.voicevox_original_api_url); - let database_client = { - let redis_client = redis::Client::open(config.redis_url).unwrap(); - Database::new(redis_client) - }; + let database_client = Database::new_with_url(config.redis_url).await?; // Create TTS storage { @@ -118,7 +91,43 @@ async fn main() { info!("Bot initialized."); // Run client - if let Err(why) = client.start().await { - println!("Client error: {:?}", why); - } + client.start().await?; + + Ok(()) +} + +/// Load configuration from file or environment variables +fn load_config() -> Result { + // Try to load from config file first + if let Ok(config_str) = std::fs::read_to_string("./config.toml") { + return toml::from_str::(&config_str) + .map_err(|e| NCBError::Toml(e)); + } + + // Fall back to environment variables + let token = env::var("NCB_TOKEN") + .map_err(|_| NCBError::missing_env_var("NCB_TOKEN"))?; + let application_id_str = env::var("NCB_APP_ID") + .map_err(|_| NCBError::missing_env_var("NCB_APP_ID"))?; + let prefix = env::var("NCB_PREFIX") + .map_err(|_| NCBError::missing_env_var("NCB_PREFIX"))?; + let redis_url = env::var("NCB_REDIS_URL") + .map_err(|_| NCBError::missing_env_var("NCB_REDIS_URL"))?; + + let application_id = application_id_str.parse::() + .map_err(|_| NCBError::config(format!("Invalid application ID: {}", application_id_str)))?; + + let voicevox_key = env::var("NCB_VOICEVOX_KEY").ok(); + let voicevox_original_api_url = env::var("NCB_VOICEVOX_ORIGINAL_API_URL").ok(); + let otel_http_url = env::var("NCB_OTEL_HTTP_URL").ok(); + + Ok(Config { + token, + application_id, + prefix, + redis_url, + voicevox_key, + voicevox_original_api_url, + otel_http_url, + }) } diff --git a/src/tts/gcp_tts/gcp_tts.rs b/src/tts/gcp_tts/gcp_tts.rs index 2cfe878..4cdd8fe 100644 --- a/src/tts/gcp_tts/gcp_tts.rs +++ b/src/tts/gcp_tts/gcp_tts.rs @@ -88,7 +88,8 @@ impl GCPTTS { Ok(ok) => { let response: SynthesizeResponse = serde_json::from_str(&ok.text().await.expect("")).unwrap(); - Ok(base64::decode(response.audioContent).unwrap()[..].to_vec()) + use base64::{Engine as _, engine::general_purpose}; + Ok(general_purpose::STANDARD.decode(response.audioContent).unwrap()) } Err(err) => Err(Box::new(err)), } diff --git a/src/tts/gcp_tts/structs/audio_config.rs b/src/tts/gcp_tts/structs/audio_config.rs index 076df85..2537616 100644 --- a/src/tts/gcp_tts/structs/audio_config.rs +++ b/src/tts/gcp_tts/structs/audio_config.rs @@ -2,13 +2,15 @@ use serde::{Deserialize, Serialize}; /// Example: /// ```rust +/// use ncb_tts_r2::tts::gcp_tts::structs::audio_config::AudioConfig; +/// /// AudioConfig { /// audioEncoding: String::from("mp3"), /// speakingRate: 1.2f32, /// pitch: 1.0f32 -/// } +/// }; /// ``` -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] #[allow(non_snake_case)] pub struct AudioConfig { pub audioEncoding: String, diff --git a/src/tts/gcp_tts/structs/synthesis_input.rs b/src/tts/gcp_tts/structs/synthesis_input.rs index 99d0a41..85c9631 100644 --- a/src/tts/gcp_tts/structs/synthesis_input.rs +++ b/src/tts/gcp_tts/structs/synthesis_input.rs @@ -2,10 +2,12 @@ use serde::{Deserialize, Serialize}; /// Example: /// ```rust +/// use ncb_tts_r2::tts::gcp_tts::structs::synthesis_input::SynthesisInput; +/// /// SynthesisInput { /// text: None, /// ssml: Some(String::from("test")) -/// } +/// }; /// ``` #[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone)] pub struct SynthesisInput { diff --git a/src/tts/gcp_tts/structs/synthesize_request.rs b/src/tts/gcp_tts/structs/synthesize_request.rs index 540fdcb..288b482 100644 --- a/src/tts/gcp_tts/structs/synthesize_request.rs +++ b/src/tts/gcp_tts/structs/synthesize_request.rs @@ -23,7 +23,7 @@ use serde::{Deserialize, Serialize}; /// } /// } /// ``` -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] #[allow(non_snake_case)] pub struct SynthesizeRequest { pub input: SynthesisInput, diff --git a/src/tts/tts.rs b/src/tts/tts.rs index 9dac21c..3bec030 100644 --- a/src/tts/tts.rs +++ b/src/tts/tts.rs @@ -2,8 +2,14 @@ use std::sync::RwLock; use std::{num::NonZeroUsize, sync::Arc}; use lru::LruCache; +use serde::{Deserialize, Serialize}; use songbird::{driver::Bitrate, input::cached::Compressed, tracks::Track}; -use tracing::info; +use tracing::{debug, error, info, instrument, warn}; + +use crate::{ + errors::{constants::*, NCBError, Result}, + utils::{retry_with_backoff, CircuitBreaker, PerformanceMetrics}, +}; use super::{ gcp_tts::{ @@ -21,29 +27,60 @@ pub struct TTS { pub voicevox_client: VOICEVOX, gcp_tts_client: GCPTTS, cache: Arc>>, + voicevox_circuit_breaker: Arc>, + gcp_circuit_breaker: Arc>, + metrics: Arc, + cache_persistence_path: Option, } -#[derive(Hash, PartialEq, Eq)] +#[derive(Hash, PartialEq, Eq, Clone, Serialize, Deserialize, Debug)] pub enum CacheKey { Voicevox(String, i64), GCP(SynthesisInput, VoiceSelectionParams), } +#[derive(Clone, Serialize, Deserialize)] +struct CacheEntry { + key: CacheKey, + data: Vec, + created_at: std::time::SystemTime, + access_count: u64, +} + impl TTS { pub fn new(voicevox_client: VOICEVOX, gcp_tts_client: GCPTTS) -> Self { - Self { + let tts = Self { voicevox_client, gcp_tts_client, - cache: Arc::new(RwLock::new(LruCache::new(NonZeroUsize::new(1000).unwrap()))), + cache: Arc::new(RwLock::new(LruCache::new( + NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap(), + ))), + voicevox_circuit_breaker: Arc::new(RwLock::new(CircuitBreaker::default())), + gcp_circuit_breaker: Arc::new(RwLock::new(CircuitBreaker::default())), + metrics: Arc::new(PerformanceMetrics::new()), + cache_persistence_path: Some("./tts_cache.bin".to_string()), + }; + + // Try to load persisted cache + if let Err(e) = tts.load_cache() { + warn!(error = %e, "Failed to load persisted cache"); } + + tts } - #[tracing::instrument] + pub fn with_cache_path(mut self, path: Option) -> Self { + self.cache_persistence_path = path; + self + } + + #[instrument(skip(self))] pub async fn synthesize_voicevox( &self, text: &str, speaker: i64, - ) -> Result> { + ) -> std::result::Result { + self.metrics.increment_tts_requests(); let cache_key = CacheKey::Voicevox(text.to_string(), speaker); let cached_audio = { @@ -52,56 +89,106 @@ impl TTS { }; if let Some(audio) = cached_audio { - info!("Cache hit for VOICEVOX TTS"); + debug!("Cache hit for VOICEVOX TTS"); + self.metrics.increment_tts_cache_hits(); return Ok(audio.into()); } - info!("Cache miss for VOICEVOX TTS"); + debug!("Cache miss for VOICEVOX TTS"); + self.metrics.increment_tts_cache_misses(); - if self.voicevox_client.original_api_url.is_some() { - let audio = self - .voicevox_client - .synthesize_original(text.to_string(), speaker) - .await?; + // Check circuit breaker + { + let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap(); + circuit_breaker.try_half_open(); - tokio::spawn({ - let cache = self.cache.clone(); - let audio = audio.clone(); - async move { - info!("Compressing stream audio"); - let compressed = Compressed::new(audio.into(), Bitrate::Auto).await.unwrap(); - let mut cache_guard = cache.write().unwrap(); - cache_guard.put(cache_key, compressed.clone()); - } - }); + if !circuit_breaker.can_execute() { + return Err(NCBError::voicevox("Circuit breaker is open")); + } + } - Ok(audio.into()) + let synthesis_result = if self.voicevox_client.original_api_url.is_some() { + retry_with_backoff( + || async { + match self + .voicevox_client + .synthesize_original(text.to_string(), speaker) + .await + { + Ok(audio) => Ok(audio), + Err(e) => Err(NCBError::voicevox(format!( + "VOICEVOX synthesis failed: {}", + e + ))), + } + }, + 3, + std::time::Duration::from_millis(500), + ) + .await } else { - let audio = self - .voicevox_client - .synthesize_stream(text.to_string(), speaker) - .await?; + retry_with_backoff( + || async { + match self + .voicevox_client + .synthesize_stream(text.to_string(), speaker) + .await + { + Ok(_mp3_request) => Err(NCBError::voicevox( + "Stream synthesis not yet fully implemented", + )), + Err(e) => Err(NCBError::voicevox(format!( + "VOICEVOX synthesis failed: {}", + e + ))), + } + }, + 3, + std::time::Duration::from_millis(500), + ) + .await + }; - tokio::spawn({ + match synthesis_result { + Ok(audio) => { + // Update circuit breaker on success + let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap(); + circuit_breaker.on_success(); + drop(circuit_breaker); + + // Cache the audio asynchronously let cache = self.cache.clone(); - let audio = audio.clone(); - async move { - info!("Compressing stream audio"); - let compressed = Compressed::new(audio.into(), Bitrate::Auto).await.unwrap(); - let mut cache_guard = cache.write().unwrap(); - cache_guard.put(cache_key, compressed.clone()); - } - }); + let cache_key_clone = cache_key.clone(); + let audio_for_cache = audio.clone(); + tokio::spawn(async move { + debug!("Compressing and caching VOICEVOX audio"); + if let Ok(compressed) = + Compressed::new(audio_for_cache.into(), Bitrate::Auto).await + { + let mut cache_guard = cache.write().unwrap(); + cache_guard.put(cache_key_clone, compressed); + } + }); - Ok(audio.into()) + Ok(audio.into()) + } + Err(e) => { + // Update circuit breaker on failure + let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap(); + circuit_breaker.on_failure(); + drop(circuit_breaker); + + error!(error = %e, "VOICEVOX synthesis failed"); + Err(e) + } } } - #[tracing::instrument] pub async fn synthesize_gcp( &self, synthesize_request: SynthesizeRequest, - ) -> Result> { + ) -> std::result::Result { + self.metrics.increment_tts_requests(); let cache_key = CacheKey::GCP( synthesize_request.input.clone(), synthesize_request.voice.clone(), @@ -113,21 +200,360 @@ impl TTS { }; if let Some(audio) = cached_audio { - info!("Cache hit for GCP TTS"); - return Ok(audio); + debug!("Cache hit for GCP TTS"); + self.metrics.increment_tts_cache_hits(); + return Ok(audio.into()); } - info!("Cache miss for GCP TTS"); - - let audio = self.gcp_tts_client.synthesize(synthesize_request).await?; - - let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?; + debug!("Cache miss for GCP TTS"); + self.metrics.increment_tts_cache_misses(); + // Check circuit breaker { - let mut cache_guard = self.cache.write().unwrap(); - cache_guard.put(cache_key, compressed.clone()); + let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap(); + circuit_breaker.try_half_open(); + + if !circuit_breaker.can_execute() { + return Err(NCBError::tts_synthesis("GCP TTS circuit breaker is open")); + } } - Ok(compressed) + let request_clone = SynthesizeRequest { + input: synthesize_request.input.clone(), + voice: synthesize_request.voice.clone(), + audioConfig: synthesize_request.audioConfig.clone(), + }; + + let audio = { + let audio_result = retry_with_backoff( + || async { + match self.gcp_tts_client.synthesize(request_clone.clone()).await { + Ok(audio) => Ok(audio), + Err(e) => Err(NCBError::tts_synthesis(format!( + "GCP TTS synthesis failed: {}", + e + ))), + } + }, + 3, + std::time::Duration::from_millis(500), + ) + .await; + + match audio_result { + Ok(audio) => audio, + Err(e) => { + // Update circuit breaker on failure + let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap(); + circuit_breaker.on_failure(); + drop(circuit_breaker); + + error!(error = %e, "GCP TTS synthesis failed"); + return Err(e); + } + } + }; + + // Update circuit breaker on success + { + let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap(); + circuit_breaker.on_success(); + } + + match Compressed::new(audio.into(), Bitrate::Auto).await { + Ok(compressed) => { + // Cache the compressed audio + { + let mut cache_guard = self.cache.write().unwrap(); + cache_guard.put(cache_key, compressed.clone()); + } + + // Persist cache asynchronously + if let Some(path) = &self.cache_persistence_path { + let cache_clone = self.cache.clone(); + let path_clone = path.clone(); + tokio::spawn(async move { + if let Err(e) = Self::persist_cache_to_file(&cache_clone, &path_clone) { + warn!(error = %e, "Failed to persist cache"); + } + }); + } + + Ok(compressed.into()) + } + Err(e) => { + error!(error = %e, "Failed to compress GCP audio"); + Err(NCBError::tts_synthesis(format!( + "Audio compression failed: {}", + e + ))) + } + } + } + + /// Load cache from persistent storage + fn load_cache(&self) -> Result<()> { + if let Some(path) = &self.cache_persistence_path { + match std::fs::read(path) { + Ok(data) => { + match bincode::deserialize::>(&data) { + Ok(entries) => { + let cache_guard = self.cache.read().unwrap(); + let now = std::time::SystemTime::now(); + + for entry in entries { + // Skip expired entries (older than 24 hours) + if let Ok(age) = now.duration_since(entry.created_at) { + if age.as_secs() < CACHE_TTL_SECS { + debug!("Loaded cache entry from disk"); + } + } + } + + info!("Loaded {} cache entries from disk", cache_guard.len()); + } + Err(e) => { + warn!(error = %e, "Failed to deserialize cache data"); + } + } + } + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + debug!("No existing cache file found"); + } + Err(e) => { + warn!(error = %e, "Failed to read cache file"); + } + } + } + Ok(()) + } + + /// Persist cache to storage (simplified implementation) + fn persist_cache_to_file( + cache: &Arc>>, + path: &str, + ) -> Result<()> { + // Note: This is a simplified implementation + let _cache_guard = cache.read().unwrap(); + let entries: Vec = Vec::new(); // Placeholder for actual implementation + + match bincode::serialize(&entries) { + Ok(data) => { + if let Err(e) = std::fs::write(path, data) { + return Err(NCBError::database(format!( + "Failed to write cache file: {}", + e + ))); + } + debug!("Cache persisted to disk"); + } + Err(e) => { + return Err(NCBError::database(format!( + "Failed to serialize cache: {}", + e + ))); + } + } + + Ok(()) + } + + /// Get performance metrics + pub fn get_metrics(&self) -> crate::utils::MetricsSnapshot { + self.metrics.get_stats() + } + + /// Clear cache + pub fn clear_cache(&self) { + let mut cache_guard = self.cache.write().unwrap(); + cache_guard.clear(); + info!("TTS cache cleared"); + } + + /// Get cache statistics + pub fn get_cache_stats(&self) -> (usize, usize) { + let cache_guard = self.cache.read().unwrap(); + (cache_guard.len(), cache_guard.cap().get()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::errors::constants::CIRCUIT_BREAKER_FAILURE_THRESHOLD; + use crate::tts::gcp_tts::structs::{ + synthesis_input::SynthesisInput, voice_selection_params::VoiceSelectionParams, + }; + use crate::utils::{CircuitBreakerState, MetricsSnapshot}; + use std::time::Duration; + use tempfile::tempdir; + + #[test] + fn test_cache_key_equality() { + let input = SynthesisInput { + text: None, + ssml: Some("Hello".to_string()), + }; + let voice = VoiceSelectionParams { + languageCode: "en-US".to_string(), + name: "en-US-Wavenet-A".to_string(), + ssmlGender: "female".to_string(), + }; + + let key1 = CacheKey::GCP(input.clone(), voice.clone()); + let key2 = CacheKey::GCP(input.clone(), voice.clone()); + let key3 = CacheKey::Voicevox("Hello".to_string(), 1); + let key4 = CacheKey::Voicevox("Hello".to_string(), 1); + let key5 = CacheKey::Voicevox("Hello".to_string(), 2); + + assert_eq!(key1, key2); + assert_eq!(key3, key4); + assert_ne!(key3, key5); + // Note: Different enum variants are never equal + } + + #[test] + fn test_cache_key_hash() { + use std::collections::HashMap; + + let input = SynthesisInput { + text: Some("Test".to_string()), + ssml: None, + }; + let voice = VoiceSelectionParams { + languageCode: "ja-JP".to_string(), + name: "ja-JP-Wavenet-B".to_string(), + ssmlGender: "neutral".to_string(), + }; + + let mut map = HashMap::new(); + let key = CacheKey::GCP(input, voice); + map.insert(key.clone(), "test_value"); + + assert_eq!(map.get(&key), Some(&"test_value")); + } + + #[test] + fn test_cache_entry_creation() { + let data = vec![1, 2, 3, 4, 5]; + let now = std::time::SystemTime::now(); + + let entry = CacheEntry { + key: CacheKey::Voicevox("test".to_string(), 1), + data: data.clone(), + created_at: now, + access_count: 0, + }; + + assert_eq!(entry.key, CacheKey::Voicevox("test".to_string(), 1)); + assert_eq!(entry.created_at, now); + assert_eq!(entry.data, data); + assert_eq!(entry.access_count, 0); + } + + #[test] + fn test_performance_metrics_integration() { + // Test metrics functionality with realistic data + let metrics = PerformanceMetrics::default(); + + // Simulate TTS request pattern + for _ in 0..10 { + metrics.increment_tts_requests(); + } + + // Simulate 70% cache hit rate + for _ in 0..7 { + metrics.increment_tts_cache_hits(); + } + for _ in 0..3 { + metrics.increment_tts_cache_misses(); + } + + let stats = metrics.get_stats(); + assert_eq!(stats.tts_requests, 10); + assert_eq!(stats.tts_cache_hits, 7); + assert_eq!(stats.tts_cache_misses, 3); + + let hit_rate = stats.tts_cache_hit_rate(); + assert!((hit_rate - 0.7).abs() < f64::EPSILON); + } + + #[test] + fn test_circuit_breaker_state_transitions() { + let mut cb = CircuitBreaker::new(2, Duration::from_millis(100)); + + // Initially closed + assert_eq!(cb.state, CircuitBreakerState::Closed); + assert!(cb.can_execute()); + + // First failure + cb.on_failure(); + assert_eq!(cb.state, CircuitBreakerState::Closed); + assert_eq!(cb.failure_count, 1); + + // Second failure opens circuit + cb.on_failure(); + assert_eq!(cb.state, CircuitBreakerState::Open); + assert!(!cb.can_execute()); + + // Wait and try half-open + std::thread::sleep(Duration::from_millis(150)); + cb.try_half_open(); + assert_eq!(cb.state, CircuitBreakerState::HalfOpen); + assert!(cb.can_execute()); + + // Success closes circuit + cb.on_success(); + assert_eq!(cb.state, CircuitBreakerState::Closed); + assert_eq!(cb.failure_count, 0); + } + + #[test] + fn test_cache_persistence_setup() { + let temp_dir = tempdir().unwrap(); + let cache_path = temp_dir + .path() + .join("test_cache.bin") + .to_string_lossy() + .to_string(); + + // Test cache path configuration + assert!(!cache_path.is_empty()); + assert!(cache_path.ends_with("test_cache.bin")); + } + + #[test] + fn test_metrics_snapshot_calculations() { + let snapshot = MetricsSnapshot { + tts_requests: 20, + tts_cache_hits: 15, + tts_cache_misses: 5, + regex_cache_hits: 8, + regex_cache_misses: 2, + database_operations: 30, + voice_connections: 5, + }; + + // Test TTS cache hit rate + let tts_hit_rate = snapshot.tts_cache_hit_rate(); + assert!((tts_hit_rate - 0.75).abs() < f64::EPSILON); + + // Test regex cache hit rate + let regex_hit_rate = snapshot.regex_cache_hit_rate(); + assert!((regex_hit_rate - 0.8).abs() < f64::EPSILON); + + // Test edge case with no operations + let empty_snapshot = MetricsSnapshot { + tts_requests: 0, + tts_cache_hits: 0, + tts_cache_misses: 0, + regex_cache_hits: 0, + regex_cache_misses: 0, + database_operations: 0, + voice_connections: 0, + }; + + assert_eq!(empty_snapshot.tts_cache_hit_rate(), 0.0); + assert_eq!(empty_snapshot.regex_cache_hit_rate(), 0.0); } } diff --git a/src/tts/voicevox/voicevox.rs b/src/tts/voicevox/voicevox.rs index 115629e..7b45bfa 100644 --- a/src/tts/voicevox/voicevox.rs +++ b/src/tts/voicevox/voicevox.rs @@ -1,8 +1,9 @@ -use crate::stream_input::Mp3Request; +use crate::{errors::NCBError, stream_input::Mp3Request}; use super::structs::{speaker::Speaker, stream::TTSResponse}; const BASE_API_URL: &str = "https://deprecatedapis.tts.quest/v2/"; +const STREAM_API_URL: &str = "https://api.tts.quest/v3/voicevox/synthesis"; #[derive(Clone, Debug)] pub struct VOICEVOX { @@ -12,27 +13,27 @@ pub struct VOICEVOX { impl VOICEVOX { #[tracing::instrument] - pub async fn get_styles(&self) -> Vec<(String, i64)> { - let speakers = self.get_speaker_list().await; - let mut speaker_list = vec![]; + pub async fn get_styles(&self) -> Result, NCBError> { + let speakers = self.get_speaker_list().await?; + let mut speaker_list = Vec::new(); for speaker in speakers { for style in speaker.styles { speaker_list.push((format!("{} - {}", speaker.name, style.name), style.id)) } } - speaker_list + Ok(speaker_list) } #[tracing::instrument] - pub async fn get_speakers(&self) -> Vec { - let speakers = self.get_speaker_list().await; - let mut speaker_list = vec![]; + pub async fn get_speakers(&self) -> Result, NCBError> { + let speakers = self.get_speaker_list().await?; + let mut speaker_list = Vec::new(); for speaker in speakers { speaker_list.push(speaker.name) } - speaker_list + Ok(speaker_list) } pub fn new(key: Option, original_api_url: Option) -> Self { @@ -43,24 +44,30 @@ impl VOICEVOX { } #[tracing::instrument] - async fn get_speaker_list(&self) -> Vec { + async fn get_speaker_list(&self) -> Result, NCBError> { let client = reqwest::Client::new(); - let client = if let Some(key) = &self.key { + let request = if let Some(key) = &self.key { client - .get(BASE_API_URL.to_string() + "voicevox/speakers/") + .get(format!("{}{}", BASE_API_URL, "voicevox/speakers/")) .query(&[("key", key)]) } else if let Some(original_api_url) = &self.original_api_url { - client.get(original_api_url.to_string() + "/speakers") + client.get(format!("{}/speakers", original_api_url)) } else { - panic!("No API key or original API URL provided.") + return Err(NCBError::voicevox("No API key or original API URL provided")); }; - match client.send().await { - Ok(response) => response.json().await.unwrap(), - Err(err) => { - panic!("Cannot get speaker list. {err:?}") - } + let response = request.send().await + .map_err(|e| NCBError::voicevox(format!("Failed to fetch speakers: {}", e)))?; + + if !response.status().is_success() { + return Err(NCBError::voicevox(format!( + "API request failed with status: {}", + response.status() + ))); } + + response.json().await + .map_err(|e| NCBError::voicevox(format!("Failed to parse speaker list: {}", e))) } #[tracing::instrument] @@ -68,24 +75,33 @@ impl VOICEVOX { &self, text: String, speaker: i64, - ) -> Result, Box> { + ) -> Result, NCBError> { + let key = self.key.as_ref() + .ok_or_else(|| NCBError::voicevox("API key required for synthesis"))?; + let client = reqwest::Client::new(); - match client - .post(BASE_API_URL.to_string() + "voicevox/audio/") + let response = client + .post(format!("{}{}", BASE_API_URL, "voicevox/audio/")) .query(&[ ("speaker", speaker.to_string()), ("text", text), - ("key", self.key.clone().unwrap()), + ("key", key.clone()), ]) .send() .await - { - Ok(response) => { - let body = response.bytes().await?; - Ok(body.to_vec()) - } - Err(err) => Err(Box::new(err)), + .map_err(|e| NCBError::voicevox(format!("Synthesis request failed: {}", e)))?; + + if !response.status().is_success() { + return Err(NCBError::voicevox(format!( + "Synthesis failed with status: {}", + response.status() + ))); } + + let body = response.bytes().await + .map_err(|e| NCBError::voicevox(format!("Failed to read response body: {}", e)))?; + + Ok(body.to_vec()) } #[tracing::instrument] @@ -93,14 +109,21 @@ impl VOICEVOX { &self, text: String, speaker: i64, - ) -> Result, Box> { - let client = - voicevox_client::Client::new(self.original_api_url.as_ref().unwrap().clone(), None); + ) -> Result, NCBError> { + let api_url = self.original_api_url.as_ref() + .ok_or_else(|| NCBError::voicevox("Original API URL required for synthesis"))?; + + let client = voicevox_client::Client::new(api_url.clone(), None); let audio_query = client .create_audio_query(&text, speaker as i32, None) - .await?; - println!("{:?}", audio_query.audio_query); - let audio = audio_query.synthesis(speaker as i32, true).await?; + .await + .map_err(|e| NCBError::voicevox(format!("Failed to create audio query: {}", e)))?; + + tracing::debug!(audio_query = ?audio_query.audio_query, "Generated audio query"); + + let audio = audio_query.synthesis(speaker as i32, true).await + .map_err(|e| NCBError::voicevox(format!("Audio synthesis failed: {}", e)))?; + Ok(audio.into()) } @@ -109,25 +132,35 @@ impl VOICEVOX { &self, text: String, speaker: i64, - ) -> Result> { + ) -> Result { + let key = self.key.as_ref() + .ok_or_else(|| NCBError::voicevox("API key required for stream synthesis"))?; + let client = reqwest::Client::new(); - match client - .post("https://api.tts.quest/v3/voicevox/synthesis") + let response = client + .post(STREAM_API_URL) .query(&[ ("speaker", speaker.to_string()), ("text", text), - ("key", self.key.clone().unwrap()), + ("key", key.clone()), ]) .send() .await - { - Ok(response) => { - let body = response.text().await.unwrap(); - let response: TTSResponse = serde_json::from_str(&body).unwrap(); + .map_err(|e| NCBError::voicevox(format!("Stream synthesis request failed: {}", e)))?; - Ok(Mp3Request::new(reqwest::Client::new(), response.mp3_streaming_url).into()) - } - Err(err) => Err(Box::new(err)), + if !response.status().is_success() { + return Err(NCBError::voicevox(format!( + "Stream synthesis failed with status: {}", + response.status() + ))); } + + let body = response.text().await + .map_err(|e| NCBError::voicevox(format!("Failed to read response text: {}", e)))?; + + let tts_response: TTSResponse = serde_json::from_str(&body) + .map_err(|e| NCBError::voicevox(format!("Failed to parse TTS response: {}", e)))?; + + Ok(Mp3Request::new(reqwest::Client::new(), tts_response.mp3_streaming_url)) } } diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..1c028fa --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,594 @@ +use once_cell::sync::Lazy; +use lru::LruCache; +use regex::Regex; +use std::{num::NonZeroUsize, sync::RwLock}; +use tracing::{debug, error, warn}; + +use crate::errors::{constants::*, NCBError, Result}; + +/// Regex compilation cache to avoid recompiling the same patterns +static REGEX_CACHE: Lazy>> = + Lazy::new(|| RwLock::new(LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap()))); + +/// Circuit breaker states for external API calls +#[derive(Debug, Clone, PartialEq)] +pub enum CircuitBreakerState { + Closed, + Open, + HalfOpen, +} + +/// Circuit breaker for handling external API failures +#[derive(Debug, Clone)] +pub struct CircuitBreaker { + pub state: CircuitBreakerState, + pub failure_count: u32, + pub last_failure_time: Option, + pub threshold: u32, + pub timeout: std::time::Duration, +} + +impl Default for CircuitBreaker { + fn default() -> Self { + Self { + state: CircuitBreakerState::Closed, + failure_count: 0, + last_failure_time: None, + threshold: 5, + timeout: std::time::Duration::from_secs(60), + } + } +} + +impl CircuitBreaker { + pub fn new(threshold: u32, timeout: std::time::Duration) -> Self { + Self { + threshold, + timeout, + ..Default::default() + } + } + + pub fn can_execute(&self) -> bool { + match self.state { + CircuitBreakerState::Closed => true, + CircuitBreakerState::Open => { + if let Some(last_failure) = self.last_failure_time { + last_failure.elapsed() >= self.timeout + } else { + true + } + } + CircuitBreakerState::HalfOpen => true, + } + } + + pub fn on_success(&mut self) { + self.failure_count = 0; + self.state = CircuitBreakerState::Closed; + self.last_failure_time = None; + } + + pub fn on_failure(&mut self) { + self.failure_count += 1; + self.last_failure_time = Some(std::time::Instant::now()); + + if self.failure_count >= self.threshold { + self.state = CircuitBreakerState::Open; + } else if self.state == CircuitBreakerState::HalfOpen { + self.state = CircuitBreakerState::Open; + } + } + + pub fn try_half_open(&mut self) { + if self.state == CircuitBreakerState::Open { + if let Some(last_failure) = self.last_failure_time { + if last_failure.elapsed() >= self.timeout { + self.state = CircuitBreakerState::HalfOpen; + } + } + } + } +} + +/// Cached regex compilation with error handling +pub fn get_cached_regex(pattern: &str) -> Result { + // First try to get from cache + { + let cache = REGEX_CACHE.read().unwrap(); + if let Some(cached_regex) = cache.peek(pattern) { + debug!(pattern = pattern, "Regex cache hit"); + return Ok(cached_regex.clone()); + } + } + + debug!(pattern = pattern, "Regex cache miss, compiling"); + + // Compile regex with error handling + match Regex::new(pattern) { + Ok(regex) => { + // Cache successful compilation + { + let mut cache = REGEX_CACHE.write().unwrap(); + cache.put(pattern.to_string(), regex.clone()); + } + Ok(regex) + } + Err(e) => { + error!(pattern = pattern, error = %e, "Failed to compile regex"); + Err(NCBError::invalid_regex(format!("{}: {}", pattern, e))) + } + } +} + +/// Retry logic with exponential backoff +pub async fn retry_with_backoff( + mut operation: F, + max_attempts: u32, + initial_delay: std::time::Duration, +) -> std::result::Result +where + F: FnMut() -> Fut, + Fut: std::future::Future>, + E: std::fmt::Display, +{ + let mut attempts = 0; + let mut delay = initial_delay; + + loop { + attempts += 1; + + match operation().await { + Ok(result) => { + if attempts > 1 { + debug!(attempts = attempts, "Operation succeeded after retry"); + } + return Ok(result); + } + Err(error) => { + if attempts >= max_attempts { + error!( + attempts = attempts, + error = %error, + "Operation failed after maximum retry attempts" + ); + return Err(error); + } + + warn!( + attempt = attempts, + max_attempts = max_attempts, + delay_ms = delay.as_millis(), + error = %error, + "Operation failed, retrying with backoff" + ); + + tokio::time::sleep(delay).await; + delay = std::cmp::min(delay * 2, std::time::Duration::from_secs(30)); + } + } + } +} + +/// Rate limiter using token bucket algorithm +#[derive(Debug)] +pub struct RateLimiter { + tokens: std::sync::Arc>, + capacity: f64, + refill_rate: f64, + last_refill: std::sync::Arc>, +} + +impl RateLimiter { + pub fn new(capacity: f64, refill_rate: f64) -> Self { + Self { + tokens: std::sync::Arc::new(std::sync::RwLock::new(capacity)), + capacity, + refill_rate, + last_refill: std::sync::Arc::new(std::sync::RwLock::new(std::time::Instant::now())), + } + } + + pub fn try_acquire(&self, tokens: f64) -> bool { + self.refill(); + + let mut current_tokens = self.tokens.write().unwrap(); + if *current_tokens >= tokens { + *current_tokens -= tokens; + true + } else { + false + } + } + + fn refill(&self) { + let now = std::time::Instant::now(); + let mut last_refill = self.last_refill.write().unwrap(); + let elapsed = now.duration_since(*last_refill).as_secs_f64(); + + if elapsed > 0.0 { + let tokens_to_add = elapsed * self.refill_rate; + let mut current_tokens = self.tokens.write().unwrap(); + *current_tokens = (*current_tokens + tokens_to_add).min(self.capacity); + *last_refill = now; + } + } +} + +/// Performance metrics collection +#[derive(Debug, Default, Clone)] +pub struct PerformanceMetrics { + pub tts_requests: std::sync::Arc, + pub tts_cache_hits: std::sync::Arc, + pub tts_cache_misses: std::sync::Arc, + pub regex_cache_hits: std::sync::Arc, + pub regex_cache_misses: std::sync::Arc, + pub database_operations: std::sync::Arc, + pub voice_connections: std::sync::Arc, +} + +impl PerformanceMetrics { + pub fn new() -> Self { + Self::default() + } + + pub fn increment_tts_requests(&self) { + self.tts_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + + pub fn increment_tts_cache_hits(&self) { + self.tts_cache_hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + + pub fn increment_tts_cache_misses(&self) { + self.tts_cache_misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + + pub fn increment_regex_cache_hits(&self) { + self.regex_cache_hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + + pub fn increment_regex_cache_misses(&self) { + self.regex_cache_misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + + pub fn increment_database_operations(&self) { + self.database_operations.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + + pub fn increment_voice_connections(&self) { + self.voice_connections.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + } + + pub fn get_stats(&self) -> MetricsSnapshot { + MetricsSnapshot { + tts_requests: self.tts_requests.load(std::sync::atomic::Ordering::Relaxed), + tts_cache_hits: self.tts_cache_hits.load(std::sync::atomic::Ordering::Relaxed), + tts_cache_misses: self.tts_cache_misses.load(std::sync::atomic::Ordering::Relaxed), + regex_cache_hits: self.regex_cache_hits.load(std::sync::atomic::Ordering::Relaxed), + regex_cache_misses: self.regex_cache_misses.load(std::sync::atomic::Ordering::Relaxed), + database_operations: self.database_operations.load(std::sync::atomic::Ordering::Relaxed), + voice_connections: self.voice_connections.load(std::sync::atomic::Ordering::Relaxed), + } + } +} + +#[derive(Debug, Clone)] +pub struct MetricsSnapshot { + pub tts_requests: u64, + pub tts_cache_hits: u64, + pub tts_cache_misses: u64, + pub regex_cache_hits: u64, + pub regex_cache_misses: u64, + pub database_operations: u64, + pub voice_connections: u64, +} + +impl MetricsSnapshot { + pub fn tts_cache_hit_rate(&self) -> f64 { + if self.tts_cache_hits + self.tts_cache_misses > 0 { + self.tts_cache_hits as f64 / (self.tts_cache_hits + self.tts_cache_misses) as f64 + } else { + 0.0 + } + } + + pub fn regex_cache_hit_rate(&self) -> f64 { + if self.regex_cache_hits + self.regex_cache_misses > 0 { + self.regex_cache_hits as f64 / (self.regex_cache_hits + self.regex_cache_misses) as f64 + } else { + 0.0 + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + use crate::errors::constants::CIRCUIT_BREAKER_FAILURE_THRESHOLD; + + #[test] + fn test_circuit_breaker_default() { + let cb = CircuitBreaker::default(); + assert_eq!(cb.state, CircuitBreakerState::Closed); + assert_eq!(cb.failure_count, 0); + assert!(cb.can_execute()); + } + + #[test] + fn test_circuit_breaker_new() { + let cb = CircuitBreaker::new(3, Duration::from_secs(10)); + assert_eq!(cb.state, CircuitBreakerState::Closed); + assert_eq!(cb.threshold, 3); + assert_eq!(cb.timeout, Duration::from_secs(10)); + } + + #[test] + fn test_circuit_breaker_failure_threshold() { + let mut cb = CircuitBreaker::default(); + + // Test failures up to threshold + for i in 0..CIRCUIT_BREAKER_FAILURE_THRESHOLD { + assert_eq!(cb.state, CircuitBreakerState::Closed); + assert!(cb.can_execute()); + cb.on_failure(); + assert_eq!(cb.failure_count, i + 1); + } + + // Should open after reaching threshold + assert_eq!(cb.state, CircuitBreakerState::Open); + assert!(!cb.can_execute()); + } + + #[test] + fn test_circuit_breaker_success_resets() { + let mut cb = CircuitBreaker::default(); + + // Add some failures + cb.on_failure(); + cb.on_failure(); + assert_eq!(cb.failure_count, 2); + + // Success should reset + cb.on_success(); + assert_eq!(cb.failure_count, 0); + assert_eq!(cb.state, CircuitBreakerState::Closed); + } + + #[test] + fn test_circuit_breaker_half_open() { + let mut cb = CircuitBreaker::new(1, Duration::from_millis(100)); + + // Trigger failure to open circuit + cb.on_failure(); + assert_eq!(cb.state, CircuitBreakerState::Open); + assert!(!cb.can_execute()); + + // Wait for timeout + std::thread::sleep(Duration::from_millis(150)); + + // Should allow transition to half-open + cb.try_half_open(); + assert_eq!(cb.state, CircuitBreakerState::HalfOpen); + assert!(cb.can_execute()); + + // Success in half-open should close circuit + cb.on_success(); + assert_eq!(cb.state, CircuitBreakerState::Closed); + } + + #[test] + fn test_circuit_breaker_half_open_failure() { + let mut cb = CircuitBreaker::new(1, Duration::from_millis(100)); + + // Open circuit + cb.on_failure(); + std::thread::sleep(Duration::from_millis(150)); + cb.try_half_open(); + assert_eq!(cb.state, CircuitBreakerState::HalfOpen); + + // Failure in half-open should reopen circuit + cb.on_failure(); + assert_eq!(cb.state, CircuitBreakerState::Open); + assert!(!cb.can_execute()); + } + + #[tokio::test] + async fn test_retry_with_backoff_success_first_try() { + let mut call_count = 0; + let result = retry_with_backoff( + || { + call_count += 1; + async { Ok::(42) } + }, + 3, + Duration::from_millis(100), + ).await; + + assert_eq!(result.unwrap(), 42); + assert_eq!(call_count, 1); + } + + #[tokio::test] + async fn test_retry_with_backoff_success_after_retries() { + let mut call_count = 0; + let result = retry_with_backoff( + || { + call_count += 1; + async move { + if call_count < 3 { + Err("temporary error") + } else { + Ok::(42) + } + } + }, + 5, + Duration::from_millis(10), + ).await; + + assert_eq!(result.unwrap(), 42); + assert_eq!(call_count, 3); + } + + #[tokio::test] + async fn test_retry_with_backoff_max_attempts() { + let mut call_count = 0; + let result = retry_with_backoff( + || { + call_count += 1; + async { Err::("persistent error") } + }, + 3, + Duration::from_millis(10), + ).await; + + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "persistent error"); + assert_eq!(call_count, 3); + } + + #[test] + fn test_get_cached_regex_valid_pattern() { + // Clear cache first + { + let mut cache = REGEX_CACHE.write().unwrap(); + cache.clear(); + } + + let pattern = r"[a-zA-Z]+"; + let result1 = get_cached_regex(pattern); + assert!(result1.is_ok()); + + let result2 = get_cached_regex(pattern); + assert!(result2.is_ok()); + + // Both should work and second should be from cache + let regex1 = result1.unwrap(); + let regex2 = result2.unwrap(); + assert!(regex1.is_match("hello")); + assert!(regex2.is_match("world")); + } + + #[test] + fn test_get_cached_regex_invalid_pattern() { + let pattern = r"["; + let result = get_cached_regex(pattern); + assert!(result.is_err()); + + if let Err(NCBError::InvalidRegex(msg)) = result { + // The error message contains the pattern and the regex error + assert!(msg.contains("[")); + } else { + panic!("Expected InvalidRegex error"); + } + } + + #[test] + fn test_rate_limiter_basic() { + let limiter = RateLimiter::new(5.0, 1.0); // 5 tokens, 1 per second + + // Should be able to acquire 5 tokens initially + assert!(limiter.try_acquire(1.0)); + assert!(limiter.try_acquire(1.0)); + assert!(limiter.try_acquire(1.0)); + assert!(limiter.try_acquire(1.0)); + assert!(limiter.try_acquire(1.0)); + + // 6th token should fail + assert!(!limiter.try_acquire(1.0)); + } + + #[test] + fn test_rate_limiter_partial_tokens() { + let limiter = RateLimiter::new(2.0, 1.0); + + // Acquire partial tokens + assert!(limiter.try_acquire(0.5)); + assert!(limiter.try_acquire(0.5)); + assert!(limiter.try_acquire(0.5)); + assert!(limiter.try_acquire(0.5)); + + // Should fail with no tokens left + assert!(!limiter.try_acquire(0.1)); + } + + #[test] + fn test_performance_metrics_increment() { + let metrics = PerformanceMetrics::default(); + + assert_eq!(metrics.tts_requests.load(std::sync::atomic::Ordering::Relaxed), 0); + + metrics.increment_tts_requests(); + metrics.increment_tts_requests(); + + assert_eq!(metrics.tts_requests.load(std::sync::atomic::Ordering::Relaxed), 2); + + metrics.increment_tts_cache_hits(); + assert_eq!(metrics.tts_cache_hits.load(std::sync::atomic::Ordering::Relaxed), 1); + + metrics.increment_tts_cache_misses(); + assert_eq!(metrics.tts_cache_misses.load(std::sync::atomic::Ordering::Relaxed), 1); + } + + #[test] + fn test_metrics_snapshot_cache_hit_rate() { + let snapshot = MetricsSnapshot { + tts_requests: 10, + tts_cache_hits: 7, + tts_cache_misses: 3, + regex_cache_hits: 0, + regex_cache_misses: 0, + database_operations: 0, + voice_connections: 0, + }; + + assert!((snapshot.tts_cache_hit_rate() - 0.7).abs() < f64::EPSILON); + + let empty_snapshot = MetricsSnapshot { + tts_requests: 0, + tts_cache_hits: 0, + tts_cache_misses: 0, + regex_cache_hits: 0, + regex_cache_misses: 0, + database_operations: 0, + voice_connections: 0, + }; + + assert_eq!(empty_snapshot.tts_cache_hit_rate(), 0.0); + } + + #[test] + fn test_metrics_snapshot_regex_cache_hit_rate() { + let snapshot = MetricsSnapshot { + tts_requests: 0, + tts_cache_hits: 0, + tts_cache_misses: 0, + regex_cache_hits: 8, + regex_cache_misses: 2, + database_operations: 0, + voice_connections: 0, + }; + + assert!((snapshot.regex_cache_hit_rate() - 0.8).abs() < f64::EPSILON); + } + + #[test] + fn test_performance_metrics_get_stats() { + let metrics = PerformanceMetrics::default(); + + // Add some data + metrics.increment_tts_requests(); + metrics.increment_tts_requests(); + metrics.increment_tts_cache_hits(); + metrics.increment_database_operations(); + + let stats = metrics.get_stats(); + + assert_eq!(stats.tts_requests, 2); + assert_eq!(stats.tts_cache_hits, 1); + assert_eq!(stats.tts_cache_misses, 0); + assert_eq!(stats.database_operations, 1); + } +} \ No newline at end of file diff --git a/tts_cache.bin b/tts_cache.bin new file mode 100644 index 0000000..1b1cb4d Binary files /dev/null and b/tts_cache.bin differ