mirror of
https://github.com/mii443/ncb-chat.git
synced 2025-08-22 16:15:27 +00:00
support local llama
This commit is contained in:
@ -1,5 +1,5 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "ncb-ping"
|
name = "ncb-chat"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
@ -11,6 +11,10 @@ serde = "1.0"
|
|||||||
toml = "0.5"
|
toml = "0.5"
|
||||||
async-trait = "0.1.57"
|
async-trait = "0.1.57"
|
||||||
chrono = "0.4.23"
|
chrono = "0.4.23"
|
||||||
|
reqwest = { version = "0.11", features = ["json"] }
|
||||||
|
google-translate3 = "4.0.1+20220121"
|
||||||
|
gcp_auth = "0.5.0"
|
||||||
|
|
||||||
|
|
||||||
[dependencies.uuid]
|
[dependencies.uuid]
|
||||||
version = "0.8"
|
version = "0.8"
|
||||||
|
@ -2,8 +2,8 @@ version: '3'
|
|||||||
|
|
||||||
services:
|
services:
|
||||||
ncb-ping:
|
ncb-ping:
|
||||||
container_name: ncb-ping
|
container_name: ncb-chat
|
||||||
image: ghcr.io/morioka22/ncb-ping:0.0.1
|
image: ghcr.io/morioka22/ncb-chat:0.0.1
|
||||||
environment:
|
environment:
|
||||||
- NCB_TOKEN=YOUR_BOT_TOKEN
|
- NCB_TOKEN=YOUR_BOT_TOKEN
|
||||||
- NCB_APP_ID=YOUR_BOT_ID
|
- NCB_APP_ID=YOUR_BOT_ID
|
||||||
|
@ -1,7 +1,20 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
use serenity::{futures::lock::Mutex, prelude::TypeMapKey};
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub token: String,
|
pub token: String,
|
||||||
pub application_id: u64,
|
pub application_id: u64,
|
||||||
|
pub llama_url: String,
|
||||||
|
pub openai_key: String,
|
||||||
|
pub chatgpt_allows: Vec<i64>,
|
||||||
|
pub chatgpt_forums: Vec<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ConfigData;
|
||||||
|
|
||||||
|
impl TypeMapKey for ConfigData {
|
||||||
|
type Value = Arc<Mutex<Config>>;
|
||||||
}
|
}
|
||||||
|
94
src/data.rs
94
src/data.rs
@ -1,24 +1,96 @@
|
|||||||
use std::{collections::HashMap, sync::Arc};
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
use chrono::{DateTime, Utc};
|
use serde::{Deserialize, Serialize};
|
||||||
use serenity::{
|
use serenity::{
|
||||||
futures::lock::Mutex,
|
futures::lock::Mutex,
|
||||||
model::prelude::{ChannelId, Message, UserId},
|
model::prelude::{ChannelId, UserId},
|
||||||
prelude::TypeMapKey,
|
prelude::TypeMapKey,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
|
pub struct TranslateRequest {
|
||||||
|
pub input: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
|
pub struct LlamaMessage {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
|
pub struct LlamaRequest {
|
||||||
|
pub messages: Vec<LlamaMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Ping {
|
pub struct Llama {
|
||||||
pub channel: ChannelId,
|
pub channel: ChannelId,
|
||||||
pub user_id: UserId,
|
pub history: Vec<LlamaMessage>,
|
||||||
pub author: UserId,
|
|
||||||
pub message: Message,
|
|
||||||
pub time: DateTime<Utc>,
|
|
||||||
pub args: Vec<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct PingData;
|
pub struct LlamaData;
|
||||||
|
|
||||||
impl TypeMapKey for PingData {
|
impl TypeMapKey for LlamaData {
|
||||||
type Value = Arc<Mutex<HashMap<UserId, Ping>>>;
|
type Value = Arc<Mutex<HashMap<ChannelId, Llama>>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
|
pub struct ChatGPTMessage {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: usize,
|
||||||
|
pub completion_tokens: usize,
|
||||||
|
pub total_tokens: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
|
pub struct Choice {
|
||||||
|
pub index: usize,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
pub message: ChatGPTMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
|
pub struct ChatGPTResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: usize,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<Choice>,
|
||||||
|
pub usage: Usage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
|
pub struct ChatGPTRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<ChatGPTMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ChatGPT {
|
||||||
|
pub channel: ChannelId,
|
||||||
|
pub history: Vec<ChatGPTMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct IndividualChatGPT {
|
||||||
|
pub user: UserId,
|
||||||
|
pub history: Vec<ChatGPTMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct IndividualChatGPTData;
|
||||||
|
|
||||||
|
impl TypeMapKey for IndividualChatGPTData {
|
||||||
|
type Value = Arc<Mutex<HashMap<UserId, IndividualChatGPT>>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ChatGPTData;
|
||||||
|
|
||||||
|
impl TypeMapKey for ChatGPTData {
|
||||||
|
type Value = Arc<Mutex<HashMap<ChannelId, ChatGPT>>>;
|
||||||
}
|
}
|
||||||
|
@ -1,18 +1,91 @@
|
|||||||
|
use serenity::model::prelude::interaction::Interaction;
|
||||||
use serenity::{
|
use serenity::{
|
||||||
async_trait,
|
async_trait,
|
||||||
client::{Context, EventHandler},
|
client::{Context, EventHandler},
|
||||||
model::{
|
model::{channel::GuildChannel, channel::Message, gateway::Ready},
|
||||||
channel::Message,
|
|
||||||
gateway::Ready,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::config::ConfigData;
|
||||||
|
use crate::data::*;
|
||||||
use crate::events;
|
use crate::events;
|
||||||
|
|
||||||
pub struct Handler;
|
pub struct Handler;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl EventHandler for Handler {
|
impl EventHandler for Handler {
|
||||||
|
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
|
||||||
|
if let Interaction::ApplicationCommand(command) = interaction.clone() {
|
||||||
|
let name = &*command.data.name;
|
||||||
|
if name == "llama" {
|
||||||
|
let message = command
|
||||||
|
.channel_id
|
||||||
|
.send_message(&ctx.http, |f| f.content("LLaMa Chat thread"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let id = command
|
||||||
|
.channel_id
|
||||||
|
.create_public_thread(&ctx.http, message, |f| {
|
||||||
|
f.name("LLaMa Chat").auto_archive_duration(60)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.id;
|
||||||
|
let llama = Llama {
|
||||||
|
channel: id,
|
||||||
|
history: vec![],
|
||||||
|
};
|
||||||
|
let storage_lock = {
|
||||||
|
let data_read = ctx.data.read().await;
|
||||||
|
data_read
|
||||||
|
.get::<LlamaData>()
|
||||||
|
.expect("Cannot get TTSStorage")
|
||||||
|
.clone()
|
||||||
|
};
|
||||||
|
storage_lock.lock().await.insert(id, llama);
|
||||||
|
}
|
||||||
|
|
||||||
|
if name == "chatgpt" {
|
||||||
|
let cs = {
|
||||||
|
let data_read = ctx.data.read().await;
|
||||||
|
data_read
|
||||||
|
.get::<ConfigData>()
|
||||||
|
.expect("Cannot get ConfigData")
|
||||||
|
.clone()
|
||||||
|
};
|
||||||
|
let config = cs.lock().await;
|
||||||
|
|
||||||
|
if !config.chatgpt_allows.contains(&(command.user.id.0 as i64)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let message = command
|
||||||
|
.channel_id
|
||||||
|
.send_message(&ctx.http, |f| f.content("ChatGPT thread"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let id = command
|
||||||
|
.channel_id
|
||||||
|
.create_public_thread(&ctx.http, message, |f| {
|
||||||
|
f.name("ChatGPT").auto_archive_duration(60)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.id;
|
||||||
|
let chatgpt = ChatGPT {
|
||||||
|
channel: id,
|
||||||
|
history: vec![],
|
||||||
|
};
|
||||||
|
let storage_lock = {
|
||||||
|
let data_read = ctx.data.read().await;
|
||||||
|
data_read
|
||||||
|
.get::<ChatGPTData>()
|
||||||
|
.expect("Cannot get TTSStorage")
|
||||||
|
.clone()
|
||||||
|
};
|
||||||
|
storage_lock.lock().await.insert(id, chatgpt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
async fn message(&self, ctx: Context, message: Message) {
|
async fn message(&self, ctx: Context, message: Message) {
|
||||||
events::message_receive::message(ctx, message).await;
|
events::message_receive::message(ctx, message).await;
|
||||||
}
|
}
|
||||||
|
@ -1,58 +1,339 @@
|
|||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use serenity::{model::prelude::Message, prelude::Context};
|
use serenity::{
|
||||||
|
http::CacheHttp,
|
||||||
|
model::prelude::{GuildChannel, Message},
|
||||||
|
prelude::Context,
|
||||||
|
};
|
||||||
|
|
||||||
use crate::data::{PingData, Ping};
|
use crate::{
|
||||||
|
config::*,
|
||||||
|
data::{
|
||||||
|
ChatGPT, ChatGPTData, ChatGPTMessage, ChatGPTRequest, ChatGPTResponse, IndividualChatGPT,
|
||||||
|
IndividualChatGPTData, Llama, LlamaData, LlamaMessage, LlamaRequest, TranslateRequest,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
async fn translate_en_ja(input: String) -> String {
|
||||||
|
let request = TranslateRequest { input };
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
match client
|
||||||
|
.post("http://localhost:8008/translate/enja")
|
||||||
|
.header(reqwest::header::CONTENT_TYPE, "application/json")
|
||||||
|
.body(serde_json::to_string(&request).unwrap())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(ok) => ok.text().await.expect("ERROR"),
|
||||||
|
Err(err) => {
|
||||||
|
panic!("Error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn translate_ja_en(input: String) -> String {
|
||||||
|
let request = TranslateRequest { input };
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
match client
|
||||||
|
.post("http://localhost:8008/translate/jaen")
|
||||||
|
.header(reqwest::header::CONTENT_TYPE, "application/json")
|
||||||
|
.body(serde_json::to_string(&request).unwrap())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(ok) => ok.text().await.expect("ERROR"),
|
||||||
|
Err(err) => {
|
||||||
|
panic!("Error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chatgpt_request(input: ChatGPTRequest, key: String) -> ChatGPTResponse {
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
match client
|
||||||
|
.post("https://api.openai.com/v1/chat/completions")
|
||||||
|
.header(reqwest::header::CONTENT_TYPE, "application/json")
|
||||||
|
.header(reqwest::header::AUTHORIZATION, format!("Bearer {}", key))
|
||||||
|
.body(serde_json::to_string(&input).unwrap())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(ok) => {
|
||||||
|
let text = ok.text().await.unwrap();
|
||||||
|
println!("{}", text.clone());
|
||||||
|
let response: ChatGPTResponse = serde_json::from_str(&text).unwrap();
|
||||||
|
response
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
panic!("Error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn split_string_into_chunks(s: &str, chunk_size: usize) -> Vec<String> {
|
||||||
|
let mut chunks = Vec::new();
|
||||||
|
let mut start = 0;
|
||||||
|
let len = s.len();
|
||||||
|
|
||||||
|
while start < len {
|
||||||
|
let end = if start + chunk_size < len {
|
||||||
|
start + chunk_size
|
||||||
|
} else {
|
||||||
|
len
|
||||||
|
};
|
||||||
|
chunks.push(s[start..end].to_string());
|
||||||
|
start = end;
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn llama(ctx: Context, message: Message) {}
|
||||||
|
|
||||||
pub async fn message(ctx: Context, message: Message) {
|
pub async fn message(ctx: Context, message: Message) {
|
||||||
if message.author.bot {
|
if message.author.bot {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if message.content.starts_with(";") {
|
||||||
|
message
|
||||||
|
.reply(&ctx.http, "スキップしました。")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let guild_id = message.guild(&ctx.cache);
|
let guild_id = message.guild(&ctx.cache);
|
||||||
|
|
||||||
if let None = guild_id {
|
if let None = guild_id {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let storage_lock = {
|
println!("Event received: {}", message.content);
|
||||||
|
let config_lock = {
|
||||||
let data_read = ctx.data.read().await;
|
let data_read = ctx.data.read().await;
|
||||||
data_read
|
data_read
|
||||||
.get::<PingData>()
|
.get::<ConfigData>()
|
||||||
.expect("Cannot get PingData")
|
.expect("Cannot get Config")
|
||||||
.clone()
|
.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
if message.mentions.len() == 1 {
|
let config = config_lock.lock().await;
|
||||||
let user = message.mentions.first().unwrap();
|
|
||||||
let m = format!("PING {} ({}) 56(84) bytes of data.", user.name, user.id.0);
|
|
||||||
let ping_message = message.reply(&ctx.http, m).await.unwrap();
|
|
||||||
|
|
||||||
let ping = Ping {
|
let llama_storage_lock = {
|
||||||
channel: message.channel_id,
|
let data_read = ctx.data.read().await;
|
||||||
user_id: user.id,
|
data_read
|
||||||
author: message.author.id,
|
.get::<LlamaData>()
|
||||||
time: Utc::now(),
|
.expect("Cannot get LlamaData")
|
||||||
message: ping_message,
|
.clone()
|
||||||
args: vec![]
|
};
|
||||||
};
|
|
||||||
|
|
||||||
let mut storage = storage_lock.lock().await;
|
let mut llama_storage = llama_storage_lock.lock().await;
|
||||||
storage.insert(user.id, ping.clone());
|
|
||||||
|
if let Some(mut llama) = llama_storage.clone().get_mut(&message.channel_id) {
|
||||||
|
let typing = message.channel_id.start_typing(&ctx.http).unwrap();
|
||||||
|
let mut history = llama.history.clone();
|
||||||
|
let text = translate_ja_en(message.content.clone()).await;
|
||||||
|
println!("{}", text);
|
||||||
|
history.push(LlamaMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: text,
|
||||||
|
});
|
||||||
|
llama.history = history.clone();
|
||||||
|
llama_storage.insert(llama.channel, llama.clone());
|
||||||
|
|
||||||
|
let request = LlamaRequest { messages: history };
|
||||||
|
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
match client
|
||||||
|
.post("http://localhost:18080/")
|
||||||
|
.header(reqwest::header::CONTENT_TYPE, "application/json")
|
||||||
|
.body(serde_json::to_string(&request).unwrap())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(ok) => {
|
||||||
|
let response = ok.text().await.expect("ERROR");
|
||||||
|
let response = translate_en_ja(response).await;
|
||||||
|
println!("{}", response);
|
||||||
|
message
|
||||||
|
.channel_id
|
||||||
|
.send_message(&ctx.http, |f| f.content(response))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
panic!("Error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typing.stop().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
let chatgpt_storage_lock = {
|
||||||
let mut storage = storage_lock.lock().await;
|
let data_read = ctx.data.read().await;
|
||||||
if !storage.contains_key(&message.author.id) {
|
data_read
|
||||||
|
.get::<ChatGPTData>()
|
||||||
|
.expect("Cannot get ChatGPTData")
|
||||||
|
.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut chatgpt_storage = chatgpt_storage_lock.lock().await;
|
||||||
|
|
||||||
|
// forum
|
||||||
|
let channel = message.channel(&ctx.http).await.unwrap();
|
||||||
|
if let Some(guild_channel) = channel.clone().guild() {
|
||||||
|
if let Some(parent) = guild_channel.parent_id {
|
||||||
|
if chatgpt_storage
|
||||||
|
.clone()
|
||||||
|
.get_mut(&message.channel_id)
|
||||||
|
.is_none()
|
||||||
|
&& config.chatgpt_forums.contains(&(parent.0 as i64))
|
||||||
|
{
|
||||||
|
chatgpt_storage.insert(
|
||||||
|
guild_channel.id,
|
||||||
|
ChatGPT {
|
||||||
|
channel: channel.id(),
|
||||||
|
history: vec![],
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(mut chatgpt) = chatgpt_storage.clone().get_mut(&message.channel_id) {
|
||||||
|
if message.content == "reset".to_string() {
|
||||||
|
chatgpt.history = vec![];
|
||||||
|
chatgpt_storage.insert(chatgpt.channel, chatgpt.clone());
|
||||||
|
message
|
||||||
|
.channel_id
|
||||||
|
.send_message(&ctx.http, |f| f.content("会話履歴をリセットしました。"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
let ping = storage.get_mut(&message.author.id).unwrap();
|
|
||||||
|
|
||||||
if ping.channel == message.channel_id {
|
let typing = message.channel_id.start_typing(&ctx.http).unwrap();
|
||||||
let user = ping.user_id.to_user(&ctx.http).await.unwrap();
|
let mut history = chatgpt.history.clone();
|
||||||
let time = Utc::now() - ping.time;
|
|
||||||
ping.message.edit(&ctx.http, |f| f.content(format!("--- {} ping statistics ---\n1 packets transmitted, 1 received, 0% packet loss, time {}ms", user.name, time.num_milliseconds()))).await.unwrap();
|
history.push(ChatGPTMessage {
|
||||||
message.channel_id.send_message(&ctx.http, |f| f.content(format!("<@{}>", ping.author.0))).await.unwrap();
|
role: "user".to_string(),
|
||||||
storage.remove(&message.author.id);
|
content: message.content.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let request = ChatGPTRequest {
|
||||||
|
model: "gpt-3.5-turbo".to_string(),
|
||||||
|
messages: history.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = chatgpt_request(request, config.openai_key.clone()).await;
|
||||||
|
|
||||||
|
history.push(response.choices[0].message.clone());
|
||||||
|
|
||||||
|
chatgpt.history = history.clone();
|
||||||
|
|
||||||
|
chatgpt_storage.insert(chatgpt.channel, chatgpt.clone());
|
||||||
|
|
||||||
|
println!("Tokens: {:?}", response.usage.total_tokens);
|
||||||
|
let responses = format!(
|
||||||
|
"{}\ntokens: {}/4096",
|
||||||
|
response.choices[0].message.content.clone(),
|
||||||
|
response.usage.total_tokens
|
||||||
|
);
|
||||||
|
let responses = split_string_into_chunks(&responses, 2000);
|
||||||
|
|
||||||
|
let l = responses.len();
|
||||||
|
for response in responses {
|
||||||
|
if l > 1 {
|
||||||
|
message.reply(&ctx.http, response.replace("```", "")).await;
|
||||||
|
} else {
|
||||||
|
message.reply(&ctx.http, response).await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
typing.stop().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let individual_chatgpt_storage_lock = {
|
||||||
|
let data_read = ctx.data.read().await;
|
||||||
|
data_read
|
||||||
|
.get::<IndividualChatGPTData>()
|
||||||
|
.expect("Cannot get IndividualChatGPTData")
|
||||||
|
.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut individual_storage = individual_chatgpt_storage_lock.lock().await;
|
||||||
|
|
||||||
|
let bot_id = format!("<@{}>", config.application_id);
|
||||||
|
if message.content.starts_with(&bot_id) {
|
||||||
|
if !config
|
||||||
|
.chatgpt_allows
|
||||||
|
.contains(&(message.author.id.0 as i64))
|
||||||
|
{
|
||||||
|
message
|
||||||
|
.reply(&ctx.http, "権限がありません。")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let text = message.content.replacen(&bot_id, "", 1);
|
||||||
|
|
||||||
|
let mut tmp = IndividualChatGPT {
|
||||||
|
user: message.author.id.clone(),
|
||||||
|
history: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
let storage_tmp = individual_storage.clone();
|
||||||
|
let chatgpt = storage_tmp.get(&message.author.id.clone()).unwrap_or(&tmp);
|
||||||
|
let mut chatgpt = chatgpt.clone();
|
||||||
|
|
||||||
|
if text.trim() == "reset".to_string() {
|
||||||
|
individual_storage.insert(message.author.id.clone(), tmp);
|
||||||
|
message
|
||||||
|
.reply(&ctx.http, "会話履歴をリセットしました。")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let typing = message.channel_id.start_typing(&ctx.http).unwrap();
|
||||||
|
let mut history = chatgpt.history.clone();
|
||||||
|
|
||||||
|
history.push(ChatGPTMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: message.content.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let request = ChatGPTRequest {
|
||||||
|
model: "gpt-3.5-turbo".to_string(),
|
||||||
|
messages: history.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = chatgpt_request(request, config.openai_key.clone()).await;
|
||||||
|
|
||||||
|
history.push(response.choices[0].message.clone());
|
||||||
|
|
||||||
|
chatgpt.history = history.clone();
|
||||||
|
|
||||||
|
individual_storage.insert(chatgpt.user.clone(), chatgpt.clone());
|
||||||
|
|
||||||
|
println!("Tokens: {:?}", response.usage.total_tokens);
|
||||||
|
let responses = format!(
|
||||||
|
"{}\ntokens: {}/4096",
|
||||||
|
response.choices[0].message.content.clone(),
|
||||||
|
response.usage.total_tokens
|
||||||
|
);
|
||||||
|
let responses = split_string_into_chunks(&responses, 2000);
|
||||||
|
|
||||||
|
let l = responses.len();
|
||||||
|
for response in responses {
|
||||||
|
if l > 1 {
|
||||||
|
message.reply(&ctx.http, response.replace("```", "")).await;
|
||||||
|
} else {
|
||||||
|
message.reply(&ctx.http, response).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typing.stop().unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,35 @@
|
|||||||
use serenity::{model::prelude::Ready, prelude::Context};
|
use serenity::{model::prelude::command::Command, model::prelude::Ready, prelude::Context};
|
||||||
|
|
||||||
pub async fn ready(_: Context, ready: Ready) {
|
pub async fn ready(ctx: Context, ready: Ready) {
|
||||||
println!("{} is connected!", ready.user.name);
|
println!("{} is connected!", ready.user.name);
|
||||||
|
|
||||||
|
let mut cosmo = true;
|
||||||
|
let mut chatgpt = true;
|
||||||
|
|
||||||
|
for command in Command::get_global_application_commands(&ctx.http)
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
{
|
||||||
|
if command.name == "cosmo" {
|
||||||
|
cosmo = false;
|
||||||
|
}
|
||||||
|
if command.name == "chatgpt" {
|
||||||
|
chatgpt = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cosmo {
|
||||||
|
Command::create_global_application_command(&ctx.http, |command| {
|
||||||
|
command.name("cosmo").description("Start cosmo chat.")
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
if chatgpt {
|
||||||
|
Command::create_global_application_command(&ctx.http, |command| {
|
||||||
|
command.name("chatgpt").description("Start ChatGPT chat.")
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
25
src/main.rs
25
src/main.rs
@ -5,8 +5,8 @@ mod events;
|
|||||||
|
|
||||||
use std::{collections::HashMap, env, sync::Arc};
|
use std::{collections::HashMap, env, sync::Arc};
|
||||||
|
|
||||||
use config::Config;
|
use config::{Config, ConfigData};
|
||||||
use data::PingData;
|
use data::{ChatGPTData, IndividualChatGPTData, LlamaData};
|
||||||
use event_handler::Handler;
|
use event_handler::Handler;
|
||||||
use serenity::{
|
use serenity::{
|
||||||
client::Client, framework::StandardFramework, futures::lock::Mutex, prelude::GatewayIntents,
|
client::Client, framework::StandardFramework, futures::lock::Mutex, prelude::GatewayIntents,
|
||||||
@ -40,10 +40,26 @@ async fn main() {
|
|||||||
} else {
|
} else {
|
||||||
let token = env::var("NCB_TOKEN").unwrap();
|
let token = env::var("NCB_TOKEN").unwrap();
|
||||||
let application_id = env::var("NCB_APP_ID").unwrap();
|
let application_id = env::var("NCB_APP_ID").unwrap();
|
||||||
|
let llama_url = env::var("LLAMA_URL").unwrap();
|
||||||
|
let openai_key = env::var("OPENAI_KEY").unwrap();
|
||||||
|
let chatgpt_allows = env::var("CHATGPT_ALLOWS").unwrap().to_string();
|
||||||
|
let chatgpt_allows: Vec<i64> = chatgpt_allows
|
||||||
|
.split(",")
|
||||||
|
.map(|f| i64::from_str_radix(f, 10).unwrap())
|
||||||
|
.collect();
|
||||||
|
let chatgpt_forums = env::var("CHATGPT_FORUMS").unwrap().to_string();
|
||||||
|
let chatgpt_forums: Vec<i64> = chatgpt_forums
|
||||||
|
.split(",")
|
||||||
|
.map(|f| i64::from_str_radix(f, 10).unwrap())
|
||||||
|
.collect();
|
||||||
|
|
||||||
Config {
|
Config {
|
||||||
token,
|
token,
|
||||||
application_id: u64::from_str_radix(&application_id, 10).unwrap(),
|
application_id: u64::from_str_radix(&application_id, 10).unwrap(),
|
||||||
|
llama_url,
|
||||||
|
openai_key,
|
||||||
|
chatgpt_allows,
|
||||||
|
chatgpt_forums,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -56,7 +72,10 @@ async fn main() {
|
|||||||
// Create TTS storage
|
// Create TTS storage
|
||||||
{
|
{
|
||||||
let mut data = client.data.write().await;
|
let mut data = client.data.write().await;
|
||||||
data.insert::<PingData>(Arc::new(Mutex::new(HashMap::default())));
|
data.insert::<LlamaData>(Arc::new(Mutex::new(HashMap::default())));
|
||||||
|
data.insert::<ChatGPTData>(Arc::new(Mutex::new(HashMap::default())));
|
||||||
|
data.insert::<IndividualChatGPTData>(Arc::new(Mutex::new(HashMap::default())));
|
||||||
|
data.insert::<ConfigData>(Arc::new(Mutex::new(config)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run client
|
// Run client
|
||||||
|
Reference in New Issue
Block a user