mirror of
https://github.com/mii443/ncb-tts-r2.git
synced 2025-08-22 16:15:29 +00:00
reduce Mutex lock
This commit is contained in:
@ -31,9 +31,8 @@ pub async fn config_command(
|
||||
|
||||
let tts_client = data_read
|
||||
.get::<TTSClientData>()
|
||||
.expect("Cannot get TTSClientData")
|
||||
.clone();
|
||||
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_styles().await;
|
||||
.expect("Cannot get TTSClientData");
|
||||
let voicevox_speakers = tts_client.voicevox_client.get_styles().await;
|
||||
|
||||
let voicevox_speaker = config.voicevox_speaker.unwrap_or(1);
|
||||
let tts_type = config.tts_type.unwrap_or(TTSType::GCP);
|
||||
|
@ -135,14 +135,14 @@ pub async fn setup_command(
|
||||
|
||||
let _handler = manager.join(guild.id, channel_id).await;
|
||||
|
||||
let tts_client = ctx
|
||||
let data = ctx
|
||||
.data
|
||||
.read()
|
||||
.await
|
||||
.await;
|
||||
let tts_client = data
|
||||
.get::<TTSClientData>()
|
||||
.expect("Cannot get TTSClientData")
|
||||
.clone();
|
||||
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_speakers().await;
|
||||
.expect("Cannot get TTSClientData");
|
||||
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
|
||||
|
||||
text_channel_id
|
||||
.send_message(&ctx.http, CreateMessage::new()
|
||||
|
@ -22,7 +22,7 @@ impl TypeMapKey for TTSData {
|
||||
pub struct TTSClientData;
|
||||
|
||||
impl TypeMapKey for TTSClientData {
|
||||
type Value = Arc<Mutex<TTS>>;
|
||||
type Value = Arc<TTS>;
|
||||
}
|
||||
|
||||
/// Database client data
|
||||
|
@ -65,14 +65,14 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
|
||||
);
|
||||
|
||||
let _handler = manager.join(guild_id, new_channel).await;
|
||||
let tts_client = ctx
|
||||
let data = ctx
|
||||
.data
|
||||
.read()
|
||||
.await
|
||||
.await;
|
||||
let tts_client = data
|
||||
.get::<TTSClientData>()
|
||||
.expect("Cannot get TTSClientData")
|
||||
.clone();
|
||||
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_speakers().await;
|
||||
.expect("Cannot get TTSClientData");
|
||||
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
|
||||
|
||||
new_channel
|
||||
.send_message(&ctx.http,
|
||||
|
@ -82,11 +82,6 @@ impl TTSMessage for Message {
|
||||
let text = self.parse(instance, ctx).await;
|
||||
|
||||
let data_read = ctx.data.read().await;
|
||||
let storage = data_read
|
||||
.get::<TTSClientData>()
|
||||
.expect("Cannot get GCP TTSClientStorage")
|
||||
.clone();
|
||||
let mut tts = storage.lock().await;
|
||||
|
||||
let config = {
|
||||
let database = data_read
|
||||
@ -101,6 +96,10 @@ impl TTSMessage for Message {
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let tts = data_read
|
||||
.get::<TTSClientData>()
|
||||
.expect("Cannot get GCP TTSClientStorage");
|
||||
|
||||
let audio = match config.tts_type.unwrap_or(TTSType::GCP) {
|
||||
TTSType::GCP => tts
|
||||
.synthesize_gcp(SynthesizeRequest {
|
||||
|
@ -88,7 +88,7 @@ async fn main() {
|
||||
{
|
||||
let mut data = client.data.write().await;
|
||||
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
|
||||
data.insert::<TTSClientData>(Arc::new(Mutex::new(TTS::new(voicevox, tts))));
|
||||
data.insert::<TTSClientData>(Arc::new(TTS::new(voicevox, tts)));
|
||||
data.insert::<DatabaseClientData>(Arc::new(Mutex::new(database_client)));
|
||||
}
|
||||
|
||||
|
@ -46,13 +46,11 @@ impl TTSMessage for AnnounceMessage {
|
||||
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Compressed> {
|
||||
let text = self.parse(instance, ctx).await;
|
||||
let data_read = ctx.data.read().await;
|
||||
let storage = data_read
|
||||
let tts = data_read
|
||||
.get::<TTSClientData>()
|
||||
.expect("Cannot get TTSClientStorage")
|
||||
.clone();
|
||||
let mut storage = storage.lock().await;
|
||||
.expect("Cannot get TTSClientStorage");
|
||||
|
||||
let audio = storage
|
||||
let audio = tts
|
||||
.synthesize_gcp(SynthesizeRequest {
|
||||
input: SynthesisInput {
|
||||
text: None,
|
||||
|
@ -1,4 +1,5 @@
|
||||
use std::num::NonZeroUsize;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use lru::LruCache;
|
||||
use songbird::{driver::Bitrate, input::cached::Compressed};
|
||||
@ -9,7 +10,7 @@ use super::{gcp_tts::{gcp_tts::GCPTTS, structs::{synthesis_input::SynthesisInput
|
||||
pub struct TTS {
|
||||
pub voicevox_client: VOICEVOX,
|
||||
gcp_tts_client: GCPTTS,
|
||||
cache: LruCache<CacheKey, Compressed>,
|
||||
cache: RwLock<LruCache<CacheKey, Compressed>>,
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
@ -26,17 +27,23 @@ impl TTS {
|
||||
Self {
|
||||
voicevox_client,
|
||||
gcp_tts_client,
|
||||
cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
|
||||
cache: RwLock::new(LruCache::new(NonZeroUsize::new(100).unwrap())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn synthesize_voicevox(&mut self, text: &str, speaker: i64) -> Result<Compressed, Box<dyn std::error::Error>> {
|
||||
pub async fn synthesize_voicevox(&self, text: &str, speaker: i64) -> Result<Compressed, Box<dyn std::error::Error>> {
|
||||
let cache_key = CacheKey::Voicevox(text.to_string(), speaker);
|
||||
|
||||
if let Some(audio) = self.cache.get(&cache_key) {
|
||||
let cached_audio = {
|
||||
let mut cache_guard = self.cache.write().unwrap();
|
||||
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
|
||||
};
|
||||
|
||||
if let Some(audio) = cached_audio {
|
||||
info!("Cache hit for VOICEVOX TTS");
|
||||
return Ok(audio.new_handle());
|
||||
return Ok(audio);
|
||||
}
|
||||
|
||||
info!("Cache miss for VOICEVOX TTS");
|
||||
|
||||
let audio = self.voicevox_client
|
||||
@ -45,21 +52,30 @@ impl TTS {
|
||||
|
||||
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
|
||||
|
||||
self.cache.put(cache_key, compressed.clone());
|
||||
{
|
||||
let mut cache_guard = self.cache.write().unwrap();
|
||||
cache_guard.put(cache_key, compressed.clone());
|
||||
}
|
||||
|
||||
Ok(compressed)
|
||||
}
|
||||
|
||||
pub async fn synthesize_gcp(&mut self, synthesize_request: SynthesizeRequest) -> Result<Compressed, Box<dyn std::error::Error>> {
|
||||
pub async fn synthesize_gcp(&self, synthesize_request: SynthesizeRequest) -> Result<Compressed, Box<dyn std::error::Error>> {
|
||||
let cache_key = CacheKey::GCP(
|
||||
synthesize_request.input.clone(),
|
||||
synthesize_request.voice.clone(),
|
||||
);
|
||||
|
||||
if let Some(audio) = self.cache.get(&cache_key) {
|
||||
let cached_audio = {
|
||||
let mut cache_guard = self.cache.write().unwrap();
|
||||
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
|
||||
};
|
||||
|
||||
if let Some(audio) = cached_audio {
|
||||
info!("Cache hit for GCP TTS");
|
||||
return Ok(audio.new_handle());
|
||||
return Ok(audio);
|
||||
}
|
||||
|
||||
info!("Cache miss for GCP TTS");
|
||||
|
||||
let audio = self.gcp_tts_client
|
||||
@ -68,8 +84,11 @@ impl TTS {
|
||||
|
||||
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
|
||||
|
||||
self.cache.put(cache_key, compressed.clone());
|
||||
{
|
||||
let mut cache_guard = self.cache.write().unwrap();
|
||||
cache_guard.put(cache_key, compressed.clone());
|
||||
}
|
||||
|
||||
Ok(compressed)
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user