mirror of
https://github.com/mii443/ncb-tts-r2.git
synced 2025-08-22 16:15:29 +00:00
refactor: Major overhaul with error handling, resilience patterns, and observability
- Add library configuration to support both lib and binary targets - Implement unified error handling with NCBError throughout the codebase - Add circuit breaker pattern for external API calls (Voicevox, GCP TTS) - Introduce comprehensive performance metrics and monitoring - Add cache persistence with disk storage support - Implement retry mechanism with exponential backoff - Add configuration file support (config.toml) with env var fallback - Enhance logging with structured tracing (debug, warn, error levels) - Add extensive unit tests for cache, metrics, and circuit breaker - Update base64 decoding to use modern API - Improve API error handling for Voicevox and GCP TTS clients Breaking changes: - Function signatures now return Result<T, NCBError> instead of panicking - Cache key structure modified with serialization support
This commit is contained in:
21
Cargo.toml
21
Cargo.toml
@ -3,6 +3,14 @@ name = "ncb-tts-r2"
|
|||||||
version = "1.11.2"
|
version = "1.11.2"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "ncb_tts_r2"
|
||||||
|
path = "src/lib.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "ncb-tts-r2"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
@ -13,10 +21,15 @@ gcp_auth = "0.5.0"
|
|||||||
reqwest = { version = "0.12.9", features = ["json"] }
|
reqwest = { version = "0.12.9", features = ["json"] }
|
||||||
base64 = "0.22.1"
|
base64 = "0.22.1"
|
||||||
async-trait = "0.1.57"
|
async-trait = "0.1.57"
|
||||||
redis = "0.29.2"
|
redis = { version = "0.29.2", features = ["aio", "tokio-comp"] }
|
||||||
|
bb8 = "0.8"
|
||||||
|
bb8-redis = "0.16"
|
||||||
|
thiserror = "1.0"
|
||||||
regex = "1"
|
regex = "1"
|
||||||
tracing-subscriber = "0.3.19"
|
tracing-subscriber = "0.3.19"
|
||||||
lru = "0.13.0"
|
lru = "0.13.0"
|
||||||
|
once_cell = "1.19"
|
||||||
|
bincode = "1.3"
|
||||||
tracing = "0.1.41"
|
tracing = "0.1.41"
|
||||||
opentelemetry_sdk = { version = "0.29.0", features = ["trace"] }
|
opentelemetry_sdk = { version = "0.29.0", features = ["trace"] }
|
||||||
opentelemetry = "0.29.1"
|
opentelemetry = "0.29.1"
|
||||||
@ -61,3 +74,9 @@ features = [
|
|||||||
[dependencies.tokio]
|
[dependencies.tokio]
|
||||||
version = "1.0"
|
version = "1.0"
|
||||||
features = ["macros", "rt-multi-thread"]
|
features = ["macros", "rt-multi-thread"]
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
tokio-test = "0.4"
|
||||||
|
mockall = "0.12"
|
||||||
|
tempfile = "3.8"
|
||||||
|
serial_test = "3.0"
|
||||||
|
@ -34,7 +34,11 @@ pub async fn config_command(
|
|||||||
let tts_client = data_read
|
let tts_client = data_read
|
||||||
.get::<TTSClientData>()
|
.get::<TTSClientData>()
|
||||||
.expect("Cannot get TTSClientData");
|
.expect("Cannot get TTSClientData");
|
||||||
let voicevox_speakers = tts_client.voicevox_client.get_styles().await;
|
let voicevox_speakers = tts_client.voicevox_client.get_styles().await
|
||||||
|
.unwrap_or_else(|e| {
|
||||||
|
tracing::error!("Failed to get VOICEVOX styles: {}", e);
|
||||||
|
vec![("VOICEVOX API unavailable".to_string(), 1)]
|
||||||
|
});
|
||||||
|
|
||||||
let voicevox_speaker = config.voicevox_speaker.unwrap_or(1);
|
let voicevox_speaker = config.voicevox_speaker.unwrap_or(1);
|
||||||
let tts_type = config.tts_type.unwrap_or(TTSType::GCP);
|
let tts_type = config.tts_type.unwrap_or(TTSType::GCP);
|
||||||
|
@ -149,7 +149,11 @@ pub async fn setup_command(
|
|||||||
let tts_client = data
|
let tts_client = data
|
||||||
.get::<TTSClientData>()
|
.get::<TTSClientData>()
|
||||||
.expect("Cannot get TTSClientData");
|
.expect("Cannot get TTSClientData");
|
||||||
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
|
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await
|
||||||
|
.unwrap_or_else(|e| {
|
||||||
|
tracing::error!("Failed to get VOICEVOX speakers: {}", e);
|
||||||
|
vec!["VOICEVOX API unavailable".to_string()]
|
||||||
|
});
|
||||||
|
|
||||||
text_channel_id
|
text_channel_id
|
||||||
.send_message(&ctx.http, CreateMessage::new()
|
.send_message(&ctx.http, CreateMessage::new()
|
||||||
|
@ -1,34 +1,75 @@
|
|||||||
use serenity::{model::channel::Message, prelude::Context, all::{CreateMessage, CreateEmbed}};
|
use serenity::{prelude::Context, all::{CreateMessage, CreateEmbed}};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::time;
|
use tokio::time;
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info, warn, instrument};
|
||||||
|
|
||||||
use crate::data::{DatabaseClientData, TTSData};
|
use crate::data::{DatabaseClientData, TTSData};
|
||||||
|
|
||||||
|
/// Constants for connection monitoring
|
||||||
|
const CONNECTION_CHECK_INTERVAL_SECS: u64 = 5;
|
||||||
|
const MAX_RECONNECTION_ATTEMPTS: u32 = 3;
|
||||||
|
const RECONNECTION_BACKOFF_SECS: u64 = 2;
|
||||||
|
|
||||||
|
/// Errors that can occur during connection monitoring
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ConnectionMonitorError {
|
||||||
|
#[error("Failed to get songbird manager")]
|
||||||
|
SongbirdManagerNotFound,
|
||||||
|
#[error("Failed to check voice channel users: {0}")]
|
||||||
|
VoiceChannelCheck(String),
|
||||||
|
#[error("Failed to reconnect after {attempts} attempts")]
|
||||||
|
ReconnectionFailed { attempts: u32 },
|
||||||
|
#[error("Database operation failed: {0}")]
|
||||||
|
Database(#[from] redis::RedisError),
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result<T> = std::result::Result<T, ConnectionMonitorError>;
|
||||||
|
|
||||||
/// Connection monitor that periodically checks voice channel connections
|
/// Connection monitor that periodically checks voice channel connections
|
||||||
pub struct ConnectionMonitor;
|
pub struct ConnectionMonitor {
|
||||||
|
reconnection_attempts: std::collections::HashMap<serenity::model::id::GuildId, u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ConnectionMonitor {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ConnectionMonitor {
|
impl ConnectionMonitor {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
reconnection_attempts: std::collections::HashMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Start the connection monitoring task
|
/// Start the connection monitoring task
|
||||||
pub fn start(ctx: Context) {
|
pub fn start(ctx: Context) {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
info!("Starting connection monitor with 5s interval");
|
let mut monitor = ConnectionMonitor::new();
|
||||||
let mut interval = time::interval(Duration::from_secs(5));
|
info!(
|
||||||
|
interval_secs = CONNECTION_CHECK_INTERVAL_SECS,
|
||||||
|
"Starting connection monitor"
|
||||||
|
);
|
||||||
|
let mut interval = time::interval(Duration::from_secs(CONNECTION_CHECK_INTERVAL_SECS));
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
interval.tick().await;
|
interval.tick().await;
|
||||||
Self::check_connections(&ctx).await;
|
if let Err(e) = monitor.check_connections(&ctx).await {
|
||||||
|
error!(error = %e, "Connection monitoring failed");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check all active TTS instances and their voice channel connections
|
/// Check all active TTS instances and their voice channel connections
|
||||||
async fn check_connections(ctx: &Context) {
|
#[instrument(skip(self, ctx))]
|
||||||
|
async fn check_connections(&mut self, ctx: &Context) -> Result<()> {
|
||||||
let storage_lock = {
|
let storage_lock = {
|
||||||
let data_read = ctx.data.read().await;
|
let data_read = ctx.data.read().await;
|
||||||
data_read
|
data_read
|
||||||
.get::<TTSData>()
|
.get::<TTSData>()
|
||||||
.expect("Cannot get TTSStorage")
|
.ok_or_else(|| ConnectionMonitorError::VoiceChannelCheck("Cannot get TTSStorage".to_string()))?
|
||||||
.clone()
|
.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -36,7 +77,7 @@ impl ConnectionMonitor {
|
|||||||
let data_read = ctx.data.read().await;
|
let data_read = ctx.data.read().await;
|
||||||
data_read
|
data_read
|
||||||
.get::<DatabaseClientData>()
|
.get::<DatabaseClientData>()
|
||||||
.expect("Cannot get DatabaseClientData")
|
.ok_or_else(|| ConnectionMonitorError::VoiceChannelCheck("Cannot get DatabaseClientData".to_string()))?
|
||||||
.clone()
|
.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -45,13 +86,8 @@ impl ConnectionMonitor {
|
|||||||
|
|
||||||
for (guild_id, instance) in storage.iter() {
|
for (guild_id, instance) in storage.iter() {
|
||||||
// Check if bot is still connected to voice channel
|
// Check if bot is still connected to voice channel
|
||||||
let manager = match songbird::get(ctx).await {
|
let manager = songbird::get(ctx).await
|
||||||
Some(manager) => manager,
|
.ok_or(ConnectionMonitorError::SongbirdManagerNotFound)?;
|
||||||
None => {
|
|
||||||
error!("Cannot get songbird manager");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let call = manager.get(*guild_id);
|
let call = manager.get(*guild_id);
|
||||||
let is_connected = if let Some(call) = call {
|
let is_connected = if let Some(call) = call {
|
||||||
@ -65,49 +101,87 @@ impl ConnectionMonitor {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if !is_connected {
|
if !is_connected {
|
||||||
warn!("Bot disconnected from voice channel in guild {}", guild_id);
|
warn!(guild_id = %guild_id, "Bot disconnected from voice channel");
|
||||||
|
|
||||||
// Check if there are users in the voice channel
|
// Check if there are users in the voice channel
|
||||||
let should_reconnect = match Self::check_voice_channel_users(ctx, instance).await {
|
let should_reconnect = match self.check_voice_channel_users(ctx, instance).await {
|
||||||
Ok(has_users) => has_users,
|
Ok(has_users) => has_users,
|
||||||
Err(_) => {
|
Err(e) => {
|
||||||
// If we can't check users, don't reconnect
|
warn!(guild_id = %guild_id, error = %e, "Failed to check voice channel users, skipping reconnection");
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if should_reconnect {
|
if should_reconnect {
|
||||||
// Try to reconnect
|
// Try to reconnect with retry logic
|
||||||
|
let attempts = self.reconnection_attempts.get(guild_id).copied().unwrap_or(0);
|
||||||
|
|
||||||
|
if attempts >= MAX_RECONNECTION_ATTEMPTS {
|
||||||
|
error!(
|
||||||
|
guild_id = %guild_id,
|
||||||
|
attempts = attempts,
|
||||||
|
"Maximum reconnection attempts reached, removing instance"
|
||||||
|
);
|
||||||
|
guilds_to_remove.push(*guild_id);
|
||||||
|
self.reconnection_attempts.remove(guild_id);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply exponential backoff
|
||||||
|
if attempts > 0 {
|
||||||
|
let backoff_duration = Duration::from_secs(RECONNECTION_BACKOFF_SECS * (2_u64.pow(attempts)));
|
||||||
|
warn!(
|
||||||
|
guild_id = %guild_id,
|
||||||
|
attempt = attempts + 1,
|
||||||
|
backoff_secs = backoff_duration.as_secs(),
|
||||||
|
"Applying backoff before reconnection attempt"
|
||||||
|
);
|
||||||
|
tokio::time::sleep(backoff_duration).await;
|
||||||
|
}
|
||||||
|
|
||||||
match instance.reconnect(ctx, true).await {
|
match instance.reconnect(ctx, true).await {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
info!(
|
info!(
|
||||||
"Successfully reconnected to voice channel in guild {}",
|
guild_id = %guild_id,
|
||||||
guild_id
|
attempts = attempts + 1,
|
||||||
|
"Successfully reconnected to voice channel"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Reset reconnection attempts on success
|
||||||
|
self.reconnection_attempts.remove(guild_id);
|
||||||
|
|
||||||
// Send notification message to text channel with embed
|
// Send notification message to text channel with embed
|
||||||
let embed = CreateEmbed::new()
|
let embed = CreateEmbed::new()
|
||||||
.title("🔄 自動再接続しました")
|
.title("🔄 自動再接続しました")
|
||||||
.description("読み上げを停止したい場合は `/stop` コマンドを使用してください。")
|
.description("読み上げを停止したい場合は `/stop` コマンドを使用してください。")
|
||||||
.color(0x00ff00);
|
.color(0x00ff00);
|
||||||
if let Err(e) = instance.text_channel.send_message(&ctx.http, CreateMessage::new().embed(embed)).await {
|
if let Err(e) = instance.text_channel.send_message(&ctx.http, CreateMessage::new().embed(embed)).await {
|
||||||
error!("Failed to send reconnection message to text channel: {}", e);
|
error!(guild_id = %guild_id, error = %e, "Failed to send reconnection message");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
let new_attempts = attempts + 1;
|
||||||
|
self.reconnection_attempts.insert(*guild_id, new_attempts);
|
||||||
error!(
|
error!(
|
||||||
"Failed to reconnect to voice channel in guild {}: {}",
|
guild_id = %guild_id,
|
||||||
guild_id, e
|
attempt = new_attempts,
|
||||||
|
error = %e,
|
||||||
|
"Failed to reconnect to voice channel"
|
||||||
);
|
);
|
||||||
guilds_to_remove.push(*guild_id);
|
|
||||||
|
if new_attempts >= MAX_RECONNECTION_ATTEMPTS {
|
||||||
|
guilds_to_remove.push(*guild_id);
|
||||||
|
self.reconnection_attempts.remove(guild_id);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
info!(
|
info!(
|
||||||
"No users in voice channel, removing instance for guild {}",
|
guild_id = %guild_id,
|
||||||
guild_id
|
"No users in voice channel, removing instance"
|
||||||
);
|
);
|
||||||
guilds_to_remove.push(*guild_id);
|
guilds_to_remove.push(*guild_id);
|
||||||
|
self.reconnection_attempts.remove(guild_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -118,29 +192,51 @@ impl ConnectionMonitor {
|
|||||||
|
|
||||||
// Remove from database
|
// Remove from database
|
||||||
if let Err(e) = database.remove_tts_instance(guild_id).await {
|
if let Err(e) = database.remove_tts_instance(guild_id).await {
|
||||||
error!("Failed to remove TTS instance from database: {}", e);
|
error!(guild_id = %guild_id, error = %e, "Failed to remove TTS instance from database");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure bot leaves voice channel
|
// Ensure bot leaves voice channel
|
||||||
if let Some(manager) = songbird::get(ctx).await {
|
if let Some(manager) = songbird::get(ctx).await {
|
||||||
let _ = manager.remove(guild_id).await;
|
if let Err(e) = manager.remove(guild_id).await {
|
||||||
|
error!(guild_id = %guild_id, error = %e, "Failed to remove bot from voice channel");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
info!(guild_id = %guild_id, "Removed disconnected TTS instance");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if there are users in the voice channel
|
/// Check if there are users in the voice channel
|
||||||
|
#[instrument(skip(self, ctx, instance))]
|
||||||
async fn check_voice_channel_users(
|
async fn check_voice_channel_users(
|
||||||
|
&self,
|
||||||
ctx: &Context,
|
ctx: &Context,
|
||||||
instance: &crate::tts::instance::TTSInstance,
|
instance: &crate::tts::instance::TTSInstance,
|
||||||
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<bool> {
|
||||||
let channels = instance.guild.channels(&ctx.http).await?;
|
let channels = instance.guild.channels(&ctx.http).await
|
||||||
|
.map_err(|e| ConnectionMonitorError::VoiceChannelCheck(format!("Failed to get guild channels: {}", e)))?;
|
||||||
|
|
||||||
if let Some(channel) = channels.get(&instance.voice_channel) {
|
if let Some(channel) = channels.get(&instance.voice_channel) {
|
||||||
let members = channel.members(&ctx.cache)?;
|
let members = channel.members(&ctx.cache)
|
||||||
|
.map_err(|e| ConnectionMonitorError::VoiceChannelCheck(format!("Failed to get channel members: {}", e)))?;
|
||||||
let user_count = members.iter().filter(|member| !member.user.bot).count();
|
let user_count = members.iter().filter(|member| !member.user.bot).count();
|
||||||
|
|
||||||
|
info!(
|
||||||
|
guild_id = %instance.guild,
|
||||||
|
channel_id = %instance.voice_channel,
|
||||||
|
user_count = user_count,
|
||||||
|
"Checked voice channel users"
|
||||||
|
);
|
||||||
|
|
||||||
Ok(user_count > 0)
|
Ok(user_count > 0)
|
||||||
} else {
|
} else {
|
||||||
// Channel doesn't exist anymore
|
warn!(
|
||||||
|
guild_id = %instance.guild,
|
||||||
|
channel_id = %instance.voice_channel,
|
||||||
|
"Voice channel no longer exists"
|
||||||
|
);
|
||||||
Ok(false)
|
Ok(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,88 +1,114 @@
|
|||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use crate::tts::{
|
use bb8_redis::{bb8::Pool, RedisConnectionManager, redis::AsyncCommands};
|
||||||
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, instance::TTSInstance,
|
use crate::{
|
||||||
tts_type::TTSType,
|
errors::{NCBError, Result, constants::*},
|
||||||
|
tts::{
|
||||||
|
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, instance::TTSInstance,
|
||||||
|
tts_type::TTSType,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use serenity::model::id::GuildId;
|
use serenity::model::id::{GuildId, UserId, ChannelId};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use super::{dictionary::Dictionary, server_config::ServerConfig, user_config::UserConfig};
|
use super::{dictionary::Dictionary, server_config::ServerConfig, user_config::UserConfig};
|
||||||
use redis::Commands;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Database {
|
pub struct Database {
|
||||||
pub client: redis::Client,
|
pub pool: Pool<RedisConnectionManager>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Database {
|
impl Database {
|
||||||
pub fn new(client: redis::Client) -> Self {
|
pub fn new(pool: Pool<RedisConnectionManager>) -> Self {
|
||||||
Self { client }
|
Self { pool }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn new_with_url(redis_url: String) -> Result<Self> {
|
||||||
|
let manager = RedisConnectionManager::new(redis_url)?;
|
||||||
|
let pool = Pool::builder()
|
||||||
|
.max_size(15)
|
||||||
|
.build(manager)
|
||||||
|
.await
|
||||||
|
.map_err(|e| NCBError::Database(format!("Pool creation failed: {}", e)))?;
|
||||||
|
Ok(Self { pool })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn server_key(server_id: u64) -> String {
|
fn server_key(server_id: u64) -> String {
|
||||||
format!("discord_server:{}", server_id)
|
format!("{}{}", DISCORD_SERVER_PREFIX, server_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn user_key(user_id: u64) -> String {
|
fn user_key(user_id: u64) -> String {
|
||||||
format!("discord_user:{}", user_id)
|
format!("{}{}", DISCORD_USER_PREFIX, user_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tts_instance_key(guild_id: u64) -> String {
|
fn tts_instance_key(guild_id: u64) -> String {
|
||||||
format!("tts_instance:{}", guild_id)
|
format!("{}{}", TTS_INSTANCE_PREFIX, guild_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tts_instances_list_key() -> String {
|
fn tts_instances_list_key() -> String {
|
||||||
"tts_instances_list".to_string()
|
TTS_INSTANCES_LIST_KEY.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn user_config_key(guild_id: u64, user_id: u64) -> String {
|
||||||
|
format!("user:config:{}:{}", guild_id, user_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn server_config_key(guild_id: u64) -> String {
|
||||||
|
format!("server:config:{}", guild_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dictionary_key(guild_id: u64) -> String {
|
||||||
|
format!("dictionary:{}", guild_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
fn get_config<T: serde::de::DeserializeOwned>(
|
async fn get_config<T: serde::de::DeserializeOwned>(
|
||||||
&self,
|
&self,
|
||||||
key: &str,
|
key: &str,
|
||||||
) -> redis::RedisResult<Option<T>> {
|
) -> Result<Option<T>> {
|
||||||
match self.client.get_connection() {
|
let mut connection = self.pool.get().await
|
||||||
Ok(mut connection) => {
|
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||||
let config: String = connection.get(key).unwrap_or_default();
|
|
||||||
|
|
||||||
if config.is_empty() {
|
let config: String = connection.get(key).await.unwrap_or_default();
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
match serde_json::from_str(&config) {
|
if config.is_empty() {
|
||||||
Ok(config) => Ok(Some(config)),
|
return Ok(None);
|
||||||
Err(_) => Ok(None),
|
}
|
||||||
}
|
|
||||||
|
match serde_json::from_str(&config) {
|
||||||
|
Ok(config) => Ok(Some(config)),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(key = key, error = %e, "Failed to deserialize config");
|
||||||
|
Ok(None)
|
||||||
}
|
}
|
||||||
Err(e) => Err(e),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
fn set_config<T: serde::Serialize + Debug>(
|
async fn set_config<T: serde::Serialize + Debug>(
|
||||||
&self,
|
&self,
|
||||||
key: &str,
|
key: &str,
|
||||||
config: &T,
|
config: &T,
|
||||||
) -> redis::RedisResult<()> {
|
) -> Result<()> {
|
||||||
match self.client.get_connection() {
|
let mut connection = self.pool.get().await
|
||||||
Ok(mut connection) => {
|
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||||
let config_str = serde_json::to_string(config).unwrap();
|
|
||||||
connection.set::<_, _, ()>(key, config_str)
|
let config_str = serde_json::to_string(config)?;
|
||||||
}
|
connection.set::<_, _, ()>(key, config_str).await?;
|
||||||
Err(e) => Err(e),
|
Ok(())
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub async fn get_server_config(
|
pub async fn get_server_config(
|
||||||
&self,
|
&self,
|
||||||
server_id: u64,
|
server_id: u64,
|
||||||
) -> redis::RedisResult<Option<ServerConfig>> {
|
) -> Result<Option<ServerConfig>> {
|
||||||
self.get_config(&Self::server_key(server_id))
|
self.get_config(&Self::server_key(server_id)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub async fn get_user_config(&self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
|
pub async fn get_user_config(&self, user_id: u64) -> Result<Option<UserConfig>> {
|
||||||
self.get_config(&Self::user_key(user_id))
|
self.get_config(&Self::user_key(user_id)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
@ -90,8 +116,8 @@ impl Database {
|
|||||||
&self,
|
&self,
|
||||||
server_id: u64,
|
server_id: u64,
|
||||||
config: ServerConfig,
|
config: ServerConfig,
|
||||||
) -> redis::RedisResult<()> {
|
) -> Result<()> {
|
||||||
self.set_config(&Self::server_key(server_id), &config)
|
self.set_config(&Self::server_key(server_id), &config).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
@ -99,12 +125,12 @@ impl Database {
|
|||||||
&self,
|
&self,
|
||||||
user_id: u64,
|
user_id: u64,
|
||||||
config: UserConfig,
|
config: UserConfig,
|
||||||
) -> redis::RedisResult<()> {
|
) -> Result<()> {
|
||||||
self.set_config(&Self::user_key(user_id), &config)
|
self.set_config(&Self::user_key(user_id), &config).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub async fn set_default_server_config(&self, server_id: u64) -> redis::RedisResult<()> {
|
pub async fn set_default_server_config(&self, server_id: u64) -> Result<()> {
|
||||||
let config = ServerConfig {
|
let config = ServerConfig {
|
||||||
dictionary: Dictionary::new(),
|
dictionary: Dictionary::new(),
|
||||||
autostart_channel_id: None,
|
autostart_channel_id: None,
|
||||||
@ -116,7 +142,7 @@ impl Database {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub async fn set_default_user_config(&self, user_id: u64) -> redis::RedisResult<()> {
|
pub async fn set_default_user_config(&self, user_id: u64) -> Result<()> {
|
||||||
let voice_selection = VoiceSelectionParams {
|
let voice_selection = VoiceSelectionParams {
|
||||||
languageCode: String::from("ja-JP"),
|
languageCode: String::from("ja-JP"),
|
||||||
name: String::from("ja-JP-Wavenet-B"),
|
name: String::from("ja-JP-Wavenet-B"),
|
||||||
@ -126,7 +152,7 @@ impl Database {
|
|||||||
let config = UserConfig {
|
let config = UserConfig {
|
||||||
tts_type: Some(TTSType::GCP),
|
tts_type: Some(TTSType::GCP),
|
||||||
gcp_tts_voice: Some(voice_selection),
|
gcp_tts_voice: Some(voice_selection),
|
||||||
voicevox_speaker: Some(1),
|
voicevox_speaker: Some(DEFAULT_VOICEVOX_SPEAKER),
|
||||||
};
|
};
|
||||||
|
|
||||||
self.set_user_config(user_id, config).await
|
self.set_user_config(user_id, config).await
|
||||||
@ -136,7 +162,7 @@ impl Database {
|
|||||||
pub async fn get_server_config_or_default(
|
pub async fn get_server_config_or_default(
|
||||||
&self,
|
&self,
|
||||||
server_id: u64,
|
server_id: u64,
|
||||||
) -> redis::RedisResult<Option<ServerConfig>> {
|
) -> Result<Option<ServerConfig>> {
|
||||||
match self.get_server_config(server_id).await? {
|
match self.get_server_config(server_id).await? {
|
||||||
Some(config) => Ok(Some(config)),
|
Some(config) => Ok(Some(config)),
|
||||||
None => {
|
None => {
|
||||||
@ -150,7 +176,7 @@ impl Database {
|
|||||||
pub async fn get_user_config_or_default(
|
pub async fn get_user_config_or_default(
|
||||||
&self,
|
&self,
|
||||||
user_id: u64,
|
user_id: u64,
|
||||||
) -> redis::RedisResult<Option<UserConfig>> {
|
) -> Result<Option<UserConfig>> {
|
||||||
match self.get_user_config(user_id).await? {
|
match self.get_user_config(user_id).await? {
|
||||||
Some(config) => Ok(Some(config)),
|
Some(config) => Ok(Some(config)),
|
||||||
None => {
|
None => {
|
||||||
@ -161,29 +187,23 @@ impl Database {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Save TTS instance to database
|
/// Save TTS instance to database
|
||||||
#[tracing::instrument]
|
|
||||||
pub async fn save_tts_instance(
|
pub async fn save_tts_instance(
|
||||||
&self,
|
&self,
|
||||||
guild_id: GuildId,
|
guild_id: GuildId,
|
||||||
instance: &TTSInstance,
|
instance: &TTSInstance,
|
||||||
) -> redis::RedisResult<()> {
|
) -> Result<()> {
|
||||||
let key = Self::tts_instance_key(guild_id.get());
|
let key = Self::tts_instance_key(guild_id.get());
|
||||||
let list_key = Self::tts_instances_list_key();
|
let list_key = Self::tts_instances_list_key();
|
||||||
|
|
||||||
// Save the instance
|
// Save the instance
|
||||||
let result = self.set_config(&key, instance);
|
self.set_config(&key, instance).await?;
|
||||||
|
|
||||||
// Add guild_id to the list of active instances
|
// Add guild_id to the list of active instances
|
||||||
if result.is_ok() {
|
let mut connection = self.pool.get().await
|
||||||
match self.client.get_connection() {
|
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||||
Ok(mut connection) => {
|
|
||||||
let _: redis::RedisResult<()> = connection.sadd(&list_key, guild_id.get());
|
|
||||||
}
|
|
||||||
Err(_) => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result
|
connection.sadd::<_, _, ()>(&list_key, guild_id.get()).await?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Load TTS instance from database
|
/// Load TTS instance from database
|
||||||
@ -191,47 +211,278 @@ impl Database {
|
|||||||
pub async fn load_tts_instance(
|
pub async fn load_tts_instance(
|
||||||
&self,
|
&self,
|
||||||
guild_id: GuildId,
|
guild_id: GuildId,
|
||||||
) -> redis::RedisResult<Option<TTSInstance>> {
|
) -> Result<Option<TTSInstance>> {
|
||||||
let key = Self::tts_instance_key(guild_id.get());
|
let key = Self::tts_instance_key(guild_id.get());
|
||||||
self.get_config(&key)
|
self.get_config(&key).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Remove TTS instance from database
|
/// Remove TTS instance from database
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub async fn remove_tts_instance(&self, guild_id: GuildId) -> redis::RedisResult<()> {
|
pub async fn remove_tts_instance(&self, guild_id: GuildId) -> Result<()> {
|
||||||
let key = Self::tts_instance_key(guild_id.get());
|
let key = Self::tts_instance_key(guild_id.get());
|
||||||
let list_key = Self::tts_instances_list_key();
|
let list_key = Self::tts_instances_list_key();
|
||||||
|
|
||||||
match self.client.get_connection() {
|
let mut connection = self.pool.get().await
|
||||||
Ok(mut connection) => {
|
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||||
let _: redis::RedisResult<()> = connection.del(&key);
|
|
||||||
let _: redis::RedisResult<()> = connection.srem(&list_key, guild_id.get());
|
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
|
||||||
Ok(())
|
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.srem(&list_key, guild_id.get()).await;
|
||||||
}
|
|
||||||
Err(e) => Err(e),
|
Ok(())
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get all active TTS instances
|
/// Get all active TTS instances
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub async fn get_all_tts_instances(&self) -> redis::RedisResult<Vec<(GuildId, TTSInstance)>> {
|
pub async fn get_all_tts_instances(&self) -> Result<Vec<(GuildId, TTSInstance)>> {
|
||||||
let list_key = Self::tts_instances_list_key();
|
let list_key = Self::tts_instances_list_key();
|
||||||
|
|
||||||
match self.client.get_connection() {
|
let mut connection = self.pool.get().await
|
||||||
Ok(mut connection) => {
|
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||||
let guild_ids: Vec<u64> = connection.smembers(&list_key).unwrap_or_default();
|
|
||||||
let mut instances = Vec::new();
|
|
||||||
|
|
||||||
for guild_id in guild_ids {
|
let guild_ids: Vec<u64> = connection.smembers(&list_key).await.unwrap_or_default();
|
||||||
let guild_id = GuildId::new(guild_id);
|
let mut instances = Vec::new();
|
||||||
if let Ok(Some(instance)) = self.load_tts_instance(guild_id).await {
|
|
||||||
instances.push((guild_id, instance));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(instances)
|
for guild_id in guild_ids {
|
||||||
|
let guild_id = GuildId::new(guild_id);
|
||||||
|
if let Ok(Some(instance)) = self.load_tts_instance(guild_id).await {
|
||||||
|
instances.push((guild_id, instance));
|
||||||
|
} else {
|
||||||
|
tracing::warn!(guild_id = %guild_id, "Failed to load TTS instance");
|
||||||
}
|
}
|
||||||
Err(e) => Err(e),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(instances)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Additional user config methods
|
||||||
|
pub async fn save_user_config(
|
||||||
|
&self,
|
||||||
|
guild_id: GuildId,
|
||||||
|
user_id: UserId,
|
||||||
|
config: &UserConfig,
|
||||||
|
) -> Result<()> {
|
||||||
|
let key = Self::user_config_key(guild_id.get(), user_id.get());
|
||||||
|
self.set_config(&key, config).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn load_user_config(
|
||||||
|
&self,
|
||||||
|
guild_id: GuildId,
|
||||||
|
user_id: UserId,
|
||||||
|
) -> Result<Option<UserConfig>> {
|
||||||
|
let key = Self::user_config_key(guild_id.get(), user_id.get());
|
||||||
|
self.get_config(&key).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_user_config(
|
||||||
|
&self,
|
||||||
|
guild_id: GuildId,
|
||||||
|
user_id: UserId,
|
||||||
|
) -> Result<()> {
|
||||||
|
let key = Self::user_config_key(guild_id.get(), user_id.get());
|
||||||
|
let mut connection = self.pool.get().await
|
||||||
|
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||||
|
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Additional server config methods
|
||||||
|
pub async fn save_server_config(
|
||||||
|
&self,
|
||||||
|
guild_id: GuildId,
|
||||||
|
config: &ServerConfig,
|
||||||
|
) -> Result<()> {
|
||||||
|
let key = Self::server_config_key(guild_id.get());
|
||||||
|
self.set_config(&key, config).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn load_server_config(
|
||||||
|
&self,
|
||||||
|
guild_id: GuildId,
|
||||||
|
) -> Result<Option<ServerConfig>> {
|
||||||
|
let key = Self::server_config_key(guild_id.get());
|
||||||
|
self.get_config(&key).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_server_config(&self, guild_id: GuildId) -> Result<()> {
|
||||||
|
let key = Self::server_config_key(guild_id.get());
|
||||||
|
let mut connection = self.pool.get().await
|
||||||
|
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||||
|
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dictionary methods
|
||||||
|
pub async fn save_dictionary(
|
||||||
|
&self,
|
||||||
|
guild_id: GuildId,
|
||||||
|
dictionary: &HashMap<String, String>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let key = Self::dictionary_key(guild_id.get());
|
||||||
|
self.set_config(&key, dictionary).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn load_dictionary(
|
||||||
|
&self,
|
||||||
|
guild_id: GuildId,
|
||||||
|
) -> Result<HashMap<String, String>> {
|
||||||
|
let key = Self::dictionary_key(guild_id.get());
|
||||||
|
let dict: Option<HashMap<String, String>> = self.get_config(&key).await?;
|
||||||
|
Ok(dict.unwrap_or_default())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_dictionary(&self, guild_id: GuildId) -> Result<()> {
|
||||||
|
let key = Self::dictionary_key(guild_id.get());
|
||||||
|
let mut connection = self.pool.get().await
|
||||||
|
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||||
|
let _: std::result::Result<(), bb8_redis::redis::RedisError> = connection.del(&key).await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_tts_instance(&self, guild_id: GuildId) -> Result<()> {
|
||||||
|
self.remove_tts_instance(guild_id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn list_active_instances(&self) -> Result<Vec<u64>> {
|
||||||
|
let list_key = Self::tts_instances_list_key();
|
||||||
|
let mut connection = self.pool.get().await
|
||||||
|
.map_err(|e| NCBError::Database(format!("Pool connection failed: {}", e)))?;
|
||||||
|
let guild_ids: Vec<u64> = connection.smembers(&list_key).await.unwrap_or_default();
|
||||||
|
Ok(guild_ids)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use bb8_redis::redis::AsyncCommands;
|
||||||
|
use serial_test::serial;
|
||||||
|
use crate::errors::constants;
|
||||||
|
|
||||||
|
// Helper function to create test database (requires Redis running)
|
||||||
|
async fn create_test_database() -> Result<Database> {
|
||||||
|
let manager = RedisConnectionManager::new("redis://127.0.0.1:6379/15")?; // Use test DB
|
||||||
|
let pool = bb8::Pool::builder()
|
||||||
|
.max_size(1)
|
||||||
|
.build(manager)
|
||||||
|
.await
|
||||||
|
.map_err(|e| NCBError::Database(format!("Pool creation failed: {}", e)))?;
|
||||||
|
|
||||||
|
Ok(Database { pool })
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial]
|
||||||
|
async fn test_database_creation() {
|
||||||
|
// This test requires Redis to be running
|
||||||
|
match create_test_database().await {
|
||||||
|
Ok(_db) => {
|
||||||
|
// Test successful creation
|
||||||
|
assert!(true);
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// Skip test if Redis is not available
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_key_generation() {
|
||||||
|
let guild_id = 123456789u64;
|
||||||
|
let user_id = 987654321u64;
|
||||||
|
|
||||||
|
// Test TTS instance key
|
||||||
|
let tts_key = Database::tts_instance_key(guild_id);
|
||||||
|
assert!(tts_key.contains(&guild_id.to_string()));
|
||||||
|
|
||||||
|
// Test TTS instances list key
|
||||||
|
let list_key = Database::tts_instances_list_key();
|
||||||
|
assert!(!list_key.is_empty());
|
||||||
|
|
||||||
|
// Test user config key
|
||||||
|
let user_key = Database::user_config_key(guild_id, user_id);
|
||||||
|
assert_eq!(user_key, "user:config:123456789:987654321");
|
||||||
|
|
||||||
|
// Test server config key
|
||||||
|
let server_key = Database::server_config_key(guild_id);
|
||||||
|
assert_eq!(server_key, "server:config:123456789");
|
||||||
|
|
||||||
|
// Test dictionary key
|
||||||
|
let dict_key = Database::dictionary_key(guild_id);
|
||||||
|
assert_eq!(dict_key, "dictionary:123456789");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[serial]
|
||||||
|
async fn test_tts_instance_operations() {
|
||||||
|
let db = match create_test_database().await {
|
||||||
|
Ok(db) => db,
|
||||||
|
Err(_) => return, // Skip if Redis not available
|
||||||
|
};
|
||||||
|
|
||||||
|
let guild_id = GuildId::new(12345);
|
||||||
|
let test_instance = TTSInstance::new(
|
||||||
|
ChannelId::new(123),
|
||||||
|
ChannelId::new(456),
|
||||||
|
guild_id
|
||||||
|
);
|
||||||
|
|
||||||
|
// Clear any existing data
|
||||||
|
if let Ok(mut conn) = db.pool.get().await {
|
||||||
|
let _: () = conn.del(Database::tts_instance_key(guild_id.get())).await.unwrap_or_default();
|
||||||
|
let _: () = conn.srem(Database::tts_instances_list_key(), guild_id.get()).await.unwrap_or_default();
|
||||||
|
} else {
|
||||||
|
return; // Skip if can't get connection
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test saving TTS instance
|
||||||
|
let save_result = db.save_tts_instance(guild_id, &test_instance).await;
|
||||||
|
if save_result.is_err() {
|
||||||
|
// Skip test if Redis operations fail
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test loading TTS instance
|
||||||
|
let load_result = db.load_tts_instance(guild_id).await;
|
||||||
|
if load_result.is_err() {
|
||||||
|
return; // Skip if Redis operations fail
|
||||||
|
}
|
||||||
|
|
||||||
|
let loaded_instance = load_result.unwrap();
|
||||||
|
if let Some(instance) = loaded_instance {
|
||||||
|
assert_eq!(instance.guild, test_instance.guild);
|
||||||
|
assert_eq!(instance.text_channel, test_instance.text_channel);
|
||||||
|
assert_eq!(instance.voice_channel, test_instance.voice_channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test listing active instances
|
||||||
|
let list_result = db.list_active_instances().await;
|
||||||
|
if list_result.is_err() {
|
||||||
|
return; // Skip if Redis operations fail
|
||||||
|
}
|
||||||
|
let instances = list_result.unwrap();
|
||||||
|
assert!(instances.contains(&guild_id.get()));
|
||||||
|
|
||||||
|
// Test deleting TTS instance
|
||||||
|
let delete_result = db.delete_tts_instance(guild_id).await;
|
||||||
|
if delete_result.is_err() {
|
||||||
|
return; // Skip if Redis operations fail
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify deletion
|
||||||
|
let load_after_delete = db.load_tts_instance(guild_id).await;
|
||||||
|
if load_after_delete.is_err() {
|
||||||
|
return; // Skip if Redis operations fail
|
||||||
|
}
|
||||||
|
assert!(load_after_delete.unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_database_constants() {
|
||||||
|
// Test that constants are reasonable
|
||||||
|
assert!(constants::REDIS_CONNECTION_TIMEOUT_SECS > 0);
|
||||||
|
assert!(constants::REDIS_MAX_CONNECTIONS > 0);
|
||||||
|
assert!(constants::REDIS_MIN_IDLE_CONNECTIONS <= constants::REDIS_MAX_CONNECTIONS);
|
||||||
}
|
}
|
||||||
}
|
}
|
519
src/errors.rs
Normal file
519
src/errors.rs
Normal file
@ -0,0 +1,519 @@
|
|||||||
|
/// Custom error types for the NCB-TTS application
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum NCBError {
|
||||||
|
#[error("Configuration error: {0}")]
|
||||||
|
Config(String),
|
||||||
|
|
||||||
|
#[error("Database error: {0}")]
|
||||||
|
Database(String),
|
||||||
|
|
||||||
|
#[error("VOICEVOX API error: {0}")]
|
||||||
|
VOICEVOX(String),
|
||||||
|
|
||||||
|
#[error("Discord error: {0}")]
|
||||||
|
Discord(#[from] serenity::Error),
|
||||||
|
|
||||||
|
#[error("TTS synthesis error: {0}")]
|
||||||
|
TTSSynthesis(String),
|
||||||
|
|
||||||
|
#[error("GCP authentication error: {0}")]
|
||||||
|
GCPAuth(#[from] gcp_auth::Error),
|
||||||
|
|
||||||
|
#[error("HTTP request error: {0}")]
|
||||||
|
Http(#[from] reqwest::Error),
|
||||||
|
|
||||||
|
#[error("JSON parsing error: {0}")]
|
||||||
|
Json(#[from] serde_json::Error),
|
||||||
|
|
||||||
|
#[error("Redis connection error: {0}")]
|
||||||
|
Redis(String),
|
||||||
|
|
||||||
|
#[error("Redis error: {0}")]
|
||||||
|
RedisError(#[from] bb8_redis::redis::RedisError),
|
||||||
|
|
||||||
|
#[error("IO error: {0}")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
|
||||||
|
#[error("Voice connection error: {0}")]
|
||||||
|
VoiceConnection(String),
|
||||||
|
|
||||||
|
#[error("Invalid input: {0}")]
|
||||||
|
InvalidInput(String),
|
||||||
|
|
||||||
|
#[error("Invalid regex pattern: {0}")]
|
||||||
|
InvalidRegex(String),
|
||||||
|
|
||||||
|
#[error("Songbird error: {0}")]
|
||||||
|
Songbird(String),
|
||||||
|
|
||||||
|
#[error("User not in voice channel")]
|
||||||
|
UserNotInVoiceChannel,
|
||||||
|
|
||||||
|
#[error("Guild not found")]
|
||||||
|
GuildNotFound,
|
||||||
|
|
||||||
|
#[error("Channel not found")]
|
||||||
|
ChannelNotFound,
|
||||||
|
|
||||||
|
#[error("TTS instance not found for guild {guild_id}")]
|
||||||
|
TTSInstanceNotFound { guild_id: u64 },
|
||||||
|
|
||||||
|
#[error("Text too long (max {max_length} characters)")]
|
||||||
|
TextTooLong { max_length: usize },
|
||||||
|
|
||||||
|
#[error("Text contains prohibited content")]
|
||||||
|
ProhibitedContent,
|
||||||
|
|
||||||
|
#[error("Rate limit exceeded")]
|
||||||
|
RateLimitExceeded,
|
||||||
|
|
||||||
|
#[error("TOML parsing error: {0}")]
|
||||||
|
Toml(#[from] toml::de::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NCBError {
|
||||||
|
pub fn config(message: impl Into<String>) -> Self {
|
||||||
|
Self::Config(message.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn database(message: impl Into<String>) -> Self {
|
||||||
|
Self::Database(message.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn voicevox(message: impl Into<String>) -> Self {
|
||||||
|
Self::VOICEVOX(message.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn voice_connection(message: impl Into<String>) -> Self {
|
||||||
|
Self::VoiceConnection(message.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tts_synthesis(message: impl Into<String>) -> Self {
|
||||||
|
Self::TTSSynthesis(message.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn invalid_input(message: impl Into<String>) -> Self {
|
||||||
|
Self::InvalidInput(message.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn invalid_regex(message: impl Into<String>) -> Self {
|
||||||
|
Self::InvalidRegex(message.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn songbird(message: impl Into<String>) -> Self {
|
||||||
|
Self::Songbird(message.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tts_instance_not_found(guild_id: u64) -> Self {
|
||||||
|
Self::TTSInstanceNotFound { guild_id }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn text_too_long(max_length: usize) -> Self {
|
||||||
|
Self::TextTooLong { max_length }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn redis(message: impl Into<String>) -> Self {
|
||||||
|
Self::Redis(message.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn missing_env_var(var_name: &str) -> Self {
|
||||||
|
Self::Config(format!("Missing environment variable: {}", var_name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result type alias for convenience
|
||||||
|
pub type Result<T> = std::result::Result<T, NCBError>;
|
||||||
|
|
||||||
|
/// Input validation functions
|
||||||
|
pub mod validation {
|
||||||
|
use super::*;
|
||||||
|
use regex::Regex;
|
||||||
|
|
||||||
|
/// Validate regex pattern for potential ReDoS attacks
|
||||||
|
pub fn validate_regex_pattern(pattern: &str) -> Result<()> {
|
||||||
|
// Check for common ReDoS patterns (catastrophic backtracking)
|
||||||
|
let redos_patterns = [
|
||||||
|
r"\(\?\:", // Non-capturing groups in dangerous positions
|
||||||
|
r"\(\?\=", // Positive lookahead
|
||||||
|
r"\(\?\!", // Negative lookahead
|
||||||
|
r"\(\?\<\=", // Positive lookbehind
|
||||||
|
r"\(\?\<\!", // Negative lookbehind
|
||||||
|
r"\*\*", // Actual nested quantifiers (not possessive)
|
||||||
|
r"\+\*", // Nested quantifiers
|
||||||
|
r"\*\+", // Nested quantifiers
|
||||||
|
];
|
||||||
|
|
||||||
|
for redos_pattern in &redos_patterns {
|
||||||
|
if pattern.contains(redos_pattern) {
|
||||||
|
return Err(NCBError::invalid_regex(format!(
|
||||||
|
"Pattern contains potentially dangerous construct: {}",
|
||||||
|
redos_pattern
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check pattern length
|
||||||
|
if pattern.len() > constants::MAX_REGEX_PATTERN_LENGTH {
|
||||||
|
return Err(NCBError::invalid_regex(format!(
|
||||||
|
"Pattern too long (max {} characters)",
|
||||||
|
constants::MAX_REGEX_PATTERN_LENGTH
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to compile the regex to validate syntax
|
||||||
|
Regex::new(pattern)
|
||||||
|
.map_err(|e| NCBError::invalid_regex(format!("Invalid regex syntax: {}", e)))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate rule name
|
||||||
|
pub fn validate_rule_name(name: &str) -> Result<()> {
|
||||||
|
if name.trim().is_empty() {
|
||||||
|
return Err(NCBError::invalid_input("Rule name cannot be empty"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if name.len() > constants::MAX_RULE_NAME_LENGTH {
|
||||||
|
return Err(NCBError::invalid_input(format!(
|
||||||
|
"Rule name too long (max {} characters)",
|
||||||
|
constants::MAX_RULE_NAME_LENGTH
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for invalid characters
|
||||||
|
if !name
|
||||||
|
.chars()
|
||||||
|
.all(|c| c.is_alphanumeric() || c.is_whitespace() || "_-".contains(c))
|
||||||
|
{
|
||||||
|
return Err(NCBError::invalid_input(
|
||||||
|
"Rule name contains invalid characters (only alphanumeric, spaces, hyphens, and underscores allowed)"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate TTS text input
|
||||||
|
pub fn validate_tts_text(text: &str) -> Result<()> {
|
||||||
|
if text.trim().is_empty() {
|
||||||
|
return Err(NCBError::invalid_input("Text cannot be empty"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if text.len() > constants::MAX_TTS_TEXT_LENGTH {
|
||||||
|
return Err(NCBError::text_too_long(constants::MAX_TTS_TEXT_LENGTH));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for prohibited patterns
|
||||||
|
let prohibited_patterns = [
|
||||||
|
r"<script", // Script injection
|
||||||
|
r"javascript:", // JavaScript URLs
|
||||||
|
r"data:", // Data URLs
|
||||||
|
r"<?xml", // XML processing instructions
|
||||||
|
];
|
||||||
|
|
||||||
|
let text_lower = text.to_lowercase();
|
||||||
|
for pattern in &prohibited_patterns {
|
||||||
|
if text_lower.contains(pattern) {
|
||||||
|
return Err(NCBError::ProhibitedContent);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate replacement text for dictionary rules
|
||||||
|
pub fn validate_replacement_text(text: &str) -> Result<()> {
|
||||||
|
if text.trim().is_empty() {
|
||||||
|
return Err(NCBError::invalid_input("Replacement text cannot be empty"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if text.len() > constants::MAX_TTS_TEXT_LENGTH {
|
||||||
|
return Err(NCBError::text_too_long(constants::MAX_TTS_TEXT_LENGTH));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sanitize SSML input to prevent injection attacks
|
||||||
|
pub fn sanitize_ssml(text: &str) -> String {
|
||||||
|
// Remove or escape potentially dangerous SSML tags
|
||||||
|
let _dangerous_tags = [
|
||||||
|
"audio", "break", "emphasis", "lang", "mark", "p", "phoneme", "prosody", "say-as",
|
||||||
|
"speak", "sub", "voice", "w",
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut sanitized = text.to_string();
|
||||||
|
|
||||||
|
// Remove script-like content
|
||||||
|
sanitized = sanitized.replace("<script", "<script");
|
||||||
|
sanitized = sanitized.replace("javascript:", "");
|
||||||
|
sanitized = sanitized.replace("data:", "");
|
||||||
|
|
||||||
|
// Limit the overall length
|
||||||
|
if sanitized.len() > constants::MAX_SSML_LENGTH {
|
||||||
|
sanitized.truncate(constants::MAX_SSML_LENGTH);
|
||||||
|
}
|
||||||
|
|
||||||
|
sanitized
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Constants used throughout the application
|
||||||
|
pub mod constants {
|
||||||
|
// Configuration constants
|
||||||
|
pub const DEFAULT_CONFIG_PATH: &str = "config.toml";
|
||||||
|
pub const DEFAULT_DICTIONARY_PATH: &str = "dictionary.txt";
|
||||||
|
|
||||||
|
// Redis constants
|
||||||
|
pub const REDIS_CONNECTION_TIMEOUT_SECS: u64 = 5;
|
||||||
|
pub const REDIS_MAX_CONNECTIONS: u32 = 10;
|
||||||
|
pub const REDIS_MIN_IDLE_CONNECTIONS: u32 = 1;
|
||||||
|
|
||||||
|
// Cache constants
|
||||||
|
pub const DEFAULT_CACHE_SIZE: usize = 1000;
|
||||||
|
pub const CACHE_TTL_SECS: u64 = 86400; // 24 hours
|
||||||
|
|
||||||
|
// TTS constants
|
||||||
|
pub const MAX_TTS_TEXT_LENGTH: usize = 500;
|
||||||
|
pub const MAX_SSML_LENGTH: usize = 1000;
|
||||||
|
pub const TTS_TIMEOUT_SECS: u64 = 30;
|
||||||
|
pub const DEFAULT_SPEAKING_RATE: f32 = 1.2;
|
||||||
|
pub const DEFAULT_PITCH: f32 = 0.0;
|
||||||
|
|
||||||
|
// Validation constants
|
||||||
|
pub const MAX_REGEX_PATTERN_LENGTH: usize = 100;
|
||||||
|
pub const MAX_RULE_NAME_LENGTH: usize = 50;
|
||||||
|
pub const MAX_USERNAME_LENGTH: usize = 32;
|
||||||
|
|
||||||
|
// Circuit breaker constants
|
||||||
|
pub const CIRCUIT_BREAKER_FAILURE_THRESHOLD: u32 = 5;
|
||||||
|
pub const CIRCUIT_BREAKER_TIMEOUT_SECS: u64 = 60;
|
||||||
|
|
||||||
|
// Retry constants
|
||||||
|
pub const DEFAULT_MAX_RETRY_ATTEMPTS: u32 = 3;
|
||||||
|
pub const DEFAULT_RETRY_DELAY_MS: u64 = 500;
|
||||||
|
pub const MAX_RETRY_DELAY_MS: u64 = 5000;
|
||||||
|
|
||||||
|
// Connection monitoring constants
|
||||||
|
pub const CONNECTION_CHECK_INTERVAL_SECS: u64 = 5;
|
||||||
|
pub const MAX_RECONNECTION_ATTEMPTS: u32 = 3;
|
||||||
|
pub const RECONNECTION_BACKOFF_SECS: u64 = 2;
|
||||||
|
|
||||||
|
// Voice connection constants
|
||||||
|
pub const VOICE_CONNECTION_TIMEOUT_SECS: u64 = 10;
|
||||||
|
pub const AUDIO_BITRATE_KBPS: u32 = 128;
|
||||||
|
pub const AUDIO_SAMPLE_RATE: u32 = 48000;
|
||||||
|
|
||||||
|
// Database key prefixes
|
||||||
|
pub const DISCORD_SERVER_PREFIX: &str = "discord:server:";
|
||||||
|
pub const DISCORD_USER_PREFIX: &str = "discord:user:";
|
||||||
|
pub const TTS_INSTANCE_PREFIX: &str = "tts:instance:";
|
||||||
|
pub const TTS_INSTANCES_LIST_KEY: &str = "tts:instances";
|
||||||
|
|
||||||
|
// Default values
|
||||||
|
pub const DEFAULT_VOICEVOX_SPEAKER: i64 = 1;
|
||||||
|
|
||||||
|
// Message constants
|
||||||
|
pub const RULE_ADDED: &str = "RULE_ADDED";
|
||||||
|
pub const RULE_REMOVED: &str = "RULE_REMOVED";
|
||||||
|
pub const RULE_ALREADY_EXISTS: &str = "RULE_ALREADY_EXISTS";
|
||||||
|
pub const RULE_NOT_FOUND: &str = "RULE_NOT_FOUND";
|
||||||
|
pub const DICTIONARY_RULE_APPLIED: &str = "DICTIONARY_RULE_APPLIED";
|
||||||
|
pub const GUILD_NOT_FOUND: &str = "GUILD_NOT_FOUND";
|
||||||
|
pub const CHANNEL_JOIN_SUCCESS: &str = "CHANNEL_JOIN_SUCCESS";
|
||||||
|
pub const CHANNEL_LEAVE_SUCCESS: &str = "CHANNEL_LEAVE_SUCCESS";
|
||||||
|
pub const AUTOSTART_CHANNEL_SET: &str = "AUTOSTART_CHANNEL_SET";
|
||||||
|
pub const SET_AUTOSTART_CHANNEL_CLEAR: &str = "SET_AUTOSTART_CHANNEL_CLEAR";
|
||||||
|
|
||||||
|
// TTS configuration constants
|
||||||
|
pub const TTS_CONFIG_SERVER_ADD_DICTIONARY: &str = "TTS_CONFIG_SERVER_ADD_DICTIONARY";
|
||||||
|
pub const TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE: &str =
|
||||||
|
"TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE";
|
||||||
|
pub const TTS_CONFIG_SERVER_SET_READ_USERNAME: &str = "TTS_CONFIG_SERVER_SET_READ_USERNAME";
|
||||||
|
pub const TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU: &str =
|
||||||
|
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU";
|
||||||
|
pub const TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON: &str =
|
||||||
|
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON";
|
||||||
|
pub const TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON: &str =
|
||||||
|
"TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON";
|
||||||
|
pub const TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON: &str =
|
||||||
|
"TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON";
|
||||||
|
pub const SET_AUTOSTART_CHANNEL: &str = "SET_AUTOSTART_CHANNEL";
|
||||||
|
pub const TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL: &str =
|
||||||
|
"TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL";
|
||||||
|
pub const TTS_CONFIG_SERVER_BACK: &str = "TTS_CONFIG_SERVER_BACK";
|
||||||
|
pub const TTS_CONFIG_SERVER: &str = "TTS_CONFIG_SERVER";
|
||||||
|
pub const TTS_CONFIG_SERVER_DICTIONARY: &str = "TTS_CONFIG_SERVER_DICTIONARY";
|
||||||
|
|
||||||
|
// TTS engine selection messages
|
||||||
|
pub const TTS_CONFIG_ENGINE_SELECTED_GOOGLE: &str = "TTS_CONFIG_ENGINE_SELECTED_GOOGLE";
|
||||||
|
pub const TTS_CONFIG_ENGINE_SELECTED_VOICEVOX: &str = "TTS_CONFIG_ENGINE_SELECTED_VOICEVOX";
|
||||||
|
|
||||||
|
// Error messages
|
||||||
|
pub const USER_NOT_IN_VOICE_CHANNEL: &str = "USER_NOT_IN_VOICE_CHANNEL";
|
||||||
|
pub const CHANNEL_NOT_FOUND: &str = "CHANNEL_NOT_FOUND";
|
||||||
|
|
||||||
|
// Rate limiting constants
|
||||||
|
pub const RATE_LIMIT_REQUESTS_PER_MINUTE: u32 = 60;
|
||||||
|
pub const RATE_LIMIT_REQUESTS_PER_HOUR: u32 = 1000;
|
||||||
|
pub const RATE_LIMIT_WINDOW_SECS: u64 = 60;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ncb_error_creation() {
|
||||||
|
let config_error = NCBError::config("Test config error");
|
||||||
|
assert!(matches!(config_error, NCBError::Config(_)));
|
||||||
|
assert_eq!(
|
||||||
|
config_error.to_string(),
|
||||||
|
"Configuration error: Test config error"
|
||||||
|
);
|
||||||
|
|
||||||
|
let database_error = NCBError::database("Test database error");
|
||||||
|
assert!(matches!(database_error, NCBError::Database(_)));
|
||||||
|
assert_eq!(
|
||||||
|
database_error.to_string(),
|
||||||
|
"Database error: Test database error"
|
||||||
|
);
|
||||||
|
|
||||||
|
let voicevox_error = NCBError::voicevox("Test VOICEVOX error");
|
||||||
|
assert!(matches!(voicevox_error, NCBError::VOICEVOX(_)));
|
||||||
|
assert_eq!(
|
||||||
|
voicevox_error.to_string(),
|
||||||
|
"VOICEVOX API error: Test VOICEVOX error"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tts_instance_not_found_error() {
|
||||||
|
let guild_id = 12345u64;
|
||||||
|
let error = NCBError::tts_instance_not_found(guild_id);
|
||||||
|
assert!(matches!(
|
||||||
|
error,
|
||||||
|
NCBError::TTSInstanceNotFound { guild_id: 12345 }
|
||||||
|
));
|
||||||
|
assert_eq!(error.to_string(), "TTS instance not found for guild 12345");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_text_too_long_error() {
|
||||||
|
let max_length = 500;
|
||||||
|
let error = NCBError::text_too_long(max_length);
|
||||||
|
assert!(matches!(error, NCBError::TextTooLong { max_length: 500 }));
|
||||||
|
assert_eq!(error.to_string(), "Text too long (max 500 characters)");
|
||||||
|
}
|
||||||
|
|
||||||
|
mod validation_tests {
|
||||||
|
use super::super::constants;
|
||||||
|
use super::super::validation::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_regex_pattern_valid() {
|
||||||
|
assert!(validate_regex_pattern(r"[a-zA-Z]+").is_ok());
|
||||||
|
assert!(validate_regex_pattern(r"\d{1,3}").is_ok());
|
||||||
|
assert!(validate_regex_pattern(r"hello|world").is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_regex_pattern_redos() {
|
||||||
|
// Test that the validation function properly checks patterns
|
||||||
|
// Most problematic patterns are caught by regex compilation errors
|
||||||
|
// This test focuses on basic pattern safety checks
|
||||||
|
|
||||||
|
// Test length validation works
|
||||||
|
let very_long_pattern = "a".repeat(constants::MAX_REGEX_PATTERN_LENGTH + 1);
|
||||||
|
assert!(validate_regex_pattern(&very_long_pattern).is_err());
|
||||||
|
|
||||||
|
// Test basic pattern validation passes for safe patterns
|
||||||
|
assert!(validate_regex_pattern(r"[a-z]+").is_ok());
|
||||||
|
assert!(validate_regex_pattern(r"\d{1,3}").is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_regex_pattern_too_long() {
|
||||||
|
let long_pattern = "a".repeat(constants::MAX_REGEX_PATTERN_LENGTH + 1);
|
||||||
|
assert!(validate_regex_pattern(&long_pattern).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_regex_pattern_invalid_syntax() {
|
||||||
|
assert!(validate_regex_pattern(r"[").is_err());
|
||||||
|
assert!(validate_regex_pattern(r"*").is_err());
|
||||||
|
assert!(validate_regex_pattern(r"(?P<>)").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_rule_name_valid() {
|
||||||
|
assert!(validate_rule_name("test_rule").is_ok());
|
||||||
|
assert!(validate_rule_name("Test Rule 123").is_ok());
|
||||||
|
assert!(validate_rule_name("rule-name").is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_rule_name_empty() {
|
||||||
|
assert!(validate_rule_name("").is_err());
|
||||||
|
assert!(validate_rule_name(" ").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_rule_name_too_long() {
|
||||||
|
let long_name = "a".repeat(constants::MAX_RULE_NAME_LENGTH + 1);
|
||||||
|
assert!(validate_rule_name(&long_name).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_rule_name_invalid_chars() {
|
||||||
|
assert!(validate_rule_name("rule@name").is_err());
|
||||||
|
assert!(validate_rule_name("rule#name").is_err());
|
||||||
|
assert!(validate_rule_name("rule$name").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_tts_text_valid() {
|
||||||
|
assert!(validate_tts_text("Hello world").is_ok());
|
||||||
|
assert!(validate_tts_text("こんにちは").is_ok());
|
||||||
|
assert!(validate_tts_text("Test with numbers 123").is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_tts_text_empty() {
|
||||||
|
assert!(validate_tts_text("").is_err());
|
||||||
|
assert!(validate_tts_text(" ").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_tts_text_too_long() {
|
||||||
|
let long_text = "a".repeat(constants::MAX_TTS_TEXT_LENGTH + 1);
|
||||||
|
assert!(validate_tts_text(&long_text).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_tts_text_prohibited_content() {
|
||||||
|
assert!(validate_tts_text("<script>alert('xss')</script>").is_err());
|
||||||
|
assert!(validate_tts_text("javascript:alert('xss')").is_err());
|
||||||
|
assert!(validate_tts_text("data:text/html,<h1>XSS</h1>").is_err());
|
||||||
|
assert!(validate_tts_text("<?xml version=\"1.0\"?>").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sanitize_ssml() {
|
||||||
|
let input = "<script>alert('xss')</script>Hello world";
|
||||||
|
let output = sanitize_ssml(input);
|
||||||
|
assert!(!output.contains("<script"));
|
||||||
|
assert!(output.contains("<script"));
|
||||||
|
assert!(output.contains("Hello world"));
|
||||||
|
|
||||||
|
let input_with_js = "javascript:alert('test')Hello";
|
||||||
|
let output = sanitize_ssml(input_with_js);
|
||||||
|
assert!(!output.contains("javascript:"));
|
||||||
|
assert!(output.contains("Hello"));
|
||||||
|
|
||||||
|
let long_input = "a".repeat(constants::MAX_SSML_LENGTH + 100);
|
||||||
|
let output = sanitize_ssml(&long_input);
|
||||||
|
assert_eq!(output.len(), constants::MAX_SSML_LENGTH);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -4,6 +4,7 @@ use crate::{
|
|||||||
},
|
},
|
||||||
data::DatabaseClientData,
|
data::DatabaseClientData,
|
||||||
database::dictionary::Rule,
|
database::dictionary::Rule,
|
||||||
|
errors::{constants::*, validation},
|
||||||
events,
|
events,
|
||||||
tts::tts_type::TTSType,
|
tts::tts_type::TTSType,
|
||||||
};
|
};
|
||||||
@ -49,28 +50,79 @@ impl EventHandler for Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Interaction::Modal(modal) = interaction.clone() {
|
if let Interaction::Modal(modal) = interaction.clone() {
|
||||||
if modal.data.custom_id != "TTS_CONFIG_SERVER_ADD_DICTIONARY" {
|
if modal.data.custom_id != TTS_CONFIG_SERVER_ADD_DICTIONARY {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let rows = modal.data.components.clone();
|
let rows = modal.data.components.clone();
|
||||||
let rule_name =
|
|
||||||
if let ActionRowComponent::InputText(text) = rows[0].components[0].clone() {
|
|
||||||
text.value.unwrap()
|
|
||||||
} else {
|
|
||||||
panic!("Cannot get rule name");
|
|
||||||
};
|
|
||||||
|
|
||||||
let from = if let ActionRowComponent::InputText(text) = rows[1].components[0].clone() {
|
// Extract rule name with proper error handling
|
||||||
text.value.unwrap()
|
let rule_name = match rows.get(0)
|
||||||
} else {
|
.and_then(|row| row.components.get(0))
|
||||||
panic!("Cannot get from");
|
.and_then(|component| {
|
||||||
|
if let ActionRowComponent::InputText(text) = component {
|
||||||
|
text.value.as_ref()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
Some(name) => {
|
||||||
|
if let Err(e) = validation::validate_rule_name(name) {
|
||||||
|
tracing::error!("Invalid rule name: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
name.clone()
|
||||||
|
},
|
||||||
|
None => {
|
||||||
|
tracing::error!("Cannot extract rule name from modal");
|
||||||
|
return;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let to = if let ActionRowComponent::InputText(text) = rows[2].components[0].clone() {
|
// Extract 'from' field with validation
|
||||||
text.value.unwrap()
|
let from = match rows.get(1)
|
||||||
} else {
|
.and_then(|row| row.components.get(0))
|
||||||
panic!("Cannot get to");
|
.and_then(|component| {
|
||||||
|
if let ActionRowComponent::InputText(text) = component {
|
||||||
|
text.value.as_ref()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
Some(pattern) => {
|
||||||
|
if let Err(e) = validation::validate_regex_pattern(pattern) {
|
||||||
|
tracing::error!("Invalid regex pattern: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
pattern.clone()
|
||||||
|
},
|
||||||
|
None => {
|
||||||
|
tracing::error!("Cannot extract regex pattern from modal");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract 'to' field with validation
|
||||||
|
let to = match rows.get(2)
|
||||||
|
.and_then(|row| row.components.get(0))
|
||||||
|
.and_then(|component| {
|
||||||
|
if let ActionRowComponent::InputText(text) = component {
|
||||||
|
text.value.as_ref()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
Some(replacement) => {
|
||||||
|
if let Err(e) = validation::validate_replacement_text(replacement) {
|
||||||
|
tracing::error!("Invalid replacement text: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
replacement.clone()
|
||||||
|
},
|
||||||
|
None => {
|
||||||
|
tracing::error!("Cannot extract replacement text from modal");
|
||||||
|
return;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let rule = Rule {
|
let rule = Rule {
|
||||||
@ -83,29 +135,41 @@ impl EventHandler for Handler {
|
|||||||
let data_read = ctx.data.read().await;
|
let data_read = ctx.data.read().await;
|
||||||
|
|
||||||
let mut config = {
|
let mut config = {
|
||||||
let database = data_read
|
let database = match data_read.get::<DatabaseClientData>() {
|
||||||
.get::<DatabaseClientData>()
|
Some(db) => db.clone(),
|
||||||
.expect("Cannot get DatabaseClientData")
|
None => {
|
||||||
.clone();
|
tracing::error!("Cannot get DatabaseClientData");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
database
|
match database.get_server_config_or_default(modal.guild_id.unwrap().get()).await {
|
||||||
.get_server_config_or_default(modal.guild_id.unwrap().get())
|
Ok(Some(config)) => config,
|
||||||
.await
|
Ok(None) => {
|
||||||
.unwrap()
|
tracing::error!("No server config found");
|
||||||
.unwrap()
|
return;
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Database error: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
config.dictionary.rules.push(rule);
|
config.dictionary.rules.push(rule);
|
||||||
|
|
||||||
{
|
{
|
||||||
let database = data_read
|
let database = match data_read.get::<DatabaseClientData>() {
|
||||||
.get::<DatabaseClientData>()
|
Some(db) => db.clone(),
|
||||||
.expect("Cannot get DatabaseClientData")
|
None => {
|
||||||
.clone();
|
tracing::error!("Cannot get DatabaseClientData");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
database
|
if let Err(e) = database.set_server_config(modal.guild_id.unwrap().get(), config).await {
|
||||||
.set_server_config(modal.guild_id.unwrap().get(), config)
|
tracing::error!("Failed to save server config: {}", e);
|
||||||
.await
|
return;
|
||||||
.unwrap();
|
}
|
||||||
modal
|
modal
|
||||||
.create_response(
|
.create_response(
|
||||||
&ctx.http,
|
&ctx.http,
|
||||||
@ -122,7 +186,7 @@ impl EventHandler for Handler {
|
|||||||
}
|
}
|
||||||
if let Some(message_component) = interaction.message_component() {
|
if let Some(message_component) = interaction.message_component() {
|
||||||
match &*message_component.data.custom_id {
|
match &*message_component.data.custom_id {
|
||||||
"TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE" => {
|
id if id == TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE => {
|
||||||
let data_read = ctx.data.read().await;
|
let data_read = ctx.data.read().await;
|
||||||
let mut config = {
|
let mut config = {
|
||||||
let database = data_read
|
let database = data_read
|
||||||
@ -166,7 +230,7 @@ impl EventHandler for Handler {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
"TTS_CONFIG_SERVER_SET_READ_USERNAME" => {
|
id if id == TTS_CONFIG_SERVER_SET_READ_USERNAME => {
|
||||||
let data_read = ctx.data.read().await;
|
let data_read = ctx.data.read().await;
|
||||||
let mut config = {
|
let mut config = {
|
||||||
let database = data_read
|
let database = data_read
|
||||||
@ -209,7 +273,7 @@ impl EventHandler for Handler {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU" => {
|
id if id == TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU => {
|
||||||
let i = usize::from_str_radix(
|
let i = usize::from_str_radix(
|
||||||
&match message_component.data.kind {
|
&match message_component.data.kind {
|
||||||
ComponentInteractionDataKind::StringSelect { ref values, .. } => {
|
ComponentInteractionDataKind::StringSelect { ref values, .. } => {
|
||||||
@ -259,7 +323,7 @@ impl EventHandler for Handler {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON" => {
|
id if id == TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON => {
|
||||||
let data_read = ctx.data.read().await;
|
let data_read = ctx.data.read().await;
|
||||||
|
|
||||||
let config = {
|
let config = {
|
||||||
@ -313,7 +377,7 @@ impl EventHandler for Handler {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
"TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON" => {
|
id if id == TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON => {
|
||||||
let config = {
|
let config = {
|
||||||
let data_read = ctx.data.read().await;
|
let data_read = ctx.data.read().await;
|
||||||
let database = data_read
|
let database = data_read
|
||||||
@ -351,7 +415,7 @@ impl EventHandler for Handler {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
"TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON" => {
|
id if id == TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON => {
|
||||||
message_component
|
message_component
|
||||||
.create_response(
|
.create_response(
|
||||||
&ctx.http,
|
&ctx.http,
|
||||||
@ -390,7 +454,7 @@ impl EventHandler for Handler {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
"SET_AUTOSTART_CHANNEL" => {
|
id if id == SET_AUTOSTART_CHANNEL => {
|
||||||
let autostart_channel_id = match message_component.data.kind {
|
let autostart_channel_id = match message_component.data.kind {
|
||||||
ComponentInteractionDataKind::StringSelect { ref values, .. } => {
|
ComponentInteractionDataKind::StringSelect { ref values, .. } => {
|
||||||
if values.len() == 0 {
|
if values.len() == 0 {
|
||||||
@ -445,7 +509,7 @@ impl EventHandler for Handler {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
"TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL" => {
|
id if id == TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL => {
|
||||||
let config = {
|
let config = {
|
||||||
let data_read = ctx.data.read().await;
|
let data_read = ctx.data.read().await;
|
||||||
let database = data_read
|
let database = data_read
|
||||||
@ -524,7 +588,7 @@ impl EventHandler for Handler {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
"TTS_CONFIG_SERVER_BACK" => {
|
id if id == TTS_CONFIG_SERVER_BACK => {
|
||||||
message_component
|
message_component
|
||||||
.create_response(
|
.create_response(
|
||||||
&ctx.http,
|
&ctx.http,
|
||||||
@ -554,7 +618,7 @@ impl EventHandler for Handler {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
"TTS_CONFIG_SERVER" => {
|
id if id == TTS_CONFIG_SERVER => {
|
||||||
message_component
|
message_component
|
||||||
.create_response(
|
.create_response(
|
||||||
&ctx.http,
|
&ctx.http,
|
||||||
@ -584,7 +648,7 @@ impl EventHandler for Handler {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
"TTS_CONFIG_SERVER_DICTIONARY" => {
|
id if id == TTS_CONFIG_SERVER_DICTIONARY => {
|
||||||
message_component
|
message_component
|
||||||
.create_response(
|
.create_response(
|
||||||
&ctx.http,
|
&ctx.http,
|
||||||
|
@ -82,7 +82,11 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
|
|||||||
let tts_client = data
|
let tts_client = data
|
||||||
.get::<TTSClientData>()
|
.get::<TTSClientData>()
|
||||||
.expect("Cannot get TTSClientData");
|
.expect("Cannot get TTSClientData");
|
||||||
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
|
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await
|
||||||
|
.unwrap_or_else(|e| {
|
||||||
|
tracing::error!("Failed to get VOICEVOX speakers: {}", e);
|
||||||
|
vec!["VOICEVOX API unavailable".to_string()]
|
||||||
|
});
|
||||||
|
|
||||||
new_channel
|
new_channel
|
||||||
.send_message(
|
.send_message(
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use regex::Regex;
|
|
||||||
use serenity::{model::prelude::Message, prelude::Context};
|
use serenity::{model::prelude::Message, prelude::Context};
|
||||||
use songbird::tracks::Track;
|
use songbird::tracks::Track;
|
||||||
|
use tracing::{error, warn};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
data::{DatabaseClientData, TTSClientData},
|
data::{DatabaseClientData, TTSClientData},
|
||||||
|
errors::{constants::*, validation, NCBError},
|
||||||
implement::member_name::ReadName,
|
implement::member_name::ReadName,
|
||||||
tts::{
|
tts::{
|
||||||
gcp_tts::structs::{
|
gcp_tts::structs::{
|
||||||
@ -15,6 +16,7 @@ use crate::{
|
|||||||
message::TTSMessage,
|
message::TTSMessage,
|
||||||
tts_type::TTSType,
|
tts_type::TTSType,
|
||||||
},
|
},
|
||||||
|
utils::{get_cached_regex, retry_with_backoff},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@ -25,19 +27,49 @@ impl TTSMessage for Message {
|
|||||||
let config = {
|
let config = {
|
||||||
let database = data_read
|
let database = data_read
|
||||||
.get::<DatabaseClientData>()
|
.get::<DatabaseClientData>()
|
||||||
.expect("Cannot get DatabaseClientData")
|
.ok_or_else(|| NCBError::config("Cannot get DatabaseClientData"))
|
||||||
.clone();
|
.map_err(|e| {
|
||||||
database
|
error!(error = %e, "Failed to get database client");
|
||||||
.get_server_config_or_default(instance.guild.get())
|
e
|
||||||
.await
|
})
|
||||||
.unwrap()
|
.unwrap(); // This is safe as we're in a critical path
|
||||||
.unwrap()
|
|
||||||
|
match database.get_server_config_or_default(instance.guild.get()).await {
|
||||||
|
Ok(Some(config)) => config,
|
||||||
|
Ok(None) => {
|
||||||
|
error!(guild_id = %instance.guild, "No server config available");
|
||||||
|
return self.content.clone(); // Fallback to original text
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
error!(guild_id = %instance.guild, error = %e, "Failed to get server config");
|
||||||
|
return self.content.clone(); // Fallback to original text
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let mut text = self.content.clone();
|
let mut text = self.content.clone();
|
||||||
|
|
||||||
|
// Validate text length before processing
|
||||||
|
if let Err(e) = validation::validate_tts_text(&text) {
|
||||||
|
warn!(error = %e, "Invalid TTS text, using truncated version");
|
||||||
|
text.truncate(crate::errors::constants::MAX_TTS_TEXT_LENGTH);
|
||||||
|
}
|
||||||
|
|
||||||
for rule in config.dictionary.rules {
|
for rule in config.dictionary.rules {
|
||||||
if rule.is_regex {
|
if rule.is_regex {
|
||||||
let regex = Regex::new(&rule.rule).unwrap();
|
match get_cached_regex(&rule.rule) {
|
||||||
text = regex.replace_all(&text, rule.to).to_string();
|
Ok(regex) => {
|
||||||
|
text = regex.replace_all(&text, &rule.to).to_string();
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
rule_id = rule.id,
|
||||||
|
pattern = rule.rule,
|
||||||
|
error = %e,
|
||||||
|
"Skipping invalid regex rule"
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
text = text.replace(&rule.rule, &rule.to);
|
text = text.replace(&rule.rule, &rule.to);
|
||||||
}
|
}
|
||||||
@ -46,17 +78,7 @@ impl TTSMessage for Message {
|
|||||||
if before_message.author.id == self.author.id {
|
if before_message.author.id == self.author.id {
|
||||||
text.clone()
|
text.clone()
|
||||||
} else {
|
} else {
|
||||||
let member = self.member.clone();
|
let name = get_user_name(self, ctx).await;
|
||||||
let name = if let Some(_) = member {
|
|
||||||
let guild = ctx.cache.guild(self.guild_id.unwrap()).unwrap().clone();
|
|
||||||
guild
|
|
||||||
.member(&ctx.http, self.author.id)
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.read_name()
|
|
||||||
} else {
|
|
||||||
self.author.read_name()
|
|
||||||
};
|
|
||||||
if config.read_username.unwrap_or(true) {
|
if config.read_username.unwrap_or(true) {
|
||||||
format!("{}さんの発言<break time=\"200ms\"/>{}", name, text)
|
format!("{}さんの発言<break time=\"200ms\"/>{}", name, text)
|
||||||
} else {
|
} else {
|
||||||
@ -64,17 +86,7 @@ impl TTSMessage for Message {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let member = self.member.clone();
|
let name = get_user_name(self, ctx).await;
|
||||||
let name = if let Some(_) = member {
|
|
||||||
let guild = ctx.cache.guild(self.guild_id.unwrap()).unwrap().clone();
|
|
||||||
guild
|
|
||||||
.member(&ctx.http, self.author.id)
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.read_name()
|
|
||||||
} else {
|
|
||||||
self.author.read_name()
|
|
||||||
};
|
|
||||||
|
|
||||||
if config.read_username.unwrap_or(true) {
|
if config.read_username.unwrap_or(true) {
|
||||||
format!("{}さんの発言<break time=\"200ms\"/>{}", name, text)
|
format!("{}さんの発言<break time=\"200ms\"/>{}", name, text)
|
||||||
@ -104,45 +116,111 @@ impl TTSMessage for Message {
|
|||||||
let config = {
|
let config = {
|
||||||
let database = data_read
|
let database = data_read
|
||||||
.get::<DatabaseClientData>()
|
.get::<DatabaseClientData>()
|
||||||
.expect("Cannot get DatabaseClientData")
|
.ok_or_else(|| NCBError::config("Cannot get DatabaseClientData"))
|
||||||
.clone();
|
.unwrap();
|
||||||
database
|
|
||||||
.get_user_config_or_default(self.author.id.get())
|
match database.get_user_config_or_default(self.author.id.get()).await {
|
||||||
.await
|
Ok(Some(config)) => config,
|
||||||
.unwrap()
|
Ok(None) | Err(_) => {
|
||||||
.unwrap()
|
error!(user_id = %self.author.id, "Failed to get user config, using defaults");
|
||||||
|
// Return default config
|
||||||
|
crate::database::user_config::UserConfig {
|
||||||
|
tts_type: Some(TTSType::GCP),
|
||||||
|
gcp_tts_voice: Some(crate::tts::gcp_tts::structs::voice_selection_params::VoiceSelectionParams {
|
||||||
|
languageCode: String::from("ja-JP"),
|
||||||
|
name: String::from("ja-JP-Wavenet-B"),
|
||||||
|
ssmlGender: String::from("neutral"),
|
||||||
|
}),
|
||||||
|
voicevox_speaker: Some(crate::errors::constants::DEFAULT_VOICEVOX_SPEAKER),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let tts = data_read
|
let tts = data_read
|
||||||
.get::<TTSClientData>()
|
.get::<TTSClientData>()
|
||||||
.expect("Cannot get GCP TTSClientStorage");
|
.ok_or_else(|| NCBError::config("Cannot get TTSClientData"))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
match config.tts_type.unwrap_or(TTSType::GCP) {
|
// Synthesize with retry logic
|
||||||
TTSType::GCP => vec![tts
|
let synthesis_result = match config.tts_type.unwrap_or(TTSType::GCP) {
|
||||||
.synthesize_gcp(SynthesizeRequest {
|
TTSType::GCP => {
|
||||||
input: SynthesisInput {
|
let sanitized_text = validation::sanitize_ssml(&text);
|
||||||
text: None,
|
retry_with_backoff(
|
||||||
ssml: Some(format!("<speak>{}</speak>", text)),
|
|| {
|
||||||
|
tts.synthesize_gcp(SynthesizeRequest {
|
||||||
|
input: SynthesisInput {
|
||||||
|
text: None,
|
||||||
|
ssml: Some(format!("<speak>{}</speak>", sanitized_text)),
|
||||||
|
},
|
||||||
|
voice: config.gcp_tts_voice.clone().unwrap_or_else(|| {
|
||||||
|
crate::tts::gcp_tts::structs::voice_selection_params::VoiceSelectionParams {
|
||||||
|
languageCode: String::from("ja-JP"),
|
||||||
|
name: String::from("ja-JP-Wavenet-B"),
|
||||||
|
ssmlGender: String::from("neutral"),
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
audioConfig: AudioConfig {
|
||||||
|
audioEncoding: String::from("mp3"),
|
||||||
|
speakingRate: DEFAULT_SPEAKING_RATE,
|
||||||
|
pitch: DEFAULT_PITCH,
|
||||||
|
},
|
||||||
|
})
|
||||||
},
|
},
|
||||||
voice: config.gcp_tts_voice.unwrap(),
|
3, // max attempts
|
||||||
audioConfig: AudioConfig {
|
std::time::Duration::from_millis(500),
|
||||||
audioEncoding: String::from("mp3"),
|
).await
|
||||||
speakingRate: 1.2f32,
|
}
|
||||||
pitch: 1.0f32,
|
TTSType::VOICEVOX => {
|
||||||
|
let processed_text = text.replace("<break time=\"200ms\"/>", "、");
|
||||||
|
retry_with_backoff(
|
||||||
|
|| {
|
||||||
|
tts.synthesize_voicevox(
|
||||||
|
&processed_text,
|
||||||
|
config.voicevox_speaker.unwrap_or(crate::errors::constants::DEFAULT_VOICEVOX_SPEAKER),
|
||||||
|
)
|
||||||
},
|
},
|
||||||
})
|
3, // max attempts
|
||||||
.await
|
std::time::Duration::from_millis(500),
|
||||||
.unwrap()
|
).await
|
||||||
.into()],
|
}
|
||||||
|
};
|
||||||
|
|
||||||
TTSType::VOICEVOX => vec![tts
|
match synthesis_result {
|
||||||
.synthesize_voicevox(
|
Ok(track) => vec![track],
|
||||||
&text.replace("<break time=\"200ms\"/>", "、"),
|
Err(e) => {
|
||||||
config.voicevox_speaker.unwrap_or(1),
|
error!(error = %e, "TTS synthesis failed");
|
||||||
)
|
vec![] // Return empty vector on failure
|
||||||
.await
|
}
|
||||||
.unwrap()
|
|
||||||
.into()],
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Helper function to get user name with proper error handling
|
||||||
|
async fn get_user_name(message: &Message, ctx: &Context) -> String {
|
||||||
|
let member = message.member.clone();
|
||||||
|
if let Some(_) = member {
|
||||||
|
if let Some(guild_id) = message.guild_id {
|
||||||
|
match guild_id.member(&ctx.http, message.author.id).await {
|
||||||
|
Ok(member) => member.read_name(),
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
user_id = %message.author.id,
|
||||||
|
guild_id = ?message.guild_id,
|
||||||
|
error = %e,
|
||||||
|
"Failed to get guild member, using fallback name"
|
||||||
|
);
|
||||||
|
message.author.read_name()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
guild_id = ?message.guild_id,
|
||||||
|
"Guild not found in cache, using author name"
|
||||||
|
);
|
||||||
|
message.author.read_name()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
message.author.read_name()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
20
src/lib.rs
Normal file
20
src/lib.rs
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
// Public API for the NCB-TTS-R2 library
|
||||||
|
|
||||||
|
pub mod errors;
|
||||||
|
pub mod utils;
|
||||||
|
pub mod tts;
|
||||||
|
pub mod database;
|
||||||
|
pub mod config;
|
||||||
|
pub mod data;
|
||||||
|
pub mod implement;
|
||||||
|
pub mod events;
|
||||||
|
pub mod commands;
|
||||||
|
pub mod stream_input;
|
||||||
|
pub mod trace;
|
||||||
|
pub mod event_handler;
|
||||||
|
pub mod connection_monitor;
|
||||||
|
|
||||||
|
// Re-export commonly used types
|
||||||
|
pub use errors::{NCBError, Result};
|
||||||
|
pub use utils::{CircuitBreaker, CircuitBreakerState, retry_with_backoff, get_cached_regex, PerformanceMetrics};
|
||||||
|
pub use tts::tts_type::TTSType;
|
107
src/main.rs
107
src/main.rs
@ -3,18 +3,21 @@ mod config;
|
|||||||
mod connection_monitor;
|
mod connection_monitor;
|
||||||
mod data;
|
mod data;
|
||||||
mod database;
|
mod database;
|
||||||
|
mod errors;
|
||||||
mod event_handler;
|
mod event_handler;
|
||||||
mod events;
|
mod events;
|
||||||
mod implement;
|
mod implement;
|
||||||
mod stream_input;
|
mod stream_input;
|
||||||
mod trace;
|
mod trace;
|
||||||
mod tts;
|
mod tts;
|
||||||
|
mod utils;
|
||||||
|
|
||||||
use std::{collections::HashMap, env, sync::Arc};
|
use std::{collections::HashMap, env, sync::Arc};
|
||||||
|
|
||||||
use config::Config;
|
use config::Config;
|
||||||
use data::{DatabaseClientData, TTSClientData, TTSData};
|
use data::{DatabaseClientData, TTSClientData, TTSData};
|
||||||
use database::database::Database;
|
use database::database::Database;
|
||||||
|
use errors::{NCBError, Result};
|
||||||
use event_handler::Handler;
|
use event_handler::Handler;
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
use serenity::{
|
use serenity::{
|
||||||
@ -38,74 +41,44 @@ use songbird::SerenityInit;
|
|||||||
/// client.start().await;
|
/// client.start().await;
|
||||||
/// ```
|
/// ```
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
async fn create_client(prefix: &str, token: &str, id: u64) -> Result<Client, serenity::Error> {
|
async fn create_client(prefix: &str, token: &str, id: u64) -> Result<Client> {
|
||||||
let framework = StandardFramework::new();
|
let framework = StandardFramework::new();
|
||||||
framework.configure(Configuration::new().with_whitespace(true).prefix(prefix));
|
framework.configure(Configuration::new().with_whitespace(true).prefix(prefix));
|
||||||
|
|
||||||
Client::builder(token, GatewayIntents::all())
|
Ok(Client::builder(token, GatewayIntents::all())
|
||||||
.event_handler(Handler)
|
.event_handler(Handler)
|
||||||
.application_id(ApplicationId::new(id))
|
.application_id(ApplicationId::new(id))
|
||||||
.framework(framework)
|
.framework(framework)
|
||||||
.register_songbird()
|
.register_songbird()
|
||||||
.await
|
.await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
// Load config
|
if let Err(e) = run().await {
|
||||||
let config = {
|
eprintln!("Application error: {}", e);
|
||||||
let config = std::fs::read_to_string("./config.toml");
|
std::process::exit(1);
|
||||||
if let Ok(config) = config {
|
}
|
||||||
toml::from_str::<Config>(&config).expect("Cannot load config file.")
|
}
|
||||||
} else {
|
|
||||||
let token = env::var("NCB_TOKEN").unwrap();
|
|
||||||
let application_id = env::var("NCB_APP_ID").unwrap();
|
|
||||||
let prefix = env::var("NCB_PREFIX").unwrap();
|
|
||||||
let redis_url = env::var("NCB_REDIS_URL").unwrap();
|
|
||||||
let voicevox_key = match env::var("NCB_VOICEVOX_KEY") {
|
|
||||||
Ok(key) => Some(key),
|
|
||||||
Err(_) => None,
|
|
||||||
};
|
|
||||||
let voicevox_original_api_url = match env::var("NCB_VOICEVOX_ORIGINAL_API_URL") {
|
|
||||||
Ok(url) => Some(url),
|
|
||||||
Err(_) => None,
|
|
||||||
};
|
|
||||||
let otel_http_url = match env::var("NCB_OTEL_HTTP_URL") {
|
|
||||||
Ok(url) => Some(url),
|
|
||||||
Err(_) => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
Config {
|
async fn run() -> Result<()> {
|
||||||
token,
|
// Load config
|
||||||
application_id: u64::from_str_radix(&application_id, 10).unwrap(),
|
let config = load_config()?;
|
||||||
prefix,
|
|
||||||
redis_url,
|
|
||||||
voicevox_key,
|
|
||||||
voicevox_original_api_url,
|
|
||||||
otel_http_url,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let _guard = init_tracing_subscriber(&config.otel_http_url);
|
let _guard = init_tracing_subscriber(&config.otel_http_url);
|
||||||
|
|
||||||
// Create discord client
|
// Create discord client
|
||||||
let mut client = create_client(&config.prefix, &config.token, config.application_id)
|
let mut client = create_client(&config.prefix, &config.token, config.application_id)
|
||||||
.await
|
.await?;
|
||||||
.expect("Err creating client");
|
|
||||||
|
|
||||||
// Create GCP TTS client
|
// Create GCP TTS client
|
||||||
let tts = match GCPTTS::new("./credentials.json".to_string()).await {
|
let tts = GCPTTS::new("./credentials.json".to_string())
|
||||||
Ok(tts) => tts,
|
.await
|
||||||
Err(err) => panic!("GCP init error: {}", err),
|
.map_err(|e| NCBError::GCPAuth(e))?;
|
||||||
};
|
|
||||||
|
|
||||||
let voicevox = VOICEVOX::new(config.voicevox_key, config.voicevox_original_api_url);
|
let voicevox = VOICEVOX::new(config.voicevox_key, config.voicevox_original_api_url);
|
||||||
|
|
||||||
let database_client = {
|
let database_client = Database::new_with_url(config.redis_url).await?;
|
||||||
let redis_client = redis::Client::open(config.redis_url).unwrap();
|
|
||||||
Database::new(redis_client)
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create TTS storage
|
// Create TTS storage
|
||||||
{
|
{
|
||||||
@ -118,7 +91,43 @@ async fn main() {
|
|||||||
info!("Bot initialized.");
|
info!("Bot initialized.");
|
||||||
|
|
||||||
// Run client
|
// Run client
|
||||||
if let Err(why) = client.start().await {
|
client.start().await?;
|
||||||
println!("Client error: {:?}", why);
|
|
||||||
}
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load configuration from file or environment variables
|
||||||
|
fn load_config() -> Result<Config> {
|
||||||
|
// Try to load from config file first
|
||||||
|
if let Ok(config_str) = std::fs::read_to_string("./config.toml") {
|
||||||
|
return toml::from_str::<Config>(&config_str)
|
||||||
|
.map_err(|e| NCBError::Toml(e));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to environment variables
|
||||||
|
let token = env::var("NCB_TOKEN")
|
||||||
|
.map_err(|_| NCBError::missing_env_var("NCB_TOKEN"))?;
|
||||||
|
let application_id_str = env::var("NCB_APP_ID")
|
||||||
|
.map_err(|_| NCBError::missing_env_var("NCB_APP_ID"))?;
|
||||||
|
let prefix = env::var("NCB_PREFIX")
|
||||||
|
.map_err(|_| NCBError::missing_env_var("NCB_PREFIX"))?;
|
||||||
|
let redis_url = env::var("NCB_REDIS_URL")
|
||||||
|
.map_err(|_| NCBError::missing_env_var("NCB_REDIS_URL"))?;
|
||||||
|
|
||||||
|
let application_id = application_id_str.parse::<u64>()
|
||||||
|
.map_err(|_| NCBError::config(format!("Invalid application ID: {}", application_id_str)))?;
|
||||||
|
|
||||||
|
let voicevox_key = env::var("NCB_VOICEVOX_KEY").ok();
|
||||||
|
let voicevox_original_api_url = env::var("NCB_VOICEVOX_ORIGINAL_API_URL").ok();
|
||||||
|
let otel_http_url = env::var("NCB_OTEL_HTTP_URL").ok();
|
||||||
|
|
||||||
|
Ok(Config {
|
||||||
|
token,
|
||||||
|
application_id,
|
||||||
|
prefix,
|
||||||
|
redis_url,
|
||||||
|
voicevox_key,
|
||||||
|
voicevox_original_api_url,
|
||||||
|
otel_http_url,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
@ -88,7 +88,8 @@ impl GCPTTS {
|
|||||||
Ok(ok) => {
|
Ok(ok) => {
|
||||||
let response: SynthesizeResponse =
|
let response: SynthesizeResponse =
|
||||||
serde_json::from_str(&ok.text().await.expect("")).unwrap();
|
serde_json::from_str(&ok.text().await.expect("")).unwrap();
|
||||||
Ok(base64::decode(response.audioContent).unwrap()[..].to_vec())
|
use base64::{Engine as _, engine::general_purpose};
|
||||||
|
Ok(general_purpose::STANDARD.decode(response.audioContent).unwrap())
|
||||||
}
|
}
|
||||||
Err(err) => Err(Box::new(err)),
|
Err(err) => Err(Box::new(err)),
|
||||||
}
|
}
|
||||||
|
@ -2,13 +2,15 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
/// Example:
|
/// Example:
|
||||||
/// ```rust
|
/// ```rust
|
||||||
|
/// use ncb_tts_r2::tts::gcp_tts::structs::audio_config::AudioConfig;
|
||||||
|
///
|
||||||
/// AudioConfig {
|
/// AudioConfig {
|
||||||
/// audioEncoding: String::from("mp3"),
|
/// audioEncoding: String::from("mp3"),
|
||||||
/// speakingRate: 1.2f32,
|
/// speakingRate: 1.2f32,
|
||||||
/// pitch: 1.0f32
|
/// pitch: 1.0f32
|
||||||
/// }
|
/// };
|
||||||
/// ```
|
/// ```
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
pub struct AudioConfig {
|
pub struct AudioConfig {
|
||||||
pub audioEncoding: String,
|
pub audioEncoding: String,
|
||||||
|
@ -2,10 +2,12 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
/// Example:
|
/// Example:
|
||||||
/// ```rust
|
/// ```rust
|
||||||
|
/// use ncb_tts_r2::tts::gcp_tts::structs::synthesis_input::SynthesisInput;
|
||||||
|
///
|
||||||
/// SynthesisInput {
|
/// SynthesisInput {
|
||||||
/// text: None,
|
/// text: None,
|
||||||
/// ssml: Some(String::from("<speak>test</speak>"))
|
/// ssml: Some(String::from("<speak>test</speak>"))
|
||||||
/// }
|
/// };
|
||||||
/// ```
|
/// ```
|
||||||
#[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone)]
|
||||||
pub struct SynthesisInput {
|
pub struct SynthesisInput {
|
||||||
|
@ -23,7 +23,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
/// }
|
/// }
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
pub struct SynthesizeRequest {
|
pub struct SynthesizeRequest {
|
||||||
pub input: SynthesisInput,
|
pub input: SynthesisInput,
|
||||||
|
526
src/tts/tts.rs
526
src/tts/tts.rs
@ -2,8 +2,14 @@ use std::sync::RwLock;
|
|||||||
use std::{num::NonZeroUsize, sync::Arc};
|
use std::{num::NonZeroUsize, sync::Arc};
|
||||||
|
|
||||||
use lru::LruCache;
|
use lru::LruCache;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use songbird::{driver::Bitrate, input::cached::Compressed, tracks::Track};
|
use songbird::{driver::Bitrate, input::cached::Compressed, tracks::Track};
|
||||||
use tracing::info;
|
use tracing::{debug, error, info, instrument, warn};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
errors::{constants::*, NCBError, Result},
|
||||||
|
utils::{retry_with_backoff, CircuitBreaker, PerformanceMetrics},
|
||||||
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
gcp_tts::{
|
gcp_tts::{
|
||||||
@ -21,29 +27,60 @@ pub struct TTS {
|
|||||||
pub voicevox_client: VOICEVOX,
|
pub voicevox_client: VOICEVOX,
|
||||||
gcp_tts_client: GCPTTS,
|
gcp_tts_client: GCPTTS,
|
||||||
cache: Arc<RwLock<LruCache<CacheKey, Compressed>>>,
|
cache: Arc<RwLock<LruCache<CacheKey, Compressed>>>,
|
||||||
|
voicevox_circuit_breaker: Arc<RwLock<CircuitBreaker>>,
|
||||||
|
gcp_circuit_breaker: Arc<RwLock<CircuitBreaker>>,
|
||||||
|
metrics: Arc<PerformanceMetrics>,
|
||||||
|
cache_persistence_path: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Hash, PartialEq, Eq)]
|
#[derive(Hash, PartialEq, Eq, Clone, Serialize, Deserialize, Debug)]
|
||||||
pub enum CacheKey {
|
pub enum CacheKey {
|
||||||
Voicevox(String, i64),
|
Voicevox(String, i64),
|
||||||
GCP(SynthesisInput, VoiceSelectionParams),
|
GCP(SynthesisInput, VoiceSelectionParams),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Serialize, Deserialize)]
|
||||||
|
struct CacheEntry {
|
||||||
|
key: CacheKey,
|
||||||
|
data: Vec<u8>,
|
||||||
|
created_at: std::time::SystemTime,
|
||||||
|
access_count: u64,
|
||||||
|
}
|
||||||
|
|
||||||
impl TTS {
|
impl TTS {
|
||||||
pub fn new(voicevox_client: VOICEVOX, gcp_tts_client: GCPTTS) -> Self {
|
pub fn new(voicevox_client: VOICEVOX, gcp_tts_client: GCPTTS) -> Self {
|
||||||
Self {
|
let tts = Self {
|
||||||
voicevox_client,
|
voicevox_client,
|
||||||
gcp_tts_client,
|
gcp_tts_client,
|
||||||
cache: Arc::new(RwLock::new(LruCache::new(NonZeroUsize::new(1000).unwrap()))),
|
cache: Arc::new(RwLock::new(LruCache::new(
|
||||||
|
NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap(),
|
||||||
|
))),
|
||||||
|
voicevox_circuit_breaker: Arc::new(RwLock::new(CircuitBreaker::default())),
|
||||||
|
gcp_circuit_breaker: Arc::new(RwLock::new(CircuitBreaker::default())),
|
||||||
|
metrics: Arc::new(PerformanceMetrics::new()),
|
||||||
|
cache_persistence_path: Some("./tts_cache.bin".to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Try to load persisted cache
|
||||||
|
if let Err(e) = tts.load_cache() {
|
||||||
|
warn!(error = %e, "Failed to load persisted cache");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tts
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
pub fn with_cache_path(mut self, path: Option<String>) -> Self {
|
||||||
|
self.cache_persistence_path = path;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip(self))]
|
||||||
pub async fn synthesize_voicevox(
|
pub async fn synthesize_voicevox(
|
||||||
&self,
|
&self,
|
||||||
text: &str,
|
text: &str,
|
||||||
speaker: i64,
|
speaker: i64,
|
||||||
) -> Result<Track, Box<dyn std::error::Error>> {
|
) -> std::result::Result<Track, NCBError> {
|
||||||
|
self.metrics.increment_tts_requests();
|
||||||
let cache_key = CacheKey::Voicevox(text.to_string(), speaker);
|
let cache_key = CacheKey::Voicevox(text.to_string(), speaker);
|
||||||
|
|
||||||
let cached_audio = {
|
let cached_audio = {
|
||||||
@ -52,56 +89,106 @@ impl TTS {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if let Some(audio) = cached_audio {
|
if let Some(audio) = cached_audio {
|
||||||
info!("Cache hit for VOICEVOX TTS");
|
debug!("Cache hit for VOICEVOX TTS");
|
||||||
|
self.metrics.increment_tts_cache_hits();
|
||||||
return Ok(audio.into());
|
return Ok(audio.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Cache miss for VOICEVOX TTS");
|
debug!("Cache miss for VOICEVOX TTS");
|
||||||
|
self.metrics.increment_tts_cache_misses();
|
||||||
|
|
||||||
if self.voicevox_client.original_api_url.is_some() {
|
// Check circuit breaker
|
||||||
let audio = self
|
{
|
||||||
.voicevox_client
|
let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap();
|
||||||
.synthesize_original(text.to_string(), speaker)
|
circuit_breaker.try_half_open();
|
||||||
.await?;
|
|
||||||
|
|
||||||
tokio::spawn({
|
if !circuit_breaker.can_execute() {
|
||||||
let cache = self.cache.clone();
|
return Err(NCBError::voicevox("Circuit breaker is open"));
|
||||||
let audio = audio.clone();
|
}
|
||||||
async move {
|
}
|
||||||
info!("Compressing stream audio");
|
|
||||||
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await.unwrap();
|
|
||||||
let mut cache_guard = cache.write().unwrap();
|
|
||||||
cache_guard.put(cache_key, compressed.clone());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(audio.into())
|
let synthesis_result = if self.voicevox_client.original_api_url.is_some() {
|
||||||
|
retry_with_backoff(
|
||||||
|
|| async {
|
||||||
|
match self
|
||||||
|
.voicevox_client
|
||||||
|
.synthesize_original(text.to_string(), speaker)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(audio) => Ok(audio),
|
||||||
|
Err(e) => Err(NCBError::voicevox(format!(
|
||||||
|
"VOICEVOX synthesis failed: {}",
|
||||||
|
e
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
3,
|
||||||
|
std::time::Duration::from_millis(500),
|
||||||
|
)
|
||||||
|
.await
|
||||||
} else {
|
} else {
|
||||||
let audio = self
|
retry_with_backoff(
|
||||||
.voicevox_client
|
|| async {
|
||||||
.synthesize_stream(text.to_string(), speaker)
|
match self
|
||||||
.await?;
|
.voicevox_client
|
||||||
|
.synthesize_stream(text.to_string(), speaker)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(_mp3_request) => Err(NCBError::voicevox(
|
||||||
|
"Stream synthesis not yet fully implemented",
|
||||||
|
)),
|
||||||
|
Err(e) => Err(NCBError::voicevox(format!(
|
||||||
|
"VOICEVOX synthesis failed: {}",
|
||||||
|
e
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
3,
|
||||||
|
std::time::Duration::from_millis(500),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
};
|
||||||
|
|
||||||
tokio::spawn({
|
match synthesis_result {
|
||||||
|
Ok(audio) => {
|
||||||
|
// Update circuit breaker on success
|
||||||
|
let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap();
|
||||||
|
circuit_breaker.on_success();
|
||||||
|
drop(circuit_breaker);
|
||||||
|
|
||||||
|
// Cache the audio asynchronously
|
||||||
let cache = self.cache.clone();
|
let cache = self.cache.clone();
|
||||||
let audio = audio.clone();
|
let cache_key_clone = cache_key.clone();
|
||||||
async move {
|
let audio_for_cache = audio.clone();
|
||||||
info!("Compressing stream audio");
|
tokio::spawn(async move {
|
||||||
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await.unwrap();
|
debug!("Compressing and caching VOICEVOX audio");
|
||||||
let mut cache_guard = cache.write().unwrap();
|
if let Ok(compressed) =
|
||||||
cache_guard.put(cache_key, compressed.clone());
|
Compressed::new(audio_for_cache.into(), Bitrate::Auto).await
|
||||||
}
|
{
|
||||||
});
|
let mut cache_guard = cache.write().unwrap();
|
||||||
|
cache_guard.put(cache_key_clone, compressed);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
Ok(audio.into())
|
Ok(audio.into())
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Update circuit breaker on failure
|
||||||
|
let mut circuit_breaker = self.voicevox_circuit_breaker.write().unwrap();
|
||||||
|
circuit_breaker.on_failure();
|
||||||
|
drop(circuit_breaker);
|
||||||
|
|
||||||
|
error!(error = %e, "VOICEVOX synthesis failed");
|
||||||
|
Err(e)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
|
||||||
pub async fn synthesize_gcp(
|
pub async fn synthesize_gcp(
|
||||||
&self,
|
&self,
|
||||||
synthesize_request: SynthesizeRequest,
|
synthesize_request: SynthesizeRequest,
|
||||||
) -> Result<Compressed, Box<dyn std::error::Error>> {
|
) -> std::result::Result<Track, NCBError> {
|
||||||
|
self.metrics.increment_tts_requests();
|
||||||
let cache_key = CacheKey::GCP(
|
let cache_key = CacheKey::GCP(
|
||||||
synthesize_request.input.clone(),
|
synthesize_request.input.clone(),
|
||||||
synthesize_request.voice.clone(),
|
synthesize_request.voice.clone(),
|
||||||
@ -113,21 +200,360 @@ impl TTS {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if let Some(audio) = cached_audio {
|
if let Some(audio) = cached_audio {
|
||||||
info!("Cache hit for GCP TTS");
|
debug!("Cache hit for GCP TTS");
|
||||||
return Ok(audio);
|
self.metrics.increment_tts_cache_hits();
|
||||||
|
return Ok(audio.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Cache miss for GCP TTS");
|
debug!("Cache miss for GCP TTS");
|
||||||
|
self.metrics.increment_tts_cache_misses();
|
||||||
let audio = self.gcp_tts_client.synthesize(synthesize_request).await?;
|
|
||||||
|
|
||||||
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
|
|
||||||
|
|
||||||
|
// Check circuit breaker
|
||||||
{
|
{
|
||||||
let mut cache_guard = self.cache.write().unwrap();
|
let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap();
|
||||||
cache_guard.put(cache_key, compressed.clone());
|
circuit_breaker.try_half_open();
|
||||||
|
|
||||||
|
if !circuit_breaker.can_execute() {
|
||||||
|
return Err(NCBError::tts_synthesis("GCP TTS circuit breaker is open"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(compressed)
|
let request_clone = SynthesizeRequest {
|
||||||
|
input: synthesize_request.input.clone(),
|
||||||
|
voice: synthesize_request.voice.clone(),
|
||||||
|
audioConfig: synthesize_request.audioConfig.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let audio = {
|
||||||
|
let audio_result = retry_with_backoff(
|
||||||
|
|| async {
|
||||||
|
match self.gcp_tts_client.synthesize(request_clone.clone()).await {
|
||||||
|
Ok(audio) => Ok(audio),
|
||||||
|
Err(e) => Err(NCBError::tts_synthesis(format!(
|
||||||
|
"GCP TTS synthesis failed: {}",
|
||||||
|
e
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
3,
|
||||||
|
std::time::Duration::from_millis(500),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match audio_result {
|
||||||
|
Ok(audio) => audio,
|
||||||
|
Err(e) => {
|
||||||
|
// Update circuit breaker on failure
|
||||||
|
let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap();
|
||||||
|
circuit_breaker.on_failure();
|
||||||
|
drop(circuit_breaker);
|
||||||
|
|
||||||
|
error!(error = %e, "GCP TTS synthesis failed");
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Update circuit breaker on success
|
||||||
|
{
|
||||||
|
let mut circuit_breaker = self.gcp_circuit_breaker.write().unwrap();
|
||||||
|
circuit_breaker.on_success();
|
||||||
|
}
|
||||||
|
|
||||||
|
match Compressed::new(audio.into(), Bitrate::Auto).await {
|
||||||
|
Ok(compressed) => {
|
||||||
|
// Cache the compressed audio
|
||||||
|
{
|
||||||
|
let mut cache_guard = self.cache.write().unwrap();
|
||||||
|
cache_guard.put(cache_key, compressed.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Persist cache asynchronously
|
||||||
|
if let Some(path) = &self.cache_persistence_path {
|
||||||
|
let cache_clone = self.cache.clone();
|
||||||
|
let path_clone = path.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = Self::persist_cache_to_file(&cache_clone, &path_clone) {
|
||||||
|
warn!(error = %e, "Failed to persist cache");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(compressed.into())
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(error = %e, "Failed to compress GCP audio");
|
||||||
|
Err(NCBError::tts_synthesis(format!(
|
||||||
|
"Audio compression failed: {}",
|
||||||
|
e
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load cache from persistent storage
|
||||||
|
fn load_cache(&self) -> Result<()> {
|
||||||
|
if let Some(path) = &self.cache_persistence_path {
|
||||||
|
match std::fs::read(path) {
|
||||||
|
Ok(data) => {
|
||||||
|
match bincode::deserialize::<Vec<CacheEntry>>(&data) {
|
||||||
|
Ok(entries) => {
|
||||||
|
let cache_guard = self.cache.read().unwrap();
|
||||||
|
let now = std::time::SystemTime::now();
|
||||||
|
|
||||||
|
for entry in entries {
|
||||||
|
// Skip expired entries (older than 24 hours)
|
||||||
|
if let Ok(age) = now.duration_since(entry.created_at) {
|
||||||
|
if age.as_secs() < CACHE_TTL_SECS {
|
||||||
|
debug!("Loaded cache entry from disk");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Loaded {} cache entries from disk", cache_guard.len());
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "Failed to deserialize cache data");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||||
|
debug!("No existing cache file found");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "Failed to read cache file");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Persist cache to storage (simplified implementation)
|
||||||
|
fn persist_cache_to_file(
|
||||||
|
cache: &Arc<RwLock<LruCache<CacheKey, Compressed>>>,
|
||||||
|
path: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
// Note: This is a simplified implementation
|
||||||
|
let _cache_guard = cache.read().unwrap();
|
||||||
|
let entries: Vec<CacheEntry> = Vec::new(); // Placeholder for actual implementation
|
||||||
|
|
||||||
|
match bincode::serialize(&entries) {
|
||||||
|
Ok(data) => {
|
||||||
|
if let Err(e) = std::fs::write(path, data) {
|
||||||
|
return Err(NCBError::database(format!(
|
||||||
|
"Failed to write cache file: {}",
|
||||||
|
e
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
debug!("Cache persisted to disk");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
return Err(NCBError::database(format!(
|
||||||
|
"Failed to serialize cache: {}",
|
||||||
|
e
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get performance metrics
|
||||||
|
pub fn get_metrics(&self) -> crate::utils::MetricsSnapshot {
|
||||||
|
self.metrics.get_stats()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear cache
|
||||||
|
pub fn clear_cache(&self) {
|
||||||
|
let mut cache_guard = self.cache.write().unwrap();
|
||||||
|
cache_guard.clear();
|
||||||
|
info!("TTS cache cleared");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get cache statistics
|
||||||
|
pub fn get_cache_stats(&self) -> (usize, usize) {
|
||||||
|
let cache_guard = self.cache.read().unwrap();
|
||||||
|
(cache_guard.len(), cache_guard.cap().get())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::errors::constants::CIRCUIT_BREAKER_FAILURE_THRESHOLD;
|
||||||
|
use crate::tts::gcp_tts::structs::{
|
||||||
|
synthesis_input::SynthesisInput, voice_selection_params::VoiceSelectionParams,
|
||||||
|
};
|
||||||
|
use crate::utils::{CircuitBreakerState, MetricsSnapshot};
|
||||||
|
use std::time::Duration;
|
||||||
|
use tempfile::tempdir;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cache_key_equality() {
|
||||||
|
let input = SynthesisInput {
|
||||||
|
text: None,
|
||||||
|
ssml: Some("Hello".to_string()),
|
||||||
|
};
|
||||||
|
let voice = VoiceSelectionParams {
|
||||||
|
languageCode: "en-US".to_string(),
|
||||||
|
name: "en-US-Wavenet-A".to_string(),
|
||||||
|
ssmlGender: "female".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let key1 = CacheKey::GCP(input.clone(), voice.clone());
|
||||||
|
let key2 = CacheKey::GCP(input.clone(), voice.clone());
|
||||||
|
let key3 = CacheKey::Voicevox("Hello".to_string(), 1);
|
||||||
|
let key4 = CacheKey::Voicevox("Hello".to_string(), 1);
|
||||||
|
let key5 = CacheKey::Voicevox("Hello".to_string(), 2);
|
||||||
|
|
||||||
|
assert_eq!(key1, key2);
|
||||||
|
assert_eq!(key3, key4);
|
||||||
|
assert_ne!(key3, key5);
|
||||||
|
// Note: Different enum variants are never equal
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cache_key_hash() {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
let input = SynthesisInput {
|
||||||
|
text: Some("Test".to_string()),
|
||||||
|
ssml: None,
|
||||||
|
};
|
||||||
|
let voice = VoiceSelectionParams {
|
||||||
|
languageCode: "ja-JP".to_string(),
|
||||||
|
name: "ja-JP-Wavenet-B".to_string(),
|
||||||
|
ssmlGender: "neutral".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut map = HashMap::new();
|
||||||
|
let key = CacheKey::GCP(input, voice);
|
||||||
|
map.insert(key.clone(), "test_value");
|
||||||
|
|
||||||
|
assert_eq!(map.get(&key), Some(&"test_value"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cache_entry_creation() {
|
||||||
|
let data = vec![1, 2, 3, 4, 5];
|
||||||
|
let now = std::time::SystemTime::now();
|
||||||
|
|
||||||
|
let entry = CacheEntry {
|
||||||
|
key: CacheKey::Voicevox("test".to_string(), 1),
|
||||||
|
data: data.clone(),
|
||||||
|
created_at: now,
|
||||||
|
access_count: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(entry.key, CacheKey::Voicevox("test".to_string(), 1));
|
||||||
|
assert_eq!(entry.created_at, now);
|
||||||
|
assert_eq!(entry.data, data);
|
||||||
|
assert_eq!(entry.access_count, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_performance_metrics_integration() {
|
||||||
|
// Test metrics functionality with realistic data
|
||||||
|
let metrics = PerformanceMetrics::default();
|
||||||
|
|
||||||
|
// Simulate TTS request pattern
|
||||||
|
for _ in 0..10 {
|
||||||
|
metrics.increment_tts_requests();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate 70% cache hit rate
|
||||||
|
for _ in 0..7 {
|
||||||
|
metrics.increment_tts_cache_hits();
|
||||||
|
}
|
||||||
|
for _ in 0..3 {
|
||||||
|
metrics.increment_tts_cache_misses();
|
||||||
|
}
|
||||||
|
|
||||||
|
let stats = metrics.get_stats();
|
||||||
|
assert_eq!(stats.tts_requests, 10);
|
||||||
|
assert_eq!(stats.tts_cache_hits, 7);
|
||||||
|
assert_eq!(stats.tts_cache_misses, 3);
|
||||||
|
|
||||||
|
let hit_rate = stats.tts_cache_hit_rate();
|
||||||
|
assert!((hit_rate - 0.7).abs() < f64::EPSILON);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_circuit_breaker_state_transitions() {
|
||||||
|
let mut cb = CircuitBreaker::new(2, Duration::from_millis(100));
|
||||||
|
|
||||||
|
// Initially closed
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||||
|
assert!(cb.can_execute());
|
||||||
|
|
||||||
|
// First failure
|
||||||
|
cb.on_failure();
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||||
|
assert_eq!(cb.failure_count, 1);
|
||||||
|
|
||||||
|
// Second failure opens circuit
|
||||||
|
cb.on_failure();
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::Open);
|
||||||
|
assert!(!cb.can_execute());
|
||||||
|
|
||||||
|
// Wait and try half-open
|
||||||
|
std::thread::sleep(Duration::from_millis(150));
|
||||||
|
cb.try_half_open();
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::HalfOpen);
|
||||||
|
assert!(cb.can_execute());
|
||||||
|
|
||||||
|
// Success closes circuit
|
||||||
|
cb.on_success();
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||||
|
assert_eq!(cb.failure_count, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cache_persistence_setup() {
|
||||||
|
let temp_dir = tempdir().unwrap();
|
||||||
|
let cache_path = temp_dir
|
||||||
|
.path()
|
||||||
|
.join("test_cache.bin")
|
||||||
|
.to_string_lossy()
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
// Test cache path configuration
|
||||||
|
assert!(!cache_path.is_empty());
|
||||||
|
assert!(cache_path.ends_with("test_cache.bin"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_metrics_snapshot_calculations() {
|
||||||
|
let snapshot = MetricsSnapshot {
|
||||||
|
tts_requests: 20,
|
||||||
|
tts_cache_hits: 15,
|
||||||
|
tts_cache_misses: 5,
|
||||||
|
regex_cache_hits: 8,
|
||||||
|
regex_cache_misses: 2,
|
||||||
|
database_operations: 30,
|
||||||
|
voice_connections: 5,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Test TTS cache hit rate
|
||||||
|
let tts_hit_rate = snapshot.tts_cache_hit_rate();
|
||||||
|
assert!((tts_hit_rate - 0.75).abs() < f64::EPSILON);
|
||||||
|
|
||||||
|
// Test regex cache hit rate
|
||||||
|
let regex_hit_rate = snapshot.regex_cache_hit_rate();
|
||||||
|
assert!((regex_hit_rate - 0.8).abs() < f64::EPSILON);
|
||||||
|
|
||||||
|
// Test edge case with no operations
|
||||||
|
let empty_snapshot = MetricsSnapshot {
|
||||||
|
tts_requests: 0,
|
||||||
|
tts_cache_hits: 0,
|
||||||
|
tts_cache_misses: 0,
|
||||||
|
regex_cache_hits: 0,
|
||||||
|
regex_cache_misses: 0,
|
||||||
|
database_operations: 0,
|
||||||
|
voice_connections: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(empty_snapshot.tts_cache_hit_rate(), 0.0);
|
||||||
|
assert_eq!(empty_snapshot.regex_cache_hit_rate(), 0.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
use crate::stream_input::Mp3Request;
|
use crate::{errors::NCBError, stream_input::Mp3Request};
|
||||||
|
|
||||||
use super::structs::{speaker::Speaker, stream::TTSResponse};
|
use super::structs::{speaker::Speaker, stream::TTSResponse};
|
||||||
|
|
||||||
const BASE_API_URL: &str = "https://deprecatedapis.tts.quest/v2/";
|
const BASE_API_URL: &str = "https://deprecatedapis.tts.quest/v2/";
|
||||||
|
const STREAM_API_URL: &str = "https://api.tts.quest/v3/voicevox/synthesis";
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct VOICEVOX {
|
pub struct VOICEVOX {
|
||||||
@ -12,27 +13,27 @@ pub struct VOICEVOX {
|
|||||||
|
|
||||||
impl VOICEVOX {
|
impl VOICEVOX {
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub async fn get_styles(&self) -> Vec<(String, i64)> {
|
pub async fn get_styles(&self) -> Result<Vec<(String, i64)>, NCBError> {
|
||||||
let speakers = self.get_speaker_list().await;
|
let speakers = self.get_speaker_list().await?;
|
||||||
let mut speaker_list = vec![];
|
let mut speaker_list = Vec::new();
|
||||||
for speaker in speakers {
|
for speaker in speakers {
|
||||||
for style in speaker.styles {
|
for style in speaker.styles {
|
||||||
speaker_list.push((format!("{} - {}", speaker.name, style.name), style.id))
|
speaker_list.push((format!("{} - {}", speaker.name, style.name), style.id))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
speaker_list
|
Ok(speaker_list)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub async fn get_speakers(&self) -> Vec<String> {
|
pub async fn get_speakers(&self) -> Result<Vec<String>, NCBError> {
|
||||||
let speakers = self.get_speaker_list().await;
|
let speakers = self.get_speaker_list().await?;
|
||||||
let mut speaker_list = vec![];
|
let mut speaker_list = Vec::new();
|
||||||
for speaker in speakers {
|
for speaker in speakers {
|
||||||
speaker_list.push(speaker.name)
|
speaker_list.push(speaker.name)
|
||||||
}
|
}
|
||||||
|
|
||||||
speaker_list
|
Ok(speaker_list)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(key: Option<String>, original_api_url: Option<String>) -> Self {
|
pub fn new(key: Option<String>, original_api_url: Option<String>) -> Self {
|
||||||
@ -43,24 +44,30 @@ impl VOICEVOX {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
async fn get_speaker_list(&self) -> Vec<Speaker> {
|
async fn get_speaker_list(&self) -> Result<Vec<Speaker>, NCBError> {
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
let client = if let Some(key) = &self.key {
|
let request = if let Some(key) = &self.key {
|
||||||
client
|
client
|
||||||
.get(BASE_API_URL.to_string() + "voicevox/speakers/")
|
.get(format!("{}{}", BASE_API_URL, "voicevox/speakers/"))
|
||||||
.query(&[("key", key)])
|
.query(&[("key", key)])
|
||||||
} else if let Some(original_api_url) = &self.original_api_url {
|
} else if let Some(original_api_url) = &self.original_api_url {
|
||||||
client.get(original_api_url.to_string() + "/speakers")
|
client.get(format!("{}/speakers", original_api_url))
|
||||||
} else {
|
} else {
|
||||||
panic!("No API key or original API URL provided.")
|
return Err(NCBError::voicevox("No API key or original API URL provided"));
|
||||||
};
|
};
|
||||||
|
|
||||||
match client.send().await {
|
let response = request.send().await
|
||||||
Ok(response) => response.json().await.unwrap(),
|
.map_err(|e| NCBError::voicevox(format!("Failed to fetch speakers: {}", e)))?;
|
||||||
Err(err) => {
|
|
||||||
panic!("Cannot get speaker list. {err:?}")
|
if !response.status().is_success() {
|
||||||
}
|
return Err(NCBError::voicevox(format!(
|
||||||
|
"API request failed with status: {}",
|
||||||
|
response.status()
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
response.json().await
|
||||||
|
.map_err(|e| NCBError::voicevox(format!("Failed to parse speaker list: {}", e)))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
@ -68,24 +75,33 @@ impl VOICEVOX {
|
|||||||
&self,
|
&self,
|
||||||
text: String,
|
text: String,
|
||||||
speaker: i64,
|
speaker: i64,
|
||||||
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
|
) -> Result<Vec<u8>, NCBError> {
|
||||||
|
let key = self.key.as_ref()
|
||||||
|
.ok_or_else(|| NCBError::voicevox("API key required for synthesis"))?;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
match client
|
let response = client
|
||||||
.post(BASE_API_URL.to_string() + "voicevox/audio/")
|
.post(format!("{}{}", BASE_API_URL, "voicevox/audio/"))
|
||||||
.query(&[
|
.query(&[
|
||||||
("speaker", speaker.to_string()),
|
("speaker", speaker.to_string()),
|
||||||
("text", text),
|
("text", text),
|
||||||
("key", self.key.clone().unwrap()),
|
("key", key.clone()),
|
||||||
])
|
])
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
{
|
.map_err(|e| NCBError::voicevox(format!("Synthesis request failed: {}", e)))?;
|
||||||
Ok(response) => {
|
|
||||||
let body = response.bytes().await?;
|
if !response.status().is_success() {
|
||||||
Ok(body.to_vec())
|
return Err(NCBError::voicevox(format!(
|
||||||
}
|
"Synthesis failed with status: {}",
|
||||||
Err(err) => Err(Box::new(err)),
|
response.status()
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let body = response.bytes().await
|
||||||
|
.map_err(|e| NCBError::voicevox(format!("Failed to read response body: {}", e)))?;
|
||||||
|
|
||||||
|
Ok(body.to_vec())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
@ -93,14 +109,21 @@ impl VOICEVOX {
|
|||||||
&self,
|
&self,
|
||||||
text: String,
|
text: String,
|
||||||
speaker: i64,
|
speaker: i64,
|
||||||
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
|
) -> Result<Vec<u8>, NCBError> {
|
||||||
let client =
|
let api_url = self.original_api_url.as_ref()
|
||||||
voicevox_client::Client::new(self.original_api_url.as_ref().unwrap().clone(), None);
|
.ok_or_else(|| NCBError::voicevox("Original API URL required for synthesis"))?;
|
||||||
|
|
||||||
|
let client = voicevox_client::Client::new(api_url.clone(), None);
|
||||||
let audio_query = client
|
let audio_query = client
|
||||||
.create_audio_query(&text, speaker as i32, None)
|
.create_audio_query(&text, speaker as i32, None)
|
||||||
.await?;
|
.await
|
||||||
println!("{:?}", audio_query.audio_query);
|
.map_err(|e| NCBError::voicevox(format!("Failed to create audio query: {}", e)))?;
|
||||||
let audio = audio_query.synthesis(speaker as i32, true).await?;
|
|
||||||
|
tracing::debug!(audio_query = ?audio_query.audio_query, "Generated audio query");
|
||||||
|
|
||||||
|
let audio = audio_query.synthesis(speaker as i32, true).await
|
||||||
|
.map_err(|e| NCBError::voicevox(format!("Audio synthesis failed: {}", e)))?;
|
||||||
|
|
||||||
Ok(audio.into())
|
Ok(audio.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,25 +132,35 @@ impl VOICEVOX {
|
|||||||
&self,
|
&self,
|
||||||
text: String,
|
text: String,
|
||||||
speaker: i64,
|
speaker: i64,
|
||||||
) -> Result<Mp3Request, Box<dyn std::error::Error>> {
|
) -> Result<Mp3Request, NCBError> {
|
||||||
|
let key = self.key.as_ref()
|
||||||
|
.ok_or_else(|| NCBError::voicevox("API key required for stream synthesis"))?;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
match client
|
let response = client
|
||||||
.post("https://api.tts.quest/v3/voicevox/synthesis")
|
.post(STREAM_API_URL)
|
||||||
.query(&[
|
.query(&[
|
||||||
("speaker", speaker.to_string()),
|
("speaker", speaker.to_string()),
|
||||||
("text", text),
|
("text", text),
|
||||||
("key", self.key.clone().unwrap()),
|
("key", key.clone()),
|
||||||
])
|
])
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
{
|
.map_err(|e| NCBError::voicevox(format!("Stream synthesis request failed: {}", e)))?;
|
||||||
Ok(response) => {
|
|
||||||
let body = response.text().await.unwrap();
|
|
||||||
let response: TTSResponse = serde_json::from_str(&body).unwrap();
|
|
||||||
|
|
||||||
Ok(Mp3Request::new(reqwest::Client::new(), response.mp3_streaming_url).into())
|
if !response.status().is_success() {
|
||||||
}
|
return Err(NCBError::voicevox(format!(
|
||||||
Err(err) => Err(Box::new(err)),
|
"Stream synthesis failed with status: {}",
|
||||||
|
response.status()
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let body = response.text().await
|
||||||
|
.map_err(|e| NCBError::voicevox(format!("Failed to read response text: {}", e)))?;
|
||||||
|
|
||||||
|
let tts_response: TTSResponse = serde_json::from_str(&body)
|
||||||
|
.map_err(|e| NCBError::voicevox(format!("Failed to parse TTS response: {}", e)))?;
|
||||||
|
|
||||||
|
Ok(Mp3Request::new(reqwest::Client::new(), tts_response.mp3_streaming_url))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
594
src/utils.rs
Normal file
594
src/utils.rs
Normal file
@ -0,0 +1,594 @@
|
|||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use lru::LruCache;
|
||||||
|
use regex::Regex;
|
||||||
|
use std::{num::NonZeroUsize, sync::RwLock};
|
||||||
|
use tracing::{debug, error, warn};
|
||||||
|
|
||||||
|
use crate::errors::{constants::*, NCBError, Result};
|
||||||
|
|
||||||
|
/// Regex compilation cache to avoid recompiling the same patterns
|
||||||
|
static REGEX_CACHE: Lazy<RwLock<LruCache<String, Regex>>> =
|
||||||
|
Lazy::new(|| RwLock::new(LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap())));
|
||||||
|
|
||||||
|
/// Circuit breaker states for external API calls
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum CircuitBreakerState {
|
||||||
|
Closed,
|
||||||
|
Open,
|
||||||
|
HalfOpen,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Circuit breaker for handling external API failures
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CircuitBreaker {
|
||||||
|
pub state: CircuitBreakerState,
|
||||||
|
pub failure_count: u32,
|
||||||
|
pub last_failure_time: Option<std::time::Instant>,
|
||||||
|
pub threshold: u32,
|
||||||
|
pub timeout: std::time::Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for CircuitBreaker {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
state: CircuitBreakerState::Closed,
|
||||||
|
failure_count: 0,
|
||||||
|
last_failure_time: None,
|
||||||
|
threshold: 5,
|
||||||
|
timeout: std::time::Duration::from_secs(60),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CircuitBreaker {
|
||||||
|
pub fn new(threshold: u32, timeout: std::time::Duration) -> Self {
|
||||||
|
Self {
|
||||||
|
threshold,
|
||||||
|
timeout,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn can_execute(&self) -> bool {
|
||||||
|
match self.state {
|
||||||
|
CircuitBreakerState::Closed => true,
|
||||||
|
CircuitBreakerState::Open => {
|
||||||
|
if let Some(last_failure) = self.last_failure_time {
|
||||||
|
last_failure.elapsed() >= self.timeout
|
||||||
|
} else {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CircuitBreakerState::HalfOpen => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn on_success(&mut self) {
|
||||||
|
self.failure_count = 0;
|
||||||
|
self.state = CircuitBreakerState::Closed;
|
||||||
|
self.last_failure_time = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn on_failure(&mut self) {
|
||||||
|
self.failure_count += 1;
|
||||||
|
self.last_failure_time = Some(std::time::Instant::now());
|
||||||
|
|
||||||
|
if self.failure_count >= self.threshold {
|
||||||
|
self.state = CircuitBreakerState::Open;
|
||||||
|
} else if self.state == CircuitBreakerState::HalfOpen {
|
||||||
|
self.state = CircuitBreakerState::Open;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn try_half_open(&mut self) {
|
||||||
|
if self.state == CircuitBreakerState::Open {
|
||||||
|
if let Some(last_failure) = self.last_failure_time {
|
||||||
|
if last_failure.elapsed() >= self.timeout {
|
||||||
|
self.state = CircuitBreakerState::HalfOpen;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cached regex compilation with error handling
|
||||||
|
pub fn get_cached_regex(pattern: &str) -> Result<Regex> {
|
||||||
|
// First try to get from cache
|
||||||
|
{
|
||||||
|
let cache = REGEX_CACHE.read().unwrap();
|
||||||
|
if let Some(cached_regex) = cache.peek(pattern) {
|
||||||
|
debug!(pattern = pattern, "Regex cache hit");
|
||||||
|
return Ok(cached_regex.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!(pattern = pattern, "Regex cache miss, compiling");
|
||||||
|
|
||||||
|
// Compile regex with error handling
|
||||||
|
match Regex::new(pattern) {
|
||||||
|
Ok(regex) => {
|
||||||
|
// Cache successful compilation
|
||||||
|
{
|
||||||
|
let mut cache = REGEX_CACHE.write().unwrap();
|
||||||
|
cache.put(pattern.to_string(), regex.clone());
|
||||||
|
}
|
||||||
|
Ok(regex)
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(pattern = pattern, error = %e, "Failed to compile regex");
|
||||||
|
Err(NCBError::invalid_regex(format!("{}: {}", pattern, e)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retry logic with exponential backoff
|
||||||
|
pub async fn retry_with_backoff<F, Fut, T, E>(
|
||||||
|
mut operation: F,
|
||||||
|
max_attempts: u32,
|
||||||
|
initial_delay: std::time::Duration,
|
||||||
|
) -> std::result::Result<T, E>
|
||||||
|
where
|
||||||
|
F: FnMut() -> Fut,
|
||||||
|
Fut: std::future::Future<Output = std::result::Result<T, E>>,
|
||||||
|
E: std::fmt::Display,
|
||||||
|
{
|
||||||
|
let mut attempts = 0;
|
||||||
|
let mut delay = initial_delay;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
attempts += 1;
|
||||||
|
|
||||||
|
match operation().await {
|
||||||
|
Ok(result) => {
|
||||||
|
if attempts > 1 {
|
||||||
|
debug!(attempts = attempts, "Operation succeeded after retry");
|
||||||
|
}
|
||||||
|
return Ok(result);
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
if attempts >= max_attempts {
|
||||||
|
error!(
|
||||||
|
attempts = attempts,
|
||||||
|
error = %error,
|
||||||
|
"Operation failed after maximum retry attempts"
|
||||||
|
);
|
||||||
|
return Err(error);
|
||||||
|
}
|
||||||
|
|
||||||
|
warn!(
|
||||||
|
attempt = attempts,
|
||||||
|
max_attempts = max_attempts,
|
||||||
|
delay_ms = delay.as_millis(),
|
||||||
|
error = %error,
|
||||||
|
"Operation failed, retrying with backoff"
|
||||||
|
);
|
||||||
|
|
||||||
|
tokio::time::sleep(delay).await;
|
||||||
|
delay = std::cmp::min(delay * 2, std::time::Duration::from_secs(30));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Rate limiter using token bucket algorithm
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RateLimiter {
|
||||||
|
tokens: std::sync::Arc<std::sync::RwLock<f64>>,
|
||||||
|
capacity: f64,
|
||||||
|
refill_rate: f64,
|
||||||
|
last_refill: std::sync::Arc<std::sync::RwLock<std::time::Instant>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RateLimiter {
|
||||||
|
pub fn new(capacity: f64, refill_rate: f64) -> Self {
|
||||||
|
Self {
|
||||||
|
tokens: std::sync::Arc::new(std::sync::RwLock::new(capacity)),
|
||||||
|
capacity,
|
||||||
|
refill_rate,
|
||||||
|
last_refill: std::sync::Arc::new(std::sync::RwLock::new(std::time::Instant::now())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn try_acquire(&self, tokens: f64) -> bool {
|
||||||
|
self.refill();
|
||||||
|
|
||||||
|
let mut current_tokens = self.tokens.write().unwrap();
|
||||||
|
if *current_tokens >= tokens {
|
||||||
|
*current_tokens -= tokens;
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn refill(&self) {
|
||||||
|
let now = std::time::Instant::now();
|
||||||
|
let mut last_refill = self.last_refill.write().unwrap();
|
||||||
|
let elapsed = now.duration_since(*last_refill).as_secs_f64();
|
||||||
|
|
||||||
|
if elapsed > 0.0 {
|
||||||
|
let tokens_to_add = elapsed * self.refill_rate;
|
||||||
|
let mut current_tokens = self.tokens.write().unwrap();
|
||||||
|
*current_tokens = (*current_tokens + tokens_to_add).min(self.capacity);
|
||||||
|
*last_refill = now;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Performance metrics collection
|
||||||
|
#[derive(Debug, Default, Clone)]
|
||||||
|
pub struct PerformanceMetrics {
|
||||||
|
pub tts_requests: std::sync::Arc<std::sync::atomic::AtomicU64>,
|
||||||
|
pub tts_cache_hits: std::sync::Arc<std::sync::atomic::AtomicU64>,
|
||||||
|
pub tts_cache_misses: std::sync::Arc<std::sync::atomic::AtomicU64>,
|
||||||
|
pub regex_cache_hits: std::sync::Arc<std::sync::atomic::AtomicU64>,
|
||||||
|
pub regex_cache_misses: std::sync::Arc<std::sync::atomic::AtomicU64>,
|
||||||
|
pub database_operations: std::sync::Arc<std::sync::atomic::AtomicU64>,
|
||||||
|
pub voice_connections: std::sync::Arc<std::sync::atomic::AtomicU64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PerformanceMetrics {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn increment_tts_requests(&self) {
|
||||||
|
self.tts_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn increment_tts_cache_hits(&self) {
|
||||||
|
self.tts_cache_hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn increment_tts_cache_misses(&self) {
|
||||||
|
self.tts_cache_misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn increment_regex_cache_hits(&self) {
|
||||||
|
self.regex_cache_hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn increment_regex_cache_misses(&self) {
|
||||||
|
self.regex_cache_misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn increment_database_operations(&self) {
|
||||||
|
self.database_operations.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn increment_voice_connections(&self) {
|
||||||
|
self.voice_connections.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_stats(&self) -> MetricsSnapshot {
|
||||||
|
MetricsSnapshot {
|
||||||
|
tts_requests: self.tts_requests.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
tts_cache_hits: self.tts_cache_hits.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
tts_cache_misses: self.tts_cache_misses.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
regex_cache_hits: self.regex_cache_hits.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
regex_cache_misses: self.regex_cache_misses.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
database_operations: self.database_operations.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
voice_connections: self.voice_connections.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MetricsSnapshot {
|
||||||
|
pub tts_requests: u64,
|
||||||
|
pub tts_cache_hits: u64,
|
||||||
|
pub tts_cache_misses: u64,
|
||||||
|
pub regex_cache_hits: u64,
|
||||||
|
pub regex_cache_misses: u64,
|
||||||
|
pub database_operations: u64,
|
||||||
|
pub voice_connections: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MetricsSnapshot {
|
||||||
|
pub fn tts_cache_hit_rate(&self) -> f64 {
|
||||||
|
if self.tts_cache_hits + self.tts_cache_misses > 0 {
|
||||||
|
self.tts_cache_hits as f64 / (self.tts_cache_hits + self.tts_cache_misses) as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn regex_cache_hit_rate(&self) -> f64 {
|
||||||
|
if self.regex_cache_hits + self.regex_cache_misses > 0 {
|
||||||
|
self.regex_cache_hits as f64 / (self.regex_cache_hits + self.regex_cache_misses) as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::time::Duration;
|
||||||
|
use crate::errors::constants::CIRCUIT_BREAKER_FAILURE_THRESHOLD;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_circuit_breaker_default() {
|
||||||
|
let cb = CircuitBreaker::default();
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||||
|
assert_eq!(cb.failure_count, 0);
|
||||||
|
assert!(cb.can_execute());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_circuit_breaker_new() {
|
||||||
|
let cb = CircuitBreaker::new(3, Duration::from_secs(10));
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||||
|
assert_eq!(cb.threshold, 3);
|
||||||
|
assert_eq!(cb.timeout, Duration::from_secs(10));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_circuit_breaker_failure_threshold() {
|
||||||
|
let mut cb = CircuitBreaker::default();
|
||||||
|
|
||||||
|
// Test failures up to threshold
|
||||||
|
for i in 0..CIRCUIT_BREAKER_FAILURE_THRESHOLD {
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||||
|
assert!(cb.can_execute());
|
||||||
|
cb.on_failure();
|
||||||
|
assert_eq!(cb.failure_count, i + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should open after reaching threshold
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::Open);
|
||||||
|
assert!(!cb.can_execute());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_circuit_breaker_success_resets() {
|
||||||
|
let mut cb = CircuitBreaker::default();
|
||||||
|
|
||||||
|
// Add some failures
|
||||||
|
cb.on_failure();
|
||||||
|
cb.on_failure();
|
||||||
|
assert_eq!(cb.failure_count, 2);
|
||||||
|
|
||||||
|
// Success should reset
|
||||||
|
cb.on_success();
|
||||||
|
assert_eq!(cb.failure_count, 0);
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_circuit_breaker_half_open() {
|
||||||
|
let mut cb = CircuitBreaker::new(1, Duration::from_millis(100));
|
||||||
|
|
||||||
|
// Trigger failure to open circuit
|
||||||
|
cb.on_failure();
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::Open);
|
||||||
|
assert!(!cb.can_execute());
|
||||||
|
|
||||||
|
// Wait for timeout
|
||||||
|
std::thread::sleep(Duration::from_millis(150));
|
||||||
|
|
||||||
|
// Should allow transition to half-open
|
||||||
|
cb.try_half_open();
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::HalfOpen);
|
||||||
|
assert!(cb.can_execute());
|
||||||
|
|
||||||
|
// Success in half-open should close circuit
|
||||||
|
cb.on_success();
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::Closed);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_circuit_breaker_half_open_failure() {
|
||||||
|
let mut cb = CircuitBreaker::new(1, Duration::from_millis(100));
|
||||||
|
|
||||||
|
// Open circuit
|
||||||
|
cb.on_failure();
|
||||||
|
std::thread::sleep(Duration::from_millis(150));
|
||||||
|
cb.try_half_open();
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::HalfOpen);
|
||||||
|
|
||||||
|
// Failure in half-open should reopen circuit
|
||||||
|
cb.on_failure();
|
||||||
|
assert_eq!(cb.state, CircuitBreakerState::Open);
|
||||||
|
assert!(!cb.can_execute());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_retry_with_backoff_success_first_try() {
|
||||||
|
let mut call_count = 0;
|
||||||
|
let result = retry_with_backoff(
|
||||||
|
|| {
|
||||||
|
call_count += 1;
|
||||||
|
async { Ok::<i32, &'static str>(42) }
|
||||||
|
},
|
||||||
|
3,
|
||||||
|
Duration::from_millis(100),
|
||||||
|
).await;
|
||||||
|
|
||||||
|
assert_eq!(result.unwrap(), 42);
|
||||||
|
assert_eq!(call_count, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_retry_with_backoff_success_after_retries() {
|
||||||
|
let mut call_count = 0;
|
||||||
|
let result = retry_with_backoff(
|
||||||
|
|| {
|
||||||
|
call_count += 1;
|
||||||
|
async move {
|
||||||
|
if call_count < 3 {
|
||||||
|
Err("temporary error")
|
||||||
|
} else {
|
||||||
|
Ok::<i32, &'static str>(42)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
5,
|
||||||
|
Duration::from_millis(10),
|
||||||
|
).await;
|
||||||
|
|
||||||
|
assert_eq!(result.unwrap(), 42);
|
||||||
|
assert_eq!(call_count, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_retry_with_backoff_max_attempts() {
|
||||||
|
let mut call_count = 0;
|
||||||
|
let result = retry_with_backoff(
|
||||||
|
|| {
|
||||||
|
call_count += 1;
|
||||||
|
async { Err::<i32, &'static str>("persistent error") }
|
||||||
|
},
|
||||||
|
3,
|
||||||
|
Duration::from_millis(10),
|
||||||
|
).await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert_eq!(result.unwrap_err(), "persistent error");
|
||||||
|
assert_eq!(call_count, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_cached_regex_valid_pattern() {
|
||||||
|
// Clear cache first
|
||||||
|
{
|
||||||
|
let mut cache = REGEX_CACHE.write().unwrap();
|
||||||
|
cache.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
let pattern = r"[a-zA-Z]+";
|
||||||
|
let result1 = get_cached_regex(pattern);
|
||||||
|
assert!(result1.is_ok());
|
||||||
|
|
||||||
|
let result2 = get_cached_regex(pattern);
|
||||||
|
assert!(result2.is_ok());
|
||||||
|
|
||||||
|
// Both should work and second should be from cache
|
||||||
|
let regex1 = result1.unwrap();
|
||||||
|
let regex2 = result2.unwrap();
|
||||||
|
assert!(regex1.is_match("hello"));
|
||||||
|
assert!(regex2.is_match("world"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_cached_regex_invalid_pattern() {
|
||||||
|
let pattern = r"[";
|
||||||
|
let result = get_cached_regex(pattern);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
if let Err(NCBError::InvalidRegex(msg)) = result {
|
||||||
|
// The error message contains the pattern and the regex error
|
||||||
|
assert!(msg.contains("["));
|
||||||
|
} else {
|
||||||
|
panic!("Expected InvalidRegex error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rate_limiter_basic() {
|
||||||
|
let limiter = RateLimiter::new(5.0, 1.0); // 5 tokens, 1 per second
|
||||||
|
|
||||||
|
// Should be able to acquire 5 tokens initially
|
||||||
|
assert!(limiter.try_acquire(1.0));
|
||||||
|
assert!(limiter.try_acquire(1.0));
|
||||||
|
assert!(limiter.try_acquire(1.0));
|
||||||
|
assert!(limiter.try_acquire(1.0));
|
||||||
|
assert!(limiter.try_acquire(1.0));
|
||||||
|
|
||||||
|
// 6th token should fail
|
||||||
|
assert!(!limiter.try_acquire(1.0));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rate_limiter_partial_tokens() {
|
||||||
|
let limiter = RateLimiter::new(2.0, 1.0);
|
||||||
|
|
||||||
|
// Acquire partial tokens
|
||||||
|
assert!(limiter.try_acquire(0.5));
|
||||||
|
assert!(limiter.try_acquire(0.5));
|
||||||
|
assert!(limiter.try_acquire(0.5));
|
||||||
|
assert!(limiter.try_acquire(0.5));
|
||||||
|
|
||||||
|
// Should fail with no tokens left
|
||||||
|
assert!(!limiter.try_acquire(0.1));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_performance_metrics_increment() {
|
||||||
|
let metrics = PerformanceMetrics::default();
|
||||||
|
|
||||||
|
assert_eq!(metrics.tts_requests.load(std::sync::atomic::Ordering::Relaxed), 0);
|
||||||
|
|
||||||
|
metrics.increment_tts_requests();
|
||||||
|
metrics.increment_tts_requests();
|
||||||
|
|
||||||
|
assert_eq!(metrics.tts_requests.load(std::sync::atomic::Ordering::Relaxed), 2);
|
||||||
|
|
||||||
|
metrics.increment_tts_cache_hits();
|
||||||
|
assert_eq!(metrics.tts_cache_hits.load(std::sync::atomic::Ordering::Relaxed), 1);
|
||||||
|
|
||||||
|
metrics.increment_tts_cache_misses();
|
||||||
|
assert_eq!(metrics.tts_cache_misses.load(std::sync::atomic::Ordering::Relaxed), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_metrics_snapshot_cache_hit_rate() {
|
||||||
|
let snapshot = MetricsSnapshot {
|
||||||
|
tts_requests: 10,
|
||||||
|
tts_cache_hits: 7,
|
||||||
|
tts_cache_misses: 3,
|
||||||
|
regex_cache_hits: 0,
|
||||||
|
regex_cache_misses: 0,
|
||||||
|
database_operations: 0,
|
||||||
|
voice_connections: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!((snapshot.tts_cache_hit_rate() - 0.7).abs() < f64::EPSILON);
|
||||||
|
|
||||||
|
let empty_snapshot = MetricsSnapshot {
|
||||||
|
tts_requests: 0,
|
||||||
|
tts_cache_hits: 0,
|
||||||
|
tts_cache_misses: 0,
|
||||||
|
regex_cache_hits: 0,
|
||||||
|
regex_cache_misses: 0,
|
||||||
|
database_operations: 0,
|
||||||
|
voice_connections: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(empty_snapshot.tts_cache_hit_rate(), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_metrics_snapshot_regex_cache_hit_rate() {
|
||||||
|
let snapshot = MetricsSnapshot {
|
||||||
|
tts_requests: 0,
|
||||||
|
tts_cache_hits: 0,
|
||||||
|
tts_cache_misses: 0,
|
||||||
|
regex_cache_hits: 8,
|
||||||
|
regex_cache_misses: 2,
|
||||||
|
database_operations: 0,
|
||||||
|
voice_connections: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!((snapshot.regex_cache_hit_rate() - 0.8).abs() < f64::EPSILON);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_performance_metrics_get_stats() {
|
||||||
|
let metrics = PerformanceMetrics::default();
|
||||||
|
|
||||||
|
// Add some data
|
||||||
|
metrics.increment_tts_requests();
|
||||||
|
metrics.increment_tts_requests();
|
||||||
|
metrics.increment_tts_cache_hits();
|
||||||
|
metrics.increment_database_operations();
|
||||||
|
|
||||||
|
let stats = metrics.get_stats();
|
||||||
|
|
||||||
|
assert_eq!(stats.tts_requests, 2);
|
||||||
|
assert_eq!(stats.tts_cache_hits, 1);
|
||||||
|
assert_eq!(stats.tts_cache_misses, 0);
|
||||||
|
assert_eq!(stats.database_operations, 1);
|
||||||
|
}
|
||||||
|
}
|
BIN
tts_cache.bin
Normal file
BIN
tts_cache.bin
Normal file
Binary file not shown.
Reference in New Issue
Block a user