support local llama

This commit is contained in:
mii
2023-03-16 13:12:24 +00:00
parent bf90abe278
commit 380ee3da3e
8 changed files with 544 additions and 52 deletions

View File

@ -1,5 +1,5 @@
[package]
name = "ncb-ping"
name = "ncb-chat"
version = "0.1.0"
edition = "2021"
@ -11,6 +11,10 @@ serde = "1.0"
toml = "0.5"
async-trait = "0.1.57"
chrono = "0.4.23"
reqwest = { version = "0.11", features = ["json"] }
google-translate3 = "4.0.1+20220121"
gcp_auth = "0.5.0"
[dependencies.uuid]
version = "0.8"

View File

@ -2,8 +2,8 @@ version: '3'
services:
ncb-ping:
container_name: ncb-ping
image: ghcr.io/morioka22/ncb-ping:0.0.1
container_name: ncb-chat
image: ghcr.io/morioka22/ncb-chat:0.0.1
environment:
- NCB_TOKEN=YOUR_BOT_TOKEN
- NCB_APP_ID=YOUR_BOT_ID

View File

@ -1,7 +1,20 @@
use std::sync::Arc;
use serde::Deserialize;
use serenity::{futures::lock::Mutex, prelude::TypeMapKey};
#[derive(Deserialize)]
pub struct Config {
pub token: String,
pub application_id: u64,
pub llama_url: String,
pub openai_key: String,
pub chatgpt_allows: Vec<i64>,
pub chatgpt_forums: Vec<i64>,
}
pub struct ConfigData;
impl TypeMapKey for ConfigData {
type Value = Arc<Mutex<Config>>;
}

View File

