mirror of
https://github.com/mii443/ncb-tts-r2.git
synced 2025-08-22 16:15:29 +00:00
add database integration
This commit is contained in:
@ -13,6 +13,7 @@ gcp_auth = "0.5.0"
|
||||
reqwest = { version = "0.11", features = ["json"] }
|
||||
base64 = "0.13"
|
||||
async-trait = "0.1.57"
|
||||
redis = "*"
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "0.8"
|
||||
|
@ -4,5 +4,6 @@ use serde::Deserialize;
|
||||
pub struct Config {
|
||||
pub prefix: String,
|
||||
pub token: String,
|
||||
pub application_id: u64
|
||||
pub application_id: u64,
|
||||
pub redis_url: String,
|
||||
}
|
11
src/data.rs
11
src/data.rs
@ -1,4 +1,4 @@
|
||||
use crate::tts::gcp_tts::gcp_tts::TTS;
|
||||
use crate::{tts::gcp_tts::gcp_tts::TTS, database::database::Database};
|
||||
use serenity::{prelude::{TypeMapKey, RwLock}, model::id::GuildId, futures::lock::Mutex};
|
||||
|
||||
use crate::tts::instance::TTSInstance;
|
||||
@ -16,4 +16,11 @@ pub struct TTSClientData;
|
||||
|
||||
impl TypeMapKey for TTSClientData {
|
||||
type Value = Arc<Mutex<TTS>>;
|
||||
}
|
||||
}
|
||||
|
||||
/// Database client data
|
||||
pub struct DatabaseClientData;
|
||||
|
||||
impl TypeMapKey for DatabaseClientData {
|
||||
type Value = Arc<Mutex<Database>>;
|
||||
}
|
||||
|
59
src/database/database.rs
Normal file
59
src/database/database.rs
Normal file
@ -0,0 +1,59 @@
|
||||
use crate::tts::{gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType};
|
||||
|
||||
use super::user_config::UserConfig;
|
||||
use redis::Commands;
|
||||
|
||||
pub struct Database {
|
||||
pub connection: redis::Connection
|
||||
}
|
||||
|
||||
impl Database {
|
||||
pub fn new(connection: redis::Connection) -> Self {
|
||||
Self { connection }
|
||||
}
|
||||
|
||||
pub async fn get_user_config(&mut self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
|
||||
let config: String = self.connection.get(format!("discord_user:{}", user_id)).unwrap_or_default();
|
||||
|
||||
match serde_json::from_str(&config) {
|
||||
Ok(config) => Ok(Some(config)),
|
||||
Err(_) => Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn set_user_config(&mut self, user_id: u64, config: UserConfig) -> redis::RedisResult<()> {
|
||||
let config = serde_json::to_string(&config).unwrap();
|
||||
self.connection.set::<String, String, ()>(format!("discord_user:{}", user_id), config).unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn set_default_user_config(&mut self, user_id: u64) -> redis::RedisResult<()> {
|
||||
let voice_selection = VoiceSelectionParams {
|
||||
languageCode: String::from("ja-JP"),
|
||||
name: String::from("ja-JP-Wavenet-B"),
|
||||
ssmlGender: String::from("neutral")
|
||||
};
|
||||
|
||||
let voice_type = TTSType::GCP;
|
||||
|
||||
let config = UserConfig {
|
||||
tts_type: Some(voice_type),
|
||||
gcp_tts_voice: Some(voice_selection)
|
||||
};
|
||||
|
||||
self.connection.set(format!("discord_user:{}", user_id), serde_json::to_string(&config).unwrap())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_user_config_or_default(&mut self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
|
||||
let config = self.get_user_config(user_id).await?;
|
||||
match config {
|
||||
Some(_) => Ok(config),
|
||||
None => {
|
||||
self.set_default_user_config(user_id).await?;
|
||||
self.get_user_config(user_id).await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,2 @@
|
||||
pub mod user_config;
|
||||
pub mod database;
|
9
src/database/user_config.rs
Normal file
9
src/database/user_config.rs
Normal file
@ -0,0 +1,9 @@
|
||||
use serde::{Serialize, Deserialize};
|
||||
|
||||
use crate::tts::{gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct UserConfig {
|
||||
pub tts_type: Option<TTSType>,
|
||||
pub gcp_tts_voice: Option<VoiceSelectionParams>
|
||||
}
|
@ -93,7 +93,13 @@ impl EventHandler for Handler {
|
||||
}
|
||||
|
||||
async fn message(&self, ctx: Context, message: Message) {
|
||||
let guild_id = message.guild(&ctx.cache).await.unwrap().id;
|
||||
let guild_id = message.guild(&ctx.cache).await;
|
||||
|
||||
if let None = guild_id {
|
||||
return;
|
||||
}
|
||||
|
||||
let guild_id = guild_id.unwrap().id;
|
||||
|
||||
let storage_lock = {
|
||||
let data_read = ctx.data.read().await;
|
||||
@ -115,7 +121,7 @@ impl EventHandler for Handler {
|
||||
async fn ready(&self, ctx: Context, ready: Ready) {
|
||||
println!("{} is connected!", ready.user.name);
|
||||
|
||||
let guild_id = GuildId(696782998799909024);
|
||||
let guild_id = GuildId(660046656934248460);
|
||||
|
||||
let commands = GuildId::set_application_commands(&guild_id, &ctx.http, |commands| {
|
||||
commands.create_application_command(|command| {
|
||||
|
@ -4,7 +4,7 @@ use async_trait::async_trait;
|
||||
use serenity::{prelude::Context, model::prelude::Message};
|
||||
|
||||
use crate::{
|
||||
data::TTSClientData,
|
||||
data::{TTSClientData, DatabaseClientData},
|
||||
tts::{
|
||||
instance::TTSInstance,
|
||||
message::TTSMessage,
|
||||
@ -39,16 +39,18 @@ impl TTSMessage for Message {
|
||||
let storage = data_read.get::<TTSClientData>().expect("Cannot get TTSClientStorage").clone();
|
||||
let storage = storage.lock().await;
|
||||
|
||||
let config = {
|
||||
let database = data_read.get::<DatabaseClientData>().expect("Cannot get DatabaseClientData").clone();
|
||||
let mut database = database.lock().await;
|
||||
database.get_user_config_or_default(self.author.id.0).await.unwrap().unwrap()
|
||||
};
|
||||
|
||||
let audio = storage.synthesize(SynthesizeRequest {
|
||||
input: SynthesisInput {
|
||||
text: None,
|
||||
ssml: Some(text)
|
||||
},
|
||||
voice: VoiceSelectionParams {
|
||||
languageCode: String::from("ja-JP"),
|
||||
name: String::from("ja-JP-Wavenet-B"),
|
||||
ssmlGender: String::from("neutral")
|
||||
},
|
||||
voice: config.gcp_tts_voice.unwrap(),
|
||||
audioConfig: AudioConfig {
|
||||
audioEncoding: String::from("mp3"),
|
||||
speakingRate: 1.2f32,
|
||||
|
11
src/main.rs
11
src/main.rs
@ -1,7 +1,8 @@
|
||||
use std::{sync::Arc, collections::HashMap};
|
||||
|
||||
use config::Config;
|
||||
use data::{TTSData, TTSClientData};
|
||||
use data::{TTSData, TTSClientData, DatabaseClientData};
|
||||
use database::database::Database;
|
||||
use event_handler::Handler;
|
||||
use tts::gcp_tts::gcp_tts::TTS;
|
||||
use serenity::{
|
||||
@ -16,6 +17,7 @@ mod event_handler;
|
||||
mod tts;
|
||||
mod implement;
|
||||
mod data;
|
||||
mod database;
|
||||
|
||||
/// Create discord client
|
||||
///
|
||||
@ -55,11 +57,18 @@ async fn main() {
|
||||
Err(err) => panic!("{}", err)
|
||||
};
|
||||
|
||||
let database_client = {
|
||||
let redis_client = redis::Client::open(config.redis_url).unwrap();
|
||||
let con = redis_client.get_connection().unwrap();
|
||||
Database::new(con)
|
||||
};
|
||||
|
||||
// Create TTS storage
|
||||
{
|
||||
let mut data = client.data.write().await;
|
||||
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
|
||||
data.insert::<TTSClientData>(Arc::new(Mutex::new(tts)));
|
||||
data.insert::<DatabaseClientData>(Arc::new(Mutex::new(database_client)));
|
||||
}
|
||||
|
||||
// Run client
|
||||
|
@ -8,7 +8,7 @@ use serde::{Serialize, Deserialize};
|
||||
/// ssmlGender: String::from("neutral")
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
|
||||
#[allow(non_snake_case)]
|
||||
pub struct VoiceSelectionParams {
|
||||
pub languageCode: String,
|
||||
|
@ -1,3 +1,6 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub enum TTSType {
|
||||
GCP,
|
||||
VOICEVOX
|
||||
|
Reference in New Issue
Block a user