implement compressed local audio cache

This commit is contained in:
mii443
2025-04-04 20:45:20 +09:00
parent 77b4c3e04d
commit c68e533133
13 changed files with 103 additions and 25 deletions

View File

@ -16,6 +16,8 @@ async-trait = "0.1.57"
redis = "*"
regex = "1"
tracing-subscriber = "0.3.19"
lru = "0.13.0"
tracing = "0.1.41"
[dependencies.uuid]
version = "0.8"

View File

@ -33,7 +33,7 @@ pub async fn config_command(
.get::<TTSClientData>()
.expect("Cannot get TTSClientData")
.clone();
let voicevox_speakers = tts_client.lock().await.1.get_styles().await;
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_styles().await;
let voicevox_speaker = config.voicevox_speaker.unwrap_or(1);
let tts_type = config.tts_type.unwrap_or(TTSType::GCP);

View File

@ -142,7 +142,7 @@ pub async fn setup_command(
.get::<TTSClientData>()
.expect("Cannot get TTSClientData")
.clone();
let voicevox_speakers = tts_client.lock().await.1.get_speakers().await;
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_speakers().await;
text_channel_id
.send_message(&ctx.http, CreateMessage::new()

View File

@ -1,6 +1,6 @@
use crate::{
database::database::Database,
tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX},
tts::{gcp_tts::gcp_tts::GCPTTS, tts::TTS, voicevox::voicevox::VOICEVOX},
};
use serenity::{
futures::lock::Mutex,
@ -22,7 +22,7 @@ impl TypeMapKey for TTSData {
pub struct TTSClientData;
impl TypeMapKey for TTSClientData {
type Value = Arc<Mutex<(TTS, VOICEVOX)>>;
type Value = Arc<Mutex<TTS>>;
}
/// Database client data

View File

@ -72,7 +72,7 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
.get::<TTSClientData>()
.expect("Cannot get TTSClientData")
.clone();
let voicevox_speakers = tts_client.lock().await.1.get_speakers().await;
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_speakers().await;
new_channel
.send_message(&ctx.http,

View File

@ -1,6 +1,7 @@
use async_trait::async_trait;
use regex::Regex;
use serenity::{model::prelude::Message, prelude::Context};
use songbird::input::cached::Compressed;
use crate::{
data::{DatabaseClientData, TTSClientData}, implement::member_name::ReadName, tts::{
@ -77,7 +78,7 @@ impl TTSMessage for Message {
res
}
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<u8> {
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Compressed {
let text = self.parse(instance, ctx).await;
let data_read = ctx.data.read().await;
@ -102,8 +103,7 @@ impl TTSMessage for Message {
let audio = match config.tts_type.unwrap_or(TTSType::GCP) {
TTSType::GCP => tts
.0
.synthesize(SynthesizeRequest {
.synthesize_gcp(SynthesizeRequest {
input: SynthesisInput {
text: None,
ssml: Some(format!("<speak>{}</speak>", text)),
@ -119,9 +119,8 @@ impl TTSMessage for Message {
.unwrap(),
TTSType::VOICEVOX => tts
.1
.synthesize(
text.replace("<break time=\"200ms\"/>", ""),
.synthesize_voicevox(
&text.replace("<break time=\"200ms\"/>", ""),
config.voicevox_speaker.unwrap_or(1),
)
.await

View File

@ -16,7 +16,8 @@ use event_handler::Handler;
use serenity::{
all::{standard::Configuration, ApplicationId}, client::Client, framework::StandardFramework, futures::lock::Mutex, prelude::{GatewayIntents, RwLock}
};
use tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX};
use tracing::Level;
use tts::{gcp_tts::gcp_tts::GCPTTS, tts::TTS, voicevox::voicevox::VOICEVOX};
use songbird::SerenityInit;
@ -42,7 +43,7 @@ async fn create_client(prefix: &str, token: &str, id: u64) -> Result<Client, ser
#[tokio::main]
async fn main() {
tracing_subscriber::fmt().init();
tracing_subscriber::fmt().with_max_level(Level::DEBUG).init();
// Load config
let config = {
let config = std::fs::read_to_string("./config.toml");
@ -71,7 +72,7 @@ async fn main() {
.expect("Err creating client");
// Create GCP TTS client
let tts = match TTS::new("./credentials.json".to_string()).await {
let tts = match GCPTTS::new("./credentials.json".to_string()).await {
Ok(tts) => tts,
Err(err) => panic!("GCP init error: {}", err),
};
@ -87,7 +88,7 @@ async fn main() {
{
let mut data = client.data.write().await;
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
data.insert::<TTSClientData>(Arc::new(Mutex::new((tts, voicevox))));
data.insert::<TTSClientData>(Arc::new(Mutex::new(TTS::new(voicevox, tts))));
data.insert::<DatabaseClientData>(Arc::new(Mutex::new(database_client)));
}

View File

@ -4,12 +4,12 @@ use crate::tts::gcp_tts::structs::{
use gcp_auth::Token;
#[derive(Clone)]
pub struct TTS {
pub struct GCPTTS {
pub token: Token,
pub credentials_path: String,
}
impl TTS {
impl GCPTTS {
pub async fn update_token(&mut self) -> Result<(), gcp_auth::Error> {
if self.token.has_expired() {
let authenticator =
@ -23,13 +23,13 @@ impl TTS {
Ok(())
}
pub async fn new(credentials_path: String) -> Result<TTS, gcp_auth::Error> {
pub async fn new(credentials_path: String) -> Result<Self, gcp_auth::Error> {
let authenticator = gcp_auth::from_credentials_file(credentials_path.clone()).await?;
let token = authenticator
.get_token(&["https://www.googleapis.com/auth/cloud-platform"])
.await?;
Ok(TTS {
Ok(Self {
token,
credentials_path,
})

View File

@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
/// ssml: Some(String::from("<speak>test</speak>"))
/// }
/// ```
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone)]
pub struct SynthesisInput {
pub text: Option<String>,
pub ssml: Option<String>,

View File

@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize};
/// ssmlGender: String::from("neutral")
/// }
/// ```
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq, Hash)]
#[allow(non_snake_case)]
pub struct VoiceSelectionParams {
pub languageCode: String,

View File

@ -1,5 +1,6 @@
use async_trait::async_trait;
use serenity::prelude::Context;
use songbird::input::cached::Compressed;
use crate::{data::TTSClientData, tts::instance::TTSInstance};
@ -25,7 +26,7 @@ pub trait TTSMessage {
/// ```rust
/// let audio = message.synthesize(instance, ctx).await;
/// ```
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<u8>;
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Compressed;
}
pub struct AnnounceMessage {
@ -42,7 +43,7 @@ impl TTSMessage for AnnounceMessage {
)
}
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<u8> {
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Compressed {
let text = self.parse(instance, ctx).await;
let data_read = ctx.data.read().await;
let storage = data_read
@ -52,8 +53,7 @@ impl TTSMessage for AnnounceMessage {
let mut storage = storage.lock().await;
let audio = storage
.0
.synthesize(SynthesizeRequest {
.synthesize_gcp(SynthesizeRequest {
input: SynthesisInput {
text: None,
ssml: Some(text),

View File

@ -3,3 +3,4 @@ pub mod instance;
pub mod message;
pub mod tts_type;
pub mod voicevox;
pub mod tts;

75
src/tts/tts.rs Normal file
View File

@ -0,0 +1,75 @@
use std::num::NonZeroUsize;
use lru::LruCache;
use songbird::{driver::Bitrate, input::cached::Compressed};
use tracing::info;
use super::{gcp_tts::{gcp_tts::GCPTTS, structs::{synthesis_input::SynthesisInput, synthesize_request::SynthesizeRequest, voice_selection_params::VoiceSelectionParams}}, voicevox::voicevox::VOICEVOX};
pub struct TTS {
pub voicevox_client: VOICEVOX,
gcp_tts_client: GCPTTS,
cache: LruCache<CacheKey, Compressed>,
}
#[derive(Hash, PartialEq, Eq)]
pub enum CacheKey {
Voicevox(String, i64),
GCP(SynthesisInput, VoiceSelectionParams),
}
impl TTS {
pub fn new(
voicevox_client: VOICEVOX,
gcp_tts_client: GCPTTS,
) -> Self {
Self {
voicevox_client,
gcp_tts_client,
cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
}
}
pub async fn synthesize_voicevox(&mut self, text: &str, speaker: i64) -> Result<Compressed, Box<dyn std::error::Error>> {
let cache_key = CacheKey::Voicevox(text.to_string(), speaker);
if let Some(audio) = self.cache.get(&cache_key) {
info!("Cache hit for VOICEVOX TTS");
return Ok(audio.clone());
}
info!("Cache miss for VOICEVOX TTS");
let audio = self.voicevox_client
.synthesize(text.to_string(), speaker)
.await?;
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
self.cache.put(cache_key, compressed.clone());
Ok(compressed)
}
pub async fn synthesize_gcp(&mut self, synthesize_request: SynthesizeRequest) -> Result<Compressed, Box<dyn std::error::Error>> {
let cache_key = CacheKey::GCP(
synthesize_request.input.clone(),
synthesize_request.voice.clone(),
);
if let Some(audio) = self.cache.get(&cache_key) {
info!("Cache hit for GCP TTS");
return Ok(audio.clone());
}
info!("Cache miss for GCP TTS");
let audio = self.gcp_tts_client
.synthesize(synthesize_request)
.await?;
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
self.cache.put(cache_key, compressed.clone());
Ok(compressed)
}
}