mirror of
https://github.com/mii443/ncb-tts-r2.git
synced 2025-08-22 16:15:29 +00:00
optimize database lock
This commit is contained in:
@ -24,7 +24,6 @@ pub async fn config_command(
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let mut database = database.lock().await;
|
||||
database
|
||||
.get_user_config_or_default(command.user.id.get())
|
||||
.await
|
||||
|
@ -1,6 +1,5 @@
|
||||
use crate::{database::database::Database, tts::tts::TTS};
|
||||
use serenity::{
|
||||
futures::lock::Mutex,
|
||||
model::id::GuildId,
|
||||
prelude::{RwLock, TypeMapKey},
|
||||
};
|
||||
@ -26,5 +25,5 @@ impl TypeMapKey for TTSClientData {
|
||||
pub struct DatabaseClientData;
|
||||
|
||||
impl TypeMapKey for DatabaseClientData {
|
||||
type Value = Arc<Mutex<Database>>;
|
||||
type Value = Arc<Database>;
|
||||
}
|
||||
|
@ -1,3 +1,5 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::tts::{
|
||||
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType,
|
||||
};
|
||||
@ -15,127 +17,116 @@ impl Database {
|
||||
Self { client }
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn get_server_config(
|
||||
&mut self,
|
||||
server_id: u64,
|
||||
) -> redis::RedisResult<Option<ServerConfig>> {
|
||||
if let Ok(mut connection) = self.client.get_connection() {
|
||||
let config: String = connection
|
||||
.get(format!("discord_server:{}", server_id))
|
||||
.unwrap_or_default();
|
||||
fn server_key(server_id: u64) -> String {
|
||||
format!("discord_server:{}", server_id)
|
||||
}
|
||||
|
||||
match serde_json::from_str(&config) {
|
||||
Ok(config) => Ok(Some(config)),
|
||||
Err(_) => Ok(None),
|
||||
fn user_key(user_id: u64) -> String {
|
||||
format!("discord_user:{}", user_id)
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
fn get_config<T: serde::de::DeserializeOwned>(
|
||||
&self,
|
||||
key: &str,
|
||||
) -> redis::RedisResult<Option<T>> {
|
||||
match self.client.get_connection() {
|
||||
Ok(mut connection) => {
|
||||
let config: String = connection.get(key).unwrap_or_default();
|
||||
|
||||
if config.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
match serde_json::from_str(&config) {
|
||||
Ok(config) => Ok(Some(config)),
|
||||
Err(_) => Ok(None),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn get_user_config(
|
||||
&mut self,
|
||||
user_id: u64,
|
||||
) -> redis::RedisResult<Option<UserConfig>> {
|
||||
if let Ok(mut connection) = self.client.get_connection() {
|
||||
let config: String = connection
|
||||
.get(format!("discord_user:{}", user_id))
|
||||
.unwrap_or_default();
|
||||
|
||||
match serde_json::from_str(&config) {
|
||||
Ok(config) => Ok(Some(config)),
|
||||
Err(_) => Ok(None),
|
||||
fn set_config<T: serde::Serialize + Debug>(
|
||||
&self,
|
||||
key: &str,
|
||||
config: &T,
|
||||
) -> redis::RedisResult<()> {
|
||||
match self.client.get_connection() {
|
||||
Ok(mut connection) => {
|
||||
let config_str = serde_json::to_string(config).unwrap();
|
||||
connection.set::<_, _, ()>(key, config_str)
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn get_server_config(
|
||||
&self,
|
||||
server_id: u64,
|
||||
) -> redis::RedisResult<Option<ServerConfig>> {
|
||||
self.get_config(&Self::server_key(server_id))
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn get_user_config(&self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
|
||||
self.get_config(&Self::user_key(user_id))
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn set_server_config(
|
||||
&mut self,
|
||||
&self,
|
||||
server_id: u64,
|
||||
config: ServerConfig,
|
||||
) -> redis::RedisResult<()> {
|
||||
let config = serde_json::to_string(&config).unwrap();
|
||||
self.client
|
||||
.get_connection()
|
||||
.unwrap()
|
||||
.set::<String, String, ()>(format!("discord_server:{}", server_id), config)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
self.set_config(&Self::server_key(server_id), &config)
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn set_user_config(
|
||||
&mut self,
|
||||
&self,
|
||||
user_id: u64,
|
||||
config: UserConfig,
|
||||
) -> redis::RedisResult<()> {
|
||||
let config = serde_json::to_string(&config).unwrap();
|
||||
self.client
|
||||
.get_connection()
|
||||
.unwrap()
|
||||
.set::<String, String, ()>(format!("discord_user:{}", user_id), config)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
self.set_config(&Self::user_key(user_id), &config)
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn set_default_server_config(&mut self, server_id: u64) -> redis::RedisResult<()> {
|
||||
pub async fn set_default_server_config(&self, server_id: u64) -> redis::RedisResult<()> {
|
||||
let config = ServerConfig {
|
||||
dictionary: Dictionary::new(),
|
||||
autostart_channel_id: None,
|
||||
};
|
||||
|
||||
self.client
|
||||
.get_connection()
|
||||
.unwrap()
|
||||
.set::<String, String, ()>(
|
||||
format!("discord_server:{}", server_id),
|
||||
serde_json::to_string(&config).unwrap(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
self.set_server_config(server_id, config).await
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn set_default_user_config(&mut self, user_id: u64) -> redis::RedisResult<()> {
|
||||
pub async fn set_default_user_config(&self, user_id: u64) -> redis::RedisResult<()> {
|
||||
let voice_selection = VoiceSelectionParams {
|
||||
languageCode: String::from("ja-JP"),
|
||||
name: String::from("ja-JP-Wavenet-B"),
|
||||
ssmlGender: String::from("neutral"),
|
||||
};
|
||||
|
||||
let voice_type = TTSType::GCP;
|
||||
|
||||
let config = UserConfig {
|
||||
tts_type: Some(voice_type),
|
||||
tts_type: Some(TTSType::GCP),
|
||||
gcp_tts_voice: Some(voice_selection),
|
||||
voicevox_speaker: Some(1),
|
||||
};
|
||||
|
||||
self.client
|
||||
.get_connection()
|
||||
.unwrap()
|
||||
.set::<String, String, ()>(
|
||||
format!("discord_user:{}", user_id),
|
||||
serde_json::to_string(&config).unwrap(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
self.set_user_config(user_id, config).await
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn get_server_config_or_default(
|
||||
&mut self,
|
||||
&self,
|
||||
server_id: u64,
|
||||
) -> redis::RedisResult<Option<ServerConfig>> {
|
||||
let config = self.get_server_config(server_id).await?;
|
||||
match config {
|
||||
Some(_) => Ok(config),
|
||||
match self.get_server_config(server_id).await? {
|
||||
Some(config) => Ok(Some(config)),
|
||||
None => {
|
||||
self.set_default_server_config(server_id).await?;
|
||||
self.get_server_config(server_id).await
|
||||
@ -145,12 +136,11 @@ impl Database {
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn get_user_config_or_default(
|
||||
&mut self,
|
||||
&self,
|
||||
user_id: u64,
|
||||
) -> redis::RedisResult<Option<UserConfig>> {
|
||||
let config = self.get_user_config(user_id).await?;
|
||||
match config {
|
||||
Some(_) => Ok(config),
|
||||
match self.get_user_config(user_id).await? {
|
||||
Some(config) => Ok(Some(config)),
|
||||
None => {
|
||||
self.set_default_user_config(user_id).await?;
|
||||
self.get_user_config(user_id).await
|
||||
|
@ -87,7 +87,7 @@ impl EventHandler for Handler {
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let mut database = database.lock().await;
|
||||
|
||||
database
|
||||
.get_server_config_or_default(modal.guild_id.unwrap().get())
|
||||
.await
|
||||
@ -101,7 +101,7 @@ impl EventHandler for Handler {
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let mut database = database.lock().await;
|
||||
|
||||
database
|
||||
.set_server_config(modal.guild_id.unwrap().get(), config)
|
||||
.await
|
||||
@ -140,7 +140,7 @@ impl EventHandler for Handler {
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let mut database = database.lock().await;
|
||||
|
||||
database
|
||||
.get_server_config_or_default(message_component.guild_id.unwrap().get())
|
||||
.await
|
||||
@ -154,7 +154,7 @@ impl EventHandler for Handler {
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let mut database = database.lock().await;
|
||||
|
||||
database
|
||||
.set_server_config(message_component.guild_id.unwrap().get(), config)
|
||||
.await
|
||||
@ -180,7 +180,7 @@ impl EventHandler for Handler {
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let mut database = database.lock().await;
|
||||
|
||||
database
|
||||
.get_server_config_or_default(message_component.guild_id.unwrap().get())
|
||||
.await
|
||||
@ -233,7 +233,7 @@ impl EventHandler for Handler {
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let mut database = database.lock().await;
|
||||
|
||||
database
|
||||
.get_server_config_or_default(message_component.guild_id.unwrap().get())
|
||||
.await
|
||||
@ -326,7 +326,7 @@ impl EventHandler for Handler {
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let mut database = database.lock().await;
|
||||
|
||||
let mut config = database
|
||||
.get_server_config_or_default(message_component.guild_id.unwrap().get())
|
||||
.await
|
||||
@ -357,7 +357,7 @@ impl EventHandler for Handler {
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let mut database = database.lock().await;
|
||||
|
||||
database
|
||||
.get_server_config_or_default(message_component.guild_id.unwrap().get())
|
||||
.await
|
||||
@ -460,7 +460,7 @@ impl EventHandler for Handler {
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let mut database = database.lock().await;
|
||||
|
||||
database
|
||||
.get_user_config_or_default(message_component.user.id.get())
|
||||
.await
|
||||
@ -499,7 +499,7 @@ impl EventHandler for Handler {
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let mut database = database.lock().await;
|
||||
|
||||
database
|
||||
.set_user_config(message_component.user.id.get(), config.clone())
|
||||
.await
|
||||
|
@ -6,7 +6,11 @@ use crate::{
|
||||
},
|
||||
tts::{instance::TTSInstance, message::AnnounceMessage},
|
||||
};
|
||||
use serenity::{all::{CreateEmbed, CreateMessage, EditThread}, model::voice::VoiceState, prelude::Context};
|
||||
use serenity::{
|
||||
all::{CreateEmbed, CreateMessage, EditThread},
|
||||
model::voice::VoiceState,
|
||||
prelude::Context,
|
||||
};
|
||||
|
||||
pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: VoiceState) {
|
||||
if new.member.clone().unwrap().user.bot {
|
||||
@ -37,7 +41,6 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let mut database = database.lock().await;
|
||||
database
|
||||
.get_server_config_or_default(guild_id.get())
|
||||
.await
|
||||
@ -65,24 +68,26 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
|
||||
);
|
||||
|
||||
let _handler = manager.join(guild_id, new_channel).await;
|
||||
let data = ctx
|
||||
.data
|
||||
.read()
|
||||
.await;
|
||||
let data = ctx.data.read().await;
|
||||
let tts_client = data
|
||||
.get::<TTSClientData>()
|
||||
.expect("Cannot get TTSClientData");
|
||||
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
|
||||
|
||||
new_channel
|
||||
.send_message(&ctx.http,
|
||||
CreateMessage::new()
|
||||
.embed(
|
||||
CreateEmbed::new()
|
||||
.title("自動参加 読み上げ(Serenity)")
|
||||
.field("VOICEVOXクレジット", format!("```\n{}\n```", voicevox_speakers.join("\n")), false)
|
||||
.field("設定コマンド", "`/config`", false)
|
||||
.field("フィードバック", "https://feedback.mii.codes/", false))
|
||||
.send_message(
|
||||
&ctx.http,
|
||||
CreateMessage::new().embed(
|
||||
CreateEmbed::new()
|
||||
.title("自動参加 読み上げ(Serenity)")
|
||||
.field(
|
||||
"VOICEVOXクレジット",
|
||||
format!("```\n{}\n```", voicevox_speakers.join("\n")),
|
||||
false,
|
||||
)
|
||||
.field("設定コマンド", "`/config`", false)
|
||||
.field("フィードバック", "https://feedback.mii.codes/", false),
|
||||
),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
@ -27,7 +27,6 @@ impl TTSMessage for Message {
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let mut database = database.lock().await;
|
||||
database
|
||||
.get_server_config_or_default(instance.guild.get())
|
||||
.await
|
||||
@ -98,7 +97,6 @@ impl TTSMessage for Message {
|
||||
.get::<DatabaseClientData>()
|
||||
.expect("Cannot get DatabaseClientData")
|
||||
.clone();
|
||||
let mut database = database.lock().await;
|
||||
database
|
||||
.get_user_config_or_default(self.author.id.get())
|
||||
.await
|
||||
|
@ -20,7 +20,6 @@ use serenity::{
|
||||
all::{standard::Configuration, ApplicationId},
|
||||
client::Client,
|
||||
framework::StandardFramework,
|
||||
futures::lock::Mutex,
|
||||
prelude::{GatewayIntents, RwLock},
|
||||
};
|
||||
use trace::init_tracing_subscriber;
|
||||
@ -104,7 +103,7 @@ async fn main() {
|
||||
let mut data = client.data.write().await;
|
||||
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
|
||||
data.insert::<TTSClientData>(Arc::new(TTS::new(voicevox, tts)));
|
||||
data.insert::<DatabaseClientData>(Arc::new(Mutex::new(database_client)));
|
||||
data.insert::<DatabaseClientData>(Arc::new(database_client));
|
||||
}
|
||||
|
||||
info!("Bot initialized.");
|
||||
|
Reference in New Issue
Block a user