From 380ee3da3eb2b3f1ac12ea8718a36cbbf7d74e1d Mon Sep 17 00:00:00 2001 From: mii Date: Thu, 16 Mar 2023 13:12:24 +0000 Subject: [PATCH] support local llama --- Cargo.toml | 6 +- docker-compose.yml | 4 +- src/config.rs | 13 ++ src/data.rs | 94 ++++++++-- src/event_handler.rs | 81 +++++++- src/events/message_receive.rs | 339 +++++++++++++++++++++++++++++++--- src/events/ready.rs | 34 +++- src/main.rs | 25 ++- 8 files changed, 544 insertions(+), 52 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ef6b982..01516b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "ncb-ping" +name = "ncb-chat" version = "0.1.0" edition = "2021" @@ -11,6 +11,10 @@ serde = "1.0" toml = "0.5" async-trait = "0.1.57" chrono = "0.4.23" +reqwest = { version = "0.11", features = ["json"] } +google-translate3 = "4.0.1+20220121" +gcp_auth = "0.5.0" + [dependencies.uuid] version = "0.8" diff --git a/docker-compose.yml b/docker-compose.yml index 4b48b25..daa26f5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,8 +2,8 @@ version: '3' services: ncb-ping: - container_name: ncb-ping - image: ghcr.io/morioka22/ncb-ping:0.0.1 + container_name: ncb-chat + image: ghcr.io/morioka22/ncb-chat:0.0.1 environment: - NCB_TOKEN=YOUR_BOT_TOKEN - NCB_APP_ID=YOUR_BOT_ID diff --git a/src/config.rs b/src/config.rs index 842a2f6..7d3d4e1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,7 +1,20 @@ +use std::sync::Arc; + use serde::Deserialize; +use serenity::{futures::lock::Mutex, prelude::TypeMapKey}; #[derive(Deserialize)] pub struct Config { pub token: String, pub application_id: u64, + pub llama_url: String, + pub openai_key: String, + pub chatgpt_allows: Vec, + pub chatgpt_forums: Vec, +} + +pub struct ConfigData; + +impl TypeMapKey for ConfigData { + type Value = Arc>; } diff --git a/src/data.rs b/src/data.rs index 5e2c6e4..d42efcf 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,24 +1,96 @@ use std::{collections::HashMap, sync::Arc}; -use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; use serenity::{ futures::lock::Mutex, - model::prelude::{ChannelId, Message, UserId}, + model::prelude::{ChannelId, UserId}, prelude::TypeMapKey, }; +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct TranslateRequest { + pub input: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct LlamaMessage { + pub role: String, + pub content: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct LlamaRequest { + pub messages: Vec, +} + #[derive(Debug, Clone)] -pub struct Ping { +pub struct Llama { pub channel: ChannelId, - pub user_id: UserId, - pub author: UserId, - pub message: Message, - pub time: DateTime, - pub args: Vec, + pub history: Vec, } -pub struct PingData; +pub struct LlamaData; -impl TypeMapKey for PingData { - type Value = Arc>>; +impl TypeMapKey for LlamaData { + type Value = Arc>>; +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ChatGPTMessage { + pub role: String, + pub content: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Usage { + pub prompt_tokens: usize, + pub completion_tokens: usize, + pub total_tokens: usize, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Choice { + pub index: usize, + pub finish_reason: Option, + pub message: ChatGPTMessage, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ChatGPTResponse { + pub id: String, + pub object: String, + pub created: usize, + pub model: String, + pub choices: Vec, + pub usage: Usage, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ChatGPTRequest { + pub model: String, + pub messages: Vec, +} + +#[derive(Debug, Clone)] +pub struct ChatGPT { + pub channel: ChannelId, + pub history: Vec, +} + +#[derive(Debug, Clone)] +pub struct IndividualChatGPT { + pub user: UserId, + pub history: Vec, +} + +pub struct IndividualChatGPTData; + +impl TypeMapKey for IndividualChatGPTData { + type Value = Arc>>; +} + +pub struct ChatGPTData; + +impl TypeMapKey for ChatGPTData { + type Value = Arc>>; } diff --git a/src/event_handler.rs b/src/event_handler.rs index 6365eb8..66825d2 100644 --- a/src/event_handler.rs +++ b/src/event_handler.rs @@ -1,18 +1,91 @@ +use serenity::model::prelude::interaction::Interaction; use serenity::{ async_trait, client::{Context, EventHandler}, - model::{ - channel::Message, - gateway::Ready, - }, + model::{channel::GuildChannel, channel::Message, gateway::Ready}, }; +use crate::config::ConfigData; +use crate::data::*; use crate::events; pub struct Handler; #[async_trait] impl EventHandler for Handler { + async fn interaction_create(&self, ctx: Context, interaction: Interaction) { + if let Interaction::ApplicationCommand(command) = interaction.clone() { + let name = &*command.data.name; + if name == "llama" { + let message = command + .channel_id + .send_message(&ctx.http, |f| f.content("LLaMa Chat thread")) + .await + .unwrap(); + let id = command + .channel_id + .create_public_thread(&ctx.http, message, |f| { + f.name("LLaMa Chat").auto_archive_duration(60) + }) + .await + .unwrap() + .id; + let llama = Llama { + channel: id, + history: vec![], + }; + let storage_lock = { + let data_read = ctx.data.read().await; + data_read + .get::() + .expect("Cannot get TTSStorage") + .clone() + }; + storage_lock.lock().await.insert(id, llama); + } + + if name == "chatgpt" { + let cs = { + let data_read = ctx.data.read().await; + data_read + .get::() + .expect("Cannot get ConfigData") + .clone() + }; + let config = cs.lock().await; + + if !config.chatgpt_allows.contains(&(command.user.id.0 as i64)) { + return; + } + + let message = command + .channel_id + .send_message(&ctx.http, |f| f.content("ChatGPT thread")) + .await + .unwrap(); + let id = command + .channel_id + .create_public_thread(&ctx.http, message, |f| { + f.name("ChatGPT").auto_archive_duration(60) + }) + .await + .unwrap() + .id; + let chatgpt = ChatGPT { + channel: id, + history: vec![], + }; + let storage_lock = { + let data_read = ctx.data.read().await; + data_read + .get::() + .expect("Cannot get TTSStorage") + .clone() + }; + storage_lock.lock().await.insert(id, chatgpt); + } + } + } async fn message(&self, ctx: Context, message: Message) { events::message_receive::message(ctx, message).await; } diff --git a/src/events/message_receive.rs b/src/events/message_receive.rs index dbc807c..107c998 100644 --- a/src/events/message_receive.rs +++ b/src/events/message_receive.rs @@ -1,58 +1,339 @@ use chrono::Utc; -use serenity::{model::prelude::Message, prelude::Context}; +use serenity::{ + http::CacheHttp, + model::prelude::{GuildChannel, Message}, + prelude::Context, +}; -use crate::data::{PingData, Ping}; +use crate::{ + config::*, + data::{ + ChatGPT, ChatGPTData, ChatGPTMessage, ChatGPTRequest, ChatGPTResponse, IndividualChatGPT, + IndividualChatGPTData, Llama, LlamaData, LlamaMessage, LlamaRequest, TranslateRequest, + }, +}; + +async fn translate_en_ja(input: String) -> String { + let request = TranslateRequest { input }; + let client = reqwest::Client::new(); + match client + .post("http://localhost:8008/translate/enja") + .header(reqwest::header::CONTENT_TYPE, "application/json") + .body(serde_json::to_string(&request).unwrap()) + .send() + .await + { + Ok(ok) => ok.text().await.expect("ERROR"), + Err(err) => { + panic!("Error") + } + } +} + +async fn translate_ja_en(input: String) -> String { + let request = TranslateRequest { input }; + let client = reqwest::Client::new(); + match client + .post("http://localhost:8008/translate/jaen") + .header(reqwest::header::CONTENT_TYPE, "application/json") + .body(serde_json::to_string(&request).unwrap()) + .send() + .await + { + Ok(ok) => ok.text().await.expect("ERROR"), + Err(err) => { + panic!("Error") + } + } +} + +async fn chatgpt_request(input: ChatGPTRequest, key: String) -> ChatGPTResponse { + let client = reqwest::Client::new(); + match client + .post("https://api.openai.com/v1/chat/completions") + .header(reqwest::header::CONTENT_TYPE, "application/json") + .header(reqwest::header::AUTHORIZATION, format!("Bearer {}", key)) + .body(serde_json::to_string(&input).unwrap()) + .send() + .await + { + Ok(ok) => { + let text = ok.text().await.unwrap(); + println!("{}", text.clone()); + let response: ChatGPTResponse = serde_json::from_str(&text).unwrap(); + response + } + Err(err) => { + panic!("Error") + } + } +} + +fn split_string_into_chunks(s: &str, chunk_size: usize) -> Vec { + let mut chunks = Vec::new(); + let mut start = 0; + let len = s.len(); + + while start < len { + let end = if start + chunk_size < len { + start + chunk_size + } else { + len + }; + chunks.push(s[start..end].to_string()); + start = end; + } + + chunks +} + +pub async fn llama(ctx: Context, message: Message) {} pub async fn message(ctx: Context, message: Message) { if message.author.bot { return; } + if message.content.starts_with(";") { + message + .reply(&ctx.http, "スキップしました。") + .await + .unwrap(); + return; + } + let guild_id = message.guild(&ctx.cache); if let None = guild_id { return; } - let storage_lock = { + println!("Event received: {}", message.content); + let config_lock = { let data_read = ctx.data.read().await; data_read - .get::() - .expect("Cannot get PingData") + .get::() + .expect("Cannot get Config") .clone() }; - if message.mentions.len() == 1 { - let user = message.mentions.first().unwrap(); - let m = format!("PING {} ({}) 56(84) bytes of data.", user.name, user.id.0); - let ping_message = message.reply(&ctx.http, m).await.unwrap(); + let config = config_lock.lock().await; - let ping = Ping { - channel: message.channel_id, - user_id: user.id, - author: message.author.id, - time: Utc::now(), - message: ping_message, - args: vec![] - }; + let llama_storage_lock = { + let data_read = ctx.data.read().await; + data_read + .get::() + .expect("Cannot get LlamaData") + .clone() + }; - let mut storage = storage_lock.lock().await; - storage.insert(user.id, ping.clone()); + let mut llama_storage = llama_storage_lock.lock().await; + + if let Some(mut llama) = llama_storage.clone().get_mut(&message.channel_id) { + let typing = message.channel_id.start_typing(&ctx.http).unwrap(); + let mut history = llama.history.clone(); + let text = translate_ja_en(message.content.clone()).await; + println!("{}", text); + history.push(LlamaMessage { + role: "user".to_string(), + content: text, + }); + llama.history = history.clone(); + llama_storage.insert(llama.channel, llama.clone()); + + let request = LlamaRequest { messages: history }; + + let client = reqwest::Client::new(); + match client + .post("http://localhost:18080/") + .header(reqwest::header::CONTENT_TYPE, "application/json") + .body(serde_json::to_string(&request).unwrap()) + .send() + .await + { + Ok(ok) => { + let response = ok.text().await.expect("ERROR"); + let response = translate_en_ja(response).await; + println!("{}", response); + message + .channel_id + .send_message(&ctx.http, |f| f.content(response)) + .await + .unwrap(); + } + Err(err) => { + panic!("Error") + } + } + + typing.stop().unwrap(); } - { - let mut storage = storage_lock.lock().await; - if !storage.contains_key(&message.author.id) { + let chatgpt_storage_lock = { + let data_read = ctx.data.read().await; + data_read + .get::() + .expect("Cannot get ChatGPTData") + .clone() + }; + + let mut chatgpt_storage = chatgpt_storage_lock.lock().await; + + // forum + let channel = message.channel(&ctx.http).await.unwrap(); + if let Some(guild_channel) = channel.clone().guild() { + if let Some(parent) = guild_channel.parent_id { + if chatgpt_storage + .clone() + .get_mut(&message.channel_id) + .is_none() + && config.chatgpt_forums.contains(&(parent.0 as i64)) + { + chatgpt_storage.insert( + guild_channel.id, + ChatGPT { + channel: channel.id(), + history: vec![], + }, + ); + } + } + } + + if let Some(mut chatgpt) = chatgpt_storage.clone().get_mut(&message.channel_id) { + if message.content == "reset".to_string() { + chatgpt.history = vec![]; + chatgpt_storage.insert(chatgpt.channel, chatgpt.clone()); + message + .channel_id + .send_message(&ctx.http, |f| f.content("会話履歴をリセットしました。")) + .await + .unwrap(); return; } - let ping = storage.get_mut(&message.author.id).unwrap(); - if ping.channel == message.channel_id { - let user = ping.user_id.to_user(&ctx.http).await.unwrap(); - let time = Utc::now() - ping.time; - ping.message.edit(&ctx.http, |f| f.content(format!("--- {} ping statistics ---\n1 packets transmitted, 1 received, 0% packet loss, time {}ms", user.name, time.num_milliseconds()))).await.unwrap(); - message.channel_id.send_message(&ctx.http, |f| f.content(format!("<@{}>", ping.author.0))).await.unwrap(); - storage.remove(&message.author.id); + let typing = message.channel_id.start_typing(&ctx.http).unwrap(); + let mut history = chatgpt.history.clone(); + + history.push(ChatGPTMessage { + role: "user".to_string(), + content: message.content.clone(), + }); + + let request = ChatGPTRequest { + model: "gpt-3.5-turbo".to_string(), + messages: history.clone(), + }; + + let response = chatgpt_request(request, config.openai_key.clone()).await; + + history.push(response.choices[0].message.clone()); + + chatgpt.history = history.clone(); + + chatgpt_storage.insert(chatgpt.channel, chatgpt.clone()); + + println!("Tokens: {:?}", response.usage.total_tokens); + let responses = format!( + "{}\ntokens: {}/4096", + response.choices[0].message.content.clone(), + response.usage.total_tokens + ); + let responses = split_string_into_chunks(&responses, 2000); + + let l = responses.len(); + for response in responses { + if l > 1 { + message.reply(&ctx.http, response.replace("```", "")).await; + } else { + message.reply(&ctx.http, response).await; + } } + + typing.stop().unwrap(); + } + + let individual_chatgpt_storage_lock = { + let data_read = ctx.data.read().await; + data_read + .get::() + .expect("Cannot get IndividualChatGPTData") + .clone() + }; + + let mut individual_storage = individual_chatgpt_storage_lock.lock().await; + + let bot_id = format!("<@{}>", config.application_id); + if message.content.starts_with(&bot_id) { + if !config + .chatgpt_allows + .contains(&(message.author.id.0 as i64)) + { + message + .reply(&ctx.http, "権限がありません。") + .await + .unwrap(); + return; + } + + let text = message.content.replacen(&bot_id, "", 1); + + let mut tmp = IndividualChatGPT { + user: message.author.id.clone(), + history: vec![], + }; + + let storage_tmp = individual_storage.clone(); + let chatgpt = storage_tmp.get(&message.author.id.clone()).unwrap_or(&tmp); + let mut chatgpt = chatgpt.clone(); + + if text.trim() == "reset".to_string() { + individual_storage.insert(message.author.id.clone(), tmp); + message + .reply(&ctx.http, "会話履歴をリセットしました。") + .await + .unwrap(); + return; + } + + let typing = message.channel_id.start_typing(&ctx.http).unwrap(); + let mut history = chatgpt.history.clone(); + + history.push(ChatGPTMessage { + role: "user".to_string(), + content: message.content.clone(), + }); + + let request = ChatGPTRequest { + model: "gpt-3.5-turbo".to_string(), + messages: history.clone(), + }; + + let response = chatgpt_request(request, config.openai_key.clone()).await; + + history.push(response.choices[0].message.clone()); + + chatgpt.history = history.clone(); + + individual_storage.insert(chatgpt.user.clone(), chatgpt.clone()); + + println!("Tokens: {:?}", response.usage.total_tokens); + let responses = format!( + "{}\ntokens: {}/4096", + response.choices[0].message.content.clone(), + response.usage.total_tokens + ); + let responses = split_string_into_chunks(&responses, 2000); + + let l = responses.len(); + for response in responses { + if l > 1 { + message.reply(&ctx.http, response.replace("```", "")).await; + } else { + message.reply(&ctx.http, response).await; + } + } + + typing.stop().unwrap(); } } diff --git a/src/events/ready.rs b/src/events/ready.rs index 354d628..0957cce 100644 --- a/src/events/ready.rs +++ b/src/events/ready.rs @@ -1,5 +1,35 @@ -use serenity::{model::prelude::Ready, prelude::Context}; +use serenity::{model::prelude::command::Command, model::prelude::Ready, prelude::Context}; -pub async fn ready(_: Context, ready: Ready) { +pub async fn ready(ctx: Context, ready: Ready) { println!("{} is connected!", ready.user.name); + + let mut cosmo = true; + let mut chatgpt = true; + + for command in Command::get_global_application_commands(&ctx.http) + .await + .unwrap() + { + if command.name == "cosmo" { + cosmo = false; + } + if command.name == "chatgpt" { + chatgpt = false; + } + } + + if cosmo { + Command::create_global_application_command(&ctx.http, |command| { + command.name("cosmo").description("Start cosmo chat.") + }) + .await + .unwrap(); + } + if chatgpt { + Command::create_global_application_command(&ctx.http, |command| { + command.name("chatgpt").description("Start ChatGPT chat.") + }) + .await + .unwrap(); + } } diff --git a/src/main.rs b/src/main.rs index 73520fd..995eb9f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,8 +5,8 @@ mod events; use std::{collections::HashMap, env, sync::Arc}; -use config::Config; -use data::PingData; +use config::{Config, ConfigData}; +use data::{ChatGPTData, IndividualChatGPTData, LlamaData}; use event_handler::Handler; use serenity::{ client::Client, framework::StandardFramework, futures::lock::Mutex, prelude::GatewayIntents, @@ -40,10 +40,26 @@ async fn main() { } else { let token = env::var("NCB_TOKEN").unwrap(); let application_id = env::var("NCB_APP_ID").unwrap(); + let llama_url = env::var("LLAMA_URL").unwrap(); + let openai_key = env::var("OPENAI_KEY").unwrap(); + let chatgpt_allows = env::var("CHATGPT_ALLOWS").unwrap().to_string(); + let chatgpt_allows: Vec = chatgpt_allows + .split(",") + .map(|f| i64::from_str_radix(f, 10).unwrap()) + .collect(); + let chatgpt_forums = env::var("CHATGPT_FORUMS").unwrap().to_string(); + let chatgpt_forums: Vec = chatgpt_forums + .split(",") + .map(|f| i64::from_str_radix(f, 10).unwrap()) + .collect(); Config { token, application_id: u64::from_str_radix(&application_id, 10).unwrap(), + llama_url, + openai_key, + chatgpt_allows, + chatgpt_forums, } } }; @@ -56,7 +72,10 @@ async fn main() { // Create TTS storage { let mut data = client.data.write().await; - data.insert::(Arc::new(Mutex::new(HashMap::default()))); + data.insert::(Arc::new(Mutex::new(HashMap::default()))); + data.insert::(Arc::new(Mutex::new(HashMap::default()))); + data.insert::(Arc::new(Mutex::new(HashMap::default()))); + data.insert::(Arc::new(Mutex::new(config))); } // Run client