Merge pull request #1 from mii443/stable

Stable
This commit is contained in:
mii443
2025-05-24 21:54:25 +09:00
committed by GitHub
32 changed files with 1255 additions and 773 deletions

View File

@ -28,7 +28,7 @@ jobs:
platforms: linux/amd64,linux/arm64
- name: Cache Docker layers
uses: actions/cache@v2
uses: actions/cache@v4
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-${{ github.sha }}

View File

@ -1,6 +1,6 @@
[package]
name = "ncb-tts-r2"
version = "1.7.0"
version = "1.10.1"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@ -13,18 +13,51 @@ gcp_auth = "0.12.3"
reqwest = { version = "0.12.9", features = ["json"] }
base64 = "0.22.1"
async-trait = "0.1.57"
redis = "*"
redis = "0.29.2"
regex = "1"
poise = "0.6.1"
tracing-subscriber = "0.3.19"
lru = "0.13.0"
tracing = "0.1.41"
opentelemetry_sdk = { version = "0.29.0", features = ["trace"] }
opentelemetry = "0.29.1"
opentelemetry-semantic-conventions = "0.29.0"
opentelemetry-otlp = { version = "0.29.0", features = ["grpc-tonic"] }
opentelemetry-stdout = "0.29.0"
tracing-opentelemetry = "0.30.0"
symphonia-core = "0.5.4"
tokio-util = { version = "0.7.14", features = ["compat"] }
futures = "0.3.31"
bytes = "1.10.1"
voicevox-client = { git = "https://github.com/mii443/rust" }
[dependencies.uuid]
version = "1.11.0"
features = ["serde", "v4"]
[dependencies.songbird]
version = "0.4.4"
version = "0.5"
features = ["builtin-queue"]
[dependencies.symphonia]
version = "0.5"
features = ["mp3"]
[dependencies.serenity]
version = "0.12"
features = [
"builder",
"cache",
"client",
"gateway",
"model",
"utils",
"unstable_discord_api",
"collector",
"rustls_backend",
"framework",
"voice",
]
[dependencies.tokio]
version = "1.0"
features = ["macros", "rt-multi-thread"]

View File

@ -1,4 +1,4 @@
FROM lukemathwalker/cargo-chef:latest-rust-1.72 AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.82 AS chef
WORKDIR app
FROM chef AS planner
@ -14,6 +14,6 @@ RUN cargo build --release
FROM ubuntu:22.04 AS runtime
WORKDIR /ncb-tts-r2
RUN apt-get update && apt-get install -y --no-install-recommends openssl ca-certificates ffmpeg libssl-dev libopus-dev && apt-get -y clean && mkdir audio
RUN apt-get update && apt-get install -y --no-install-recommends openssl ca-certificates ffmpeg libssl-dev libopus-dev && apt-get -y clean
COPY --from=builder /app/target/release/ncb-tts-r2 /usr/local/bin
ENTRYPOINT ["/usr/local/bin/ncb-tts-r2"]

View File

@ -3,7 +3,7 @@ version: '3'
services:
ncb-tts-r2:
container_name: ncb-tts-r2
image: ghcr.io/mii443/ncb-tts-r2:1.7.3
image: ghcr.io/mii443/ncb-tts-r2:1.10.1
environment:
- NCB_TOKEN=YOUR_BOT_TOKEN
- NCB_APP_ID=YOUR_BOT_ID

View File

@ -22,7 +22,7 @@ spec:
- name: ncb-redis-pvc
mountPath: /data
- name: tts
image: ghcr.io/morioka22/ncb-tts-r2
image: ghcr.io/mii443/ncb-tts-r2
volumeMounts:
- name: gcp-credentials
mountPath: /ncb-tts-r2/credentials.json

View File

@ -1,7 +1,8 @@
use serenity::{
model::prelude::{
component::ButtonStyle,
interaction::{application_command::ApplicationCommandInteraction, MessageFlags},
all::{
ButtonStyle, CommandInteraction, CreateActionRow, CreateButton, CreateInteractionResponse,
CreateInteractionResponseMessage, CreateSelectMenu, CreateSelectMenuKind,
CreateSelectMenuOption,
},
prelude::Context,
};
@ -11,9 +12,10 @@ use crate::{
tts::tts_type::TTSType,
};
#[tracing::instrument]
pub async fn config_command(
ctx: &Context,
command: &ApplicationCommandInteraction,
command: &CommandInteraction,
) -> Result<(), Box<dyn std::error::Error>> {
let data_read = ctx.data.read().await;
@ -22,9 +24,8 @@ pub async fn config_command(
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_user_config_or_default(command.user.id.0)
.get_user_config_or_default(command.user.id.get())
.await
.unwrap()
.unwrap()
@ -32,84 +33,66 @@ pub async fn config_command(
let tts_client = data_read
.get::<TTSClientData>()
.expect("Cannot get TTSClientData")
.clone();
let voicevox_speakers = tts_client.lock().await.1.get_styles().await;
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client.voicevox_client.get_styles().await;
let voicevox_speaker = config.voicevox_speaker.unwrap_or(1);
let tts_type = config.tts_type.unwrap_or(TTSType::GCP);
let engine_select = CreateActionRow::SelectMenu(
CreateSelectMenu::new(
"TTS_CONFIG_ENGINE",
CreateSelectMenuKind::String {
options: vec![
CreateSelectMenuOption::new("Google TTS", "TTS_CONFIG_ENGINE_SELECTED_GOOGLE")
.default_selection(tts_type == TTSType::GCP),
CreateSelectMenuOption::new("VOICEVOX", "TTS_CONFIG_ENGINE_SELECTED_VOICEVOX")
.default_selection(tts_type == TTSType::VOICEVOX),
],
},
)
.placeholder("読み上げAPIを選択"),
);
let server_button = CreateActionRow::Buttons(vec![CreateButton::new("TTS_CONFIG_SERVER")
.label("サーバー設定")
.style(ButtonStyle::Primary)]);
let mut components = vec![engine_select, server_button];
for (index, speaker_chunk) in voicevox_speakers[0..24].chunks(25).enumerate() {
let mut options = Vec::new();
for (name, id) in speaker_chunk {
options.push(
CreateSelectMenuOption::new(
name,
format!("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_{}", id),
)
.default_selection(*id == voicevox_speaker),
);
}
components.push(CreateActionRow::SelectMenu(
CreateSelectMenu::new(
format!("TTS_CONFIG_VOICEVOX_SPEAKER_{}", index),
CreateSelectMenuKind::String { options },
)
.placeholder("VOICEVOX Speakerを指定"),
));
}
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("読み上げ設定")
.components(|c| {
let mut c = c;
c = c
.create_action_row(|a| {
a.create_select_menu(|m| {
m.custom_id("TTS_CONFIG_ENGINE")
.options(|o| {
o.create_option(|co| {
co.label("Google TTS")
.value("TTS_CONFIG_ENGINE_SELECTED_GOOGLE")
.default_selection(tts_type == TTSType::GCP)
})
.create_option(|co| {
co.label("VOICEVOX")
.value("TTS_CONFIG_ENGINE_SELECTED_VOICEVOX")
.default_selection(
tts_type == TTSType::VOICEVOX,
)
})
})
.placeholder("読み上げAPIを選択")
})
})
.create_action_row(|a| {
a.create_button(|f| {
f.label("サーバー設定")
.custom_id("TTS_CONFIG_SERVER")
.style(ButtonStyle::Primary)
})
});
for (index, speaker_chunk) in
voicevox_speakers[0..24].chunks(25).enumerate()
{
c = c.create_action_row(|a| {
let mut a = a;
a = a.create_select_menu(|m| {
m.custom_id(
"TTS_CONFIG_VOICEVOX_SPEAKER_".to_string()
+ &index.to_string(),
)
.options(|o| {
let mut o = o;
for (name, id) in speaker_chunk {
o = o.create_option(|co| {
co.label(name)
.value(format!(
"TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_{}",
id
))
.default_selection(*id == voicevox_speaker)
})
}
o
})
.placeholder("VOICEVOX Speakerを指定")
});
a
})
}
println!("{:?}", c);
c
})
.flags(MessageFlags::EPHEMERAL)
})
})
.create_response(
&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("読み上げ設定")
.components(components)
.ephemeral(true),
),
)
.await?;
Ok(())
}

View File

