From 97ae9dd9e0237c059bb84aaa85b966ad37fc7ed9 Mon Sep 17 00:00:00 2001 From: mii443 Date: Fri, 11 Apr 2025 18:07:46 +0900 Subject: [PATCH] optimize database lock --- src/commands/config.rs | 1 - src/data.rs | 3 +- src/database/database.rs | 140 ++++++++++++++----------------- src/event_handler.rs | 20 ++--- src/events/voice_state_update.rs | 33 ++++---- src/implement/message.rs | 2 - src/main.rs | 3 +- 7 files changed, 96 insertions(+), 106 deletions(-) diff --git a/src/commands/config.rs b/src/commands/config.rs index 718a05f..7c88f92 100644 --- a/src/commands/config.rs +++ b/src/commands/config.rs @@ -24,7 +24,6 @@ pub async fn config_command( .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; database .get_user_config_or_default(command.user.id.get()) .await diff --git a/src/data.rs b/src/data.rs index e7b4e8d..2409307 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,6 +1,5 @@ use crate::{database::database::Database, tts::tts::TTS}; use serenity::{ - futures::lock::Mutex, model::id::GuildId, prelude::{RwLock, TypeMapKey}, }; @@ -26,5 +25,5 @@ impl TypeMapKey for TTSClientData { pub struct DatabaseClientData; impl TypeMapKey for DatabaseClientData { - type Value = Arc>; + type Value = Arc; } diff --git a/src/database/database.rs b/src/database/database.rs index 56bea33..c3bbba9 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; + use crate::tts::{ gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType, }; @@ -15,127 +17,116 @@ impl Database { Self { client } } - #[tracing::instrument] - pub async fn get_server_config( - &mut self, - server_id: u64, - ) -> redis::RedisResult> { - if let Ok(mut connection) = self.client.get_connection() { - let config: String = connection - .get(format!("discord_server:{}", server_id)) - .unwrap_or_default(); + fn server_key(server_id: u64) -> String { + format!("discord_server:{}", server_id) + } - match serde_json::from_str(&config) { - Ok(config) => Ok(Some(config)), - Err(_) => Ok(None), + fn user_key(user_id: u64) -> String { + format!("discord_user:{}", user_id) + } + + #[tracing::instrument] + fn get_config( + &self, + key: &str, + ) -> redis::RedisResult> { + match self.client.get_connection() { + Ok(mut connection) => { + let config: String = connection.get(key).unwrap_or_default(); + + if config.is_empty() { + return Ok(None); + } + + match serde_json::from_str(&config) { + Ok(config) => Ok(Some(config)), + Err(_) => Ok(None), + } } - } else { - Ok(None) + Err(e) => Err(e), } } #[tracing::instrument] - pub async fn get_user_config( - &mut self, - user_id: u64, - ) -> redis::RedisResult> { - if let Ok(mut connection) = self.client.get_connection() { - let config: String = connection - .get(format!("discord_user:{}", user_id)) - .unwrap_or_default(); - - match serde_json::from_str(&config) { - Ok(config) => Ok(Some(config)), - Err(_) => Ok(None), + fn set_config( + &self, + key: &str, + config: &T, + ) -> redis::RedisResult<()> { + match self.client.get_connection() { + Ok(mut connection) => { + let config_str = serde_json::to_string(config).unwrap(); + connection.set::<_, _, ()>(key, config_str) } - } else { - Ok(None) + Err(e) => Err(e), } } + #[tracing::instrument] + pub async fn get_server_config( + &self, + server_id: u64, + ) -> redis::RedisResult> { + self.get_config(&Self::server_key(server_id)) + } + + #[tracing::instrument] + pub async fn get_user_config(&self, user_id: u64) -> redis::RedisResult> { + self.get_config(&Self::user_key(user_id)) + } + #[tracing::instrument] pub async fn set_server_config( - &mut self, + &self, server_id: u64, config: ServerConfig, ) -> redis::RedisResult<()> { - let config = serde_json::to_string(&config).unwrap(); - self.client - .get_connection() - .unwrap() - .set::(format!("discord_server:{}", server_id), config) - .unwrap(); - Ok(()) + self.set_config(&Self::server_key(server_id), &config) } #[tracing::instrument] pub async fn set_user_config( - &mut self, + &self, user_id: u64, config: UserConfig, ) -> redis::RedisResult<()> { - let config = serde_json::to_string(&config).unwrap(); - self.client - .get_connection() - .unwrap() - .set::(format!("discord_user:{}", user_id), config) - .unwrap(); - Ok(()) + self.set_config(&Self::user_key(user_id), &config) } #[tracing::instrument] - pub async fn set_default_server_config(&mut self, server_id: u64) -> redis::RedisResult<()> { + pub async fn set_default_server_config(&self, server_id: u64) -> redis::RedisResult<()> { let config = ServerConfig { dictionary: Dictionary::new(), autostart_channel_id: None, }; - self.client - .get_connection() - .unwrap() - .set::( - format!("discord_server:{}", server_id), - serde_json::to_string(&config).unwrap(), - )?; - - Ok(()) + self.set_server_config(server_id, config).await } #[tracing::instrument] - pub async fn set_default_user_config(&mut self, user_id: u64) -> redis::RedisResult<()> { + pub async fn set_default_user_config(&self, user_id: u64) -> redis::RedisResult<()> { let voice_selection = VoiceSelectionParams { languageCode: String::from("ja-JP"), name: String::from("ja-JP-Wavenet-B"), ssmlGender: String::from("neutral"), }; - let voice_type = TTSType::GCP; - let config = UserConfig { - tts_type: Some(voice_type), + tts_type: Some(TTSType::GCP), gcp_tts_voice: Some(voice_selection), voicevox_speaker: Some(1), }; - self.client - .get_connection() - .unwrap() - .set::( - format!("discord_user:{}", user_id), - serde_json::to_string(&config).unwrap(), - )?; - - Ok(()) + self.set_user_config(user_id, config).await } #[tracing::instrument] pub async fn get_server_config_or_default( - &mut self, + &self, server_id: u64, ) -> redis::RedisResult> { - let config = self.get_server_config(server_id).await?; - match config { - Some(_) => Ok(config), + match self.get_server_config(server_id).await? { + Some(config) => Ok(Some(config)), None => { self.set_default_server_config(server_id).await?; self.get_server_config(server_id).await @@ -145,12 +136,11 @@ impl Database { #[tracing::instrument] pub async fn get_user_config_or_default( - &mut self, + &self, user_id: u64, ) -> redis::RedisResult> { - let config = self.get_user_config(user_id).await?; - match config { - Some(_) => Ok(config), + match self.get_user_config(user_id).await? { + Some(config) => Ok(Some(config)), None => { self.set_default_user_config(user_id).await?; self.get_user_config(user_id).await diff --git a/src/event_handler.rs b/src/event_handler.rs index 251c30f..e986ca2 100644 --- a/src/event_handler.rs +++ b/src/event_handler.rs @@ -87,7 +87,7 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database .get_server_config_or_default(modal.guild_id.unwrap().get()) .await @@ -101,7 +101,7 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database .set_server_config(modal.guild_id.unwrap().get(), config) .await @@ -140,7 +140,7 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database .get_server_config_or_default(message_component.guild_id.unwrap().get()) .await @@ -154,7 +154,7 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database .set_server_config(message_component.guild_id.unwrap().get(), config) .await @@ -180,7 +180,7 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database .get_server_config_or_default(message_component.guild_id.unwrap().get()) .await @@ -233,7 +233,7 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database .get_server_config_or_default(message_component.guild_id.unwrap().get()) .await @@ -326,7 +326,7 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + let mut config = database .get_server_config_or_default(message_component.guild_id.unwrap().get()) .await @@ -357,7 +357,7 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database .get_server_config_or_default(message_component.guild_id.unwrap().get()) .await @@ -460,7 +460,7 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database .get_user_config_or_default(message_component.user.id.get()) .await @@ -499,7 +499,7 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database .set_user_config(message_component.user.id.get(), config.clone()) .await diff --git a/src/events/voice_state_update.rs b/src/events/voice_state_update.rs index f22d60f..b25d716 100644 --- a/src/events/voice_state_update.rs +++ b/src/events/voice_state_update.rs @@ -6,7 +6,11 @@ use crate::{ }, tts::{instance::TTSInstance, message::AnnounceMessage}, }; -use serenity::{all::{CreateEmbed, CreateMessage, EditThread}, model::voice::VoiceState, prelude::Context}; +use serenity::{ + all::{CreateEmbed, CreateMessage, EditThread}, + model::voice::VoiceState, + prelude::Context, +}; pub async fn voice_state_update(ctx: Context, old: Option, new: VoiceState) { if new.member.clone().unwrap().user.bot { @@ -37,7 +41,6 @@ pub async fn voice_state_update(ctx: Context, old: Option, new: Voic .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; database .get_server_config_or_default(guild_id.get()) .await @@ -65,24 +68,26 @@ pub async fn voice_state_update(ctx: Context, old: Option, new: Voic ); let _handler = manager.join(guild_id, new_channel).await; - let data = ctx - .data - .read() - .await; + let data = ctx.data.read().await; let tts_client = data .get::() .expect("Cannot get TTSClientData"); let voicevox_speakers = tts_client.voicevox_client.get_speakers().await; new_channel - .send_message(&ctx.http, - CreateMessage::new() - .embed( - CreateEmbed::new() - .title("自動参加 読み上げ(Serenity)") - .field("VOICEVOXクレジット", format!("```\n{}\n```", voicevox_speakers.join("\n")), false) - .field("設定コマンド", "`/config`", false) - .field("フィードバック", "https://feedback.mii.codes/", false)) + .send_message( + &ctx.http, + CreateMessage::new().embed( + CreateEmbed::new() + .title("自動参加 読み上げ(Serenity)") + .field( + "VOICEVOXクレジット", + format!("```\n{}\n```", voicevox_speakers.join("\n")), + false, + ) + .field("設定コマンド", "`/config`", false) + .field("フィードバック", "https://feedback.mii.codes/", false), + ), ) .await .unwrap(); diff --git a/src/implement/message.rs b/src/implement/message.rs index afc3e3e..6e93d1f 100644 --- a/src/implement/message.rs +++ b/src/implement/message.rs @@ -27,7 +27,6 @@ impl TTSMessage for Message { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; database .get_server_config_or_default(instance.guild.get()) .await @@ -98,7 +97,6 @@ impl TTSMessage for Message { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; database .get_user_config_or_default(self.author.id.get()) .await diff --git a/src/main.rs b/src/main.rs index 3cc77f8..525823c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,7 +20,6 @@ use serenity::{ all::{standard::Configuration, ApplicationId}, client::Client, framework::StandardFramework, - futures::lock::Mutex, prelude::{GatewayIntents, RwLock}, }; use trace::init_tracing_subscriber; @@ -104,7 +103,7 @@ async fn main() { let mut data = client.data.write().await; data.insert::(Arc::new(RwLock::new(HashMap::default()))); data.insert::(Arc::new(TTS::new(voicevox, tts))); - data.insert::(Arc::new(Mutex::new(database_client))); + data.insert::(Arc::new(database_client)); } info!("Bot initialized.");