Files
ncb-tts-r2/src/tts/tts.rs
2025-04-04 22:14:23 +09:00

94 lines
2.8 KiB
Rust

use std::num::NonZeroUsize;
use std::sync::RwLock;
use lru::LruCache;
use songbird::{driver::Bitrate, input::cached::Compressed};
use tracing::info;
use super::{gcp_tts::{gcp_tts::GCPTTS, structs::{synthesis_input::SynthesisInput, synthesize_request::SynthesizeRequest, voice_selection_params::VoiceSelectionParams}}, voicevox::voicevox::VOICEVOX};
pub struct TTS {
pub voicevox_client: VOICEVOX,
gcp_tts_client: GCPTTS,
cache: RwLock<LruCache<CacheKey, Compressed>>,
}
#[derive(Hash, PartialEq, Eq)]
pub enum CacheKey {
Voicevox(String, i64),
GCP(SynthesisInput, VoiceSelectionParams),
}
impl TTS {
pub fn new(
voicevox_client: VOICEVOX,
gcp_tts_client: GCPTTS,
) -> Self {
Self {
voicevox_client,
gcp_tts_client,
cache: RwLock::new(LruCache::new(NonZeroUsize::new(100).unwrap())),
}
}
pub async fn synthesize_voicevox(&self, text: &str, speaker: i64) -> Result<Compressed, Box<dyn std::error::Error>> {
let cache_key = CacheKey::Voicevox(text.to_string(), speaker);
let cached_audio = {
let mut cache_guard = self.cache.write().unwrap();
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
};
if let Some(audio) = cached_audio {
info!("Cache hit for VOICEVOX TTS");
return Ok(audio);
}
info!("Cache miss for VOICEVOX TTS");
let audio = self.voicevox_client
.synthesize(text.to_string(), speaker)
.await?;
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
{
let mut cache_guard = self.cache.write().unwrap();
cache_guard.put(cache_key, compressed.clone());
}
Ok(compressed)
}
pub async fn synthesize_gcp(&self, synthesize_request: SynthesizeRequest) -> Result<Compressed, Box<dyn std::error::Error>> {
let cache_key = CacheKey::GCP(
synthesize_request.input.clone(),
synthesize_request.voice.clone(),
);
let cached_audio = {
let mut cache_guard = self.cache.write().unwrap();
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
};
if let Some(audio) = cached_audio {
info!("Cache hit for GCP TTS");
return Ok(audio);
}
info!("Cache miss for GCP TTS");
let audio = self.gcp_tts_client
.synthesize(synthesize_request)
.await?;
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
{
let mut cache_guard = self.cache.write().unwrap();
cache_guard.put(cache_key, compressed.clone());
}
Ok(compressed)
}
}