mirror of
https://github.com/mii443/tfhe-mutual-friends.git
synced 2025-08-22 16:15:39 +00:00
multithread, avx
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -1905,6 +1905,7 @@ name = "tfhe-mutual-friends"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"rayon",
|
||||
"rpassword",
|
||||
"serde",
|
||||
"serde_bytes",
|
||||
|
@ -4,7 +4,7 @@ version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64-unix"] }
|
||||
tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64-unix", "nightly-avx512"] }
|
||||
vrchatapi = { git = "https://github.com/C0D3-M4513R/vrchatapi-rust.git", rev = "41255a7932d5626effec7421bad001703c977a31" }
|
||||
serde_json = "1.0"
|
||||
bincode = "1.3.3"
|
||||
@ -13,3 +13,4 @@ serde = "1.0"
|
||||
zstd = "0.13.1"
|
||||
rpassword = "7.3"
|
||||
tokio = { version = "1.38.0", features = ["full"] }
|
||||
rayon = "1.10.0"
|
||||
|
34
src/main.rs
34
src/main.rs
@ -2,7 +2,9 @@ mod vrchat;
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::{Read, Write};
|
||||
use std::sync::Mutex;
|
||||
|
||||
use rayon::prelude::*;
|
||||
use rpassword::read_password;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tfhe::{prelude::*, ClientKey, CompressedFheUint128, FheBool, FheUint128, ServerKey};
|
||||
@ -37,29 +39,37 @@ async fn calc_phase() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("相手の暗号化済みフレンドリストの読み込み中...");
|
||||
let (remote_data, server_keys) = load_remote_data().await.unwrap();
|
||||
|
||||
rayon::broadcast(|_| set_server_key(server_keys.clone()));
|
||||
set_server_key(server_keys);
|
||||
|
||||
println!("暗号文の展開中(この処理には時間がかかります)...");
|
||||
let remote_friends: Vec<CompressedFheUint128> = remote_data.friends.iter().map(|f| &f.value).map(|v| bincode::deserialize(&v).unwrap()).collect();
|
||||
let remote_friends: Vec<FheUint128> = remote_friends.iter().map(|f| f.decompress()).collect();
|
||||
let remote_friends: Vec<CompressedFheUint128> = remote_data.friends.par_iter().map(|f| &f.value).map(|v| bincode::deserialize(&v).unwrap()).collect();
|
||||
let remote_friends: Vec<FheUint128> = remote_friends.par_iter().map(|f| f.decompress()).collect();
|
||||
|
||||
// calc time
|
||||
println!("共通フレンドの計算中...");
|
||||
let mut compared: Vec<Vec<FheBool>> = Vec::new();
|
||||
for i in 0..remote_friends.len() {
|
||||
compared.push(Vec::new());
|
||||
for j in 0..my_friends.len() {
|
||||
compared[i].push(remote_friends[i].eq(my_friends[j]));
|
||||
println!("{}/{}: {}/{}", i, remote_friends.len(), j, my_friends.len());
|
||||
}
|
||||
}
|
||||
let now = std::time::Instant::now();
|
||||
let all = my_friends.len() * remote_friends.len();
|
||||
let count = Mutex::new(0);
|
||||
let compared: Vec<Vec<FheBool>> = remote_friends.par_iter().enumerate().map(|(i, remote_friend)| {
|
||||
(0..my_friends.len()).into_par_iter().map(|j| {
|
||||
let result = remote_friend.eq(my_friends[j]);
|
||||
let mut count = count.lock().unwrap();
|
||||
*count += 1;
|
||||
println!("{} / {}", *count, all);
|
||||
result
|
||||
}).collect()
|
||||
}).collect();
|
||||
|
||||
println!("圧縮中...");
|
||||
let sample = compared[0][0].clone();
|
||||
let mut result: Vec<FheBool> = Vec::new();
|
||||
for res in compared {
|
||||
result.push(res.iter().fold(sample.clone(), |acc, x| (acc | x)));
|
||||
result.push(res.par_iter().cloned().reduce_with(|acc, x| acc | x).unwrap());
|
||||
}
|
||||
|
||||
let elapsed = now.elapsed();
|
||||
println!("計算時間: {}.{:03}秒", elapsed.as_secs(), elapsed.subsec_millis());
|
||||
|
||||
println!("データの整形中...");
|
||||
let serialized_compared_friends: SerializedComparedFriends = SerializedComparedFriends { friends: result.iter().map(|f| SerializedComparedFriend { value: bincode::serialize(f).unwrap() }).collect() };
|
||||
|
||||
|
Reference in New Issue
Block a user