This commit is contained in:
mii443
2025-07-08 21:45:58 +09:00
parent 7df121954d
commit 3f51136654
6 changed files with 154 additions and 47 deletions

52
Cargo.lock generated
View File

@ -35,6 +35,37 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268"
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "either"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]]
name = "getrandom"
version = "0.3.3"
@ -55,6 +86,7 @@ dependencies = [
"rand",
"rand_chacha",
"rand_distr",
"rayon",
]
[[package]]
@ -239,6 +271,26 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "rayon"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]]
name = "safe_arch"
version = "0.7.4"

View File

@ -8,3 +8,4 @@ nalgebra = "0.33.2"
rand = "0.9.1"
rand_chacha = "0.9.0"
rand_distr = "0.5.1"
rayon = "1.10.0"

View File

@ -1,4 +1,4 @@
use nalgebra::SMatrix;
use nalgebra::DVector;
use crate::{
common,
@ -7,22 +7,60 @@ use crate::{
pub struct Client {
pub i: usize,
pub projection_matrix: Vec<SMatrix<f32, GRID_SIZE, 1>>,
pub projection_matrix: Vec<DVector<f32>>,
}
pub struct ClientRef<'a> {
pub i: usize,
pub projection_matrix: &'a Vec<DVector<f32>>,
}
impl Client {
pub fn new(i: usize, projection_matrix: Vec<SMatrix<f32, GRID_SIZE, 1>>) -> Self {
pub fn new(i: usize, projection_matrix: Vec<DVector<f32>>) -> Self {
Client {
i,
projection_matrix,
}
}
pub fn new_with_ref(i: usize, projection_matrix: &Vec<DVector<f32>>) -> ClientRef {
ClientRef {
i,
projection_matrix,
}
}
pub fn observe(&mut self, current_grid: (usize, usize)) -> (usize, f32) {
let index = common::grid_to_index(current_grid);
let projection = &self.projection_matrix[self.i];
let observation = projection[(index, 0)];
let observation = projection[index];
let result = (self.i, observation);
if self.i < NUM_OF_PROJECTIONS - 1 {
self.i += 1;
} else {
self.i = 0;
}
result
}
pub fn observe_quantized(&mut self, current_grid: (usize, usize)) -> (usize, i8) {
let normal_observation = self.observe(current_grid);
let quantized_observation = common::quantize(normal_observation.1);
(normal_observation.0, quantized_observation)
}
}
impl<'a> ClientRef<'a> {
pub fn observe(&mut self, current_grid: (usize, usize)) -> (usize, f32) {
let index = common::grid_to_index(current_grid);
let projection = &self.projection_matrix[self.i];
let observation = projection[index];
let result = (self.i, observation);

View File

@ -1,5 +1,5 @@
pub const GRID_HEIGHT: usize = 500;
pub const GRID_WIDTH: usize = 500;
pub const GRID_HEIGHT: usize = 1000;
pub const GRID_WIDTH: usize = 1000;
pub const GRID_SIZE: usize = GRID_HEIGHT * GRID_WIDTH;
pub const NUM_OF_PROJECTIONS: usize = 256;

View File

@ -1,6 +1,8 @@
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use rand::seq::IndexedRandom;
use rayon::prelude::*;
use crate::constant::NUM_OF_PROJECTIONS;
@ -10,32 +12,45 @@ pub mod constant;
pub mod server;
fn main() {
// Set larger stack size for threads
rayon::ThreadPoolBuilder::new()
.stack_size(8 * 1024 * 1024) // 8MB stack
.build_global()
.unwrap();
println!("Generating projections...");
let projections = server::generate_projection_matrix();
let projections = Arc::new(server::generate_projection_matrix());
println!("Generating fingerprint database...");
let fingerprint_database = server::generate_fingerprint_database(projections.clone());
let mut client = client::Client::new(0, projections);
let fingerprint_database = server::generate_fingerprint_database((*projections).clone());
let grid_candidates: Vec<(usize, usize)> = vec![(5, 5), (120, 50), (200, 200), (10, 5)];
let mut rng = rand::rng();
let mut count: HashMap<usize, usize> = HashMap::new();
for i in 0..NUM_OF_PROJECTIONS * 3 {
let current_grid = *grid_candidates.choose(&mut rng).unwrap();
let (index, observation) = client.observe_quantized(current_grid);
println!("Observation {i}: index = {index}, observation = {observation}");
let predictions = server::predict_location_from_database_quantized(
&fingerprint_database,
(index, observation),
);
for location in predictions.iter() {
count.entry(*location).and_modify(|e| *e += 1).or_insert(1);
}
}
let count = Mutex::new(HashMap::<usize, usize>::new());
(0..NUM_OF_PROJECTIONS * 3)
.into_par_iter()
.for_each(|i| {
let mut thread_rng = rand::rng();
let current_grid = *grid_candidates.choose(&mut thread_rng).unwrap();
let thread_projections = Arc::clone(&projections);
let mut thread_client = client::Client::new_with_ref(i % NUM_OF_PROJECTIONS, &thread_projections);
let (index, observation) = thread_client.observe_quantized(current_grid);
println!("Observation {i}: index = {index}, observation = {observation}");
let predictions = server::predict_location_from_database_quantized(
&fingerprint_database,
(index, observation),
);
let mut count_guard = count.lock().unwrap();
for location in predictions.iter() {
count_guard.entry(*location).and_modify(|e| *e += 1).or_insert(1);
}
});
let count = count.into_inner().unwrap();
let sorted = count.iter().collect::<Vec<_>>();
let mut sorted = sorted

View File

@ -1,25 +1,26 @@
use std::collections::HashMap;
use nalgebra::SMatrix;
use nalgebra::DVector;
use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng};
use rand_distr::{Distribution, Normal};
use rayon::prelude::*;
use crate::{
common,
constant::{GRID_SIZE, NUM_OF_PROJECTIONS},
};
pub fn generate_projection_matrix() -> Vec<SMatrix<f32, GRID_SIZE, 1>> {
let mut rng = ChaCha20Rng::from_rng(&mut rand::rng());
pub fn generate_projection_matrix() -> Vec<DVector<f32>> {
let std_dev = 1.0 / (GRID_SIZE as f64).sqrt();
let normal = Normal::new(0.0, std_dev).unwrap();
(0..NUM_OF_PROJECTIONS)
.map(|_| {
let mut projection = SMatrix::<f32, GRID_SIZE, 1>::zeros();
for i in 0..GRID_SIZE {
projection[(i, 0)] = normal.sample(&mut rng) as f32;
.into_par_iter()
.map(|i| {
let mut rng = ChaCha20Rng::from_seed([i as u8; 32]);
let mut projection = DVector::<f32>::zeros(GRID_SIZE);
for j in 0..GRID_SIZE {
projection[j] = normal.sample(&mut rng) as f32;
}
projection
})
@ -27,20 +28,20 @@ pub fn generate_projection_matrix() -> Vec<SMatrix<f32, GRID_SIZE, 1>> {
}
pub fn generate_fingerprint_database(
projections: Vec<SMatrix<f32, GRID_SIZE, 1>>,
projections: Vec<DVector<f32>>,
) -> HashMap<usize, (usize, Vec<f32>)> {
let mut fingerprint_database = HashMap::new();
for (i, projection) in projections.iter().enumerate() {
let mut fingerprint = Vec::with_capacity(GRID_SIZE);
for j in 0..GRID_SIZE {
let observation = projection[(j, 0)];
fingerprint.push(observation);
}
fingerprint_database.insert(i, (i, fingerprint));
}
fingerprint_database
projections
.into_par_iter()
.enumerate()
.map(|(i, projection)| {
let mut fingerprint = Vec::with_capacity(GRID_SIZE);
for j in 0..GRID_SIZE {
let observation = projection[j];
fingerprint.push(observation);
}
(i, (i, fingerprint))
})
.collect()
}
pub fn predict_location_from_database(