multithread, avx

This commit is contained in:
Masato Imai
2024-06-18 10:30:17 +00:00
parent 46fa9f40d4
commit e1096eef1b
3 changed files with 26 additions and 14 deletions

1
Cargo.lock generated
View File

@ -1905,6 +1905,7 @@ name = "tfhe-mutual-friends"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"bincode", "bincode",
"rayon",
"rpassword", "rpassword",
"serde", "serde",
"serde_bytes", "serde_bytes",

View File

@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64-unix"] } tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64-unix", "nightly-avx512"] }
vrchatapi = { git = "https://github.com/C0D3-M4513R/vrchatapi-rust.git", rev = "41255a7932d5626effec7421bad001703c977a31" } vrchatapi = { git = "https://github.com/C0D3-M4513R/vrchatapi-rust.git", rev = "41255a7932d5626effec7421bad001703c977a31" }
serde_json = "1.0" serde_json = "1.0"
bincode = "1.3.3" bincode = "1.3.3"
@ -12,4 +12,5 @@ serde_bytes = "0.11.14"
serde = "1.0" serde = "1.0"
zstd = "0.13.1" zstd = "0.13.1"
rpassword = "7.3" rpassword = "7.3"
tokio = { version = "1.38.0", features = ["full"] } tokio = { version = "1.38.0", features = ["full"] }
rayon = "1.10.0"

View File

@ -2,7 +2,9 @@ mod vrchat;
use std::fs::File; use std::fs::File;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::sync::Mutex;
use rayon::prelude::*;
use rpassword::read_password; use rpassword::read_password;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tfhe::{prelude::*, ClientKey, CompressedFheUint128, FheBool, FheUint128, ServerKey}; use tfhe::{prelude::*, ClientKey, CompressedFheUint128, FheBool, FheUint128, ServerKey};
@ -37,29 +39,37 @@ async fn calc_phase() -> Result<(), Box<dyn std::error::Error>> {
println!("相手の暗号化済みフレンドリストの読み込み中..."); println!("相手の暗号化済みフレンドリストの読み込み中...");
let (remote_data, server_keys) = load_remote_data().await.unwrap(); let (remote_data, server_keys) = load_remote_data().await.unwrap();
rayon::broadcast(|_| set_server_key(server_keys.clone()));
set_server_key(server_keys); set_server_key(server_keys);
println!("暗号文の展開中(この処理には時間がかかります)..."); println!("暗号文の展開中(この処理には時間がかかります)...");
let remote_friends: Vec<CompressedFheUint128> = remote_data.friends.iter().map(|f| &f.value).map(|v| bincode::deserialize(&v).unwrap()).collect(); let remote_friends: Vec<CompressedFheUint128> = remote_data.friends.par_iter().map(|f| &f.value).map(|v| bincode::deserialize(&v).unwrap()).collect();
let remote_friends: Vec<FheUint128> = remote_friends.iter().map(|f| f.decompress()).collect(); let remote_friends: Vec<FheUint128> = remote_friends.par_iter().map(|f| f.decompress()).collect();
// calc time
println!("共通フレンドの計算中..."); println!("共通フレンドの計算中...");
let mut compared: Vec<Vec<FheBool>> = Vec::new(); let now = std::time::Instant::now();
for i in 0..remote_friends.len() { let all = my_friends.len() * remote_friends.len();
compared.push(Vec::new()); let count = Mutex::new(0);
for j in 0..my_friends.len() { let compared: Vec<Vec<FheBool>> = remote_friends.par_iter().enumerate().map(|(i, remote_friend)| {
compared[i].push(remote_friends[i].eq(my_friends[j])); (0..my_friends.len()).into_par_iter().map(|j| {
println!("{}/{}: {}/{}", i, remote_friends.len(), j, my_friends.len()); let result = remote_friend.eq(my_friends[j]);
} let mut count = count.lock().unwrap();
} *count += 1;
println!("{} / {}", *count, all);
result
}).collect()
}).collect();
println!("圧縮中..."); println!("圧縮中...");
let sample = compared[0][0].clone();
let mut result: Vec<FheBool> = Vec::new(); let mut result: Vec<FheBool> = Vec::new();
for res in compared { for res in compared {
result.push(res.iter().fold(sample.clone(), |acc, x| (acc | x))); result.push(res.par_iter().cloned().reduce_with(|acc, x| acc | x).unwrap());
} }
let elapsed = now.elapsed();
println!("計算時間: {}.{:03}", elapsed.as_secs(), elapsed.subsec_millis());
println!("データの整形中..."); println!("データの整形中...");
let serialized_compared_friends: SerializedComparedFriends = SerializedComparedFriends { friends: result.iter().map(|f| SerializedComparedFriend { value: bincode::serialize(f).unwrap() }).collect() }; let serialized_compared_friends: SerializedComparedFriends = SerializedComparedFriends { friends: result.iter().map(|f| SerializedComparedFriend { value: bincode::serialize(f).unwrap() }).collect() };