@ -1,61 +1,53 @@
use serenity::{
model::prelude::{
interaction::{application_command::ApplicationCommandInteraction, MessageFlags},
UserId,
all::{
AutoArchiveDuration, CommandInteraction, CreateEmbed, CreateInteractionResponse, CreateInteractionResponseMessage, CreateMessage, CreateThread
},
model::prelude::UserId,
prelude::Context,
};
use tracing::info;
use crate::{
data::{TTSClientData, TTSData},
tts::instance::TTSInstance,
};
#[tracing::instrument]
pub async fn setup_command(
ctx: &Context,
command: &ApplicationCommandInteraction,
command: &CommandInteraction,
) -> Result<(), Box<dyn std::error::Error>> {
println!("Received event");
if let None = command.guild_id {
info!("Received event");
if command.guild_id.is_none() {
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("このコマンドはサーバーでのみ使用可能です.")
.flags(MessageFlags::EPHEMERAL)
})
})
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("このコマンドはサーバーでのみ使用可能です.")
.ephemeral(true)
))
.await?;
return Ok(());
}
println!("Fetching guild cache");
let guild = command.guild_id.unwrap().to_guild_cached(&ctx.cache);
if let None = guild {
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("ギルドキャッシュを取得できませんでした.")
.flags(MessageFlags::EPHEMERAL)
})
})
.await?;
return Ok(());
}
let guild = guild.unwrap();
info!("Fetching guild cache");
let guild_id = command.guild_id.unwrap();
let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone();
let channel_id = guild
.voice_states
.get(&UserId(command.user.id.0))
.get(&UserId::from(command.user.id.get()))
.and_then(|state| state.channel_id);
if let None = channel_id {
if channel_id.is_none() {
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("ボイスチャンネルに参加してから実行してください.")
.flags(MessageFlags::EPHEMERAL)
})
})
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("ボイスチャンネルに参加してから実行してください.")
.ephemeral(true)
))
.await?;
return Ok(());
}
@ -79,39 +71,34 @@ pub async fn setup_command(
let mut storage = storage_lock.write().await;
if storage.contains_key(&guild.id) {
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("すでにセットアップしています.")
.flags(MessageFlags::EPHEMERAL)
})
})
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("すでにセットアップしています.")
.ephemeral(true)
))
.await?;
return Ok(());
}
let text_channel_id = {
if let Some(mode) = command.data.options.get(0) {
let mode = mode.clone();
let value = mode.value.unwrap();
let value = value.as_str().unwrap();
match value {
"TEXT_CHANNEL" => command.channel_id,
"NEW_THREAD" => {
let message = command
.channel_id
.send_message(&ctx.http, |f| f.content("TTS thread"))
.await
.unwrap();
command
.channel_id
.create_public_thread(&ctx.http, message, |f| {
f.name("TTS").auto_archive_duration(60)
})
.await
.unwrap()
.id
}
"VOICE_CHANNEL" => channel_id,
match &mode.value {
serenity::all::CommandDataOptionValue::String(value) => {
match value.as_str() {
"TEXT_CHANNEL" => command.channel_id,
"NEW_THREAD" => {
command
.channel_id
.create_thread(&ctx.http, CreateThread::new("TTS").auto_archive_duration(AutoArchiveDuration::OneHour).kind(serenity::all::ChannelType::PublicThread))
.await
.unwrap()
.id
}
"VOICE_CHANNEL" => channel_id,
_ => channel_id,
}
},
_ => channel_id,
}
} else {
@ -133,27 +120,37 @@ pub async fn setup_command(
};
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content(format!("TTS Channel: <#{}>{}", text_channel_id, if text_channel_id == channel_id { "\nボイスチャンネルを右クリックし `チャットを開く` を押して開くことが出来ます。" } else { "" }))
})
})
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content(format!(
"TTS Channel: <#{}>{}",
text_channel_id,
if text_channel_id == channel_id {
"\nボイスチャンネルを右クリックし `チャットを開く` を押して開くことが出来ます。"
} else {
""
}
))
))
.await?;
let _handler = manager.join(guild.id.0, channel_id.0).await;
let tts_client = ctx
let _handler = manager.join(guild.id, channel_id).await;
let data = ctx
.data
.read()
.await
.await;
let tts_client = data
.get::<TTSClientData>()
.expect("Cannot get TTSClientData")
.clone();
let voicevox_speakers = tts_client.lock().await.1.get_speakers().await;
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
text_channel_id
.send_message(&ctx.http, |f| {
f.embed(|e| {
e.title("読み上げ (Serenity)")
.send_message(&ctx.http, CreateMessage::new()
.embed(
CreateEmbed::new()
.title("読み上げ (Serenity)")
.field(
"VOICEVOXクレジット",
format!("```\n{}\n```", voicevox_speakers.join("\n")),
@ -161,9 +158,8 @@ pub async fn setup_command(
)
.field("設定コマンド", "`/config`", false)
.field("フィードバック", "https://feedback.mii.codes/", false)
})
})
))
.await?;
Ok(())
}
}

View File

@ -1,8 +1,8 @@
use serenity::{
model::prelude::{
interaction::{application_command::ApplicationCommandInteraction, MessageFlags},
UserId,
all::{
CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage
},
model::prelude::UserId,
prelude::Context,
};
@ -10,47 +10,36 @@ use crate::data::TTSData;
pub async fn skip_command(
ctx: &Context,
command: &ApplicationCommandInteraction,
command: &CommandInteraction,
) -> Result<(), Box<dyn std::error::Error>> {
if let None = command.guild_id {
if command.guild_id.is_none() {
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("このコマンドはサーバーでのみ使用可能です.")
.flags(MessageFlags::EPHEMERAL)
})
})
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("このコマンドはサーバーでのみ使用可能です.")
.ephemeral(true)
))
.await?;
return Ok(());
}
let guild = command.guild_id.unwrap().to_guild_cached(&ctx.cache);
if let None = guild {
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("ギルドキャッシュを取得できませんでした.")
.flags(MessageFlags::EPHEMERAL)
})
})
.await?;
return Ok(());
}
let guild = guild.unwrap();
let guild_id = command.guild_id.unwrap();
let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone();
let channel_id = guild
.voice_states
.get(&UserId(command.user.id.0))
.get(&UserId::from(command.user.id.get()))
.and_then(|state| state.channel_id);
if let None = channel_id {
if channel_id.is_none() {
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("ボイスチャンネルに参加してから実行してください.")
.flags(MessageFlags::EPHEMERAL)
})
})
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("ボイスチャンネルに参加してから実行してください.")
.ephemeral(true)
))
.await?;
return Ok(());
}
@ -67,24 +56,26 @@ pub async fn skip_command(
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild.id) {
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("読み上げしていません")
.flags(MessageFlags::EPHEMERAL)
})
})
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("読み上げしていません")
.ephemeral(true)
))
.await?;
return Ok(());
}
storage.get_mut(&guild.id).unwrap().skip(&ctx).await;
storage.get_mut(&guild.id).unwrap().skip(ctx).await;
}
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| d.content("スキップしました"))
})
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("スキップしました")
))
.await?;
Ok(())
}
}

View File

@ -1,56 +1,46 @@
use serenity::{
model::prelude::{
interaction::{application_command::ApplicationCommandInteraction, MessageFlags},
UserId,
all::{
CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, EditThread
},
prelude::Context,
model::prelude::UserId,
prelude::Context
};
use crate::data::TTSData;
pub async fn stop_command(
ctx: &Context,
command: &ApplicationCommandInteraction,
command: &CommandInteraction,
) -> Result<(), Box<dyn std::error::Error>> {
if let None = command.guild_id {
if command.guild_id.is_none() {
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("このコマンドはサーバーでのみ使用可能です.")
.flags(MessageFlags::EPHEMERAL)
})
})
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("このコマンドはサーバーでのみ使用可能です.")
.ephemeral(true)
))
.await?;
return Ok(());
}
let guild = command.guild_id.unwrap().to_guild_cached(&ctx.cache);
if let None = guild {
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("ギルドキャッシュを取得できませんでした.")
.flags(MessageFlags::EPHEMERAL)
})
})
.await?;
return Ok(());
}
let guild = guild.unwrap();
let guild_id = command.guild_id.unwrap();
let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone();
let channel_id = guild
.voice_states
.get(&UserId(command.user.id.0))
.get(&UserId::from(command.user.id.get()))
.and_then(|state| state.channel_id);
if let None = channel_id {
if channel_id.is_none() {
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("ボイスチャンネルに参加してから実行してください.")
.flags(MessageFlags::EPHEMERAL)
})
})
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("ボイスチャンネルに参加してから実行してください.")
.ephemeral(true)
))
.await?;
return Ok(());
}
@ -70,36 +60,37 @@ pub async fn stop_command(
let text_channel_id = {
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild.id) {
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("すでに停止しています")
.flags(MessageFlags::EPHEMERAL)
})
})
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("すでに停止しています")
.ephemeral(true)
))
.await?;
return Ok(());
}
let text_channel_id = storage.get(&guild.id).unwrap().text_channel;
storage.remove(&guild.id);
text_channel_id
};
let _handler = manager.remove(guild.id.0).await;
let _handler = manager.remove(guild.id).await;
command
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| d.content("停止しました"))
})
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("停止しました")
))
.await?;
let _ = text_channel_id
.edit_thread(&ctx.http, |f| f.archived(true))
.edit_thread(&ctx.http, EditThread::new().archived(true))
.await;
Ok(())
}
}

View File

@ -6,5 +6,7 @@ pub struct Config {
pub token: String,
pub application_id: u64,
pub redis_url: String,
pub voicevox_key: String,
pub voicevox_key: Option<String>,
pub voicevox_original_api_url: Option<String>,
pub otel_http_url: Option<String>,
}

View File

