Refactor to idiomatic Result/Option patterns (#25)
This: - introduces a small [thiserror](https://github.com/dtolnay/thiserror)-powered enum to improve ProofVerifyError's messages, - refactors point decompression errors into a variant of that enum, thereby suppressing the panics which occur when decompresison fails. - folds other panics into the Error cases of their enclosing `Result` return
This commit is contained in:
Родитель
7b102a241f
Коммит
9e4c166edb
|
@ -26,6 +26,7 @@ zeroize = { version = "1", default-features = false }
|
|||
itertools = "0.9.0"
|
||||
colored = "1.9.3"
|
||||
flate2 = "1.0.14"
|
||||
thiserror = "1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3.1"
|
||||
|
|
|
@ -27,7 +27,7 @@ impl MultiCommitGens {
|
|||
|
||||
MultiCommitGens {
|
||||
n,
|
||||
G: gens[0..n].to_vec(),
|
||||
G: gens[..n].to_vec(),
|
||||
h: gens[n],
|
||||
}
|
||||
}
|
||||
|
|
|
@ -90,7 +90,7 @@ impl EqPolynomial {
|
|||
let ell = self.r.len();
|
||||
let (left_num_vars, _right_num_vars) = EqPolynomial::compute_factored_lens(ell);
|
||||
|
||||
let L = EqPolynomial::new(self.r[0..left_num_vars].to_vec()).evals();
|
||||
let L = EqPolynomial::new(self.r[..left_num_vars].to_vec()).evals();
|
||||
let R = EqPolynomial::new(self.r[left_num_vars..ell].to_vec()).evals();
|
||||
|
||||
(L, R)
|
||||
|
@ -137,7 +137,7 @@ impl DensePolynomial {
|
|||
pub fn split(&self, idx: usize) -> (DensePolynomial, DensePolynomial) {
|
||||
assert!(idx < self.len());
|
||||
(
|
||||
DensePolynomial::new(self.Z[0..idx].to_vec()),
|
||||
DensePolynomial::new(self.Z[..idx].to_vec()),
|
||||
DensePolynomial::new(self.Z[idx..2 * idx].to_vec()),
|
||||
)
|
||||
}
|
||||
|
@ -326,18 +326,12 @@ impl PolyEvalProof {
|
|||
let default_blinds = PolyCommitmentBlinds {
|
||||
blinds: vec![Scalar::zero(); L_size],
|
||||
};
|
||||
let blinds = match blinds_opt {
|
||||
Some(p) => p,
|
||||
None => &default_blinds,
|
||||
};
|
||||
let blinds = blinds_opt.map_or(&default_blinds, |p| p);
|
||||
|
||||
assert_eq!(blinds.blinds.len(), L_size);
|
||||
|
||||
let zero = Scalar::zero();
|
||||
let blind_Zr = match blind_Zr_opt {
|
||||
Some(p) => p,
|
||||
None => &zero,
|
||||
};
|
||||
let blind_Zr = blind_Zr_opt.map_or(&zero, |p| p);
|
||||
|
||||
// compute the L and R vectors
|
||||
let eq = EqPolynomial::new(r.to_vec());
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
use core::fmt;
|
||||
use core::fmt::Debug;
|
||||
use thiserror::Error;
|
||||
|
||||
pub struct ProofVerifyError;
|
||||
|
||||
impl fmt::Display for ProofVerifyError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "Proof verification failed")
|
||||
}
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ProofVerifyError {
|
||||
#[error("Proof verification failed")]
|
||||
InternalError,
|
||||
#[error("Compressed group element failed to decompress: {0:?}")]
|
||||
DecompressionError([u8; 32]),
|
||||
}
|
||||
|
||||
impl fmt::Debug for ProofVerifyError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{{ file: {}, line: {} }}", file!(), line!())
|
||||
impl Default for ProofVerifyError {
|
||||
fn default() -> Self {
|
||||
ProofVerifyError::InternalError
|
||||
}
|
||||
}
|
||||
|
||||
|
|
16
src/group.rs
16
src/group.rs
|
@ -1,9 +1,25 @@
|
|||
use super::errors::ProofVerifyError;
|
||||
use super::scalar::{Scalar, ScalarBytes, ScalarBytesFromScalar};
|
||||
use core::borrow::Borrow;
|
||||
use core::ops::{Mul, MulAssign};
|
||||
|
||||
pub type GroupElement = curve25519_dalek::ristretto::RistrettoPoint;
|
||||
pub type CompressedGroup = curve25519_dalek::ristretto::CompressedRistretto;
|
||||
|
||||
pub trait CompressedGroupExt {
|
||||
type Group;
|
||||
fn unpack(&self) -> Result<Self::Group, ProofVerifyError>;
|
||||
}
|
||||
|
||||
impl CompressedGroupExt for CompressedGroup {
|
||||
type Group = curve25519_dalek::ristretto::RistrettoPoint;
|
||||
fn unpack(&self) -> Result<Self::Group, ProofVerifyError> {
|
||||
self
|
||||
.decompress()
|
||||
.ok_or_else(|| ProofVerifyError::DecompressionError(self.to_bytes()))
|
||||
}
|
||||
}
|
||||
|
||||
pub const GROUP_BASEPOINT_COMPRESSED: CompressedGroup =
|
||||
curve25519_dalek::constants::RISTRETTO_BASEPOINT_COMPRESSED;
|
||||
|
||||
|
|
38
src/lib.rs
38
src/lib.rs
|
@ -339,17 +339,14 @@ impl SNARK {
|
|||
|
||||
let timer_sat_proof = Timer::new("verify_sat_proof");
|
||||
assert_eq!(input.assignment.len(), comm.comm.get_num_inputs());
|
||||
let (rx, ry) = self
|
||||
.r1cs_sat_proof
|
||||
.verify(
|
||||
comm.comm.get_num_vars(),
|
||||
comm.comm.get_num_cons(),
|
||||
&input.assignment,
|
||||
&self.inst_evals,
|
||||
transcript,
|
||||
&gens.gens_r1cs_sat,
|
||||
)
|
||||
.unwrap();
|
||||
let (rx, ry) = self.r1cs_sat_proof.verify(
|
||||
comm.comm.get_num_vars(),
|
||||
comm.comm.get_num_cons(),
|
||||
&input.assignment,
|
||||
&self.inst_evals,
|
||||
transcript,
|
||||
&gens.gens_r1cs_sat,
|
||||
)?;
|
||||
timer_sat_proof.stop();
|
||||
|
||||
let timer_eval_proof = Timer::new("verify_eval_proof");
|
||||
|
@ -454,17 +451,14 @@ impl NIZK {
|
|||
|
||||
let timer_sat_proof = Timer::new("verify_sat_proof");
|
||||
assert_eq!(input.assignment.len(), inst.inst.get_num_inputs());
|
||||
let (rx, ry) = self
|
||||
.r1cs_sat_proof
|
||||
.verify(
|
||||
inst.inst.get_num_vars(),
|
||||
inst.inst.get_num_cons(),
|
||||
&input.assignment,
|
||||
&inst_evals,
|
||||
transcript,
|
||||
&gens.gens_r1cs_sat,
|
||||
)
|
||||
.unwrap();
|
||||
let (rx, ry) = self.r1cs_sat_proof.verify(
|
||||
inst.inst.get_num_vars(),
|
||||
inst.inst.get_num_cons(),
|
||||
&input.assignment,
|
||||
&inst_evals,
|
||||
transcript,
|
||||
&gens.gens_r1cs_sat,
|
||||
)?;
|
||||
|
||||
// verify if claimed rx and ry are correct
|
||||
assert_eq!(rx, *claimed_rx);
|
||||
|
|
|
@ -148,10 +148,10 @@ impl BulletReductionProof {
|
|||
if lg_n >= 32 {
|
||||
// 4 billion multiplications should be enough for anyone
|
||||
// and this check prevents overflow in 1<<lg_n below.
|
||||
return Err(ProofVerifyError);
|
||||
return Err(ProofVerifyError::InternalError);
|
||||
}
|
||||
if n != (1 << lg_n) {
|
||||
return Err(ProofVerifyError);
|
||||
return Err(ProofVerifyError::InternalError);
|
||||
}
|
||||
|
||||
// 1. Recompute x_k,...,x_1 based on the proof transcript
|
||||
|
@ -206,13 +206,13 @@ impl BulletReductionProof {
|
|||
let Ls = self
|
||||
.L_vec
|
||||
.iter()
|
||||
.map(|p| p.decompress().ok_or(ProofVerifyError))
|
||||
.map(|p| p.decompress().ok_or(ProofVerifyError::InternalError))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let Rs = self
|
||||
.R_vec
|
||||
.iter()
|
||||
.map(|p| p.decompress().ok_or(ProofVerifyError))
|
||||
.map(|p| p.decompress().ok_or(ProofVerifyError::InternalError))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
let G_hat = GroupElement::vartime_multiscalar_mul(s.iter(), G.iter());
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#![allow(clippy::too_many_arguments)]
|
||||
use super::commitments::{Commitments, MultiCommitGens};
|
||||
use super::errors::ProofVerifyError;
|
||||
use super::group::CompressedGroup;
|
||||
use super::group::{CompressedGroup, CompressedGroupExt};
|
||||
use super::math::Math;
|
||||
use super::random::RandomTape;
|
||||
use super::scalar::Scalar;
|
||||
|
@ -64,17 +64,12 @@ impl KnowledgeProof {
|
|||
let c = transcript.challenge_scalar(b"c");
|
||||
|
||||
let lhs = self.z1.commit(&self.z2, gens_n).compress();
|
||||
let rhs = (c * C.decompress().expect("Could not decompress C")
|
||||
+ self
|
||||
.alpha
|
||||
.decompress()
|
||||
.expect("Could not decompress self.alpha"))
|
||||
.compress();
|
||||
let rhs = (c * C.unpack()? + self.alpha.unpack()?).compress();
|
||||
|
||||
if lhs == rhs {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ProofVerifyError)
|
||||
Err(ProofVerifyError::InternalError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -134,8 +129,8 @@ impl EqualityProof {
|
|||
|
||||
let c = transcript.challenge_scalar(b"c");
|
||||
let rhs = {
|
||||
let C = C1.decompress().unwrap() - C2.decompress().unwrap();
|
||||
(c * C + self.alpha.decompress().unwrap()).compress()
|
||||
let C = C1.unpack()? - C2.unpack()?;
|
||||
(c * C + self.alpha.unpack()?).compress()
|
||||
};
|
||||
|
||||
let lhs = (self.z * gens_n.h).compress();
|
||||
|
@ -143,7 +138,7 @@ impl EqualityProof {
|
|||
if lhs == rhs {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ProofVerifyError)
|
||||
Err(ProofVerifyError::InternalError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -280,7 +275,7 @@ impl ProductProof {
|
|||
&c,
|
||||
&MultiCommitGens {
|
||||
n: 1,
|
||||
G: vec![X.decompress().unwrap()],
|
||||
G: vec![X.unpack()?],
|
||||
h: gens_n.h,
|
||||
},
|
||||
&z3,
|
||||
|
@ -289,7 +284,7 @@ impl ProductProof {
|
|||
{
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ProofVerifyError)
|
||||
Err(ProofVerifyError::InternalError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -392,17 +387,16 @@ impl DotProductProof {
|
|||
|
||||
let c = transcript.challenge_scalar(b"c");
|
||||
|
||||
let mut result = c * Cx.decompress().unwrap() + self.delta.decompress().unwrap()
|
||||
== self.z.commit(&self.z_delta, gens_n);
|
||||
let mut result =
|
||||
c * Cx.unpack()? + self.delta.unpack()? == self.z.commit(&self.z_delta, gens_n);
|
||||
|
||||
let dotproduct_z_a = DotProductProof::compute_dotproduct(&self.z, &a);
|
||||
result &= c * Cy.decompress().unwrap() + self.beta.decompress().unwrap()
|
||||
== dotproduct_z_a.commit(&self.z_beta, gens_1);
|
||||
result &= c * Cy.unpack()? + self.beta.unpack()? == dotproduct_z_a.commit(&self.z_beta, gens_1);
|
||||
|
||||
if result {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ProofVerifyError)
|
||||
Err(ProofVerifyError::InternalError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -534,7 +528,7 @@ impl DotProductProofLog {
|
|||
Cx.append_to_transcript(b"Cx", transcript);
|
||||
Cy.append_to_transcript(b"Cy", transcript);
|
||||
|
||||
let Gamma = Cx.decompress().unwrap() + Cy.decompress().unwrap();
|
||||
let Gamma = Cx.unpack()? + Cy.unpack()?;
|
||||
|
||||
let (g_hat, Gamma_hat, a_hat) = self
|
||||
.bullet_reduction_proof
|
||||
|
@ -547,9 +541,9 @@ impl DotProductProofLog {
|
|||
let c = transcript.challenge_scalar(b"c");
|
||||
|
||||
let c_s = &c;
|
||||
let beta_s = self.beta.decompress().unwrap();
|
||||
let beta_s = self.beta.unpack()?;
|
||||
let a_hat_s = &a_hat;
|
||||
let delta_s = self.delta.decompress().unwrap();
|
||||
let delta_s = self.delta.unpack()?;
|
||||
let z1_s = &self.z1;
|
||||
let z2_s = &self.z2;
|
||||
|
||||
|
@ -561,7 +555,7 @@ impl DotProductProofLog {
|
|||
if lhs == rhs {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(ProofVerifyError)
|
||||
Err(ProofVerifyError::InternalError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -211,11 +211,11 @@ impl R1CSInstance {
|
|||
};
|
||||
|
||||
assert_eq!(
|
||||
inst.is_sat(&Z[0..num_vars].to_vec(), &Z[num_vars + 1..].to_vec()),
|
||||
inst.is_sat(&Z[..num_vars].to_vec(), &Z[num_vars + 1..].to_vec()),
|
||||
true,
|
||||
);
|
||||
|
||||
(inst, Z[0..num_vars].to_vec(), Z[num_vars + 1..].to_vec())
|
||||
(inst, Z[..num_vars].to_vec(), Z[num_vars + 1..].to_vec())
|
||||
}
|
||||
|
||||
pub fn is_sat(&self, vars: &[Scalar], input: &[Scalar]) -> bool {
|
||||
|
|
|
@ -370,18 +370,14 @@ impl R1CSProof {
|
|||
let claim_phase1 = Scalar::zero()
|
||||
.commit(&Scalar::zero(), &gens.gens_sc.gens_1)
|
||||
.compress();
|
||||
let (comm_claim_post_phase1, rx) = self
|
||||
.sc_proof_phase1
|
||||
.verify(
|
||||
&claim_phase1,
|
||||
num_rounds_x,
|
||||
3,
|
||||
&gens.gens_sc.gens_1,
|
||||
&gens.gens_sc.gens_4,
|
||||
transcript,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let (comm_claim_post_phase1, rx) = self.sc_proof_phase1.verify(
|
||||
&claim_phase1,
|
||||
num_rounds_x,
|
||||
3,
|
||||
&gens.gens_sc.gens_1,
|
||||
&gens.gens_sc.gens_4,
|
||||
transcript,
|
||||
)?;
|
||||
// perform the intermediate sum-check test with claimed Az, Bz, and Cz
|
||||
let (comm_Az_claim, comm_Bz_claim, comm_Cz_claim, comm_prod_Az_Bz_claims) = &self.claims_phase2;
|
||||
let (pok_Cz_claim, proof_prod) = &self.pok_claims_phase2;
|
||||
|
|
|
@ -398,7 +398,7 @@ impl Scalar {
|
|||
pub fn from_bytes(bytes: &[u8; 32]) -> CtOption<Scalar> {
|
||||
let mut tmp = Scalar([0, 0, 0, 0]);
|
||||
|
||||
tmp.0[0] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap());
|
||||
tmp.0[0] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[..8]).unwrap());
|
||||
tmp.0[1] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap());
|
||||
tmp.0[2] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap());
|
||||
tmp.0[3] = u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap());
|
||||
|
@ -429,7 +429,7 @@ impl Scalar {
|
|||
let tmp = Scalar::montgomery_reduce(self.0[0], self.0[1], self.0[2], self.0[3], 0, 0, 0, 0);
|
||||
|
||||
let mut res = [0; 32];
|
||||
res[0..8].copy_from_slice(&tmp.0[0].to_le_bytes());
|
||||
res[..8].copy_from_slice(&tmp.0[0].to_le_bytes());
|
||||
res[8..16].copy_from_slice(&tmp.0[1].to_le_bytes());
|
||||
res[16..24].copy_from_slice(&tmp.0[2].to_le_bytes());
|
||||
res[24..32].copy_from_slice(&tmp.0[3].to_le_bytes());
|
||||
|
@ -441,7 +441,7 @@ impl Scalar {
|
|||
/// a `Scalar` by reducing by the modulus.
|
||||
pub fn from_bytes_wide(bytes: &[u8; 64]) -> Scalar {
|
||||
Scalar::from_u512([
|
||||
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[0..8]).unwrap()),
|
||||
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[..8]).unwrap()),
|
||||
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[8..16]).unwrap()),
|
||||
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[16..24]).unwrap()),
|
||||
u64::from_le_bytes(<[u8; 8]>::try_from(&bytes[24..32]).unwrap()),
|
||||
|
|
|
@ -1400,9 +1400,7 @@ impl PolyEvalNetworkProof {
|
|||
|
||||
let (claims_mem, rand_mem, mut claims_ops, claims_dotp, rand_ops) = self
|
||||
.proof_prod_layer
|
||||
.verify(num_ops, num_cells, evals, transcript)
|
||||
.unwrap();
|
||||
|
||||
.verify(num_ops, num_cells, evals, transcript)?;
|
||||
assert_eq!(claims_mem.len(), 4);
|
||||
assert_eq!(claims_ops.len(), 4 * num_instances);
|
||||
assert_eq!(claims_dotp.len(), 3 * num_instances);
|
||||
|
|
|
@ -80,7 +80,7 @@ impl UniPoly {
|
|||
}
|
||||
|
||||
pub fn compress(&self) -> CompressedUniPoly {
|
||||
let coeffs_except_linear_term = [&self.coeffs[0..1], &self.coeffs[2..]].concat();
|
||||
let coeffs_except_linear_term = [&self.coeffs[..1], &self.coeffs[2..]].concat();
|
||||
assert_eq!(coeffs_except_linear_term.len() + 1, self.coeffs.len());
|
||||
CompressedUniPoly {
|
||||
coeffs_except_linear_term,
|
||||
|
|
Загрузка…
Ссылка в новой задаче