@ -1,24 +1,96 @@
use std::{collections::HashMap, sync::Arc};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serenity::{
futures::lock::Mutex,
model::prelude::{ChannelId, Message, UserId},
model::prelude::{ChannelId, UserId},
prelude::TypeMapKey,
};
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TranslateRequest {
pub input: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct LlamaMessage {
pub role: String,
pub content: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct LlamaRequest {
pub messages: Vec<LlamaMessage>,
}
#[derive(Debug, Clone)]
pub struct Ping {
pub struct Llama {
pub channel: ChannelId,
pub user_id: UserId,
pub author: UserId,
pub message: Message,
pub time: DateTime<Utc>,
pub args: Vec<String>,
pub history: Vec<LlamaMessage>,
}
pub struct PingData;
pub struct LlamaData;
impl TypeMapKey for PingData {
type Value = Arc<Mutex<HashMap<UserId, Ping>>>;
impl TypeMapKey for LlamaData {
type Value = Arc<Mutex<HashMap<ChannelId, Llama>>>;
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChatGPTMessage {
pub role: String,
pub content: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Choice {
pub index: usize,
pub finish_reason: Option<String>,
pub message: ChatGPTMessage,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChatGPTResponse {
pub id: String,
pub object: String,
pub created: usize,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Usage,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ChatGPTRequest {
pub model: String,
pub messages: Vec<ChatGPTMessage>,
}
#[derive(Debug, Clone)]
pub struct ChatGPT {
pub channel: ChannelId,
pub history: Vec<ChatGPTMessage>,
}
#[derive(Debug, Clone)]
pub struct IndividualChatGPT {
pub user: UserId,
pub history: Vec<ChatGPTMessage>,
}
pub struct IndividualChatGPTData;
impl TypeMapKey for IndividualChatGPTData {
type Value = Arc<Mutex<HashMap<UserId, IndividualChatGPT>>>;
}
pub struct ChatGPTData;
impl TypeMapKey for ChatGPTData {
type Value = Arc<Mutex<HashMap<ChannelId, ChatGPT>>>;
}

View File

@ -1,18 +1,91 @@
use serenity::model::prelude::interaction::Interaction;
use serenity::{
async_trait,
client::{Context, EventHandler},
model::{
channel::Message,
gateway::Ready,
},
model::{channel::GuildChannel, channel::Message, gateway::Ready},
};
use crate::config::ConfigData;
use crate::data::*;
use crate::events;
pub struct Handler;
#[async_trait]
impl EventHandler for Handler {
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
if let Interaction::ApplicationCommand(command) = interaction.clone() {
let name = &*command.data.name;
if name == "llama" {
let message = command
.channel_id
.send_message(&ctx.http, |f| f.content("LLaMa Chat thread"))
.await
.unwrap();
let id = command
.channel_id
.create_public_thread(&ctx.http, message, |f| {
f.name("LLaMa Chat").auto_archive_duration(60)
})
.await
.unwrap()
.id;
let llama = Llama {
channel: id,
history: vec![],
};
let storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<LlamaData>()
.expect("Cannot get TTSStorage")
.clone()
};
storage_lock.lock().await.insert(id, llama);
}
if name == "chatgpt" {
let cs = {
let data_read = ctx.data.read().await;
data_read
.get::<ConfigData>()
.expect("Cannot get ConfigData")
.clone()
};
let config = cs.lock().await;
if !config.chatgpt_allows.contains(&(command.user.id.0 as i64)) {
return;
}
let message = command
.channel_id
.send_message(&ctx.http, |f| f.content("ChatGPT thread"))
.await
.unwrap();
let id = command
.channel_id
.create_public_thread(&ctx.http, message, |f| {
f.name("ChatGPT").auto_archive_duration(60)
})
.await
.unwrap()
.id;
let chatgpt = ChatGPT {
channel: id,
history: vec![],
};
let storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<ChatGPTData>()
.expect("Cannot get TTSStorage")
.clone()
};
storage_lock.lock().await.insert(id, chatgpt);
}
}
}
async fn message(&self, ctx: Context, message: Message) {
events::message_receive::message(ctx, message).await;
}

View File

@ -1,58 +1,339 @@
use chrono::Utc;
use serenity::{model::prelude::Message, prelude::Context};
use serenity::{
http::CacheHttp,
model::prelude::{GuildChannel, Message},
prelude::Context,
};
use crate::data::{PingData, Ping};
use crate::{
config::*,
data::{
ChatGPT, ChatGPTData, ChatGPTMessage, ChatGPTRequest, ChatGPTResponse, IndividualChatGPT,
IndividualChatGPTData, Llama, LlamaData, LlamaMessage, LlamaRequest, TranslateRequest,
},
};
async fn translate_en_ja(input: String) -> String {
let request = TranslateRequest { input };
let client = reqwest::Client::new();
match client
.post("http://localhost:8008/translate/enja")
.header(reqwest::header::CONTENT_TYPE, "application/json")
.body(serde_json::to_string(&request).unwrap())
.send()
.await
{
Ok(ok) => ok.text().await.expect("ERROR"),
Err(err) => {
panic!("Error")
}
}
}
async fn translate_ja_en(input: String) -> String {
let request = TranslateRequest { input };
let client = reqwest::Client::new();
match client
.post("http://localhost:8008/translate/jaen")
.header(reqwest::header::CONTENT_TYPE, "application/json")
.body(serde_json::to_string(&request).unwrap())
.send()
.await
{
Ok(ok) => ok.text().await.expect("ERROR"),
Err(err) => {
panic!("Error")
}
}
}
async fn chatgpt_request(input: ChatGPTRequest, key: String) -> ChatGPTResponse {
let client = reqwest::Client::new();
match client
.post("https://api.openai.com/v1/chat/completions")
.header(reqwest::header::CONTENT_TYPE, "application/json")
.header(reqwest::header::AUTHORIZATION, format!("Bearer {}", key))
.body(serde_json::to_string(&input).unwrap())
.send()
.await
{
Ok(ok) => {
let text = ok.text().await.unwrap();
println!("{}", text.clone());
let response: ChatGPTResponse = serde_json::from_str(&text).unwrap();
response
}
Err(err) => {
panic!("Error")
}
}
}
fn split_string_into_chunks(s: &str, chunk_size: usize) -> Vec<String> {
let mut chunks = Vec::new();
let mut start = 0;
let len = s.len();
while start < len {
let end = if start + chunk_size < len {
start + chunk_size
} else {
len
};
chunks.push(s[start..end].to_string());
start = end;
}
chunks
}
pub async fn llama(ctx: Context, message: Message) {}
pub async fn message(ctx: Context, message: Message) {
if message.author.bot {
return;
}
if message.content.starts_with(";") {
message
.reply(&ctx.http, "スキップしました。")
.await
.unwrap();
return;
}
let guild_id = message.guild(&ctx.cache);
if let None = guild_id {
return;
}
let storage_lock = {
println!("Event received: {}", message.content);
let config_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<PingData>()
.expect("Cannot get PingData")
.get::<ConfigData>()
.expect("Cannot get Config")
.clone()
};
if message.mentions.len() == 1 {
let user = message.mentions.first().unwrap();
let m = format!("PING {} ({}) 56(84) bytes of data.", user.name, user.id.0);
let ping_message = message.reply(&ctx.http, m).await.unwrap();
let config = config_lock.lock().await;
let ping = Ping {
channel: message.channel_id,
user_id: user.id,
author: message.author.id,
time: Utc::now(),
message: ping_message,
args: vec![]
let llama_storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<LlamaData>()
.expect("Cannot get LlamaData")
.clone()
};
let mut storage = storage_lock.lock().await;
storage.insert(user.id, ping.clone());
let mut llama_storage = llama_storage_lock.lock().await;
if let Some(mut llama) = llama_storage.clone().get_mut(&message.channel_id) {
let typing = message.channel_id.start_typing(&ctx.http).unwrap();
let mut history = llama.history.clone();
let text = translate_ja_en(message.content.clone()).await;
println!("{}", text);
history.push(LlamaMessage {
role: "user".to_string(),
content: text,
});
llama.history = history.clone();
llama_storage.insert(llama.channel, llama.clone());
let request = LlamaRequest { messages: history };
let client = reqwest::Client::new();
match client
.post("http://localhost:18080/")
.header(reqwest::header::CONTENT_TYPE, "application/json")
.body(serde_json::to_string(&request).unwrap())
.send()
.await
{
Ok(ok) => {
let response = ok.text().await.expect("ERROR");
let response = translate_en_ja(response).await;
println!("{}", response);
message
.channel_id
.send_message(&ctx.http, |f| f.content(response))
.await
.unwrap();
}
Err(err) => {
panic!("Error")
}
}
typing.stop().unwrap();
}
let chatgpt_storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<ChatGPTData>()
.expect("Cannot get ChatGPTData")
.clone()
};
let mut chatgpt_storage = chatgpt_storage_lock.lock().await;
// forum
let channel = message.channel(&ctx.http).await.unwrap();
if let Some(guild_channel) = channel.clone().guild() {
if let Some(parent) = guild_channel.parent_id {
if chatgpt_storage
.clone()
.get_mut(&message.channel_id)
.is_none()
&& config.chatgpt_forums.contains(&(parent.0 as i64))
{
let mut storage = storage_lock.lock().await;
if !storage.contains_key(&message.author.id) {
chatgpt_storage.insert(
guild_channel.id,
ChatGPT {
channel: channel.id(),
history: vec![],
},
);
}
}
}
if let Some(mut chatgpt) = chatgpt_storage.clone().get_mut(&message.channel_id) {
if message.content == "reset".to_string() {
chatgpt.history = vec![];
chatgpt_storage.insert(chatgpt.channel, chatgpt.clone());
message
.channel_id
.send_message(&ctx.http, |f| f.content("会話履歴をリセットしました。"))
.await
.unwrap();
return;
}
let ping = storage.get_mut(&message.author.id).unwrap();
if ping.channel == message.channel_id {
let user = ping.user_id.to_user(&ctx.http).await.unwrap();
let time = Utc::now() - ping.time;
ping.message.edit(&ctx.http, |f| f.content(format!("--- {} ping statistics ---\n1 packets transmitted, 1 received, 0% packet loss, time {}ms", user.name, time.num_milliseconds()))).await.unwrap();
message.channel_id.send_message(&ctx.http, |f| f.content(format!("<@{}>", ping.author.0))).await.unwrap();
storage.remove(&message.author.id);
let typing = message.channel_id.start_typing(&ctx.http).unwrap();
let mut history = chatgpt.history.clone();
history.push(ChatGPTMessage {
role: "user".to_string(),
content: message.content.clone(),
});
let request = ChatGPTRequest {
model: "gpt-3.5-turbo".to_string(),
messages: history.clone(),
};
let response = chatgpt_request(request, config.openai_key.clone()).await;
history.push(response.choices[0].message.clone());
chatgpt.history = history.clone();
chatgpt_storage.insert(chatgpt.channel, chatgpt.clone());
println!("Tokens: {:?}", response.usage.total_tokens);
let responses = format!(
"{}\ntokens: {}/4096",
response.choices[0].message.content.clone(),
response.usage.total_tokens
);
let responses = split_string_into_chunks(&responses, 2000);
let l = responses.len();
for response in responses {
if l > 1 {
message.reply(&ctx.http, response.replace("```", "")).await;
} else {
message.reply(&ctx.http, response).await;
}
}
typing.stop().unwrap();
}
let individual_chatgpt_storage_lock = {
let data_read = ctx.data.read().await;
data_read
.get::<IndividualChatGPTData>()
.expect("Cannot get IndividualChatGPTData")
.clone()
};
let mut individual_storage = individual_chatgpt_storage_lock.lock().await;
let bot_id = format!("<@{}>", config.application_id);
if message.content.starts_with(&bot_id) {
if !config
.chatgpt_allows
.contains(&(message.author.id.0 as i64))
{
message
.reply(&ctx.http, "権限がありません。")
.await
.unwrap();
return;
}
let text = message.content.replacen(&bot_id, "", 1);
let mut tmp = IndividualChatGPT {
user: message.author.id.clone(),
history: vec![],
};
let storage_tmp = individual_storage.clone();
let chatgpt = storage_tmp.get(&message.author.id.clone()).unwrap_or(&tmp);
let mut chatgpt = chatgpt.clone();
if text.trim() == "reset".to_string() {
individual_storage.insert(message.author.id.clone(), tmp);
message
.reply(&ctx.http, "会話履歴をリセットしました。")
.await
.unwrap();
return;
}
let typing = message.channel_id.start_typing(&ctx.http).unwrap();
let mut history = chatgpt.history.clone();
history.push(ChatGPTMessage {
role: "user".to_string(),
content: message.content.clone(),
});
let request = ChatGPTRequest {
model: "gpt-3.5-turbo".to_string(),
messages: history.clone(),
};
let response = chatgpt_request(request, config.openai_key.clone()).await;
history.push(response.choices[0].message.clone());
chatgpt.history = history.clone();
individual_storage.insert(chatgpt.user.clone(), chatgpt.clone());
println!("Tokens: {:?}", response.usage.total_tokens);
let responses = format!(
"{}\ntokens: {}/4096",
response.choices[0].message.content.clone(),
response.usage.total_tokens
);
let responses = split_string_into_chunks(&responses, 2000);
let l = responses.len();
for response in responses {
if l > 1 {
message.reply(&ctx.http, response.replace("```", "")).await;
} else {
message.reply(&ctx.http, response).await;
}
}
typing.stop().unwrap();
}
}

View File

@ -1,5 +1,35 @@
use serenity::{model::prelude::Ready, prelude::Context};
use serenity::{model::prelude::command::Command, model::prelude::Ready, prelude::Context};
pub async fn ready(_: Context, ready: Ready) {
pub async fn ready(ctx: Context, ready: Ready) {
println!("{} is connected!", ready.user.name);
let mut cosmo = true;
let mut chatgpt = true;
for command in Command::get_global_application_commands(&ctx.http)
.await
.unwrap()
{
if command.name == "cosmo" {
cosmo = false;
}
if command.name == "chatgpt" {
chatgpt = false;
}
}
if cosmo {
Command::create_global_application_command(&ctx.http, |command| {
command.name("cosmo").description("Start cosmo chat.")
})
.await
.unwrap();
}
if chatgpt {
Command::create_global_application_command(&ctx.http, |command| {
command.name("chatgpt").description("Start ChatGPT chat.")
})
.await
.unwrap();
}
}

View File

@ -5,8 +5,8 @@ mod events;
use std::{collections::HashMap, env, sync::Arc};
use config::Config;
use data::PingData;
use config::{Config, ConfigData};
use data::{ChatGPTData, IndividualChatGPTData, LlamaData};
use event_handler::Handler;
use serenity::{
client::Client, framework::StandardFramework, futures::lock::Mutex, prelude::GatewayIntents,
@ -40,10 +40,26 @@ async fn main() {
} else {
let token = env::var("NCB_TOKEN").unwrap();
let application_id = env::var("NCB_APP_ID").unwrap();
let llama_url = env::var("LLAMA_URL").unwrap();
let openai_key = env::var("OPENAI_KEY").unwrap();
let chatgpt_allows = env::var("CHATGPT_ALLOWS").unwrap().to_string();
let chatgpt_allows: Vec<i64> = chatgpt_allows
.split(",")
.map(|f| i64::from_str_radix(f, 10).unwrap())
.collect();
let chatgpt_forums = env::var("CHATGPT_FORUMS").unwrap().to_string();
let chatgpt_forums: Vec<i64> = chatgpt_forums
.split(",")
.map(|f| i64::from_str_radix(f, 10).unwrap())
.collect();
Config {
token,
application_id: u64::from_str_radix(&application_id, 10).unwrap(),
llama_url,
openai_key,
chatgpt_allows,
chatgpt_forums,
}
}
};
@ -56,7 +72,10 @@ async fn main() {
// Create TTS storage
{
let mut data = client.data.write().await;
data.insert::<PingData>(Arc::new(Mutex::new(HashMap::default())));
data.insert::<LlamaData>(Arc::new(Mutex::new(HashMap::default())));
data.insert::<ChatGPTData>(Arc::new(Mutex::new(HashMap::default())));
data.insert::<IndividualChatGPTData>(Arc::new(Mutex::new(HashMap::default())));
data.insert::<ConfigData>(Arc::new(Mutex::new(config)));
}
// Run client