From e1096eef1b5977ca44208ee2422c9ab071b26c76 Mon Sep 17 00:00:00 2001 From: Masato Imai Date: Tue, 18 Jun 2024 10:30:17 +0000 Subject: [PATCH] multithread, avx --- Cargo.lock | 1 + Cargo.toml | 5 +++-- src/main.rs | 34 ++++++++++++++++++++++------------ 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2cecd4b..1230649 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1905,6 +1905,7 @@ name = "tfhe-mutual-friends" version = "0.1.0" dependencies = [ "bincode", + "rayon", "rpassword", "serde", "serde_bytes", diff --git a/Cargo.toml b/Cargo.toml index b5bd1cf..f9c010d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64-unix"] } +tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64-unix", "nightly-avx512"] } vrchatapi = { git = "https://github.com/C0D3-M4513R/vrchatapi-rust.git", rev = "41255a7932d5626effec7421bad001703c977a31" } serde_json = "1.0" bincode = "1.3.3" @@ -12,4 +12,5 @@ serde_bytes = "0.11.14" serde = "1.0" zstd = "0.13.1" rpassword = "7.3" -tokio = { version = "1.38.0", features = ["full"] } \ No newline at end of file +tokio = { version = "1.38.0", features = ["full"] } +rayon = "1.10.0" diff --git a/src/main.rs b/src/main.rs index 4057966..bfb5d30 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,9 @@ mod vrchat; use std::fs::File; use std::io::{Read, Write}; +use std::sync::Mutex; +use rayon::prelude::*; use rpassword::read_password; use serde::{Deserialize, Serialize}; use tfhe::{prelude::*, ClientKey, CompressedFheUint128, FheBool, FheUint128, ServerKey}; @@ -37,29 +39,37 @@ async fn calc_phase() -> Result<(), Box> { println!("相手の暗号化済みフレンドリストの読み込み中..."); let (remote_data, server_keys) = load_remote_data().await.unwrap(); + rayon::broadcast(|_| set_server_key(server_keys.clone())); set_server_key(server_keys); println!("暗号文の展開中(この処理には時間がかかります)..."); - let remote_friends: Vec = remote_data.friends.iter().map(|f| &f.value).map(|v| bincode::deserialize(&v).unwrap()).collect(); - let remote_friends: Vec = remote_friends.iter().map(|f| f.decompress()).collect(); + let remote_friends: Vec = remote_data.friends.par_iter().map(|f| &f.value).map(|v| bincode::deserialize(&v).unwrap()).collect(); + let remote_friends: Vec = remote_friends.par_iter().map(|f| f.decompress()).collect(); + // calc time println!("共通フレンドの計算中..."); - let mut compared: Vec> = Vec::new(); - for i in 0..remote_friends.len() { - compared.push(Vec::new()); - for j in 0..my_friends.len() { - compared[i].push(remote_friends[i].eq(my_friends[j])); - println!("{}/{}: {}/{}", i, remote_friends.len(), j, my_friends.len()); - } - } + let now = std::time::Instant::now(); + let all = my_friends.len() * remote_friends.len(); + let count = Mutex::new(0); + let compared: Vec> = remote_friends.par_iter().enumerate().map(|(i, remote_friend)| { + (0..my_friends.len()).into_par_iter().map(|j| { + let result = remote_friend.eq(my_friends[j]); + let mut count = count.lock().unwrap(); + *count += 1; + println!("{} / {}", *count, all); + result + }).collect() + }).collect(); println!("圧縮中..."); - let sample = compared[0][0].clone(); let mut result: Vec = Vec::new(); for res in compared { - result.push(res.iter().fold(sample.clone(), |acc, x| (acc | x))); + result.push(res.par_iter().cloned().reduce_with(|acc, x| acc | x).unwrap()); } + let elapsed = now.elapsed(); + println!("計算時間: {}.{:03}秒", elapsed.as_secs(), elapsed.subsec_millis()); + println!("データの整形中..."); let serialized_compared_friends: SerializedComparedFriends = SerializedComparedFriends { friends: result.iter().map(|f| SerializedComparedFriend { value: bincode::serialize(f).unwrap() }).collect() };