From 1bb98a36b1571a8dd0b6184f4a81010ce129172b Mon Sep 17 00:00:00 2001 From: Srinath Setty Date: Tue, 1 Sep 2020 10:13:53 -0700 Subject: [PATCH] additional error checking --- Cargo.toml | 2 +- src/errors.rs | 15 ++------- src/lib.rs | 87 ++++++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 80 insertions(+), 24 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4f5cdeb..f41f559 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "spartan" -version = "0.2.0" +version = "0.2.1" authors = ["Srinath Setty "] edition = "2018" description = "High-speed zkSNARKs without trusted setup" diff --git a/src/errors.rs b/src/errors.rs index 8b04c9a..4917979 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -14,6 +14,7 @@ impl fmt::Debug for ProofVerifyError { } } +#[derive(Clone, Debug, Eq, PartialEq)] pub enum R1CSError { /// returned if the number of constraints is not a power of 2 NonPowerOfTwoCons, @@ -25,16 +26,6 @@ pub enum R1CSError { InvalidNumberOfVars, /// returned if a [u8;32] does not parse into a valid Scalar in the field of ristretto255 InvalidScalar, -} - -impl fmt::Display for R1CSError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "R1CSError") - } -} - -impl fmt::Debug for R1CSError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{{ file: {}, line: {} }}", file!(), line!()) - } + /// returned if the supplied row or col in (row,col,val) tuple is out of range + InvalidIndex, } diff --git a/src/lib.rs b/src/lib.rs index a090d38..be781d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -129,6 +129,17 @@ impl Instance { let mut mat: Vec<(usize, usize, Scalar)> = Vec::new(); for i in 0..tups.len() { let (row, col, val_bytes) = tups[i]; + + // row must be smaller than num_cons + if row >= num_cons { + return Err(R1CSError::InvalidIndex); + } + + // col must be smaller than num_vars + 1 + num_inputs + if col >= num_vars + 1 + num_inputs { + return Err(R1CSError::InvalidIndex); + } + let val = Scalar::from_bytes(&val_bytes); if val.is_some().unwrap_u8() == 1 { mat.push((row, col, val.unwrap())); @@ -140,12 +151,18 @@ impl Instance { }; let A_scalar = bytes_to_scalar(A); - let B_scalar = bytes_to_scalar(B); - let C_scalar = bytes_to_scalar(C); + if A_scalar.is_err() { + return Err(A_scalar.err().unwrap()); + } - // check for any parsing errors - if A_scalar.is_err() || B_scalar.is_err() || C_scalar.is_err() { - return Err(R1CSError::InvalidScalar); + let B_scalar = bytes_to_scalar(B); + if B_scalar.is_err() { + return Err(B_scalar.err().unwrap()); + } + + let C_scalar = bytes_to_scalar(C); + if C_scalar.is_err() { + return Err(C_scalar.err().unwrap()); } let inst = R1CSInstance::new( @@ -161,16 +178,19 @@ impl Instance { } /// Checks if a given R1CSInstance is satisfiable with a given variables and inputs assignments - pub fn is_sat(&self, vars: &VarsAssignment, inputs: &InputsAssignment) -> Result { - + pub fn is_sat( + &self, + vars: &VarsAssignment, + inputs: &InputsAssignment, + ) -> Result { if vars.assignment.len() != self.inst.get_num_vars() { - return Err(R1CSError::InvalidNumberOfVars) + return Err(R1CSError::InvalidNumberOfVars); } - + if inputs.assignment.len() != self.inst.get_num_inputs() { - return Err(R1CSError::InvalidNumberOfInputs) + return Err(R1CSError::InvalidNumberOfInputs); } - + Ok(self.inst.is_sat(&vars.assignment, &inputs.assignment)) } @@ -485,4 +505,49 @@ mod tests { .verify(&comm, &inputs, &mut verifier_transcript, &gens) .is_ok()); } + + #[test] + pub fn check_r1cs_invalid_index() { + let num_cons = 4; + let num_vars = 8; + let num_inputs = 1; + + let zero: [u8; 32] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, + ]; + + let A = vec![(0, 0, zero)]; + let B = vec![(100, 1, zero)]; + let C = vec![(1, 1, zero)]; + + let inst = Instance::new(num_cons, num_vars, num_inputs, &A, &B, &C); + assert_eq!(inst.is_err(), true); + assert_eq!(inst.err(), Some(R1CSError::InvalidIndex)); + } + + #[test] + pub fn check_r1cs_invalid_scalar() { + let num_cons = 4; + let num_vars = 8; + let num_inputs = 1; + + let zero: [u8; 32] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, + ]; + + let larger_than_mod = [ + 3, 0, 0, 0, 255, 255, 255, 255, 254, 91, 254, 255, 2, 164, 189, 83, 5, 216, 161, 9, 8, 216, + 57, 51, 72, 125, 157, 41, 83, 167, 237, 115, + ]; + + let A = vec![(0, 0, zero)]; + let B = vec![(1, 1, larger_than_mod)]; + let C = vec![(1, 1, zero)]; + + let inst = Instance::new(num_cons, num_vars, num_inputs, &A, &B, &C); + assert_eq!(inst.is_err(), true); + assert_eq!(inst.err(), Some(R1CSError::InvalidScalar)); + } }