reduce Mutex lock

This commit is contained in:
mii443
2025-04-04 22:14:23 +09:00
parent 55ea223f69
commit 4c176935e3
8 changed files with 51 additions and 36 deletions

View File

@ -31,9 +31,8 @@ pub async fn config_command(
let tts_client = data_read
.get::<TTSClientData>()
.expect("Cannot get TTSClientData")
.clone();
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_styles().await;
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client.voicevox_client.get_styles().await;
let voicevox_speaker = config.voicevox_speaker.unwrap_or(1);
let tts_type = config.tts_type.unwrap_or(TTSType::GCP);

View File

@ -135,14 +135,14 @@ pub async fn setup_command(
let _handler = manager.join(guild.id, channel_id).await;
let tts_client = ctx
let data = ctx
.data
.read()
.await
.await;
let tts_client = data
.get::<TTSClientData>()
.expect("Cannot get TTSClientData")
.clone();
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_speakers().await;
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
text_channel_id
.send_message(&ctx.http, CreateMessage::new()

View File

@ -22,7 +22,7 @@ impl TypeMapKey for TTSData {
pub struct TTSClientData;
impl TypeMapKey for TTSClientData {
type Value = Arc<Mutex<TTS>>;
type Value = Arc<TTS>;
}
/// Database client data

View File

@ -65,14 +65,14 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
);
let _handler = manager.join(guild_id, new_channel).await;
let tts_client = ctx
let data = ctx
.data
.read()
.await
.await;
let tts_client = data
.get::<TTSClientData>()
.expect("Cannot get TTSClientData")
.clone();
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_speakers().await;
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
new_channel
.send_message(&ctx.http,

View File

@ -82,11 +82,6 @@ impl TTSMessage for Message {
let text = self.parse(instance, ctx).await;
let data_read = ctx.data.read().await;
let storage = data_read
.get::<TTSClientData>()
.expect("Cannot get GCP TTSClientStorage")
.clone();
let mut tts = storage.lock().await;
let config = {
let database = data_read
@ -101,6 +96,10 @@ impl TTSMessage for Message {
.unwrap()
};
let tts = data_read
.get::<TTSClientData>()
.expect("Cannot get GCP TTSClientStorage");
let audio = match config.tts_type.unwrap_or(TTSType::GCP) {
TTSType::GCP => tts
.synthesize_gcp(SynthesizeRequest {

View File

@ -88,7 +88,7 @@ async fn main() {
{
let mut data = client.data.write().await;
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
data.insert::<TTSClientData>(Arc::new(Mutex::new(TTS::new(voicevox, tts))));
data.insert::<TTSClientData>(Arc::new(TTS::new(voicevox, tts)));
data.insert::<DatabaseClientData>(Arc::new(Mutex::new(database_client)));
}

View File

@ -46,13 +46,11 @@ impl TTSMessage for AnnounceMessage {
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Compressed> {
let text = self.parse(instance, ctx).await;
let data_read = ctx.data.read().await;
let storage = data_read
let tts = data_read
.get::<TTSClientData>()
.expect("Cannot get TTSClientStorage")
.clone();
let mut storage = storage.lock().await;
.expect("Cannot get TTSClientStorage");
let audio = storage
let audio = tts
.synthesize_gcp(SynthesizeRequest {
input: SynthesisInput {
text: None,

View File

@ -1,4 +1,5 @@
use std::num::NonZeroUsize;
use std::sync::RwLock;
use lru::LruCache;
use songbird::{driver::Bitrate, input::cached::Compressed};
@ -9,7 +10,7 @@ use super::{gcp_tts::{gcp_tts::GCPTTS, structs::{synthesis_input::SynthesisInput
pub struct TTS {
pub voicevox_client: VOICEVOX,
gcp_tts_client: GCPTTS,
cache: LruCache<CacheKey, Compressed>,
cache: RwLock<LruCache<CacheKey, Compressed>>,
}
#[derive(Hash, PartialEq, Eq)]
@ -26,17 +27,23 @@ impl TTS {
Self {
voicevox_client,
gcp_tts_client,
cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
cache: RwLock::new(LruCache::new(NonZeroUsize::new(100).unwrap())),
}
}
pub async fn synthesize_voicevox(&mut self, text: &str, speaker: i64) -> Result<Compressed, Box<dyn std::error::Error>> {
pub async fn synthesize_voicevox(&self, text: &str, speaker: i64) -> Result<Compressed, Box<dyn std::error::Error>> {
let cache_key = CacheKey::Voicevox(text.to_string(), speaker);
if let Some(audio) = self.cache.get(&cache_key) {
let cached_audio = {
let mut cache_guard = self.cache.write().unwrap();
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
};
if let Some(audio) = cached_audio {
info!("Cache hit for VOICEVOX TTS");
return Ok(audio.new_handle());
return Ok(audio);
}
info!("Cache miss for VOICEVOX TTS");
let audio = self.voicevox_client
@ -45,21 +52,30 @@ impl TTS {
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
self.cache.put(cache_key, compressed.clone());
{
let mut cache_guard = self.cache.write().unwrap();
cache_guard.put(cache_key, compressed.clone());
}
Ok(compressed)
}
pub async fn synthesize_gcp(&mut self, synthesize_request: SynthesizeRequest) -> Result<Compressed, Box<dyn std::error::Error>> {
pub async fn synthesize_gcp(&self, synthesize_request: SynthesizeRequest) -> Result<Compressed, Box<dyn std::error::Error>> {
let cache_key = CacheKey::GCP(
synthesize_request.input.clone(),
synthesize_request.voice.clone(),
);
if let Some(audio) = self.cache.get(&cache_key) {
let cached_audio = {
let mut cache_guard = self.cache.write().unwrap();
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
};
if let Some(audio) = cached_audio {
info!("Cache hit for GCP TTS");
return Ok(audio.new_handle());
return Ok(audio);
}
info!("Cache miss for GCP TTS");
let audio = self.gcp_tts_client
@ -68,7 +84,10 @@ impl TTS {
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
self.cache.put(cache_key, compressed.clone());
{
let mut cache_guard = self.cache.write().unwrap();
cache_guard.put(cache_key, compressed.clone());
}
Ok(compressed)
}