multithread, avx

This commit is contained in:
Masato Imai
2024-06-18 10:30:17 +00:00
parent 46fa9f40d4
commit e1096eef1b
3 changed files with 26 additions and 14 deletions

1
Cargo.lock generated
View File

@ -1905,6 +1905,7 @@ name = "tfhe-mutual-friends"
version = "0.1.0"
dependencies = [
"bincode",
"rayon",
"rpassword",
"serde",
"serde_bytes",

View File

@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2021"
[dependencies]
tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64-unix"] }
tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64-unix", "nightly-avx512"] }
vrchatapi = { git = "https://github.com/C0D3-M4513R/vrchatapi-rust.git", rev = "41255a7932d5626effec7421bad001703c977a31" }
serde_json = "1.0"
bincode = "1.3.3"
@ -13,3 +13,4 @@ serde = "1.0"
zstd = "0.13.1"
rpassword = "7.3"
tokio = { version = "1.38.0", features = ["full"] }
rayon = "1.10.0"

View File

@ -2,7 +2,9 @@ mod vrchat;
use std::fs::File;
use std::io::{Read, Write};
use std::sync::Mutex;
use rayon::prelude::*;
use rpassword::read_password;
use serde::{Deserialize, Serialize};
use tfhe::{prelude::*, ClientKey, CompressedFheUint128, FheBool, FheUint128, ServerKey};
@ -37,29 +39,37 @@ async fn calc_phase() -> Result<(), Box<dyn std::error::Error>> {
println!("相手の暗号化済みフレンドリストの読み込み中...");
let (remote_data, server_keys) = load_remote_data().await.unwrap();
rayon::broadcast(|_| set_server_key(server_keys.clone()));
set_server_key(server_keys);
println!("暗号文の展開中(この処理には時間がかかります)...");
let remote_friends: Vec<CompressedFheUint128> = remote_data.friends.iter().map(|f| &f.value).map(|v| bincode::deserialize(&v).unwrap()).collect();
let remote_friends: Vec<FheUint128> = remote_friends.iter().map(|f| f.decompress()).collect();
let remote_friends: Vec<CompressedFheUint128> = remote_data.friends.par_iter().map(|f| &f.value).map(|v| bincode::deserialize(&v).unwrap()).collect();
let remote_friends: Vec<FheUint128> = remote_friends.par_iter().map(|f| f.decompress()).collect();
// calc time
println!("共通フレンドの計算中...");
let mut compared: Vec<Vec<FheBool>> = Vec::new();
for i in 0..remote_friends.len() {
compared.push(Vec::new());
for j in 0..my_friends.len() {
compared[i].push(remote_friends[i].eq(my_friends[j]));
println!("{}/{}: {}/{}", i, remote_friends.len(), j, my_friends.len());
}
}
let now = std::time::Instant::now();
let all = my_friends.len() * remote_friends.len();
let count = Mutex::new(0);
let compared: Vec<Vec<FheBool>> = remote_friends.par_iter().enumerate().map(|(i, remote_friend)| {
(0..my_friends.len()).into_par_iter().map(|j| {
let result = remote_friend.eq(my_friends[j]);
let mut count = count.lock().unwrap();
*count += 1;
println!("{} / {}", *count, all);
result
}).collect()
}).collect();
println!("圧縮中...");
let sample = compared[0][0].clone();
let mut result: Vec<FheBool> = Vec::new();
for res in compared {
result.push(res.iter().fold(sample.clone(), |acc, x| (acc | x)));
result.push(res.par_iter().cloned().reduce_with(|acc, x| acc | x).unwrap());
}
let elapsed = now.elapsed();
println!("計算時間: {}.{:03}", elapsed.as_secs(), elapsed.subsec_millis());
println!("データの整形中...");
let serialized_compared_friends: SerializedComparedFriends = SerializedComparedFriends { friends: result.iter().map(|f| SerializedComparedFriend { value: bincode::serialize(f).unwrap() }).collect() };