@ -1,9 +1,5 @@
use crate::{
database::database::Database,
tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX},
};
use crate::{database::database::Database, tts::tts::TTS};
use serenity::{
futures::lock::Mutex,
model::id::GuildId,
prelude::{RwLock, TypeMapKey},
};
@ -22,12 +18,12 @@ impl TypeMapKey for TTSData {
pub struct TTSClientData;
impl TypeMapKey for TTSClientData {
type Value = Arc<Mutex<(TTS, VOICEVOX)>>;
type Value = Arc<TTS>;
}
/// Database client data
pub struct DatabaseClientData;
impl TypeMapKey for DatabaseClientData {
type Value = Arc<Mutex<Database>>;
type Value = Arc<Database>;
}

View File

@ -1,3 +1,5 @@
use std::fmt::Debug;
use crate::tts::{
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType,
};
@ -5,6 +7,7 @@ use crate::tts::{
use super::{dictionary::Dictionary, server_config::ServerConfig, user_config::UserConfig};
use redis::Commands;
#[derive(Debug, Clone)]
pub struct Database {
pub client: redis::Client,
}
@ -14,114 +17,116 @@ impl Database {
Self { client }
}
fn server_key(server_id: u64) -> String {
format!("discord_server:{}", server_id)
}
fn user_key(user_id: u64) -> String {
format!("discord_user:{}", user_id)
}
#[tracing::instrument]
fn get_config<T: serde::de::DeserializeOwned>(
&self,
key: &str,
) -> redis::RedisResult<Option<T>> {
match self.client.get_connection() {
Ok(mut connection) => {
let config: String = connection.get(key).unwrap_or_default();
if config.is_empty() {
return Ok(None);
}
match serde_json::from_str(&config) {
Ok(config) => Ok(Some(config)),
Err(_) => Ok(None),
}
}
Err(e) => Err(e),
}
}
#[tracing::instrument]
fn set_config<T: serde::Serialize + Debug>(
&self,
key: &str,
config: &T,
) -> redis::RedisResult<()> {
match self.client.get_connection() {
Ok(mut connection) => {
let config_str = serde_json::to_string(config).unwrap();
connection.set::<_, _, ()>(key, config_str)
}
Err(e) => Err(e),
}
}
#[tracing::instrument]
pub async fn get_server_config(
&mut self,
&self,
server_id: u64,
) -> redis::RedisResult<Option<ServerConfig>> {
if let Ok(mut connection) = self.client.get_connection() {
let config: String = connection
.get(format!("discord_server:{}", server_id))
.unwrap_or_default();
match serde_json::from_str(&config) {
Ok(config) => Ok(Some(config)),
Err(_) => Ok(None),
}
} else {
Ok(None)
}
self.get_config(&Self::server_key(server_id))
}
pub async fn get_user_config(
&mut self,
user_id: u64,
) -> redis::RedisResult<Option<UserConfig>> {
if let Ok(mut connection) = self.client.get_connection() {
let config: String = connection
.get(format!("discord_user:{}", user_id))
.unwrap_or_default();
match serde_json::from_str(&config) {
Ok(config) => Ok(Some(config)),
Err(_) => Ok(None),
}
} else {
Ok(None)
}
#[tracing::instrument]
pub async fn get_user_config(&self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
self.get_config(&Self::user_key(user_id))
}
#[tracing::instrument]
pub async fn set_server_config(
&mut self,
&self,
server_id: u64,
config: ServerConfig,
) -> redis::RedisResult<()> {
let config = serde_json::to_string(&config).unwrap();
self.client
.get_connection()
.unwrap()
.set::<String, String, ()>(format!("discord_server:{}", server_id), config)
.unwrap();
Ok(())
self.set_config(&Self::server_key(server_id), &config)
}
#[tracing::instrument]
pub async fn set_user_config(
&mut self,
&self,
user_id: u64,
config: UserConfig,
) -> redis::RedisResult<()> {
let config = serde_json::to_string(&config).unwrap();
self.client
.get_connection()
.unwrap()
.set::<String, String, ()>(format!("discord_user:{}", user_id), config)
.unwrap();
Ok(())
self.set_config(&Self::user_key(user_id), &config)
}
pub async fn set_default_server_config(&mut self, server_id: u64) -> redis::RedisResult<()> {
#[tracing::instrument]
pub async fn set_default_server_config(&self, server_id: u64) -> redis::RedisResult<()> {
let config = ServerConfig {
dictionary: Dictionary::new(),
autostart_channel_id: None,
};
self.client.get_connection().unwrap().set(
format!("discord_server:{}", server_id),
serde_json::to_string(&config).unwrap(),
)?;
Ok(())
self.set_server_config(server_id, config).await
}
pub async fn set_default_user_config(&mut self, user_id: u64) -> redis::RedisResult<()> {
#[tracing::instrument]
pub async fn set_default_user_config(&self, user_id: u64) -> redis::RedisResult<()> {
let voice_selection = VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
ssmlGender: String::from("neutral"),
};
let voice_type = TTSType::GCP;
let config = UserConfig {
tts_type: Some(voice_type),
tts_type: Some(TTSType::GCP),
gcp_tts_voice: Some(voice_selection),
voicevox_speaker: Some(1),
};
self.client.get_connection().unwrap().set(
format!("discord_user:{}", user_id),
serde_json::to_string(&config).unwrap(),
)?;
Ok(())
self.set_user_config(user_id, config).await
}
#[tracing::instrument]
pub async fn get_server_config_or_default(
&mut self,
&self,
server_id: u64,
) -> redis::RedisResult<Option<ServerConfig>> {
let config = self.get_server_config(server_id).await?;
match config {
Some(_) => Ok(config),
match self.get_server_config(server_id).await? {
Some(config) => Ok(Some(config)),
None => {
self.set_default_server_config(server_id).await?;
self.get_server_config(server_id).await
@ -129,13 +134,13 @@ impl Database {
}
}
#[tracing::instrument]
pub async fn get_user_config_or_default(
&mut self,
&self,
user_id: u64,
) -> redis::RedisResult<Option<UserConfig>> {
let config = self.get_user_config(user_id).await?;
match config {
Some(_) => Ok(config),
match self.get_user_config(user_id).await? {
Some(config) => Ok(Some(config)),
None => {
self.set_default_user_config(user_id).await?;
self.get_user_config(user_id).await

View File

@ -8,34 +8,37 @@ use crate::{
tts::tts_type::TTSType,
};
use serenity::{
all::{
ActionRowComponent, ButtonStyle, ComponentInteractionDataKind, CreateActionRow,
CreateButton, CreateEmbed, CreateInputText, CreateInteractionResponse,
CreateInteractionResponseMessage, CreateModal, CreateSelectMenu, CreateSelectMenuKind,
CreateSelectMenuOption, InputTextStyle,
},
async_trait,
client::{Context, EventHandler},
model::{
channel::Message,
gateway::Ready,
prelude::{
component::{ActionRowComponent, ButtonStyle, InputTextStyle},
interaction::{Interaction, InteractionResponseType, MessageFlags},
ChannelType,
},
application::Interaction, channel::Message, gateway::Ready, prelude::ChannelType,
voice::VoiceState,
},
};
#[derive(Clone, Debug)]
pub struct Handler;
#[async_trait]
impl EventHandler for Handler {
#[tracing::instrument]
async fn message(&self, ctx: Context, message: Message) {
events::message_receive::message(ctx, message).await
}
#[tracing::instrument]
async fn ready(&self, ctx: Context, ready: Ready) {
events::ready::ready(ctx, ready).await
}
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
if let Interaction::ApplicationCommand(command) = interaction.clone() {
if let Interaction::Command(command) = interaction.clone() {
let name = &*command.data.name;
match name {
"setup" => setup_command(&ctx, &command).await.unwrap(),
@ -45,7 +48,7 @@ impl EventHandler for Handler {
_ => {}
}
}
if let Interaction::ModalSubmit(modal) = interaction.clone() {
if let Interaction::Modal(modal) = interaction.clone() {
if modal.data.custom_id != "TTS_CONFIG_SERVER_ADD_DICTIONARY" {
return;
}
@ -53,19 +56,19 @@ impl EventHandler for Handler {
let rows = modal.data.components.clone();
let rule_name =
if let ActionRowComponent::InputText(text) = rows[0].components[0].clone() {
text.value
text.value.unwrap()
} else {
panic!("Cannot get rule name");
};
let from = if let ActionRowComponent::InputText(text) = rows[1].components[0].clone() {
text.value
text.value.unwrap()
} else {
panic!("Cannot get from");
};
let to = if let ActionRowComponent::InputText(text) = rows[2].components[0].clone() {
text.value
text.value.unwrap()
} else {
panic!("Cannot get to");
};
@ -84,9 +87,9 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_server_config_or_default(modal.guild_id.unwrap().0)
.get_server_config_or_default(modal.guild_id.unwrap().get())
.await
.unwrap()
.unwrap()
@ -98,22 +101,21 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.set_server_config(modal.guild_id.unwrap().0, config)
.set_server_config(modal.guild_id.unwrap().get(), config)
.await
.unwrap();
modal
.create_interaction_response(&ctx.http, |f| {
f.kind(InteractionResponseType::UpdateMessage)
.interaction_response_data(|d| {
d.custom_id("TTS_CONFIG_SERVER_ADD_DICTIONARY_RESPONSE")
.content(format!(
"辞書を追加しました\n名前: {}\n変換元: {}\n変換後: {}",
rule_name, from, to
))
})
})
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new().content(format!(
"辞書を追加しました\n名前: {}\n変換元: {}\n変換後: {}",
rule_name, from, to
)),
),
)
.await
.unwrap();
}
@ -121,7 +123,16 @@ impl EventHandler for Handler {
if let Some(message_component) = interaction.message_component() {
match &*message_component.data.custom_id {
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU" => {
let i = usize::from_str_radix(&message_component.data.values[0], 10).unwrap();
let i = usize::from_str_radix(
&match message_component.data.kind {
ComponentInteractionDataKind::StringSelect { ref values, .. } => {
values[0].clone()
}
_ => panic!("Cannot get index"),
},
10,
)
.unwrap();
let data_read = ctx.data.read().await;
let mut config = {
@ -129,9 +140,9 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_server_config_or_default(message_component.guild_id.unwrap().0)
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
.unwrap()
.unwrap()
@ -143,22 +154,21 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.set_server_config(message_component.guild_id.unwrap().0, config)
.set_server_config(message_component.guild_id.unwrap().get(), config)
.await
.unwrap();
}
message_component
.create_interaction_response(&ctx, |f| {
f.kind(InteractionResponseType::UpdateMessage)
.interaction_response_data(|d| {
d.custom_id("DICTIONARY_REMOVED")
.content("辞書を削除しました")
.components(|c| c)
})
})
.create_response(
&ctx,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new()
.content("辞書を削除しました"),
),
)
.await
.unwrap();
}
@ -170,53 +180,49 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_server_config_or_default(message_component.guild_id.unwrap().0)
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
.unwrap()
.unwrap()
};
message_component
.create_interaction_response(&ctx.http, |f| {
f.kind(InteractionResponseType::UpdateMessage)
.interaction_response_data(|d| {
d.custom_id("TTS_CONFIG_SERVER_REMOVE_DICTIONARY")
.content("削除する辞書内容を選択してください")
.components(|c| {
c.create_action_row(|a| {
a.create_select_menu(|s| {
s.custom_id(
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU",
)
.options(|o| {
let mut o = o;
for (i, rule) in config
.dictionary
.rules
.iter()
.enumerate()
{
o = o.create_option(|c| {
c.label(rule.id.clone())
.value(i)
.description(format!(
"{} -> {}",
rule.rule.clone(),
rule.to.clone()
))
});
}
o
})
.max_values(1)
.min_values(0)
})
})
})
})
})
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new()
.content("削除する辞書内容を選択してください")
.components(vec![CreateActionRow::SelectMenu(
CreateSelectMenu::new(
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU",
CreateSelectMenuKind::String {
options: {
let mut options = vec![];
for (i, rule) in
config.dictionary.rules.iter().enumerate()
{
let option = CreateSelectMenuOption::new(
rule.id.clone(),
i.to_string(),
)
.description(format!(
"{} -> {}",
rule.rule.clone(),
rule.to.clone()
));
options.push(option);
}
options
},
},
)
.max_values(1)
.min_values(0),
)]),
),
)
.await
.unwrap();
}
@ -227,80 +233,92 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_server_config_or_default(message_component.guild_id.unwrap().0)
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
.unwrap()
.unwrap()
};
message_component
.create_interaction_response(&ctx.http, |f| {
f.kind(InteractionResponseType::UpdateMessage)
.interaction_response_data(|d| {
d.custom_id("DICTIONARY_LIST").content("").embed(|e| {
e.title("辞書一覧");
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new().content("").embed(
CreateEmbed::new().title("辞書一覧").fields({
let mut fields = vec![];
for rule in config.dictionary.rules {
e.field(
rule.id,
let field = (
rule.id.clone(),
format!("{} -> {}", rule.rule, rule.to),
true,
);
fields.push(field);
}
e
})
})
})
fields
}),
),
),
)
.await
.unwrap();
}
"TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON" => {
message_component
.create_interaction_response(&ctx.http, |f| {
f.kind(InteractionResponseType::Modal)
.interaction_response_data(|d| {
d.custom_id("TTS_CONFIG_SERVER_ADD_DICTIONARY")
.title("辞書追加")
.components(|c| {
c.create_action_row(|a| {
a.create_input_text(|i| {
i.style(InputTextStyle::Short)
.label("Rule name")
.custom_id("rule_name")
.required(true)
})
})
.create_action_row(|a| {
a.create_input_text(|i| {
i.style(InputTextStyle::Paragraph)
.label("From")
.custom_id("from")
.required(true)
})
})
.create_action_row(|a| {
a.create_input_text(|i| {
i.style(InputTextStyle::Short)
.label("To")
.custom_id("to")
.required(true)
})
})
})
})
})
.create_response(
&ctx.http,
CreateInteractionResponse::Modal(
CreateModal::new("TTS_CONFIG_SERVER_ADD_DICTIONARY", "辞書追加")
.components({
vec![
CreateActionRow::InputText(
CreateInputText::new(
InputTextStyle::Short,
"rule_name",
"辞書名",
)
.required(true),
),
CreateActionRow::InputText(
CreateInputText::new(
InputTextStyle::Paragraph,
"from",
"変換元(正規表現)",
)
.required(true),
),
CreateActionRow::InputText(
CreateInputText::new(
InputTextStyle::Short,
"to",
"変換先",
)
.required(true),
),
]
}),
),
)
.await
.unwrap();
}
"SET_AUTOSTART_CHANNEL" => {
let autostart_channel_id = if message_component.data.values.len() == 0 {
None
} else {
let ch = message_component.data.values[0]
.strip_prefix("SET_AUTOSTART_CHANNEL_")
.unwrap();
Some(u64::from_str_radix(ch, 10).unwrap())
let autostart_channel_id = match message_component.data.kind {
ComponentInteractionDataKind::StringSelect { ref values, .. } => {
if values.len() == 0 {
None
} else {
Some(
u64::from_str_radix(
&values[0].strip_prefix("SET_AUTOSTART_CHANNEL_").unwrap(),
10,
)
.unwrap(),
)
}
}
_ => panic!("Cannot get index"),
};
{
let data_read = ctx.data.read().await;
@ -308,27 +326,27 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
let mut config = database
.get_server_config_or_default(message_component.guild_id.unwrap().0)
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
.unwrap()
.unwrap();
config.autostart_channel_id = autostart_channel_id;
database
.set_server_config(message_component.guild_id.unwrap().0, config)
.set_server_config(message_component.guild_id.unwrap().get(), config)
.await
.unwrap();
};
message_component
.create_interaction_response(&ctx.http, |c| {
c.kind(InteractionResponseType::UpdateMessage)
.interaction_response_data(|d| {
d.content("自動参加チャンネルを設定しました。")
.components(|f| f)
})
})
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new()
.content("自動参加チャンネルを設定しました。"),
),
)
.await
.unwrap();
}
@ -339,9 +357,9 @@ impl EventHandler for Handler {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_server_config_or_default(message_component.guild_id.unwrap().0)
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
.unwrap()
.unwrap()
@ -356,166 +374,159 @@ impl EventHandler for Handler {
.await
.unwrap();
let mut options = Vec::new();
for (id, channel) in channels {
if channel.kind != ChannelType::Voice {
continue;
}
let description = channel
.topic
.unwrap_or_else(|| String::from("No topic provided."));
let option = CreateSelectMenuOption::new(
&channel.name,
format!("SET_AUTOSTART_CHANNEL_{}", id.get()),
)
.description(description)
.default_selection(channel.id.get() == autostart_channel_id);
options.push(option);
}
message_component
.create_interaction_response(&ctx.http, |f| {
f.kind(InteractionResponseType::UpdateMessage)
.interaction_response_data(|d| {
d.custom_id("SET_AUTOSTART_FORM")
.content("自動参加チャンネル設定")
.components(|c| {
c.create_action_row(|a| {
a.create_select_menu(|m| {
m.min_values(0)
.max_values(1)
.disabled(false)
.custom_id("SET_AUTOSTART_CHANNEL")
.options(|o| {
// Create channel list
for (id, channel) in channels {
if channel.kind != ChannelType::Voice {
continue;
}
o.create_option(|co| {
co.label(channel.name)
.description(
channel
.topic
.unwrap_or(String::from(
"No topic provided.",
)),
)
.value(format!("SET_AUTOSTART_CHANNEL_{}", id.0))
.default_selection(channel.id.0 == autostart_channel_id)
});
}
o
})
})
})
})
})
})
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new()
.content("自動参加チャンネル設定")
.components(vec![CreateActionRow::SelectMenu(
CreateSelectMenu::new(
"SET_AUTOSTART_CHANNEL",
CreateSelectMenuKind::String { options },
)
.min_values(0)
.max_values(1),
)]),
),
)
.await
.unwrap();
}
"TTS_CONFIG_SERVER" => {
message_component
.create_interaction_response(&ctx.http, |f| {
f.kind(InteractionResponseType::UpdateMessage)
.interaction_response_data(|d| {
d.content("サーバー設定")
.custom_id("TTS_CONFIG_SERVER")
.components(|c| {
c.create_action_row(|a| {
a.create_button(|b| {
b.custom_id(
"TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON",
)
.label("辞書を追加")
.style(ButtonStyle::Primary)
})
.create_button(|b| {
b.custom_id(
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON",
)
.label("辞書を削除")
.style(ButtonStyle::Danger)
})
.create_button(|b| {
b.custom_id(
"TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON",
)
.label("辞書一覧")
.style(ButtonStyle::Primary)
})
.create_button(|b| {
b.custom_id(
"TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL"
)
.label("自動参加チャンネル")
.style(ButtonStyle::Primary)
})
})
})
})
})
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new()
.content("サーバー設定")
.components(vec![CreateActionRow::Buttons(vec![
CreateButton::new(
"TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON",
)
.label("辞書を追加")
.style(ButtonStyle::Primary),
CreateButton::new(
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON",
)
.label("辞書を削除")
.style(ButtonStyle::Danger),
CreateButton::new(
"TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON",
)
.label("辞書一覧")
.style(ButtonStyle::Primary),
CreateButton::new(
"TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL",
)
.label("自動参加チャンネル")
.style(ButtonStyle::Primary),
])]),
),
)
.await
.unwrap();
}
_ => {}
}
if let Some(v) = message_component.data.values.get(0) {
let data_read = ctx.data.read().await;
match message_component.data.kind {
ComponentInteractionDataKind::StringSelect { ref values, .. }
if !values.is_empty() =>
{
let res = &values[0].clone();
let data_read = ctx.data.read().await;
let mut config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_user_config_or_default(message_component.user.id.0)
.await
.unwrap()
.unwrap()
};
let mut config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let res = (*v).clone();
let mut config_changed = false;
let mut voicevox_changed = false;
match &*res {
"TTS_CONFIG_ENGINE_SELECTED_GOOGLE" => {
config.tts_type = Some(TTSType::GCP);
config_changed = true;
}
"TTS_CONFIG_ENGINE_SELECTED_VOICEVOX" => {
config.tts_type = Some(TTSType::VOICEVOX);
config_changed = true;
}
_ => {
if res.starts_with("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_") {
config.voicevox_speaker = Some(
i64::from_str_radix(
&res.replace("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_", ""),
10,
)
.unwrap(),
);
database
.get_user_config_or_default(message_component.user.id.get())
.await
.unwrap()
.unwrap()
};
let mut config_changed = false;
let mut voicevox_changed = false;
match res.as_str() {
"TTS_CONFIG_ENGINE_SELECTED_GOOGLE" => {
config.tts_type = Some(TTSType::GCP);
config_changed = true;
voicevox_changed = true;
}
"TTS_CONFIG_ENGINE_SELECTED_VOICEVOX" => {
config.tts_type = Some(TTSType::VOICEVOX);
config_changed = true;
}
_ => {
if res.starts_with("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_") {
let speaker_id = res
.strip_prefix("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_")
.and_then(|id_str| id_str.parse::<i64>().ok())
.expect("Invalid speaker ID format");
config.voicevox_speaker = Some(speaker_id);
config_changed = true;
voicevox_changed = true;
}
}
}
}
if config_changed {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.set_user_config(message_component.user.id.0, config.clone())
.await
.unwrap();
if config_changed {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.set_user_config(message_component.user.id.get(), config.clone())
.await
.unwrap();
let response_content = if voicevox_changed
&& config.tts_type.unwrap_or(TTSType::GCP) == TTSType::GCP
{
"設定しました\nこの音声を使うにはAPIをGoogleからVOICEVOXに変更する必要があります。"
} else {
"設定しました"
};
if voicevox_changed && config.tts_type.unwrap_or(TTSType::GCP) == TTSType::GCP {
message_component.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("設定しました\nこの音声を使うにはAPIをGoogleからVOICEVOXに変更する必要があります。")
.flags(MessageFlags::EPHEMERAL)
})
}).await.unwrap();
} else {
message_component
.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("設定しました").flags(MessageFlags::EPHEMERAL)
})
})
.create_response(
&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content(response_content)
.ephemeral(true),
),
)
.await
.unwrap();
}
}
_ => {}
}
}
}

