diff --git a/src/commands/config.rs b/src/commands/config.rs index a1964f7..b18e69e 100644 --- a/src/commands/config.rs +++ b/src/commands/config.rs @@ -34,7 +34,10 @@ pub async fn config_command( let tts_client = data_read .get::() .expect("Cannot get TTSClientData"); - let voicevox_speakers = tts_client.voicevox_client.get_styles().await + let voicevox_speakers = tts_client + .voicevox_client + .get_styles() + .await .unwrap_or_else(|e| { tracing::error!("Failed to get VOICEVOX styles: {}", e); vec![("VOICEVOX API unavailable".to_string(), 1)] @@ -58,11 +61,7 @@ pub async fn config_command( .placeholder("読み上げAPIを選択"), ); - let server_button = CreateActionRow::Buttons(vec![CreateButton::new("TTS_CONFIG_SERVER") - .label("サーバー設定") - .style(ButtonStyle::Primary)]); - - let mut components = vec![engine_select, server_button]; + let mut components = vec![engine_select]; for (index, speaker_chunk) in voicevox_speakers[0..24].chunks(25).enumerate() { let mut options = Vec::new(); @@ -86,6 +85,12 @@ pub async fn config_command( )); } + let server_button = CreateActionRow::Buttons(vec![CreateButton::new("TTS_CONFIG_SERVER") + .label("サーバー設定") + .style(ButtonStyle::Primary)]); + + components.push(server_button); + command .create_response( &ctx.http, diff --git a/src/commands/setup.rs b/src/commands/setup.rs index eb30765..47e8529 100644 --- a/src/commands/setup.rs +++ b/src/commands/setup.rs @@ -81,32 +81,44 @@ pub async fn setup_command( return Ok(()); } - let text_channel_id = { + let text_channel_ids = { if let Some(mode) = command.data.options.get(0) { match &mode.value { serenity::all::CommandDataOptionValue::String(value) => { match value.as_str() { - "TEXT_CHANNEL" => command.channel_id, + "TEXT_CHANNEL" => vec![command.channel_id], "NEW_THREAD" => { - command - .channel_id - .create_thread(&ctx.http, CreateThread::new("TTS").auto_archive_duration(AutoArchiveDuration::OneHour).kind(serenity::all::ChannelType::PublicThread)) - .await - .unwrap() - .id + vec![command + .channel_id + .create_thread(&ctx.http, CreateThread::new("TTS").auto_archive_duration(AutoArchiveDuration::OneHour).kind(serenity::all::ChannelType::PublicThread)) + .await + .unwrap() + .id] } - "VOICE_CHANNEL" => channel_id, - _ => channel_id, + "VOICE_CHANNEL" => vec![channel_id], + _ => if channel_id != command.channel_id { + vec![command.channel_id, channel_id] + } else { + vec![channel_id] + }, } }, - _ => channel_id, + _ => if channel_id != command.channel_id { + vec![command.channel_id, channel_id] + } else { + vec![channel_id] + }, } } else { - channel_id + if channel_id != command.channel_id { + vec![command.channel_id, channel_id] + } else { + vec![channel_id] + } } }; - let instance = TTSInstance::new(text_channel_id, channel_id, guild.id); + let instance = TTSInstance::new(text_channel_ids.clone(), channel_id, guild.id); storage.insert(guild.id, instance.clone()); // Save to database @@ -121,7 +133,7 @@ pub async fn setup_command( tracing::error!("Failed to save TTS instance to database: {}", e); } - text_channel_id + text_channel_ids[0] }; command diff --git a/src/commands/stop.rs b/src/commands/stop.rs index ce462ab..315a429 100644 --- a/src/commands/stop.rs +++ b/src/commands/stop.rs @@ -78,7 +78,7 @@ pub async fn stop_command( return Ok(()); } - let text_channel_id = storage.get(&guild.id).unwrap().text_channel; + let text_channel_id = storage.get(&guild.id).unwrap().text_channels[0]; storage.remove(&guild.id); // Remove from database diff --git a/src/connection_monitor.rs b/src/connection_monitor.rs index c35bbec..89de178 100644 --- a/src/connection_monitor.rs +++ b/src/connection_monitor.rs @@ -1,7 +1,10 @@ -use serenity::{prelude::Context, all::{CreateMessage, CreateEmbed}}; +use serenity::{ + all::{CreateEmbed, CreateMessage}, + prelude::Context, +}; use std::time::Duration; use tokio::time; -use tracing::{error, info, warn, instrument}; +use tracing::{error, info, instrument, warn}; use crate::data::{DatabaseClientData, TTSData}; @@ -69,7 +72,9 @@ impl ConnectionMonitor { let data_read = ctx.data.read().await; data_read .get::() - .ok_or_else(|| ConnectionMonitorError::VoiceChannelCheck("Cannot get TTSStorage".to_string()))? + .ok_or_else(|| { + ConnectionMonitorError::VoiceChannelCheck("Cannot get TTSStorage".to_string()) + })? .clone() }; @@ -77,7 +82,11 @@ impl ConnectionMonitor { let data_read = ctx.data.read().await; data_read .get::() - .ok_or_else(|| ConnectionMonitorError::VoiceChannelCheck("Cannot get DatabaseClientData".to_string()))? + .ok_or_else(|| { + ConnectionMonitorError::VoiceChannelCheck( + "Cannot get DatabaseClientData".to_string(), + ) + })? .clone() }; @@ -86,7 +95,8 @@ impl ConnectionMonitor { for (guild_id, instance) in storage.iter() { // Check if bot is still connected to voice channel - let manager = songbird::get(ctx).await + let manager = songbird::get(ctx) + .await .ok_or(ConnectionMonitorError::SongbirdManagerNotFound)?; let call = manager.get(*guild_id); @@ -114,8 +124,12 @@ impl ConnectionMonitor { if should_reconnect { // Try to reconnect with retry logic - let attempts = self.reconnection_attempts.get(guild_id).copied().unwrap_or(0); - + let attempts = self + .reconnection_attempts + .get(guild_id) + .copied() + .unwrap_or(0); + if attempts >= MAX_RECONNECTION_ATTEMPTS { error!( guild_id = %guild_id, @@ -129,7 +143,8 @@ impl ConnectionMonitor { // Apply exponential backoff if attempts > 0 { - let backoff_duration = Duration::from_secs(RECONNECTION_BACKOFF_SECS * (2_u64.pow(attempts))); + let backoff_duration = + Duration::from_secs(RECONNECTION_BACKOFF_SECS * (2_u64.pow(attempts))); warn!( guild_id = %guild_id, attempt = attempts + 1, @@ -146,17 +161,24 @@ impl ConnectionMonitor { attempts = attempts + 1, "Successfully reconnected to voice channel" ); - + // Reset reconnection attempts on success self.reconnection_attempts.remove(guild_id); - + // Send notification message to text channel with embed let embed = CreateEmbed::new() .title("🔄 自動再接続しました") .description("読み上げを停止したい場合は `/stop` コマンドを使用してください。") .color(0x00ff00); - if let Err(e) = instance.text_channel.send_message(&ctx.http, CreateMessage::new().embed(embed)).await { - error!(guild_id = %guild_id, error = %e, "Failed to send reconnection message"); + + // Send message to the first text channel + if let Some(&text_channel) = instance.text_channels.first() { + if let Err(e) = text_channel + .send_message(&ctx.http, CreateMessage::new().embed(embed)) + .await + { + error!(guild_id = %guild_id, error = %e, "Failed to send reconnection message"); + } } } Err(e) => { @@ -168,7 +190,7 @@ impl ConnectionMonitor { error = %e, "Failed to reconnect to voice channel" ); - + if new_attempts >= MAX_RECONNECTION_ATTEMPTS { guilds_to_remove.push(*guild_id); self.reconnection_attempts.remove(guild_id); @@ -201,10 +223,10 @@ impl ConnectionMonitor { error!(guild_id = %guild_id, error = %e, "Failed to remove bot from voice channel"); } } - + info!(guild_id = %guild_id, "Removed disconnected TTS instance"); } - + Ok(()) } @@ -215,21 +237,29 @@ impl ConnectionMonitor { ctx: &Context, instance: &crate::tts::instance::TTSInstance, ) -> Result { - let channels = instance.guild.channels(&ctx.http).await - .map_err(|e| ConnectionMonitorError::VoiceChannelCheck(format!("Failed to get guild channels: {}", e)))?; + let channels = instance.guild.channels(&ctx.http).await.map_err(|e| { + ConnectionMonitorError::VoiceChannelCheck(format!( + "Failed to get guild channels: {}", + e + )) + })?; if let Some(channel) = channels.get(&instance.voice_channel) { - let members = channel.members(&ctx.cache) - .map_err(|e| ConnectionMonitorError::VoiceChannelCheck(format!("Failed to get channel members: {}", e)))?; + let members = channel.members(&ctx.cache).map_err(|e| { + ConnectionMonitorError::VoiceChannelCheck(format!( + "Failed to get channel members: {}", + e + )) + })?; let user_count = members.iter().filter(|member| !member.user.bot).count(); - + info!( guild_id = %instance.guild, channel_id = %instance.voice_channel, user_count = user_count, "Checked voice channel users" ); - + Ok(user_count > 0) } else { warn!( diff --git a/src/database/database.rs b/src/database/database.rs index 1eacc40..ddaf37e 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -1,14 +1,14 @@ use std::fmt::Debug; -use bb8_redis::{bb8::Pool, RedisConnectionManager, redis::AsyncCommands}; use crate::{ - errors::{NCBError, Result, constants::*}, + errors::{constants::*, NCBError, Result}, tts::{ gcp_tts::structs::voice_selection_params::VoiceSelectionParams, instance::TTSInstance, tts_type::TTSType, }, }; -use serenity::model::id::{GuildId, UserId, ChannelId}; +use bb8_redis::{bb8::Pool, redis::AsyncCommands, RedisConnectionManager}; +use serenity::model::id::{ChannelId, GuildId, UserId}; use std::collections::HashMap; use super::{dictionary::Dictionary, server_config::ServerConfig, user_config::UserConfig}; @@ -22,7 +22,7 @@ impl Database { pub fn new(pool: Pool) -> Self { Self { pool } } - + pub async fn new_with_url(redis_url: String) -> Result { let manager = RedisConnectionManager::new(redis_url)?; let pool = Pool::builder() @@ -62,13 +62,13 @@ impl Database { } #[tracing::instrument] - async fn get_config( - &self, - key: &str, - ) -> Result> { - let mut connection = self.pool.get().await + async fn get_config(&self, key: &str) -> Result> { + let mut connection = self + .pool + .get() + .await .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; - + let config: String = connection.get(key).await.unwrap_or_default(); if config.is_empty() { @@ -85,24 +85,20 @@ impl Database { } #[tracing::instrument] - async fn set_config( - &self, - key: &str, - config: &T, - ) -> Result<()> { - let mut connection = self.pool.get().await + async fn set_config(&self, key: &str, config: &T) -> Result<()> { + let mut connection = self + .pool + .get() + .await .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; - + let config_str = serde_json::to_string(config)?; connection.set::<_, _, ()>(key, config_str).await?; Ok(()) } #[tracing::instrument] - pub async fn get_server_config( - &self, - server_id: u64, - ) -> Result> { + pub async fn get_server_config(&self, server_id: u64) -> Result> { self.get_config(&Self::server_key(server_id)).await } @@ -112,20 +108,12 @@ impl Database { } #[tracing::instrument] - pub async fn set_server_config( - &self, - server_id: u64, - config: ServerConfig, - ) -> Result<()> { + pub async fn set_server_config(&self, server_id: u64, config: ServerConfig) -> Result<()> { self.set_config(&Self::server_key(server_id), &config).await } #[tracing::instrument] - pub async fn set_user_config( - &self, - user_id: u64, - config: UserConfig, - ) -> Result<()> { + pub async fn set_user_config(&self, user_id: u64, config: UserConfig) -> Result<()> { self.set_config(&Self::user_key(user_id), &config).await } @@ -134,8 +122,9 @@ impl Database { let config = ServerConfig { dictionary: Dictionary::new(), autostart_channel_id: None, - voice_state_announce: Some(true), - read_username: Some(true), + autostart_text_channel_id: None, + voice_state_announce: Some(false), + read_username: Some(false), }; self.set_server_config(server_id, config).await @@ -173,10 +162,7 @@ impl Database { } #[tracing::instrument] - pub async fn get_user_config_or_default( - &self, - user_id: u64, - ) -> Result> { + pub async fn get_user_config_or_default(&self, user_id: u64) -> Result> { match self.get_user_config(user_id).await? { Some(config) => Ok(Some(config)), None => { @@ -187,11 +173,7 @@ impl Database { } /// Save TTS instance to database - pub async fn save_tts_instance( - &self, - guild_id: GuildId, - instance: &TTSInstance, - ) -> Result<()> { + pub async fn save_tts_instance(&self, guild_id: GuildId, instance: &TTSInstance) -> Result<()> { let key = Self::tts_instance_key(guild_id.get()); let list_key = Self::tts_instances_list_key(); @@ -199,19 +181,21 @@ impl Database { self.set_config(&key, instance).await?; // Add guild_id to the list of active instances - let mut connection = self.pool.get().await + let mut connection = self + .pool + .get() + .await .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; - - connection.sadd::<_, _, ()>(&list_key, guild_id.get()).await?; + + connection + .sadd::<_, _, ()>(&list_key, guild_id.get()) + .await?; Ok(()) } /// Load TTS instance from database #[tracing::instrument] - pub async fn load_tts_instance( - &self, - guild_id: GuildId, - ) -> Result> { + pub async fn load_tts_instance(&self, guild_id: GuildId) -> Result> { let key = Self::tts_instance_key(guild_id.get()); self.get_config(&key).await } @@ -222,12 +206,16 @@ impl Database { let key = Self::tts_instance_key(guild_id.get()); let list_key = Self::tts_instances_list_key(); - let mut connection = self.pool.get().await + let mut connection = self + .pool + .get() + .await .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; - + let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await; - let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.srem(&list_key, guild_id.get()).await; - + let _: std::result::Result<(), bb8_redis::redis::RedisError> = + connection.srem(&list_key, guild_id.get()).await; + Ok(()) } @@ -236,9 +224,12 @@ impl Database { pub async fn get_all_tts_instances(&self) -> Result> { let list_key = Self::tts_instances_list_key(); - let mut connection = self.pool.get().await + let mut connection = self + .pool + .get() + .await .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; - + let guild_ids: Vec = connection.smembers(&list_key).await.unwrap_or_default(); let mut instances = Vec::new(); @@ -274,39 +265,34 @@ impl Database { self.get_config(&key).await } - pub async fn delete_user_config( - &self, - guild_id: GuildId, - user_id: UserId, - ) -> Result<()> { + pub async fn delete_user_config(&self, guild_id: GuildId, user_id: UserId) -> Result<()> { let key = Self::user_config_key(guild_id.get(), user_id.get()); - let mut connection = self.pool.get().await + let mut connection = self + .pool + .get() + .await .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await; Ok(()) } // Additional server config methods - pub async fn save_server_config( - &self, - guild_id: GuildId, - config: &ServerConfig, - ) -> Result<()> { + pub async fn save_server_config(&self, guild_id: GuildId, config: &ServerConfig) -> Result<()> { let key = Self::server_config_key(guild_id.get()); self.set_config(&key, config).await } - pub async fn load_server_config( - &self, - guild_id: GuildId, - ) -> Result> { + pub async fn load_server_config(&self, guild_id: GuildId) -> Result> { let key = Self::server_config_key(guild_id.get()); self.get_config(&key).await } pub async fn delete_server_config(&self, guild_id: GuildId) -> Result<()> { let key = Self::server_config_key(guild_id.get()); - let mut connection = self.pool.get().await + let mut connection = self + .pool + .get() + .await .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await; Ok(()) @@ -322,10 +308,7 @@ impl Database { self.set_config(&key, dictionary).await } - pub async fn load_dictionary( - &self, - guild_id: GuildId, - ) -> Result> { + pub async fn load_dictionary(&self, guild_id: GuildId) -> Result> { let key = Self::dictionary_key(guild_id.get()); let dict: Option> = self.get_config(&key).await?; Ok(dict.unwrap_or_default()) @@ -333,7 +316,10 @@ impl Database { pub async fn delete_dictionary(&self, guild_id: GuildId) -> Result<()> { let key = Self::dictionary_key(guild_id.get()); - let mut connection = self.pool.get().await + let mut connection = self + .pool + .get() + .await .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await; Ok(()) @@ -345,7 +331,10 @@ impl Database { pub async fn list_active_instances(&self) -> Result> { let list_key = Self::tts_instances_list_key(); - let mut connection = self.pool.get().await + let mut connection = self + .pool + .get() + .await .map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?; let guild_ids: Vec = connection.smembers(&list_key).await.unwrap_or_default(); Ok(guild_ids) @@ -355,9 +344,9 @@ impl Database { #[cfg(test)] mod tests { use super::*; + use crate::errors::constants; use bb8_redis::redis::AsyncCommands; use serial_test::serial; - use crate::errors::constants; // Helper function to create test database (requires Redis running) async fn create_test_database() -> Result { @@ -422,16 +411,19 @@ mod tests { }; let guild_id = GuildId::new(12345); - let test_instance = TTSInstance::new( - ChannelId::new(123), - ChannelId::new(456), - guild_id - ); + let test_instance = + TTSInstance::new_single(ChannelId::new(123), ChannelId::new(456), guild_id); // Clear any existing data if let Ok(mut conn) = db.pool.get().await { - let _: () = conn.del(Database::tts_instance_key(guild_id.get())).await.unwrap_or_default(); - let _: () = conn.srem(Database::tts_instances_list_key(), guild_id.get()).await.unwrap_or_default(); + let _: () = conn + .del(Database::tts_instance_key(guild_id.get())) + .await + .unwrap_or_default(); + let _: () = conn + .srem(Database::tts_instances_list_key(), guild_id.get()) + .await + .unwrap_or_default(); } else { return; // Skip if can't get connection } @@ -452,7 +444,7 @@ mod tests { let loaded_instance = load_result.unwrap(); if let Some(instance) = loaded_instance { assert_eq!(instance.guild, test_instance.guild); - assert_eq!(instance.text_channel, test_instance.text_channel); + assert_eq!(instance.text_channels, test_instance.text_channels); assert_eq!(instance.voice_channel, test_instance.voice_channel); } @@ -485,4 +477,4 @@ mod tests { assert!(constants::REDIS_MAX_CONNECTIONS > 0); assert!(constants::REDIS_MIN_IDLE_CONNECTIONS <= constants::REDIS_MAX_CONNECTIONS); } -} \ No newline at end of file +} diff --git a/src/database/server_config.rs b/src/database/server_config.rs index 592fc81..94188bc 100644 --- a/src/database/server_config.rs +++ b/src/database/server_config.rs @@ -10,6 +10,7 @@ pub struct DictionaryOnlyServerConfig { pub struct ServerConfig { pub dictionary: Dictionary, pub autostart_channel_id: Option, + pub autostart_text_channel_id: Option, pub voice_state_announce: Option, pub read_username: Option, } diff --git a/src/errors.rs b/src/errors.rs index 8c50865..265a14f 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -324,6 +324,8 @@ pub mod constants { pub const CHANNEL_LEAVE_SUCCESS: &str = "CHANNEL_LEAVE_SUCCESS"; pub const AUTOSTART_CHANNEL_SET: &str = "AUTOSTART_CHANNEL_SET"; pub const SET_AUTOSTART_CHANNEL_CLEAR: &str = "SET_AUTOSTART_CHANNEL_CLEAR"; + pub const SET_AUTOSTART_TEXT_CHANNEL: &str = "SET_AUTOSTART_TEXT_CHANNEL"; + pub const SET_AUTOSTART_TEXT_CHANNEL_CLEAR: &str = "SET_AUTOSTART_TEXT_CHANNEL_CLEAR"; // TTS configuration constants pub const TTS_CONFIG_SERVER_ADD_DICTIONARY: &str = "TTS_CONFIG_SERVER_ADD_DICTIONARY"; diff --git a/src/event_handler.rs b/src/event_handler.rs index 6c42960..0fef2ca 100644 --- a/src/event_handler.rs +++ b/src/event_handler.rs @@ -55,55 +55,60 @@ impl EventHandler for Handler { } let rows = modal.data.components.clone(); - + // Extract rule name with proper error handling - let rule_name = match rows.get(0) - .and_then(|row| row.components.get(0)) - .and_then(|component| { - if let ActionRowComponent::InputText(text) = component { - text.value.as_ref() - } else { - None + let rule_name = + match rows + .get(0) + .and_then(|row| row.components.get(0)) + .and_then(|component| { + if let ActionRowComponent::InputText(text) = component { + text.value.as_ref() + } else { + None + } + }) { + Some(name) => { + if let Err(e) = validation::validate_rule_name(name) { + tracing::error!("Invalid rule name: {}", e); + return; + } + name.clone() } - }) { - Some(name) => { - if let Err(e) = validation::validate_rule_name(name) { - tracing::error!("Invalid rule name: {}", e); + None => { + tracing::error!("Cannot extract rule name from modal"); return; } - name.clone() - }, - None => { - tracing::error!("Cannot extract rule name from modal"); - return; - } - }; + }; // Extract 'from' field with validation - let from = match rows.get(1) - .and_then(|row| row.components.get(0)) - .and_then(|component| { - if let ActionRowComponent::InputText(text) = component { - text.value.as_ref() - } else { - None + let from = + match rows + .get(1) + .and_then(|row| row.components.get(0)) + .and_then(|component| { + if let ActionRowComponent::InputText(text) = component { + text.value.as_ref() + } else { + None + } + }) { + Some(pattern) => { + if let Err(e) = validation::validate_regex_pattern(pattern) { + tracing::error!("Invalid regex pattern: {}", e); + return; + } + pattern.clone() } - }) { - Some(pattern) => { - if let Err(e) = validation::validate_regex_pattern(pattern) { - tracing::error!("Invalid regex pattern: {}", e); + None => { + tracing::error!("Cannot extract regex pattern from modal"); return; } - pattern.clone() - }, - None => { - tracing::error!("Cannot extract regex pattern from modal"); - return; - } - }; + }; // Extract 'to' field with validation - let to = match rows.get(2) + let to = match rows + .get(2) .and_then(|row| row.components.get(0)) .and_then(|component| { if let ActionRowComponent::InputText(text) = component { @@ -118,7 +123,7 @@ impl EventHandler for Handler { return; } replacement.clone() - }, + } None => { tracing::error!("Cannot extract replacement text from modal"); return; @@ -143,12 +148,15 @@ impl EventHandler for Handler { } }; - match database.get_server_config_or_default(modal.guild_id.unwrap().get()).await { + match database + .get_server_config_or_default(modal.guild_id.unwrap().get()) + .await + { Ok(Some(config)) => config, Ok(None) => { tracing::error!("No server config found"); return; - }, + } Err(e) => { tracing::error!("Database error: {}", e); return; @@ -166,7 +174,10 @@ impl EventHandler for Handler { } }; - if let Err(e) = database.set_server_config(modal.guild_id.unwrap().get(), config).await { + if let Err(e) = database + .set_server_config(modal.guild_id.unwrap().get(), config) + .await + { tracing::error!("Failed to save server config: {}", e); return; } @@ -502,8 +513,64 @@ impl EventHandler for Handler { .create_response( &ctx.http, CreateInteractionResponse::UpdateMessage( - CreateInteractionResponseMessage::new() - .content(response_content), + CreateInteractionResponseMessage::new().content(response_content), + ), + ) + .await + .unwrap(); + } + id if id == SET_AUTOSTART_TEXT_CHANNEL => { + let autostart_text_channel_id = match message_component.data.kind { + ComponentInteractionDataKind::StringSelect { ref values, .. } => { + if values.len() == 0 { + None + } else if values[0] == "SET_AUTOSTART_TEXT_CHANNEL_CLEAR" { + None + } else { + Some( + u64::from_str_radix( + &values[0] + .strip_prefix("SET_AUTOSTART_TEXT_CHANNEL_") + .unwrap(), + 10, + ) + .unwrap(), + ) + } + } + _ => panic!("Cannot get index"), + }; + + { + let data_read = ctx.data.read().await; + let database = data_read + .get::() + .expect("Cannot get DatabaseClientData") + .clone(); + + let mut config = database + .get_server_config_or_default(message_component.guild_id.unwrap().get()) + .await + .unwrap() + .unwrap(); + config.autostart_text_channel_id = autostart_text_channel_id; + database + .set_server_config(message_component.guild_id.unwrap().get(), config) + .await + .unwrap(); + } + + let response_content = if autostart_text_channel_id.is_some() { + "自動参加テキストチャンネルを設定しました。" + } else { + "自動参加テキストチャンネルを解除しました。" + }; + + message_component + .create_response( + &ctx.http, + CreateInteractionResponse::UpdateMessage( + CreateInteractionResponseMessage::new().content(response_content), ), ) .await @@ -534,17 +601,15 @@ impl EventHandler for Handler { .unwrap(); let mut options = Vec::new(); - + // 解除オプションを追加 - let clear_option = CreateSelectMenuOption::new( - "解除", - "SET_AUTOSTART_CHANNEL_CLEAR", - ) - .description("自動参加チャンネルを解除します") - .default_selection(autostart_channel_id == 0); + let clear_option = + CreateSelectMenuOption::new("解除", "SET_AUTOSTART_CHANNEL_CLEAR") + .description("自動参加チャンネルを解除します") + .default_selection(autostart_channel_id == 0); options.push(clear_option); - for (id, channel) in channels { + for (id, channel) in channels.clone() { if channel.kind != ChannelType::Voice { continue; } @@ -562,6 +627,33 @@ impl EventHandler for Handler { options.push(option); } + let mut text_channel_options = Vec::new(); + + let clear_option = + CreateSelectMenuOption::new("解除", "SET_AUTOSTART_TEXT_CHANNEL_CLEAR") + .description("自動参加テキストチャンネルを解除します") + .default_selection(config.autostart_text_channel_id.is_none()); + text_channel_options.push(clear_option); + + for (id, channel) in channels { + if channel.kind != ChannelType::Text { + continue; + } + + let description = channel + .topic + .unwrap_or_else(|| String::from("No topic provided.")); + let option = CreateSelectMenuOption::new( + &channel.name, + format!("SET_AUTOSTART_TEXT_CHANNEL_{}", id.get()), + ) + .description(description) + .default_selection( + channel.id.get() == config.autostart_text_channel_id.unwrap_or(0), + ); + text_channel_options.push(option); + } + message_component .create_response( &ctx.http, @@ -577,6 +669,16 @@ impl EventHandler for Handler { .min_values(0) .max_values(1), ), + CreateActionRow::SelectMenu( + CreateSelectMenu::new( + "SET_AUTOSTART_TEXT_CHANNEL", + CreateSelectMenuKind::String { + options: text_channel_options, + }, + ) + .min_values(0) + .max_values(1), + ), CreateActionRow::Buttons(vec![CreateButton::new( "TTS_CONFIG_SERVER_BACK", ) diff --git a/src/events/message_receive.rs b/src/events/message_receive.rs index d80f470..690bc9e 100644 --- a/src/events/message_receive.rs +++ b/src/events/message_receive.rs @@ -31,7 +31,7 @@ pub async fn message(ctx: Context, message: Message) { let instance = storage.get_mut(&guild_id).unwrap(); - if instance.text_channel != message.channel_id { + if !instance.contains_text_channel(message.channel_id) { return; } diff --git a/src/events/voice_state_update.rs b/src/events/voice_state_update.rs index 0dd4d96..88d4770 100644 --- a/src/events/voice_state_update.rs +++ b/src/events/voice_state_update.rs @@ -62,7 +62,14 @@ pub async fn voice_state_update(ctx: Context, old: Option, new: Voic .expect("Cannot get songbird client.") .clone(); - let instance = TTSInstance::new(new_channel, new_channel, guild_id); + let text_channel_ids = + if let Some(text_channel_id) = config.autostart_text_channel_id { + vec![text_channel_id.into(), new_channel] + } else { + vec![new_channel] + }; + + let instance = TTSInstance::new(text_channel_ids, new_channel, guild_id); storage.insert(guild_id, instance.clone()); // Save to database @@ -82,7 +89,10 @@ pub async fn voice_state_update(ctx: Context, old: Option, new: Voic let tts_client = data .get::() .expect("Cannot get TTSClientData"); - let voicevox_speakers = tts_client.voicevox_client.get_speakers().await + let voicevox_speakers = tts_client + .voicevox_client + .get_speakers() + .await .unwrap_or_else(|e| { tracing::error!("Failed to get VOICEVOX speakers: {}", e); vec!["VOICEVOX API unavailable".to_string()] @@ -142,12 +152,15 @@ pub async fn voice_state_update(ctx: Context, old: Option, new: Voic } if del_flag { - let _ = storage - .get(&guild_id) - .unwrap() - .text_channel - .edit_thread(&ctx.http, EditThread::new().archived(true)) - .await; + // Archive thread if it exists + if let Some(&channel_id) = storage.get(&guild_id).unwrap().text_channels.first() { + let http = ctx.http.clone(); + tokio::spawn(async move { + let _ = channel_id + .edit_thread(&http, EditThread::new().archived(true)) + .await; + }); + } storage.remove(&guild_id); // Remove from database diff --git a/src/tts/instance.rs b/src/tts/instance.rs index fd44034..35b8fc9 100644 --- a/src/tts/instance.rs +++ b/src/tts/instance.rs @@ -15,22 +15,54 @@ use crate::tts::message::TTSMessage; pub struct TTSInstance { #[serde(skip)] // Messageは複雑すぎるのでシリアライズしない pub before_message: Option, - pub text_channel: ChannelId, + pub text_channels: Vec, pub voice_channel: ChannelId, pub guild: GuildId, } impl TTSInstance { /// Create a new TTSInstance - pub fn new(text_channel: ChannelId, voice_channel: ChannelId, guild: GuildId) -> Self { + pub fn new(text_channels: Vec, voice_channel: ChannelId, guild: GuildId) -> Self { Self { before_message: None, - text_channel, + text_channels, voice_channel, guild, } } + /// Create a new TTSInstance with a single text channel + pub fn new_single(text_channel: ChannelId, voice_channel: ChannelId, guild: GuildId) -> Self { + Self::new(vec![text_channel], voice_channel, guild) + } + + /// Add a text channel to the instance + pub fn add_text_channel(&mut self, channel_id: ChannelId) { + if !self.text_channels.contains(&channel_id) { + self.text_channels.push(channel_id); + } + } + + /// Remove a text channel from the instance + pub fn remove_text_channel(&mut self, channel_id: ChannelId) -> bool { + if let Some(pos) = self.text_channels.iter().position(|&x| x == channel_id) { + self.text_channels.remove(pos); + true + } else { + false + } + } + + /// Check if a channel is in the text channels list + pub fn contains_text_channel(&self, channel_id: ChannelId) -> bool { + self.text_channels.contains(&channel_id) + } + + /// Get all text channels + pub fn get_text_channels(&self) -> &Vec { + &self.text_channels + } + pub async fn check_connection(&self, ctx: &Context) -> bool { let manager = match songbird::get(ctx).await { Some(manager) => manager,