feat: restoring instance

This commit is contained in:
mii443
2025-05-25 23:42:01 +09:00
parent ffa49c18cc
commit d382a045d0
7 changed files with 277 additions and 41 deletions

View File

@ -8,7 +8,7 @@ use serenity::{
use tracing::info;
use crate::{
data::{TTSClientData, TTSData},
data::{DatabaseClientData, TTSClientData, TTSData},
tts::instance::TTSInstance,
};
@ -106,15 +106,20 @@ pub async fn setup_command(
}
};
storage.insert(
guild.id,
TTSInstance {
before_message: None,
guild: guild.id,
text_channel: text_channel_id,
voice_channel: channel_id,
},
);
let instance = TTSInstance::new(text_channel_id, channel_id, guild.id);
storage.insert(guild.id, instance.clone());
// Save to database
let data_read = ctx.data.read().await;
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
drop(data_read);
if let Err(e) = database.save_tts_instance(guild.id, &instance).await {
tracing::error!("Failed to save TTS instance to database: {}", e);
}
text_channel_id
};

View File

@ -1,13 +1,12 @@
use serenity::{
all::{
CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, EditThread
CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, EditThread,
},
model::prelude::UserId,
prelude::Context
prelude::Context,
};
use crate::data::TTSData;
use crate::data::{DatabaseClientData, TTSData};
pub async fn stop_command(
ctx: &Context,
@ -15,12 +14,14 @@ pub async fn stop_command(
) -> Result<(), Box<dyn std::error::Error>> {
if command.guild_id.is_none() {
command
.create_response(&ctx.http,
.create_response(
&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("このコマンドはサーバーでのみ使用可能です.")
.ephemeral(true)
))
.ephemeral(true),
),
)
.await?;
return Ok(());
}
@ -35,12 +36,14 @@ pub async fn stop_command(
if channel_id.is_none() {
command
.create_response(&ctx.http,
.create_response(
&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("ボイスチャンネルに参加してから実行してください.")
.ephemeral(true)
))
.ephemeral(true),
),
)
.await?;
return Ok(());
}
@ -60,32 +63,48 @@ pub async fn stop_command(
let text_channel_id = {
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild.id) {
command
.create_response(&ctx.http,
.create_response(
&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("すでに停止しています")
.ephemeral(true)
))
.ephemeral(true),
),
)
.await?;
return Ok(());
}
let text_channel_id = storage.get(&guild.id).unwrap().text_channel;
storage.remove(&guild.id);
// Remove from database
let data_read = ctx.data.read().await;
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
drop(data_read);
if let Err(e) = database.remove_tts_instance(guild.id).await {
tracing::error!("Failed to remove TTS instance from database: {}", e);
}
text_channel_id
};
let _handler = manager.remove(guild.id).await;
command
.create_response(&ctx.http,
.create_response(
&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("停止しました")
))
CreateInteractionResponseMessage::new().content("停止しました"),
),
)
.await?;
let _ = text_channel_id
@ -93,4 +112,4 @@ pub async fn stop_command(
.await;
Ok(())
}
}

View File

@ -1,8 +1,10 @@
use std::fmt::Debug;
use crate::tts::{
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType,
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, instance::TTSInstance,
tts_type::TTSType,
};
use serenity::model::id::GuildId;
use super::{dictionary::Dictionary, server_config::ServerConfig, user_config::UserConfig};
use redis::Commands;
@ -25,6 +27,14 @@ impl Database {
format!("discord_user:{}", user_id)
}
fn tts_instance_key(guild_id: u64) -> String {
format!("tts_instance:{}", guild_id)
}
fn tts_instances_list_key() -> String {
"tts_instances_list".to_string()
}
#[tracing::instrument]
fn get_config<T: serde::de::DeserializeOwned>(
&self,
@ -149,4 +159,79 @@ impl Database {
}
}
}
/// Save TTS instance to database
#[tracing::instrument]
pub async fn save_tts_instance(
&self,
guild_id: GuildId,
instance: &TTSInstance,
) -> redis::RedisResult<()> {
let key = Self::tts_instance_key(guild_id.get());
let list_key = Self::tts_instances_list_key();
// Save the instance
let result = self.set_config(&key, instance);
// Add guild_id to the list of active instances
if result.is_ok() {
match self.client.get_connection() {
Ok(mut connection) => {
let _: redis::RedisResult<()> = connection.sadd(&list_key, guild_id.get());
}
Err(_) => {}
}
}
result
}
/// Load TTS instance from database
#[tracing::instrument]
pub async fn load_tts_instance(
&self,
guild_id: GuildId,
) -> redis::RedisResult<Option<TTSInstance>> {
let key = Self::tts_instance_key(guild_id.get());
self.get_config(&key)
}
/// Remove TTS instance from database
#[tracing::instrument]
pub async fn remove_tts_instance(&self, guild_id: GuildId) -> redis::RedisResult<()> {
let key = Self::tts_instance_key(guild_id.get());
let list_key = Self::tts_instances_list_key();
match self.client.get_connection() {
Ok(mut connection) => {
let _: redis::RedisResult<()> = connection.del(&key);
let _: redis::RedisResult<()> = connection.srem(&list_key, guild_id.get());
Ok(())
}
Err(e) => Err(e),
}
}
/// Get all active TTS instances
#[tracing::instrument]
pub async fn get_all_tts_instances(&self) -> redis::RedisResult<Vec<(GuildId, TTSInstance)>> {
let list_key = Self::tts_instances_list_key();
match self.client.get_connection() {
Ok(mut connection) => {
let guild_ids: Vec<u64> = connection.smembers(&list_key).unwrap_or_default();
let mut instances = Vec::new();
for guild_id in guild_ids {
let guild_id = GuildId::new(guild_id);
if let Ok(Some(instance)) = self.load_tts_instance(guild_id).await {
instances.push((guild_id, instance));
}
}
Ok(instances)
}
Err(e) => Err(e),
}
}
}