View File

@ -31,7 +31,7 @@ pub async fn message(ctx: Context, message: Message) {
let instance = storage.get_mut(&guild_id).unwrap();
if instance.text_channel.0 != message.channel_id.0 {
if instance.text_channel != message.channel_id {
return;
}

View File

@ -1,32 +1,33 @@
use serenity::{
model::prelude::{command::Command, Ready},
all::{Command, CommandOptionType, CreateCommand, CreateCommandOption},
model::prelude::Ready,
prelude::Context,
};
use tracing::info;
#[tracing::instrument]
pub async fn ready(ctx: Context, ready: Ready) {
println!("{} is connected!", ready.user.name);
info!("{} is connected!", ready.user.name);
let _ = Command::set_global_application_commands(&ctx.http, |commands| {
commands
.create_application_command(|command| command.name("stop").description("Stop tts"))
.create_application_command(|command| {
command
.name("setup")
.description("Setup tts")
.create_option(|o| {
o.name("mode")
.description("TTS channel")
.add_string_choice("Text Channel", "TEXT_CHANNEL")
.add_string_choice("New Thread", "NEW_THREAD")
.add_string_choice("Voice Channel", "VOICE_CHANNEL")
.kind(serenity::model::prelude::command::CommandOptionType::String)
.required(false)
})
})
.create_application_command(|command| command.name("config").description("Config"))
.create_application_command(|command| {
command.name("skip").description("skip tts message")
})
})
.await;
Command::set_global_commands(
&ctx.http,
vec![
CreateCommand::new("stop").description("Stop tts"),
CreateCommand::new("setup")
.description("Setup tts")
.set_options(vec![CreateCommandOption::new(
CommandOptionType::String,
"mode",
"TTS channel",
)
.add_string_choice("Text Channel", "TEXT_CHANNEL")
.add_string_choice("New Thread", "NEW_THREAD")
.add_string_choice("Voice Channel", "VOICE_CHANNEL")
.required(false)]),
CreateCommand::new("config").description("Config"),
CreateCommand::new("skip").description("skip tts message"),
],
)
.await
.unwrap();
}

View File

@ -6,7 +6,11 @@ use crate::{
},
tts::{instance::TTSInstance, message::AnnounceMessage},
};
use serenity::{model::voice::VoiceState, prelude::Context};
use serenity::{
all::{CreateEmbed, CreateMessage, EditThread},
model::voice::VoiceState,
prelude::Context,
};
pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: VoiceState) {
if new.member.clone().unwrap().user.bot {
@ -37,9 +41,8 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_server_config_or_default(guild_id.0)
.get_server_config_or_default(guild_id.get())
.await
.unwrap()
.unwrap()
@ -49,7 +52,7 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild_id) {
if let Some(new_channel) = new.channel_id {
if config.autostart_channel_id.unwrap_or(0) == new_channel.0 {
if config.autostart_channel_id.unwrap_or(0) == new_channel.get() {
let manager = songbird::get(&ctx)
.await
.expect("Cannot get songbird client.")
@ -64,29 +67,28 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
},
);
let _handler = manager.join(guild_id.0, new_channel.0).await;
let tts_client = ctx
.data
.read()
.await
let _handler = manager.join(guild_id, new_channel).await;
let data = ctx.data.read().await;
let tts_client = data
.get::<TTSClientData>()
.expect("Cannot get TTSClientData")
.clone();
let voicevox_speakers = tts_client.lock().await.1.get_speakers().await;
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
new_channel
.send_message(&ctx.http, |f| {
f.embed(|e| {
e.title("自動参加 読み上げ (Serenity)")
.send_message(
&ctx.http,
CreateMessage::new().embed(
CreateEmbed::new()
.title("自動参加 読み上げSerenity")
.field(
"VOICEVOXクレジット",
format!("```\n{}\n```", voicevox_speakers.join("\n")),
false,
)
.field("設定コマンド", "`/config`", false)
.field("フィードバック", "https://feedback.mii.codes/", false)
})
})
.field("フィードバック", "https://feedback.mii.codes/", false),
),
)
.await
.unwrap();
}
@ -118,7 +120,10 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
let mut del_flag = false;
for channel in guild_id.channels(&ctx.http).await.unwrap() {
if channel.0 == instance.voice_channel {
del_flag = channel.1.members(&ctx.cache).await.unwrap().len() <= 1;
let members = channel.1.members(&ctx.cache).unwrap();
let user_count = members.iter().filter(|member| !member.user.bot).count();
del_flag = user_count == 0;
}
}
@ -127,7 +132,7 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
.get(&guild_id)
.unwrap()
.text_channel
.edit_thread(&ctx.http, |f| f.archived(true))
.edit_thread(&ctx.http, EditThread::new().archived(true))
.await;
storage.remove(&guild_id);
@ -136,7 +141,7 @@ pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: Voic
.expect("Cannot get songbird client.")
.clone();
manager.remove(guild_id.0).await.unwrap();
manager.remove(guild_id).await.unwrap();
}
}
}

