diff --git a/src/lib.rs b/src/lib.rs index 63e6edd..a7f65de 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -271,17 +271,24 @@ impl Drop for PrivateKey { impl PartialEq for PrivateKey { // ⚠️ This is not a constant-time implementation fn eq(&self, other: &PrivateKey) -> bool { - if self.one_values.len() != other.one_values.len() { + if self.algorithm != other.algorithm { return false; } - if self.zero_values.len() != other.zero_values.len() { + // NOTE: The `zero_values` and `one_values` need not be of the + // the same length (and maybe this should change). + let zero_size = self.zero_values.len(); + let one_size = self.one_values.len(); + if zero_size != other.zero_values.len() || one_size != other.one_values.len() { return false; } - for i in 0..self.zero_values.len() { - if self.zero_values[i] != other.zero_values[i] || - self.one_values[i] != other.one_values[i] - { + for i in 0..zero_size { + if self.zero_values[i] != other.zero_values[i] { + return false; + } + } + for i in 0..one_size { + if self.one_values[i] != other.one_values[i] { return false; } } diff --git a/src/tests.rs b/src/tests.rs index 26505b1..5d192ed 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -88,3 +88,27 @@ fn test_serialization_panic() { assert_eq!(pub_key.one_values, recovered_pub_key.one_values); assert_eq!(pub_key.zero_values, recovered_pub_key.zero_values); } + +#[test] +fn test_private_key_equality() { + let mut pub_key = PrivateKey::new(digest_512); + let pub_key_2 = pub_key.clone(); + + assert!(pub_key == pub_key_2); + + pub_key.one_values.push(vec![0]); + + assert!(pub_key != pub_key_2); + + let mut pub_key = PrivateKey::new(digest_512); + let pub_key_2 = pub_key.clone(); + pub_key.one_values.pop(); + + assert!(pub_key != pub_key_2); + + let mut pub_key = PrivateKey::new(digest_512); + let pub_key_2 = pub_key.clone(); + pub_key.algorithm = digest_256; + + assert!(pub_key != pub_key_2); +}