diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index 7bd3a2bff..b0beffda0 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -579,15 +579,7 @@ pub fn verify_chip_proof( is_infinity: Usize::uninit(builder), }; - if composed_cs.has_ecc_ops() { - builder.assert_nonzero(&chip_proof.has_ecc_proof); - let ecc_proof = &chip_proof.ecc_proof; - builder.assert_usize_eq(ecc_proof.sum.is_infinity.clone(), Usize::from(0)); - verify_ecc_proof(builder, challenger, ecc_proof, unipoly_extrapolator); - builder.assign(&shard_ec_sum, ecc_proof.sum.clone()); - } else { - builder.assign(&shard_ec_sum.is_infinity, Usize::from(1)); - } + builder.assign(&shard_ec_sum.is_infinity, Usize::from(1)); let tower_proof = &chip_proof.tower_proof; let num_variables: Array> = builder.dyn_array(num_batched); @@ -631,6 +623,14 @@ pub fn verify_chip_proof( }); } + if composed_cs.has_ecc_ops() { + builder.assert_nonzero(&chip_proof.has_ecc_proof); + let ecc_proof = &chip_proof.ecc_proof; + builder.assert_usize_eq(ecc_proof.sum.is_infinity.clone(), Usize::from(0)); + verify_ecc_proof(builder, challenger, ecc_proof, unipoly_extrapolator); + builder.assign(&shard_ec_sum, ecc_proof.sum.clone()); + } + let num_rw_records: Usize = builder.eval(r_counts_per_instance + w_counts_per_instance); builder.assert_usize_eq(record_evals.len(), num_rw_records.clone()); builder.assert_usize_eq(logup_p_evals.len(), lk_counts_per_instance.clone()); @@ -849,6 +849,101 @@ pub fn verify_chip_proof( } } + if composed_cs.has_ecc_ops() { + let [x_group_idx, y_group_idx, slope_group_idx] = first_layer + .ecc_bridge_group_indices() + .expect("ecc bridge selectors missing"); + + transcript_observe_label(builder, challenger, b"ecc_gkr_bridge_r"); + let sample_r: Ext = challenger.sample_ext(builder); + let one_minus_r: Ext = builder.eval(one - sample_r); + let ecc_proof = &chip_proof.ecc_proof; + + let xy_point_len: Usize = builder.eval(ecc_proof.rt.fs.len() + Usize::from(1)); + let xy_point: Array> = builder.dyn_array(xy_point_len); + builder.set(&xy_point, 0, sample_r); + builder + .range(0, ecc_proof.rt.fs.len()) + .for_each(|idx_vec, builder| { + let idx = idx_vec[0]; + let v = builder.get(&ecc_proof.rt.fs, idx); + let shifted_idx = Usize::Var(Var::uninit(builder)); + builder.assign(&shifted_idx, idx + Usize::from(1)); + builder.set(&xy_point, shifted_idx, v); + }); + + let s_point_len: Usize = builder.eval(ecc_proof.rt.fs.len() + Usize::from(1)); + let s_point: Array> = builder.dyn_array(s_point_len); + builder + .range(0, ecc_proof.rt.fs.len()) + .for_each(|idx_vec, builder| { + let idx = idx_vec[0]; + let v = builder.get(&ecc_proof.rt.fs, idx); + builder.set(&s_point, idx, v); + }); + builder.set(&s_point, ecc_proof.rt.fs.len(), sample_r); + + let degree = SEPTIC_EXTENSION_DEGREE; + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[x_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("ecc bridge x group must use EvalExpression::Single"); + }; + let x0 = builder.get(&ecc_proof.evals, 3 + degree + idx); + let x1 = builder.get(&ecc_proof.evals, 3 + degree * 3 + idx); + let eval = builder.eval(x0 * one_minus_r + x1 * sample_r); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: xy_point.clone(), + }, + eval, + }); + builder.set(&out_evals, *out_idx, claim); + } + + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[y_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("ecc bridge y group must use EvalExpression::Single"); + }; + let y0 = builder.get(&ecc_proof.evals, 3 + degree * 2 + idx); + let y1 = builder.get(&ecc_proof.evals, 3 + degree * 4 + idx); + let eval = builder.eval(y0 * one_minus_r + y1 * sample_r); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: xy_point.clone(), + }, + eval, + }); + builder.set(&out_evals, *out_idx, claim); + } + + for (idx, eval_expr) in first_layer.out_sel_and_eval_exprs[slope_group_idx] + .1 + .iter() + .enumerate() + { + let EvalExpression::Single(out_idx) = eval_expr else { + panic!("ecc bridge slope group must use EvalExpression::Single"); + }; + let s1 = builder.get(&ecc_proof.evals, 3 + idx); + let eval = builder.eval(s1 * sample_r); + let claim: PointAndEvalVariable = builder.eval(PointAndEvalVariable { + point: PointVariable { + fs: s_point.clone(), + }, + eval, + }); + builder.set(&out_evals, *out_idx, claim); + } + } + builder.cycle_tracker_start("Verify GKR Circuit"); let rt = verify_gkr_circuit( builder, @@ -904,9 +999,8 @@ pub fn verify_gkr_circuit( }, } = layer_proof; - let expected_main_evals_len: Usize = Usize::from( - layer.n_witin + layer.n_fixed + layer.n_instance + layer.n_structural_witin, - ); + let expected_main_evals_len: Usize = + Usize::from(layer.n_witin + layer.n_fixed + layer.n_structural_witin); builder.assert_usize_eq(expected_main_evals_len, main_evals.len()); transcript_observe_label(builder, challenger, b"combine subset evals"); @@ -947,7 +1041,7 @@ pub fn verify_gkr_circuit( unipoly_extrapolator, ); - let structural_witin_offset = layer.n_witin + layer.n_fixed + layer.n_instance; + let structural_witin_offset = layer.n_witin + layer.n_fixed; // check selector evaluations layer diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 6a679edda..eb3430539 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -8,7 +8,10 @@ use crate::{ constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, hal::{DeviceProvingKey, EccQuarkProver, ProofInput, TowerProverSpec}, septic_curve::{SepticExtension, SepticPoint, SymbolicSepticExtension}, - utils::{infer_tower_logup_witness, infer_tower_product_witness}, + utils::{ + assign_group_evals, derive_ecc_bridge_claims, extract_ecc_quark_witness_inputs, + infer_tower_logup_witness, infer_tower_product_witness, split_rotation_evals, + }, }, structs::{ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs}, }; @@ -16,7 +19,6 @@ use either::Either; use ff_ext::ExtensionField; use gkr_iop::{ cpu::{CpuBackend, CpuProver}, - evaluation::EvalExpression, gkr::{self, Evaluation, GKRProof, GKRProverOutput, layer::LayerWitness}, hal::ProverBackend, selector::{SelectorContext, SelectorType}, @@ -312,19 +314,22 @@ impl> EccQuarkProver( &self, - num_instances: usize, - xs: Vec>>, - ys: Vec>>, - invs: Vec>>, + cs: &ComposedConstrainSystem, + input: &ProofInput<'a, CpuBackend>, transcript: &mut impl Transcript, - ) -> Result, ZKVMError> { - Ok(CpuEccProver::create_ecc_proof( - num_instances, - xs, - ys, - invs, + ) -> Result>, ZKVMError> { + let Some(ecc_inputs) = extract_ecc_quark_witness_inputs::>(cs, input) + else { + return Ok(None); + }; + + Ok(Some(CpuEccProver::create_ecc_proof( + input.num_instances(), + ecc_inputs.xs, + ecc_inputs.ys, + ecc_inputs.slopes, transcript, - )) + ))) } } @@ -876,6 +881,7 @@ impl> MainSumcheckProver, rotation: Option>, + ecc_proof: Option<&EccQuarkProof>, input: &'b ProofInput<'a, CpuBackend>, composed_cs: &ComposedConstrainSystem, challenges: &[E; 2], @@ -937,6 +943,7 @@ impl> MainSumcheckProver> MainSumcheckProver], - eval_exprs: &[EvalExpression], - evals: &[E], - point: &Point| { - assert_eq!( - eval_exprs.len(), - evals.len(), - "rotation eval length mismatch" - ); - for (eval_expr, eval) in eval_exprs.iter().zip_eq(evals.iter()) { - let EvalExpression::Single(index) = eval_expr else { - panic!("rotation groups must use EvalExpression::Single"); - }; - out_evals[*index] = PointAndEval::new(point.clone(), *eval); - } - }; - - assign_group( + assign_group_evals( &mut out_evals, &first_layer.out_sel_and_eval_exprs[left_group_idx].1, &left_evals, &rotation.left_point, ); - assign_group( + assign_group_evals( &mut out_evals, &first_layer.out_sel_and_eval_exprs[right_group_idx].1, &right_evals, &rotation.right_point, ); - assign_group( + assign_group_evals( &mut out_evals, &first_layer.out_sel_and_eval_exprs[point_group_idx].1, &point_evals, &rotation.point, ); } + + if let Some(ecc_proof) = ecc_proof { + let Some([x_group_idx, y_group_idx, slope_group_idx]) = + first_layer.ecc_bridge_group_indices() + else { + panic!("ecc proof provided for non-ecc layer") + }; + + let sample_r = transcript.sample_and_append_vec(b"ecc_gkr_bridge_r", 1)[0]; + let claims = derive_ecc_bridge_claims(ecc_proof, sample_r, num_var_with_rotation); + + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x_group_idx].1, + &claims.x_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y_group_idx].1, + &claims.y_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[slope_group_idx].1, + &claims.s_evals, + &claims.s_point, + ); + } let GKRProverOutput { gkr_proof, opening_evaluations, diff --git a/ceno_zkvm/src/scheme/gpu/mod.rs b/ceno_zkvm/src/scheme/gpu/mod.rs index 34c19ea16..b7fc3bb0b 100644 --- a/ceno_zkvm/src/scheme/gpu/mod.rs +++ b/ceno_zkvm/src/scheme/gpu/mod.rs @@ -5,18 +5,22 @@ use super::hal::{ use crate::{ error::ZKVMError, scheme::{ + constants::SEPTIC_EXTENSION_DEGREE, cpu::TowerRelationOutput, hal::{ DeviceProvingKey, MainSumcheckEvals, ProofInput, RotationProverOutput, TowerProverSpec, }, + utils::{ + assign_group_evals, derive_ecc_bridge_claims, extract_ecc_quark_witness_inputs, + split_rotation_evals, + }, }, - structs::{ComposedConstrainSystem, PointAndEval, TowerProofs}, + structs::{ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs}, }; use ceno_gpu::bb31::{CudaHalBB31, GpuPolynomial}; use either::Either; use ff_ext::ExtensionField; use gkr_iop::{ - evaluation::EvalExpression, gkr::{ self, Evaluation, GKRProof, GKRProverOutput, layer::{LayerWitness, gpu::utils::extract_mle_relationships_from_monomial_terms}, @@ -289,6 +293,7 @@ pub fn prove_main_constraints_impl< >( rt_tower: Vec, rotation: Option>, + ecc_proof: Option<&EccQuarkProof>, input: &ProofInput<'_, GpuBackend>, composed_cs: &ComposedConstrainSystem, challenges: &[E; 2], @@ -350,6 +355,7 @@ pub fn prove_main_constraints_impl< let mut out_evals = vec![PointAndEval::new(rt_tower.clone(), E::ZERO); gkr_circuit.n_evaluations]; + if let Some(rotation) = rotation.as_ref() { let Some([left_group_idx, right_group_idx, point_group_idx]) = first_layer.rotation_selector_group_indices() @@ -357,45 +363,21 @@ pub fn prove_main_constraints_impl< panic!("rotation proof provided for non-rotation layer") }; - let mut left_evals = Vec::new(); - let mut right_evals = Vec::new(); - let mut point_evals = Vec::new(); - for chunk in rotation.proof.evals.chunks_exact(3) { - left_evals.push(chunk[0]); - right_evals.push(chunk[1]); - point_evals.push(chunk[2]); - } - - let assign_group = |out_evals: &mut [PointAndEval], - eval_exprs: &[EvalExpression], - evals: &[E], - point: &Point| { - assert_eq!( - eval_exprs.len(), - evals.len(), - "rotation eval length mismatch" - ); - for (eval_expr, eval) in eval_exprs.iter().zip_eq(evals.iter()) { - let EvalExpression::Single(index) = eval_expr else { - panic!("rotation groups must use EvalExpression::Single"); - }; - out_evals[*index] = PointAndEval::new(point.clone(), *eval); - } - }; + let (left_evals, right_evals, point_evals) = split_rotation_evals(&rotation.proof.evals); - assign_group( + assign_group_evals( &mut out_evals, &first_layer.out_sel_and_eval_exprs[left_group_idx].1, &left_evals, &rotation.left_point, ); - assign_group( + assign_group_evals( &mut out_evals, &first_layer.out_sel_and_eval_exprs[right_group_idx].1, &right_evals, &rotation.right_point, ); - assign_group( + assign_group_evals( &mut out_evals, &first_layer.out_sel_and_eval_exprs[point_group_idx].1, &point_evals, @@ -403,6 +385,36 @@ pub fn prove_main_constraints_impl< ); } + if let Some(ecc_proof) = ecc_proof { + let Some([x_group_idx, y_group_idx, slope_group_idx]) = + first_layer.ecc_bridge_group_indices() + else { + panic!("ecc proof provided for non-ecc layer") + }; + + let sample_r = transcript.sample_and_append_vec(b"ecc_gkr_bridge_r", 1)[0]; + let claims = derive_ecc_bridge_claims(ecc_proof, sample_r, num_var_with_rotation); + + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x_group_idx].1, + &claims.x_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y_group_idx].1, + &claims.y_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[slope_group_idx].1, + &claims.s_evals, + &claims.s_point, + ); + } + let GKRProverOutput { gkr_proof, opening_evaluations, @@ -459,12 +471,20 @@ pub fn prove_main_constraints_impl< level = "trace" )] pub fn prove_ec_sum_quark_impl<'a, E: ExtensionField, PCS: PolynomialCommitmentScheme>( - num_instances: usize, - xs: Vec>>, - ys: Vec>>, - invs: Vec>>, + composed_cs: &ComposedConstrainSystem, + input: &ProofInput<'a, GpuBackend>, transcript: &mut impl Transcript, -) -> Result, ZKVMError> { +) -> Result>, ZKVMError> { + let Some(ecc_inputs) = + extract_ecc_quark_witness_inputs::>(composed_cs, input) + else { + return Ok(None); + }; + let xs = ecc_inputs.xs; + let ys = ecc_inputs.ys; + let invs = ecc_inputs.slopes; + + let num_instances = input.num_instances(); let stream = gkr_iop::gpu::get_thread_stream(); assert_eq!(xs.len(), SEPTIC_EXTENSION_DEGREE); assert_eq!(ys.len(), SEPTIC_EXTENSION_DEGREE); @@ -670,13 +690,13 @@ pub fn prove_ec_sum_quark_impl<'a, E: ExtensionField, PCS: PolynomialCommitmentS assert_eq!(evals.len(), 3 + SEPTIC_EXTENSION_DEGREE * 7); let final_sum = SepticPoint::from_affine(final_sum_x.clone(), final_sum_y.clone()); - Ok(EccQuarkProof { + Ok(Some(EccQuarkProof { zerocheck_proof: proof_gpu_e, num_instances, evals, rt, sum: final_sum, - }) + })) } impl> TraceCommitter> @@ -1237,6 +1257,7 @@ impl> MainSumcheckProver, rotation: Option>, + ecc_proof: Option<&EccQuarkProof>, // _records: Vec>, // not used by GPU after delegation input: &'b ProofInput<'a, GpuBackend>, composed_cs: &ComposedConstrainSystem, @@ -1257,6 +1278,7 @@ impl> MainSumcheckProver( rt_tower, rotation, + ecc_proof, input, composed_cs, challenges, @@ -1290,21 +1312,19 @@ impl> EccQuarkProver( &self, - num_instances: usize, - xs: Vec>>, - ys: Vec>>, - invs: Vec>>, + composed_cs: &ComposedConstrainSystem, + input: &ProofInput<'a, GpuBackend>, transcript: &mut impl Transcript, - ) -> Result, ZKVMError> { - // n = num_vars of the ecc quark sumcheck (xs[0].num_vars - 1) - let n = xs[0].mle.num_vars() - 1; + ) -> Result>, ZKVMError> { let cuda_hal = get_cuda_hal().expect("Failed to get CUDA HAL"); let gpu_mem_tracker = init_gpu_mem_tracker(&cuda_hal, "prove_ec_sum_quark"); - let res = prove_ec_sum_quark_impl::(num_instances, xs, ys, invs, transcript); + let res = prove_ec_sum_quark_impl::(composed_cs, input, transcript); - let estimated_bytes = estimate_ecc_quark_bytes_from_num_vars(n); - check_gpu_mem_estimation(gpu_mem_tracker, estimated_bytes); + if let Ok(Some(proof)) = &res { + let estimated_bytes = estimate_ecc_quark_bytes_from_num_vars(proof.rt.len()); + check_gpu_mem_estimation(gpu_mem_tracker, estimated_bytes); + } res } diff --git a/ceno_zkvm/src/scheme/hal.rs b/ceno_zkvm/src/scheme/hal.rs index 483eb571b..65fe06f2d 100644 --- a/ceno_zkvm/src/scheme/hal.rs +++ b/ceno_zkvm/src/scheme/hal.rs @@ -108,12 +108,10 @@ pub trait TraceCommitter { pub trait EccQuarkProver { fn prove_ec_sum_quark<'a>( &self, - num_instances: usize, - xs: Vec>>, - ys: Vec>>, - invs: Vec>>, + cs: &ComposedConstrainSystem, + input: &ProofInput<'a, PB>, transcript: &mut impl Transcript, - ) -> Result, ZKVMError>; + ) -> Result>, ZKVMError>; } pub trait TowerProver { @@ -184,11 +182,12 @@ pub trait MainSumcheckProver { // the validity of read/write/logup records through sumchecks; // 2. multiple multiplication relations between witness multilinear polynomials // achieved via zerochecks. - #[allow(clippy::type_complexity)] + #[allow(clippy::type_complexity, clippy::too_many_arguments)] fn prove_main_constraints<'a, 'b>( &self, rt_tower: Vec, rotation: Option>, + ecc_proof: Option<&EccQuarkProof>, input: &'b ProofInput<'a, PB>, cs: &ComposedConstrainSystem, challenges: &[PB::E; 2], diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index d7e7c8b71..cf04c0ed7 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -8,14 +8,13 @@ use std::{collections::BTreeMap, marker::PhantomData, sync::Arc}; #[cfg(feature = "gpu")] use crate::scheme::gpu::estimate_chip_proof_memory; use crate::scheme::{ - constants::SEPTIC_EXTENSION_DEGREE, hal::MainSumcheckEvals, scheduler::{ChipScheduler, ChipTask, ChipTaskResult}, }; use either::Either; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; -use multilinear_extensions::{Expression, Instance}; +use multilinear_extensions::Instance; use p3::field::FieldAlgebra; use std::iter::Iterator; use sumcheck::{ @@ -429,39 +428,6 @@ impl< let log2_num_instances = input.log2_num_instances(); let num_var_with_rotation = log2_num_instances + cs.rotation_vars().unwrap_or(0); - // run ecc quark prover - let ecc_proof = if !cs.zkvm_v1_css.ec_final_sum.is_empty() { - let span = entered_span!("run_ecc_final_sum", profiling_2 = true); - let ec_point_exprs = &cs.zkvm_v1_css.ec_point_exprs; - assert_eq!(ec_point_exprs.len(), SEPTIC_EXTENSION_DEGREE * 2); - let mut xs_ys = ec_point_exprs - .iter() - .map(|expr| match expr { - Expression::WitIn(id) => input.witness[*id as usize].clone(), - _ => unreachable!("ec point's expression must be WitIn"), - }) - .collect_vec(); - let ys = xs_ys.split_off(SEPTIC_EXTENSION_DEGREE); - let xs = xs_ys; - let slopes = cs - .zkvm_v1_css - .ec_slope_exprs - .iter() - .map(|expr| match expr { - Expression::WitIn(id) => input.witness[*id as usize].clone(), - _ => unreachable!("slope's expression must be WitIn"), - }) - .collect_vec(); - let ecc_proof = Some(info_span!("[ceno] prove_ec_sum_quark").in_scope(|| { - self.device - .prove_ec_sum_quark(input.num_instances(), xs, ys, slopes, transcript) - })?); - exit_span!(span); - ecc_proof - } else { - None - }; - // build main witness let records = info_span!("[ceno] build_main_witness") .in_scope(|| build_main_witness::(cs, input, challenges)); @@ -481,6 +447,11 @@ impl< num_var_with_rotation, ); + let span = entered_span!("run_ecc_final_sum", profiling_2 = true); + let ecc_proof = info_span!("[ceno] prove_ec_sum_quark") + .in_scope(|| self.device.prove_ec_sum_quark(cs, input, transcript))?; + exit_span!(span); + let span = entered_span!("prove_rotation", profiling_2 = true); let rotation = info_span!("[ceno] prove_rotation").in_scope(|| { self.device @@ -496,6 +467,7 @@ impl< self.device.prove_main_constraints( rt_tower, rotation.clone(), + ecc_proof.as_ref(), input, cs, challenges, @@ -787,38 +759,6 @@ where }); } - // run ecc quark prover using _impl function - let ecc_proof = if !cs.zkvm_v1_css.ec_final_sum.is_empty() { - let span = entered_span!("run_ecc_final_sum", profiling_2 = true); - let ec_point_exprs = &cs.zkvm_v1_css.ec_point_exprs; - assert_eq!(ec_point_exprs.len(), SEPTIC_EXTENSION_DEGREE * 2); - let mut xs_ys = ec_point_exprs - .iter() - .map(|expr| match expr { - Expression::WitIn(id) => input.witness[*id as usize].clone(), - _ => unreachable!("ec point's expression must be WitIn"), - }) - .collect_vec(); - let ys = xs_ys.split_off(SEPTIC_EXTENSION_DEGREE); - let xs = xs_ys; - let slopes = cs - .zkvm_v1_css - .ec_slope_exprs - .iter() - .map(|expr| match expr { - Expression::WitIn(id) => input.witness[*id as usize].clone(), - _ => unreachable!("slope's expression must be WitIn"), - }) - .collect_vec(); - let ecc_proof = Some(info_span!("[ceno] prove_ec_sum_quark").in_scope(|| { - prove_ec_sum_quark_impl::(input.num_instances(), xs, ys, slopes, transcript) - })?); - exit_span!(span); - ecc_proof - } else { - None - }; - // build main witness let records = info_span!("[ceno] build_main_witness").in_scope(|| { @@ -842,6 +782,11 @@ where assert_eq!(rt_tower.len(), num_var_with_rotation,); + let span = entered_span!("run_ecc_final_sum", profiling_2 = true); + let ecc_proof = info_span!("[ceno] prove_ec_sum_quark") + .in_scope(|| prove_ec_sum_quark_impl::(cs, &input, transcript))?; + exit_span!(span); + let span = entered_span!("prove_rotation", profiling_2 = true); let rotation = info_span!("[ceno] prove_rotation").in_scope(|| { prove_rotation_impl::(cs, &input, &rt_tower, challenges, transcript) @@ -855,6 +800,7 @@ where prove_main_constraints_impl::( rt_tower, rotation.clone(), + ecc_proof.as_ref(), &input, cs, challenges, diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 8583e2710..ecac9ca97 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -1,9 +1,9 @@ use crate::{ scheme::{ - constants::MIN_PAR_SIZE, + constants::{MIN_PAR_SIZE, SEPTIC_EXTENSION_DEGREE}, hal::{ProofInput, ProverDevice}, }, - structs::ComposedConstrainSystem, + structs::{ComposedConstrainSystem, EccQuarkProof, PointAndEval}, }; use either::Either; use ff_ext::ExtensionField; @@ -13,9 +13,10 @@ use gkr_iop::{ hal::{MultilinearPolynomial, ProtocolWitnessGeneratorProver, ProverBackend}, }; use itertools::Itertools; -use mpcs::PolynomialCommitmentScheme; +use mpcs::{Point, PolynomialCommitmentScheme}; pub use multilinear_extensions::wit_infer_by_expr; use multilinear_extensions::{ + Expression, mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, util::ceil_log2, }; @@ -29,6 +30,125 @@ use rayon::{ use std::{iter, sync::Arc}; use witness::next_pow2_instance_padding; +pub(crate) struct EccBridgeClaims { + pub(crate) xy_point: Point, + pub(crate) s_point: Point, + pub(crate) x_evals: Vec, + pub(crate) y_evals: Vec, + pub(crate) s_evals: Vec, +} + +pub(crate) struct EccQuarkWitnessInputs<'a, PB: ProverBackend> { + pub(crate) xs: Vec>>, + pub(crate) ys: Vec>>, + pub(crate) slopes: Vec>>, +} + +pub(crate) fn extract_ecc_quark_witness_inputs<'a, PB: ProverBackend>( + cs: &ComposedConstrainSystem, + input: &ProofInput<'a, PB>, +) -> Option> { + let cs = &cs.zkvm_v1_css; + if cs.ec_final_sum.is_empty() { + return None; + } + + let ec_point_exprs = &cs.ec_point_exprs; + assert_eq!(ec_point_exprs.len(), SEPTIC_EXTENSION_DEGREE * 2); + let mut xs_ys = ec_point_exprs + .iter() + .map(|expr| match expr { + Expression::WitIn(id) => input.witness[*id as usize].clone(), + _ => unreachable!("ec point's expression must be WitIn"), + }) + .collect_vec(); + let ys = xs_ys.split_off(SEPTIC_EXTENSION_DEGREE); + let xs = xs_ys; + + let slopes = cs + .ec_slope_exprs + .iter() + .map(|expr| match expr { + Expression::WitIn(id) => input.witness[*id as usize].clone(), + _ => unreachable!("slope's expression must be WitIn"), + }) + .collect_vec(); + + Some(EccQuarkWitnessInputs { xs, ys, slopes }) +} + +pub(crate) fn derive_ecc_bridge_claims( + ecc_proof: &EccQuarkProof, + sample_r: E, + num_var_with_rotation: usize, +) -> EccBridgeClaims { + let degree = SEPTIC_EXTENSION_DEGREE; + let evals = &ecc_proof.evals[3..]; + assert_eq!(evals.len(), degree * 7); + + let s1 = &evals[0..degree]; + let x0 = &evals[degree..2 * degree]; + let y0 = &evals[2 * degree..3 * degree]; + let x1 = &evals[3 * degree..4 * degree]; + let y1 = &evals[4 * degree..5 * degree]; + + let one_minus_r = E::ONE - sample_r; + let x_evals = x0 + .iter() + .zip_eq(x1.iter()) + .map(|(a, b)| *a * one_minus_r + *b * sample_r) + .collect_vec(); + let y_evals = y0 + .iter() + .zip_eq(y1.iter()) + .map(|(a, b)| *a * one_minus_r + *b * sample_r) + .collect_vec(); + let s_evals = s1.iter().map(|v| *v * sample_r).collect_vec(); + + let mut xy_point = vec![sample_r]; + xy_point.extend(ecc_proof.rt.iter().copied()); + assert_eq!(xy_point.len(), num_var_with_rotation); + + let mut s_point = ecc_proof.rt.clone(); + s_point.push(sample_r); + assert_eq!(s_point.len(), num_var_with_rotation); + + EccBridgeClaims { + xy_point, + s_point, + x_evals, + y_evals, + s_evals, + } +} + +pub(crate) fn split_rotation_evals(evals: &[E]) -> (Vec, Vec, Vec) { + let mut left_evals = Vec::new(); + let mut right_evals = Vec::new(); + let mut point_evals = Vec::new(); + for chunk in evals.chunks_exact(3) { + left_evals.push(chunk[0]); + right_evals.push(chunk[1]); + point_evals.push(chunk[2]); + } + (left_evals, right_evals, point_evals) +} + +pub(crate) fn assign_group_evals( + out_evals: &mut [PointAndEval], + eval_exprs: &[EvalExpression], + evals: &[E], + point: &Point, +) { + assert_eq!(eval_exprs.len(), evals.len(), "group eval length mismatch"); + for (eval_expr, eval) in eval_exprs.iter().zip_eq(evals.iter()) { + let EvalExpression::Single(index) = eval_expr else { + panic!("group must use EvalExpression::Single"); + }; + out_evals[*index] = PointAndEval::new(point.clone(), *eval); + } +} + /// Wrapper that asserts a shared reference is safe to send across threads. /// /// # Safety @@ -407,7 +527,7 @@ pub fn gkr_witness< phase1_witness_group: &[Arc>], structural_witness: &[Arc>], fixed: &[Arc>], - pub_io_mles: &[Arc>], + _pub_io_mles: &[Arc>], pub_io_evals: &[Either], challenges: &[E], ) -> (GKRCircuitWitness<'b, PB>, GKRCircuitOutput<'b, PB>) { @@ -438,16 +558,6 @@ pub fn gkr_witness< witness_mle_flatten[*index] = Some(fixed_mle.clone()); }); - first_layer - .in_eval_expr - .iter() - .skip(first_layer.n_witin + first_layer.n_fixed) - .take(first_layer.n_instance) - .zip_eq(pub_io_mles.iter()) - .for_each(|(index, pubio_mle)| { - witness_mle_flatten[*index] = Some(pubio_mle.clone()); - }); - // XXX currently fixed poly not support in layers > 1 // TODO process fixed (and probably short) mle // @@ -500,10 +610,7 @@ pub fn gkr_witness< assert_eq!( current_layer_wits.len(), - layer.n_witin - + layer.n_fixed - + layer.n_instance - + if i == 0 { layer.n_structural_witin } else { 0 } + layer.n_witin + layer.n_fixed + if i == 0 { layer.n_structural_witin } else { 0 } ); // infer current layer output diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 67ce8048f..619dfae28 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -17,6 +17,7 @@ use crate::{ scheme::{ constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, septic_curve::{SepticExtension, SepticPoint}, + utils::{assign_group_evals, derive_ecc_bridge_claims}, }, structs::{ ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs, VerifyingKey, @@ -536,19 +537,7 @@ impl> ZKVMVerifier assert_eq!(num_vars, log2_num_instances); }); - // verify ecc proof if exists - let shard_ec_sum: Option> = if composed_cs.has_ecc_ops() { - tracing::debug!("verifying ecc proof..."); - assert!(proof.ecc_proof.is_some()); - let ecc_proof = proof.ecc_proof.as_ref().unwrap(); - assert!(!ecc_proof.sum.is_infinity); - - EccVerifier::verify_ecc_proof(ecc_proof, transcript)?; - tracing::debug!("ecc proof verified."); - Some(ecc_proof.sum.clone()) - } else { - None - }; + let mut shard_ec_sum: Option> = None; // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; @@ -593,6 +582,17 @@ impl> ZKVMVerifier })?; } + if composed_cs.has_ecc_ops() { + tracing::debug!("verifying ecc proof..."); + assert!(proof.ecc_proof.is_some()); + let ecc_proof = proof.ecc_proof.as_ref().unwrap(); + assert!(!ecc_proof.sum.is_infinity); + + EccVerifier::verify_ecc_proof(ecc_proof, transcript)?; + tracing::debug!("ecc proof verified."); + shard_ec_sum = Some(ecc_proof.sum.clone()); + } + debug_assert!( chain!(&record_evals, &logup_p_evals, &logup_q_evals) .map(|e| &e.point) @@ -678,36 +678,19 @@ impl> ZKVMVerifier )); }; - let assign_group = |out_evals: &mut [PointAndEval], - eval_exprs: &[gkr_iop::evaluation::EvalExpression], - evals: &[E], - point: &Point| { - assert_eq!( - eval_exprs.len(), - evals.len(), - "rotation eval length mismatch" - ); - for (eval_expr, eval) in eval_exprs.iter().zip_eq(evals.iter()) { - let gkr_iop::evaluation::EvalExpression::Single(index) = eval_expr else { - panic!("rotation groups must use EvalExpression::Single"); - }; - out_evals[*index] = PointAndEval::new(point.clone(), *eval); - } - }; - - assign_group( + assign_group_evals( &mut out_evals, &first_layer.out_sel_and_eval_exprs[left_group_idx].1, &rotation_claims.left_evals, &rotation_claims.rotation_points.left, ); - assign_group( + assign_group_evals( &mut out_evals, &first_layer.out_sel_and_eval_exprs[right_group_idx].1, &rotation_claims.right_evals, &rotation_claims.rotation_points.right, ); - assign_group( + assign_group_evals( &mut out_evals, &first_layer.out_sel_and_eval_exprs[point_group_idx].1, &rotation_claims.target_evals, @@ -715,6 +698,38 @@ impl> ZKVMVerifier ); } + if let Some(ecc_proof) = proof.ecc_proof.as_ref() { + let Some([x_group_idx, y_group_idx, slope_group_idx]) = + first_layer.ecc_bridge_group_indices() + else { + return Err(ZKVMError::InvalidProof( + "ecc bridge claims expected but selectors are missing".into(), + )); + }; + + let sample_r = transcript.sample_and_append_vec(b"ecc_gkr_bridge_r", 1)[0]; + let claims = derive_ecc_bridge_claims(ecc_proof, sample_r, num_var_with_rotation); + + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[x_group_idx].1, + &claims.x_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[y_group_idx].1, + &claims.y_evals, + &claims.xy_point, + ); + assign_group_evals( + &mut out_evals, + &first_layer.out_sel_and_eval_exprs[slope_group_idx].1, + &claims.s_evals, + &claims.s_point, + ); + } + let pi = cs .instance .iter() diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 82719a7bc..31540e3bc 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -44,7 +44,7 @@ pub struct EccQuarkProof { pub zerocheck_proof: IOPProof, /// Number of EC points being summed pub num_instances: usize, - pub evals: Vec, // x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[0,rt], y[0,rt], s[0,rt] + pub evals: Vec, /* [sel_add, sel_bypass, sel_export] ++ [s[1,rt], x[rt,0], y[rt,0], x[rt,1], y[rt,1], x[1,rt], y[1,rt]] */ pub rt: Point, pub sum: SepticPoint, } diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 6c5ef93b5..2bbbd1825 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -418,6 +418,9 @@ impl TableCircuit for ShardRamCircuit { let selector_r = cb.create_placeholder_structural_witin(|| "selector_r"); let selector_w = cb.create_placeholder_structural_witin(|| "selector_w"); let selector_zero = cb.create_placeholder_structural_witin(|| "selector_zero"); + let selector_ecc_x = cb.create_placeholder_structural_witin(|| "selector_ecc_x"); + let selector_ecc_y = cb.create_placeholder_structural_witin(|| "selector_ecc_y"); + let selector_ecc_s = cb.create_placeholder_structural_witin(|| "selector_ecc_s"); let config = Self::construct_circuit(cb, param)?; @@ -439,6 +442,11 @@ impl TableCircuit for ShardRamCircuit { cb.cs.w_selector = Some(selector_w); cb.cs.zero_selector = Some(selector_zero.clone()); cb.cs.lk_selector = Some(selector_zero); + cb.cs.ec_bridge_selectors = Some([ + SelectorType::Whole(selector_ecc_x.expr()), + SelectorType::Whole(selector_ecc_y.expr()), + SelectorType::Whole(selector_ecc_s.expr()), + ]); // all shared the same selector let (out_evals, mut chip) = ( @@ -487,10 +495,18 @@ impl TableCircuit for ShardRamCircuit { // this is workaround, as call `construct_circuit` will not initialized selector // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` - assert_eq!(num_structural_witin, 3); + // ShardRam expects exactly these structural selectors: + // r, w, zero, ecc_x, ecc_y, ecc_s. + assert_eq!( + num_structural_witin, 6, + "ShardRam requires exactly 6 structural selectors (r,w,zero,ecc_x,ecc_y,ecc_s)" + ); let selector_r_witin = WitIn { id: 0 }; let selector_w_witin = WitIn { id: 1 }; let selector_zero_witin = WitIn { id: 2 }; + let selector_ecc_x_witin = WitIn { id: 3 }; + let selector_ecc_y_witin = WitIn { id: 4 }; + let selector_ecc_s_witin = WitIn { id: 5 }; let nthreads = max_usable_threads(); @@ -539,6 +555,15 @@ impl TableCircuit for ShardRamCircuit { ); RowMajorMatrix::new(value, num_structural_witin) }; + // ECC bridge selectors are `Whole`, so keep them active on all rows. + raw_structual_witin + .values + .par_chunks_mut(num_structural_witin) + .for_each(|row| { + set_val!(row, selector_ecc_x_witin, E::BaseField::ONE); + set_val!(row, selector_ecc_y_witin, E::BaseField::ONE); + set_val!(row, selector_ecc_s_witin, E::BaseField::ONE); + }); let raw_witin_iter = raw_witin.values[0..steps.len() * num_witin] .par_chunks_mut(num_instance_per_batch * num_witin); let raw_structual_witin_iter = raw_structual_witin.values diff --git a/gkr_iop/src/chip.rs b/gkr_iop/src/chip.rs index 5877ae9b8..39be9436b 100644 --- a/gkr_iop/src/chip.rs +++ b/gkr_iop/src/chip.rs @@ -1,6 +1,6 @@ use crate::{ circuit_builder::CircuitBuilder, - gkr::layer::{Layer, ROTATION_OPENING_COUNT}, + gkr::layer::{ECC_BRIDGE_OPENING_COUNT, Layer, ROTATION_OPENING_COUNT}, }; use ff_ext::ExtensionField; use itertools::Itertools; @@ -34,13 +34,19 @@ pub struct Chip { impl Chip { pub fn new_from_cb(cb: &CircuitBuilder) -> Chip { let rotation_eval_count = cb.cs.rotations.len() * ROTATION_OPENING_COUNT; + let ecc_eval_count = if cb.cs.ec_point_exprs.is_empty() { + 0 + } else { + cb.cs.ec_slope_exprs.len() * ECC_BRIDGE_OPENING_COUNT + }; let num_non_zero_outputs = cb.cs.w_expressions.len() + cb.cs.r_expressions.len() + cb.cs.lk_expressions.len() + cb.cs.w_table_expressions.len() + cb.cs.r_table_expressions.len() + cb.cs.lk_table_expressions.len() * 2 - + rotation_eval_count; + + rotation_eval_count + + ecc_eval_count; Self { n_fixed: cb.cs.num_fixed, n_committed: cb.cs.num_witin as usize, diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index b66d05300..319537218 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -107,6 +107,7 @@ pub struct ConstraintSystem { pub ec_point_exprs: Vec>, pub ec_slope_exprs: Vec>, pub ec_final_sum: Vec>, + pub ec_bridge_selectors: Option<[SelectorType; 3]>, pub r_selector: Option>, pub r_expressions: Vec>, @@ -179,6 +180,7 @@ impl ConstraintSystem { ec_final_sum: vec![], ec_slope_exprs: vec![], ec_point_exprs: vec![], + ec_bridge_selectors: None, r_selector: None, r_expressions: vec![], r_expressions_namespace_map: vec![], diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 068495bec..53f5a2166 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -41,6 +41,7 @@ pub type RotateExprs = ( // rotation contribute // left + right + target, overall 3 pub const ROTATION_OPENING_COUNT: usize = 3; +pub const ECC_BRIDGE_OPENING_COUNT: usize = 3; #[derive(Clone, Debug, Serialize, Deserialize)] pub enum LayerType { @@ -71,7 +72,6 @@ pub struct Layer { pub n_witin: usize, pub n_structural_witin: usize, pub n_fixed: usize, - pub n_instance: usize, pub max_expr_degree: usize, /// keep all structural witin which could be evaluated succinctly without PCS pub structural_witins: Vec, @@ -100,6 +100,7 @@ pub struct Layer { // there got 3 different eq for (left, right, target) during rotation argument // refer https://hackmd.io/HAAj1JTQQiKfu0SIwOJDRw?view#Rotation pub rotation_exprs: RotateExprs, + pub ecc_bridge_group_indices: Option<[usize; ECC_BRIDGE_OPENING_COUNT]>, pub rotation_cyclic_group_log2: usize, pub rotation_cyclic_subgroup_size: usize, @@ -140,7 +141,6 @@ impl Layer { n_witin: usize, n_structural_witin: usize, n_fixed: usize, - n_instance: usize, // exprs concat zero/non-zero expression. exprs: Vec>, in_eval_expr: Vec, @@ -169,7 +169,6 @@ impl Layer { n_witin, n_structural_witin, n_fixed, - n_instance, max_expr_degree, structural_witins, exprs, @@ -177,6 +176,7 @@ impl Layer { in_eval_expr, out_sel_and_eval_exprs, rotation_exprs: (rotation_eq, rotation_exprs), + ecc_bridge_group_indices: None, rotation_cyclic_group_log2, rotation_cyclic_subgroup_size, expr_names, @@ -339,6 +339,11 @@ impl Layer { assert_eq!(zero_evals.len(), zero_len); let rotation_expr_len = cb.cs.rotations.len() * ROTATION_OPENING_COUNT; + let ecc_bridge_expr_len = if cb.cs.ec_point_exprs.is_empty() { + 0 + } else { + cb.cs.ec_slope_exprs.len() * ECC_BRIDGE_OPENING_COUNT + }; let mut next_non_zero_eval_idx = r_record_evals .iter() .chain(w_record_evals.iter()) @@ -352,7 +357,8 @@ impl Layer { + cb.cs.r_table_expressions.len() + cb.cs.lk_expressions.len() + cb.cs.lk_table_expressions.len() * 2 - + rotation_expr_len; + + rotation_expr_len + + ecc_bridge_expr_len; let zero_expr_len = cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); @@ -369,10 +375,17 @@ impl Layer { 0 } else { ROTATION_OPENING_COUNT + } + + if cb.cs.ec_point_exprs.is_empty() { + 0 + } else { + ECC_BRIDGE_OPENING_COUNT }; let mut expr_evals = Vec::with_capacity(selector_group_capacity); let mut expr_names = Vec::with_capacity(non_zero_expr_len + zero_expr_len); let mut expressions = Vec::with_capacity(non_zero_expr_len + zero_expr_len); + let mut ecc_bridge_group_indices: Option<[usize; ECC_BRIDGE_OPENING_COUNT]> = None; + let mut ecc_bridge_eval_bases: Option<[usize; ECC_BRIDGE_OPENING_COUNT]> = None; if let Some(r_selector) = cb.cs.r_selector.as_ref() { // process r_record @@ -550,9 +563,58 @@ impl Layer { } } + if !cb.cs.ec_point_exprs.is_empty() { + let septic_degree = cb.cs.ec_slope_exprs.len(); + assert_eq!(cb.cs.ec_point_exprs.len(), septic_degree * 2); + + // ECC bridge selector groups must be explicitly supplied and independent. + // Do not fall back to (or reuse) preceding r/w/lk/zero selectors. + let [ecc_sel_x, ecc_sel_y, ecc_sel_s] = + cb.cs.ec_bridge_selectors.clone().expect( + "ecc bridge selectors must be provided when ec_point_exprs is non-empty", + ); + + let x_eval_base = next_non_zero_eval_idx; + let y_eval_base = x_eval_base + septic_degree; + let s_eval_base = y_eval_base + septic_degree; + next_non_zero_eval_idx = s_eval_base + septic_degree; + ecc_bridge_eval_bases = Some([x_eval_base, y_eval_base, s_eval_base]); + + let x_group_idx = expr_evals.len(); + expr_evals.push((ecc_sel_x, vec![])); + let y_group_idx = expr_evals.len(); + expr_evals.push((ecc_sel_y, vec![])); + let s_group_idx = expr_evals.len(); + expr_evals.push((ecc_sel_s, vec![])); + ecc_bridge_group_indices = Some([x_group_idx, y_group_idx, s_group_idx]); + + for (idx, x_expr) in cb.cs.ec_point_exprs[..septic_degree].iter().enumerate() { + expressions.push(x_expr.clone()); + expr_evals[x_group_idx] + .1 + .push(EvalExpression::Single(x_eval_base + idx)); + expr_names.push(format!("ecc_bridge/x/{idx}")); + } + + for (idx, y_expr) in cb.cs.ec_point_exprs[septic_degree..].iter().enumerate() { + expressions.push(y_expr.clone()); + expr_evals[y_group_idx] + .1 + .push(EvalExpression::Single(y_eval_base + idx)); + expr_names.push(format!("ecc_bridge/y/{idx}")); + } + + for (idx, slope_expr) in cb.cs.ec_slope_exprs.iter().enumerate() { + expressions.push(slope_expr.clone()); + expr_evals[s_group_idx] + .1 + .push(EvalExpression::Single(s_eval_base + idx)); + expr_names.push(format!("ecc_bridge/slope/{idx}")); + } + } + if let Some(zero_selector) = cb.cs.zero_selector.as_ref() { // process zero_record - // Intentionally dedup with the previous selector group when selector matches. let evals = Self::dedup_last_selector_evals(zero_selector, &mut expr_evals); for (idx, (zero_expr, name)) in izip!( 0.., @@ -583,6 +645,25 @@ impl Layer { // Drop selector groups that ended up without eval expressions. expr_evals.retain(|(_, evals)| !evals.is_empty()); + if let Some([x_base, y_base, s_base]) = ecc_bridge_eval_bases { + let find_group_by_base = |base: usize| { + expr_evals + .iter() + .enumerate() + .find_map(|(idx, (_, evals))| match evals.first() { + Some(EvalExpression::Single(pos)) if *pos == base => Some(idx), + _ => None, + }) + }; + let x_idx = find_group_by_base(x_base) + .expect("missing x ecc bridge selector group after retain"); + let y_idx = find_group_by_base(y_base) + .expect("missing y ecc bridge selector group after retain"); + let s_idx = find_group_by_base(s_base) + .expect("missing slope ecc bridge selector group after retain"); + ecc_bridge_group_indices = Some([x_idx, y_idx, s_idx]); + } + let out_eval_count = expr_evals .iter() .map(|(_, evals)| evals.len()) @@ -599,20 +680,21 @@ impl Layer { .take(cb.cs.num_witin as usize + cb.cs.num_fixed) .collect_vec(); if rotations.is_empty() { - Layer::new( + let mut layer = Layer::new( layer_name, LayerType::Zerocheck, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, cb.cs.num_fixed, - 0, expressions, in_eval_expr, expr_evals, ((None, vec![]), 0, 0), expr_names, cb.cs.structural_witins.clone(), - ) + ); + layer.ecc_bridge_group_indices = ecc_bridge_group_indices; + layer } else { let Some(RotationParams { rotation_eqs, @@ -622,13 +704,12 @@ impl Layer { else { panic!("rotation params not set"); }; - Layer::new( + let mut layer = Layer::new( layer_name, LayerType::Zerocheck, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, cb.cs.num_fixed, - 0, expressions, in_eval_expr, expr_evals, @@ -639,7 +720,9 @@ impl Layer { ), expr_names, cb.cs.structural_witins.clone(), - ) + ); + layer.ecc_bridge_group_indices = ecc_bridge_group_indices; + layer } } @@ -690,6 +773,10 @@ impl Layer { let point_idx = find_group(point_eq)?; Some([left_idx, right_idx, point_idx]) } + + pub fn ecc_bridge_group_indices(&self) -> Option<[usize; ECC_BRIDGE_OPENING_COUNT]> { + self.ecc_bridge_group_indices + } } impl<'a, PB: ProverBackend> LayerWitness<'a, PB> { diff --git a/gkr_iop/src/gkr/layer/cpu/mod.rs b/gkr_iop/src/gkr/layer/cpu/mod.rs index 303b22f4d..de12967a1 100644 --- a/gkr_iop/src/gkr/layer/cpu/mod.rs +++ b/gkr_iop/src/gkr/layer/cpu/mod.rs @@ -174,12 +174,11 @@ impl> ZerocheckLayerProver } exit_span!(span); - // `wit` := witin ++ fixed ++ pubio ++ structural + // `wit` := witin ++ fixed ++ structural // selector structural witins are replaced by computed eq MLEs in-place by witness id. - let base_wit_count = layer.n_witin + layer.n_fixed + layer.n_instance; - let mut all_witins = Vec::with_capacity( - layer.n_witin + layer.n_structural_witin + layer.n_fixed + layer.n_instance, - ); + let base_wit_count = layer.n_witin + layer.n_fixed; + let mut all_witins = + Vec::with_capacity(layer.n_witin + layer.n_structural_witin + layer.n_fixed); all_witins.extend( wit.iter() .take(base_wit_count) @@ -199,13 +198,12 @@ impl> ZerocheckLayerProver assert_eq!( all_witins.len(), - layer.n_witin + layer.n_structural_witin + layer.n_fixed + layer.n_instance, - "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {} + layer.n_instance {}", + layer.n_witin + layer.n_structural_witin + layer.n_fixed, + "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {}", all_witins.len(), layer.n_witin, layer.n_structural_witin, layer.n_fixed, - layer.n_instance, ); let builder = diff --git a/gkr_iop/src/gkr/layer/gpu/mod.rs b/gkr_iop/src/gkr/layer/gpu/mod.rs index cb0977bc5..14372729e 100644 --- a/gkr_iop/src/gkr/layer/gpu/mod.rs +++ b/gkr_iop/src/gkr/layer/gpu/mod.rs @@ -143,9 +143,9 @@ impl> ZerocheckLayerProver } } - // `wit` := witin ++ fixed ++ pubio ++ structural + // `wit` := witin ++ fixed ++ structural // selector structural witins are replaced by computed eq MLEs in-place by witness id. - let base_wit_count = layer.n_witin + layer.n_fixed + layer.n_instance; + let base_wit_count = layer.n_witin + layer.n_fixed; let all_witins_gpu = wit .iter() .take(base_wit_count) @@ -169,13 +169,12 @@ impl> ZerocheckLayerProver .collect_vec(); assert_eq!( all_witins_gpu.len(), - layer.n_witin + layer.n_structural_witin + layer.n_fixed + layer.n_instance, - "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {} + layer.n_instance {}", + layer.n_witin + layer.n_structural_witin + layer.n_fixed, + "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {}", all_witins_gpu.len(), layer.n_witin, layer.n_structural_witin, layer.n_fixed, - layer.n_instance, ); exit_span!(span_eq); diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index 99ef11cb6..2ac2cb2b6 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -138,7 +138,7 @@ impl ZerocheckLayer for Layer { &expr, self.n_witin as WitnessId, self.n_fixed as WitnessId, - self.n_instance, + 0, ) }) .collect::>(); @@ -159,7 +159,7 @@ impl ZerocheckLayer for Layer { expr, self.n_witin as WitnessId, self.n_fixed as WitnessId, - self.n_instance, + 0, ) }); @@ -167,7 +167,7 @@ impl ZerocheckLayer for Layer { &mut zero_expr, self.n_witin as WitnessId, self.n_fixed as WitnessId, - self.n_instance, + 0, ); tracing::trace!("{} main sumcheck degree: {}", self.name, zero_expr.degree()); self.main_sumcheck_expression = Some(zero_expr); @@ -248,7 +248,7 @@ impl ZerocheckLayer for Layer { assert_eq!( main_evals.len(), - self.n_witin + self.n_fixed + self.n_instance + self.n_structural_witin, + self.n_witin + self.n_fixed + self.n_structural_witin, "invalid main_evals length", ); @@ -281,7 +281,7 @@ impl ZerocheckLayer for Layer { ); let in_point = in_point.into_iter().map(|c| c.elements).collect_vec(); - let structural_witin_offset = self.n_witin + self.n_fixed + self.n_instance; + let structural_witin_offset = self.n_witin + self.n_fixed; // eval selector and set to respective witin izip!( &self.out_sel_and_eval_exprs, @@ -770,7 +770,7 @@ pub fn rlc_zero_expr( layer: &Layer, alpha_pows: &[Expression], ) -> Vec> { - let offset_structural_witid = (layer.n_witin + layer.n_fixed + layer.n_instance) as WitnessId; + let offset_structural_witid = (layer.n_witin + layer.n_fixed) as WitnessId; let mut alpha_pows_iter = alpha_pows.iter(); let mut expr_iter = layer.exprs.iter(); let mut zero_check_exprs = Vec::with_capacity(layer.out_sel_and_eval_exprs.len()); diff --git a/gkr_iop/src/gkr/layer_constraint_system.rs b/gkr_iop/src/gkr/layer_constraint_system.rs index c074feedc..a76f4d927 100644 --- a/gkr_iop/src/gkr/layer_constraint_system.rs +++ b/gkr_iop/src/gkr/layer_constraint_system.rs @@ -410,7 +410,6 @@ impl LayerConstraintSystem { self.num_witin, 0, self.num_fixed, - 0, expressions, in_eval_expr, expr_evals, @@ -433,7 +432,6 @@ impl LayerConstraintSystem { self.num_witin, 0, self.num_fixed, - 0, expressions, in_eval_expr, expr_evals, diff --git a/gkr_iop/src/gkr/mock.rs b/gkr_iop/src/gkr/mock.rs index 43c8239ff..07b5658fd 100644 --- a/gkr_iop/src/gkr/mock.rs +++ b/gkr_iop/src/gkr/mock.rs @@ -74,7 +74,7 @@ impl MockProver { &(sel.selector_expr() * expr), layer.n_witin as WitnessId, layer.n_fixed as WitnessId, - layer.n_instance, + 0, &[], &wits, &structural_wits, @@ -93,7 +93,6 @@ impl MockProver { out.mock_evaluate( layer.n_witin as WitnessId, layer.n_fixed as WitnessId, - layer.n_instance, &evaluations, &challenges, num_vars, @@ -148,7 +147,6 @@ impl EvalExpression { &self, n_witin: WitnessId, n_fixed: WitnessId, - n_instance: usize, evals: &[ArcMultilinearExtension<'a, E>], challenges: &[E], num_vars: usize, @@ -162,7 +160,7 @@ impl EvalExpression { &(Expression::WitIn(*i as WitnessId) * *c0.clone() + *c1.clone()), n_witin, n_fixed, - n_instance, + 0, &[], evals, &[], @@ -174,11 +172,7 @@ impl EvalExpression { assert_eq!(parts.len(), 1 << indices.len()); let parts = parts .iter() - .map(|part| { - part.mock_evaluate( - n_witin, n_fixed, n_instance, evals, challenges, num_vars, - ) - }) + .map(|part| part.mock_evaluate(n_witin, n_fixed, evals, challenges, num_vars)) .collect::, _>>()?; indices .iter()