95 Commits

Author SHA1 Message Date
248c88a7a2 update t v1.11.2 2025-05-25 22:27:42 +09:00
65db668e2a update Docker things 2025-05-25 00:10:08 +09:00
879644f30c fix gcp version 2025-05-25 00:02:44 +09:00
8ce6dbf57a Merge branch 'stable' of github.com:mii443/ncb-tts-r2 into stable 2025-05-24 23:56:06 +09:00
4d658b6671 WIP: ローカルの変更を保存 2025-05-24 23:49:30 +09:00
e606b29f81 add server configs 2025-05-24 23:44:59 +09:00
b452d4609a Merge branch 'master' into stable 2025-05-24 21:54:11 +09:00
599b7fb0d1 v1.10.1 2025-04-27 17:17:44 +09:00
4c76fd037a v1.9.2 2025-04-27 17:17:06 +09:00
83fde399b2 support original voicevox api 2025-04-27 17:16:17 +09:00
91044c7c25 fix user count 2025-04-24 15:32:43 +00:00
a13994c37e fix user count 2025-04-12 18:06:40 +09:00
97ae9dd9e0 optimize database lock 2025-04-11 18:07:46 +09:00
f7e08b4e2e add mp3 streaming for voicevox 2025-04-11 15:10:01 +09:00
257c8511e3 change type 2025-04-11 14:03:13 +09:00
771711e3bf update 2025-04-11 10:43:21 +09:00
40e6194942 supress warn 2025-04-10 15:06:37 +09:00
9949d501b5 fix warn, optimize 2025-04-10 14:57:09 +09:00
24de817d6f add otel http url config 2025-04-10 14:17:30 +09:00
8a1fa22074 add tracing opentelemetry 2025-04-10 13:40:29 +09:00
24c609aaf2 update 2025-04-05 17:53:51 +09:00
51f008cfdc extend cache size 2025-04-05 17:53:16 +09:00
4c176935e3 reduce Mutex lock 2025-04-04 22:14:23 +09:00
55ea223f69 gcp synthesize without mutable 2025-04-04 21:59:53 +09:00
696954395b support vec audio 2025-04-04 21:46:32 +09:00
82e3c55fd5 fix warnings 2025-04-04 21:11:46 +09:00
af83f6b6e0 change log level, remove unused import 2025-04-04 21:07:48 +09:00
c68e533133 implement compressed local audio cache 2025-04-04 20:45:20 +09:00
77b4c3e04d update workflow 2025-04-03 03:05:59 +09:00
8db4d65042 change versions 2025-04-03 03:04:35 +09:00
e12dbd7375 fix voice move state bug 2025-04-03 03:03:06 +09:00
cff30e7471 fix reading name 2025-04-03 02:58:02 +09:00
bf4b160af7 implement tts without disk I/O 2025-04-03 02:43:26 +09:00
b4de0f1ad6 fix warnings 2025-04-03 02:38:40 +09:00
1975b2e9cd fix instance.rs bug 2025-04-03 02:32:59 +09:00
5ee3c9b328 fix errors 2025-04-03 02:04:17 +09:00
df46152a12 update crates, fix event_handler, main 2025-03-30 23:34:25 +09:00
1830029231 add ARM build 2025-02-21 16:17:16 +09:00
e4dbedcbe7 Update docker-compose.yml 2025-02-21 16:15:23 +09:00
f11718cc8b Update docker-compose.yml 2025-02-21 16:14:47 +09:00
8ef5530524 update serenity to poise 2024-11-19 10:15:52 +00:00
89f66aefb0 update dependencies 2024-11-19 10:00:15 +00:00
mii
68e96ef784 change base image 2024-01-16 15:01:45 +00:00
mii
883a54f70a update rust version 2024-01-16 14:24:46 +00:00
mii
60255d7582 fix action row 2024-01-16 14:13:19 +00:00
mii
68d49772e7 update version to 1.7.0 2023-10-03 15:15:58 +00:00
mii
0fbe068f1e fix action row 2023-10-03 15:13:55 +00:00
mii
c39800da18 auto load voiecvox speaker list 2023-10-03 14:54:34 +00:00
mii
5aa5e09bb7 downgrade base image 2023-09-30 12:33:54 +00:00
mii
ec47a6f521 fix github workflow 2023-09-30 11:15:09 +00:00
mii
bb4d0a0504 fix github workflow 2023-09-30 11:05:05 +00:00
mii
4630883b28 Update version 2023-09-30 10:55:57 +00:00
mii
2249e8c213 Add autostart feature 2023-09-30 10:53:57 +00:00
mii
f9ebd8a430 wip autostart settings 2023-09-28 10:53:41 +00:00
mii
8a9817a449 update version to 1.5.1 2023-04-13 04:44:09 +00:00
mii
4c5b9cb345 update songbird 2023-04-13 04:43:20 +00:00
mii
52f86a6c16 add ignore 2023-04-06 01:46:40 +00:00
mii
b62e81dd66 add nurserobo_type_t 2023-01-15 07:50:59 +00:00
mii
60770f65b6 add skip command 2022-12-08 08:25:00 +00:00
mii
f93701a591 add dictionary, server config 2022-12-07 08:21:43 +00:00
mii
2a40e9ee16 change regex 2022-12-07 03:36:43 +00:00
mii
2f2b82857f update docker build 2022-12-07 03:24:49 +00:00
mii
8b2574e90b fix code block regex 2022-12-06 15:55:17 +00:00
mii
fb654229b0 fix command register 2022-12-06 15:35:19 +00:00
mii
7124930a9d fix code block regex 2022-12-06 14:46:03 +00:00
mii
0bd051e48c add feedback url 2022-12-06 14:39:14 +00:00
mii
ccf2c63224 change setup message 2022-12-06 14:29:50 +00:00
mii
c2a84be3a4 TTS Channel mode feature 2022-12-06 14:04:14 +00:00
mii
d525636060 add kubernetes manifest, add code block regex 2022-12-06 13:01:32 +00:00
mii
68a73a4dae support auto archive thread 2022-11-19 02:51:09 +00:00
mii
f7b7071a09 update docker-compose.yml image version 2022-11-18 09:26:19 +00:00
mii
b7a4da7f3e thread tts 2022-11-18 09:08:27 +00:00
mii
708b6fc429 fix getting guild id 2022-11-03 05:13:05 +00:00
mii
f9fd0686a7 add sample docker-compose.yml 2022-11-02 15:07:36 +00:00
mii
99d8ef9bef voice state bugfix 2022-11-02 15:04:18 +00:00
mii
af01576a99 update serenity and songbird 2022-11-02 14:38:56 +00:00
mii
ddab474d67 support env variable 2022-10-31 13:02:20 +00:00
mii
d10bfcc333 fix database, borrow 2022-10-31 12:46:59 +00:00
mii
1470612d8b clippy 2022-10-31 12:40:55 +00:00
mii
065717839b reading attachment files, fix database 2022-10-31 21:04:11 +09:00
mii
6dafc66878 . 2022-08-14 12:02:52 +09:00
mii
1789bd7c4e fix config change message, auto disconnect 2022-08-13 16:48:47 +09:00
mii
0b93c23e91 auto leave 2022-08-12 23:42:23 +09:00
mii
7fe65bc397 add validator 2022-08-12 23:07:20 +09:00
mii
51c39036c6 delete debug register 2022-08-12 22:43:21 +09:00
mii
c52429bce0 add config command 2022-08-12 22:39:12 +09:00
mii
5ca5325fbd refactoring 2022-08-12 20:25:39 +09:00
mii
4161acbd45 fix warn 2022-08-12 20:05:34 +09:00
mii
47b11262e2 refactoring 2022-08-12 19:05:02 +09:00
mii
b36bee8be8 fix database connection bug 2022-08-12 17:57:28 +09:00
mii
6aec4e4ea7 fix dockerfile 2022-08-12 00:50:12 +09:00
mii
2bf2fe05f1 dockerfile 2022-08-12 00:40:14 +09:00
mii
f99d37ea56 fix slash commands register 2022-08-12 00:25:52 +09:00
mii
3652079bab fix audio path 2022-08-12 00:00:03 +09:00
mii
8a4de65a8a fix audio path 2022-08-11 23:59:52 +09:00
53 changed files with 2645 additions and 565 deletions

2
.dockerignore Normal file
View File

@ -0,0 +1,2 @@
target
audio

View File

@ -8,22 +8,29 @@ jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: docker/metadata-action@v3
- uses: actions/checkout@v4
name: Checkout
- uses: docker/metadata-action@v4
id: meta
with:
images: ghcr.io/morioka22/ncb-tts-r2
images: ghcr.io/mii443/ncb-tts-r2
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
- uses: docker/login-action@v1
- uses: docker/login-action@v2
with:
registry: ghcr.io
username: morioka22
username: mii443
password: ${{ secrets.GITHUB_TOKEN }}
- uses: docker/build-push-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- uses: docker/build-push-action@v4
with:
context: .
push: true
platforms: linux/amd64,linux/arm64
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max

3
.gitignore vendored
View File

@ -3,4 +3,5 @@ Cargo.lock
config.toml
credentials.json
/audio
*.mp3
*.mp3
*.swp

View File

@ -1,6 +1,6 @@
[package]
name = "ncb-tts-r2"
version = "0.1.0"
version = "1.11.2"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@ -8,26 +8,56 @@ edition = "2021"
[dependencies]
serde_json = "1.0"
serde = "1.0"
toml = "0.5"
toml = "0.8.19"
gcp_auth = "0.5.0"
reqwest = { version = "0.11", features = ["json"] }
base64 = "0.13"
reqwest = { version = "0.12.9", features = ["json"] }
base64 = "0.22.1"
async-trait = "0.1.57"
redis = "*"
redis = "0.29.2"
regex = "1"
tracing-subscriber = "0.3.19"
lru = "0.13.0"
tracing = "0.1.41"
opentelemetry_sdk = { version = "0.29.0", features = ["trace"] }
opentelemetry = "0.29.1"
opentelemetry-semantic-conventions = "0.29.0"
opentelemetry-otlp = { version = "0.29.0", features = ["grpc-tonic"] }
opentelemetry-stdout = "0.29.0"
tracing-opentelemetry = "0.30.0"
symphonia-core = "0.5.4"
tokio-util = { version = "0.7.14", features = ["compat"] }
futures = "0.3.31"
bytes = "1.10.1"
voicevox-client = { git = "https://github.com/mii443/rust" }
[dependencies.uuid]
version = "0.8"
version = "1.11.0"
features = ["serde", "v4"]
[dependencies.songbird]
version = "0.2.0"
version = "0.5"
features = ["builtin-queue"]
[dependencies.serenity]
version = "0.10.9"
features = ["builder", "cache", "client", "gateway", "model", "utils", "unstable_discord_api", "collector", "rustls_backend", "framework", "voice"]
[dependencies.symphonia]
version = "0.5"
features = ["mp3"]
[dependencies.serenity]
version = "0.12"
features = [
"builder",
"cache",
"client",
"gateway",
"model",
"utils",
"unstable_discord_api",
"collector",
"rustls_backend",
"framework",
"voice",
]
[dependencies.tokio]
version = "1.0"
features = ["macros", "rt-multi-thread"]
features = ["macros", "rt-multi-thread"]

View File

