This commit is contained in:
mii443
2025-07-08 21:45:58 +09:00
parent 7df121954d
commit 3f51136654
6 changed files with 154 additions and 47 deletions

52
Cargo.lock generated
View File

@ -35,6 +35,37 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268"
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "either"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
[[package]] [[package]]
name = "getrandom" name = "getrandom"
version = "0.3.3" version = "0.3.3"
@ -55,6 +86,7 @@ dependencies = [
"rand", "rand",
"rand_chacha", "rand_chacha",
"rand_distr", "rand_distr",
"rayon",
] ]
[[package]] [[package]]
@ -239,6 +271,26 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "rayon"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]] [[package]]
name = "safe_arch" name = "safe_arch"
version = "0.7.4" version = "0.7.4"

View File

@ -8,3 +8,4 @@ nalgebra = "0.33.2"
rand = "0.9.1" rand = "0.9.1"
rand_chacha = "0.9.0" rand_chacha = "0.9.0"
rand_distr = "0.5.1" rand_distr = "0.5.1"
rayon = "1.10.0"

View File

@ -1,4 +1,4 @@
use nalgebra::SMatrix; use nalgebra::DVector;
use crate::{ use crate::{
common, common,
@ -7,22 +7,60 @@ use crate::{
pub struct Client { pub struct Client {
pub i: usize, pub i: usize,
pub projection_matrix: Vec<SMatrix<f32, GRID_SIZE, 1>>, pub projection_matrix: Vec<DVector<f32>>,
}
pub struct ClientRef<'a> {
pub i: usize,
pub projection_matrix: &'a Vec<DVector<f32>>,
} }
impl Client { impl Client {
pub fn new(i: usize, projection_matrix: Vec<SMatrix<f32, GRID_SIZE, 1>>) -> Self { pub fn new(i: usize, projection_matrix: Vec<DVector<f32>>) -> Self {
Client { Client {
i, i,
projection_matrix, projection_matrix,
} }
} }
pub fn new_with_ref(i: usize, projection_matrix: &Vec<DVector<f32>>) -> ClientRef {
ClientRef {
i,
projection_matrix,
}
}
pub fn observe(&mut self, current_grid: (usize, usize)) -> (usize, f32) { pub fn observe(&mut self, current_grid: (usize, usize)) -> (usize, f32) {
let index = common::grid_to_index(current_grid); let index = common::grid_to_index(current_grid);
let projection = &self.projection_matrix[self.i]; let projection = &self.projection_matrix[self.i];
let observation = projection[(index, 0)]; let observation = projection[index];
let result = (self.i, observation);
if self.i < NUM_OF_PROJECTIONS - 1 {
self.i += 1;
} else {
self.i = 0;
}
result
}
pub fn observe_quantized(&mut self, current_grid: (usize, usize)) -> (usize, i8) {
let normal_observation = self.observe(current_grid);
let quantized_observation = common::quantize(normal_observation.1);
(normal_observation.0, quantized_observation)
}
}
impl<'a> ClientRef<'a> {
pub fn observe(&mut self, current_grid: (usize, usize)) -> (usize, f32) {
let index = common::grid_to_index(current_grid);
let projection = &self.projection_matrix[self.i];
let observation = projection[index];
let result = (self.i, observation); let result = (self.i, observation);

View File

@ -1,5 +1,5 @@
pub const GRID_HEIGHT: usize = 500; pub const GRID_HEIGHT: usize = 1000;
pub const GRID_WIDTH: usize = 500; pub const GRID_WIDTH: usize = 1000;
pub const GRID_SIZE: usize = GRID_HEIGHT * GRID_WIDTH; pub const GRID_SIZE: usize = GRID_HEIGHT * GRID_WIDTH;
pub const NUM_OF_PROJECTIONS: usize = 256; pub const NUM_OF_PROJECTIONS: usize = 256;

View File

@ -1,6 +1,8 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use rand::seq::IndexedRandom; use rand::seq::IndexedRandom;
use rayon::prelude::*;
use crate::constant::NUM_OF_PROJECTIONS; use crate::constant::NUM_OF_PROJECTIONS;
@ -10,32 +12,45 @@ pub mod constant;
pub mod server; pub mod server;
fn main() { fn main() {
// Set larger stack size for threads
rayon::ThreadPoolBuilder::new()
.stack_size(8 * 1024 * 1024) // 8MB stack
.build_global()
.unwrap();
println!("Generating projections..."); println!("Generating projections...");
let projections = server::generate_projection_matrix(); let projections = Arc::new(server::generate_projection_matrix());
println!("Generating fingerprint database..."); println!("Generating fingerprint database...");
let fingerprint_database = server::generate_fingerprint_database(projections.clone()); let fingerprint_database = server::generate_fingerprint_database((*projections).clone());
let mut client = client::Client::new(0, projections);
let grid_candidates: Vec<(usize, usize)> = vec![(5, 5), (120, 50), (200, 200), (10, 5)]; let grid_candidates: Vec<(usize, usize)> = vec![(5, 5), (120, 50), (200, 200), (10, 5)];
let mut rng = rand::rng(); let count = Mutex::new(HashMap::<usize, usize>::new());
let mut count: HashMap<usize, usize> = HashMap::new(); (0..NUM_OF_PROJECTIONS * 3)
for i in 0..NUM_OF_PROJECTIONS * 3 { .into_par_iter()
let current_grid = *grid_candidates.choose(&mut rng).unwrap(); .for_each(|i| {
let mut thread_rng = rand::rng();
let current_grid = *grid_candidates.choose(&mut thread_rng).unwrap();
let (index, observation) = client.observe_quantized(current_grid); let thread_projections = Arc::clone(&projections);
let mut thread_client = client::Client::new_with_ref(i % NUM_OF_PROJECTIONS, &thread_projections);
let (index, observation) = thread_client.observe_quantized(current_grid);
println!("Observation {i}: index = {index}, observation = {observation}"); println!("Observation {i}: index = {index}, observation = {observation}");
let predictions = server::predict_location_from_database_quantized( let predictions = server::predict_location_from_database_quantized(
&fingerprint_database, &fingerprint_database,
(index, observation), (index, observation),
); );
let mut count_guard = count.lock().unwrap();
for location in predictions.iter() { for location in predictions.iter() {
count.entry(*location).and_modify(|e| *e += 1).or_insert(1); count_guard.entry(*location).and_modify(|e| *e += 1).or_insert(1);
}
} }
});
let count = count.into_inner().unwrap();
let sorted = count.iter().collect::<Vec<_>>(); let sorted = count.iter().collect::<Vec<_>>();
let mut sorted = sorted let mut sorted = sorted

View File

@ -1,25 +1,26 @@
use std::collections::HashMap; use std::collections::HashMap;
use nalgebra::SMatrix; use nalgebra::DVector;
use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng};
use rand_distr::{Distribution, Normal}; use rand_distr::{Distribution, Normal};
use rayon::prelude::*;
use crate::{ use crate::{
common, common,
constant::{GRID_SIZE, NUM_OF_PROJECTIONS}, constant::{GRID_SIZE, NUM_OF_PROJECTIONS},
}; };
pub fn generate_projection_matrix() -> Vec<SMatrix<f32, GRID_SIZE, 1>> { pub fn generate_projection_matrix() -> Vec<DVector<f32>> {
let mut rng = ChaCha20Rng::from_rng(&mut rand::rng());
let std_dev = 1.0 / (GRID_SIZE as f64).sqrt(); let std_dev = 1.0 / (GRID_SIZE as f64).sqrt();
let normal = Normal::new(0.0, std_dev).unwrap(); let normal = Normal::new(0.0, std_dev).unwrap();
(0..NUM_OF_PROJECTIONS) (0..NUM_OF_PROJECTIONS)
.map(|_| { .into_par_iter()
let mut projection = SMatrix::<f32, GRID_SIZE, 1>::zeros(); .map(|i| {
for i in 0..GRID_SIZE { let mut rng = ChaCha20Rng::from_seed([i as u8; 32]);
projection[(i, 0)] = normal.sample(&mut rng) as f32; let mut projection = DVector::<f32>::zeros(GRID_SIZE);
for j in 0..GRID_SIZE {
projection[j] = normal.sample(&mut rng) as f32;
} }
projection projection
}) })
@ -27,20 +28,20 @@ pub fn generate_projection_matrix() -> Vec<SMatrix<f32, GRID_SIZE, 1>> {
} }
pub fn generate_fingerprint_database( pub fn generate_fingerprint_database(
projections: Vec<SMatrix<f32, GRID_SIZE, 1>>, projections: Vec<DVector<f32>>,
) -> HashMap<usize, (usize, Vec<f32>)> { ) -> HashMap<usize, (usize, Vec<f32>)> {
let mut fingerprint_database = HashMap::new(); projections
.into_par_iter()
for (i, projection) in projections.iter().enumerate() { .enumerate()
.map(|(i, projection)| {
let mut fingerprint = Vec::with_capacity(GRID_SIZE); let mut fingerprint = Vec::with_capacity(GRID_SIZE);
for j in 0..GRID_SIZE { for j in 0..GRID_SIZE {
let observation = projection[(j, 0)]; let observation = projection[j];
fingerprint.push(observation); fingerprint.push(observation);
} }
fingerprint_database.insert(i, (i, fingerprint)); (i, (i, fingerprint))
} })
.collect()
fingerprint_database
} }
pub fn predict_location_from_database( pub fn predict_location_from_database(