diff --git a/src/commands/config.rs b/src/commands/config.rs index 5047499..a8bb573 100644 --- a/src/commands/config.rs +++ b/src/commands/config.rs @@ -31,9 +31,8 @@ pub async fn config_command( let tts_client = data_read .get::() - .expect("Cannot get TTSClientData") - .clone(); - let voicevox_speakers = tts_client.lock().await.voicevox_client.get_styles().await; + .expect("Cannot get TTSClientData"); + let voicevox_speakers = tts_client.voicevox_client.get_styles().await; let voicevox_speaker = config.voicevox_speaker.unwrap_or(1); let tts_type = config.tts_type.unwrap_or(TTSType::GCP); diff --git a/src/commands/setup.rs b/src/commands/setup.rs index 4a9891b..7c0caea 100644 --- a/src/commands/setup.rs +++ b/src/commands/setup.rs @@ -135,14 +135,14 @@ pub async fn setup_command( let _handler = manager.join(guild.id, channel_id).await; - let tts_client = ctx + let data = ctx .data .read() - .await + .await; + let tts_client = data .get::() - .expect("Cannot get TTSClientData") - .clone(); - let voicevox_speakers = tts_client.lock().await.voicevox_client.get_speakers().await; + .expect("Cannot get TTSClientData"); + let voicevox_speakers = tts_client.voicevox_client.get_speakers().await; text_channel_id .send_message(&ctx.http, CreateMessage::new() diff --git a/src/data.rs b/src/data.rs index 5c80a9f..f0052fe 100644 --- a/src/data.rs +++ b/src/data.rs @@ -22,7 +22,7 @@ impl TypeMapKey for TTSData { pub struct TTSClientData; impl TypeMapKey for TTSClientData { - type Value = Arc>; + type Value = Arc; } /// Database client data diff --git a/src/events/voice_state_update.rs b/src/events/voice_state_update.rs index 9b2ba1e..f22d60f 100644 --- a/src/events/voice_state_update.rs +++ b/src/events/voice_state_update.rs @@ -65,14 +65,14 @@ pub async fn voice_state_update(ctx: Context, old: Option, new: Voic ); let _handler = manager.join(guild_id, new_channel).await; - let tts_client = ctx + let data = ctx .data .read() - .await + .await; + let tts_client = data .get::() - .expect("Cannot get TTSClientData") - .clone(); - let voicevox_speakers = tts_client.lock().await.voicevox_client.get_speakers().await; + .expect("Cannot get TTSClientData"); + let voicevox_speakers = tts_client.voicevox_client.get_speakers().await; new_channel .send_message(&ctx.http, diff --git a/src/implement/message.rs b/src/implement/message.rs index a93d4b6..d3dd6b9 100644 --- a/src/implement/message.rs +++ b/src/implement/message.rs @@ -82,11 +82,6 @@ impl TTSMessage for Message { let text = self.parse(instance, ctx).await; let data_read = ctx.data.read().await; - let storage = data_read - .get::() - .expect("Cannot get GCP TTSClientStorage") - .clone(); - let mut tts = storage.lock().await; let config = { let database = data_read @@ -101,6 +96,10 @@ impl TTSMessage for Message { .unwrap() }; + let tts = data_read + .get::() + .expect("Cannot get GCP TTSClientStorage"); + let audio = match config.tts_type.unwrap_or(TTSType::GCP) { TTSType::GCP => tts .synthesize_gcp(SynthesizeRequest { diff --git a/src/main.rs b/src/main.rs index 4737664..ae7d7d0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -88,7 +88,7 @@ async fn main() { { let mut data = client.data.write().await; data.insert::(Arc::new(RwLock::new(HashMap::default()))); - data.insert::(Arc::new(Mutex::new(TTS::new(voicevox, tts)))); + data.insert::(Arc::new(TTS::new(voicevox, tts))); data.insert::(Arc::new(Mutex::new(database_client))); } diff --git a/src/tts/message.rs b/src/tts/message.rs index e821473..68ff20a 100644 --- a/src/tts/message.rs +++ b/src/tts/message.rs @@ -46,13 +46,11 @@ impl TTSMessage for AnnounceMessage { async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec { let text = self.parse(instance, ctx).await; let data_read = ctx.data.read().await; - let storage = data_read + let tts = data_read .get::() - .expect("Cannot get TTSClientStorage") - .clone(); - let mut storage = storage.lock().await; + .expect("Cannot get TTSClientStorage"); - let audio = storage + let audio = tts .synthesize_gcp(SynthesizeRequest { input: SynthesisInput { text: None, diff --git a/src/tts/tts.rs b/src/tts/tts.rs index ebbd2cc..d05ff36 100644 --- a/src/tts/tts.rs +++ b/src/tts/tts.rs @@ -1,4 +1,5 @@ use std::num::NonZeroUsize; +use std::sync::RwLock; use lru::LruCache; use songbird::{driver::Bitrate, input::cached::Compressed}; @@ -9,7 +10,7 @@ use super::{gcp_tts::{gcp_tts::GCPTTS, structs::{synthesis_input::SynthesisInput pub struct TTS { pub voicevox_client: VOICEVOX, gcp_tts_client: GCPTTS, - cache: LruCache, + cache: RwLock>, } #[derive(Hash, PartialEq, Eq)] @@ -26,17 +27,23 @@ impl TTS { Self { voicevox_client, gcp_tts_client, - cache: LruCache::new(NonZeroUsize::new(100).unwrap()), + cache: RwLock::new(LruCache::new(NonZeroUsize::new(100).unwrap())), } } - pub async fn synthesize_voicevox(&mut self, text: &str, speaker: i64) -> Result> { + pub async fn synthesize_voicevox(&self, text: &str, speaker: i64) -> Result> { let cache_key = CacheKey::Voicevox(text.to_string(), speaker); - if let Some(audio) = self.cache.get(&cache_key) { + let cached_audio = { + let mut cache_guard = self.cache.write().unwrap(); + cache_guard.get(&cache_key).map(|audio| audio.new_handle()) + }; + + if let Some(audio) = cached_audio { info!("Cache hit for VOICEVOX TTS"); - return Ok(audio.new_handle()); + return Ok(audio); } + info!("Cache miss for VOICEVOX TTS"); let audio = self.voicevox_client @@ -45,21 +52,30 @@ impl TTS { let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?; - self.cache.put(cache_key, compressed.clone()); + { + let mut cache_guard = self.cache.write().unwrap(); + cache_guard.put(cache_key, compressed.clone()); + } Ok(compressed) } - pub async fn synthesize_gcp(&mut self, synthesize_request: SynthesizeRequest) -> Result> { + pub async fn synthesize_gcp(&self, synthesize_request: SynthesizeRequest) -> Result> { let cache_key = CacheKey::GCP( synthesize_request.input.clone(), synthesize_request.voice.clone(), ); - if let Some(audio) = self.cache.get(&cache_key) { + let cached_audio = { + let mut cache_guard = self.cache.write().unwrap(); + cache_guard.get(&cache_key).map(|audio| audio.new_handle()) + }; + + if let Some(audio) = cached_audio { info!("Cache hit for GCP TTS"); - return Ok(audio.new_handle()); + return Ok(audio); } + info!("Cache miss for GCP TTS"); let audio = self.gcp_tts_client @@ -68,8 +84,11 @@ impl TTS { let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?; - self.cache.put(cache_key, compressed.clone()); + { + let mut cache_guard = self.cache.write().unwrap(); + cache_guard.put(cache_key, compressed.clone()); + } Ok(compressed) } -} +} \ No newline at end of file