mirror of
https://github.com/mii443/ncb-tts-r2.git
synced 2025-08-22 16:15:29 +00:00
reduce Mutex lock
This commit is contained in:
@ -31,9 +31,8 @@ pub async fn config_command(
|
|||||||
|
|
||||||
let tts_client = data_read
|
let tts_client = data_read
|
||||||
.get::<TTSClientData>()
|
.get::<TTSClientData>()
|
||||||
.expect("Cannot get TTSClientData")
|
.expect("Cannot get TTSClientData");
|
||||||
.clone();
|
let voicevox_speakers = tts_client.voicevox_client.get_styles().await;
|
||||||
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_styles().await;
|
|
||||||
|
|
||||||
let voicevox_speaker = config.voicevox_speaker.unwrap_or(1);
|
let voicevox_speaker = config.voicevox_speaker.unwrap_or(1);
|
||||||
let tts_type = config.tts_type.unwrap_or(TTSType::GCP);
|
let tts_type = config.tts_type.unwrap_or(TTSType::GCP);
|
||||||
|
@ -135,14 +135,14 @@ pub async fn setup_command(
|
|||||||
|
|
||||||
let _handler = manager.join(guild.id, channel_id).await;
|
let _handler = manager.join(guild.id, channel_id).await;
|
||||||
|
|
||||||
let tts_client = ctx
|
let data = ctx
|
||||||
.data
|
.data
|
||||||
.read()
|
.read()
|
||||||
.await
|
.await;
|
||||||
|
let tts_client = data
|
||||||
.get::<TTSClientData>()
|
.get::<TTSClientData>()
|
||||||
.expect("Cannot get TTSClientData")
|
.expect("Cannot get TTSClientData");
|
||||||
.clone();
|
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
|
||||||
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_speakers().await;
|
|
||||||
|
|
||||||
text_channel_id
|
text_channel_id
|
||||||
.send_message(&ctx.http, CreateMessage::new()
|
.send_message(&ctx.http, CreateMessage::new()
|
||||||
|
@ -22,7 +22,7 @@ impl TypeMapKey for TTSData {
|
|||||||
pub struct TTSClientData;
|
pub struct TTSClientData;
|
||||||
|
|
||||||
impl TypeMapKey for TTSClientData {
|
impl TypeMapKey for TTSClientData {
|
||||||
type Value = Arc<Mutex<TTS>>;
|
type Value = Arc<TTS>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Database client data
|
/// Database client data
|
||||||
|
@ -65,14 +65,14 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
|
|||||||
);
|
);
|
||||||
|
|
||||||
let _handler = manager.join(guild_id, new_channel).await;
|
let _handler = manager.join(guild_id, new_channel).await;
|
||||||
let tts_client = ctx
|
let data = ctx
|
||||||
.data
|
.data
|
||||||
.read()
|
.read()
|
||||||
.await
|
.await;
|
||||||
|
let tts_client = data
|
||||||
.get::<TTSClientData>()
|
.get::<TTSClientData>()
|
||||||
.expect("Cannot get TTSClientData")
|
.expect("Cannot get TTSClientData");
|
||||||
.clone();
|
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
|
||||||
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_speakers().await;
|
|
||||||
|
|
||||||
new_channel
|
new_channel
|
||||||
.send_message(&ctx.http,
|
.send_message(&ctx.http,
|
||||||
|
@ -82,11 +82,6 @@ impl TTSMessage for Message {
|
|||||||
let text = self.parse(instance, ctx).await;
|
let text = self.parse(instance, ctx).await;
|
||||||
|
|
||||||
let data_read = ctx.data.read().await;
|
let data_read = ctx.data.read().await;
|
||||||
let storage = data_read
|
|
||||||
.get::<TTSClientData>()
|
|
||||||
.expect("Cannot get GCP TTSClientStorage")
|
|
||||||
.clone();
|
|
||||||
let mut tts = storage.lock().await;
|
|
||||||
|
|
||||||
let config = {
|
let config = {
|
||||||
let database = data_read
|
let database = data_read
|
||||||
@ -101,6 +96,10 @@ impl TTSMessage for Message {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let tts = data_read
|
||||||
|
.get::<TTSClientData>()
|
||||||
|
.expect("Cannot get GCP TTSClientStorage");
|
||||||
|
|
||||||
let audio = match config.tts_type.unwrap_or(TTSType::GCP) {
|
let audio = match config.tts_type.unwrap_or(TTSType::GCP) {
|
||||||
TTSType::GCP => tts
|
TTSType::GCP => tts
|
||||||
.synthesize_gcp(SynthesizeRequest {
|
.synthesize_gcp(SynthesizeRequest {
|
||||||
|
@ -88,7 +88,7 @@ async fn main() {
|
|||||||
{
|
{
|
||||||
let mut data = client.data.write().await;
|
let mut data = client.data.write().await;
|
||||||
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
|
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
|
||||||
data.insert::<TTSClientData>(Arc::new(Mutex::new(TTS::new(voicevox, tts))));
|
data.insert::<TTSClientData>(Arc::new(TTS::new(voicevox, tts)));
|
||||||
data.insert::<DatabaseClientData>(Arc::new(Mutex::new(database_client)));
|
data.insert::<DatabaseClientData>(Arc::new(Mutex::new(database_client)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,13 +46,11 @@ impl TTSMessage for AnnounceMessage {
|
|||||||
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Compressed> {
|
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Compressed> {
|
||||||
let text = self.parse(instance, ctx).await;
|
let text = self.parse(instance, ctx).await;
|
||||||
let data_read = ctx.data.read().await;
|
let data_read = ctx.data.read().await;
|
||||||
let storage = data_read
|
let tts = data_read
|
||||||
.get::<TTSClientData>()
|
.get::<TTSClientData>()
|
||||||
.expect("Cannot get TTSClientStorage")
|
.expect("Cannot get TTSClientStorage");
|
||||||
.clone();
|
|
||||||
let mut storage = storage.lock().await;
|
|
||||||
|
|
||||||
let audio = storage
|
let audio = tts
|
||||||
.synthesize_gcp(SynthesizeRequest {
|
.synthesize_gcp(SynthesizeRequest {
|
||||||
input: SynthesisInput {
|
input: SynthesisInput {
|
||||||
text: None,
|
text: None,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
use std::num::NonZeroUsize;
|
use std::num::NonZeroUsize;
|
||||||
|
use std::sync::RwLock;
|
||||||
|
|
||||||
use lru::LruCache;
|
use lru::LruCache;
|
||||||
use songbird::{driver::Bitrate, input::cached::Compressed};
|
use songbird::{driver::Bitrate, input::cached::Compressed};
|
||||||
@ -9,7 +10,7 @@ use super::{gcp_tts::{gcp_tts::GCPTTS, structs::{synthesis_input::SynthesisInput
|
|||||||
pub struct TTS {
|
pub struct TTS {
|
||||||
pub voicevox_client: VOICEVOX,
|
pub voicevox_client: VOICEVOX,
|
||||||
gcp_tts_client: GCPTTS,
|
gcp_tts_client: GCPTTS,
|
||||||
cache: LruCache<CacheKey, Compressed>,
|
cache: RwLock<LruCache<CacheKey, Compressed>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Hash, PartialEq, Eq)]
|
#[derive(Hash, PartialEq, Eq)]
|
||||||
@ -26,17 +27,23 @@ impl TTS {
|
|||||||
Self {
|
Self {
|
||||||
voicevox_client,
|
voicevox_client,
|
||||||
gcp_tts_client,
|
gcp_tts_client,
|
||||||
cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
|
cache: RwLock::new(LruCache::new(NonZeroUsize::new(100).unwrap())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn synthesize_voicevox(&mut self, text: &str, speaker: i64) -> Result<Compressed, Box<dyn std::error::Error>> {
|
pub async fn synthesize_voicevox(&self, text: &str, speaker: i64) -> Result<Compressed, Box<dyn std::error::Error>> {
|
||||||
let cache_key = CacheKey::Voicevox(text.to_string(), speaker);
|
let cache_key = CacheKey::Voicevox(text.to_string(), speaker);
|
||||||
|
|
||||||
if let Some(audio) = self.cache.get(&cache_key) {
|
let cached_audio = {
|
||||||
|
let mut cache_guard = self.cache.write().unwrap();
|
||||||
|
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(audio) = cached_audio {
|
||||||
info!("Cache hit for VOICEVOX TTS");
|
info!("Cache hit for VOICEVOX TTS");
|
||||||
return Ok(audio.new_handle());
|
return Ok(audio);
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Cache miss for VOICEVOX TTS");
|
info!("Cache miss for VOICEVOX TTS");
|
||||||
|
|
||||||
let audio = self.voicevox_client
|
let audio = self.voicevox_client
|
||||||
@ -45,21 +52,30 @@ impl TTS {
|
|||||||
|
|
||||||
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
|
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
|
||||||
|
|
||||||
self.cache.put(cache_key, compressed.clone());
|
{
|
||||||
|
let mut cache_guard = self.cache.write().unwrap();
|
||||||
|
cache_guard.put(cache_key, compressed.clone());
|
||||||
|
}
|
||||||
|
|
||||||
Ok(compressed)
|
Ok(compressed)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn synthesize_gcp(&mut self, synthesize_request: SynthesizeRequest) -> Result<Compressed, Box<dyn std::error::Error>> {
|
pub async fn synthesize_gcp(&self, synthesize_request: SynthesizeRequest) -> Result<Compressed, Box<dyn std::error::Error>> {
|
||||||
let cache_key = CacheKey::GCP(
|
let cache_key = CacheKey::GCP(
|
||||||
synthesize_request.input.clone(),
|
synthesize_request.input.clone(),
|
||||||
synthesize_request.voice.clone(),
|
synthesize_request.voice.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Some(audio) = self.cache.get(&cache_key) {
|
let cached_audio = {
|
||||||
|
let mut cache_guard = self.cache.write().unwrap();
|
||||||
|
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(audio) = cached_audio {
|
||||||
info!("Cache hit for GCP TTS");
|
info!("Cache hit for GCP TTS");
|
||||||
return Ok(audio.new_handle());
|
return Ok(audio);
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Cache miss for GCP TTS");
|
info!("Cache miss for GCP TTS");
|
||||||
|
|
||||||
let audio = self.gcp_tts_client
|
let audio = self.gcp_tts_client
|
||||||
@ -68,8 +84,11 @@ impl TTS {
|
|||||||
|
|
||||||
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
|
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
|
||||||
|
|
||||||
self.cache.put(cache_key, compressed.clone());
|
{
|
||||||
|
let mut cache_guard = self.cache.write().unwrap();
|
||||||
|
cache_guard.put(cache_key, compressed.clone());
|
||||||
|
}
|
||||||
|
|
||||||
Ok(compressed)
|
Ok(compressed)
|
||||||
}
|
}
|
||||||
}
|
}
|
Reference in New Issue
Block a user