diff --git a/src/errors.rs b/src/errors.rs index 9c300ff..40fe950 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -64,6 +64,11 @@ error_chain! { display("The shares are incompatible with each other because they do not all have the same threshold.") } + IncompatibleDataLengths(sets: Vec>) { + description("The shares are incompatible with each other because they do not all have the same share data length.") + display("The shares are incompatible with each other because they do not all have the same share data length.") + } + MissingShares(provided: usize, required: u8) { description("The number of shares provided is insufficient to recover the secret.") display("{} shares are required to recover the secret, found only {}.", required, provided) diff --git a/src/share/validation.rs b/src/share/validation.rs index a94b4dd..93d0078 100644 --- a/src/share/validation.rs +++ b/src/share/validation.rs @@ -33,14 +33,17 @@ pub(crate) fn validate_shares(shares: Vec) -> Result<(u8, Vec) let mut result: Vec = Vec::with_capacity(shares_count); let mut k_compatibility_sets = HashMap::new(); + let mut data_len_compatibility_sets = HashMap::new(); for share in shares { - let (id, threshold) = (share.get_id(), share.get_threshold()); + let (id, threshold, data_len) = (share.get_id(), share.get_threshold(), share.get_data().len()); if id < 1 { bail!(ErrorKind::ShareParsingInvalidShareId(id)) } else if threshold < 2 { bail!(ErrorKind::ShareParsingInvalidShareThreshold(threshold, id)) + } else if data_len < 1 { + bail!(ErrorKind::ShareParsingErrorEmptyShare(id)) } k_compatibility_sets @@ -53,9 +56,12 @@ pub(crate) fn validate_shares(shares: Vec) -> Result<(u8, Vec) bail!(ErrorKind::DuplicateShareId(id)); } - if share.get_data().is_empty() { - bail!(ErrorKind::ShareParsingErrorEmptyShare(id)) - } + data_len_compatibility_sets + .entry(data_len) + .or_insert_with(HashSet::new); + let data_len_set = data_len_compatibility_sets.get_mut(&data_len).unwrap(); + data_len_set.insert(id); + result.push(share); } @@ -84,6 +90,23 @@ pub(crate) fn validate_shares(shares: Vec) -> Result<(u8, Vec) bail!(ErrorKind::MissingShares(shares_count, threshold)); } + // Validate share length consistency + let data_len_sets = data_len_compatibility_sets.keys().count(); + + match data_len_sets { + 1 => {} // All shares have the same `data` field len + _ => { + bail! { + ErrorKind::IncompatibleDataLengths( + data_len_compatibility_sets + .values() + .map(|x| x.to_owned()) + .collect(), + ) + } + } + } + Ok((threshold, result)) } diff --git a/tests/recovery_errors.rs b/tests/recovery_errors.rs index 1fd7295..c8f13e1 100644 --- a/tests/recovery_errors.rs +++ b/tests/recovery_errors.rs @@ -66,6 +66,17 @@ fn test_recover_duplicate_shares_number() { recover_secret(&shares, false).unwrap(); } +#[test] +#[should_panic(expected = "IncompatibleDataLengths")] +fn test_recover_incompatible_data_lengths() { + let share1 = "2-1-CgnlCxRNtnkzENE".to_string(); + let share2 = "2-2-ChbG46L1zRszs0PPn63XnnupmZTcgYJ3".to_string(); + + let shares = vec![share1, share2]; + + recover_secret(&shares, false).unwrap(); +} + #[test] #[should_panic(expected = "MissingShares")] fn test_recover_too_few_shares() {