View File

@ -5,6 +5,8 @@ use serenity::{
};
use tracing::info;
use crate::data::{DatabaseClientData, TTSData};
#[tracing::instrument]
pub async fn ready(ctx: Context, ready: Ready) {
info!("{} is connected!", ready.user.name);
@ -30,4 +32,66 @@ pub async fn ready(ctx: Context, ready: Ready) {
)
.await
.unwrap();
// Restore TTS instances from database
restore_tts_instances(&ctx).await;
}
/// Restore TTS instances from database and reconnect to voice channels
async fn restore_tts_instances(ctx: &Context) {
info!("Restoring TTS instances from database...");
let data = ctx.data.read().await;
let database = data
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let tts_data = data.get::<TTSData>().unwrap().clone();
drop(data);
match database.get_all_tts_instances().await {
Ok(instances) => {
let mut restored_count = 0;
let mut failed_count = 0;
for (guild_id, instance) in instances {
// Try to reconnect to voice channel
match instance.reconnect(ctx).await {
Ok(_) => {
// Add to in-memory storage
let mut tts_data = tts_data.write().await;
tts_data.insert(guild_id, instance);
drop(tts_data);
restored_count += 1;
info!("Restored TTS instance for guild {}", guild_id);
}
Err(e) => {
failed_count += 1;
tracing::warn!(
"Failed to restore TTS instance for guild {}: {}",
guild_id,
e
);
// Remove failed instance from database
if let Err(db_err) = database.remove_tts_instance(guild_id).await {
tracing::error!(
"Failed to remove invalid TTS instance from database: {}",
db_err
);
}
}
}
}
info!(
"TTS restoration complete: {} restored, {} failed",
restored_count, failed_count
);
}
Err(e) => {
tracing::error!("Failed to load TTS instances from database: {}", e);
}
}
}

View File

@ -61,15 +61,21 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
.await
.expect("Cannot get songbird client.")
.clone();
storage.insert(
guild_id,
TTSInstance {
before_message: None,
guild: guild_id,
text_channel: new_channel,
voice_channel: new_channel,
},
);
let instance = TTSInstance::new(new_channel, new_channel, guild_id);
storage.insert(guild_id, instance.clone());
// Save to database
let data_read = ctx.data.read().await;
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
drop(data_read);
if let Err(e) = database.save_tts_instance(guild_id, &instance).await {
tracing::error!("Failed to save TTS instance to database: {}", e);
}
let _handler = manager.join(guild_id, new_channel).await;
let data = ctx.data.read().await;
@ -140,6 +146,18 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
.await;
storage.remove(&guild_id);
// Remove from database
let data_read = ctx.data.read().await;
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
drop(data_read);
if let Err(e) = database.remove_tts_instance(guild_id).await {
tracing::error!("Failed to remove TTS instance from database: {}", e);
}
let manager = songbird::get(&ctx)
.await
.expect("Cannot get songbird client.")

View File

@ -111,7 +111,7 @@ async fn main() {
let mut data = client.data.write().await;
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
data.insert::<TTSClientData>(Arc::new(TTS::new(voicevox, tts)));
data.insert::<DatabaseClientData>(Arc::new(database_client));
data.insert::<DatabaseClientData>(Arc::new(database_client.clone()));
}
info!("Bot initialized.");

View File

@ -1,5 +1,6 @@
use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use serenity::{
model::{
channel::Message,
@ -10,8 +11,9 @@ use serenity::{
use crate::tts::message::TTSMessage;
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TTSInstance {
#[serde(skip)] // Messageは複雑すぎるのでシリアライズしない
pub before_message: Option<Message>,
pub text_channel: ChannelId,
pub voice_channel: ChannelId,
@ -19,6 +21,49 @@ pub struct TTSInstance {
}
impl TTSInstance {
/// Create a new TTSInstance
pub fn new(text_channel: ChannelId, voice_channel: ChannelId, guild: GuildId) -> Self {
Self {
before_message: None,
text_channel,
voice_channel,
guild,
}
}
/// Reconnect to the voice channel after bot restart
#[tracing::instrument]
pub async fn reconnect(
&self,
ctx: &Context,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let manager = songbird::get(&ctx)
.await
.ok_or("Songbird manager not available")?;
// Check if we're already connected
if manager.get(self.guild).is_some() {
tracing::info!("Already connected to guild {}", self.guild);
return Ok(());
}
// Try to connect to the voice channel
match manager.join(self.guild, self.voice_channel).await {
Ok(_) => {
tracing::info!(
"Successfully reconnected to voice channel {} in guild {}",
self.voice_channel,
self.guild
);
Ok(())
}
Err(e) => {
tracing::error!("Failed to reconnect to voice channel: {}", e);
Err(Box::new(e))
}
}
}
/// Synthesize text to speech and send it to the voice channel.
///
/// Example: