mirror of
https://github.com/mii443/ncb-tts-r2.git
synced 2025-08-22 16:15:29 +00:00
feat: restoring instance
This commit is contained in:
@ -8,7 +8,7 @@ use serenity::{
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
data::{TTSClientData, TTSData},
|
||||
data::{DatabaseClientData, TTSClientData, TTSData},
|
||||
tts::instance::TTSInstance,
|
||||
};
|
||||
|
||||
@ -106,15 +106,20 @@ pub async fn setup_command(
|
||||
}
|
||||
};
|
||||
|
||||
storage.insert(
|
||||
guild.id,
|
||||
TTSInstance {
|
||||
before_message: None,
|
||||
guild: guild.id,
|
||||
text_channel: text_channel_id,
|
||||
voice_channel: channel_id,
|
||||
},
|
||||
);
|
||||
let instance = TTSInstance::new(text_channel_id, channel_id, guild.id);
|
||||
storage.insert(guild.id, instance.clone());
|
||||
|
||||
// Save to database
|
||||
let data_read = ctx.data.read().await;
|
||||
let database = data_read
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
drop(data_read);
|
||||
|
||||
if let Err(e) = database.save_tts_instance(guild.id, &instance).await {
|
||||
tracing::error!("Failed to save TTS instance to database: {}", e);
|
||||
}
|
||||
|
||||
text_channel_id
|
||||
};
|
||||
|
@ -1,13 +1,12 @@
|
||||
|
||||
use serenity::{
|
||||
all::{
|
||||
CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, EditThread
|
||||
CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, EditThread,
|
||||
},
|
||||
model::prelude::UserId,
|
||||
prelude::Context
|
||||
prelude::Context,
|
||||
};
|
||||
|
||||
use crate::data::TTSData;
|
||||
use crate::data::{DatabaseClientData, TTSData};
|
||||
|
||||
pub async fn stop_command(
|
||||
ctx: &Context,
|
||||
@ -15,12 +14,14 @@ pub async fn stop_command(
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if command.guild_id.is_none() {
|
||||
command
|
||||
.create_response(&ctx.http,
|
||||
.create_response(
|
||||
&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("このコマンドはサーバーでのみ使用可能です.")
|
||||
.ephemeral(true)
|
||||
))
|
||||
.ephemeral(true),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
@ -35,12 +36,14 @@ pub async fn stop_command(
|
||||
|
||||
if channel_id.is_none() {
|
||||
command
|
||||
.create_response(&ctx.http,
|
||||
.create_response(
|
||||
&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("ボイスチャンネルに参加してから実行してください.")
|
||||
.ephemeral(true)
|
||||
))
|
||||
.ephemeral(true),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
@ -60,32 +63,48 @@ pub async fn stop_command(
|
||||
|
||||
let text_channel_id = {
|
||||
let mut storage = storage_lock.write().await;
|
||||
|
||||
|
||||
if !storage.contains_key(&guild.id) {
|
||||
command
|
||||
.create_response(&ctx.http,
|
||||
.create_response(
|
||||
&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("すでに停止しています")
|
||||
.ephemeral(true)
|
||||
))
|
||||
.ephemeral(true),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let text_channel_id = storage.get(&guild.id).unwrap().text_channel;
|
||||
storage.remove(&guild.id);
|
||||
|
||||
// Remove from database
|
||||
let data_read = ctx.data.read().await;
|
||||
let database = data_read
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
drop(data_read);
|
||||
|
||||
if let Err(e) = database.remove_tts_instance(guild.id).await {
|
||||
tracing::error!("Failed to remove TTS instance from database: {}", e);
|
||||
}
|
||||
|
||||
text_channel_id
|
||||
};
|
||||
|
||||
let _handler = manager.remove(guild.id).await;
|
||||
|
||||
command
|
||||
.create_response(&ctx.http,
|
||||
.create_response(
|
||||
&ctx.http,
|
||||
CreateInteractionResponse::Message(
|
||||
CreateInteractionResponseMessage::new()
|
||||
.content("停止しました")
|
||||
))
|
||||
CreateInteractionResponseMessage::new().content("停止しました"),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let _ = text_channel_id
|
||||
@ -93,4 +112,4 @@ pub async fn stop_command(
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,10 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::tts::{
|
||||
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType,
|
||||
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, instance::TTSInstance,
|
||||
tts_type::TTSType,
|
||||
};
|
||||
use serenity::model::id::GuildId;
|
||||
|
||||
use super::{dictionary::Dictionary, server_config::ServerConfig, user_config::UserConfig};
|
||||
use redis::Commands;
|
||||
@ -25,6 +27,14 @@ impl Database {
|
||||
format!("discord_user:{}", user_id)
|
||||
}
|
||||
|
||||
fn tts_instance_key(guild_id: u64) -> String {
|
||||
format!("tts_instance:{}", guild_id)
|
||||
}
|
||||
|
||||
fn tts_instances_list_key() -> String {
|
||||
"tts_instances_list".to_string()
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
fn get_config<T: serde::de::DeserializeOwned>(
|
||||
&self,
|
||||
@ -149,4 +159,79 @@ impl Database {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Save TTS instance to database
|
||||
#[tracing::instrument]
|
||||
pub async fn save_tts_instance(
|
||||
&self,
|
||||
guild_id: GuildId,
|
||||
instance: &TTSInstance,
|
||||
) -> redis::RedisResult<()> {
|
||||
let key = Self::tts_instance_key(guild_id.get());
|
||||
let list_key = Self::tts_instances_list_key();
|
||||
|
||||
// Save the instance
|
||||
let result = self.set_config(&key, instance);
|
||||
|
||||
// Add guild_id to the list of active instances
|
||||
if result.is_ok() {
|
||||
match self.client.get_connection() {
|
||||
Ok(mut connection) => {
|
||||
let _: redis::RedisResult<()> = connection.sadd(&list_key, guild_id.get());
|
||||
}
|
||||
Err(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Load TTS instance from database
|
||||
#[tracing::instrument]
|
||||
pub async fn load_tts_instance(
|
||||
&self,
|
||||
guild_id: GuildId,
|
||||
) -> redis::RedisResult<Option<TTSInstance>> {
|
||||
let key = Self::tts_instance_key(guild_id.get());
|
||||
self.get_config(&key)
|
||||
}
|
||||
|
||||
/// Remove TTS instance from database
|
||||
#[tracing::instrument]
|
||||
pub async fn remove_tts_instance(&self, guild_id: GuildId) -> redis::RedisResult<()> {
|
||||
let key = Self::tts_instance_key(guild_id.get());
|
||||
let list_key = Self::tts_instances_list_key();
|
||||
|
||||
match self.client.get_connection() {
|
||||
Ok(mut connection) => {
|
||||
let _: redis::RedisResult<()> = connection.del(&key);
|
||||
let _: redis::RedisResult<()> = connection.srem(&list_key, guild_id.get());
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all active TTS instances
|
||||
#[tracing::instrument]
|
||||
pub async fn get_all_tts_instances(&self) -> redis::RedisResult<Vec<(GuildId, TTSInstance)>> {
|
||||
let list_key = Self::tts_instances_list_key();
|
||||
|
||||
match self.client.get_connection() {
|
||||
Ok(mut connection) => {
|
||||
let guild_ids: Vec<u64> = connection.smembers(&list_key).unwrap_or_default();
|
||||
let mut instances = Vec::new();
|
||||
|
||||
for guild_id in guild_ids {
|
||||
let guild_id = GuildId::new(guild_id);
|
||||
if let Ok(Some(instance)) = self.load_tts_instance(guild_id).await {
|
||||
instances.push((guild_id, instance));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(instances)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,8 @@ use serenity::{
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
use crate::data::{DatabaseClientData, TTSData};
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn ready(ctx: Context, ready: Ready) {
|
||||
info!("{} is connected!", ready.user.name);
|
||||
@ -30,4 +32,66 @@ pub async fn ready(ctx: Context, ready: Ready) {
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Restore TTS instances from database
|
||||
restore_tts_instances(&ctx).await;
|
||||
}
|
||||
|
||||
/// Restore TTS instances from database and reconnect to voice channels
|
||||
async fn restore_tts_instances(ctx: &Context) {
|
||||
info!("Restoring TTS instances from database...");
|
||||
|
||||
let data = ctx.data.read().await;
|
||||
let database = data
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let tts_data = data.get::<TTSData>().unwrap().clone();
|
||||
drop(data);
|
||||
|
||||
match database.get_all_tts_instances().await {
|
||||
Ok(instances) => {
|
||||
let mut restored_count = 0;
|
||||
let mut failed_count = 0;
|
||||
|
||||
for (guild_id, instance) in instances {
|
||||
// Try to reconnect to voice channel
|
||||
match instance.reconnect(ctx).await {
|
||||
Ok(_) => {
|
||||
// Add to in-memory storage
|
||||
let mut tts_data = tts_data.write().await;
|
||||
tts_data.insert(guild_id, instance);
|
||||
drop(tts_data);
|
||||
|
||||
restored_count += 1;
|
||||
info!("Restored TTS instance for guild {}", guild_id);
|
||||
}
|
||||
Err(e) => {
|
||||
failed_count += 1;
|
||||
tracing::warn!(
|
||||
"Failed to restore TTS instance for guild {}: {}",
|
||||
guild_id,
|
||||
e
|
||||
);
|
||||
|
||||
// Remove failed instance from database
|
||||
if let Err(db_err) = database.remove_tts_instance(guild_id).await {
|
||||
tracing::error!(
|
||||
"Failed to remove invalid TTS instance from database: {}",
|
||||
db_err
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"TTS restoration complete: {} restored, {} failed",
|
||||
restored_count, failed_count
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to load TTS instances from database: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -61,15 +61,21 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
|
||||
.await
|
||||
.expect("Cannot get songbird client.")
|
||||
.clone();
|
||||
storage.insert(
|
||||
guild_id,
|
||||
TTSInstance {
|
||||
before_message: None,
|
||||
guild: guild_id,
|
||||
text_channel: new_channel,
|
||||
voice_channel: new_channel,
|
||||
},
|
||||
);
|
||||
|
||||
let instance = TTSInstance::new(new_channel, new_channel, guild_id);
|
||||
storage.insert(guild_id, instance.clone());
|
||||
|
||||
// Save to database
|
||||
let data_read = ctx.data.read().await;
|
||||
let database = data_read
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
drop(data_read);
|
||||
|
||||
if let Err(e) = database.save_tts_instance(guild_id, &instance).await {
|
||||
tracing::error!("Failed to save TTS instance to database: {}", e);
|
||||
}
|
||||
|
||||
let _handler = manager.join(guild_id, new_channel).await;
|
||||
let data = ctx.data.read().await;
|
||||
@ -140,6 +146,18 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
|
||||
.await;
|
||||
storage.remove(&guild_id);
|
||||
|
||||
// Remove from database
|
||||
let data_read = ctx.data.read().await;
|
||||
let database = data_read
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
drop(data_read);
|
||||
|
||||
if let Err(e) = database.remove_tts_instance(guild_id).await {
|
||||
tracing::error!("Failed to remove TTS instance from database: {}", e);
|
||||
}
|
||||
|
||||
let manager = songbird::get(&ctx)
|
||||
.await
|
||||
.expect("Cannot get songbird client.")
|
||||
|
@ -111,7 +111,7 @@ async fn main() {
|
||||
let mut data = client.data.write().await;
|
||||
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
|
||||
data.insert::<TTSClientData>(Arc::new(TTS::new(voicevox, tts)));
|
||||
data.insert::<DatabaseClientData>(Arc::new(database_client));
|
||||
data.insert::<DatabaseClientData>(Arc::new(database_client.clone()));
|
||||
}
|
||||
|
||||
info!("Bot initialized.");
|
||||
|
@ -1,5 +1,6 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serenity::{
|
||||
model::{
|
||||
channel::Message,
|
||||
@ -10,8 +11,9 @@ use serenity::{
|
||||
|
||||
use crate::tts::message::TTSMessage;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TTSInstance {
|
||||
#[serde(skip)] // Messageは複雑すぎるのでシリアライズしない
|
||||
pub before_message: Option<Message>,
|
||||
pub text_channel: ChannelId,
|
||||
pub voice_channel: ChannelId,
|
||||
@ -19,6 +21,49 @@ pub struct TTSInstance {
|
||||
}
|
||||
|
||||
impl TTSInstance {
|
||||
/// Create a new TTSInstance
|
||||
pub fn new(text_channel: ChannelId, voice_channel: ChannelId, guild: GuildId) -> Self {
|
||||
Self {
|
||||
before_message: None,
|
||||
text_channel,
|
||||
voice_channel,
|
||||
guild,
|
||||
}
|
||||
}
|
||||
|
||||
/// Reconnect to the voice channel after bot restart
|
||||
#[tracing::instrument]
|
||||
pub async fn reconnect(
|
||||
&self,
|
||||
ctx: &Context,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let manager = songbird::get(&ctx)
|
||||
.await
|
||||
.ok_or("Songbird manager not available")?;
|
||||
|
||||
// Check if we're already connected
|
||||
if manager.get(self.guild).is_some() {
|
||||
tracing::info!("Already connected to guild {}", self.guild);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Try to connect to the voice channel
|
||||
match manager.join(self.guild, self.voice_channel).await {
|
||||
Ok(_) => {
|
||||
tracing::info!(
|
||||
"Successfully reconnected to voice channel {} in guild {}",
|
||||
self.voice_channel,
|
||||
self.guild
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to reconnect to voice channel: {}", e);
|
||||
Err(Box::new(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Synthesize text to speech and send it to the voice channel.
|
||||
///
|
||||
/// Example:
|
||||
|
Reference in New Issue
Block a user