optimize database lock

This commit is contained in:
mii443
2025-04-11 18:07:46 +09:00
parent f7e08b4e2e
commit 97ae9dd9e0
7 changed files with 96 additions and 106 deletions

View File

@ -24,7 +24,6 @@ pub async fn config_command(
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_user_config_or_default(command.user.id.get())
.await

View File

@ -1,6 +1,5 @@
use crate::{database::database::Database, tts::tts::TTS};
use serenity::{
futures::lock::Mutex,
model::id::GuildId,
prelude::{RwLock, TypeMapKey},
};
@ -26,5 +25,5 @@ impl TypeMapKey for TTSClientData {
pub struct DatabaseClientData;
impl TypeMapKey for DatabaseClientData {
type Value = Arc<Mutex<Database>>;
type Value = Arc<Database>;
}

View File

@ -1,3 +1,5 @@
use std::fmt::Debug;
use crate::tts::{
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType,
};
@ -15,127 +17,116 @@ impl Database {
Self { client }
}
fn server_key(server_id: u64) -> String {
format!("discord_server:{}", server_id)
}
fn user_key(user_id: u64) -> String {
format!("discord_user:{}", user_id)
}
#[tracing::instrument]
fn get_config<T: serde::de::DeserializeOwned>(
&self,
key: &str,
) -> redis::RedisResult<Option<T>> {
match self.client.get_connection() {
Ok(mut connection) => {
let config: String = connection.get(key).unwrap_or_default();
if config.is_empty() {
return Ok(None);
}
match serde_json::from_str(&config) {
Ok(config) => Ok(Some(config)),
Err(_) => Ok(None),
}
}
Err(e) => Err(e),
}
}
#[tracing::instrument]
fn set_config<T: serde::Serialize + Debug>(
&self,
key: &str,
config: &T,
) -> redis::RedisResult<()> {
match self.client.get_connection() {
Ok(mut connection) => {
let config_str = serde_json::to_string(config).unwrap();
connection.set::<_, _, ()>(key, config_str)
}
Err(e) => Err(e),
}
}
#[tracing::instrument]
pub async fn get_server_config(
&mut self,
&self,
server_id: u64,
) -> redis::RedisResult<Option<ServerConfig>> {
if let Ok(mut connection) = self.client.get_connection() {
let config: String = connection
.get(format!("discord_server:{}", server_id))
.unwrap_or_default();
match serde_json::from_str(&config) {
Ok(config) => Ok(Some(config)),
Err(_) => Ok(None),
}
} else {
Ok(None)
}
self.get_config(&Self::server_key(server_id))
}
#[tracing::instrument]
pub async fn get_user_config(
&mut self,
user_id: u64,
) -> redis::RedisResult<Option<UserConfig>> {
if let Ok(mut connection) = self.client.get_connection() {
let config: String = connection
.get(format!("discord_user:{}", user_id))
.unwrap_or_default();
match serde_json::from_str(&config) {
Ok(config) => Ok(Some(config)),
Err(_) => Ok(None),
}
} else {
Ok(None)
}
pub async fn get_user_config(&self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
self.get_config(&Self::user_key(user_id))
}
#[tracing::instrument]
pub async fn set_server_config(
&mut self,
&self,
server_id: u64,
config: ServerConfig,
) -> redis::RedisResult<()> {
let config = serde_json::to_string(&config).unwrap();
self.client
.get_connection()
.unwrap()
.set::<String, String, ()>(format!("discord_server:{}", server_id), config)
.unwrap();
Ok(())
self.set_config(&Self::server_key(server_id), &config)
}
#[tracing::instrument]
pub async fn set_user_config(
&mut self,
&self,
user_id: u64,
config: UserConfig,
) -> redis::RedisResult<()> {
let config = serde_json::to_string(&config).unwrap();
self.client
.get_connection()
.unwrap()
.set::<String, String, ()>(format!("discord_user:{}", user_id), config)
.unwrap();
Ok(())
self.set_config(&Self::user_key(user_id), &config)
}
#[tracing::instrument]
pub async fn set_default_server_config(&mut self, server_id: u64) -> redis::RedisResult<()> {
pub async fn set_default_server_config(&self, server_id: u64) -> redis::RedisResult<()> {
let config = ServerConfig {
dictionary: Dictionary::new(),
autostart_channel_id: None,
};
self.client
.get_connection()
.unwrap()
.set::<String, String, ()>(
format!("discord_server:{}", server_id),
serde_json::to_string(&config).unwrap(),
)?;
Ok(())
self.set_server_config(server_id, config).await
}
#[tracing::instrument]
pub async fn set_default_user_config(&mut self, user_id: u64) -> redis::RedisResult<()> {
pub async fn set_default_user_config(&self, user_id: u64) -> redis::RedisResult<()> {
let voice_selection = VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
ssmlGender: String::from("neutral"),
};
let voice_type = TTSType::GCP;
let config = UserConfig {
tts_type: Some(voice_type),
tts_type: Some(TTSType::GCP),
gcp_tts_voice: Some(voice_selection),
voicevox_speaker: Some(1),
};
self.client
.get_connection()
.unwrap()
.set::<String, String, ()>(
format!("discord_user:{}", user_id),
serde_json::to_string(&config).unwrap(),
)?;
Ok(())
self.set_user_config(user_id, config).await
}
#[tracing::instrument]
pub async fn get_server_config_or_default(
&mut self,
&self,
server_id: u64,
) -> redis::RedisResult<Option<ServerConfig>> {
let config = self.get_server_config(server_id).await?;
match config {
Some(_) => Ok(config),
match self.get_server_config(server_id).await? {
Some(config) => Ok(Some(config)),
None => {
self.set_default_server_config(server_id).await?;
self.get_server_config(server_id).await
@ -145,12 +136,11 @@ impl Database {
#[tracing::instrument]
pub async fn get_user_config_or_default(
&mut self,
&self,
user_id: u64,
) -> redis::RedisResult<Option<UserConfig>> {
let config = self.get_user_config(user_id).await?;
match config {
Some(_) => Ok(config),
match self.get_user_config(user_id).await? {
Some(config) => Ok(Some(config)),
None => {
self.set_default_user_config(user_id).await?;
self.get_user_config(user_id).await

View File

@ -87,7 +87,7 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_server_config_or_default(modal.guild_id.unwrap().get())
.await
@ -101,7 +101,7 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.set_server_config(modal.guild_id.unwrap().get(), config)
.await
@ -140,7 +140,7 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
@ -154,7 +154,7 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.set_server_config(message_component.guild_id.unwrap().get(), config)
.await
@ -180,7 +180,7 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
@ -233,7 +233,7 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
@ -326,7 +326,7 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
let mut config = database
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
@ -357,7 +357,7 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
@ -460,7 +460,7 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_user_config_or_default(message_component.user.id.get())
.await
@ -499,7 +499,7 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.set_user_config(message_component.user.id.get(), config.clone())
.await

View File

@ -6,7 +6,11 @@ use crate::{
},
tts::{instance::TTSInstance, message::AnnounceMessage},
};
use serenity::{all::{CreateEmbed, CreateMessage, EditThread}, model::voice::VoiceState, prelude::Context};
use serenity::{
all::{CreateEmbed, CreateMessage, EditThread},
model::voice::VoiceState,
prelude::Context,
};
pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: VoiceState) {
if new.member.clone().unwrap().user.bot {
@ -37,7 +41,6 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_server_config_or_default(guild_id.get())
.await
@ -65,24 +68,26 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
);
let _handler = manager.join(guild_id, new_channel).await;
let data = ctx
.data
.read()
.await;
let data = ctx.data.read().await;
let tts_client = data
.get::<TTSClientData>()
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
new_channel
.send_message(&ctx.http,
CreateMessage::new()
.embed(
.send_message(
&ctx.http,
CreateMessage::new().embed(
CreateEmbed::new()
.title("自動参加 読み上げSerenity")
.field("VOICEVOXクレジット", format!("```\n{}\n```", voicevox_speakers.join("\n")), false)
.field(
"VOICEVOXクレジット",
format!("```\n{}\n```", voicevox_speakers.join("\n")),
false,
)
.field("設定コマンド", "`/config`", false)
.field("フィードバック", "https://feedback.mii.codes/", false))
.field("フィードバック", "https://feedback.mii.codes/", false),
),
)
.await
.unwrap();

View File

@ -27,7 +27,6 @@ impl TTSMessage for Message {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_server_config_or_default(instance.guild.get())
.await
@ -98,7 +97,6 @@ impl TTSMessage for Message {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_user_config_or_default(self.author.id.get())
.await

View File

@ -20,7 +20,6 @@ use serenity::{
all::{standard::Configuration, ApplicationId},
client::Client,
framework::StandardFramework,
futures::lock::Mutex,
prelude::{GatewayIntents, RwLock},
};
use trace::init_tracing_subscriber;
@ -104,7 +103,7 @@ async fn main() {
let mut data = client.data.write().await;
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
data.insert::<TTSClientData>(Arc::new(TTS::new(voicevox, tts)));
data.insert::<DatabaseClientData>(Arc::new(Mutex::new(database_client)));
data.insert::<DatabaseClientData>(Arc::new(database_client));
}
info!("Bot initialized.");