diff --git a/src/main.rs b/src/main.rs index 155efc0..6d0f7e5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -58,7 +58,9 @@ fn lagrange_interpolate(src: &[(u8, u8)], raw_x: u8) -> u8 { for (j, &(raw_xj, _)) in src.iter().enumerate() { if i != j { let xj = Gf256::from_byte(raw_xj); - lix = lix * (x - xj) / (xi - xj); + let delta = xi - xj; + assert!(delta.poly !=0, "Duplicate shares"); + lix = lix * (x - xj) / delta; } } sum = sum + lix * yi; @@ -135,7 +137,7 @@ fn read_shares() -> IoResult<(u8, Vec<(u8,Vec)>)> { let mut stdin = BufferedReader::new(stdio::stdin()); let mut opt_k_l: Option<(u8, usize)> = None; let mut counter = 0u8; - let mut shares = Vec::new(); + let mut shares: Vec<(u8,Vec)> = Vec::new(); for line in stdin.lines() { let line = try!(line); let parts: Vec<_> = line.split('-').collect(); @@ -167,13 +169,15 @@ fn read_shares() -> IoResult<(u8, Vec<(u8,Vec)>)> { } else { opt_k_l = Some((k,raw.len())); } - shares.push((n, raw)); - counter += 1; - if counter == k { - return Ok((k, shares)); + if shares.iter().all(|s| s.0 != n) { + shares.push((n, raw)); + counter += 1; + if counter == k { + return Ok((k, shares)); + } } } - Err(other_io_err("No shares")) + Err(other_io_err("Not enough shares provided!")) } fn perform_decode() -> IoResult<()> { @@ -213,6 +217,7 @@ fn main() { "The program secretshare is an implementation of Shamir's secret sharing scheme.\n\ It is applied byte-wise within a finite field for arbitraty long secrets.\n"); println!("{}", opts.usage("Usage: secretshare [options]")); + println!("Input is read from STDIN and output is written to STDOUT."); return; }