diff --git a/Cargo.toml b/Cargo.toml index 7f0579b..16e75b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,8 @@ async-trait = "0.1.57" redis = "*" regex = "1" tracing-subscriber = "0.3.19" +lru = "0.13.0" +tracing = "0.1.41" [dependencies.uuid] version = "0.8" diff --git a/src/commands/config.rs b/src/commands/config.rs index 351368f..5047499 100644 --- a/src/commands/config.rs +++ b/src/commands/config.rs @@ -33,7 +33,7 @@ pub async fn config_command( .get::() .expect("Cannot get TTSClientData") .clone(); - let voicevox_speakers = tts_client.lock().await.1.get_styles().await; + let voicevox_speakers = tts_client.lock().await.voicevox_client.get_styles().await; let voicevox_speaker = config.voicevox_speaker.unwrap_or(1); let tts_type = config.tts_type.unwrap_or(TTSType::GCP); diff --git a/src/commands/setup.rs b/src/commands/setup.rs index 17e8168..4a9891b 100644 --- a/src/commands/setup.rs +++ b/src/commands/setup.rs @@ -142,7 +142,7 @@ pub async fn setup_command( .get::() .expect("Cannot get TTSClientData") .clone(); - let voicevox_speakers = tts_client.lock().await.1.get_speakers().await; + let voicevox_speakers = tts_client.lock().await.voicevox_client.get_speakers().await; text_channel_id .send_message(&ctx.http, CreateMessage::new() diff --git a/src/data.rs b/src/data.rs index 7ecbc86..b51a628 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,6 +1,6 @@ use crate::{ database::database::Database, - tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX}, + tts::{gcp_tts::gcp_tts::GCPTTS, tts::TTS, voicevox::voicevox::VOICEVOX}, }; use serenity::{ futures::lock::Mutex, @@ -22,7 +22,7 @@ impl TypeMapKey for TTSData { pub struct TTSClientData; impl TypeMapKey for TTSClientData { - type Value = Arc>; + type Value = Arc>; } /// Database client data diff --git a/src/events/voice_state_update.rs b/src/events/voice_state_update.rs index e6a98bd..9b2ba1e 100644 --- a/src/events/voice_state_update.rs +++ b/src/events/voice_state_update.rs @@ -72,7 +72,7 @@ pub async fn voice_state_update(ctx: Context, old: Option, new: Voic .get::() .expect("Cannot get TTSClientData") .clone(); - let voicevox_speakers = tts_client.lock().await.1.get_speakers().await; + let voicevox_speakers = tts_client.lock().await.voicevox_client.get_speakers().await; new_channel .send_message(&ctx.http, diff --git a/src/implement/message.rs b/src/implement/message.rs index a3b01dc..39d54a9 100644 --- a/src/implement/message.rs +++ b/src/implement/message.rs @@ -1,6 +1,7 @@ use async_trait::async_trait; use regex::Regex; use serenity::{model::prelude::Message, prelude::Context}; +use songbird::input::cached::Compressed; use crate::{ data::{DatabaseClientData, TTSClientData}, implement::member_name::ReadName, tts::{ @@ -77,7 +78,7 @@ impl TTSMessage for Message { res } - async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec { + async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Compressed { let text = self.parse(instance, ctx).await; let data_read = ctx.data.read().await; @@ -102,8 +103,7 @@ impl TTSMessage for Message { let audio = match config.tts_type.unwrap_or(TTSType::GCP) { TTSType::GCP => tts - .0 - .synthesize(SynthesizeRequest { + .synthesize_gcp(SynthesizeRequest { input: SynthesisInput { text: None, ssml: Some(format!("{}", text)), @@ -119,9 +119,8 @@ impl TTSMessage for Message { .unwrap(), TTSType::VOICEVOX => tts - .1 - .synthesize( - text.replace("", "、"), + .synthesize_voicevox( + &text.replace("", "、"), config.voicevox_speaker.unwrap_or(1), ) .await diff --git a/src/main.rs b/src/main.rs index 2fb8dee..9c95643 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,8 @@ use event_handler::Handler; use serenity::{ all::{standard::Configuration, ApplicationId}, client::Client, framework::StandardFramework, futures::lock::Mutex, prelude::{GatewayIntents, RwLock} }; -use tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX}; +use tracing::Level; +use tts::{gcp_tts::gcp_tts::GCPTTS, tts::TTS, voicevox::voicevox::VOICEVOX}; use songbird::SerenityInit; @@ -42,7 +43,7 @@ async fn create_client(prefix: &str, token: &str, id: u64) -> Result tts, Err(err) => panic!("GCP init error: {}", err), }; @@ -87,7 +88,7 @@ async fn main() { { let mut data = client.data.write().await; data.insert::(Arc::new(RwLock::new(HashMap::default()))); - data.insert::(Arc::new(Mutex::new((tts, voicevox)))); + data.insert::(Arc::new(Mutex::new(TTS::new(voicevox, tts)))); data.insert::(Arc::new(Mutex::new(database_client))); } diff --git a/src/tts/gcp_tts/gcp_tts.rs b/src/tts/gcp_tts/gcp_tts.rs index acf9d33..344e12a 100644 --- a/src/tts/gcp_tts/gcp_tts.rs +++ b/src/tts/gcp_tts/gcp_tts.rs @@ -4,12 +4,12 @@ use crate::tts::gcp_tts::structs::{ use gcp_auth::Token; #[derive(Clone)] -pub struct TTS { +pub struct GCPTTS { pub token: Token, pub credentials_path: String, } -impl TTS { +impl GCPTTS { pub async fn update_token(&mut self) -> Result<(), gcp_auth::Error> { if self.token.has_expired() { let authenticator = @@ -23,13 +23,13 @@ impl TTS { Ok(()) } - pub async fn new(credentials_path: String) -> Result { + pub async fn new(credentials_path: String) -> Result { let authenticator = gcp_auth::from_credentials_file(credentials_path.clone()).await?; let token = authenticator .get_token(&["https://www.googleapis.com/auth/cloud-platform"]) .await?; - Ok(TTS { + Ok(Self { token, credentials_path, }) diff --git a/src/tts/gcp_tts/structs/synthesis_input.rs b/src/tts/gcp_tts/structs/synthesis_input.rs index d7a1464..99d0a41 100644 --- a/src/tts/gcp_tts/structs/synthesis_input.rs +++ b/src/tts/gcp_tts/structs/synthesis_input.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; /// ssml: Some(String::from("test")) /// } /// ``` -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone)] pub struct SynthesisInput { pub text: Option, pub ssml: Option, diff --git a/src/tts/gcp_tts/structs/voice_selection_params.rs b/src/tts/gcp_tts/structs/voice_selection_params.rs index 37c78bd..442e94c 100644 --- a/src/tts/gcp_tts/structs/voice_selection_params.rs +++ b/src/tts/gcp_tts/structs/voice_selection_params.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; /// ssmlGender: String::from("neutral") /// } /// ``` -#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)] +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq, Hash)] #[allow(non_snake_case)] pub struct VoiceSelectionParams { pub languageCode: String, diff --git a/src/tts/message.rs b/src/tts/message.rs index 9ab4648..e9718ce 100644 --- a/src/tts/message.rs +++ b/src/tts/message.rs @@ -1,5 +1,6 @@ use async_trait::async_trait; use serenity::prelude::Context; +use songbird::input::cached::Compressed; use crate::{data::TTSClientData, tts::instance::TTSInstance}; @@ -25,7 +26,7 @@ pub trait TTSMessage { /// ```rust /// let audio = message.synthesize(instance, ctx).await; /// ``` - async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec; + async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Compressed; } pub struct AnnounceMessage { @@ -42,7 +43,7 @@ impl TTSMessage for AnnounceMessage { ) } - async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec { + async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Compressed { let text = self.parse(instance, ctx).await; let data_read = ctx.data.read().await; let storage = data_read @@ -52,8 +53,7 @@ impl TTSMessage for AnnounceMessage { let mut storage = storage.lock().await; let audio = storage - .0 - .synthesize(SynthesizeRequest { + .synthesize_gcp(SynthesizeRequest { input: SynthesisInput { text: None, ssml: Some(text), diff --git a/src/tts/mod.rs b/src/tts/mod.rs index 0ad87a4..e8bd604 100644 --- a/src/tts/mod.rs +++ b/src/tts/mod.rs @@ -3,3 +3,4 @@ pub mod instance; pub mod message; pub mod tts_type; pub mod voicevox; +pub mod tts; \ No newline at end of file diff --git a/src/tts/tts.rs b/src/tts/tts.rs new file mode 100644 index 0000000..d9a01ea --- /dev/null +++ b/src/tts/tts.rs @@ -0,0 +1,75 @@ +use std::num::NonZeroUsize; + +use lru::LruCache; +use songbird::{driver::Bitrate, input::cached::Compressed}; +use tracing::info; + +use super::{gcp_tts::{gcp_tts::GCPTTS, structs::{synthesis_input::SynthesisInput, synthesize_request::SynthesizeRequest, voice_selection_params::VoiceSelectionParams}}, voicevox::voicevox::VOICEVOX}; + +pub struct TTS { + pub voicevox_client: VOICEVOX, + gcp_tts_client: GCPTTS, + cache: LruCache, +} + +#[derive(Hash, PartialEq, Eq)] +pub enum CacheKey { + Voicevox(String, i64), + GCP(SynthesisInput, VoiceSelectionParams), +} + +impl TTS { + pub fn new( + voicevox_client: VOICEVOX, + gcp_tts_client: GCPTTS, + ) -> Self { + Self { + voicevox_client, + gcp_tts_client, + cache: LruCache::new(NonZeroUsize::new(100).unwrap()), + } + } + + pub async fn synthesize_voicevox(&mut self, text: &str, speaker: i64) -> Result> { + let cache_key = CacheKey::Voicevox(text.to_string(), speaker); + + if let Some(audio) = self.cache.get(&cache_key) { + info!("Cache hit for VOICEVOX TTS"); + return Ok(audio.clone()); + } + info!("Cache miss for VOICEVOX TTS"); + + let audio = self.voicevox_client + .synthesize(text.to_string(), speaker) + .await?; + + let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?; + + self.cache.put(cache_key, compressed.clone()); + + Ok(compressed) + } + + pub async fn synthesize_gcp(&mut self, synthesize_request: SynthesizeRequest) -> Result> { + let cache_key = CacheKey::GCP( + synthesize_request.input.clone(), + synthesize_request.voice.clone(), + ); + + if let Some(audio) = self.cache.get(&cache_key) { + info!("Cache hit for GCP TTS"); + return Ok(audio.clone()); + } + info!("Cache miss for GCP TTS"); + + let audio = self.gcp_tts_client + .synthesize(synthesize_request) + .await?; + + let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?; + + self.cache.put(cache_key, compressed.clone()); + + Ok(compressed) + } +}