diff --git a/Cargo.toml b/Cargo.toml index 01516b2..bf905d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,8 @@ chrono = "0.4.23" reqwest = { version = "0.11", features = ["json"] } google-translate3 = "4.0.1+20220121" gcp_auth = "0.5.0" - +tungstenite = "0.18.0" +url = "*" [dependencies.uuid] version = "0.8" @@ -26,5 +27,5 @@ features = ["builder", "cache", "client", "gateway", "model", "utils", "unstable [dependencies.tokio] -version = "1.0" +version = "1.26.0" features = ["macros", "rt-multi-thread", "sync"] diff --git a/src/events/message_receive.rs b/src/events/message_receive.rs index 107c998..a9a651f 100644 --- a/src/events/message_receive.rs +++ b/src/events/message_receive.rs @@ -1,9 +1,10 @@ use chrono::Utc; use serenity::{ http::CacheHttp, - model::prelude::{GuildChannel, Message}, + model::prelude::{GuildChannel, Message, MessageId, ReactionType}, prelude::Context, }; +use url::Url; use crate::{ config::*, @@ -130,9 +131,21 @@ pub async fn message(ctx: Context, message: Message) { let mut llama_storage = llama_storage_lock.lock().await; if let Some(mut llama) = llama_storage.clone().get_mut(&message.channel_id) { - let typing = message.channel_id.start_typing(&ctx.http).unwrap(); let mut history = llama.history.clone(); - let text = translate_ja_en(message.content.clone()).await; + + if message.content.trim() == "reset".to_string() { + llama.history = vec![]; + llama_storage.insert(llama.channel, llama.clone()); + message + .reply(&ctx.http, "会話履歴をリセットしました。") + .await + .unwrap(); + return; + } + + let typing = message.channel_id.start_typing(&ctx.http).unwrap(); + //let text = translate_ja_en(message.content.clone()).await; + let text = message.content.clone(); println!("{}", text); history.push(LlamaMessage { role: "user".to_string(), @@ -141,32 +154,87 @@ pub async fn message(ctx: Context, message: Message) { llama.history = history.clone(); llama_storage.insert(llama.channel, llama.clone()); - let request = LlamaRequest { messages: history }; + let request = LlamaRequest { + messages: history.clone(), + }; + let (mut socket, response) = tungstenite::connect(Url::parse("ws://192.168.0.19:18080/").unwrap()).expect("Can't connect to websocket server"); + + socket.write_message(tungstenite::Message::Text(serde_json::to_string(&request).unwrap().into())).unwrap(); + + let mut buffer = String::default(); + let rate = 3; + let mut count = 0; + + let mut response_message: Message = if let Ok(s) = socket.read_message() { + if let tungstenite::Message::Text(msg) = s { + buffer = buffer + &msg.to_string(); + message.channel_id.send_message(&ctx.http, |f| f.content(msg)).await.unwrap() + } else { + panic!("cannot read message"); + } + } else { + panic!("cannot read message"); + }; + + loop { + if let Ok(s) = socket.read_message() { + match s { + tungstenite::Message::Text(msg) => { + buffer = buffer + &msg.to_string(); + println!("{}", msg.to_string()); + + if count == rate { + response_message.edit(&ctx.http, |f| f.content(buffer.clone())).await.unwrap(); + count = 0; + } + count += 1; + } + _ => { + break; + } + } + } + } + + response_message.edit(&ctx.http, |f| f.content(buffer.clone())).await.unwrap(); + response_message.react(&ctx.http, ReactionType::Unicode("✅".to_string())).await.unwrap(); + + typing.stop().unwrap(); +/* let client = reqwest::Client::new(); match client - .post("http://localhost:18080/") + .get("http://localhost:18080/") .header(reqwest::header::CONTENT_TYPE, "application/json") .body(serde_json::to_string(&request).unwrap()) .send() .await { Ok(ok) => { - let response = ok.text().await.expect("ERROR"); - let response = translate_en_ja(response).await; - println!("{}", response); + let response_en = ok.text().await.expect("ERROR").trim().to_string(); + //let response = translate_en_ja(response_en.clone()).await; + //println!("JA: {}", response); + println!("EN: {}", response_en); + + history.push(LlamaMessage { + role: "ai".to_string(), + content: response_en.clone(), + }); + + llama.history = history; + + llama_storage.insert(llama.channel, llama.clone()); + message .channel_id - .send_message(&ctx.http, |f| f.content(response)) + .send_message(&ctx.http, |f| f.content(response_en)) .await .unwrap(); } Err(err) => { panic!("Error") } - } - - typing.stop().unwrap(); + } */ } let chatgpt_storage_lock = { diff --git a/src/events/ready.rs b/src/events/ready.rs index 0957cce..f907b29 100644 --- a/src/events/ready.rs +++ b/src/events/ready.rs @@ -3,24 +3,24 @@ use serenity::{model::prelude::command::Command, model::prelude::Ready, prelude: pub async fn ready(ctx: Context, ready: Ready) { println!("{} is connected!", ready.user.name); - let mut cosmo = true; + let mut llama = true; let mut chatgpt = true; for command in Command::get_global_application_commands(&ctx.http) .await .unwrap() { - if command.name == "cosmo" { - cosmo = false; + if command.name == "llama" { + llama = false; } if command.name == "chatgpt" { chatgpt = false; } } - if cosmo { + if llama { Command::create_global_application_command(&ctx.http, |command| { - command.name("cosmo").description("Start cosmo chat.") + command.name("llama").description("Start llama chat.") }) .await .unwrap();