diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index ebf279b11..4e0fa3020 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -337,6 +337,22 @@ impl Layer { assert_eq!(w_record_evals.len(), w_len); assert_eq!(lookup_evals.len(), lk_len); assert_eq!(zero_evals.len(), zero_len); + // Construction of output-evaluation groups used by the main zerocheck: + // - Read group (`r_selector`): sel_r(x) * (r_i(x) - 1) = (claim^r_i - 1) + // - Write group (`w_selector`): sel_w(x) * (w_i(x) - 1) = (claim^w_i - 1) + // - Lookup group (`lk_selector`): sel_lk(x) * f_i(x) = claim^lk_i + // where f_i is normalized to absorb lookup padding alpha: + // non-negated: f_i = lookup_i - alpha, claim^lk_i = claim_i - alpha + // negated: f_i = lookup_i + alpha, claim^lk_i = alpha - claim_i + // - Rotation groups (3 groups): left/right/target claims, each one eq-selected. + // - ECC bridge groups (5 groups): x/y/slope/x3/y3 claims, each one selector-separated. + // - Zero group (`zero_selector`): sel_0(x) * z_i(x) = 0 (encoded via EvalExpression::Zero). + // + // The final batched main-sumcheck polynomial is formed as: + // p(x) = Σ_g sel_g(x) * (Σ_i α^{offset(g,i)} * (expr_{g,i}(x) - eval_{g,i})). + // `offset(g,i)` is the global challenge-power index of the i-th expression in group g, + // i.e. the contiguous position after flattening all groups in `expr_evals` order. + // Here, `expr_evals` below constructs (g -> {(expr_{g,i}, eval_{g,i})_i}) for all groups. let rotation_expr_len = cb.cs.rotations.len() * ROTATION_OPENING_COUNT; let ecc_bridge_expr_len = if cb.cs.ec_point_exprs.is_empty() { diff --git a/gkr_iop/src/gkr/layer/cpu/mod.rs b/gkr_iop/src/gkr/layer/cpu/mod.rs index de12967a1..9269ff856 100644 --- a/gkr_iop/src/gkr/layer/cpu/mod.rs +++ b/gkr_iop/src/gkr/layer/cpu/mod.rs @@ -129,7 +129,12 @@ impl> ZerocheckLayerProver selector_ctxs.len() ); - // Main sumcheck: constraints are fully unified in out_sel_and_eval_exprs. + // Main sumcheck batches smaller selector-group sumchecks. + // Per group g (from `out_sel_and_eval_exprs`): + // p_g(x) = sel_g(x) * Σ_j (α_{2+offset(g,j)} * expr_{g,j}(x)), + // S_g = Σ_{x in {0,1}^n} p_g(x). + // For zerocheck constraints (from each chip), S_g is expected to be 0. + // The batched polynomial is p(x) = Σ_g p_g(x), so Σ_x p(x) = Σ_g S_g = 0. let span = entered_span!("build_out_points_eq", profiling_4 = true); let main_sumcheck_challenges = chain!( challenges.iter().copied(), diff --git a/gkr_iop/src/gkr/layer/gpu/mod.rs b/gkr_iop/src/gkr/layer/gpu/mod.rs index 14372729e..072f89156 100644 --- a/gkr_iop/src/gkr/layer/gpu/mod.rs +++ b/gkr_iop/src/gkr/layer/gpu/mod.rs @@ -98,7 +98,12 @@ impl> ZerocheckLayerProver out_points.len(), ); - // Main sumcheck: constraints are fully unified in out_sel_and_eval_exprs. + // Main sumcheck batches smaller selector-group sumchecks. + // Per group g (from `out_sel_and_eval_exprs`): + // p_g(x) = sel_g(x) * Σ_j (α_{2+offset(g,j)} * expr_{g,j}(x)), + // S_g = Σ_{x in {0,1}^n} p_g(x). + // For zerocheck constraints (from each chip), S_g is expected to be 0. + // The batched polynomial is p(x) = Σ_g p_g(x), so Σ_x p(x) = Σ_g S_g = 0. let main_sumcheck_challenges = chain!( challenges.iter().copied(), get_challenge_pows(layer.exprs.len(), transcript) diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index 2ac2cb2b6..23dd1ac8a 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -143,7 +143,15 @@ impl ZerocheckLayer for Layer { }) .collect::>(); - // build main sumcheck expression + // Build the concrete main-sumcheck polynomial by batching smaller sumchecks. + // For each selector group g with expressions expr_{g,0..k-1}, define: + // p_g(x) = sel_g(x) * Σ_j (α_{2+offset(g,j)} * expr_{g,j}(x)). + // The corresponding smaller sumcheck target is: + // S_g = Σ_{x in {0,1}^n} p_g(x). + // For zerocheck constraints contributed by each chip, the expected target is S_g = 0. + // Main sumcheck batches them into: + // p(x) = Σ_g p_g(x), so Σ_{x in {0,1}^n} p(x) = Σ_g S_g = 0. + // `rlc_zero_expr` returns the per-group p_g terms, then we sum them into p. let alpha_pows_expr = (2..) .take(self.exprs.len()) .map(|id| Expression::Challenge(id as ChallengeId, 1, E::ONE, E::ZERO))