@ -1,16 +1,45 @@
FROM ubuntu:22.04
RUN apt-get update \
&& apt-get install -y ffmpeg libssl-dev pkg-config libopus-dev wget curl gcc \
&& apt-get -y clean \
&& rm -rf /var/lib/apt/lists/*
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain stable
ENV PATH $PATH:/root/.cargo/bin/
RUN rustup install stable
WORKDIR /usr/src/ncb-tts-r2
COPY Cargo.toml .
COPY src src
RUN cargo build --release \
&& cp /usr/src/ncb-tts-r2/target/release/ncb-tts-r2 /usr/bin/ncb-tts-r2 \
&& mkdir -p /ncb-tts-r2/audio
FROM lukemathwalker/cargo-chef:latest-rust-1.82 AS chef
WORKDIR /app
FROM chef AS planner
COPY . .
RUN cargo chef prepare --recipe-path recipe.json
FROM chef AS builder
COPY --from=planner /app/recipe.json recipe.json
RUN apt-get update && \
apt-get install -y --no-install-recommends \
ffmpeg \
libssl-dev \
pkg-config \
libopus-dev \
gcc && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN cargo chef cook --release --recipe-path recipe.json
COPY . .
RUN cargo build --release
FROM ubuntu:22.04 AS runtime
WORKDIR /ncb-tts-r2
CMD ["ncb-tts-r2"]
# 非rootユーザーの作成
RUN groupadd -r appgroup && useradd -r -g appgroup appuser
RUN apt-get update && \
apt-get install -y --no-install-recommends \
openssl \
ca-certificates \
ffmpeg \
libssl-dev \
libopus-dev && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
COPY --from=builder /app/target/release/ncb-tts-r2 /usr/local/bin/ncb-tts-r2
RUN chmod +x /usr/local/bin/ncb-tts-r2
# 非rootユーザーに切り替え
USER appuser
ENTRYPOINT ["/usr/local/bin/ncb-tts-r2"]

14
docker-compose.yml Normal file
View File

@ -0,0 +1,14 @@
version: '3'
services:
ncb-tts-r2:
container_name: ncb-tts-r2
image: ghcr.io/mii443/ncb-tts-r2:1.11.2
environment:
- NCB_TOKEN=YOUR_BOT_TOKEN
- NCB_APP_ID=YOUR_BOT_ID
- NCB_PREFIX=BOT_PREFIX
- NCB_REDIS_URL=redis://<REDIS_IP>/
- NCB_VOICEVOX_KEY=VOICEVOX_KEY
volumes:
- ./credentials.json:/ncb-tts-r2/credentials.json:ro

56
manifest/ncb-tts.yaml Normal file
View File

@ -0,0 +1,56 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: ncb-tts-deployment
spec:
replicas: 1
selector:
matchLabels:
app: ncb-tts
template:
metadata:
labels:
app: ncb-tts
spec:
containers:
- name: redis
image: redis:7.0.4-alpine
ports:
- containerPort: 6379
name: ncb-redis
volumeMounts:
- name: ncb-redis-pvc
mountPath: /data
- name: tts
image: ghcr.io/mii443/ncb-tts-r2
volumeMounts:
- name: gcp-credentials
mountPath: /ncb-tts-r2/credentials.json
subPath: credentials.json
env:
- name: NCB_REDIS_URL
value: "redis://localhost:6379/"
- name: NCB_PREFIX
value: "t2!"
- name: NCB_TOKEN
valueFrom:
secretKeyRef:
name: ncb-secret
key: BOT_TOKEN
- name: NCB_VOICEVOX_KEY
valueFrom:
secretKeyRef:
name: ncb-secret
key: VOICEVOX_KEY
- name: NCB_APP_ID
valueFrom:
secretKeyRef:
name: ncb-secret
key: APP_ID
volumes:
- name: ncb-redis-pvc
persistentVolumeClaim:
claimName: ncb-redis-pvc
- name: gcp-credentials
secret:
secretName: gcp-credentials

12
manifest/pvc.yaml Normal file
View File

@ -0,0 +1,12 @@
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: ncb-redis-pvc
labels:
app: ncb-redis
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 3Gi

98
src/commands/config.rs Normal file
View File

@ -0,0 +1,98 @@
use serenity::{
all::{
ButtonStyle, CommandInteraction, CreateActionRow, CreateButton, CreateInteractionResponse,
CreateInteractionResponseMessage, CreateSelectMenu, CreateSelectMenuKind,
CreateSelectMenuOption,
},
prelude::Context,
};
use crate::{
data::{DatabaseClientData, TTSClientData},
tts::tts_type::TTSType,
};
#[tracing::instrument]
pub async fn config_command(
ctx: &Context,
command: &CommandInteraction,
) -> Result<(), Box<dyn std::error::Error>> {
let data_read = ctx.data.read().await;
let config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_user_config_or_default(command.user.id.get())
.await
.unwrap()
.unwrap()
};
let tts_client = data_read
.get::<TTSClientData>()
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client.voicevox_client.get_styles().await;
let voicevox_speaker = config.voicevox_speaker.unwrap_or(1);
let tts_type = config.tts_type.unwrap_or(TTSType::GCP);
let engine_select = CreateActionRow::SelectMenu(
CreateSelectMenu::new(
"TTS_CONFIG_ENGINE",
CreateSelectMenuKind::String {
options: vec![
CreateSelectMenuOption::new("Google TTS", "TTS_CONFIG_ENGINE_SELECTED_GOOGLE")
.default_selection(tts_type == TTSType::GCP),
CreateSelectMenuOption::new("VOICEVOX", "TTS_CONFIG_ENGINE_SELECTED_VOICEVOX")
.default_selection(tts_type == TTSType::VOICEVOX),
],
},
)
.placeholder("読み上げAPIを選択"),
);
let server_button = CreateActionRow::Buttons(vec![CreateButton::new("TTS_CONFIG_SERVER")
.label("サーバー設定")
.style(ButtonStyle::Primary)]);
let mut components = vec![engine_select, server_button];
for (index, speaker_chunk) in voicevox_speakers[0..24].chunks(25).enumerate() {
let mut options = Vec::new();
for (name, id) in speaker_chunk {
options.push(
CreateSelectMenuOption::new(
name,
format!("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_{}", id),
)
.default_selection(*id == voicevox_speaker),
);
}
components.push(CreateActionRow::SelectMenu(
CreateSelectMenu::new(
format!("TTS_CONFIG_VOICEVOX_SPEAKER_{}", index),
CreateSelectMenuKind::String { options },
)
.placeholder("VOICEVOX Speakerを指定"),
));
}
command
.create_response(
&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("読み上げ設定")
.components(components)
.ephemeral(true),
),
)
.await?;
Ok(())
}

4
src/commands/mod.rs Normal file
View File

@ -0,0 +1,4 @@
pub mod config;
pub mod setup;
pub mod skip;
pub mod stop;

165
src/commands/setup.rs Normal file
View File

@ -0,0 +1,165 @@
use serenity::{
all::{
AutoArchiveDuration, CommandInteraction, CreateEmbed, CreateInteractionResponse, CreateInteractionResponseMessage, CreateMessage, CreateThread
},
model::prelude::UserId,
prelude::Context,
};
use tracing::info;
use crate::{
data::{TTSClientData, TTSData},
tts::instance::TTSInstance,
};
#[tracing::instrument]
pub async fn setup_command(
ctx: &Context,
command: &CommandInteraction,
) -> Result<(), Box<dyn std::error::Error>> {
info!("Received event");
if command.guild_id.is_none() {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("このコマンドはサーバーでのみ使用可能です.")
.ephemeral(true)
))
.await?;
return Ok(());
}
info!("Fetching guild cache");
let guild_id = command.guild_id.unwrap();
let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone();
let channel_id = guild
.voice_states
.get(&UserId::from(command.user.id.get()))
.and_then(|state| state.channel_id);
if channel_id.is_none() {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("ボイスチャンネルに参加してから実行してください.")
.ephemeral(true)
))
.await?;
return Ok(());
}
let channel_id = channel_id.unwrap();
let manager = songbird::get(ctx)
.await
.expect("Cannot get songbird client.")
.clone();
let storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<TTSData>()
.expect("Cannot get TTSStorage")
.clone()
};
let text_channel_id = {
let mut storage = storage_lock.write().await;
if storage.contains_key(&guild.id) {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("すでにセットアップしています.")
.ephemeral(true)
))
.await?;
return Ok(());
}
let text_channel_id = {
if let Some(mode) = command.data.options.get(0) {
match &mode.value {
serenity::all::CommandDataOptionValue::String(value) => {
match value.as_str() {
"TEXT_CHANNEL" => command.channel_id,
"NEW_THREAD" => {
command
.channel_id
.create_thread(&ctx.http, CreateThread::new("TTS").auto_archive_duration(AutoArchiveDuration::OneHour).kind(serenity::all::ChannelType::PublicThread))
.await
.unwrap()
.id
}
"VOICE_CHANNEL" => channel_id,
_ => channel_id,
}
},
_ => channel_id,
}
} else {
channel_id
}
};
storage.insert(
guild.id,
TTSInstance {
before_message: None,
guild: guild.id,
text_channel: text_channel_id,
voice_channel: channel_id,
},
);
text_channel_id
};
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content(format!(
"TTS Channel: <#{}>{}",
text_channel_id,
if text_channel_id == channel_id {
"\nボイスチャンネルを右クリックし `チャットを開く` を押して開くことが出来ます。"
} else {
""
}
))
))
.await?;
let _handler = manager.join(guild.id, channel_id).await;
let data = ctx
.data
.read()
.await;
let tts_client = data
.get::<TTSClientData>()
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
text_channel_id
.send_message(&ctx.http, CreateMessage::new()
.embed(
CreateEmbed::new()
.title("読み上げ (Serenity)")
.field(
"VOICEVOXクレジット",
format!("```\n{}\n```", voicevox_speakers.join("\n")),
false,
)
.field("設定コマンド", "`/config`", false)
.field("フィードバック", "https://feedback.mii.codes/", false)
))
.await?;
Ok(())
}

81
src/commands/skip.rs Normal file
View File

@ -0,0 +1,81 @@
use serenity::{
all::{
CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage
},
model::prelude::UserId,
prelude::Context,
};
use crate::data::TTSData;
pub async fn skip_command(
ctx: &Context,
command: &CommandInteraction,
) -> Result<(), Box<dyn std::error::Error>> {
if command.guild_id.is_none() {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("このコマンドはサーバーでのみ使用可能です.")
.ephemeral(true)
))
.await?;
return Ok(());
}
let guild_id = command.guild_id.unwrap();
let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone();
let channel_id = guild
.voice_states
.get(&UserId::from(command.user.id.get()))
.and_then(|state| state.channel_id);
if channel_id.is_none() {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("ボイスチャンネルに参加してから実行してください.")
.ephemeral(true)
))
.await?;
return Ok(());
}
let storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<TTSData>()
.expect("Cannot get TTSStorage")
.clone()
};
{
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild.id) {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("読み上げしていません")
.ephemeral(true)
))
.await?;
return Ok(());
}
storage.get_mut(&guild.id).unwrap().skip(ctx).await;
}
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("スキップしました")
))
.await?;
Ok(())
}

96
src/commands/stop.rs Normal file
View File

@ -0,0 +1,96 @@
use serenity::{
all::{
CommandInteraction, CreateInteractionResponse, CreateInteractionResponseMessage, EditThread
},
model::prelude::UserId,
prelude::Context
};
use crate::data::TTSData;
pub async fn stop_command(
ctx: &Context,
command: &CommandInteraction,
) -> Result<(), Box<dyn std::error::Error>> {
if command.guild_id.is_none() {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("このコマンドはサーバーでのみ使用可能です.")
.ephemeral(true)
))
.await?;
return Ok(());
}
let guild_id = command.guild_id.unwrap();
let guild = guild_id.to_guild_cached(&ctx.cache).unwrap().clone();
let channel_id = guild
.voice_states
.get(&UserId::from(command.user.id.get()))
.and_then(|state| state.channel_id);
if channel_id.is_none() {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("ボイスチャンネルに参加してから実行してください.")
.ephemeral(true)
))
.await?;
return Ok(());
}
let manager = songbird::get(ctx)
.await
.expect("Cannot get songbird client.")
.clone();
let storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<TTSData>()
.expect("Cannot get TTSStorage")
.clone()
};
let text_channel_id = {
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild.id) {
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("すでに停止しています")
.ephemeral(true)
))
.await?;
return Ok(());
}
let text_channel_id = storage.get(&guild.id).unwrap().text_channel;
storage.remove(&guild.id);
text_channel_id
};
let _handler = manager.remove(guild.id).await;
command
.create_response(&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content("停止しました")
))
.await?;
let _ = text_channel_id
.edit_thread(&ctx.http, EditThread::new().archived(true))
.await;
Ok(())
}

View File

@ -6,5 +6,7 @@ pub struct Config {
pub token: String,
pub application_id: u64,
pub redis_url: String,
pub voicevox_key: String
}
pub voicevox_key: Option<String>,
pub voicevox_original_api_url: Option<String>,
pub otel_http_url: Option<String>,
}

View File

@ -1,8 +1,11 @@
use crate::{tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX}, database::database::Database};
use serenity::{prelude::{TypeMapKey, RwLock}, model::id::GuildId, futures::lock::Mutex};
use crate::{database::database::Database, tts::tts::TTS};
use serenity::{
model::id::GuildId,
prelude::{RwLock, TypeMapKey},
};
use crate::tts::instance::TTSInstance;
use std::{sync::Arc, collections::HashMap};
use std::{collections::HashMap, sync::Arc};
/// TTSInstance data
pub struct TTSData;
@ -15,12 +18,12 @@ impl TypeMapKey for TTSData {
pub struct TTSClientData;
impl TypeMapKey for TTSClientData {
type Value = Arc<Mutex<(TTS, VOICEVOX)>>;
type Value = Arc<TTS>;
}
/// Database client data
pub struct DatabaseClientData;
impl TypeMapKey for DatabaseClientData {
type Value = Arc<Mutex<Database>>;
type Value = Arc<Database>;
}

View File

@ -1,56 +1,148 @@
use crate::tts::{gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType};
use std::fmt::Debug;
use super::user_config::UserConfig;
use crate::tts::{
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType,
};
use super::{dictionary::Dictionary, server_config::ServerConfig, user_config::UserConfig};
use redis::Commands;
#[derive(Debug, Clone)]
pub struct Database {
pub connection: redis::Connection
pub client: redis::Client,
}
impl Database {
pub fn new(connection: redis::Connection) -> Self {
Self { connection }
pub fn new(client: redis::Client) -> Self {
Self { client }
}
pub async fn get_user_config(&mut self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
let config: String = self.connection.get(format!("discord_user:{}", user_id)).unwrap_or_default();
fn server_key(server_id: u64) -> String {
format!("discord_server:{}", server_id)
}
match serde_json::from_str(&config) {
Ok(config) => Ok(Some(config)),
Err(_) => Ok(None)
fn user_key(user_id: u64) -> String {
format!("discord_user:{}", user_id)
}
#[tracing::instrument]
fn get_config<T: serde::de::DeserializeOwned>(
&self,
key: &str,
) -> redis::RedisResult<Option<T>> {
match self.client.get_connection() {
Ok(mut connection) => {
let config: String = connection.get(key).unwrap_or_default();
if config.is_empty() {
return Ok(None);
}
match serde_json::from_str(&config) {
Ok(config) => Ok(Some(config)),
Err(_) => Ok(None),
}
}
Err(e) => Err(e),
}
}
pub async fn set_user_config(&mut self, user_id: u64, config: UserConfig) -> redis::RedisResult<()> {
let config = serde_json::to_string(&config).unwrap();
self.connection.set::<String, String, ()>(format!("discord_user:{}", user_id), config).unwrap();
Ok(())
#[tracing::instrument]
fn set_config<T: serde::Serialize + Debug>(
&self,
key: &str,
config: &T,
) -> redis::RedisResult<()> {
match self.client.get_connection() {
Ok(mut connection) => {
let config_str = serde_json::to_string(config).unwrap();
connection.set::<_, _, ()>(key, config_str)
}
Err(e) => Err(e),
}
}
pub async fn set_default_user_config(&mut self, user_id: u64) -> redis::RedisResult<()> {
#[tracing::instrument]
pub async fn get_server_config(
&self,
server_id: u64,
) -> redis::RedisResult<Option<ServerConfig>> {
self.get_config(&Self::server_key(server_id))
}
#[tracing::instrument]
pub async fn get_user_config(&self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
self.get_config(&Self::user_key(user_id))
}
#[tracing::instrument]
pub async fn set_server_config(
&self,
server_id: u64,
config: ServerConfig,
) -> redis::RedisResult<()> {
self.set_config(&Self::server_key(server_id), &config)
}
#[tracing::instrument]
pub async fn set_user_config(
&self,
user_id: u64,
config: UserConfig,
) -> redis::RedisResult<()> {
self.set_config(&Self::user_key(user_id), &config)
}
#[tracing::instrument]
pub async fn set_default_server_config(&self, server_id: u64) -> redis::RedisResult<()> {
let config = ServerConfig {
dictionary: Dictionary::new(),
autostart_channel_id: None,
voice_state_announce: Some(true),
read_username: Some(true),
};
self.set_server_config(server_id, config).await
}
#[tracing::instrument]
pub async fn set_default_user_config(&self, user_id: u64) -> redis::RedisResult<()> {
let voice_selection = VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
ssmlGender: String::from("neutral")
ssmlGender: String::from("neutral"),
};
let voice_type = TTSType::GCP;
let config = UserConfig {
tts_type: Some(voice_type),
tts_type: Some(TTSType::GCP),
gcp_tts_voice: Some(voice_selection),
voicevox_speaker: Some(1)
voicevox_speaker: Some(1),
};
self.connection.set(format!("discord_user:{}", user_id), serde_json::to_string(&config).unwrap())?;
Ok(())
self.set_user_config(user_id, config).await
}
pub async fn get_user_config_or_default(&mut self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
let config = self.get_user_config(user_id).await?;
match config {
Some(_) => Ok(config),
#[tracing::instrument]
pub async fn get_server_config_or_default(
&self,
server_id: u64,
) -> redis::RedisResult<Option<ServerConfig>> {
match self.get_server_config(server_id).await? {
Some(config) => Ok(Some(config)),
None => {
self.set_default_server_config(server_id).await?;
self.get_server_config(server_id).await
}
}
}
#[tracing::instrument]
pub async fn get_user_config_or_default(
&self,
user_id: u64,
) -> redis::RedisResult<Option<UserConfig>> {
match self.get_user_config(user_id).await? {
Some(config) => Ok(Some(config)),
None => {
self.set_default_user_config(user_id).await?;
self.get_user_config(user_id).await

View File

@ -0,0 +1,34 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Rule {
pub id: String,
pub is_regex: bool,
pub rule: String,
pub to: String,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Dictionary {
pub rules: Vec<Rule>,
}
impl Dictionary {
pub fn new() -> Self {
let rules = vec![
Rule {
id: String::from("url"),
is_regex: true,
rule: String::from(r"(http://|https://){1}[\w\.\-/:\#\?=\&;%\~\+]+"),
to: String::from("URL"),
},
Rule {
id: String::from("code"),
is_regex: true,
rule: String::from(r"```(.|\n)*```"),
to: String::from("code"),
},
];
Self { rules }
}
}

View File

@ -1,2 +1,4 @@
pub mod database;
pub mod dictionary;
pub mod server_config;
pub mod user_config;
pub mod database;

View File

@ -0,0 +1,15 @@
use super::dictionary::Dictionary;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct DictionaryOnlyServerConfig {
pub dictionary: Dictionary,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ServerConfig {
pub dictionary: Dictionary,
pub autostart_channel_id: Option<u64>,
pub voice_state_announce: Option<bool>,
pub read_username: Option<bool>,
}

View File

@ -1,10 +1,12 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
use crate::tts::{gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType};
use crate::tts::{
gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType,
};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct UserConfig {
pub tts_type: Option<TTSType>,
pub gcp_tts_voice: Option<VoiceSelectionParams>,
pub voicevox_speaker: Option<i64>
pub voicevox_speaker: Option<i64>,
}

View File

@ -1,299 +1,693 @@
use serenity::{client::{EventHandler, Context}, async_trait, model::{gateway::Ready, interactions::{Interaction, application_command::ApplicationCommandInteraction, InteractionApplicationCommandCallbackDataFlags}, id::{GuildId, UserId}, channel::Message, prelude::Member, voice::VoiceState}, framework::standard::macros::group};
use crate::{data::TTSData, tts::{instance::TTSInstance, message::AnnounceMessage}, implement::member_name::ReadName};
#[group]
struct Test;
use crate::{
commands::{
config::config_command, setup::setup_command, skip::skip_command, stop::stop_command,
},
data::DatabaseClientData,
database::dictionary::Rule,
events,
tts::tts_type::TTSType,
};
use serenity::{
all::{
ActionRowComponent, ButtonStyle, ComponentInteractionDataKind, CreateActionRow,
CreateButton, CreateEmbed, CreateInputText, CreateInteractionResponse,
CreateInteractionResponseMessage, CreateModal, CreateSelectMenu, CreateSelectMenuKind,
CreateSelectMenuOption, InputTextStyle,
},
async_trait,
client::{Context, EventHandler},
model::{
application::Interaction, channel::Message, gateway::Ready, prelude::ChannelType,
voice::VoiceState,
},
};
#[derive(Clone, Debug)]
pub struct Handler;
async fn stop_command(ctx: &Context, command: &ApplicationCommandInteraction) -> Result<(), Box<dyn std::error::Error>> {
if let None = command.guild_id {
command.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("このコマンドはサーバーでのみ使用可能です.").flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
})
}).await?;
return Ok(());
}
let guild = command.guild_id.unwrap().to_guild_cached(&ctx.cache).await;
if let None = guild {
command.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("ギルドキャッシュを取得できませんでした.").flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
})
}).await?;
return Ok(());
}
let guild = guild.unwrap();
let channel_id = guild
.voice_states
.get(&UserId(command.user.id.0))
.and_then(|state| state.channel_id);
if let None = channel_id {
command.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("ボイスチャンネルに参加してから実行してください.").flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
})
}).await?;
return Ok(());
}
let channel_id = channel_id.unwrap();
let manager = songbird::get(ctx).await.expect("Cannot get songbird client.").clone();
let storage_lock = {
let data_read = ctx.data.read().await;
data_read.get::<TTSData>().expect("Cannot get TTSStorage").clone()
};
{
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild.id) {
command.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("すでに停止しています").flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
})
}).await?;
return Ok(());
}
storage.remove(&guild.id);
}
let _handler = manager.leave(guild.id.0).await;
command.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("停止しました").flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
})
}).await?;
Ok(())
}
async fn setup_command(ctx: &Context, command: &ApplicationCommandInteraction) -> Result<(), Box<dyn std::error::Error>> {
if let None = command.guild_id {
command.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("このコマンドはサーバーでのみ使用可能です.").flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
})
}).await?;
return Ok(());
}
let guild = command.guild_id.unwrap().to_guild_cached(&ctx.cache).await;
if let None = guild {
command.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("ギルドキャッシュを取得できませんでした.").flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
})
}).await?;
return Ok(());
}
let guild = guild.unwrap();
let channel_id = guild
.voice_states
.get(&UserId(command.user.id.0))
.and_then(|state| state.channel_id);
if let None = channel_id {
command.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("ボイスチャンネルに参加してから実行してください.").flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
})
}).await?;
return Ok(());
}
let channel_id = channel_id.unwrap();
let manager = songbird::get(ctx).await.expect("Cannot get songbird client.").clone();
let storage_lock = {
let data_read = ctx.data.read().await;
data_read.get::<TTSData>().expect("Cannot get TTSStorage").clone()
};
{
let mut storage = storage_lock.write().await;
if storage.contains_key(&guild.id) {
command.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("すでにセットアップしています.").flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
})
}).await?;
return Ok(());
}
storage.insert(guild.id, TTSInstance {
before_message: None,
guild: guild.id,
text_channel: command.channel_id,
voice_channel: channel_id
});
}
let _handler = manager.join(guild.id.0, channel_id.0).await;
command.create_interaction_response(&ctx.http, |f| {
f.interaction_response_data(|d| {
d.content("セットアップ完了").flags(InteractionApplicationCommandCallbackDataFlags::EPHEMERAL)
})
}).await?;
Ok(())
}
#[async_trait]
impl EventHandler for Handler {
#[tracing::instrument]
async fn message(&self, ctx: Context, message: Message) {
events::message_receive::message(ctx, message).await
}
#[tracing::instrument]
async fn ready(&self, ctx: Context, ready: Ready) {
events::ready::ready(ctx, ready).await
}
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
if let Interaction::ApplicationCommand(command) = interaction {
if let Interaction::Command(command) = interaction.clone() {
let name = &*command.data.name;
match name {
"setup" => setup_command(&ctx, &command).await.unwrap(),
"stop" => stop_command(&ctx, &command).await.unwrap(),
"config" => config_command(&ctx, &command).await.unwrap(),
"skip" => skip_command(&ctx, &command).await.unwrap(),
_ => {}
}
}
if let Interaction::Modal(modal) = interaction.clone() {
if modal.data.custom_id != "TTS_CONFIG_SERVER_ADD_DICTIONARY" {
return;
}
let rows = modal.data.components.clone();
let rule_name =
if let ActionRowComponent::InputText(text) = rows[0].components[0].clone() {
text.value.unwrap()
} else {
panic!("Cannot get rule name");
};
let from = if let ActionRowComponent::InputText(text) = rows[1].components[0].clone() {
text.value.unwrap()
} else {
panic!("Cannot get from");
};
let to = if let ActionRowComponent::InputText(text) = rows[2].components[0].clone() {
text.value.unwrap()
} else {
panic!("Cannot get to");
};
let rule = Rule {
id: rule_name.clone(),
is_regex: true,
rule: from.clone(),
to: to.clone(),
};
let data_read = ctx.data.read().await;
let mut config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_server_config_or_default(modal.guild_id.unwrap().get())
.await
.unwrap()
.unwrap()
};
config.dictionary.rules.push(rule);
{
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.set_server_config(modal.guild_id.unwrap().get(), config)
.await
.unwrap();
modal
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new().content(format!(
"辞書を追加しました\n名前: {}\n変換元: {}\n変換後: {}",
rule_name, from, to
)),
),
)
.await
.unwrap();
}
}
if let Some(message_component) = interaction.message_component() {
match &*message_component.data.custom_id {
"TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE" => {
let data_read = ctx.data.read().await;
let mut config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
.unwrap()
.unwrap()
};
config.voice_state_announce =
Some(!config.voice_state_announce.unwrap_or(true));
let state = config.voice_state_announce.unwrap_or(true);
{
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.set_server_config(message_component.guild_id.unwrap().get(), config)
.await
.unwrap();
}
message_component
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new().content(format!(
"入退出アナウンス通知を{}へ切り替えました。",
if state { "`有効`" } else { "`無効`" }
)),
),
)
.await
.unwrap();
}
"TTS_CONFIG_SERVER_SET_READ_USERNAME" => {
let data_read = ctx.data.read().await;
let mut config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
.unwrap()
.unwrap()
};
config.read_username = Some(!config.read_username.unwrap_or(true));
let state = config.read_username.unwrap_or(true);
{
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.set_server_config(message_component.guild_id.unwrap().get(), config)
.await
.unwrap();
}
message_component
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new().content(format!(
"ユーザー名読み上げを{}へ切り替えました。",
if state { "`有効`" } else { "`無効`" }
)),
),
)
.await
.unwrap();
}
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU" => {
let i = usize::from_str_radix(
&match message_component.data.kind {
ComponentInteractionDataKind::StringSelect { ref values, .. } => {
values[0].clone()
}
_ => panic!("Cannot get index"),
},
10,
)
.unwrap();
let data_read = ctx.data.read().await;
let mut config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
.unwrap()
.unwrap()
};
config.dictionary.rules.remove(i);
{
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.set_server_config(message_component.guild_id.unwrap().get(), config)
.await
.unwrap();
}
message_component
.create_response(
&ctx,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new()
.content("辞書を削除しました"),
),
)
.await
.unwrap();
}
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON" => {
let data_read = ctx.data.read().await;
let config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
.unwrap()
.unwrap()
};
message_component
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new()
.content("削除する辞書内容を選択してください")
.components(vec![CreateActionRow::SelectMenu(
CreateSelectMenu::new(
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_MENU",
CreateSelectMenuKind::String {
options: {
let mut options = vec![];
for (i, rule) in
config.dictionary.rules.iter().enumerate()
{
let option = CreateSelectMenuOption::new(
rule.id.clone(),
i.to_string(),
)
.description(format!(
"{} -> {}",
rule.rule.clone(),
rule.to.clone()
));
options.push(option);
}
options
},
},
)
.max_values(1)
.min_values(0),
)]),
),
)
.await
.unwrap();
}
"TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON" => {
let config = {
let data_read = ctx.data.read().await;
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
.unwrap()
.unwrap()
};
message_component
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new().content("").embed(
CreateEmbed::new().title("辞書一覧").fields({
let mut fields = vec![];
for rule in config.dictionary.rules {
let field = (
rule.id.clone(),
format!("{} -> {}", rule.rule, rule.to),
true,
);
fields.push(field);
}
fields
}),
),
),
)
.await
.unwrap();
}
"TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON" => {
message_component
.create_response(
&ctx.http,
CreateInteractionResponse::Modal(
CreateModal::new("TTS_CONFIG_SERVER_ADD_DICTIONARY", "辞書追加")
.components({
vec![
CreateActionRow::InputText(
CreateInputText::new(
InputTextStyle::Short,
"rule_name",
"辞書名",
)
.required(true),
),
CreateActionRow::InputText(
CreateInputText::new(
InputTextStyle::Paragraph,
"from",
"変換元(正規表現)",
)
.required(true),
),
CreateActionRow::InputText(
CreateInputText::new(
InputTextStyle::Short,
"to",
"変換先",
)
.required(true),
),
]
}),
),
)
.await
.unwrap();
}
"SET_AUTOSTART_CHANNEL" => {
let autostart_channel_id = match message_component.data.kind {
ComponentInteractionDataKind::StringSelect { ref values, .. } => {
if values.len() == 0 {
None
} else {
Some(
u64::from_str_radix(
&values[0].strip_prefix("SET_AUTOSTART_CHANNEL_").unwrap(),
10,
)
.unwrap(),
)
}
}
_ => panic!("Cannot get index"),
};
{
let data_read = ctx.data.read().await;
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
let mut config = database
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
.unwrap()
.unwrap();
config.autostart_channel_id = autostart_channel_id;
database
.set_server_config(message_component.guild_id.unwrap().get(), config)
.await
.unwrap();
};
message_component
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new()
.content("自動参加チャンネルを設定しました。"),
),
)
.await
.unwrap();
}
"TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL" => {
let config = {
let data_read = ctx.data.read().await;
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_server_config_or_default(message_component.guild_id.unwrap().get())
.await
.unwrap()
.unwrap()
};
let autostart_channel_id = config.autostart_channel_id.unwrap_or(0);
let channels = message_component
.guild_id
.unwrap()
.channels(&ctx.http)
.await
.unwrap();
let mut options = Vec::new();
for (id, channel) in channels {
if channel.kind != ChannelType::Voice {
continue;
}
let description = channel
.topic
.unwrap_or_else(|| String::from("No topic provided."));
let option = CreateSelectMenuOption::new(
&channel.name,
format!("SET_AUTOSTART_CHANNEL_{}", id.get()),
)
.description(description)
.default_selection(channel.id.get() == autostart_channel_id);
options.push(option);
}
message_component
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new()
.content("自動参加チャンネル設定")
.components(vec![
CreateActionRow::SelectMenu(
CreateSelectMenu::new(
"SET_AUTOSTART_CHANNEL",
CreateSelectMenuKind::String { options },
)
.min_values(0)
.max_values(1),
),
CreateActionRow::Buttons(vec![CreateButton::new(
"TTS_CONFIG_SERVER_BACK",
)
.label("← サーバー設定に戻る")
.style(ButtonStyle::Secondary)]),
]),
),
)
.await
.unwrap();
}
"TTS_CONFIG_SERVER_BACK" => {
message_component
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new()
.content("サーバー設定")
.components(vec![CreateActionRow::Buttons(vec![
CreateButton::new("TTS_CONFIG_SERVER_DICTIONARY")
.label("辞書管理")
.style(ButtonStyle::Primary),
CreateButton::new(
"TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL",
)
.label("自動参加チャンネル")
.style(ButtonStyle::Primary),
CreateButton::new(
"TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE",
)
.label("入退出アナウンス通知切り替え")
.style(ButtonStyle::Primary),
CreateButton::new("TTS_CONFIG_SERVER_SET_READ_USERNAME")
.label("ユーザー名読み上げ切り替え")
.style(ButtonStyle::Primary),
])]),
),
)
.await
.unwrap();
}
"TTS_CONFIG_SERVER" => {
message_component
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new()
.content("サーバー設定")
.components(vec![CreateActionRow::Buttons(vec![
CreateButton::new("TTS_CONFIG_SERVER_DICTIONARY")
.label("辞書管理")
.style(ButtonStyle::Primary),
CreateButton::new(
"TTS_CONFIG_SERVER_SET_AUTOSTART_CHANNEL",
)
.label("自動参加チャンネル")
.style(ButtonStyle::Primary),
CreateButton::new(
"TTS_CONFIG_SERVER_SET_VOICE_STATE_ANNOUNCE",
)
.label("入退出アナウンス通知切り替え")
.style(ButtonStyle::Primary),
CreateButton::new("TTS_CONFIG_SERVER_SET_READ_USERNAME")
.label("ユーザー名読み上げ切り替え")
.style(ButtonStyle::Primary),
])]),
),
)
.await
.unwrap();
}
"TTS_CONFIG_SERVER_DICTIONARY" => {
message_component
.create_response(
&ctx.http,
CreateInteractionResponse::UpdateMessage(
CreateInteractionResponseMessage::new()
.content("辞書管理")
.components(vec![
CreateActionRow::Buttons(vec![
CreateButton::new(
"TTS_CONFIG_SERVER_ADD_DICTIONARY_BUTTON",
)
.label("辞書を追加")
.style(ButtonStyle::Primary),
CreateButton::new(
"TTS_CONFIG_SERVER_REMOVE_DICTIONARY_BUTTON",
)
.label("辞書を削除")
.style(ButtonStyle::Danger),
CreateButton::new(
"TTS_CONFIG_SERVER_SHOW_DICTIONARY_BUTTON",
)
.label("辞書一覧")
.style(ButtonStyle::Primary),
]),
CreateActionRow::Buttons(vec![CreateButton::new(
"TTS_CONFIG_SERVER_BACK",
)
.label("← サーバー設定に戻る")
.style(ButtonStyle::Secondary)]),
]),
),
)
.await
.unwrap();
}
_ => {}
}
match message_component.data.kind {
ComponentInteractionDataKind::StringSelect { ref values, .. }
if !values.is_empty() =>
{
let res = &values[0].clone();
let data_read = ctx.data.read().await;
let mut config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_user_config_or_default(message_component.user.id.get())
.await
.unwrap()
.unwrap()
};
let mut config_changed = false;
let mut voicevox_changed = false;
match res.as_str() {
"TTS_CONFIG_ENGINE_SELECTED_GOOGLE" => {
config.tts_type = Some(TTSType::GCP);
config_changed = true;
}
"TTS_CONFIG_ENGINE_SELECTED_VOICEVOX" => {
config.tts_type = Some(TTSType::VOICEVOX);
config_changed = true;
}
_ => {
if res.starts_with("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_") {
let speaker_id = res
.strip_prefix("TTS_CONFIG_VOICEVOX_SPEAKER_SELECTED_")
.and_then(|id_str| id_str.parse::<i64>().ok())
.expect("Invalid speaker ID format");
config.voicevox_speaker = Some(speaker_id);
config_changed = true;
voicevox_changed = true;
}
}
}
if config_changed {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.set_user_config(message_component.user.id.get(), config.clone())
.await
.unwrap();
let response_content = if voicevox_changed
&& config.tts_type.unwrap_or(TTSType::GCP) == TTSType::GCP
{
"設定しました\nこの音声を使うにはAPIをGoogleからVOICEVOXに変更する必要があります。"
} else {
"設定しました"
};
message_component
.create_response(
&ctx.http,
CreateInteractionResponse::Message(
CreateInteractionResponseMessage::new()
.content(response_content)
.ephemeral(true),
),
)
.await
.unwrap();
}
}
_ => {}
}
}
}
async fn voice_state_update(
&self,
ctx: Context,
guild_id: Option<GuildId>,
old: Option<VoiceState>,
new: VoiceState,
) {
let guild_id = guild_id.unwrap();
let storage_lock = {
let data_read = ctx.data.read().await;
data_read.get::<TTSData>().expect("Cannot get TTSStorage").clone()
};
{
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild_id) {
return;
}
let instance = storage.get_mut(&guild_id).unwrap();
let mut message: Option<String> = None;
match old {
Some(old) => {
match (old.channel_id, new.channel_id) {
(Some(old_channel_id), Some(new_channel_id)) => {
if old_channel_id == new_channel_id {
return;
}
if old_channel_id != new_channel_id {
if instance.voice_channel == new_channel_id {
message = Some(format!("{} さんが通話に参加しました", new.member.unwrap().read_name()));
}
} else if old_channel_id == instance.voice_channel && new_channel_id != instance.voice_channel {
message = Some(format!("{} さんが通話から退出しました", new.member.unwrap().read_name()));
} else {
return;
}
}
(Some(old_channel_id), None) => {
if old_channel_id == instance.voice_channel {
message = Some(format!("{} さんが通話から退出しました", new.member.unwrap().read_name()));
} else {
return;
}
}
(None, Some(new_channel_id)) => {
if new_channel_id == instance.voice_channel {
message = Some(format!("{} さんが通話に参加しました", new.member.unwrap().read_name()));
} else {
return;
}
}
_ => {
return;
}
}
}
None => {
match new.channel_id {
Some(channel_id) => {
if instance.voice_channel == channel_id {
message = Some(format!("{} さんが通話に参加しました", new.member.unwrap().read_name()));
}
}
None => {
return;
}
}
}
}
if let Some(message) = message {
instance.read(AnnounceMessage {
message
}, &ctx).await;
}
}
}
async fn message(&self, ctx: Context, message: Message) {
if message.author.bot {
return;
}
let guild_id = message.guild(&ctx.cache).await;
if let None = guild_id {
return;
}
let guild_id = guild_id.unwrap().id;
let storage_lock = {
let data_read = ctx.data.read().await;
data_read.get::<TTSData>().expect("Cannot get TTSStorage").clone()
};
{
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild_id) {
return;
}
let instance = storage.get_mut(&guild_id).unwrap();
if instance.text_channel.0 != message.channel_id.0 {
return;
}
instance.read(message, &ctx).await;
}
}
async fn ready(&self, ctx: Context, ready: Ready) {
println!("{} is connected!", ready.user.name);
let guild_id = GuildId(660046656934248460);
let commands = GuildId::set_application_commands(&guild_id, &ctx.http, |commands| {
commands.create_application_command(|command| {
command.name("stop")
.description("Stop tts")
});
commands.create_application_command(|command| {
command.name("setup")
.description("Setup tts")
})
}).await;
println!("{:?}", commands);
async fn voice_state_update(&self, ctx: Context, old: Option<VoiceState>, new: VoiceState) {
events::voice_state_update::voice_state_update(ctx, old, new).await
}
}

View File

@ -0,0 +1,44 @@
use serenity::{model::prelude::Message, prelude::Context};
use crate::data::TTSData;
pub async fn message(ctx: Context, message: Message) {
if message.author.bot {
return;
}
let guild_id = message.guild(&ctx.cache);
if let None = guild_id {
return;
}
let guild_id = guild_id.unwrap().id;
let storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<TTSData>()
.expect("Cannot get TTSStorage")
.clone()
};
{
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild_id) {
return;
}
let instance = storage.get_mut(&guild_id).unwrap();
if instance.text_channel != message.channel_id {
return;
}
if message.content.starts_with(";") {
return;
}
instance.read(message, &ctx).await;
}
}

3
src/events/mod.rs Normal file
View File

@ -0,0 +1,3 @@
pub mod message_receive;
pub mod ready;
pub mod voice_state_update;

33
src/events/ready.rs Normal file
View File

@ -0,0 +1,33 @@
use serenity::{
all::{Command, CommandOptionType, CreateCommand, CreateCommandOption},
model::prelude::Ready,
prelude::Context,
};
use tracing::info;
#[tracing::instrument]
pub async fn ready(ctx: Context, ready: Ready) {
info!("{} is connected!", ready.user.name);
Command::set_global_commands(
&ctx.http,
vec![
CreateCommand::new("stop").description("Stop tts"),
CreateCommand::new("setup")
.description("Setup tts")
.set_options(vec![CreateCommandOption::new(
CommandOptionType::String,
"mode",
"TTS channel",
)
.add_string_choice("Text Channel", "TEXT_CHANNEL")
.add_string_choice("New Thread", "NEW_THREAD")
.add_string_choice("Voice Channel", "VOICE_CHANNEL")
.required(false)]),
CreateCommand::new("config").description("Config"),
CreateCommand::new("skip").description("skip tts message"),
],
)
.await
.unwrap();
}

View File

@ -0,0 +1,152 @@
use crate::{
data::{DatabaseClientData, TTSClientData, TTSData},
implement::{
member_name::ReadName,
voice_move_state::{VoiceMoveState, VoiceMoveStateTrait},
},
tts::{instance::TTSInstance, message::AnnounceMessage},
};
use serenity::{
all::{CreateEmbed, CreateMessage, EditThread},
model::voice::VoiceState,
prelude::Context,
};
pub async fn voice_state_update(ctx: Context, old: Option<VoiceState>, new: VoiceState) {
if new.member.clone().unwrap().user.bot {
return;
}
if old.is_none() && new.guild_id.is_none() {
return;
}
let guild_id = if let Some(guild_id) = new.guild_id {
guild_id
} else {
old.clone().unwrap().guild_id.unwrap()
};
let storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<TTSData>()
.expect("Cannot get TTSStorage")
.clone()
};
let config = {
let data_read = ctx.data.read().await;
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_server_config_or_default(guild_id.get())
.await
.unwrap()
.unwrap()
};
if !config.voice_state_announce.unwrap_or(true) {
return;
}
{
let mut storage = storage_lock.write().await;
if !storage.contains_key(&guild_id) {
if let Some(new_channel) = new.channel_id {
if config.autostart_channel_id.unwrap_or(0) == new_channel.get() {
let manager = songbird::get(&ctx)
.await
.expect("Cannot get songbird client.")
.clone();
storage.insert(
guild_id,
TTSInstance {
before_message: None,
guild: guild_id,
text_channel: new_channel,
voice_channel: new_channel,
},
);
let _handler = manager.join(guild_id, new_channel).await;
let data = ctx.data.read().await;
let tts_client = data
.get::<TTSClientData>()
.expect("Cannot get TTSClientData");
let voicevox_speakers = tts_client.voicevox_client.get_speakers().await;
new_channel
.send_message(
&ctx.http,
CreateMessage::new().embed(
CreateEmbed::new()
.title("自動参加 読み上げSerenity")
.field(
"VOICEVOXクレジット",
format!("```\n{}\n```", voicevox_speakers.join("\n")),
false,
)
.field("設定コマンド", "`/config`", false)
.field("フィードバック", "https://feedback.mii.codes/", false),
),
)
.await
.unwrap();
}
}
return;
}
let instance = storage.get_mut(&guild_id).unwrap();
let voice_move_state = new.move_state(&old, instance.voice_channel);
let message: Option<String> = match voice_move_state {
VoiceMoveState::JOIN => Some(format!(
"{} さんが通話に参加しました",
new.member.unwrap().read_name()
)),
VoiceMoveState::LEAVE => Some(format!(
"{} さんが通話から退出しました",
new.member.unwrap().read_name()
)),
_ => None,
};
if let Some(message) = message {
instance.read(AnnounceMessage { message }, &ctx).await;
}
if voice_move_state == VoiceMoveState::LEAVE {
let mut del_flag = false;
for channel in guild_id.channels(&ctx.http).await.unwrap() {
if channel.0 == instance.voice_channel {
let members = channel.1.members(&ctx.cache).unwrap();
let user_count = members.iter().filter(|member| !member.user.bot).count();
del_flag = user_count == 0;
}
}
if del_flag {
let _ = storage
.get(&guild_id)
.unwrap()
.text_channel
.edit_thread(&ctx.http, EditThread::new().archived(true))
.await;
storage.remove(&guild_id);
let manager = songbird::get(&ctx)
.await
.expect("Cannot get songbird client.")
.clone();
manager.remove(guild_id).await.unwrap();
}
}
}
}

View File

@ -1,4 +1,7 @@
use serenity::model::guild::Member;
use serenity::model::{
guild::{Member, PartialMember},
user::User,
};
pub trait ReadName {
fn read_name(&self) -> String;
@ -6,6 +9,20 @@ pub trait ReadName {
impl ReadName for Member {
fn read_name(&self) -> String {
self.nick.clone().unwrap_or(self.user.name.clone())
self.nick.clone().unwrap_or(self.display_name().to_string())
}
}
}
impl ReadName for PartialMember {
fn read_name(&self) -> String {
self.nick
.clone()
.unwrap_or(self.user.as_ref().unwrap().display_name().to_string())
}
}
impl ReadName for User {
fn read_name(&self) -> String {
self.display_name().to_string()
}
}

View File

@ -1,92 +1,148 @@
use std::{path::Path, fs::File, io::Write};
use async_trait::async_trait;
use serenity::{prelude::Context, model::prelude::Message};
use regex::Regex;
use serenity::{model::prelude::Message, prelude::Context};
use songbird::tracks::Track;
use crate::{
data::{TTSClientData, DatabaseClientData},
data::{DatabaseClientData, TTSClientData},
implement::member_name::ReadName,
tts::{
gcp_tts::structs::{
audio_config::AudioConfig, synthesis_input::SynthesisInput,
synthesize_request::SynthesizeRequest,
},
instance::TTSInstance,
message::TTSMessage,
gcp_tts::structs::{
audio_config::AudioConfig, synthesis_input::SynthesisInput, synthesize_request::SynthesizeRequest
}, tts_type::{self, TTSType}
tts_type::TTSType,
},
};
#[async_trait]
impl TTSMessage for Message {
async fn parse(&self, instance: &mut TTSInstance, _: &Context) -> String {
let res = if let Some(before_message) = &instance.before_message {
async fn parse(&self, instance: &mut TTSInstance, ctx: &Context) -> String {
let data_read = ctx.data.read().await;
let config = {
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_server_config_or_default(instance.guild.get())
.await
.unwrap()
.unwrap()
};
let mut text = self.content.clone();
for rule in config.dictionary.rules {
if rule.is_regex {
let regex = Regex::new(&rule.rule).unwrap();
text = regex.replace_all(&text, rule.to).to_string();
} else {
text = text.replace(&rule.rule, &rule.to);
}
}
let mut res = if let Some(before_message) = &instance.before_message {
if before_message.author.id == self.author.id {
self.content.clone()
text.clone()
} else {
let member = self.member.clone();
let name = if let Some(member) = member {
member.nick.unwrap_or(self.author.name.clone())
let name = if let Some(_) = member {
let guild = ctx.cache.guild(self.guild_id.unwrap()).unwrap().clone();
guild
.member(&ctx.http, self.author.id)
.await
.unwrap()
.read_name()
} else {
self.author.name.clone()
self.author.read_name()
};
format!("{} さんの発言<break time=\"200ms\"/>{}", name, self.content)
if config.read_username.unwrap_or(true) {
format!("{}さんの発言<break time=\"200ms\"/>{}", name, text)
} else {
format!("{}", text)
}
}
} else {
let member = self.member.clone();
let name = if let Some(member) = member {
member.nick.unwrap_or(self.author.name.clone())
let name = if let Some(_) = member {
let guild = ctx.cache.guild(self.guild_id.unwrap()).unwrap().clone();
guild
.member(&ctx.http, self.author.id)
.await
.unwrap()
.read_name()
} else {
self.author.name.clone()
self.author.read_name()
};
format!("{} さんの発言<break time=\"200ms\"/>{}", name, self.content)
if config.read_username.unwrap_or(true) {
format!("{}さんの発言<break time=\"200ms\"/>{}", name, text)
} else {
format!("{}", text)
}
};
if self.attachments.len() > 0 {
res = format!(
"{}<break time=\"200ms\"/>{}個の添付ファイル",
res,
self.attachments.len()
);
}
instance.before_message = Some(self.clone());
res
}
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String {
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Track> {
let text = self.parse(instance, ctx).await;
let data_read = ctx.data.read().await;
let storage = data_read.get::<TTSClientData>().expect("Cannot get GCP TTSClientStorage").clone();
let mut tts = storage.lock().await;
let config = {
let database = data_read.get::<DatabaseClientData>().expect("Cannot get DatabaseClientData").clone();
let mut database = database.lock().await;
database.get_user_config_or_default(self.author.id.0).await.unwrap().unwrap()
let database = data_read
.get::<DatabaseClientData>()
.expect("Cannot get DatabaseClientData")
.clone();
database
.get_user_config_or_default(self.author.id.get())
.await
.unwrap()
.unwrap()
};
let audio = match config.tts_type.unwrap_or(TTSType::GCP) {
TTSType::GCP => {
tts.0.synthesize(SynthesizeRequest {
let tts = data_read
.get::<TTSClientData>()
.expect("Cannot get GCP TTSClientStorage");
match config.tts_type.unwrap_or(TTSType::GCP) {
TTSType::GCP => vec![tts
.synthesize_gcp(SynthesizeRequest {
input: SynthesisInput {
text: None,
ssml: Some(format!("<speak>{}</speak>", text))
ssml: Some(format!("<speak>{}</speak>", text)),
},
voice: config.gcp_tts_voice.unwrap(),
audioConfig: AudioConfig {
audioEncoding: String::from("mp3"),
speakingRate: 1.2f32,
pitch: 1.0f32
}
}).await.unwrap()
}
pitch: 1.0f32,
},
})
.await
.unwrap()
.into()],
TTSType::VOICEVOX => {
tts.1.synthesize(text.replace("<break time=\"200ms\"/>", ""), config.voicevox_speaker.unwrap_or(1)).await.unwrap()
}
};
let uuid = uuid::Uuid::new_v4().to_string();
let root = option_env!("CARGO_MANIFEST_DIR").unwrap();
let path = Path::new(root);
let file_path = path.join("audio").join(format!("{}.mp3", uuid));
let mut file = File::create(file_path.clone()).unwrap();
file.write(&audio).unwrap();
file_path.into_os_string().into_string().unwrap()
TTSType::VOICEVOX => vec![tts
.synthesize_voicevox(
&text.replace("<break time=\"200ms\"/>", ""),
config.voicevox_speaker.unwrap_or(1),
)
.await
.unwrap()
.into()],
}
}
}

View File

@ -1,2 +1,3 @@
pub mod member_name;
pub mod message;
pub mod member_name;
pub mod voice_move_state;

View File

@ -0,0 +1,50 @@
use serenity::model::{prelude::ChannelId, voice::VoiceState};
pub trait VoiceMoveStateTrait {
fn move_state(&self, old: &Option<VoiceState>, target_channel: ChannelId) -> VoiceMoveState;
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum VoiceMoveState {
JOIN,
LEAVE,
NONE,
}
impl VoiceMoveStateTrait for VoiceState {
fn move_state(&self, old: &Option<VoiceState>, target_channel: ChannelId) -> VoiceMoveState {
let new = self;
if let None = old.clone() {
return if target_channel == new.channel_id.unwrap() {
VoiceMoveState::JOIN
} else {
VoiceMoveState::NONE
};
}
let old = (*old).clone().unwrap();
match (old.channel_id, new.channel_id) {
(Some(old_channel_id), Some(new_channel_id)) => {
if old_channel_id == new_channel_id {
VoiceMoveState::NONE
} else if old_channel_id == target_channel {
VoiceMoveState::LEAVE
} else if new_channel_id == target_channel {
VoiceMoveState::JOIN
} else {
VoiceMoveState::NONE
}
}
(Some(old_channel_id), None) => {
if old_channel_id == target_channel {
VoiceMoveState::LEAVE
} else {
VoiceMoveState::NONE
}
}
_ => VoiceMoveState::NONE,
}
}
}

View File

@ -1,23 +1,32 @@
use std::{sync::Arc, collections::HashMap};
use config::Config;
use data::{TTSData, TTSClientData, DatabaseClientData};
use database::database::Database;
use event_handler::Handler;
use tts::{gcp_tts::gcp_tts::TTS, voicevox::voicevox::VOICEVOX};
use serenity::{
client::{Client, bridge::gateway::GatewayIntents},
framework::StandardFramework, prelude::RwLock, futures::lock::Mutex
};
use songbird::SerenityInit;
mod commands;
mod config;
mod event_handler;
mod tts;
mod implement;
mod data;
mod database;
mod event_handler;
mod events;
mod implement;
mod stream_input;
mod trace;
mod tts;
use std::{collections::HashMap, env, sync::Arc};
use config::Config;
use data::{DatabaseClientData, TTSClientData, TTSData};
use database::database::Database;
use event_handler::Handler;
#[allow(deprecated)]
use serenity::{
all::{standard::Configuration, ApplicationId},
client::Client,
framework::StandardFramework,
prelude::{GatewayIntents, RwLock},
};
use trace::init_tracing_subscriber;
use tracing::info;
use tts::{gcp_tts::gcp_tts::GCPTTS, tts::TTS, voicevox::voicevox::VOICEVOX};
use songbird::SerenityInit;
/// Create discord client
///
@ -27,17 +36,15 @@ mod database;
///
/// client.start().await;
/// ```
#[allow(deprecated)]
async fn create_client(prefix: &str, token: &str, id: u64) -> Result<Client, serenity::Error> {
let framework = StandardFramework::new()
.configure(|c| c
.with_whitespace(true)
.prefix(prefix));
let framework = StandardFramework::new();
framework.configure(Configuration::new().with_whitespace(true).prefix(prefix));
Client::builder(token)
Client::builder(token, GatewayIntents::all())
.event_handler(Handler)
.application_id(id)
.application_id(ApplicationId::new(id))
.framework(framework)
.intents(GatewayIntents::all())
.register_songbird()
.await
}
@ -45,34 +52,70 @@ async fn create_client(prefix: &str, token: &str, id: u64) -> Result<Client, ser
#[tokio::main]
async fn main() {
// Load config
let config = std::fs::read_to_string("./config.toml").expect("Cannot read config file.");
let config: Config = toml::from_str(&config).expect("Cannot load config file.");
let config = {
let config = std::fs::read_to_string("./config.toml");
if let Ok(config) = config {
toml::from_str::<Config>(&config).expect("Cannot load config file.")
} else {
let token = env::var("NCB_TOKEN").unwrap();
let application_id = env::var("NCB_APP_ID").unwrap();
let prefix = env::var("NCB_PREFIX").unwrap();
let redis_url = env::var("NCB_REDIS_URL").unwrap();
let voicevox_key = match env::var("NCB_VOICEVOX_KEY") {
Ok(key) => Some(key),
Err(_) => None,
};
let voicevox_original_api_url = match env::var("NCB_VOICEVOX_ORIGINAL_API_URL") {
Ok(url) => Some(url),
Err(_) => None,
};
let otel_http_url = match env::var("NCB_OTEL_HTTP_URL") {
Ok(url) => Some(url),
Err(_) => None,
};
// Create discord client
let mut client = create_client(&config.prefix, &config.token, config.application_id).await.expect("Err creating client");
// Create GCP TTS client
let tts = match TTS::new("./credentials.json".to_string()).await {
Ok(tts) => tts,
Err(err) => panic!("{}", err)
Config {
token,
application_id: u64::from_str_radix(&application_id, 10).unwrap(),
prefix,
redis_url,
voicevox_key,
voicevox_original_api_url,
otel_http_url,
}
}
};
let voicevox = VOICEVOX::new(config.voicevox_key);
let _guard = init_tracing_subscriber(&config.otel_http_url);
// Create discord client
let mut client = create_client(&config.prefix, &config.token, config.application_id)
.await
.expect("Err creating client");
// Create GCP TTS client
let tts = match GCPTTS::new("./credentials.json".to_string()).await {
Ok(tts) => tts,
Err(err) => panic!("GCP init error: {}", err),
};
let voicevox = VOICEVOX::new(config.voicevox_key, config.voicevox_original_api_url);
let database_client = {
let redis_client = redis::Client::open(config.redis_url).unwrap();
let con = redis_client.get_connection().unwrap();
Database::new(con)
Database::new(redis_client)
};
// Create TTS storage
{
let mut data = client.data.write().await;
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
data.insert::<TTSClientData>(Arc::new(Mutex::new((tts, voicevox))));
data.insert::<DatabaseClientData>(Arc::new(Mutex::new(database_client)));
data.insert::<TTSClientData>(Arc::new(TTS::new(voicevox, tts)));
data.insert::<DatabaseClientData>(Arc::new(database_client));
}
info!("Bot initialized.");
// Run client
if let Err(why) = client.start().await {
println!("Client error: {:?}", why);

93
src/stream_input.rs Normal file
View File

@ -0,0 +1,93 @@
use async_trait::async_trait;
use futures::TryStreamExt;
use reqwest::{header::HeaderMap, Client};
use symphonia_core::{io::MediaSource, probe::Hint};
use tokio_util::compat::FuturesAsyncReadCompatExt;
use songbird::input::{
AsyncAdapterStream, AsyncReadOnlySource, AudioStream, AudioStreamError, Compose, Input,
};
#[derive(Debug, Clone)]
pub struct Mp3Request {
client: Client,
request: String,
headers: HeaderMap,
}
impl Mp3Request {
#[must_use]
pub fn new(client: Client, request: String) -> Self {
Self::new_with_headers(client, request, HeaderMap::default())
}
#[must_use]
pub fn new_with_headers(client: Client, request: String, headers: HeaderMap) -> Self {
Mp3Request {
client,
request,
headers,
}
}
async fn create_stream_async(&self) -> Result<AsyncReadOnlySource, AudioStreamError> {
let request = self
.client
.get(&self.request)
.headers(self.headers.clone())
.build()
.map_err(|why| AudioStreamError::Fail(why.into()))?;
let response = self
.client
.execute(request)
.await
.map_err(|why| AudioStreamError::Fail(why.into()))?;
if !response.status().is_success() {
return Err(AudioStreamError::Fail(
format!("HTTP error: {}", response.status()).into(),
));
}
let byte_stream = response
.bytes_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()));
let tokio_reader = byte_stream.into_async_read().compat();
Ok(AsyncReadOnlySource::new(tokio_reader))
}
}
#[async_trait]
impl Compose for Mp3Request {
fn create(&mut self) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> {
Err(AudioStreamError::Fail(
"Mp3Request::create must be called in an async context via create_async".into(),
))
}
async fn create_async(
&mut self,
) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> {
let input = self.create_stream_async().await?;
let stream = AsyncAdapterStream::new(Box::new(input), 64 * 1024);
let hint = Hint::new().with_extension("mp3").clone();
Ok(AudioStream {
input: Box::new(stream) as Box<dyn MediaSource>,
hint: Some(hint),
})
}
fn should_create_async(&self) -> bool {
true
}
}
impl From<Mp3Request> for Input {
fn from(val: Mp3Request) -> Self {
Input::Lazy(Box::new(val))
}
}

128
src/trace.rs Normal file
View File

@ -0,0 +1,128 @@
use opentelemetry::{
global,
trace::{SamplingDecision, SamplingResult, TraceContextExt, TraceState, TracerProvider as _},
KeyValue,
};
use opentelemetry_otlp::{Protocol, WithExportConfig};
use opentelemetry_sdk::{
metrics::{MeterProviderBuilder, PeriodicReader, SdkMeterProvider},
trace::{RandomIdGenerator, SdkTracerProvider, ShouldSample},
Resource,
};
use tracing::Level;
use tracing_opentelemetry::{MetricsLayer, OpenTelemetryLayer};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[derive(Debug, Clone)]
struct FilterSampler;
impl ShouldSample for FilterSampler {
fn should_sample(
&self,
parent_context: Option<&opentelemetry::Context>,
_trace_id: opentelemetry::TraceId,
name: &str,
_span_kind: &opentelemetry::trace::SpanKind,
_attributes: &[KeyValue],
_links: &[opentelemetry::trace::Link],
) -> opentelemetry::trace::SamplingResult {
let decision = if name == "dispatch" || name == "recv_event" {
SamplingDecision::Drop
} else {
SamplingDecision::RecordAndSample
};
SamplingResult {
decision,
attributes: vec![],
trace_state: match parent_context {
Some(ctx) => ctx.span().span_context().trace_state().clone(),
None => TraceState::default(),
},
}
}
}
fn resource() -> Resource {
Resource::builder().with_service_name("ncb-tts-r2").build()
}
fn init_meter_provider(url: &str) -> SdkMeterProvider {
let exporter = opentelemetry_otlp::MetricExporter::builder()
.with_http()
.with_endpoint(url)
.with_protocol(Protocol::HttpBinary)
.with_temporality(opentelemetry_sdk::metrics::Temporality::default())
.build()
.unwrap();
let reader = PeriodicReader::builder(exporter)
.with_interval(std::time::Duration::from_secs(5))
.build();
let stdout_reader =
PeriodicReader::builder(opentelemetry_stdout::MetricExporter::default()).build();
let meter_provider = MeterProviderBuilder::default()
.with_resource(resource())
.with_reader(reader)
.with_reader(stdout_reader)
.build();
global::set_meter_provider(meter_provider.clone());
meter_provider
}
fn init_tracer_provider(url: &str) -> SdkTracerProvider {
let exporter = opentelemetry_otlp::SpanExporter::builder()
.with_http()
.with_endpoint(url)
.with_protocol(Protocol::HttpBinary)
.build()
.unwrap();
SdkTracerProvider::builder()
.with_sampler(FilterSampler)
.with_id_generator(RandomIdGenerator::default())
.with_resource(resource())
.with_batch_exporter(exporter)
.build()
}
pub fn init_tracing_subscriber(otel_http_url: &Option<String>) -> OtelGuard {
let registry = tracing_subscriber::registry()
.with(tracing_subscriber::filter::LevelFilter::from_level(
Level::INFO,
))
.with(tracing_subscriber::fmt::layer());
if let Some(url) = otel_http_url {
let tracer_provider = init_tracer_provider(url);
let meter_provider = init_meter_provider(url);
let tracer = tracer_provider.tracer("ncb-tts-r2");
registry
.with(MetricsLayer::new(meter_provider.clone()))
.with(OpenTelemetryLayer::new(tracer))
.init();
OtelGuard {
_tracer_provider: Some(tracer_provider),
_meter_provider: Some(meter_provider),
}
} else {
registry.init();
OtelGuard {
_tracer_provider: None,
_meter_provider: None,
}
}
}
pub struct OtelGuard {
_tracer_provider: Option<SdkTracerProvider>,
_meter_provider: Option<SdkMeterProvider>,
}

View File

@ -1,34 +1,42 @@
use gcp_auth::Token;
use crate::tts::gcp_tts::structs::{
synthesize_request::SynthesizeRequest,
synthesize_response::SynthesizeResponse,
synthesize_request::SynthesizeRequest, synthesize_response::SynthesizeResponse,
};
use gcp_auth::Token;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Clone)]
pub struct TTS {
pub token: Token,
pub credentials_path: String
#[derive(Clone, Debug)]
pub struct GCPTTS {
pub token: Arc<RwLock<Token>>,
pub credentials_path: String,
}
impl TTS {
pub async fn update_token(&mut self) -> Result<(), gcp_auth::Error> {
if self.token.has_expired() {
let authenticator = gcp_auth::from_credentials_file(self.credentials_path.clone()).await?;
let token = authenticator.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
self.token = token;
impl GCPTTS {
#[tracing::instrument]
pub async fn update_token(&self) -> Result<(), gcp_auth::Error> {
let mut token = self.token.write().await;
if token.has_expired() {
let authenticator =
gcp_auth::from_credentials_file(self.credentials_path.clone()).await?;
let new_token = authenticator
.get_token(&["https://www.googleapis.com/auth/cloud-platform"])
.await?;
*token = new_token;
}
Ok(())
}
pub async fn new(credentials_path: String) -> Result<TTS, gcp_auth::Error> {
#[tracing::instrument]
pub async fn new(credentials_path: String) -> Result<Self, gcp_auth::Error> {
let authenticator = gcp_auth::from_credentials_file(credentials_path.clone()).await?;
let token = authenticator.get_token(&["https://www.googleapis.com/auth/cloud-platform"]).await?;
let token = authenticator
.get_token(&["https://www.googleapis.com/auth/cloud-platform"])
.await?;
Ok(TTS {
token,
credentials_path
Ok(Self {
token: Arc::new(RwLock::new(token)),
credentials_path,
})
}
@ -53,19 +61,36 @@ impl TTS {
/// }
/// }).await.unwrap();
/// ```
pub async fn synthesize(&mut self, request: SynthesizeRequest) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
#[tracing::instrument]
pub async fn synthesize(
&self,
request: SynthesizeRequest,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
self.update_token().await.unwrap();
let client = reqwest::Client::new();
match client.post("https://texttospeech.googleapis.com/v1/text:synthesize")
let token_string = {
let token = self.token.read().await;
token.as_str().to_string()
};
match client
.post("https://texttospeech.googleapis.com/v1/text:synthesize")
.header(reqwest::header::CONTENT_TYPE, "application/json")
.header(reqwest::header::AUTHORIZATION, format!("Bearer {}", self.token.as_str()))
.header(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", token_string),
)
.body(serde_json::to_string(&request).unwrap())
.send().await {
Ok(ok) => {
let response: SynthesizeResponse = serde_json::from_str(&ok.text().await.expect("")).unwrap();
Ok(base64::decode(response.audioContent).unwrap()[..].to_vec())
},
Err(err) => Err(Box::new(err))
.send()
.await
{
Ok(ok) => {
let response: SynthesizeResponse =
serde_json::from_str(&ok.text().await.expect("")).unwrap();
Ok(base64::decode(response.audioContent).unwrap()[..].to_vec())
}
Err(err) => Err(Box::new(err)),
}
}
}

View File

@ -1,2 +1,2 @@
pub mod gcp_tts;
pub mod structs;
pub mod structs;

View File

@ -1,4 +1,4 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
/// Example:
/// ```rust
@ -13,5 +13,5 @@ use serde::{Serialize, Deserialize};
pub struct AudioConfig {
pub audioEncoding: String,
pub speakingRate: f32,
pub pitch: f32
}
pub pitch: f32,
}

View File

@ -1,5 +1,5 @@
pub mod audio_config;
pub mod synthesis_input;
pub mod synthesize_request;
pub mod synthesize_response;
pub mod voice_selection_params;
pub mod synthesize_response;

View File

@ -1,4 +1,4 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
/// Example:
/// ```rust
@ -7,8 +7,8 @@ use serde::{Serialize, Deserialize};
/// ssml: Some(String::from("<speak>test</speak>"))
/// }
/// ```
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone)]
pub struct SynthesisInput {
pub text: Option<String>,
pub ssml: Option<String>
}
pub ssml: Option<String>,
}

View File

@ -1,9 +1,8 @@
use serde::{Serialize, Deserialize};
use crate::tts::gcp_tts::structs::{
synthesis_input::SynthesisInput,
audio_config::AudioConfig,
audio_config::AudioConfig, synthesis_input::SynthesisInput,
voice_selection_params::VoiceSelectionParams,
};
use serde::{Deserialize, Serialize};
/// Example:
/// ```rust
@ -30,4 +29,4 @@ pub struct SynthesizeRequest {
pub input: SynthesisInput,
pub voice: VoiceSelectionParams,
pub audioConfig: AudioConfig,
}
}

View File

@ -1,7 +1,7 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug)]
#[allow(non_snake_case)]
pub struct SynthesizeResponse {
pub audioContent: String
}
pub audioContent: String,
}

View File

@ -1,4 +1,4 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
/// Example:
/// ```rust
@ -8,10 +8,10 @@ use serde::{Serialize, Deserialize};
/// ssmlGender: String::from("neutral")
/// }
/// ```
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq, Eq, Hash)]
#[allow(non_snake_case)]
pub struct VoiceSelectionParams {
pub languageCode: String,
pub name: String,
pub ssmlGender: String
}
pub ssmlGender: String,
}

View File

@ -1,12 +1,21 @@
use serenity::{model::{channel::Message, id::{ChannelId, GuildId}}, prelude::Context};
use std::fmt::Debug;
use crate::{tts::message::TTSMessage};
use serenity::{
model::{
channel::Message,
id::{ChannelId, GuildId},
},
prelude::Context,
};
use crate::tts::message::TTSMessage;
#[derive(Debug, Clone)]
pub struct TTSInstance {
pub before_message: Option<Message>,
pub text_channel: ChannelId,
pub voice_channel: ChannelId,
pub guild: GuildId
pub guild: GuildId,
}
impl TTSInstance {
@ -16,17 +25,29 @@ impl TTSInstance {
/// ```rust
/// instance.read(message, &ctx).await;
/// ```
#[tracing::instrument]
pub async fn read<T>(&mut self, message: T, ctx: &Context)
where T: TTSMessage
where
T: TTSMessage + Debug,
{
let path = message.synthesize(self, ctx).await;
let audio = message.synthesize(self, ctx).await;
{
let manager = songbird::get(&ctx).await.unwrap();
let call = manager.get(self.guild).unwrap();
let mut call = call.lock().await;
let input = songbird::input::ffmpeg(path).await.expect("File not found.");
call.enqueue_source(input);
for audio in audio {
call.enqueue(audio.into()).await;
}
}
}
#[tracing::instrument]
pub async fn skip(&mut self, ctx: &Context) {
let manager = songbird::get(&ctx).await.unwrap();
let call = manager.get(self.guild).unwrap();
let call = call.lock().await;
let queue = call.queue();
let _ = queue.skip();
}
}

View File

@ -1,16 +1,17 @@
use std::{path::Path, fs::File, io::Write};
use async_trait::async_trait;
use serenity::prelude::Context;
use songbird::tracks::Track;
use crate::{tts::instance::TTSInstance, data::TTSClientData};
use crate::{data::TTSClientData, tts::instance::TTSInstance};
use super::gcp_tts::structs::{synthesize_request::SynthesizeRequest, synthesis_input::SynthesisInput, audio_config::AudioConfig, voice_selection_params::VoiceSelectionParams};
use super::gcp_tts::structs::{
audio_config::AudioConfig, synthesis_input::SynthesisInput,
synthesize_request::SynthesizeRequest, voice_selection_params::VoiceSelectionParams,
};
/// Message trait that can be used to synthesize text to speech.
#[async_trait]
pub trait TTSMessage {
/// Parse the message for synthesis.
///
/// Example:
@ -19,58 +20,57 @@ pub trait TTSMessage {
/// ```
async fn parse(&self, instance: &mut TTSInstance, ctx: &Context) -> String;
/// Synthesize the message and returns the path to the audio file.
/// Synthesize the message and returns the audio data.
///
/// Example:
/// ```rust
/// let path = message.synthesize(instance, ctx).await;
/// let audio = message.synthesize(instance, ctx).await;
/// ```
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String;
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Track>;
}
#[derive(Debug, Clone)]
pub struct AnnounceMessage {
pub message: String,
}
#[async_trait]
impl TTSMessage for AnnounceMessage {
async fn parse(&self, instance: &mut TTSInstance, ctx: &Context) -> String {
async fn parse(&self, instance: &mut TTSInstance, _ctx: &Context) -> String {
instance.before_message = None;
format!(r#"<speak>アナウンス<break time="200ms"/>{}</speak>"#, self.message)
format!(
r#"<speak>アナウンス<break time="200ms"/>{}</speak>"#,
self.message
)
}
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> String {
async fn synthesize(&self, instance: &mut TTSInstance, ctx: &Context) -> Vec<Track> {
let text = self.parse(instance, ctx).await;
let data_read = ctx.data.read().await;
let storage = data_read.get::<TTSClientData>().expect("Cannot get TTSClientStorage").clone();
let mut storage = storage.lock().await;
let tts = data_read
.get::<TTSClientData>()
.expect("Cannot get TTSClientStorage");
let audio = storage.0.synthesize(SynthesizeRequest {
input: SynthesisInput {
text: None,
ssml: Some(text)
},
voice: VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
ssmlGender: String::from("neutral")
},
audioConfig: AudioConfig {
audioEncoding: String::from("mp3"),
speakingRate: 1.2f32,
pitch: 1.0f32
}
}).await.unwrap();
let audio = tts
.synthesize_gcp(SynthesizeRequest {
input: SynthesisInput {
text: None,
ssml: Some(text),
},
voice: VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
ssmlGender: String::from("neutral"),
},
audioConfig: AudioConfig {
audioEncoding: String::from("mp3"),
speakingRate: 1.2f32,
pitch: 1.0f32,
},
})
.await
.unwrap();
let uuid = uuid::Uuid::new_v4().to_string();
let root = option_env!("CARGO_MANIFEST_DIR").unwrap();
let path = Path::new(root);
let file_path = path.join("audio").join(format!("{}.mp3", uuid));
let mut file = File::create(file_path.clone()).unwrap();
file.write(&audio).unwrap();
file_path.into_os_string().into_string().unwrap()
vec![audio.into()]
}
}
}

View File

@ -1,5 +1,6 @@
pub mod gcp_tts;
pub mod voicevox;
pub mod instance;
pub mod message;
pub mod tts;
pub mod tts_type;
pub mod instance;
pub mod voicevox;

133
src/tts/tts.rs Normal file
View File

@ -0,0 +1,133 @@
use std::sync::RwLock;
use std::{num::NonZeroUsize, sync::Arc};
use lru::LruCache;
use songbird::{driver::Bitrate, input::cached::Compressed, tracks::Track};
use tracing::info;
use super::{
gcp_tts::{
gcp_tts::GCPTTS,
structs::{
synthesis_input::SynthesisInput, synthesize_request::SynthesizeRequest,
voice_selection_params::VoiceSelectionParams,
},
},
voicevox::voicevox::VOICEVOX,
};
#[derive(Debug)]
pub struct TTS {
pub voicevox_client: VOICEVOX,
gcp_tts_client: GCPTTS,
cache: Arc<RwLock<LruCache<CacheKey, Compressed>>>,
}
#[derive(Hash, PartialEq, Eq)]
pub enum CacheKey {
Voicevox(String, i64),
GCP(SynthesisInput, VoiceSelectionParams),
}
impl TTS {
pub fn new(voicevox_client: VOICEVOX, gcp_tts_client: GCPTTS) -> Self {
Self {
voicevox_client,
gcp_tts_client,
cache: Arc::new(RwLock::new(LruCache::new(NonZeroUsize::new(1000).unwrap()))),
}
}
#[tracing::instrument]
pub async fn synthesize_voicevox(
&self,
text: &str,
speaker: i64,
) -> Result<Track, Box<dyn std::error::Error>> {
let cache_key = CacheKey::Voicevox(text.to_string(), speaker);
let cached_audio = {
let mut cache_guard = self.cache.write().unwrap();
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
};
if let Some(audio) = cached_audio {
info!("Cache hit for VOICEVOX TTS");
return Ok(audio.into());
}
info!("Cache miss for VOICEVOX TTS");
if self.voicevox_client.original_api_url.is_some() {
let audio = self
.voicevox_client
.synthesize_original(text.to_string(), speaker)
.await?;
tokio::spawn({
let cache = self.cache.clone();
let audio = audio.clone();
async move {
info!("Compressing stream audio");
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await.unwrap();
let mut cache_guard = cache.write().unwrap();
cache_guard.put(cache_key, compressed.clone());
}
});
Ok(audio.into())
} else {
let audio = self
.voicevox_client
.synthesize_stream(text.to_string(), speaker)
.await?;
tokio::spawn({
let cache = self.cache.clone();
let audio = audio.clone();
async move {
info!("Compressing stream audio");
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await.unwrap();
let mut cache_guard = cache.write().unwrap();
cache_guard.put(cache_key, compressed.clone());
}
});
Ok(audio.into())
}
}
#[tracing::instrument]
pub async fn synthesize_gcp(
&self,
synthesize_request: SynthesizeRequest,
) -> Result<Compressed, Box<dyn std::error::Error>> {
let cache_key = CacheKey::GCP(
synthesize_request.input.clone(),
synthesize_request.voice.clone(),
);
let cached_audio = {
let mut cache_guard = self.cache.write().unwrap();
cache_guard.get(&cache_key).map(|audio| audio.new_handle())
};
if let Some(audio) = cached_audio {
info!("Cache hit for GCP TTS");
return Ok(audio);
}
info!("Cache miss for GCP TTS");
let audio = self.gcp_tts_client.synthesize(synthesize_request).await?;
let compressed = Compressed::new(audio.into(), Bitrate::Auto).await?;
{
let mut cache_guard = self.cache.write().unwrap();
cache_guard.put(cache_key, compressed.clone());
}
Ok(compressed)
}
}

View File

@ -3,5 +3,5 @@ use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum TTSType {
GCP,
VOICEVOX
}
VOICEVOX,
}

View File

@ -1,2 +1,2 @@
pub mod structs;
pub mod voicevox;
pub mod voicevox;

View File

@ -1,4 +1,4 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
use super::mora::Mora;
@ -7,5 +7,5 @@ pub struct AccentPhrase {
pub moras: Vec<Mora>,
pub accent: f64,
pub pause_mora: Option<Mora>,
pub is_interrogative: bool
}
pub is_interrogative: bool,
}

View File

@ -1,4 +1,4 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
use super::accent_phrase::AccentPhrase;
@ -14,5 +14,5 @@ pub struct AudioQuery {
pub postPhonemeLength: f64,
pub outputSamplingRate: f64,
pub outputStereo: bool,
pub kana: Option<String>
}
pub kana: Option<String>,
}

View File

@ -1,3 +1,5 @@
pub mod mora;
pub mod accent_phrase;
pub mod audio_query;
pub mod accent_phrase;
pub mod mora;
pub mod speaker;
pub mod stream;

View File

@ -1,4 +1,4 @@
use serde::{Serialize, Deserialize};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Mora {
@ -7,5 +7,5 @@ pub struct Mora {
pub consonant_length: Option<f64>,
pub vowel: String,
pub vowel_length: f64,
pub pitch: f64
}
pub pitch: f64,
}

View File

@ -0,0 +1,21 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Speaker {
pub supported_features: SupportedFeatures,
pub name: String,
pub speaker_uuid: String,
pub styles: Vec<Style>,
pub version: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct SupportedFeatures {
pub permitted_synthesis_morphing: String,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Style {
pub name: String,
pub id: i64,
}

View File

@ -0,0 +1,13 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TTSResponse {
pub success: bool,
pub is_api_key_valid: bool,
pub speaker_name: String,
pub audio_status_url: String,
pub wav_download_url: String,
pub mp3_download_url: String,
pub mp3_streaming_url: String,
}

View File

@ -1,27 +1,133 @@
const API_URL: &str = "https://api.su-shiki.com/v2/voicevox/audio";
use crate::stream_input::Mp3Request;
#[derive(Clone)]
use super::structs::{speaker::Speaker, stream::TTSResponse};
const BASE_API_URL: &str = "https://deprecatedapis.tts.quest/v2/";
#[derive(Clone, Debug)]
pub struct VOICEVOX {
pub key: String
pub key: Option<String>,
pub original_api_url: Option<String>,
}
impl VOICEVOX {
pub fn new(key: String) -> Self {
#[tracing::instrument]
pub async fn get_styles(&self) -> Vec<(String, i64)> {
let speakers = self.get_speaker_list().await;
let mut speaker_list = vec![];
for speaker in speakers {
for style in speaker.styles {
speaker_list.push((format!("{} - {}", speaker.name, style.name), style.id))
}
}
speaker_list
}
#[tracing::instrument]
pub async fn get_speakers(&self) -> Vec<String> {
let speakers = self.get_speaker_list().await;
let mut speaker_list = vec![];
for speaker in speakers {
speaker_list.push(speaker.name)
}
speaker_list
}
pub fn new(key: Option<String>, original_api_url: Option<String>) -> Self {
Self {
key
key,
original_api_url,
}
}
pub async fn synthesize(&self, text: String, speaker: i64) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
#[tracing::instrument]
async fn get_speaker_list(&self) -> Vec<Speaker> {
let client = reqwest::Client::new();
match client.post(API_URL).query(&[("speaker", speaker.to_string()), ("text", text), ("key", self.key.clone())]).send().await {
let client = if let Some(key) = &self.key {
client
.get(BASE_API_URL.to_string() + "voicevox/speakers/")
.query(&[("key", key)])
} else if let Some(original_api_url) = &self.original_api_url {
client.get(original_api_url.to_string() + "/speakers")
} else {
panic!("No API key or original API URL provided.")
};
match client.send().await {
Ok(response) => response.json().await.unwrap(),
Err(err) => {
panic!("Cannot get speaker list. {err:?}")
}
}
}
#[tracing::instrument]
pub async fn synthesize(
&self,
text: String,
speaker: i64,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let client = reqwest::Client::new();
match client
.post(BASE_API_URL.to_string() + "voicevox/audio/")
.query(&[
("speaker", speaker.to_string()),
("text", text),
("key", self.key.clone().unwrap()),
])
.send()
.await
{
Ok(response) => {
let body = response.bytes().await?;
Ok(body.to_vec())
}
Err(err) => {
Err(Box::new(err))
}
Err(err) => Err(Box::new(err)),
}
}
}
#[tracing::instrument]
pub async fn synthesize_original(
&self,
text: String,
speaker: i64,
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let client =
voicevox_client::Client::new(self.original_api_url.as_ref().unwrap().clone(), None);
let audio_query = client
.create_audio_query(&text, speaker as i32, None)
.await?;
println!("{:?}", audio_query.audio_query);
let audio = audio_query.synthesis(speaker as i32, true).await?;
Ok(audio.into())
}
#[tracing::instrument]
pub async fn synthesize_stream(
&self,
text: String,
speaker: i64,
) -> Result<Mp3Request, Box<dyn std::error::Error>> {
let client = reqwest::Client::new();
match client
.post("https://api.tts.quest/v3/voicevox/synthesis")
.query(&[
("speaker", speaker.to_string()),
("text", text),
("key", self.key.clone().unwrap()),
])
.send()
.await
{
Ok(response) => {
let body = response.text().await.unwrap();
let response: TTSResponse = serde_json::from_str(&body).unwrap();
Ok(Mp3Request::new(reqwest::Client::new(), response.mp3_streaming_url).into())
}
Err(err) => Err(Box::new(err)),
}
}
}