add database integration

This commit is contained in:
mii
2022-08-06 11:40:50 +09:00
parent 6c3e1700bd
commit adb111f2d1
11 changed files with 112 additions and 13 deletions

View File

@ -13,6 +13,7 @@ gcp_auth = "0.5.0"
reqwest = { version = "0.11", features = ["json"] }
base64 = "0.13"
async-trait = "0.1.57"
redis = "*"
[dependencies.uuid]
version = "0.8"

View File

@ -4,5 +4,6 @@ use serde::Deserialize;
pub struct Config {
pub prefix: String,
pub token: String,
pub application_id: u64
pub application_id: u64,
pub redis_url: String,
}

View File

@ -1,4 +1,4 @@
use crate::tts::gcp_tts::gcp_tts::TTS;
use crate::{tts::gcp_tts::gcp_tts::TTS, database::database::Database};
use serenity::{prelude::{TypeMapKey, RwLock}, model::id::GuildId, futures::lock::Mutex};
use crate::tts::instance::TTSInstance;
@ -17,3 +17,10 @@ pub struct TTSClientData;
impl TypeMapKey for TTSClientData {
type Value = Arc<Mutex<TTS>>;
}
/// Database client data
pub struct DatabaseClientData;
impl TypeMapKey for DatabaseClientData {
type Value = Arc<Mutex<Database>>;
}

59
src/database/database.rs Normal file
View File

@ -0,0 +1,59 @@
use crate::tts::{gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType};
use super::user_config::UserConfig;
use redis::Commands;
pub struct Database {
pub connection: redis::Connection
}
impl Database {
pub fn new(connection: redis::Connection) -> Self {
Self { connection }
}
pub async fn get_user_config(&mut self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
let config: String = self.connection.get(format!("discord_user:{}", user_id)).unwrap_or_default();
match serde_json::from_str(&config) {
Ok(config) => Ok(Some(config)),
Err(_) => Ok(None)
}
}
pub async fn set_user_config(&mut self, user_id: u64, config: UserConfig) -> redis::RedisResult<()> {
let config = serde_json::to_string(&config).unwrap();
self.connection.set::<String, String, ()>(format!("discord_user:{}", user_id), config).unwrap();
Ok(())
}
pub async fn set_default_user_config(&mut self, user_id: u64) -> redis::RedisResult<()> {
let voice_selection = VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
ssmlGender: String::from("neutral")
};
let voice_type = TTSType::GCP;
let config = UserConfig {
tts_type: Some(voice_type),
gcp_tts_voice: Some(voice_selection)
};
self.connection.set(format!("discord_user:{}", user_id), serde_json::to_string(&config).unwrap())?;
Ok(())
}
pub async fn get_user_config_or_default(&mut self, user_id: u64) -> redis::RedisResult<Option<UserConfig>> {
let config = self.get_user_config(user_id).await?;
match config {
Some(_) => Ok(config),
None => {
self.set_default_user_config(user_id).await?;
self.get_user_config(user_id).await
}
}
}
}

View File

@ -0,0 +1,2 @@
pub mod user_config;
pub mod database;

View File

@ -0,0 +1,9 @@
use serde::{Serialize, Deserialize};
use crate::tts::{gcp_tts::structs::voice_selection_params::VoiceSelectionParams, tts_type::TTSType};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct UserConfig {
pub tts_type: Option<TTSType>,
pub gcp_tts_voice: Option<VoiceSelectionParams>
}

View File

@ -93,7 +93,13 @@ impl EventHandler for Handler {
}
async fn message(&self, ctx: Context, message: Message) {
let guild_id = message.guild(&ctx.cache).await.unwrap().id;
let guild_id = message.guild(&ctx.cache).await;
if let None = guild_id {
return;
}
let guild_id = guild_id.unwrap().id;
let storage_lock = {
let data_read = ctx.data.read().await;
@ -115,7 +121,7 @@ impl EventHandler for Handler {
async fn ready(&self, ctx: Context, ready: Ready) {
println!("{} is connected!", ready.user.name);
let guild_id = GuildId(696782998799909024);
let guild_id = GuildId(660046656934248460);
let commands = GuildId::set_application_commands(&guild_id, &ctx.http, |commands| {
commands.create_application_command(|command| {

View File

@ -4,7 +4,7 @@ use async_trait::async_trait;
use serenity::{prelude::Context, model::prelude::Message};
use crate::{
data::TTSClientData,
data::{TTSClientData, DatabaseClientData},
tts::{
instance::TTSInstance,
message::TTSMessage,
@ -39,16 +39,18 @@ impl TTSMessage for Message {
let storage = data_read.get::<TTSClientData>().expect("Cannot get TTSClientStorage").clone();
let storage = storage.lock().await;
let config = {
let database = data_read.get::<DatabaseClientData>().expect("Cannot get DatabaseClientData").clone();
let mut database = database.lock().await;
database.get_user_config_or_default(self.author.id.0).await.unwrap().unwrap()
};
let audio = storage.synthesize(SynthesizeRequest {
input: SynthesisInput {
text: None,
ssml: Some(text)
},
voice: VoiceSelectionParams {
languageCode: String::from("ja-JP"),
name: String::from("ja-JP-Wavenet-B"),
ssmlGender: String::from("neutral")
},
voice: config.gcp_tts_voice.unwrap(),
audioConfig: AudioConfig {
audioEncoding: String::from("mp3"),
speakingRate: 1.2f32,

View File

@ -1,7 +1,8 @@
use std::{sync::Arc, collections::HashMap};
use config::Config;
use data::{TTSData, TTSClientData};
use data::{TTSData, TTSClientData, DatabaseClientData};
use database::database::Database;
use event_handler::Handler;
use tts::gcp_tts::gcp_tts::TTS;
use serenity::{
@ -16,6 +17,7 @@ mod event_handler;
mod tts;
mod implement;
mod data;
mod database;
/// Create discord client
///
@ -55,11 +57,18 @@ async fn main() {
Err(err) => panic!("{}", err)
};
let database_client = {
let redis_client = redis::Client::open(config.redis_url).unwrap();
let con = redis_client.get_connection().unwrap();
Database::new(con)
};
// Create TTS storage
{
let mut data = client.data.write().await;
data.insert::<TTSData>(Arc::new(RwLock::new(HashMap::default())));
data.insert::<TTSClientData>(Arc::new(Mutex::new(tts)));
data.insert::<DatabaseClientData>(Arc::new(Mutex::new(database_client)));
}
// Run client

View File

@ -8,7 +8,7 @@ use serde::{Serialize, Deserialize};
/// ssmlGender: String::from("neutral")
/// }
/// ```
#[derive(Serialize, Deserialize, Debug)]
#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
#[allow(non_snake_case)]
pub struct VoiceSelectionParams {
pub languageCode: String,

View File

@ -1,3 +1,6 @@
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum TTSType {
GCP,
VOICEVOX