Files
ncb-tts-r2/src/utils.rs
mii443 733646b6b8 refactor: Major overhaul with error handling, resilience patterns, and observability
- Add library configuration to support both lib and binary targets
- Implement unified error handling with NCBError throughout the codebase
- Add circuit breaker pattern for external API calls (Voicevox, GCP TTS)
- Introduce comprehensive performance metrics and monitoring
- Add cache persistence with disk storage support
- Implement retry mechanism with exponential backoff
- Add configuration file support (config.toml) with env var fallback
- Enhance logging with structured tracing (debug, warn, error levels)
- Add extensive unit tests for cache, metrics, and circuit breaker
- Update base64 decoding to use modern API
- Improve API error handling for Voicevox and GCP TTS clients

Breaking changes:
- Function signatures now return Result<T, NCBError> instead of panicking
- Cache key structure modified with serialization support
2025-05-28 01:01:12 +09:00

594 lines
18 KiB
Rust

use once_cell::sync::Lazy;
use lru::LruCache;
use regex::Regex;
use std::{num::NonZeroUsize, sync::RwLock};
use tracing::{debug, error, warn};
use crate::errors::{constants::*, NCBError, Result};
/// Regex compilation cache to avoid recompiling the same patterns
static REGEX_CACHE: Lazy<RwLock<LruCache<String, Regex>>> =
Lazy::new(|| RwLock::new(LruCache::new(NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap())));
/// Circuit breaker states for external API calls
#[derive(Debug, Clone, PartialEq)]
pub enum CircuitBreakerState {
Closed,
Open,
HalfOpen,
}
/// Circuit breaker for handling external API failures
#[derive(Debug, Clone)]
pub struct CircuitBreaker {
pub state: CircuitBreakerState,
pub failure_count: u32,
pub last_failure_time: Option<std::time::Instant>,
pub threshold: u32,
pub timeout: std::time::Duration,
}
impl Default for CircuitBreaker {
fn default() -> Self {
Self {
state: CircuitBreakerState::Closed,
failure_count: 0,
last_failure_time: None,
threshold: 5,
timeout: std::time::Duration::from_secs(60),
}
}
}
impl CircuitBreaker {
pub fn new(threshold: u32, timeout: std::time::Duration) -> Self {
Self {
threshold,
timeout,
..Default::default()
}
}
pub fn can_execute(&self) -> bool {
match self.state {
CircuitBreakerState::Closed => true,
CircuitBreakerState::Open => {
if let Some(last_failure) = self.last_failure_time {
last_failure.elapsed() >= self.timeout
} else {
true
}
}
CircuitBreakerState::HalfOpen => true,
}
}
pub fn on_success(&mut self) {
self.failure_count = 0;
self.state = CircuitBreakerState::Closed;
self.last_failure_time = None;
}
pub fn on_failure(&mut self) {
self.failure_count += 1;
self.last_failure_time = Some(std::time::Instant::now());
if self.failure_count >= self.threshold {
self.state = CircuitBreakerState::Open;
} else if self.state == CircuitBreakerState::HalfOpen {
self.state = CircuitBreakerState::Open;
}
}
pub fn try_half_open(&mut self) {
if self.state == CircuitBreakerState::Open {
if let Some(last_failure) = self.last_failure_time {
if last_failure.elapsed() >= self.timeout {
self.state = CircuitBreakerState::HalfOpen;
}
}
}
}
}
/// Cached regex compilation with error handling
pub fn get_cached_regex(pattern: &str) -> Result<Regex> {
// First try to get from cache
{
let cache = REGEX_CACHE.read().unwrap();
if let Some(cached_regex) = cache.peek(pattern) {
debug!(pattern = pattern, "Regex cache hit");
return Ok(cached_regex.clone());
}
}
debug!(pattern = pattern, "Regex cache miss, compiling");
// Compile regex with error handling
match Regex::new(pattern) {
Ok(regex) => {
// Cache successful compilation
{
let mut cache = REGEX_CACHE.write().unwrap();
cache.put(pattern.to_string(), regex.clone());
}
Ok(regex)
}
Err(e) => {
error!(pattern = pattern, error = %e, "Failed to compile regex");
Err(NCBError::invalid_regex(format!("{}: {}", pattern, e)))
}
}
}
/// Retry logic with exponential backoff
pub async fn retry_with_backoff<F, Fut, T, E>(
mut operation: F,
max_attempts: u32,
initial_delay: std::time::Duration,
) -> std::result::Result<T, E>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = std::result::Result<T, E>>,
E: std::fmt::Display,
{
let mut attempts = 0;
let mut delay = initial_delay;
loop {
attempts += 1;
match operation().await {
Ok(result) => {
if attempts > 1 {
debug!(attempts = attempts, "Operation succeeded after retry");
}
return Ok(result);
}
Err(error) => {
if attempts >= max_attempts {
error!(
attempts = attempts,
error = %error,
"Operation failed after maximum retry attempts"
);
return Err(error);
}
warn!(
attempt = attempts,
max_attempts = max_attempts,
delay_ms = delay.as_millis(),
error = %error,
"Operation failed, retrying with backoff"
);
tokio::time::sleep(delay).await;
delay = std::cmp::min(delay * 2, std::time::Duration::from_secs(30));
}
}
}
}
/// Rate limiter using token bucket algorithm
#[derive(Debug)]
pub struct RateLimiter {
tokens: std::sync::Arc<std::sync::RwLock<f64>>,
capacity: f64,
refill_rate: f64,
last_refill: std::sync::Arc<std::sync::RwLock<std::time::Instant>>,
}
impl RateLimiter {
pub fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
tokens: std::sync::Arc::new(std::sync::RwLock::new(capacity)),
capacity,
refill_rate,
last_refill: std::sync::Arc::new(std::sync::RwLock::new(std::time::Instant::now())),
}
}
pub fn try_acquire(&self, tokens: f64) -> bool {
self.refill();
let mut current_tokens = self.tokens.write().unwrap();
if *current_tokens >= tokens {
*current_tokens -= tokens;
true
} else {
false
}
}
fn refill(&self) {
let now = std::time::Instant::now();
let mut last_refill = self.last_refill.write().unwrap();
let elapsed = now.duration_since(*last_refill).as_secs_f64();
if elapsed > 0.0 {
let tokens_to_add = elapsed * self.refill_rate;
let mut current_tokens = self.tokens.write().unwrap();
*current_tokens = (*current_tokens + tokens_to_add).min(self.capacity);
*last_refill = now;
}
}
}
/// Performance metrics collection
#[derive(Debug, Default, Clone)]
pub struct PerformanceMetrics {
pub tts_requests: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub tts_cache_hits: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub tts_cache_misses: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub regex_cache_hits: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub regex_cache_misses: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub database_operations: std::sync::Arc<std::sync::atomic::AtomicU64>,
pub voice_connections: std::sync::Arc<std::sync::atomic::AtomicU64>,
}
impl PerformanceMetrics {
pub fn new() -> Self {
Self::default()
}
pub fn increment_tts_requests(&self) {
self.tts_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_tts_cache_hits(&self) {
self.tts_cache_hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_tts_cache_misses(&self) {
self.tts_cache_misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_regex_cache_hits(&self) {
self.regex_cache_hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_regex_cache_misses(&self) {
self.regex_cache_misses.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_database_operations(&self) {
self.database_operations.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn increment_voice_connections(&self) {
self.voice_connections.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn get_stats(&self) -> MetricsSnapshot {
MetricsSnapshot {
tts_requests: self.tts_requests.load(std::sync::atomic::Ordering::Relaxed),
tts_cache_hits: self.tts_cache_hits.load(std::sync::atomic::Ordering::Relaxed),
tts_cache_misses: self.tts_cache_misses.load(std::sync::atomic::Ordering::Relaxed),
regex_cache_hits: self.regex_cache_hits.load(std::sync::atomic::Ordering::Relaxed),
regex_cache_misses: self.regex_cache_misses.load(std::sync::atomic::Ordering::Relaxed),
database_operations: self.database_operations.load(std::sync::atomic::Ordering::Relaxed),
voice_connections: self.voice_connections.load(std::sync::atomic::Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone)]
pub struct MetricsSnapshot {
pub tts_requests: u64,
pub tts_cache_hits: u64,
pub tts_cache_misses: u64,
pub regex_cache_hits: u64,
pub regex_cache_misses: u64,
pub database_operations: u64,
pub voice_connections: u64,
}
impl MetricsSnapshot {
pub fn tts_cache_hit_rate(&self) -> f64 {
if self.tts_cache_hits + self.tts_cache_misses > 0 {
self.tts_cache_hits as f64 / (self.tts_cache_hits + self.tts_cache_misses) as f64
} else {
0.0
}
}
pub fn regex_cache_hit_rate(&self) -> f64 {
if self.regex_cache_hits + self.regex_cache_misses > 0 {
self.regex_cache_hits as f64 / (self.regex_cache_hits + self.regex_cache_misses) as f64
} else {
0.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
use crate::errors::constants::CIRCUIT_BREAKER_FAILURE_THRESHOLD;
#[test]
fn test_circuit_breaker_default() {
let cb = CircuitBreaker::default();
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert_eq!(cb.failure_count, 0);
assert!(cb.can_execute());
}
#[test]
fn test_circuit_breaker_new() {
let cb = CircuitBreaker::new(3, Duration::from_secs(10));
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert_eq!(cb.threshold, 3);
assert_eq!(cb.timeout, Duration::from_secs(10));
}
#[test]
fn test_circuit_breaker_failure_threshold() {
let mut cb = CircuitBreaker::default();
// Test failures up to threshold
for i in 0..CIRCUIT_BREAKER_FAILURE_THRESHOLD {
assert_eq!(cb.state, CircuitBreakerState::Closed);
assert!(cb.can_execute());
cb.on_failure();
assert_eq!(cb.failure_count, i + 1);
}
// Should open after reaching threshold
assert_eq!(cb.state, CircuitBreakerState::Open);
assert!(!cb.can_execute());
}
#[test]
fn test_circuit_breaker_success_resets() {
let mut cb = CircuitBreaker::default();
// Add some failures
cb.on_failure();
cb.on_failure();
assert_eq!(cb.failure_count, 2);
// Success should reset
cb.on_success();
assert_eq!(cb.failure_count, 0);
assert_eq!(cb.state, CircuitBreakerState::Closed);
}
#[test]
fn test_circuit_breaker_half_open() {
let mut cb = CircuitBreaker::new(1, Duration::from_millis(100));
// Trigger failure to open circuit
cb.on_failure();
assert_eq!(cb.state, CircuitBreakerState::Open);
assert!(!cb.can_execute());
// Wait for timeout
std::thread::sleep(Duration::from_millis(150));
// Should allow transition to half-open
cb.try_half_open();
assert_eq!(cb.state, CircuitBreakerState::HalfOpen);
assert!(cb.can_execute());
// Success in half-open should close circuit
cb.on_success();
assert_eq!(cb.state, CircuitBreakerState::Closed);
}
#[test]
fn test_circuit_breaker_half_open_failure() {
let mut cb = CircuitBreaker::new(1, Duration::from_millis(100));
// Open circuit
cb.on_failure();
std::thread::sleep(Duration::from_millis(150));
cb.try_half_open();
assert_eq!(cb.state, CircuitBreakerState::HalfOpen);
// Failure in half-open should reopen circuit
cb.on_failure();
assert_eq!(cb.state, CircuitBreakerState::Open);
assert!(!cb.can_execute());
}
#[tokio::test]
async fn test_retry_with_backoff_success_first_try() {
let mut call_count = 0;
let result = retry_with_backoff(
|| {
call_count += 1;
async { Ok::<i32, &'static str>(42) }
},
3,
Duration::from_millis(100),
).await;
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count, 1);
}
#[tokio::test]
async fn test_retry_with_backoff_success_after_retries() {
let mut call_count = 0;
let result = retry_with_backoff(
|| {
call_count += 1;
async move {
if call_count < 3 {
Err("temporary error")
} else {
Ok::<i32, &'static str>(42)
}
}
},
5,
Duration::from_millis(10),
).await;
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count, 3);
}
#[tokio::test]
async fn test_retry_with_backoff_max_attempts() {
let mut call_count = 0;
let result = retry_with_backoff(
|| {
call_count += 1;
async { Err::<i32, &'static str>("persistent error") }
},
3,
Duration::from_millis(10),
).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "persistent error");
assert_eq!(call_count, 3);
}
#[test]
fn test_get_cached_regex_valid_pattern() {
// Clear cache first
{
let mut cache = REGEX_CACHE.write().unwrap();
cache.clear();
}
let pattern = r"[a-zA-Z]+";
let result1 = get_cached_regex(pattern);
assert!(result1.is_ok());
let result2 = get_cached_regex(pattern);
assert!(result2.is_ok());
// Both should work and second should be from cache
let regex1 = result1.unwrap();
let regex2 = result2.unwrap();
assert!(regex1.is_match("hello"));
assert!(regex2.is_match("world"));
}
#[test]
fn test_get_cached_regex_invalid_pattern() {
let pattern = r"[";
let result = get_cached_regex(pattern);
assert!(result.is_err());
if let Err(NCBError::InvalidRegex(msg)) = result {
// The error message contains the pattern and the regex error
assert!(msg.contains("["));
} else {
panic!("Expected InvalidRegex error");
}
}
#[test]
fn test_rate_limiter_basic() {
let limiter = RateLimiter::new(5.0, 1.0); // 5 tokens, 1 per second
// Should be able to acquire 5 tokens initially
assert!(limiter.try_acquire(1.0));
assert!(limiter.try_acquire(1.0));
assert!(limiter.try_acquire(1.0));
assert!(limiter.try_acquire(1.0));
assert!(limiter.try_acquire(1.0));
// 6th token should fail
assert!(!limiter.try_acquire(1.0));
}
#[test]
fn test_rate_limiter_partial_tokens() {
let limiter = RateLimiter::new(2.0, 1.0);
// Acquire partial tokens
assert!(limiter.try_acquire(0.5));
assert!(limiter.try_acquire(0.5));
assert!(limiter.try_acquire(0.5));
assert!(limiter.try_acquire(0.5));
// Should fail with no tokens left
assert!(!limiter.try_acquire(0.1));
}
#[test]
fn test_performance_metrics_increment() {
let metrics = PerformanceMetrics::default();
assert_eq!(metrics.tts_requests.load(std::sync::atomic::Ordering::Relaxed), 0);
metrics.increment_tts_requests();
metrics.increment_tts_requests();
assert_eq!(metrics.tts_requests.load(std::sync::atomic::Ordering::Relaxed), 2);
metrics.increment_tts_cache_hits();
assert_eq!(metrics.tts_cache_hits.load(std::sync::atomic::Ordering::Relaxed), 1);
metrics.increment_tts_cache_misses();
assert_eq!(metrics.tts_cache_misses.load(std::sync::atomic::Ordering::Relaxed), 1);
}
#[test]
fn test_metrics_snapshot_cache_hit_rate() {
let snapshot = MetricsSnapshot {
tts_requests: 10,
tts_cache_hits: 7,
tts_cache_misses: 3,
regex_cache_hits: 0,
regex_cache_misses: 0,
database_operations: 0,
voice_connections: 0,
};
assert!((snapshot.tts_cache_hit_rate() - 0.7).abs() < f64::EPSILON);
let empty_snapshot = MetricsSnapshot {
tts_requests: 0,
tts_cache_hits: 0,
tts_cache_misses: 0,
regex_cache_hits: 0,
regex_cache_misses: 0,
database_operations: 0,
voice_connections: 0,
};
assert_eq!(empty_snapshot.tts_cache_hit_rate(), 0.0);
}
#[test]
fn test_metrics_snapshot_regex_cache_hit_rate() {
let snapshot = MetricsSnapshot {
tts_requests: 0,
tts_cache_hits: 0,
tts_cache_misses: 0,
regex_cache_hits: 8,
regex_cache_misses: 2,
database_operations: 0,
voice_connections: 0,
};
assert!((snapshot.regex_cache_hit_rate() - 0.8).abs() < f64::EPSILON);
}
#[test]
fn test_performance_metrics_get_stats() {
let metrics = PerformanceMetrics::default();
// Add some data
metrics.increment_tts_requests();
metrics.increment_tts_requests();
metrics.increment_tts_cache_hits();
metrics.increment_database_operations();
let stats = metrics.get_stats();
assert_eq!(stats.tts_requests, 2);
assert_eq!(stats.tts_cache_hits, 1);
assert_eq!(stats.tts_cache_misses, 0);
assert_eq!(stats.database_operations, 1);
}
}