mirror of
https://github.com/mii443/tfhe-mutual-friends.git
synced 2025-08-22 16:15:39 +00:00
multithread, avx
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -1905,6 +1905,7 @@ name = "tfhe-mutual-friends"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bincode",
|
"bincode",
|
||||||
|
"rayon",
|
||||||
"rpassword",
|
"rpassword",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_bytes",
|
"serde_bytes",
|
||||||
|
@ -4,7 +4,7 @@ version = "0.1.0"
|
|||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64-unix"] }
|
tfhe = { version = "*", features = ["boolean", "shortint", "integer", "x86_64-unix", "nightly-avx512"] }
|
||||||
vrchatapi = { git = "https://github.com/C0D3-M4513R/vrchatapi-rust.git", rev = "41255a7932d5626effec7421bad001703c977a31" }
|
vrchatapi = { git = "https://github.com/C0D3-M4513R/vrchatapi-rust.git", rev = "41255a7932d5626effec7421bad001703c977a31" }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
bincode = "1.3.3"
|
bincode = "1.3.3"
|
||||||
@ -13,3 +13,4 @@ serde = "1.0"
|
|||||||
zstd = "0.13.1"
|
zstd = "0.13.1"
|
||||||
rpassword = "7.3"
|
rpassword = "7.3"
|
||||||
tokio = { version = "1.38.0", features = ["full"] }
|
tokio = { version = "1.38.0", features = ["full"] }
|
||||||
|
rayon = "1.10.0"
|
||||||
|
34
src/main.rs
34
src/main.rs
@ -2,7 +2,9 @@ mod vrchat;
|
|||||||
|
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{Read, Write};
|
use std::io::{Read, Write};
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
use rayon::prelude::*;
|
||||||
use rpassword::read_password;
|
use rpassword::read_password;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tfhe::{prelude::*, ClientKey, CompressedFheUint128, FheBool, FheUint128, ServerKey};
|
use tfhe::{prelude::*, ClientKey, CompressedFheUint128, FheBool, FheUint128, ServerKey};
|
||||||
@ -37,29 +39,37 @@ async fn calc_phase() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
println!("相手の暗号化済みフレンドリストの読み込み中...");
|
println!("相手の暗号化済みフレンドリストの読み込み中...");
|
||||||
let (remote_data, server_keys) = load_remote_data().await.unwrap();
|
let (remote_data, server_keys) = load_remote_data().await.unwrap();
|
||||||
|
|
||||||
|
rayon::broadcast(|_| set_server_key(server_keys.clone()));
|
||||||
set_server_key(server_keys);
|
set_server_key(server_keys);
|
||||||
|
|
||||||
println!("暗号文の展開中(この処理には時間がかかります)...");
|
println!("暗号文の展開中(この処理には時間がかかります)...");
|
||||||
let remote_friends: Vec<CompressedFheUint128> = remote_data.friends.iter().map(|f| &f.value).map(|v| bincode::deserialize(&v).unwrap()).collect();
|
let remote_friends: Vec<CompressedFheUint128> = remote_data.friends.par_iter().map(|f| &f.value).map(|v| bincode::deserialize(&v).unwrap()).collect();
|
||||||
let remote_friends: Vec<FheUint128> = remote_friends.iter().map(|f| f.decompress()).collect();
|
let remote_friends: Vec<FheUint128> = remote_friends.par_iter().map(|f| f.decompress()).collect();
|
||||||
|
|
||||||
|
// calc time
|
||||||
println!("共通フレンドの計算中...");
|
println!("共通フレンドの計算中...");
|
||||||
let mut compared: Vec<Vec<FheBool>> = Vec::new();
|
let now = std::time::Instant::now();
|
||||||
for i in 0..remote_friends.len() {
|
let all = my_friends.len() * remote_friends.len();
|
||||||
compared.push(Vec::new());
|
let count = Mutex::new(0);
|
||||||
for j in 0..my_friends.len() {
|
let compared: Vec<Vec<FheBool>> = remote_friends.par_iter().enumerate().map(|(i, remote_friend)| {
|
||||||
compared[i].push(remote_friends[i].eq(my_friends[j]));
|
(0..my_friends.len()).into_par_iter().map(|j| {
|
||||||
println!("{}/{}: {}/{}", i, remote_friends.len(), j, my_friends.len());
|
let result = remote_friend.eq(my_friends[j]);
|
||||||
}
|
let mut count = count.lock().unwrap();
|
||||||
}
|
*count += 1;
|
||||||
|
println!("{} / {}", *count, all);
|
||||||
|
result
|
||||||
|
}).collect()
|
||||||
|
}).collect();
|
||||||
|
|
||||||
println!("圧縮中...");
|
println!("圧縮中...");
|
||||||
let sample = compared[0][0].clone();
|
|
||||||
let mut result: Vec<FheBool> = Vec::new();
|
let mut result: Vec<FheBool> = Vec::new();
|
||||||
for res in compared {
|
for res in compared {
|
||||||
result.push(res.iter().fold(sample.clone(), |acc, x| (acc | x)));
|
result.push(res.par_iter().cloned().reduce_with(|acc, x| acc | x).unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let elapsed = now.elapsed();
|
||||||
|
println!("計算時間: {}.{:03}秒", elapsed.as_secs(), elapsed.subsec_millis());
|
||||||
|
|
||||||
println!("データの整形中...");
|
println!("データの整形中...");
|
||||||
let serialized_compared_friends: SerializedComparedFriends = SerializedComparedFriends { friends: result.iter().map(|f| SerializedComparedFriend { value: bincode::serialize(f).unwrap() }).collect() };
|
let serialized_compared_friends: SerializedComparedFriends = SerializedComparedFriends { friends: result.iter().map(|f| SerializedComparedFriend { value: bincode::serialize(f).unwrap() }).collect() };
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user