mirror of
https://github.com/mii443/ncb-tts-r2.git
synced 2025-08-22 16:15:29 +00:00
implement compressed local audio cache
This commit is contained in:
@ -16,6 +16,8 @@ async-trait = "0.1.57"
|
||||
redis = "*"
|
||||
regex = "1"
|
||||
tracing-subscriber = "0.3.19"
|
||||
lru = "0.13.0"
|
||||
tracing = "0.1.41"
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "0.8"
|
||||
|
@ -33,7 +33,7 @@ pub async fn config_command(
|
||||
.get::<TTSClientData>()
|
||||
.expect("Cannot get TTSClientData")
|
||||
.clone();
|
||||
let voicevox_speakers = tts_client.lock().await.1.get_styles().await;
|
||||
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_styles().await;
|
||||
|
||||
let voicevox_speaker = config.voicevox_speaker.unwrap_or(1);
|
||||
let tts_type = config.tts_type.unwrap_or(TTSType::GCP);
|
||||
|
@ -142,7 +142,7 @@ pub async fn setup_command(
|
||||
.get::<TTSClientData>()
|
||||
.expect("Cannot get TTSClientData")
|
||||
.clone();
|
||||
let voicevox_speakers = tts_client.lock().await.1.get_speakers().await;
|
||||
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_speakers().await;
|
||||
|
||||
text_channel_id
|
||||
.send_message(&ctx.http, CreateMessage::new()
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
database::database::Database,
|
||||
tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX},
|
||||
tts::{gcp_tts::gcp_tts::GCPTTS, tts::TTS, voicevox::voicevox::VOICEVOX},
|
||||
};
|
||||
use serenity::{
|
||||
futures::lock::Mutex,
|
||||
@ -22,7 +22,7 @@ impl TypeMapKey for TTSData {
|
||||
pub struct TTSClientData;
|
||||
|
||||
impl TypeMapKey for TTSClientData {
|
||||
type Value = Arc<Mutex<(TTS, VOICEVOX)>>;
|
||||
type Value = Arc<Mutex<TTS>>;
|
||||
}
|
||||
|
||||
/// Database client data
|
||||
|
@ -72,7 +72,7 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
|
||||
.get::<TTSClientData>()
|
||||
.expect("Cannot get TTSClientData")
|
||||
.clone();
|
||||
let voicevox_speakers = tts_client.lock().await.1.get_speakers().await;
|
||||
let voicevox_speakers = tts_client.lock().await.voicevox_client.get_speakers().await;
|
||||
|
||||
new_channel
|
||||
.send_message(&ctx.http,
|
||||
|
@ -1,6 +1,7 @@
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serenity::{model::prelude::Message, prelude::Context};
|
||||
use songbird::input::cached::Compressed;
|
||||
|
||||
use crate::{
|
||||
data::{DatabaseClientData, TTSClientData}, implement::member_name::ReadName, tts::{
|
||||
@ -77,7 +78,7 @@ impl TTSMessage for Message {
|
||||
res
|
||||
}
|
||||
|
||||
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<u8> {
|
||||
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Compressed {
|
||||
let text = self.parse(instance, ctx).await;
|
||||
|
||||
let data_read = ctx.data.read().await;
|
||||
@ -102,8 +103,7 @@ impl TTSMessage for Message {
|
||||
|
||||
let audio = match config.tts_type.unwrap_or(TTSType::GCP) {
|
||||
TTSType::GCP => tts
|
||||
.0
|
||||
.synthesize(SynthesizeRequest {
|
||||
.synthesize_gcp(SynthesizeRequest {
|
||||
input: SynthesisInput {
|
||||
text: None,
|
||||
ssml: Some(format!("<speak>{}</speak>", text)),
|
||||
@ -119,9 +119,8 @@ impl TTSMessage for Message {
|
||||
.unwrap(),
|
||||
|
||||
TTSType::VOICEVOX => tts
|
||||
.1
|
||||
.synthesize(
|
||||
text.replace("<break time=\"200ms\"/>", "、"),
|
||||
.synthesize_voicevox(
|
||||
&text.replace("<break time=\"200ms\"/>", "、"),
|
||||
config.voicevox_speaker.unwrap_or(1),
|
||||
)
|
||||
.await
|
||||
|
@ -16,7 +16,8 @@ use event_handler::Handler;
|
||||
use serenity::{
|
||||
all::{standard::Configuration, ApplicationId}, client::Client, framework::StandardFramework, futures::lock::Mutex, prelude::{GatewayIntents, RwLock}
|
||||
};
|
||||
use tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX};
|
||||
use tracing::Level;
|
||||
use tts::{gcp_tts::gcp_tts::GCPTTS, tts::TTS, voicevox::voicevox::VOICEVOX};
|
||||
|
||||
use songbird::SerenityInit;
|
||||
|
||||
@ -42,7 +43,7 @@ async fn create_client(prefix: &str, token: &str, id: u64) -> Result<Client, ser
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
tracing_subscriber::fmt().init();
|
||||
tracing_subscriber::fmt().with_max_level(Level::DEBUG).init();
|
||||
// Load config
|
||||
let config = {
|
||||
let config = std::fs::read_to_string("./config.toml");
|
||||
@ -71,7 +72,7 @@ async fn main() {
|
||||
.expect("Err creating client");
|
||||
|
||||
// Create GCP TTS client
|
||||
let tts = match TTS::new("./credentials.json".to_string()).await {
|
||||
let tts = match GCPTTS::new("./credentials.json".to_string()).await {
|
||||
Ok(tts) => tts,
|
||||
Err(err) => panic!("GCP init error: {}", err),
|
||||
};
|
||||
@ -87,7 +88,7 @@ async fn main() {
|
||||
{
|
||||
let mut data = client.data.write().await;
|
||||
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
|
||||
data.insert::<TTSClientData>(Arc::new(Mutex::new((tts, voicevox))));
|
||||
data.insert::<TTSClientData>(Arc::new(Mutex::new(TTS::new(voicevox, tts))));
|
||||
data.insert::<DatabaseClientData>(Arc::new(Mutex::new(database_client)));
|
||||
}
|
||||
|
||||
|
@ -4,12 +4,12 @@ use crate::tts::gcp_tts::structs::{
|
||||
use gcp_auth::Token;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TTS {
|
||||
pub struct GCPTTS {
|
||||
pub token: Token,
|
||||
pub credentials_path: String,
|
||||
}
|
||||
|
||||
impl TTS {
|
||||
impl GCPTTS {
|
||||
pub async fn update_token(&mut self) -> Result<(), gcp_auth::Error> {
|
||||
if self.token.has_expired() {
|
||||
let authenticator =
|
||||
@ -23,13 +23,13 @@ impl TTS {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn new(credentials_path: String) -> Result<TTS, gcp_auth::Error> {
|
||||
pub async fn new(credentials_path: String) -> Result<Self, gcp_auth::Error> {
|
||||
let authenticator = gcp_auth::from_credentials_file(credentials_path.clone()).await?;
|
||||
let token = authenticator
|
||||
.get_token(&["https://www.googleapis.com/auth/cloud-platform"])
|
||||
.await?;
|
||||
|
||||
Ok(TTS {
|
||||
Ok(Self {
|
||||
token,
|
||||
credentials_path,
|
||||
})
|
||||
|
@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
|
||||
/// ssml: Some(String::from("<speak>test</speak>"))
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone)]
|
||||
pub struct SynthesisInput {
|
||||
pub text: Option<String>,
|
||||
pub ssml: Option<String>,
|
||||
|
@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize};
|
||||
/// ssmlGender: String::from("neutral")
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq, Hash)]
|
||||
#[allow(non_snake_case)]
|
||||
pub struct VoiceSelectionParams {
|
||||
pub languageCode: String,
|
||||
|
@ -1,5 +1,6 @@
|
||||
use async_trait::async_trait;
|
||||
use serenity::prelude::Context;
|
||||
use songbird::input::cached::Compressed;
|
||||
|
||||
use crate::{data::TTSClientData, tts::instance::TTSInstance};
|
||||
|
||||
@ -25,7 +26,7 @@ pub trait TTSMessage {
|
||||
/// ```rust
|
||||
/// let audio = message.synthesize(instance, ctx).await;
|
||||
/// ```
|
||||
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<u8>;
|
||||
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Compressed;
|
||||
}
|
||||
|
||||
pub struct AnnounceMessage {
|
||||
@ -42,7 +43,7 @@ impl TTSMessage for AnnounceMessage {
|
||||
)
|
||||
}
|
||||
|
||||
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<u8> {
|
||||
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Compressed {
|
||||
let text = self.parse(instance, ctx).await;
|
||||
let data_read = ctx.data.read().await;
|
||||
let storage = data_read
|
||||
@ -52,8 +53,7 @@ impl TTSMessage for AnnounceMessage {
|
||||
let mut storage = storage.lock().await;
|
||||
|
||||
let audio = storage
|
||||
.0
|
||||
.synthesize(SynthesizeRequest {
|
||||
.synthesize_gcp(SynthesizeRequest {
|
||||
input: SynthesisInput {
|
||||
text: None,
|
||||
ssml: Some(text),
|
||||
|
@ -3,3 +3,4 @@ pub mod instance;
|
||||
pub mod message;
|
||||
pub mod tts_type;
|
||||
pub mod voicevox;
|
||||
pub mod tts;
|
75
src/tts/tts.rs
Normal file
75
src/tts/tts.rs
Normal file
@ -0,0 +1,75 @@
|
||||
use std::num::NonZeroUsize;
|
||||
|
||||
use lru::LruCache;
|
||||
use songbird::{driver::Bitrate, input::cached::Compressed};
|
||||
use tracing::info;
|
||||
|
||||
use super::{gcp_tts::{gcp_tts::GCPTTS, structs::{synthesis_input::SynthesisInput, synthesize_request::SynthesizeRequest, voice_selection_params::VoiceSelectionParams}}, voicevox::voicevox::VOICEVOX};
|
||||
|
||||
pub struct TTS {
|
||||
pub voicevox_client: VOICEVOX,
|
||||
gcp_tts_client: GCPTTS,
|
||||
cache: LruCache<CacheKey, Compressed>,
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
pub enum CacheKey {
|
||||
Voicevox(String, i64),
|
||||
GCP(SynthesisInput, VoiceSelectionParams),
|
||||
}
|
||||
|
||||
impl TTS {
|
||||
pub fn new(
|
||||
voicevox_client: VOICEVOX,
|
||||
gcp_tts_client: GCPTTS,
|
||||
) -> Self {
|
||||
Self {
|
||||
voicevox_client,
|
||||
gcp_tts_client,
|
||||
cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn synthesize_voicevox(&mut self, text: &str, speaker: i64) -> Result<Compressed, Box<dyn std::error::Error>> {
|
||||
let cache_key = CacheKey::Voicevox(text.to_string(), speaker);
|
||||
|
||||
if let Some(audio) = self.cache.get(&cache_key) {
|
||||
info!("Cache hit for VOICEVOX TTS");
|
||||
return Ok(audio.clone());
|
||||
}
|
||||
info!("Cache miss for VOICEVOX TTS");
|
||||
|
||||
let audio = self.voicevox_client
|
||||
.synthesize(text.to_string(), speaker)
|
||||
.await?;
|
||||
|
||||
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
|
||||
|
||||
self.cache.put(cache_key, compressed.clone());
|
||||
|
||||
Ok(compressed)
|
||||
}
|
||||
|
||||
pub async fn synthesize_gcp(&mut self, synthesize_request: SynthesizeRequest) -> Result<Compressed, Box<dyn std::error::Error>> {
|
||||
let cache_key = CacheKey::GCP(
|
||||
synthesize_request.input.clone(),
|
||||
synthesize_request.voice.clone(),
|
||||
);
|
||||
|
||||
if let Some(audio) = self.cache.get(&cache_key) {
|
||||
info!("Cache hit for GCP TTS");
|
||||
return Ok(audio.clone());
|
||||
}
|
||||
info!("Cache miss for GCP TTS");
|
||||
|
||||
let audio = self.gcp_tts_client
|
||||
.synthesize(synthesize_request)
|
||||
.await?;
|
||||
|
||||
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
|
||||
|
||||
self.cache.put(cache_key, compressed.clone());
|
||||
|
||||
Ok(compressed)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user