View File

@ -1,4 +1,7 @@
use serenity::model::guild::Member;
use serenity::model::{
guild::{Member, PartialMember},
user::User,
};
pub trait ReadName {
fn read_name(&self) -> String;
@ -6,6 +9,20 @@ pub trait ReadName {
impl ReadName for Member {
fn read_name(&self) -> String {
self.nick.clone().unwrap_or(self.user.name.clone())
self.nick.clone().unwrap_or(self.display_name().to_string())
}
}
impl ReadName for PartialMember {
fn read_name(&self) -> String {
self.nick
.clone()
.unwrap_or(self.user.as_ref().unwrap().display_name().to_string())
}
}
impl ReadName for User {
fn read_name(&self) -> String {
self.display_name().to_string()
}
}

View File

@ -1,11 +1,11 @@
use std::{env, fs::File, io::Write};
use async_trait::async_trait;
use regex::Regex;
use serenity::{model::prelude::Message, prelude::Context};
use songbird::tracks::Track;
use crate::{
data::{DatabaseClientData, TTSClientData},
implement::member_name::ReadName,
tts::{
gcp_tts::structs::{
audio_config::AudioConfig, synthesis_input::SynthesisInput,
@ -27,9 +27,8 @@ impl TTSMessage for Message {
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_server_config_or_default(instance.guild.0)
.get_server_config_or_default(instance.guild.get())
.await
.unwrap()
.unwrap()
@ -48,19 +47,29 @@ impl TTSMessage for Message {
text.clone()
} else {
let member = self.member.clone();
let name = if let Some(member) = member {
member.nick.unwrap_or(self.author.name.clone())
let name = if let Some(_) = member {
let guild = ctx.cache.guild(self.guild_id.unwrap()).unwrap().clone();
guild
.member(&ctx.http, self.author.id)
.await
.unwrap()
.read_name()
} else {
self.author.name.clone()
self.author.read_name()
};
format!("{}さんの発言<break time=\"200ms\"/>{}", name, text)
}
} else {
let member = self.member.clone();
let name = if let Some(member) = member {
member.nick.unwrap_or(self.author.name.clone())
let name = if let Some(_) = member {
let guild = ctx.cache.guild(self.guild_id.unwrap()).unwrap().clone();
guild
.member(&ctx.http, self.author.id)
.await
.unwrap()
.read_name()
} else {
self.author.name.clone()
self.author.read_name()
};
format!("{}さんの発言<break time=\"200ms\"/>{}", name, text)
};
@ -78,33 +87,30 @@ impl TTSMessage for Message {
res
}
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String {
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Track> {
let text = self.parse(instance, ctx).await;
let data_read = ctx.data.read().await;
let storage = data_read
.get::<TTSClientData>()
.expect("Cannot get GCP TTSClientStorage")
.clone();
let mut tts = storage.lock().await;
let config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut database = database.lock().await;
database
.get_user_config_or_default(self.author.id.0)
.get_user_config_or_default(self.author.id.get())
.await
.unwrap()
.unwrap()
};
let audio = match config.tts_type.unwrap_or(TTSType::GCP) {
TTSType::GCP => tts
.0
.synthesize(SynthesizeRequest {
let tts = data_read
.get::<TTSClientData>()
.expect("Cannot get GCP TTSClientStorage");
match config.tts_type.unwrap_or(TTSType::GCP) {
TTSType::GCP => vec![tts
.synthesize_gcp(SynthesizeRequest {
input: SynthesisInput {
text: None,
ssml: Some(format!("<speak>{}</speak>", text)),
@ -117,26 +123,17 @@ impl TTSMessage for Message {
},
})
.await
.unwrap(),
.unwrap()
.into()],
TTSType::VOICEVOX => tts
.1
.synthesize(
text.replace("<break time=\"200ms\"/>", ""),
TTSType::VOICEVOX => vec![tts
.synthesize_voicevox(
&text.replace("<break time=\"200ms\"/>", ""),
config.voicevox_speaker.unwrap_or(1),
)
.await
.unwrap(),
};
let uuid = uuid::Uuid::new_v4().to_string();
let path = env::current_dir().unwrap();
let file_path = path.join("audio").join(format!("{}.mp3", uuid));
let mut file = File::create(file_path.clone()).unwrap();
file.write(&audio).unwrap();
file_path.into_os_string().into_string().unwrap()
.unwrap()
.into()],
}
}
}

View File

@ -29,12 +29,10 @@ impl VoiceMoveStateTrait for VoiceState {
(Some(old_channel_id), Some(new_channel_id)) => {
if old_channel_id == new_channel_id {
VoiceMoveState::NONE
} else if old_channel_id != new_channel_id {
if target_channel == new_channel_id {
VoiceMoveState::JOIN
} else {
VoiceMoveState::NONE
}
} else if old_channel_id == target_channel {
VoiceMoveState::LEAVE
} else if new_channel_id == target_channel {
VoiceMoveState::JOIN
} else {
VoiceMoveState::NONE
}

View File

@ -5,6 +5,8 @@ mod database;
mod event_handler;
mod events;
mod implement;
mod stream_input;
mod trace;
mod tts;
use std::{collections::HashMap, env, sync::Arc};
@ -13,13 +15,16 @@ use config::Config;
use data::{DatabaseClientData, TTSClientData, TTSData};
use database::database::Database;
use event_handler::Handler;
#[allow(deprecated)]
use serenity::{
all::{standard::Configuration, ApplicationId},
client::Client,
framework::StandardFramework,
futures::lock::Mutex,
prelude::{GatewayIntents, RwLock},
};
use tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX};
use trace::init_tracing_subscriber;
use tracing::info;
use tts::{gcp_tts::gcp_tts::GCPTTS, tts::TTS, voicevox::voicevox::VOICEVOX};
use songbird::SerenityInit;
@ -31,12 +36,14 @@ use songbird::SerenityInit;
///
/// client.start().await;
/// ```
#[allow(deprecated)]
async fn create_client(prefix: &str, token: &str, id: u64) -> Result<Client, serenity::Error> {
let framework = StandardFramework::new().configure(|c| c.with_whitespace(true).prefix(prefix));
let framework = StandardFramework::new();
framework.configure(Configuration::new().with_whitespace(true).prefix(prefix));
Client::builder(token, GatewayIntents::all())
.event_handler(Handler)
.application_id(id)
.application_id(ApplicationId::new(id))
.framework(framework)
.register_songbird()
.await
@ -54,7 +61,18 @@ async fn main() {
let application_id = env::var("NCB_APP_ID").unwrap();
let prefix = env::var("NCB_PREFIX").unwrap();
let redis_url = env::var("NCB_REDIS_URL").unwrap();
let voicevox_key = env::var("NCB_VOICEVOX_KEY").unwrap();
let voicevox_key = match env::var("NCB_VOICEVOX_KEY") {
Ok(key) => Some(key),
Err(_) => None,
};
let voicevox_original_api_url = match env::var("NCB_VOICEVOX_ORIGINAL_API_URL") {
Ok(url) => Some(url),
Err(_) => None,
};
let otel_http_url = match env::var("NCB_OTEL_HTTP_URL") {
Ok(url) => Some(url),
Err(_) => None,
};
Config {
token,
@ -62,22 +80,26 @@ async fn main() {
prefix,
redis_url,
voicevox_key,
voicevox_original_api_url,
otel_http_url,
}
}
};
let _guard = init_tracing_subscriber(&config.otel_http_url);
// Create discord client
let mut client = create_client(&config.prefix, &config.token, config.application_id)
.await
.expect("Err creating client");
// Create GCP TTS client
let tts = match TTS::new("./credentials.json".to_string()).await {
let tts = match GCPTTS::new("./credentials.json".to_string()).await {
Ok(tts) => tts,
Err(err) => panic!("GCP init error: {}", err),
};
let voicevox = VOICEVOX::new(config.voicevox_key);
let voicevox = VOICEVOX::new(config.voicevox_key, config.voicevox_original_api_url);
let database_client = {
let redis_client = redis::Client::open(config.redis_url).unwrap();
@ -88,10 +110,12 @@ async fn main() {
{
let mut data = client.data.write().await;
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
data.insert::<TTSClientData>(Arc::new(Mutex::new((tts, voicevox))));
data.insert::<DatabaseClientData>(Arc::new(Mutex::new(database_client)));
data.insert::<TTSClientData>(Arc::new(TTS::new(voicevox, tts)));
data.insert::<DatabaseClientData>(Arc::new(database_client));
}
info!("Bot initialized.");
// Run client
if let Err(why) = client.start().await {
println!("Client error: {:?}", why);

93
src/stream_input.rs Normal file
View File

@ -0,0 +1,93 @@
use async_trait::async_trait;
use futures::TryStreamExt;
use reqwest::{header::HeaderMap, Client};
use symphonia_core::{io::MediaSource, probe::Hint};
use tokio_util::compat::FuturesAsyncReadCompatExt;
use songbird::input::{
AsyncAdapterStream, AsyncReadOnlySource, AudioStream, AudioStreamError, Compose, Input,
};
#[derive(Debug, Clone)]
pub struct Mp3Request {
client: Client,
request: String,
headers: HeaderMap,
}
impl Mp3Request {
#[must_use]
pub fn new(client: Client, request: String) -> Self {
Self::new_with_headers(client, request, HeaderMap::default())
}
#[must_use]
pub fn new_with_headers(client: Client, request: String, headers: HeaderMap) -> Self {
Mp3Request {
client,
request,
headers,
}
}
async fn create_stream_async(&self) -> Result<AsyncReadOnlySource, AudioStreamError> {
let request = self
.client
.get(&self.request)
.headers(self.headers.clone())
.build()
.map_err(|why| AudioStreamError::Fail(why.into()))?;
let response = self
.client
.execute(request)
.await
.map_err(|why| AudioStreamError::Fail(why.into()))?;
if !response.status().is_success() {
return Err(AudioStreamError::Fail(
format!("HTTP error: {}", response.status()).into(),
));
}
let byte_stream = response
.bytes_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()));
let tokio_reader = byte_stream.into_async_read().compat();
Ok(AsyncReadOnlySource::new(tokio_reader))
}
}
#[async_trait]
impl Compose for Mp3Request {
fn create(&mut self) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> {
Err(AudioStreamError::Fail(
"Mp3Request::create must be called in an async context via create_async".into(),
))
}
async fn create_async(
&mut self,
) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> {
let input = self.create_stream_async().await?;
let stream = AsyncAdapterStream::new(Box::new(input), 64 * 1024);
let hint = Hint::new().with_extension("mp3").clone();
Ok(AudioStream {
input: Box::new(stream) as Box<dyn MediaSource>,
hint: Some(hint),
})
}
fn should_create_async(&self) -> bool {
true
}
}
impl From<Mp3Request> for Input {
fn from(val: Mp3Request) -> Self {
Input::Lazy(Box::new(val))
}
}

128
src/trace.rs Normal file
View File

@ -0,0 +1,128 @@
use opentelemetry::{
global,
trace::{SamplingDecision, SamplingResult, TraceContextExt, TraceState, TracerProvider as _},
KeyValue,
};
use opentelemetry_otlp::{Protocol, WithExportConfig};
use opentelemetry_sdk::{
metrics::{MeterProviderBuilder, PeriodicReader, SdkMeterProvider},
trace::{RandomIdGenerator, SdkTracerProvider, ShouldSample},
Resource,
};
use tracing::Level;
use tracing_opentelemetry::{MetricsLayer, OpenTelemetryLayer};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[derive(Debug, Clone)]
struct FilterSampler;
impl ShouldSample for FilterSampler {
fn should_sample(
&self,
parent_context: Option<&opentelemetry::Context>,
_trace_id: opentelemetry::TraceId,
name: &str,
_span_kind: &opentelemetry::trace::SpanKind,
_attributes: &[KeyValue],
_links: &[opentelemetry::trace::Link],
) -> opentelemetry::trace::SamplingResult {
let decision = if name == "dispatch" || name == "recv_event" {
SamplingDecision::Drop
} else {
SamplingDecision::RecordAndSample
};
SamplingResult {
decision,
attributes: vec![],
trace_state: match parent_context {
Some(ctx) => ctx.span().span_context().trace_state().clone(),
None => TraceState::default(),
},
}
}
}
fn resource() -> Resource {
Resource::builder().with_service_name("ncb-tts-r2").build()
}
fn init_meter_provider(url: &str) -> SdkMeterProvider {
let exporter = opentelemetry_otlp::MetricExporter::builder()
.with_http()
.with_endpoint(url)
.with_protocol(Protocol::HttpBinary)
.with_temporality(opentelemetry_sdk::metrics::Temporality::default())
.build()
.unwrap();
let reader = PeriodicReader::builder(exporter)
.with_interval(std::time::Duration::from_secs(5))
.build();
let stdout_reader =
PeriodicReader::builder(opentelemetry_stdout::MetricExporter::default()).build();
let meter_provider = MeterProviderBuilder::default()
.with_resource(resource())
.with_reader(reader)
.with_reader(stdout_reader)
.build();
global::set_meter_provider(meter_provider.clone());
meter_provider
}
fn init_tracer_provider(url: &str) -> SdkTracerProvider {
let exporter = opentelemetry_otlp::SpanExporter::builder()
.with_http()
.with_endpoint(url)
.with_protocol(Protocol::HttpBinary)
.build()
.unwrap();
SdkTracerProvider::builder()
.with_sampler(FilterSampler)
.with_id_generator(RandomIdGenerator::default())
.with_resource(resource())
.with_batch_exporter(exporter)
.build()
}
pub fn init_tracing_subscriber(otel_http_url: &Option<String>) -> OtelGuard {
let registry = tracing_subscriber::registry()
.with(tracing_subscriber::filter::LevelFilter::from_level(
Level::INFO,
))
.with(tracing_subscriber::fmt::layer());
if let Some(url) = otel_http_url {
let tracer_provider = init_tracer_provider(url);
let meter_provider = init_meter_provider(url);
let tracer = tracer_provider.tracer("ncb-tts-r2");
registry
.with(MetricsLayer::new(meter_provider.clone()))
.with(OpenTelemetryLayer::new(tracer))
.init();
OtelGuard {
_tracer_provider: Some(tracer_provider),
_meter_provider: Some(meter_provider),
}
} else {
registry.init();
OtelGuard {
_tracer_provider: None,
_meter_provider: None,
}
}
}
pub struct OtelGuard {
_tracer_provider: Option<SdkTracerProvider>,
_meter_provider: Option<SdkMeterProvider>,
}

View File

@ -2,35 +2,40 @@ use crate::tts::gcp_tts::structs::{
synthesize_request::SynthesizeRequest, synthesize_response::SynthesizeResponse,
};
use gcp_auth::Token;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Clone)]
pub struct TTS {
pub token: Token,
#[derive(Clone, Debug)]
pub struct GCPTTS {
pub token: Arc<RwLock<Token>>,
pub credentials_path: String,
}
impl TTS {
pub async fn update_token(&mut self) -> Result<(), gcp_auth::Error> {
if self.token.has_expired() {
impl GCPTTS {
#[tracing::instrument]
pub async fn update_token(&self) -> Result<(), gcp_auth::Error> {
let mut token = self.token.write().await;
if token.has_expired() {
let authenticator =
gcp_auth::from_credentials_file(self.credentials_path.clone()).await?;
let token = authenticator
let new_token = authenticator
.get_token(&["https://www.googleapis.com/auth/cloud-platform"])
.await?;
self.token = token;
*token = new_token;
}
Ok(())
}
pub async fn new(credentials_path: String) -> Result<TTS, gcp_auth::Error> {
#[tracing::instrument]
pub async fn new(credentials_path: String) -> Result<Self, gcp_auth::Error> {
let authenticator = gcp_auth::from_credentials_file(credentials_path.clone()).await?;
let token = authenticator
.get_token(&["https://www.googleapis.com/auth/cloud-platform"])
.await?;
Ok(TTS {
token,
Ok(Self {
token: Arc::new(RwLock::new(token)),
credentials_path,
})
}
@ -56,18 +61,25 @@ impl TTS {
/// }
/// }).await.unwrap();
/// ```
#[tracing::instrument]
pub async fn synthesize(
&mut self,
&self,
request: SynthesizeRequest,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
self.update_token().await.unwrap();
let client = reqwest::Client::new();
let token_string = {
let token = self.token.read().await;
token.as_str().to_string()
};
match client
.post("https://texttospeech.googleapis.com/v1/text:synthesize")
.header(reqwest::header::CONTENT_TYPE, "application/json")
.header(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", self.token.as_str()),
format!("Bearer {}", token_string),
)
.body(serde_json::to_string(&request).unwrap())
.send()

View File

@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
/// ssml: Some(String::from("<speak>test</speak>"))
/// }
/// ```
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone)]
pub struct SynthesisInput {
pub text: Option<String>,
pub ssml: Option<String>,

View File

@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize};
/// ssmlGender: String::from("neutral")
/// }
/// ```
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq, Hash)]
#[allow(non_snake_case)]
pub struct VoiceSelectionParams {
pub languageCode: String,

View File

@ -1,3 +1,5 @@
use std::fmt::Debug;
use serenity::{
model::{
channel::Message,
@ -8,6 +10,7 @@ use serenity::{
use crate::tts::message::TTSMessage;
#[derive(Debug, Clone)]
pub struct TTSInstance {
pub before_message: Option<Message>,
pub text_channel: ChannelId,
@ -22,23 +25,24 @@ impl TTSInstance {
/// ```rust
/// instance.read(message, &ctx).await;
/// ```
#[tracing::instrument]
pub async fn read<T>(&mut self, message: T, ctx: &Context)
where
T: TTSMessage,
T: TTSMessage + Debug,
{
let path = message.synthesize(self, ctx).await;
let audio = message.synthesize(self, ctx).await;
{
let manager = songbird::get(&ctx).await.unwrap();
let call = manager.get(self.guild).unwrap();
let mut call = call.lock().await;
let input = songbird::input::ffmpeg(path)
.await
.expect("File not found.");
call.enqueue_source(input);
for audio in audio {
call.enqueue(audio.into()).await;
}
}
}
#[tracing::instrument]
pub async fn skip(&mut self, ctx: &Context) {
let manager = songbird::get(&ctx).await.unwrap();
let call = manager.get(self.guild).unwrap();

View File

@ -1,7 +1,6 @@
use std::{env, fs::File, io::Write};
use async_trait::async_trait;
use serenity::prelude::Context;
use songbird::tracks::Track;
use crate::{data::TTSClientData, tts::instance::TTSInstance};
@ -21,15 +20,16 @@ pub trait TTSMessage {
/// ```
async fn parse(&self, instance: &mut TTSInstance, ctx: &Context) -> String;
/// Synthesize the message and returns the path to the audio file.
/// Synthesize the message and returns the audio data.
///
/// Example:
/// ```rust
/// let path = message.synthesize(instance, ctx).await;
/// let audio = message.synthesize(instance, ctx).await;
/// ```
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String;
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Track>;
}
#[derive(Debug, Clone)]
pub struct AnnounceMessage {
pub message: String,
}
@ -44,18 +44,15 @@ impl TTSMessage for AnnounceMessage {
)
}
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String {
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Track> {
let text = self.parse(instance, ctx).await;
let data_read = ctx.data.read().await;
let storage = data_read
let tts = data_read
.get::<TTSClientData>()
.expect("Cannot get TTSClientStorage")
.clone();
let mut storage = storage.lock().await;
.expect("Cannot get TTSClientStorage");
let audio = storage
.0
.synthesize(SynthesizeRequest {
let audio = tts
.synthesize_gcp(SynthesizeRequest {
input: SynthesisInput {
text: None,
ssml: Some(text),
@ -74,14 +71,6 @@ impl TTSMessage for AnnounceMessage {
.await
.unwrap();
let uuid = uuid::Uuid::new_v4().to_string();
let path = env::current_dir().unwrap();
let file_path = path.join("audio").join(format!("{}.mp3", uuid));
let mut file = File::create(file_path.clone()).unwrap();
file.write(&audio).unwrap();
file_path.into_os_string().into_string().unwrap()
vec![audio.into()]
}
}

View File

@ -1,5 +1,6 @@
pub mod gcp_tts;
pub mod instance;
pub mod message;
pub mod tts;
pub mod tts_type;
pub mod voicevox;

133
src/tts/tts.rs Normal file
View File

@ -0,0 +1,133 @@
use std::sync::RwLock;
use std::{num::NonZeroUsize, sync::Arc};
use lru::LruCache;
use songbird::{driver::Bitrate, input::cached::Compressed, tracks::Track};
use tracing::info;
use super::{
gcp_tts::{
gcp_tts::GCPTTS,
structs::{
synthesis_input::SynthesisInput, synthesize_request::SynthesizeRequest,
voice_selection_params::VoiceSelectionParams,
},
},
voicevox::voicevox::VOICEVOX,
};
#[derive(Debug)]
pub struct TTS {
pub voicevox_client: VOICEVOX,
gcp_tts_client: GCPTTS,
cache: Arc<RwLock<LruCache<CacheKey, Compressed>>>,
}
#[derive(Hash, PartialEq, Eq)]
pub enum CacheKey {
Voicevox(String, i64),
GCP(SynthesisInput, VoiceSelectionParams),
}
impl TTS {
pub fn new(voicevox_client: VOICEVOX, gcp_tts_client: GCPTTS) -> Self {
Self {
voicevox_client,
gcp_tts_client,
cache: Arc::new(RwLock::new(LruCache::new(NonZeroUsize::new(1000).unwrap()))),
}
}
#[tracing::instrument]
pub async fn synthesize_voicevox(
&self,
text: &str,
speaker: i64,
) -> Result<Track, Box<dyn std::error::Error>> {
let cache_key = CacheKey::Voicevox(text.to_string(), speaker);
let cached_audio = {
let mut cache_guard = self.cache.write().unwrap();
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
};
if let Some(audio) = cached_audio {
info!("Cache hit for VOICEVOX TTS");
return Ok(audio.into());
}
info!("Cache miss for VOICEVOX TTS");
if self.voicevox_client.original_api_url.is_some() {
let audio = self
.voicevox_client
.synthesize_original(text.to_string(), speaker)
.await?;
tokio::spawn({
let cache = self.cache.clone();
let audio = audio.clone();
async move {
info!("Compressing stream audio");
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await.unwrap();
let mut cache_guard = cache.write().unwrap();
cache_guard.put(cache_key, compressed.clone());
}
});
Ok(audio.into())
} else {
let audio = self
.voicevox_client
.synthesize_stream(text.to_string(), speaker)
.await?;
tokio::spawn({
let cache = self.cache.clone();
let audio = audio.clone();
async move {
info!("Compressing stream audio");
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await.unwrap();
let mut cache_guard = cache.write().unwrap();
cache_guard.put(cache_key, compressed.clone());
}
});
Ok(audio.into())
}
}
#[tracing::instrument]
pub async fn synthesize_gcp(
&self,
synthesize_request: SynthesizeRequest,
) -> Result<Compressed, Box<dyn std::error::Error>> {
let cache_key = CacheKey::GCP(
synthesize_request.input.clone(),
synthesize_request.voice.clone(),
);
let cached_audio = {
let mut cache_guard = self.cache.write().unwrap();
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
};
if let Some(audio) = cached_audio {
info!("Cache hit for GCP TTS");
return Ok(audio);
}
info!("Cache miss for GCP TTS");
let audio = self.gcp_tts_client.synthesize(synthesize_request).await?;
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
{
let mut cache_guard = self.cache.write().unwrap();
cache_guard.put(cache_key, compressed.clone());
}
Ok(compressed)
}
}

View File

@ -2,3 +2,4 @@ pub mod accent_phrase;
pub mod audio_query;
pub mod mora;
pub mod speaker;
pub mod stream;

View File

@ -0,0 +1,13 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TTSResponse {
pub success: bool,
pub is_api_key_valid: bool,
pub speaker_name: String,
pub audio_status_url: String,
pub wav_download_url: String,
pub mp3_download_url: String,
pub mp3_streaming_url: String,
}

View File

@ -1,13 +1,17 @@
use super::structs::speaker::Speaker;
use crate::stream_input::Mp3Request;
use super::structs::{speaker::Speaker, stream::TTSResponse};
const BASE_API_URL: &str = "https://deprecatedapis.tts.quest/v2/";
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct VOICEVOX {
pub key: String,
pub key: Option<String>,
pub original_api_url: Option<String>,
}
impl VOICEVOX {
#[tracing::instrument]
pub async fn get_styles(&self) -> Vec<(String, i64)> {
let speakers = self.get_speaker_list().await;
let mut speaker_list = vec![];
@ -20,6 +24,7 @@ impl VOICEVOX {
speaker_list
}
#[tracing::instrument]
pub async fn get_speakers(&self) -> Vec<String> {
let speakers = self.get_speaker_list().await;
let mut speaker_list = vec![];
@ -30,18 +35,27 @@ impl VOICEVOX {
speaker_list
}
pub fn new(key: String) -> Self {
Self { key }
pub fn new(key: Option<String>, original_api_url: Option<String>) -> Self {
Self {
key,
original_api_url,
}
}
#[tracing::instrument]
async fn get_speaker_list(&self) -> Vec<Speaker> {
let client = reqwest::Client::new();
match client
.post(BASE_API_URL.to_string() + "voicevox/speakers/")
.query(&[("key", self.key.clone())])
.send()
.await
{
let client = if let Some(key) = &self.key {
client
.get(BASE_API_URL.to_string() + "voicevox/speakers/")
.query(&[("key", key)])
} else if let Some(original_api_url) = &self.original_api_url {
client.get(original_api_url.to_string() + "/speakers")
} else {
panic!("No API key or original API URL provided.")
};
match client.send().await {
Ok(response) => response.json().await.unwrap(),
Err(err) => {
panic!("Cannot get speaker list. {err:?}")
@ -49,6 +63,7 @@ impl VOICEVOX {
}
}
#[tracing::instrument]
pub async fn synthesize(
&self,
text: String,
@ -60,7 +75,7 @@ impl VOICEVOX {
.query(&[
("speaker", speaker.to_string()),
("text", text),
("key", self.key.clone()),
("key", self.key.clone().unwrap()),
])
.send()
.await
@ -72,4 +87,47 @@ impl VOICEVOX {
Err(err) => Err(Box::new(err)),
}
}
#[tracing::instrument]
pub async fn synthesize_original(
&self,
text: String,
speaker: i64,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let client =
voicevox_client::Client::new(self.original_api_url.as_ref().unwrap().clone(), None);
let audio_query = client
.create_audio_query(&text, speaker as i32, None)
.await?;
println!("{:?}", audio_query.audio_query);
let audio = audio_query.synthesis(speaker as i32, true).await?;
Ok(audio.into())
}
#[tracing::instrument]
pub async fn synthesize_stream(
&self,
text: String,
speaker: i64,
) -> Result<Mp3Request, Box<dyn std::error::Error>> {
let client = reqwest::Client::new();
match client
.post("https://api.tts.quest/v3/voicevox/synthesis")
.query(&[
("speaker", speaker.to_string()),
("text", text),
("key", self.key.clone().unwrap()),
])
.send()
.await
{
Ok(response) => {
let body = response.text().await.unwrap();
let response: TTSResponse = serde_json::from_str(&body).unwrap();
Ok(Mp3Request::new(reqwest::Client::new(), response.mp3_streaming_url).into())
}
Err(err) => Err(Box::new(err)),
}
}
}