diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 5bafa70..9ac5e8d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -28,7 +28,7 @@ jobs: platforms: linux/amd64,linux/arm64 - name: Cache Docker layers - uses: actions/cache@v2 + uses: actions/cache@v4 with: path: /tmp/.buildx-cache key: ${{ runner.os }}-buildx-${{ github.sha }} diff --git a/Cargo.toml b/Cargo.toml index 0baadea..050fe74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ncb-tts-r2" -version = "1.7.0" +version = "1.10.1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -13,18 +13,51 @@ gcp_auth = "0.12.3" reqwest = { version = "0.12.9", features = ["json"] } base64 = "0.22.1" async-trait = "0.1.57" -redis = "*" +redis = "0.29.2" regex = "1" -poise = "0.6.1" +tracing-subscriber = "0.3.19" +lru = "0.13.0" +tracing = "0.1.41" +opentelemetry_sdk = { version = "0.29.0", features = ["trace"] } +opentelemetry = "0.29.1" +opentelemetry-semantic-conventions = "0.29.0" +opentelemetry-otlp = { version = "0.29.0", features = ["grpc-tonic"] } +opentelemetry-stdout = "0.29.0" +tracing-opentelemetry = "0.30.0" +symphonia-core = "0.5.4" +tokio-util = { version = "0.7.14", features = ["compat"] } +futures = "0.3.31" +bytes = "1.10.1" +voicevox-client = { git = "https://github.com/mii443/rust" } [dependencies.uuid] version = "1.11.0" features = ["serde", "v4"] [dependencies.songbird] -version = "0.4.4" +version = "0.5" features = ["builtin-queue"] +[dependencies.symphonia] +version = "0.5" +features = ["mp3"] + +[dependencies.serenity] +version = "0.12" +features = [ + "builder", + "cache", + "client", + "gateway", + "model", + "utils", + "unstable_discord_api", + "collector", + "rustls_backend", + "framework", + "voice", +] + [dependencies.tokio] version = "1.0" features = ["macros", "rt-multi-thread"] diff --git a/Dockerfile b/Dockerfile index 223a9ad..92a4a12 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM lukemathwalker/cargo-chef:latest-rust-1.72 AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.82 AS chef WORKDIR app FROM chef AS planner @@ -14,6 +14,6 @@ RUN cargo build --release FROM ubuntu:22.04 AS runtime WORKDIR /ncb-tts-r2 -RUN apt-get update && apt-get install -y --no-install-recommends openssl ca-certificates ffmpeg libssl-dev libopus-dev && apt-get -y clean && mkdir audio +RUN apt-get update && apt-get install -y --no-install-recommends openssl ca-certificates ffmpeg libssl-dev libopus-dev && apt-get -y clean COPY --from=builder /app/target/release/ncb-tts-r2 /usr/local/bin ENTRYPOINT ["/usr/local/bin/ncb-tts-r2"] diff --git a/docker-compose.yml b/docker-compose.yml index 2d86c50..1bc32b9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,7 +3,7 @@ version: '3' services: ncb-tts-r2: container_name: ncb-tts-r2 - image: ghcr.io/mii443/ncb-tts-r2:1.7.3 + image: ghcr.io/mii443/ncb-tts-r2:1.10.1 environment: - NCB_TOKEN=YOUR_BOT_TOKEN - NCB_APP_ID=YOUR_BOT_ID diff --git a/manifest/ncb-tts.yaml b/manifest/ncb-tts.yaml index 0ab3749..020f776 100644 --- a/manifest/ncb-tts.yaml +++ b/manifest/ncb-tts.yaml @@ -22,7 +22,7 @@ spec: - name: ncb-redis-pvc mountPath: /data - name: tts - image: ghcr.io/morioka22/ncb-tts-r2 + image: ghcr.io/mii443/ncb-tts-r2 volumeMounts: - name: gcp-credentials mountPath: /ncb-tts-r2/credentials.json diff --git a/src/commands/config.rs b/src/commands/config.rs index 3173b40..7c88f92 100644 --- a/src/commands/config.rs +++ b/src/commands/config.rs @@ -1,7 +1,8 @@ use serenity::{ - model::prelude::{ - component::ButtonStyle, - interaction::{application_command::ApplicationCommandInteraction, MessageFlags}, + all::{ + ButtonStyle, CommandInteraction, CreateActionRow, CreateButton, CreateInteractionResponse, + CreateInteractionResponseMessage, CreateSelectMenu, CreateSelectMenuKind, + CreateSelectMenuOption, }, prelude::Context, }; @@ -11,9 +12,10 @@ use crate::{ tts::tts_type::TTSType, }; +#[tracing::instrument] pub async fn config_command( ctx: &Context, - command: &ApplicationCommandInteraction, + command: &CommandInteraction, ) -> Result<(), Box> { let data_read = ctx.data.read().await; @@ -22,9 +24,8 @@ pub async fn config_command( .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; database - .get_user_config_or_default(command.user.id.0) + .get_user_config_or_default(command.user.id.get()) .await .unwrap() .unwrap() @@ -32,84 +33,66 @@ pub async fn config_command( let tts_client = data_read .get::() - .expect("Cannot get TTSClientData") - .clone(); - let voicevox_speakers = tts_client.lock().await.1.get_styles().await; + .expect("Cannot get TTSClientData"); + let voicevox_speakers = tts_client.voicevox_client.get_styles().await; let voicevox_speaker = config.voicevox_speaker.unwrap_or(1); let tts_type = config.tts_type.unwrap_or(TTSType::GCP); + let engine_select = CreateActionRow::SelectMenu( + CreateSelectMenu::new( + "TTS_CONFIG_ENGINE", + CreateSelectMenuKind::String { + options: vec![ + CreateSelectMenuOption::new("Google TTS", "TTS_CONFIG_ENGINE_SELECTED_GOOGLE") + .default_selection(tts_type == TTSType::GCP), + CreateSelectMenuOption::new("VOICEVOX", "TTS_CONFIG_ENGINE_SELECTED_VOICEVOX") + .default_selection(tts_type == TTSType::VOICEVOX), + ], + }, + ) + .placeholder("読み上げAPIを選択"), + ); + + let server_button = CreateActionRow::Buttons(vec![CreateButton::new("TTS_CONFIG_SERVER") + .label("サーバー設定") + .style(ButtonStyle::Primary)]); + + let mut components = vec![engine_select, server_button]; + + for (index, speaker_chunk) in voicevox_speakers[0..24].chunks(25).enumerate() { + let mut options = Vec::new(); + + for (name, id) in speaker_chunk { + options.push( + CreateSelectMenuOption::new( + name, + format!("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_{}", id), + ) + .default_selection(*id == voicevox_speaker), + ); + } + + components.push(CreateActionRow::SelectMenu( + CreateSelectMenu::new( + format!("TTS_CONFIG_VOICEVOX_SPEAKER_{}", index), + CreateSelectMenuKind::String { options }, + ) + .placeholder("VOICEVOX Speakerを指定"), + )); + } + command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("読み上げ設定") - .components(|c| { - let mut c = c; - c = c - .create_action_row(|a| { - a.create_select_menu(|m| { - m.custom_id("TTS_CONFIG_ENGINE") - .options(|o| { - o.create_option(|co| { - co.label("Google TTS") - .value("TTS_CONFIG_ENGINE_SELECTED_GOOGLE") - .default_selection(tts_type == TTSType::GCP) - }) - .create_option(|co| { - co.label("VOICEVOX") - .value("TTS_CONFIG_ENGINE_SELECTED_VOICEVOX") - .default_selection( - tts_type == TTSType::VOICEVOX, - ) - }) - }) - .placeholder("読み上げAPIを選択") - }) - }) - .create_action_row(|a| { - a.create_button(|f| { - f.label("サーバー設定") - .custom_id("TTS_CONFIG_SERVER") - .style(ButtonStyle::Primary) - }) - }); - - for (index, speaker_chunk) in - voicevox_speakers[0..24].chunks(25).enumerate() - { - c = c.create_action_row(|a| { - let mut a = a; - a = a.create_select_menu(|m| { - m.custom_id( - "TTS_CONFIG_VOICEVOX_SPEAKER_".to_string() - + &index.to_string(), - ) - .options(|o| { - let mut o = o; - for (name, id) in speaker_chunk { - o = o.create_option(|co| { - co.label(name) - .value(format!( - "TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_{}", - id - )) - .default_selection(*id == voicevox_speaker) - }) - } - o - }) - .placeholder("VOICEVOX Speakerを指定") - }); - a - }) - } - - println!("{:?}", c); - c - }) - .flags(MessageFlags::EPHEMERAL) - }) - }) + .create_response( + &ctx.http, + CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("読み上げ設定") + .components(components) + .ephemeral(true), + ), + ) .await?; + Ok(()) } diff --git a/src/commands/setup.rs b/src/commands/setup.rs index 21d738c..d7b8e25 100644 --- a/src/commands/setup.rs +++ b/src/commands/setup.rs @@ -1,61 +1,53 @@ use serenity::{ - model::prelude::{ - interaction::{application_command::ApplicationCommandInteraction, MessageFlags}, - UserId, + all::{ + AutoArchiveDuration, CommandInteraction, CreateEmbed, CreateInteractionResponse, CreateInteractionResponseMessage, CreateMessage, CreateThread }, + model::prelude::UserId, prelude::Context, }; +use tracing::info; use crate::{ data::{TTSClientData, TTSData}, tts::instance::TTSInstance, }; +#[tracing::instrument] pub async fn setup_command( ctx: &Context, - command: &ApplicationCommandInteraction, + command: &CommandInteraction, ) -> Result<(), Box> { - println!("Received event"); - if let None = command.guild_id { + info!("Received event"); + + if command.guild_id.is_none() { command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("このコマンドはサーバーでのみ使用可能です.") - .flags(MessageFlags::EPHEMERAL) - }) - }) + .create_response(&ctx.http, + CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("このコマンドはサーバーでのみ使用可能です.") + .ephemeral(true) + )) .await?; return Ok(()); } - println!("Fetching guild cache"); - let guild = command.guild_id.unwrap().to_guild_cached(&ctx.cache); - if let None = guild { - command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("ギルドキャッシュを取得できませんでした.") - .flags(MessageFlags::EPHEMERAL) - }) - }) - .await?; - return Ok(()); - } - let guild = guild.unwrap(); + info!("Fetching guild cache"); + let guild_id = command.guild_id.unwrap(); + let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone(); let channel_id = guild .voice_states - .get(&UserId(command.user.id.0)) + .get(&UserId::from(command.user.id.get())) .and_then(|state| state.channel_id); - if let None = channel_id { + if channel_id.is_none() { command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("ボイスチャンネルに参加してから実行してください.") - .flags(MessageFlags::EPHEMERAL) - }) - }) + .create_response(&ctx.http, + CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("ボイスチャンネルに参加してから実行してください.") + .ephemeral(true) + )) .await?; return Ok(()); } @@ -79,39 +71,34 @@ pub async fn setup_command( let mut storage = storage_lock.write().await; if storage.contains_key(&guild.id) { command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("すでにセットアップしています.") - .flags(MessageFlags::EPHEMERAL) - }) - }) + .create_response(&ctx.http, + CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("すでにセットアップしています.") + .ephemeral(true) + )) .await?; return Ok(()); } let text_channel_id = { if let Some(mode) = command.data.options.get(0) { - let mode = mode.clone(); - let value = mode.value.unwrap(); - let value = value.as_str().unwrap(); - match value { - "TEXT_CHANNEL" => command.channel_id, - "NEW_THREAD" => { - let message = command - .channel_id - .send_message(&ctx.http, |f| f.content("TTS thread")) - .await - .unwrap(); - command - .channel_id - .create_public_thread(&ctx.http, message, |f| { - f.name("TTS").auto_archive_duration(60) - }) - .await - .unwrap() - .id - } - "VOICE_CHANNEL" => channel_id, + match &mode.value { + serenity::all::CommandDataOptionValue::String(value) => { + match value.as_str() { + "TEXT_CHANNEL" => command.channel_id, + "NEW_THREAD" => { + command + .channel_id + .create_thread(&ctx.http, CreateThread::new("TTS").auto_archive_duration(AutoArchiveDuration::OneHour).kind(serenity::all::ChannelType::PublicThread)) + .await + .unwrap() + .id + } + "VOICE_CHANNEL" => channel_id, + _ => channel_id, + } + }, _ => channel_id, } } else { @@ -133,27 +120,37 @@ pub async fn setup_command( }; command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content(format!("TTS Channel: <#{}>{}", text_channel_id, if text_channel_id == channel_id { "\nボイスチャンネルを右クリックし `チャットを開く` を押して開くことが出来ます。" } else { "" })) - }) - }) + .create_response(&ctx.http, + CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(format!( + "TTS Channel: <#{}>{}", + text_channel_id, + if text_channel_id == channel_id { + "\nボイスチャンネルを右クリックし `チャットを開く` を押して開くことが出来ます。" + } else { + "" + } + )) + )) .await?; - let _handler = manager.join(guild.id.0, channel_id.0).await; - let tts_client = ctx + let _handler = manager.join(guild.id, channel_id).await; + + let data = ctx .data .read() - .await + .await; + let tts_client = data .get::() - .expect("Cannot get TTSClientData") - .clone(); - let voicevox_speakers = tts_client.lock().await.1.get_speakers().await; + .expect("Cannot get TTSClientData"); + let voicevox_speakers = tts_client.voicevox_client.get_speakers().await; text_channel_id - .send_message(&ctx.http, |f| { - f.embed(|e| { - e.title("読み上げ (Serenity)") + .send_message(&ctx.http, CreateMessage::new() + .embed( + CreateEmbed::new() + .title("読み上げ (Serenity)") .field( "VOICEVOXクレジット", format!("```\n{}\n```", voicevox_speakers.join("\n")), @@ -161,9 +158,8 @@ pub async fn setup_command( ) .field("設定コマンド", "`/config`", false) .field("フィードバック", "https://feedback.mii.codes/", false) - }) - }) + )) .await?; Ok(()) -} +} \ No newline at end of file diff --git a/src/commands/skip.rs b/src/commands/skip.rs index 7ce2ce4..7e8e159 100644 --- a/src/commands/skip.rs +++ b/src/commands/skip.rs @@ -1,8 +1,8 @@ use serenity::{ - model::prelude::{ - interaction::{application_command::ApplicationCommandInteraction, MessageFlags}, - UserId, + all::{ + CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage }, + model::prelude::UserId, prelude::Context, }; @@ -10,47 +10,36 @@ use crate::data::TTSData; pub async fn skip_command( ctx: &Context, - command: &ApplicationCommandInteraction, + command: &CommandInteraction, ) -> Result<(), Box> { - if let None = command.guild_id { + if command.guild_id.is_none() { command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("このコマンドはサーバーでのみ使用可能です.") - .flags(MessageFlags::EPHEMERAL) - }) - }) + .create_response(&ctx.http, + CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("このコマンドはサーバーでのみ使用可能です.") + .ephemeral(true) + )) .await?; return Ok(()); } - let guild = command.guild_id.unwrap().to_guild_cached(&ctx.cache); - if let None = guild { - command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("ギルドキャッシュを取得できませんでした.") - .flags(MessageFlags::EPHEMERAL) - }) - }) - .await?; - return Ok(()); - } - let guild = guild.unwrap(); + let guild_id = command.guild_id.unwrap(); + let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone(); let channel_id = guild .voice_states - .get(&UserId(command.user.id.0)) + .get(&UserId::from(command.user.id.get())) .and_then(|state| state.channel_id); - if let None = channel_id { + if channel_id.is_none() { command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("ボイスチャンネルに参加してから実行してください.") - .flags(MessageFlags::EPHEMERAL) - }) - }) + .create_response(&ctx.http, + CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("ボイスチャンネルに参加してから実行してください.") + .ephemeral(true) + )) .await?; return Ok(()); } @@ -67,24 +56,26 @@ pub async fn skip_command( let mut storage = storage_lock.write().await; if !storage.contains_key(&guild.id) { command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("読み上げしていません") - .flags(MessageFlags::EPHEMERAL) - }) - }) + .create_response(&ctx.http, + CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("読み上げしていません") + .ephemeral(true) + )) .await?; return Ok(()); } - storage.get_mut(&guild.id).unwrap().skip(&ctx).await; + storage.get_mut(&guild.id).unwrap().skip(ctx).await; } command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| d.content("スキップしました")) - }) + .create_response(&ctx.http, + CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("スキップしました") + )) .await?; Ok(()) -} +} \ No newline at end of file diff --git a/src/commands/stop.rs b/src/commands/stop.rs index 58fd045..6c4e20d 100644 --- a/src/commands/stop.rs +++ b/src/commands/stop.rs @@ -1,56 +1,46 @@ + use serenity::{ - model::prelude::{ - interaction::{application_command::ApplicationCommandInteraction, MessageFlags}, - UserId, + all::{ + CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, EditThread }, - prelude::Context, + model::prelude::UserId, + prelude::Context }; use crate::data::TTSData; pub async fn stop_command( ctx: &Context, - command: &ApplicationCommandInteraction, + command: &CommandInteraction, ) -> Result<(), Box> { - if let None = command.guild_id { + if command.guild_id.is_none() { command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("このコマンドはサーバーでのみ使用可能です.") - .flags(MessageFlags::EPHEMERAL) - }) - }) + .create_response(&ctx.http, + CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("このコマンドはサーバーでのみ使用可能です.") + .ephemeral(true) + )) .await?; return Ok(()); } - let guild = command.guild_id.unwrap().to_guild_cached(&ctx.cache); - if let None = guild { - command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("ギルドキャッシュを取得できませんでした.") - .flags(MessageFlags::EPHEMERAL) - }) - }) - .await?; - return Ok(()); - } - let guild = guild.unwrap(); + let guild_id = command.guild_id.unwrap(); + let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone(); let channel_id = guild .voice_states - .get(&UserId(command.user.id.0)) + .get(&UserId::from(command.user.id.get())) .and_then(|state| state.channel_id); - if let None = channel_id { + if channel_id.is_none() { command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("ボイスチャンネルに参加してから実行してください.") - .flags(MessageFlags::EPHEMERAL) - }) - }) + .create_response(&ctx.http, + CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("ボイスチャンネルに参加してから実行してください.") + .ephemeral(true) + )) .await?; return Ok(()); } @@ -70,36 +60,37 @@ pub async fn stop_command( let text_channel_id = { let mut storage = storage_lock.write().await; + if !storage.contains_key(&guild.id) { command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("すでに停止しています") - .flags(MessageFlags::EPHEMERAL) - }) - }) + .create_response(&ctx.http, + CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("すでに停止しています") + .ephemeral(true) + )) .await?; return Ok(()); } let text_channel_id = storage.get(&guild.id).unwrap().text_channel; - storage.remove(&guild.id); - text_channel_id }; - let _handler = manager.remove(guild.id.0).await; + let _handler = manager.remove(guild.id).await; command - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| d.content("停止しました")) - }) + .create_response(&ctx.http, + CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content("停止しました") + )) .await?; let _ = text_channel_id - .edit_thread(&ctx.http, |f| f.archived(true)) + .edit_thread(&ctx.http, EditThread::new().archived(true)) .await; Ok(()) -} +} \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index 63f03be..162e1d6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -6,5 +6,7 @@ pub struct Config { pub token: String, pub application_id: u64, pub redis_url: String, - pub voicevox_key: String, + pub voicevox_key: Option, + pub voicevox_original_api_url: Option, + pub otel_http_url: Option, } diff --git a/src/data.rs b/src/data.rs index 7ecbc86..2409307 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,9 +1,5 @@ -use crate::{ - database::database::Database, - tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX}, -}; +use crate::{database::database::Database, tts::tts::TTS}; use serenity::{ - futures::lock::Mutex, model::id::GuildId, prelude::{RwLock, TypeMapKey}, }; @@ -22,12 +18,12 @@ impl TypeMapKey for TTSData { pub struct TTSClientData; impl TypeMapKey for TTSClientData { - type Value = Arc>; + type Value = Arc; } /// Database client data pub struct DatabaseClientData; impl TypeMapKey for DatabaseClientData { - type Value = Arc>; + type Value = Arc; } diff --git a/src/database/database.rs b/src/database/database.rs index 0cd6177..c3bbba9 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; + use crate::tts::{ gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType, }; @@ -5,6 +7,7 @@ use crate::tts::{ use super::{dictionary::Dictionary, server_config::ServerConfig, user_config::UserConfig}; use redis::Commands; +#[derive(Debug, Clone)] pub struct Database { pub client: redis::Client, } @@ -14,114 +17,116 @@ impl Database { Self { client } } + fn server_key(server_id: u64) -> String { + format!("discord_server:{}", server_id) + } + + fn user_key(user_id: u64) -> String { + format!("discord_user:{}", user_id) + } + + #[tracing::instrument] + fn get_config( + &self, + key: &str, + ) -> redis::RedisResult> { + match self.client.get_connection() { + Ok(mut connection) => { + let config: String = connection.get(key).unwrap_or_default(); + + if config.is_empty() { + return Ok(None); + } + + match serde_json::from_str(&config) { + Ok(config) => Ok(Some(config)), + Err(_) => Ok(None), + } + } + Err(e) => Err(e), + } + } + + #[tracing::instrument] + fn set_config( + &self, + key: &str, + config: &T, + ) -> redis::RedisResult<()> { + match self.client.get_connection() { + Ok(mut connection) => { + let config_str = serde_json::to_string(config).unwrap(); + connection.set::<_, _, ()>(key, config_str) + } + Err(e) => Err(e), + } + } + + #[tracing::instrument] pub async fn get_server_config( - &mut self, + &self, server_id: u64, ) -> redis::RedisResult> { - if let Ok(mut connection) = self.client.get_connection() { - let config: String = connection - .get(format!("discord_server:{}", server_id)) - .unwrap_or_default(); - - match serde_json::from_str(&config) { - Ok(config) => Ok(Some(config)), - Err(_) => Ok(None), - } - } else { - Ok(None) - } + self.get_config(&Self::server_key(server_id)) } - pub async fn get_user_config( - &mut self, - user_id: u64, - ) -> redis::RedisResult> { - if let Ok(mut connection) = self.client.get_connection() { - let config: String = connection - .get(format!("discord_user:{}", user_id)) - .unwrap_or_default(); - - match serde_json::from_str(&config) { - Ok(config) => Ok(Some(config)), - Err(_) => Ok(None), - } - } else { - Ok(None) - } + #[tracing::instrument] + pub async fn get_user_config(&self, user_id: u64) -> redis::RedisResult> { + self.get_config(&Self::user_key(user_id)) } + #[tracing::instrument] pub async fn set_server_config( - &mut self, + &self, server_id: u64, config: ServerConfig, ) -> redis::RedisResult<()> { - let config = serde_json::to_string(&config).unwrap(); - self.client - .get_connection() - .unwrap() - .set::(format!("discord_server:{}", server_id), config) - .unwrap(); - Ok(()) + self.set_config(&Self::server_key(server_id), &config) } + #[tracing::instrument] pub async fn set_user_config( - &mut self, + &self, user_id: u64, config: UserConfig, ) -> redis::RedisResult<()> { - let config = serde_json::to_string(&config).unwrap(); - self.client - .get_connection() - .unwrap() - .set::(format!("discord_user:{}", user_id), config) - .unwrap(); - Ok(()) + self.set_config(&Self::user_key(user_id), &config) } - pub async fn set_default_server_config(&mut self, server_id: u64) -> redis::RedisResult<()> { + #[tracing::instrument] + pub async fn set_default_server_config(&self, server_id: u64) -> redis::RedisResult<()> { let config = ServerConfig { dictionary: Dictionary::new(), autostart_channel_id: None, }; - self.client.get_connection().unwrap().set( - format!("discord_server:{}", server_id), - serde_json::to_string(&config).unwrap(), - )?; - - Ok(()) + self.set_server_config(server_id, config).await } - pub async fn set_default_user_config(&mut self, user_id: u64) -> redis::RedisResult<()> { + #[tracing::instrument] + pub async fn set_default_user_config(&self, user_id: u64) -> redis::RedisResult<()> { let voice_selection = VoiceSelectionParams { languageCode: String::from("ja-JP"), name: String::from("ja-JP-Wavenet-B"), ssmlGender: String::from("neutral"), }; - let voice_type = TTSType::GCP; - let config = UserConfig { - tts_type: Some(voice_type), + tts_type: Some(TTSType::GCP), gcp_tts_voice: Some(voice_selection), voicevox_speaker: Some(1), }; - self.client.get_connection().unwrap().set( - format!("discord_user:{}", user_id), - serde_json::to_string(&config).unwrap(), - )?; - - Ok(()) + self.set_user_config(user_id, config).await } + #[tracing::instrument] pub async fn get_server_config_or_default( - &mut self, + &self, server_id: u64, ) -> redis::RedisResult> { - let config = self.get_server_config(server_id).await?; - match config { - Some(_) => Ok(config), + match self.get_server_config(server_id).await? { + Some(config) => Ok(Some(config)), None => { self.set_default_server_config(server_id).await?; self.get_server_config(server_id).await @@ -129,13 +134,13 @@ impl Database { } } + #[tracing::instrument] pub async fn get_user_config_or_default( - &mut self, + &self, user_id: u64, ) -> redis::RedisResult> { - let config = self.get_user_config(user_id).await?; - match config { - Some(_) => Ok(config), + match self.get_user_config(user_id).await? { + Some(config) => Ok(Some(config)), None => { self.set_default_user_config(user_id).await?; self.get_user_config(user_id).await diff --git a/src/event_handler.rs b/src/event_handler.rs index e6f9589..e986ca2 100644 --- a/src/event_handler.rs +++ b/src/event_handler.rs @@ -8,34 +8,37 @@ use crate::{ tts::tts_type::TTSType, }; use serenity::{ + all::{ + ActionRowComponent, ButtonStyle, ComponentInteractionDataKind, CreateActionRow, + CreateButton, CreateEmbed, CreateInputText, CreateInteractionResponse, + CreateInteractionResponseMessage, CreateModal, CreateSelectMenu, CreateSelectMenuKind, + CreateSelectMenuOption, InputTextStyle, + }, async_trait, client::{Context, EventHandler}, model::{ - channel::Message, - gateway::Ready, - prelude::{ - component::{ActionRowComponent, ButtonStyle, InputTextStyle}, - interaction::{Interaction, InteractionResponseType, MessageFlags}, - ChannelType, - }, + application::Interaction, channel::Message, gateway::Ready, prelude::ChannelType, voice::VoiceState, }, }; +#[derive(Clone, Debug)] pub struct Handler; #[async_trait] impl EventHandler for Handler { + #[tracing::instrument] async fn message(&self, ctx: Context, message: Message) { events::message_receive::message(ctx, message).await } + #[tracing::instrument] async fn ready(&self, ctx: Context, ready: Ready) { events::ready::ready(ctx, ready).await } async fn interaction_create(&self, ctx: Context, interaction: Interaction) { - if let Interaction::ApplicationCommand(command) = interaction.clone() { + if let Interaction::Command(command) = interaction.clone() { let name = &*command.data.name; match name { "setup" => setup_command(&ctx, &command).await.unwrap(), @@ -45,7 +48,7 @@ impl EventHandler for Handler { _ => {} } } - if let Interaction::ModalSubmit(modal) = interaction.clone() { + if let Interaction::Modal(modal) = interaction.clone() { if modal.data.custom_id != "TTS_CONFIG_SERVER_ADD_DICTIONARY" { return; } @@ -53,19 +56,19 @@ impl EventHandler for Handler { let rows = modal.data.components.clone(); let rule_name = if let ActionRowComponent::InputText(text) = rows[0].components[0].clone() { - text.value + text.value.unwrap() } else { panic!("Cannot get rule name"); }; let from = if let ActionRowComponent::InputText(text) = rows[1].components[0].clone() { - text.value + text.value.unwrap() } else { panic!("Cannot get from"); }; let to = if let ActionRowComponent::InputText(text) = rows[2].components[0].clone() { - text.value + text.value.unwrap() } else { panic!("Cannot get to"); }; @@ -84,9 +87,9 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database - .get_server_config_or_default(modal.guild_id.unwrap().0) + .get_server_config_or_default(modal.guild_id.unwrap().get()) .await .unwrap() .unwrap() @@ -98,22 +101,21 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database - .set_server_config(modal.guild_id.unwrap().0, config) + .set_server_config(modal.guild_id.unwrap().get(), config) .await .unwrap(); modal - .create_interaction_response(&ctx.http, |f| { - f.kind(InteractionResponseType::UpdateMessage) - .interaction_response_data(|d| { - d.custom_id("TTS_CONFIG_SERVER_ADD_DICTIONARY_RESPONSE") - .content(format!( - "辞書を追加しました\n名前: {}\n変換元: {}\n変換後: {}", - rule_name, from, to - )) - }) - }) + .create_response( + &ctx.http, + CreateInteractionResponse::UpdateMessage( + CreateInteractionResponseMessage::new().content(format!( + "辞書を追加しました\n名前: {}\n変換元: {}\n変換後: {}", + rule_name, from, to + )), + ), + ) .await .unwrap(); } @@ -121,7 +123,16 @@ impl EventHandler for Handler { if let Some(message_component) = interaction.message_component() { match &*message_component.data.custom_id { "TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU" => { - let i = usize::from_str_radix(&message_component.data.values[0], 10).unwrap(); + let i = usize::from_str_radix( + &match message_component.data.kind { + ComponentInteractionDataKind::StringSelect { ref values, .. } => { + values[0].clone() + } + _ => panic!("Cannot get index"), + }, + 10, + ) + .unwrap(); let data_read = ctx.data.read().await; let mut config = { @@ -129,9 +140,9 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database - .get_server_config_or_default(message_component.guild_id.unwrap().0) + .get_server_config_or_default(message_component.guild_id.unwrap().get()) .await .unwrap() .unwrap() @@ -143,22 +154,21 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database - .set_server_config(message_component.guild_id.unwrap().0, config) + .set_server_config(message_component.guild_id.unwrap().get(), config) .await .unwrap(); } message_component - .create_interaction_response(&ctx, |f| { - f.kind(InteractionResponseType::UpdateMessage) - .interaction_response_data(|d| { - d.custom_id("DICTIONARY_REMOVED") - .content("辞書を削除しました") - .components(|c| c) - }) - }) + .create_response( + &ctx, + CreateInteractionResponse::UpdateMessage( + CreateInteractionResponseMessage::new() + .content("辞書を削除しました"), + ), + ) .await .unwrap(); } @@ -170,53 +180,49 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database - .get_server_config_or_default(message_component.guild_id.unwrap().0) + .get_server_config_or_default(message_component.guild_id.unwrap().get()) .await .unwrap() .unwrap() }; message_component - .create_interaction_response(&ctx.http, |f| { - f.kind(InteractionResponseType::UpdateMessage) - .interaction_response_data(|d| { - d.custom_id("TTS_CONFIG_SERVER_REMOVE_DICTIONARY") - .content("削除する辞書内容を選択してください") - .components(|c| { - c.create_action_row(|a| { - a.create_select_menu(|s| { - s.custom_id( - "TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU", - ) - .options(|o| { - let mut o = o; - for (i, rule) in config - .dictionary - .rules - .iter() - .enumerate() - { - o = o.create_option(|c| { - c.label(rule.id.clone()) - .value(i) - .description(format!( - "{} -> {}", - rule.rule.clone(), - rule.to.clone() - )) - }); - } - o - }) - .max_values(1) - .min_values(0) - }) - }) - }) - }) - }) + .create_response( + &ctx.http, + CreateInteractionResponse::UpdateMessage( + CreateInteractionResponseMessage::new() + .content("削除する辞書内容を選択してください") + .components(vec![CreateActionRow::SelectMenu( + CreateSelectMenu::new( + "TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU", + CreateSelectMenuKind::String { + options: { + let mut options = vec![]; + for (i, rule) in + config.dictionary.rules.iter().enumerate() + { + let option = CreateSelectMenuOption::new( + rule.id.clone(), + i.to_string(), + ) + .description(format!( + "{} -> {}", + rule.rule.clone(), + rule.to.clone() + )); + options.push(option); + } + options + }, + }, + ) + .max_values(1) + .min_values(0), + )]), + ), + ) .await .unwrap(); } @@ -227,80 +233,92 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database - .get_server_config_or_default(message_component.guild_id.unwrap().0) + .get_server_config_or_default(message_component.guild_id.unwrap().get()) .await .unwrap() .unwrap() }; message_component - .create_interaction_response(&ctx.http, |f| { - f.kind(InteractionResponseType::UpdateMessage) - .interaction_response_data(|d| { - d.custom_id("DICTIONARY_LIST").content("").embed(|e| { - e.title("辞書一覧"); + .create_response( + &ctx.http, + CreateInteractionResponse::UpdateMessage( + CreateInteractionResponseMessage::new().content("").embed( + CreateEmbed::new().title("辞書一覧").fields({ + let mut fields = vec![]; for rule in config.dictionary.rules { - e.field( - rule.id, + let field = ( + rule.id.clone(), format!("{} -> {}", rule.rule, rule.to), true, ); + fields.push(field); } - e - }) - }) - }) + fields + }), + ), + ), + ) .await .unwrap(); } "TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON" => { message_component - .create_interaction_response(&ctx.http, |f| { - f.kind(InteractionResponseType::Modal) - .interaction_response_data(|d| { - d.custom_id("TTS_CONFIG_SERVER_ADD_DICTIONARY") - .title("辞書追加") - .components(|c| { - c.create_action_row(|a| { - a.create_input_text(|i| { - i.style(InputTextStyle::Short) - .label("Rule name") - .custom_id("rule_name") - .required(true) - }) - }) - .create_action_row(|a| { - a.create_input_text(|i| { - i.style(InputTextStyle::Paragraph) - .label("From") - .custom_id("from") - .required(true) - }) - }) - .create_action_row(|a| { - a.create_input_text(|i| { - i.style(InputTextStyle::Short) - .label("To") - .custom_id("to") - .required(true) - }) - }) - }) - }) - }) + .create_response( + &ctx.http, + CreateInteractionResponse::Modal( + CreateModal::new("TTS_CONFIG_SERVER_ADD_DICTIONARY", "辞書追加") + .components({ + vec![ + CreateActionRow::InputText( + CreateInputText::new( + InputTextStyle::Short, + "rule_name", + "辞書名", + ) + .required(true), + ), + CreateActionRow::InputText( + CreateInputText::new( + InputTextStyle::Paragraph, + "from", + "変換元(正規表現)", + ) + .required(true), + ), + CreateActionRow::InputText( + CreateInputText::new( + InputTextStyle::Short, + "to", + "変換先", + ) + .required(true), + ), + ] + }), + ), + ) .await .unwrap(); } "SET_AUTOSTART_CHANNEL" => { - let autostart_channel_id = if message_component.data.values.len() == 0 { - None - } else { - let ch = message_component.data.values[0] - .strip_prefix("SET_AUTOSTART_CHANNEL_") - .unwrap(); - Some(u64::from_str_radix(ch, 10).unwrap()) + let autostart_channel_id = match message_component.data.kind { + ComponentInteractionDataKind::StringSelect { ref values, .. } => { + if values.len() == 0 { + None + } else { + Some( + u64::from_str_radix( + &values[0].strip_prefix("SET_AUTOSTART_CHANNEL_").unwrap(), + 10, + ) + .unwrap(), + ) + } + } + _ => panic!("Cannot get index"), }; { let data_read = ctx.data.read().await; @@ -308,27 +326,27 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + let mut config = database - .get_server_config_or_default(message_component.guild_id.unwrap().0) + .get_server_config_or_default(message_component.guild_id.unwrap().get()) .await .unwrap() .unwrap(); config.autostart_channel_id = autostart_channel_id; database - .set_server_config(message_component.guild_id.unwrap().0, config) + .set_server_config(message_component.guild_id.unwrap().get(), config) .await .unwrap(); }; message_component - .create_interaction_response(&ctx.http, |c| { - c.kind(InteractionResponseType::UpdateMessage) - .interaction_response_data(|d| { - d.content("自動参加チャンネルを設定しました。") - .components(|f| f) - }) - }) + .create_response( + &ctx.http, + CreateInteractionResponse::UpdateMessage( + CreateInteractionResponseMessage::new() + .content("自動参加チャンネルを設定しました。"), + ), + ) .await .unwrap(); } @@ -339,9 +357,9 @@ impl EventHandler for Handler { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; + database - .get_server_config_or_default(message_component.guild_id.unwrap().0) + .get_server_config_or_default(message_component.guild_id.unwrap().get()) .await .unwrap() .unwrap() @@ -356,166 +374,159 @@ impl EventHandler for Handler { .await .unwrap(); + let mut options = Vec::new(); + for (id, channel) in channels { + if channel.kind != ChannelType::Voice { + continue; + } + + let description = channel + .topic + .unwrap_or_else(|| String::from("No topic provided.")); + let option = CreateSelectMenuOption::new( + &channel.name, + format!("SET_AUTOSTART_CHANNEL_{}", id.get()), + ) + .description(description) + .default_selection(channel.id.get() == autostart_channel_id); + + options.push(option); + } + message_component - .create_interaction_response(&ctx.http, |f| { - f.kind(InteractionResponseType::UpdateMessage) - .interaction_response_data(|d| { - d.custom_id("SET_AUTOSTART_FORM") - .content("自動参加チャンネル設定") - .components(|c| { - c.create_action_row(|a| { - a.create_select_menu(|m| { - m.min_values(0) - .max_values(1) - .disabled(false) - .custom_id("SET_AUTOSTART_CHANNEL") - .options(|o| { - // Create channel list - for (id, channel) in channels { - if channel.kind != ChannelType::Voice { - continue; - } - o.create_option(|co| { - co.label(channel.name) - .description( - channel - .topic - .unwrap_or(String::from( - "No topic provided.", - )), - ) - .value(format!("SET_AUTOSTART_CHANNEL_{}", id.0)) - .default_selection(channel.id.0 == autostart_channel_id) - }); - } - o - }) - }) - }) - }) - }) - }) + .create_response( + &ctx.http, + CreateInteractionResponse::UpdateMessage( + CreateInteractionResponseMessage::new() + .content("自動参加チャンネル設定") + .components(vec![CreateActionRow::SelectMenu( + CreateSelectMenu::new( + "SET_AUTOSTART_CHANNEL", + CreateSelectMenuKind::String { options }, + ) + .min_values(0) + .max_values(1), + )]), + ), + ) .await .unwrap(); } "TTS_CONFIG_SERVER" => { message_component - .create_interaction_response(&ctx.http, |f| { - f.kind(InteractionResponseType::UpdateMessage) - .interaction_response_data(|d| { - d.content("サーバー設定") - .custom_id("TTS_CONFIG_SERVER") - .components(|c| { - c.create_action_row(|a| { - a.create_button(|b| { - b.custom_id( - "TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON", - ) - .label("辞書を追加") - .style(ButtonStyle::Primary) - }) - .create_button(|b| { - b.custom_id( - "TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON", - ) - .label("辞書を削除") - .style(ButtonStyle::Danger) - }) - .create_button(|b| { - b.custom_id( - "TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON", - ) - .label("辞書一覧") - .style(ButtonStyle::Primary) - }) - .create_button(|b| { - b.custom_id( - "TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL" - ) - .label("自動参加チャンネル") - .style(ButtonStyle::Primary) - }) - }) - }) - }) - }) + .create_response( + &ctx.http, + CreateInteractionResponse::UpdateMessage( + CreateInteractionResponseMessage::new() + .content("サーバー設定") + .components(vec![CreateActionRow::Buttons(vec![ + CreateButton::new( + "TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON", + ) + .label("辞書を追加") + .style(ButtonStyle::Primary), + CreateButton::new( + "TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON", + ) + .label("辞書を削除") + .style(ButtonStyle::Danger), + CreateButton::new( + "TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON", + ) + .label("辞書一覧") + .style(ButtonStyle::Primary), + CreateButton::new( + "TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL", + ) + .label("自動参加チャンネル") + .style(ButtonStyle::Primary), + ])]), + ), + ) .await .unwrap(); } _ => {} } - if let Some(v) = message_component.data.values.get(0) { - let data_read = ctx.data.read().await; + match message_component.data.kind { + ComponentInteractionDataKind::StringSelect { ref values, .. } + if !values.is_empty() => + { + let res = &values[0].clone(); + let data_read = ctx.data.read().await; - let mut config = { - let database = data_read - .get::() - .expect("Cannot get DatabaseClientData") - .clone(); - let mut database = database.lock().await; - database - .get_user_config_or_default(message_component.user.id.0) - .await - .unwrap() - .unwrap() - }; + let mut config = { + let database = data_read + .get::() + .expect("Cannot get DatabaseClientData") + .clone(); - let res = (*v).clone(); - let mut config_changed = false; - let mut voicevox_changed = false; - match &*res { - "TTS_CONFIG_ENGINE_SELECTED_GOOGLE" => { - config.tts_type = Some(TTSType::GCP); - config_changed = true; - } - "TTS_CONFIG_ENGINE_SELECTED_VOICEVOX" => { - config.tts_type = Some(TTSType::VOICEVOX); - config_changed = true; - } - _ => { - if res.starts_with("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_") { - config.voicevox_speaker = Some( - i64::from_str_radix( - &res.replace("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_", ""), - 10, - ) - .unwrap(), - ); + database + .get_user_config_or_default(message_component.user.id.get()) + .await + .unwrap() + .unwrap() + }; + + let mut config_changed = false; + let mut voicevox_changed = false; + + match res.as_str() { + "TTS_CONFIG_ENGINE_SELECTED_GOOGLE" => { + config.tts_type = Some(TTSType::GCP); config_changed = true; - voicevox_changed = true; + } + "TTS_CONFIG_ENGINE_SELECTED_VOICEVOX" => { + config.tts_type = Some(TTSType::VOICEVOX); + config_changed = true; + } + _ => { + if res.starts_with("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_") { + let speaker_id = res + .strip_prefix("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_") + .and_then(|id_str| id_str.parse::().ok()) + .expect("Invalid speaker ID format"); + + config.voicevox_speaker = Some(speaker_id); + config_changed = true; + voicevox_changed = true; + } } } - } - if config_changed { - let database = data_read - .get::() - .expect("Cannot get DatabaseClientData") - .clone(); - let mut database = database.lock().await; - database - .set_user_config(message_component.user.id.0, config.clone()) - .await - .unwrap(); + if config_changed { + let database = data_read + .get::() + .expect("Cannot get DatabaseClientData") + .clone(); + + database + .set_user_config(message_component.user.id.get(), config.clone()) + .await + .unwrap(); + + let response_content = if voicevox_changed + && config.tts_type.unwrap_or(TTSType::GCP) == TTSType::GCP + { + "設定しました\nこの音声を使うにはAPIをGoogleからVOICEVOXに変更する必要があります。" + } else { + "設定しました" + }; - if voicevox_changed && config.tts_type.unwrap_or(TTSType::GCP) == TTSType::GCP { - message_component.create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("設定しました\nこの音声を使うにはAPIをGoogleからVOICEVOXに変更する必要があります。") - .flags(MessageFlags::EPHEMERAL) - }) - }).await.unwrap(); - } else { message_component - .create_interaction_response(&ctx.http, |f| { - f.interaction_response_data(|d| { - d.content("設定しました").flags(MessageFlags::EPHEMERAL) - }) - }) + .create_response( + &ctx.http, + CreateInteractionResponse::Message( + CreateInteractionResponseMessage::new() + .content(response_content) + .ephemeral(true), + ), + ) .await .unwrap(); } } + _ => {} } } } diff --git a/src/events/message_receive.rs b/src/events/message_receive.rs index f0611ad..d80f470 100644 --- a/src/events/message_receive.rs +++ b/src/events/message_receive.rs @@ -31,7 +31,7 @@ pub async fn message(ctx: Context, message: Message) { let instance = storage.get_mut(&guild_id).unwrap(); - if instance.text_channel.0 != message.channel_id.0 { + if instance.text_channel != message.channel_id { return; } diff --git a/src/events/ready.rs b/src/events/ready.rs index 7ba1cd6..43f2e44 100644 --- a/src/events/ready.rs +++ b/src/events/ready.rs @@ -1,32 +1,33 @@ use serenity::{ - model::prelude::{command::Command, Ready}, + all::{Command, CommandOptionType, CreateCommand, CreateCommandOption}, + model::prelude::Ready, prelude::Context, }; +use tracing::info; +#[tracing::instrument] pub async fn ready(ctx: Context, ready: Ready) { - println!("{} is connected!", ready.user.name); + info!("{} is connected!", ready.user.name); - let _ = Command::set_global_application_commands(&ctx.http, |commands| { - commands - .create_application_command(|command| command.name("stop").description("Stop tts")) - .create_application_command(|command| { - command - .name("setup") - .description("Setup tts") - .create_option(|o| { - o.name("mode") - .description("TTS channel") - .add_string_choice("Text Channel", "TEXT_CHANNEL") - .add_string_choice("New Thread", "NEW_THREAD") - .add_string_choice("Voice Channel", "VOICE_CHANNEL") - .kind(serenity::model::prelude::command::CommandOptionType::String) - .required(false) - }) - }) - .create_application_command(|command| command.name("config").description("Config")) - .create_application_command(|command| { - command.name("skip").description("skip tts message") - }) - }) - .await; + Command::set_global_commands( + &ctx.http, + vec![ + CreateCommand::new("stop").description("Stop tts"), + CreateCommand::new("setup") + .description("Setup tts") + .set_options(vec![CreateCommandOption::new( + CommandOptionType::String, + "mode", + "TTS channel", + ) + .add_string_choice("Text Channel", "TEXT_CHANNEL") + .add_string_choice("New Thread", "NEW_THREAD") + .add_string_choice("Voice Channel", "VOICE_CHANNEL") + .required(false)]), + CreateCommand::new("config").description("Config"), + CreateCommand::new("skip").description("skip tts message"), + ], + ) + .await + .unwrap(); } diff --git a/src/events/voice_state_update.rs b/src/events/voice_state_update.rs index b68c13e..d20e1b1 100644 --- a/src/events/voice_state_update.rs +++ b/src/events/voice_state_update.rs @@ -6,7 +6,11 @@ use crate::{ }, tts::{instance::TTSInstance, message::AnnounceMessage}, }; -use serenity::{model::voice::VoiceState, prelude::Context}; +use serenity::{ + all::{CreateEmbed, CreateMessage, EditThread}, + model::voice::VoiceState, + prelude::Context, +}; pub async fn voice_state_update(ctx: Context, old: Option, new: VoiceState) { if new.member.clone().unwrap().user.bot { @@ -37,9 +41,8 @@ pub async fn voice_state_update(ctx: Context, old: Option, new: Voic .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; database - .get_server_config_or_default(guild_id.0) + .get_server_config_or_default(guild_id.get()) .await .unwrap() .unwrap() @@ -49,7 +52,7 @@ pub async fn voice_state_update(ctx: Context, old: Option, new: Voic let mut storage = storage_lock.write().await; if !storage.contains_key(&guild_id) { if let Some(new_channel) = new.channel_id { - if config.autostart_channel_id.unwrap_or(0) == new_channel.0 { + if config.autostart_channel_id.unwrap_or(0) == new_channel.get() { let manager = songbird::get(&ctx) .await .expect("Cannot get songbird client.") @@ -64,29 +67,28 @@ pub async fn voice_state_update(ctx: Context, old: Option, new: Voic }, ); - let _handler = manager.join(guild_id.0, new_channel.0).await; - let tts_client = ctx - .data - .read() - .await + let _handler = manager.join(guild_id, new_channel).await; + let data = ctx.data.read().await; + let tts_client = data .get::() - .expect("Cannot get TTSClientData") - .clone(); - let voicevox_speakers = tts_client.lock().await.1.get_speakers().await; + .expect("Cannot get TTSClientData"); + let voicevox_speakers = tts_client.voicevox_client.get_speakers().await; new_channel - .send_message(&ctx.http, |f| { - f.embed(|e| { - e.title("自動参加 読み上げ (Serenity)") + .send_message( + &ctx.http, + CreateMessage::new().embed( + CreateEmbed::new() + .title("自動参加 読み上げ(Serenity)") .field( "VOICEVOXクレジット", format!("```\n{}\n```", voicevox_speakers.join("\n")), false, ) .field("設定コマンド", "`/config`", false) - .field("フィードバック", "https://feedback.mii.codes/", false) - }) - }) + .field("フィードバック", "https://feedback.mii.codes/", false), + ), + ) .await .unwrap(); } @@ -118,7 +120,10 @@ pub async fn voice_state_update(ctx: Context, old: Option, new: Voic let mut del_flag = false; for channel in guild_id.channels(&ctx.http).await.unwrap() { if channel.0 == instance.voice_channel { - del_flag = channel.1.members(&ctx.cache).await.unwrap().len() <= 1; + let members = channel.1.members(&ctx.cache).unwrap(); + let user_count = members.iter().filter(|member| !member.user.bot).count(); + + del_flag = user_count == 0; } } @@ -127,7 +132,7 @@ pub async fn voice_state_update(ctx: Context, old: Option, new: Voic .get(&guild_id) .unwrap() .text_channel - .edit_thread(&ctx.http, |f| f.archived(true)) + .edit_thread(&ctx.http, EditThread::new().archived(true)) .await; storage.remove(&guild_id); @@ -136,7 +141,7 @@ pub async fn voice_state_update(ctx: Context, old: Option, new: Voic .expect("Cannot get songbird client.") .clone(); - manager.remove(guild_id.0).await.unwrap(); + manager.remove(guild_id).await.unwrap(); } } } diff --git a/src/implement/member_name.rs b/src/implement/member_name.rs index e41b937..898dca1 100644 --- a/src/implement/member_name.rs +++ b/src/implement/member_name.rs @@ -1,4 +1,7 @@ -use serenity::model::guild::Member; +use serenity::model::{ + guild::{Member, PartialMember}, + user::User, +}; pub trait ReadName { fn read_name(&self) -> String; @@ -6,6 +9,20 @@ pub trait ReadName { impl ReadName for Member { fn read_name(&self) -> String { - self.nick.clone().unwrap_or(self.user.name.clone()) + self.nick.clone().unwrap_or(self.display_name().to_string()) + } +} + +impl ReadName for PartialMember { + fn read_name(&self) -> String { + self.nick + .clone() + .unwrap_or(self.user.as_ref().unwrap().display_name().to_string()) + } +} + +impl ReadName for User { + fn read_name(&self) -> String { + self.display_name().to_string() } } diff --git a/src/implement/message.rs b/src/implement/message.rs index faed4ab..6e93d1f 100644 --- a/src/implement/message.rs +++ b/src/implement/message.rs @@ -1,11 +1,11 @@ -use std::{env, fs::File, io::Write}; - use async_trait::async_trait; use regex::Regex; use serenity::{model::prelude::Message, prelude::Context}; +use songbird::tracks::Track; use crate::{ data::{DatabaseClientData, TTSClientData}, + implement::member_name::ReadName, tts::{ gcp_tts::structs::{ audio_config::AudioConfig, synthesis_input::SynthesisInput, @@ -27,9 +27,8 @@ impl TTSMessage for Message { .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; database - .get_server_config_or_default(instance.guild.0) + .get_server_config_or_default(instance.guild.get()) .await .unwrap() .unwrap() @@ -48,19 +47,29 @@ impl TTSMessage for Message { text.clone() } else { let member = self.member.clone(); - let name = if let Some(member) = member { - member.nick.unwrap_or(self.author.name.clone()) + let name = if let Some(_) = member { + let guild = ctx.cache.guild(self.guild_id.unwrap()).unwrap().clone(); + guild + .member(&ctx.http, self.author.id) + .await + .unwrap() + .read_name() } else { - self.author.name.clone() + self.author.read_name() }; format!("{}さんの発言{}", name, text) } } else { let member = self.member.clone(); - let name = if let Some(member) = member { - member.nick.unwrap_or(self.author.name.clone()) + let name = if let Some(_) = member { + let guild = ctx.cache.guild(self.guild_id.unwrap()).unwrap().clone(); + guild + .member(&ctx.http, self.author.id) + .await + .unwrap() + .read_name() } else { - self.author.name.clone() + self.author.read_name() }; format!("{}さんの発言{}", name, text) }; @@ -78,33 +87,30 @@ impl TTSMessage for Message { res } - async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String { + async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec { let text = self.parse(instance, ctx).await; let data_read = ctx.data.read().await; - let storage = data_read - .get::() - .expect("Cannot get GCP TTSClientStorage") - .clone(); - let mut tts = storage.lock().await; let config = { let database = data_read .get::() .expect("Cannot get DatabaseClientData") .clone(); - let mut database = database.lock().await; database - .get_user_config_or_default(self.author.id.0) + .get_user_config_or_default(self.author.id.get()) .await .unwrap() .unwrap() }; - let audio = match config.tts_type.unwrap_or(TTSType::GCP) { - TTSType::GCP => tts - .0 - .synthesize(SynthesizeRequest { + let tts = data_read + .get::() + .expect("Cannot get GCP TTSClientStorage"); + + match config.tts_type.unwrap_or(TTSType::GCP) { + TTSType::GCP => vec![tts + .synthesize_gcp(SynthesizeRequest { input: SynthesisInput { text: None, ssml: Some(format!("{}", text)), @@ -117,26 +123,17 @@ impl TTSMessage for Message { }, }) .await - .unwrap(), + .unwrap() + .into()], - TTSType::VOICEVOX => tts - .1 - .synthesize( - text.replace("", "、"), + TTSType::VOICEVOX => vec![tts + .synthesize_voicevox( + &text.replace("", "、"), config.voicevox_speaker.unwrap_or(1), ) .await - .unwrap(), - }; - - let uuid = uuid::Uuid::new_v4().to_string(); - - let path = env::current_dir().unwrap(); - let file_path = path.join("audio").join(format!("{}.mp3", uuid)); - - let mut file = File::create(file_path.clone()).unwrap(); - file.write(&audio).unwrap(); - - file_path.into_os_string().into_string().unwrap() + .unwrap() + .into()], + } } } diff --git a/src/implement/voice_move_state.rs b/src/implement/voice_move_state.rs index 2b65f51..aa6ed41 100644 --- a/src/implement/voice_move_state.rs +++ b/src/implement/voice_move_state.rs @@ -29,12 +29,10 @@ impl VoiceMoveStateTrait for VoiceState { (Some(old_channel_id), Some(new_channel_id)) => { if old_channel_id == new_channel_id { VoiceMoveState::NONE - } else if old_channel_id != new_channel_id { - if target_channel == new_channel_id { - VoiceMoveState::JOIN - } else { - VoiceMoveState::NONE - } + } else if old_channel_id == target_channel { + VoiceMoveState::LEAVE + } else if new_channel_id == target_channel { + VoiceMoveState::JOIN } else { VoiceMoveState::NONE } diff --git a/src/main.rs b/src/main.rs index b0d2e6d..b379d8e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,8 @@ mod database; mod event_handler; mod events; mod implement; +mod stream_input; +mod trace; mod tts; use std::{collections::HashMap, env, sync::Arc}; @@ -13,13 +15,16 @@ use config::Config; use data::{DatabaseClientData, TTSClientData, TTSData}; use database::database::Database; use event_handler::Handler; +#[allow(deprecated)] use serenity::{ + all::{standard::Configuration, ApplicationId}, client::Client, framework::StandardFramework, - futures::lock::Mutex, prelude::{GatewayIntents, RwLock}, }; -use tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX}; +use trace::init_tracing_subscriber; +use tracing::info; +use tts::{gcp_tts::gcp_tts::GCPTTS, tts::TTS, voicevox::voicevox::VOICEVOX}; use songbird::SerenityInit; @@ -31,12 +36,14 @@ use songbird::SerenityInit; /// /// client.start().await; /// ``` +#[allow(deprecated)] async fn create_client(prefix: &str, token: &str, id: u64) -> Result { - let framework = StandardFramework::new().configure(|c| c.with_whitespace(true).prefix(prefix)); + let framework = StandardFramework::new(); + framework.configure(Configuration::new().with_whitespace(true).prefix(prefix)); Client::builder(token, GatewayIntents::all()) .event_handler(Handler) - .application_id(id) + .application_id(ApplicationId::new(id)) .framework(framework) .register_songbird() .await @@ -54,7 +61,18 @@ async fn main() { let application_id = env::var("NCB_APP_ID").unwrap(); let prefix = env::var("NCB_PREFIX").unwrap(); let redis_url = env::var("NCB_REDIS_URL").unwrap(); - let voicevox_key = env::var("NCB_VOICEVOX_KEY").unwrap(); + let voicevox_key = match env::var("NCB_VOICEVOX_KEY") { + Ok(key) => Some(key), + Err(_) => None, + }; + let voicevox_original_api_url = match env::var("NCB_VOICEVOX_ORIGINAL_API_URL") { + Ok(url) => Some(url), + Err(_) => None, + }; + let otel_http_url = match env::var("NCB_OTEL_HTTP_URL") { + Ok(url) => Some(url), + Err(_) => None, + }; Config { token, @@ -62,22 +80,26 @@ async fn main() { prefix, redis_url, voicevox_key, + voicevox_original_api_url, + otel_http_url, } } }; + let _guard = init_tracing_subscriber(&config.otel_http_url); + // Create discord client let mut client = create_client(&config.prefix, &config.token, config.application_id) .await .expect("Err creating client"); // Create GCP TTS client - let tts = match TTS::new("./credentials.json".to_string()).await { + let tts = match GCPTTS::new("./credentials.json".to_string()).await { Ok(tts) => tts, Err(err) => panic!("GCP init error: {}", err), }; - let voicevox = VOICEVOX::new(config.voicevox_key); + let voicevox = VOICEVOX::new(config.voicevox_key, config.voicevox_original_api_url); let database_client = { let redis_client = redis::Client::open(config.redis_url).unwrap(); @@ -88,10 +110,12 @@ async fn main() { { let mut data = client.data.write().await; data.insert::(Arc::new(RwLock::new(HashMap::default()))); - data.insert::(Arc::new(Mutex::new((tts, voicevox)))); - data.insert::(Arc::new(Mutex::new(database_client))); + data.insert::(Arc::new(TTS::new(voicevox, tts))); + data.insert::(Arc::new(database_client)); } + info!("Bot initialized."); + // Run client if let Err(why) = client.start().await { println!("Client error: {:?}", why); diff --git a/src/stream_input.rs b/src/stream_input.rs new file mode 100644 index 0000000..ab690e1 --- /dev/null +++ b/src/stream_input.rs @@ -0,0 +1,93 @@ +use async_trait::async_trait; +use futures::TryStreamExt; +use reqwest::{header::HeaderMap, Client}; +use symphonia_core::{io::MediaSource, probe::Hint}; +use tokio_util::compat::FuturesAsyncReadCompatExt; + +use songbird::input::{ + AsyncAdapterStream, AsyncReadOnlySource, AudioStream, AudioStreamError, Compose, Input, +}; + +#[derive(Debug, Clone)] +pub struct Mp3Request { + client: Client, + request: String, + headers: HeaderMap, +} + +impl Mp3Request { + #[must_use] + pub fn new(client: Client, request: String) -> Self { + Self::new_with_headers(client, request, HeaderMap::default()) + } + + #[must_use] + pub fn new_with_headers(client: Client, request: String, headers: HeaderMap) -> Self { + Mp3Request { + client, + request, + headers, + } + } + + async fn create_stream_async(&self) -> Result { + let request = self + .client + .get(&self.request) + .headers(self.headers.clone()) + .build() + .map_err(|why| AudioStreamError::Fail(why.into()))?; + + let response = self + .client + .execute(request) + .await + .map_err(|why| AudioStreamError::Fail(why.into()))?; + + if !response.status().is_success() { + return Err(AudioStreamError::Fail( + format!("HTTP error: {}", response.status()).into(), + )); + } + + let byte_stream = response + .bytes_stream() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string())); + + let tokio_reader = byte_stream.into_async_read().compat(); + + Ok(AsyncReadOnlySource::new(tokio_reader)) + } +} + +#[async_trait] +impl Compose for Mp3Request { + fn create(&mut self) -> Result>, AudioStreamError> { + Err(AudioStreamError::Fail( + "Mp3Request::create must be called in an async context via create_async".into(), + )) + } + + async fn create_async( + &mut self, + ) -> Result>, AudioStreamError> { + let input = self.create_stream_async().await?; + let stream = AsyncAdapterStream::new(Box::new(input), 64 * 1024); + + let hint = Hint::new().with_extension("mp3").clone(); + Ok(AudioStream { + input: Box::new(stream) as Box, + hint: Some(hint), + }) + } + + fn should_create_async(&self) -> bool { + true + } +} + +impl From for Input { + fn from(val: Mp3Request) -> Self { + Input::Lazy(Box::new(val)) + } +} diff --git a/src/trace.rs b/src/trace.rs new file mode 100644 index 0000000..5ab3ad7 --- /dev/null +++ b/src/trace.rs @@ -0,0 +1,128 @@ +use opentelemetry::{ + global, + trace::{SamplingDecision, SamplingResult, TraceContextExt, TraceState, TracerProvider as _}, + KeyValue, +}; +use opentelemetry_otlp::{Protocol, WithExportConfig}; +use opentelemetry_sdk::{ + metrics::{MeterProviderBuilder, PeriodicReader, SdkMeterProvider}, + trace::{RandomIdGenerator, SdkTracerProvider, ShouldSample}, + Resource, +}; +use tracing::Level; +use tracing_opentelemetry::{MetricsLayer, OpenTelemetryLayer}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[derive(Debug, Clone)] +struct FilterSampler; + +impl ShouldSample for FilterSampler { + fn should_sample( + &self, + parent_context: Option<&opentelemetry::Context>, + _trace_id: opentelemetry::TraceId, + name: &str, + _span_kind: &opentelemetry::trace::SpanKind, + _attributes: &[KeyValue], + _links: &[opentelemetry::trace::Link], + ) -> opentelemetry::trace::SamplingResult { + let decision = if name == "dispatch" || name == "recv_event" { + SamplingDecision::Drop + } else { + SamplingDecision::RecordAndSample + }; + + SamplingResult { + decision, + attributes: vec![], + trace_state: match parent_context { + Some(ctx) => ctx.span().span_context().trace_state().clone(), + None => TraceState::default(), + }, + } + } +} + +fn resource() -> Resource { + Resource::builder().with_service_name("ncb-tts-r2").build() +} + +fn init_meter_provider(url: &str) -> SdkMeterProvider { + let exporter = opentelemetry_otlp::MetricExporter::builder() + .with_http() + .with_endpoint(url) + .with_protocol(Protocol::HttpBinary) + .with_temporality(opentelemetry_sdk::metrics::Temporality::default()) + .build() + .unwrap(); + + let reader = PeriodicReader::builder(exporter) + .with_interval(std::time::Duration::from_secs(5)) + .build(); + + let stdout_reader = + PeriodicReader::builder(opentelemetry_stdout::MetricExporter::default()).build(); + + let meter_provider = MeterProviderBuilder::default() + .with_resource(resource()) + .with_reader(reader) + .with_reader(stdout_reader) + .build(); + + global::set_meter_provider(meter_provider.clone()); + + meter_provider +} + +fn init_tracer_provider(url: &str) -> SdkTracerProvider { + let exporter = opentelemetry_otlp::SpanExporter::builder() + .with_http() + .with_endpoint(url) + .with_protocol(Protocol::HttpBinary) + .build() + .unwrap(); + + SdkTracerProvider::builder() + .with_sampler(FilterSampler) + .with_id_generator(RandomIdGenerator::default()) + .with_resource(resource()) + .with_batch_exporter(exporter) + .build() +} + +pub fn init_tracing_subscriber(otel_http_url: &Option) -> OtelGuard { + let registry = tracing_subscriber::registry() + .with(tracing_subscriber::filter::LevelFilter::from_level( + Level::INFO, + )) + .with(tracing_subscriber::fmt::layer()); + + if let Some(url) = otel_http_url { + let tracer_provider = init_tracer_provider(url); + let meter_provider = init_meter_provider(url); + + let tracer = tracer_provider.tracer("ncb-tts-r2"); + + registry + .with(MetricsLayer::new(meter_provider.clone())) + .with(OpenTelemetryLayer::new(tracer)) + .init(); + + OtelGuard { + _tracer_provider: Some(tracer_provider), + _meter_provider: Some(meter_provider), + } + } else { + registry.init(); + + OtelGuard { + _tracer_provider: None, + _meter_provider: None, + } + } +} + +pub struct OtelGuard { + _tracer_provider: Option, + _meter_provider: Option, +} diff --git a/src/tts/gcp_tts/gcp_tts.rs b/src/tts/gcp_tts/gcp_tts.rs index acf9d33..2cfe878 100644 --- a/src/tts/gcp_tts/gcp_tts.rs +++ b/src/tts/gcp_tts/gcp_tts.rs @@ -2,35 +2,40 @@ use crate::tts::gcp_tts::structs::{ synthesize_request::SynthesizeRequest, synthesize_response::SynthesizeResponse, }; use gcp_auth::Token; +use std::sync::Arc; +use tokio::sync::RwLock; -#[derive(Clone)] -pub struct TTS { - pub token: Token, +#[derive(Clone, Debug)] +pub struct GCPTTS { + pub token: Arc>, pub credentials_path: String, } -impl TTS { - pub async fn update_token(&mut self) -> Result<(), gcp_auth::Error> { - if self.token.has_expired() { +impl GCPTTS { + #[tracing::instrument] + pub async fn update_token(&self) -> Result<(), gcp_auth::Error> { + let mut token = self.token.write().await; + if token.has_expired() { let authenticator = gcp_auth::from_credentials_file(self.credentials_path.clone()).await?; - let token = authenticator + let new_token = authenticator .get_token(&["https://www.googleapis.com/auth/cloud-platform"]) .await?; - self.token = token; + *token = new_token; } Ok(()) } - pub async fn new(credentials_path: String) -> Result { + #[tracing::instrument] + pub async fn new(credentials_path: String) -> Result { let authenticator = gcp_auth::from_credentials_file(credentials_path.clone()).await?; let token = authenticator .get_token(&["https://www.googleapis.com/auth/cloud-platform"]) .await?; - Ok(TTS { - token, + Ok(Self { + token: Arc::new(RwLock::new(token)), credentials_path, }) } @@ -56,18 +61,25 @@ impl TTS { /// } /// }).await.unwrap(); /// ``` + #[tracing::instrument] pub async fn synthesize( - &mut self, + &self, request: SynthesizeRequest, ) -> Result, Box> { self.update_token().await.unwrap(); let client = reqwest::Client::new(); + + let token_string = { + let token = self.token.read().await; + token.as_str().to_string() + }; + match client .post("https://texttospeech.googleapis.com/v1/text:synthesize") .header(reqwest::header::CONTENT_TYPE, "application/json") .header( reqwest::header::AUTHORIZATION, - format!("Bearer {}", self.token.as_str()), + format!("Bearer {}", token_string), ) .body(serde_json::to_string(&request).unwrap()) .send() diff --git a/src/tts/gcp_tts/structs/synthesis_input.rs b/src/tts/gcp_tts/structs/synthesis_input.rs index d7a1464..99d0a41 100644 --- a/src/tts/gcp_tts/structs/synthesis_input.rs +++ b/src/tts/gcp_tts/structs/synthesis_input.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; /// ssml: Some(String::from("test")) /// } /// ``` -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone)] pub struct SynthesisInput { pub text: Option, pub ssml: Option, diff --git a/src/tts/gcp_tts/structs/voice_selection_params.rs b/src/tts/gcp_tts/structs/voice_selection_params.rs index 37c78bd..442e94c 100644 --- a/src/tts/gcp_tts/structs/voice_selection_params.rs +++ b/src/tts/gcp_tts/structs/voice_selection_params.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; /// ssmlGender: String::from("neutral") /// } /// ``` -#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)] +#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq, Hash)] #[allow(non_snake_case)] pub struct VoiceSelectionParams { pub languageCode: String, diff --git a/src/tts/instance.rs b/src/tts/instance.rs index a1ef5c0..7dd9142 100644 --- a/src/tts/instance.rs +++ b/src/tts/instance.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; + use serenity::{ model::{ channel::Message, @@ -8,6 +10,7 @@ use serenity::{ use crate::tts::message::TTSMessage; +#[derive(Debug, Clone)] pub struct TTSInstance { pub before_message: Option, pub text_channel: ChannelId, @@ -22,23 +25,24 @@ impl TTSInstance { /// ```rust /// instance.read(message, &ctx).await; /// ``` + #[tracing::instrument] pub async fn read(&mut self, message: T, ctx: &Context) where - T: TTSMessage, + T: TTSMessage + Debug, { - let path = message.synthesize(self, ctx).await; + let audio = message.synthesize(self, ctx).await; { let manager = songbird::get(&ctx).await.unwrap(); let call = manager.get(self.guild).unwrap(); let mut call = call.lock().await; - let input = songbird::input::ffmpeg(path) - .await - .expect("File not found."); - call.enqueue_source(input); + for audio in audio { + call.enqueue(audio.into()).await; + } } } + #[tracing::instrument] pub async fn skip(&mut self, ctx: &Context) { let manager = songbird::get(&ctx).await.unwrap(); let call = manager.get(self.guild).unwrap(); diff --git a/src/tts/message.rs b/src/tts/message.rs index 59b4fcd..bd3f815 100644 --- a/src/tts/message.rs +++ b/src/tts/message.rs @@ -1,7 +1,6 @@ -use std::{env, fs::File, io::Write}; - use async_trait::async_trait; use serenity::prelude::Context; +use songbird::tracks::Track; use crate::{data::TTSClientData, tts::instance::TTSInstance}; @@ -21,15 +20,16 @@ pub trait TTSMessage { /// ``` async fn parse(&self, instance: &mut TTSInstance, ctx: &Context) -> String; - /// Synthesize the message and returns the path to the audio file. + /// Synthesize the message and returns the audio data. /// /// Example: /// ```rust - /// let path = message.synthesize(instance, ctx).await; + /// let audio = message.synthesize(instance, ctx).await; /// ``` - async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String; + async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec; } +#[derive(Debug, Clone)] pub struct AnnounceMessage { pub message: String, } @@ -44,18 +44,15 @@ impl TTSMessage for AnnounceMessage { ) } - async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String { + async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec { let text = self.parse(instance, ctx).await; let data_read = ctx.data.read().await; - let storage = data_read + let tts = data_read .get::() - .expect("Cannot get TTSClientStorage") - .clone(); - let mut storage = storage.lock().await; + .expect("Cannot get TTSClientStorage"); - let audio = storage - .0 - .synthesize(SynthesizeRequest { + let audio = tts + .synthesize_gcp(SynthesizeRequest { input: SynthesisInput { text: None, ssml: Some(text), @@ -74,14 +71,6 @@ impl TTSMessage for AnnounceMessage { .await .unwrap(); - let uuid = uuid::Uuid::new_v4().to_string(); - - let path = env::current_dir().unwrap(); - let file_path = path.join("audio").join(format!("{}.mp3", uuid)); - - let mut file = File::create(file_path.clone()).unwrap(); - file.write(&audio).unwrap(); - - file_path.into_os_string().into_string().unwrap() + vec![audio.into()] } } diff --git a/src/tts/mod.rs b/src/tts/mod.rs index 0ad87a4..6743cc8 100644 --- a/src/tts/mod.rs +++ b/src/tts/mod.rs @@ -1,5 +1,6 @@ pub mod gcp_tts; pub mod instance; pub mod message; +pub mod tts; pub mod tts_type; pub mod voicevox; diff --git a/src/tts/tts.rs b/src/tts/tts.rs new file mode 100644 index 0000000..9dac21c --- /dev/null +++ b/src/tts/tts.rs @@ -0,0 +1,133 @@ +use std::sync::RwLock; +use std::{num::NonZeroUsize, sync::Arc}; + +use lru::LruCache; +use songbird::{driver::Bitrate, input::cached::Compressed, tracks::Track}; +use tracing::info; + +use super::{ + gcp_tts::{ + gcp_tts::GCPTTS, + structs::{ + synthesis_input::SynthesisInput, synthesize_request::SynthesizeRequest, + voice_selection_params::VoiceSelectionParams, + }, + }, + voicevox::voicevox::VOICEVOX, +}; + +#[derive(Debug)] +pub struct TTS { + pub voicevox_client: VOICEVOX, + gcp_tts_client: GCPTTS, + cache: Arc>>, +} + +#[derive(Hash, PartialEq, Eq)] +pub enum CacheKey { + Voicevox(String, i64), + GCP(SynthesisInput, VoiceSelectionParams), +} + +impl TTS { + pub fn new(voicevox_client: VOICEVOX, gcp_tts_client: GCPTTS) -> Self { + Self { + voicevox_client, + gcp_tts_client, + cache: Arc::new(RwLock::new(LruCache::new(NonZeroUsize::new(1000).unwrap()))), + } + } + + #[tracing::instrument] + pub async fn synthesize_voicevox( + &self, + text: &str, + speaker: i64, + ) -> Result> { + let cache_key = CacheKey::Voicevox(text.to_string(), speaker); + + let cached_audio = { + let mut cache_guard = self.cache.write().unwrap(); + cache_guard.get(&cache_key).map(|audio| audio.new_handle()) + }; + + if let Some(audio) = cached_audio { + info!("Cache hit for VOICEVOX TTS"); + return Ok(audio.into()); + } + + info!("Cache miss for VOICEVOX TTS"); + + if self.voicevox_client.original_api_url.is_some() { + let audio = self + .voicevox_client + .synthesize_original(text.to_string(), speaker) + .await?; + + tokio::spawn({ + let cache = self.cache.clone(); + let audio = audio.clone(); + async move { + info!("Compressing stream audio"); + let compressed = Compressed::new(audio.into(), Bitrate::Auto).await.unwrap(); + let mut cache_guard = cache.write().unwrap(); + cache_guard.put(cache_key, compressed.clone()); + } + }); + + Ok(audio.into()) + } else { + let audio = self + .voicevox_client + .synthesize_stream(text.to_string(), speaker) + .await?; + + tokio::spawn({ + let cache = self.cache.clone(); + let audio = audio.clone(); + async move { + info!("Compressing stream audio"); + let compressed = Compressed::new(audio.into(), Bitrate::Auto).await.unwrap(); + let mut cache_guard = cache.write().unwrap(); + cache_guard.put(cache_key, compressed.clone()); + } + }); + + Ok(audio.into()) + } + } + + #[tracing::instrument] + pub async fn synthesize_gcp( + &self, + synthesize_request: SynthesizeRequest, + ) -> Result> { + let cache_key = CacheKey::GCP( + synthesize_request.input.clone(), + synthesize_request.voice.clone(), + ); + + let cached_audio = { + let mut cache_guard = self.cache.write().unwrap(); + cache_guard.get(&cache_key).map(|audio| audio.new_handle()) + }; + + if let Some(audio) = cached_audio { + info!("Cache hit for GCP TTS"); + return Ok(audio); + } + + info!("Cache miss for GCP TTS"); + + let audio = self.gcp_tts_client.synthesize(synthesize_request).await?; + + let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?; + + { + let mut cache_guard = self.cache.write().unwrap(); + cache_guard.put(cache_key, compressed.clone()); + } + + Ok(compressed) + } +} diff --git a/src/tts/voicevox/structs/mod.rs b/src/tts/voicevox/structs/mod.rs index 3c24708..598d8f6 100644 --- a/src/tts/voicevox/structs/mod.rs +++ b/src/tts/voicevox/structs/mod.rs @@ -2,3 +2,4 @@ pub mod accent_phrase; pub mod audio_query; pub mod mora; pub mod speaker; +pub mod stream; diff --git a/src/tts/voicevox/structs/stream.rs b/src/tts/voicevox/structs/stream.rs new file mode 100644 index 0000000..840361f --- /dev/null +++ b/src/tts/voicevox/structs/stream.rs @@ -0,0 +1,13 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TTSResponse { + pub success: bool, + pub is_api_key_valid: bool, + pub speaker_name: String, + pub audio_status_url: String, + pub wav_download_url: String, + pub mp3_download_url: String, + pub mp3_streaming_url: String, +} diff --git a/src/tts/voicevox/voicevox.rs b/src/tts/voicevox/voicevox.rs index d9fbbb7..115629e 100644 --- a/src/tts/voicevox/voicevox.rs +++ b/src/tts/voicevox/voicevox.rs @@ -1,13 +1,17 @@ -use super::structs::speaker::Speaker; +use crate::stream_input::Mp3Request; + +use super::structs::{speaker::Speaker, stream::TTSResponse}; const BASE_API_URL: &str = "https://deprecatedapis.tts.quest/v2/"; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct VOICEVOX { - pub key: String, + pub key: Option, + pub original_api_url: Option, } impl VOICEVOX { + #[tracing::instrument] pub async fn get_styles(&self) -> Vec<(String, i64)> { let speakers = self.get_speaker_list().await; let mut speaker_list = vec![]; @@ -20,6 +24,7 @@ impl VOICEVOX { speaker_list } + #[tracing::instrument] pub async fn get_speakers(&self) -> Vec { let speakers = self.get_speaker_list().await; let mut speaker_list = vec![]; @@ -30,18 +35,27 @@ impl VOICEVOX { speaker_list } - pub fn new(key: String) -> Self { - Self { key } + pub fn new(key: Option, original_api_url: Option) -> Self { + Self { + key, + original_api_url, + } } + #[tracing::instrument] async fn get_speaker_list(&self) -> Vec { let client = reqwest::Client::new(); - match client - .post(BASE_API_URL.to_string() + "voicevox/speakers/") - .query(&[("key", self.key.clone())]) - .send() - .await - { + let client = if let Some(key) = &self.key { + client + .get(BASE_API_URL.to_string() + "voicevox/speakers/") + .query(&[("key", key)]) + } else if let Some(original_api_url) = &self.original_api_url { + client.get(original_api_url.to_string() + "/speakers") + } else { + panic!("No API key or original API URL provided.") + }; + + match client.send().await { Ok(response) => response.json().await.unwrap(), Err(err) => { panic!("Cannot get speaker list. {err:?}") @@ -49,6 +63,7 @@ impl VOICEVOX { } } + #[tracing::instrument] pub async fn synthesize( &self, text: String, @@ -60,7 +75,7 @@ impl VOICEVOX { .query(&[ ("speaker", speaker.to_string()), ("text", text), - ("key", self.key.clone()), + ("key", self.key.clone().unwrap()), ]) .send() .await @@ -72,4 +87,47 @@ impl VOICEVOX { Err(err) => Err(Box::new(err)), } } + + #[tracing::instrument] + pub async fn synthesize_original( + &self, + text: String, + speaker: i64, + ) -> Result, Box> { + let client = + voicevox_client::Client::new(self.original_api_url.as_ref().unwrap().clone(), None); + let audio_query = client + .create_audio_query(&text, speaker as i32, None) + .await?; + println!("{:?}", audio_query.audio_query); + let audio = audio_query.synthesis(speaker as i32, true).await?; + Ok(audio.into()) + } + + #[tracing::instrument] + pub async fn synthesize_stream( + &self, + text: String, + speaker: i64, + ) -> Result> { + let client = reqwest::Client::new(); + match client + .post("https://api.tts.quest/v3/voicevox/synthesis") + .query(&[ + ("speaker", speaker.to_string()), + ("text", text), + ("key", self.key.clone().unwrap()), + ]) + .send() + .await + { + Ok(response) => { + let body = response.text().await.unwrap(); + let response: TTSResponse = serde_json::from_str(&body).unwrap(); + + Ok(Mp3Request::new(reqwest::Client::new(), response.mp3_streaming_url).into()) + } + Err(err) => Err(Box::new(err)), + } + } }