Remove extra SHL/SHR CTL. (#1270)

* Remove extra shift CTL.

* Change order of inputs for the arithmetic shift operations. Add SHR test. Fix max number of bit shifts. Cleanup.

* Fix SHR in the case shift >= 256

* Limit visibility of helper functions
This commit is contained in:
Linda Guiga 2023-10-05 09:56:56 -04:00 committed by GitHub
parent 51eb7c0b52
commit 0de6f94962
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 504 additions and 107 deletions

View File

@ -104,10 +104,7 @@ pub(crate) fn all_cross_table_lookups<F: Field>() -> Vec<CrossTableLookup<F>> {
fn ctl_arithmetic<F: Field>() -> CrossTableLookup<F> {
CrossTableLookup::new(
vec![
cpu_stark::ctl_arithmetic_base_rows(),
cpu_stark::ctl_arithmetic_shift_rows(),
],
vec![cpu_stark::ctl_arithmetic_base_rows()],
arithmetic_stark::ctl_arithmetic_rows(),
)
}

View File

@ -12,6 +12,7 @@ use plonky2::util::transpose;
use static_assertions::const_assert;
use super::columns::NUM_ARITH_COLUMNS;
use super::shift;
use crate::all_stark::Table;
use crate::arithmetic::columns::{RANGE_COUNTER, RC_FREQUENCIES, SHARED_COLS};
use crate::arithmetic::{addcy, byte, columns, divmod, modular, mul, Operation};
@ -208,6 +209,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for ArithmeticSta
divmod::eval_packed(lv, nv, yield_constr);
modular::eval_packed(lv, nv, yield_constr);
byte::eval_packed(lv, yield_constr);
shift::eval_packed_generic(lv, nv, yield_constr);
}
fn eval_ext_circuit(
@ -237,6 +239,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Stark<F, D> for ArithmeticSta
divmod::eval_ext_circuit(builder, lv, nv, yield_constr);
modular::eval_ext_circuit(builder, lv, nv, yield_constr);
byte::eval_ext_circuit(builder, lv, yield_constr);
shift::eval_ext_circuit(builder, lv, nv, yield_constr);
}
fn constraint_degree(&self) -> usize {

View File

@ -101,7 +101,7 @@ pub(crate) const MODULAR_OUT_AUX_RED: Range<usize> = AUX_REGISTER_0;
pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_REGISTER_1.start;
pub(crate) const MODULAR_AUX_INPUT_LO: Range<usize> = AUX_REGISTER_1.start + 1..AUX_REGISTER_1.end;
pub(crate) const MODULAR_AUX_INPUT_HI: Range<usize> = AUX_REGISTER_2;
// Must be set to MOD_IS_ZERO for DIV operation i.e. MOD_IS_ZERO * lv[IS_DIV]
// Must be set to MOD_IS_ZERO for DIV and SHR operations i.e. MOD_IS_ZERO * (lv[IS_DIV] + lv[IS_SHR]).
pub(crate) const MODULAR_DIV_DENOM_IS_ZERO: usize = AUX_REGISTER_2.end;
/// The counter column (used for the range check) starts from 0 and increments.

View File

@ -15,24 +15,19 @@ use crate::arithmetic::modular::{
use crate::arithmetic::utils::*;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
/// Generate the output and auxiliary values for modular operations.
pub(crate) fn generate<F: PrimeField64>(
/// Generates the output and auxiliary values for modular operations,
/// assuming the input, modular and output limbs are already set.
pub(crate) fn generate_divmod<F: PrimeField64>(
lv: &mut [F],
nv: &mut [F],
filter: usize,
input0: U256,
input1: U256,
result: U256,
input_limbs_range: Range<usize>,
modulus_range: Range<usize>,
) {
debug_assert!(lv.len() == NUM_ARITH_COLUMNS);
u256_to_array(&mut lv[INPUT_REGISTER_0], input0);
u256_to_array(&mut lv[INPUT_REGISTER_1], input1);
u256_to_array(&mut lv[OUTPUT_REGISTER], result);
let input_limbs = read_value_i64_limbs::<N_LIMBS, _>(lv, INPUT_REGISTER_0);
let input_limbs = read_value_i64_limbs::<N_LIMBS, _>(lv, input_limbs_range);
let pol_input = pol_extend(input_limbs);
let (out, quo_input) = generate_modular_op(lv, nv, filter, pol_input, INPUT_REGISTER_1);
let (out, quo_input) = generate_modular_op(lv, nv, filter, pol_input, modulus_range);
debug_assert!(
&quo_input[N_LIMBS..].iter().all(|&x| x == F::ZERO),
"expected top half of quo_input to be zero"
@ -62,16 +57,35 @@ pub(crate) fn generate<F: PrimeField64>(
);
lv[AUX_INPUT_REGISTER_0].copy_from_slice(&quo_input[..N_LIMBS]);
}
_ => panic!("expected filter to be IS_DIV or IS_MOD but it was {filter}"),
_ => panic!("expected filter to be IS_DIV, IS_SHR or IS_MOD but it was {filter}"),
};
}
/// Generate the output and auxiliary values for modular operations.
pub(crate) fn generate<F: PrimeField64>(
lv: &mut [F],
nv: &mut [F],
filter: usize,
input0: U256,
input1: U256,
result: U256,
) {
debug_assert!(lv.len() == NUM_ARITH_COLUMNS);
u256_to_array(&mut lv[INPUT_REGISTER_0], input0);
u256_to_array(&mut lv[INPUT_REGISTER_1], input1);
u256_to_array(&mut lv[OUTPUT_REGISTER], result);
generate_divmod(lv, nv, filter, INPUT_REGISTER_0, INPUT_REGISTER_1);
}
/// Verify that num = quo * den + rem and 0 <= rem < den.
fn eval_packed_divmod_helper<P: PackedField>(
pub(crate) fn eval_packed_divmod_helper<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS],
nv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
filter: P,
num_range: Range<usize>,
den_range: Range<usize>,
quo_range: Range<usize>,
rem_range: Range<usize>,
) {
@ -80,8 +94,8 @@ fn eval_packed_divmod_helper<P: PackedField>(
yield_constr.constraint_last_row(filter);
let num = &lv[INPUT_REGISTER_0];
let den = read_value(lv, INPUT_REGISTER_1);
let num = &lv[num_range];
let den = read_value(lv, den_range);
let quo = {
let mut quo = [P::ZEROS; 2 * N_LIMBS];
quo[..N_LIMBS].copy_from_slice(&lv[quo_range]);
@ -104,14 +118,13 @@ pub(crate) fn eval_packed<P: PackedField>(
nv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
) {
// Constrain IS_SHR independently, so that it doesn't impact the
// constraints when combining the flag with IS_DIV.
yield_constr.constraint_last_row(lv[IS_SHR]);
eval_packed_divmod_helper(
lv,
nv,
yield_constr,
lv[IS_DIV] + lv[IS_SHR],
lv[IS_DIV],
INPUT_REGISTER_0,
INPUT_REGISTER_1,
OUTPUT_REGISTER,
AUX_INPUT_REGISTER_0,
);
@ -120,24 +133,28 @@ pub(crate) fn eval_packed<P: PackedField>(
nv,
yield_constr,
lv[IS_MOD],
INPUT_REGISTER_0,
INPUT_REGISTER_1,
AUX_INPUT_REGISTER_0,
OUTPUT_REGISTER,
);
}
fn eval_ext_circuit_divmod_helper<F: RichField + Extendable<D>, const D: usize>(
pub(crate) fn eval_ext_circuit_divmod_helper<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
filter: ExtensionTarget<D>,
num_range: Range<usize>,
den_range: Range<usize>,
quo_range: Range<usize>,
rem_range: Range<usize>,
) {
yield_constr.constraint_last_row(builder, filter);
let num = &lv[INPUT_REGISTER_0];
let den = read_value(lv, INPUT_REGISTER_1);
let num = &lv[num_range];
let den = read_value(lv, den_range);
let quo = {
let zero = builder.zero_extension();
let mut quo = [zero; 2 * N_LIMBS];
@ -164,14 +181,14 @@ pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
yield_constr.constraint_last_row(builder, lv[IS_SHR]);
let div_shr_flag = builder.add_extension(lv[IS_DIV], lv[IS_SHR]);
eval_ext_circuit_divmod_helper(
builder,
lv,
nv,
yield_constr,
div_shr_flag,
lv[IS_DIV],
INPUT_REGISTER_0,
INPUT_REGISTER_1,
OUTPUT_REGISTER,
AUX_INPUT_REGISTER_0,
);
@ -181,6 +198,8 @@ pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
nv,
yield_constr,
lv[IS_MOD],
INPUT_REGISTER_0,
INPUT_REGISTER_1,
AUX_INPUT_REGISTER_0,
OUTPUT_REGISTER,
);
@ -214,7 +233,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
// Deactivate the SHR flag so that a DIV operation is not triggered.
// Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here.
lv[IS_SHR] = F::ZERO;
let mut constraint_consumer = ConstraintConsumer::new(
@ -247,6 +266,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
// Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here.
lv[IS_SHR] = F::ZERO;
lv[op_filter] = F::ONE;
@ -308,6 +328,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
// Since SHR uses the logic for DIV, `IS_SHR` should also be set to 0 here.
lv[IS_SHR] = F::ZERO;
lv[op_filter] = F::ONE;

View File

@ -9,6 +9,7 @@ mod byte;
mod divmod;
mod modular;
mod mul;
mod shift;
mod utils;
pub mod arithmetic_stark;
@ -35,15 +36,29 @@ impl BinaryOperator {
pub(crate) fn result(&self, input0: U256, input1: U256) -> U256 {
match self {
BinaryOperator::Add => input0.overflowing_add(input1).0,
BinaryOperator::Mul | BinaryOperator::Shl => input0.overflowing_mul(input1).0,
BinaryOperator::Mul => input0.overflowing_mul(input1).0,
BinaryOperator::Shl => {
if input0 < U256::from(256usize) {
input1 << input0
} else {
U256::zero()
}
}
BinaryOperator::Sub => input0.overflowing_sub(input1).0,
BinaryOperator::Div | BinaryOperator::Shr => {
BinaryOperator::Div => {
if input1.is_zero() {
U256::zero()
} else {
input0 / input1
}
}
BinaryOperator::Shr => {
if input0 < U256::from(256usize) {
input1 >> input0
} else {
U256::zero()
}
}
BinaryOperator::Mod => {
if input1.is_zero() {
U256::zero()
@ -238,15 +253,25 @@ fn binary_op_to_rows<F: PrimeField64>(
addcy::generate(&mut row, op.row_filter(), input0, input1);
(row, None)
}
BinaryOperator::Mul | BinaryOperator::Shl => {
BinaryOperator::Mul => {
mul::generate(&mut row, input0, input1);
(row, None)
}
BinaryOperator::Div | BinaryOperator::Mod | BinaryOperator::Shr => {
BinaryOperator::Shl => {
let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS];
shift::generate(&mut row, &mut nv, true, input0, input1, result);
(row, None)
}
BinaryOperator::Div | BinaryOperator::Mod => {
let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS];
divmod::generate(&mut row, &mut nv, op.row_filter(), input0, input1, result);
(row, Some(nv))
}
BinaryOperator::Shr => {
let mut nv = vec![F::ZERO; columns::NUM_ARITH_COLUMNS];
shift::generate(&mut row, &mut nv, false, input0, input1, result);
(row, Some(nv))
}
BinaryOperator::AddFp254 | BinaryOperator::MulFp254 | BinaryOperator::SubFp254 => {
ternary_op_to_rows::<F>(op.row_filter(), input0, input1, BN_BASE, result)
}

View File

@ -239,7 +239,7 @@ pub(crate) fn generate_modular_op<F: PrimeField64>(
let mut mod_is_zero = F::ZERO;
if modulus.is_zero() {
if filter == columns::IS_DIV {
if filter == columns::IS_DIV || filter == columns::IS_SHR {
// set modulus = 2^256; the condition above means we know
// it's zero at this point, so we can just set bit 256.
modulus.set_bit(256, true);
@ -330,7 +330,7 @@ pub(crate) fn generate_modular_op<F: PrimeField64>(
nv[MODULAR_MOD_IS_ZERO] = mod_is_zero;
nv[MODULAR_OUT_AUX_RED].copy_from_slice(&out_aux_red.map(F::from_canonical_i64));
nv[MODULAR_DIV_DENOM_IS_ZERO] = mod_is_zero * lv[IS_DIV];
nv[MODULAR_DIV_DENOM_IS_ZERO] = mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]);
(
output_limbs.map(F::from_canonical_i64),
@ -392,14 +392,14 @@ pub(crate) fn check_reduced<P: PackedField>(
// Verify that the output is reduced, i.e. output < modulus.
let out_aux_red = &nv[MODULAR_OUT_AUX_RED];
// This sets is_less_than to 1 unless we get mod_is_zero when
// doing a DIV; in that case, we need is_less_than=0, since
// doing a DIV or SHR; in that case, we need is_less_than=0, since
// eval_packed_generic_addcy checks
//
// modulus + out_aux_red == output + is_less_than*2^256
//
// and we are given output = out_aux_red when modulus is zero.
let mut is_less_than = [P::ZEROS; N_LIMBS];
is_less_than[0] = P::ONES - mod_is_zero * lv[IS_DIV];
is_less_than[0] = P::ONES - mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]);
// NB: output and modulus in lv while out_aux_red and
// is_less_than (via mod_is_zero) depend on nv, hence the
// 'is_two_row_op' argument is set to 'true'.
@ -448,13 +448,15 @@ pub(crate) fn modular_constr_poly<P: PackedField>(
// modulus = 0.
modulus[0] += mod_is_zero;
// Is 1 iff the operation is DIV and the denominator is zero.
// Is 1 iff the operation is DIV or SHR and the denominator is zero.
let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO];
yield_constr.constraint_transition(filter * (mod_is_zero * lv[IS_DIV] - div_denom_is_zero));
yield_constr.constraint_transition(
filter * (mod_is_zero * (lv[IS_DIV] + lv[IS_SHR]) - div_denom_is_zero),
);
// Needed to compensate for adding mod_is_zero to modulus above,
// since the call eval_packed_generic_addcy() below subtracts modulus
// to verify in the case of a DIV.
// to verify in the case of a DIV or SHR.
output[0] += div_denom_is_zero;
check_reduced(lv, nv, yield_constr, filter, output, modulus, mod_is_zero);
@ -635,7 +637,8 @@ pub(crate) fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, cons
modulus[0] = builder.add_extension(modulus[0], mod_is_zero);
let div_denom_is_zero = nv[MODULAR_DIV_DENOM_IS_ZERO];
let t = builder.mul_sub_extension(mod_is_zero, lv[IS_DIV], div_denom_is_zero);
let div_shr_filter = builder.add_extension(lv[IS_DIV], lv[IS_SHR]);
let t = builder.mul_sub_extension(mod_is_zero, div_shr_filter, div_denom_is_zero);
let t = builder.mul_extension(filter, t);
yield_constr.constraint_transition(builder, t);
output[0] = builder.add_extension(output[0], div_denom_is_zero);
@ -645,7 +648,7 @@ pub(crate) fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, cons
let zero = builder.zero_extension();
let mut is_less_than = [zero; N_LIMBS];
is_less_than[0] =
builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], one);
builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, div_shr_filter, one);
eval_ext_circuit_addcy(
builder,
@ -834,6 +837,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
lv[IS_SHR] = F::ZERO;
lv[IS_DIV] = F::ZERO;
lv[IS_MOD] = F::ZERO;
@ -867,6 +871,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
lv[IS_SHR] = F::ZERO;
lv[IS_DIV] = F::ZERO;
lv[IS_MOD] = F::ZERO;
lv[op_filter] = F::ONE;
@ -926,6 +931,7 @@ mod tests {
for op in MODULAR_OPS {
lv[op] = F::ZERO;
}
lv[IS_SHR] = F::ZERO;
lv[IS_DIV] = F::ZERO;
lv[IS_MOD] = F::ZERO;
lv[op_filter] = F::ONE;

View File

@ -67,16 +67,8 @@ use crate::arithmetic::columns::*;
use crate::arithmetic::utils::*;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
pub fn generate<F: PrimeField64>(lv: &mut [F], left_in: U256, right_in: U256) {
// TODO: It would probably be clearer/cleaner to read the U256
// into an [i64;N] and then copy that to the lv table.
u256_to_array(&mut lv[INPUT_REGISTER_0], left_in);
u256_to_array(&mut lv[INPUT_REGISTER_1], right_in);
u256_to_array(&mut lv[INPUT_REGISTER_2], U256::zero());
let input0 = read_value_i64_limbs(lv, INPUT_REGISTER_0);
let input1 = read_value_i64_limbs(lv, INPUT_REGISTER_1);
/// Given the two limbs of `left_in` and `right_in`, computes `left_in * right_in`.
pub(crate) fn generate_mul<F: PrimeField64>(lv: &mut [F], left_in: [i64; 16], right_in: [i64; 16]) {
const MASK: i64 = (1i64 << LIMB_BITS) - 1i64;
// Input and output have 16-bit limbs
@ -86,7 +78,7 @@ pub fn generate<F: PrimeField64>(lv: &mut [F], left_in: U256, right_in: U256) {
// First calculate the coefficients of a(x)*b(x) (in unreduced_prod),
// then do carry propagation to obtain C = c(β) = a(β)*b(β).
let mut cy = 0i64;
let mut unreduced_prod = pol_mul_lo(input0, input1);
let mut unreduced_prod = pol_mul_lo(left_in, right_in);
for col in 0..N_LIMBS {
let t = unreduced_prod[col] + cy;
cy = t >> LIMB_BITS;
@ -115,17 +107,30 @@ pub fn generate<F: PrimeField64>(lv: &mut [F], left_in: U256, right_in: U256) {
.copy_from_slice(&aux_limbs.map(|c| F::from_canonical_u16((c >> 16) as u16)));
}
pub fn eval_packed_generic<P: PackedField>(
pub fn generate<F: PrimeField64>(lv: &mut [F], left_in: U256, right_in: U256) {
// TODO: It would probably be clearer/cleaner to read the U256
// into an [i64;N] and then copy that to the lv table.
u256_to_array(&mut lv[INPUT_REGISTER_0], left_in);
u256_to_array(&mut lv[INPUT_REGISTER_1], right_in);
u256_to_array(&mut lv[INPUT_REGISTER_2], U256::zero());
let input0 = read_value_i64_limbs(lv, INPUT_REGISTER_0);
let input1 = read_value_i64_limbs(lv, INPUT_REGISTER_1);
generate_mul(lv, input0, input1);
}
pub(crate) fn eval_packed_generic_mul<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS],
filter: P,
left_in_limbs: [P; 16],
right_in_limbs: [P; 16],
yield_constr: &mut ConstraintConsumer<P>,
) {
let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS);
let is_mul = lv[IS_MUL] + lv[IS_SHL];
let input0_limbs = read_value::<N_LIMBS, _>(lv, INPUT_REGISTER_0);
let input1_limbs = read_value::<N_LIMBS, _>(lv, INPUT_REGISTER_1);
let output_limbs = read_value::<N_LIMBS, _>(lv, OUTPUT_REGISTER);
let base = P::Scalar::from_canonical_u64(1 << LIMB_BITS);
let aux_limbs = {
// MUL_AUX_INPUT was offset by 2^20 in generation, so we undo
// that here
@ -153,7 +158,7 @@ pub fn eval_packed_generic<P: PackedField>(
//
// s(x) = \sum_i aux_limbs[i] * x^i
//
let mut constr_poly = pol_mul_lo(input0_limbs, input1_limbs);
let mut constr_poly = pol_mul_lo(left_in_limbs, right_in_limbs);
pol_sub_assign(&mut constr_poly, &output_limbs);
// This subtracts (x - β) * s(x) from constr_poly.
@ -164,18 +169,29 @@ pub fn eval_packed_generic<P: PackedField>(
// multiplication is valid if and only if all of those
// coefficients are zero.
for &c in &constr_poly {
yield_constr.constraint(is_mul * c);
yield_constr.constraint(filter * c);
}
}
pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
pub fn eval_packed_generic<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
) {
let is_mul = builder.add_extension(lv[IS_MUL], lv[IS_SHL]);
let is_mul = lv[IS_MUL];
let input0_limbs = read_value::<N_LIMBS, _>(lv, INPUT_REGISTER_0);
let input1_limbs = read_value::<N_LIMBS, _>(lv, INPUT_REGISTER_1);
eval_packed_generic_mul(lv, is_mul, input0_limbs, input1_limbs, yield_constr);
}
pub(crate) fn eval_ext_mul_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
filter: ExtensionTarget<D>,
left_in_limbs: [ExtensionTarget<D>; 16],
right_in_limbs: [ExtensionTarget<D>; 16],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let output_limbs = read_value::<N_LIMBS, _>(lv, OUTPUT_REGISTER);
let aux_limbs = {
@ -192,7 +208,7 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
aux_limbs
};
let mut constr_poly = pol_mul_lo_ext_circuit(builder, input0_limbs, input1_limbs);
let mut constr_poly = pol_mul_lo_ext_circuit(builder, left_in_limbs, right_in_limbs);
pol_sub_assign_ext_circuit(builder, &mut constr_poly, &output_limbs);
let base = builder.constant_extension(F::Extension::from_canonical_u64(1 << LIMB_BITS));
@ -200,11 +216,30 @@ pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
pol_sub_assign_ext_circuit(builder, &mut constr_poly, &rhs);
for &c in &constr_poly {
let filter = builder.mul_extension(is_mul, c);
let filter = builder.mul_extension(filter, c);
yield_constr.constraint(builder, filter);
}
}
pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let is_mul = lv[IS_MUL];
let input0_limbs = read_value::<N_LIMBS, _>(lv, INPUT_REGISTER_0);
let input1_limbs = read_value::<N_LIMBS, _>(lv, INPUT_REGISTER_1);
eval_ext_mul_circuit(
builder,
lv,
is_mul,
input0_limbs,
input1_limbs,
yield_constr,
);
}
#[cfg(test)]
mod tests {
use plonky2::field::goldilocks_field::GoldilocksField;
@ -229,8 +264,6 @@ mod tests {
// if `IS_MUL == 0`, then the constraints should be met even
// if all values are garbage.
lv[IS_MUL] = F::ZERO;
// Deactivate the SHL flag so that a MUL operation is not triggered.
lv[IS_SHL] = F::ZERO;
let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],

338
evm/src/arithmetic/shift.rs Normal file
View File

@ -0,0 +1,338 @@
//! Support for the EVM SHL and SHR instructions.
//!
//! This crate verifies an EVM shift instruction, which takes two
//! 256-bit inputs S and A, and produces a 256-bit output C satisfying
//!
//! C = A << S (mod 2^256) for SHL or
//! C = A >> S (mod 2^256) for SHR.
//!
//! The way this computation is carried is by providing a third input
//! B = 1 << S (mod 2^256)
//! and then computing:
//! C = A * B (mod 2^256) for SHL or
//! C = A / B (mod 2^256) for SHR
//!
//! Inputs A, S, and B, and output C, are given as arrays of 16-bit
//! limbs. For example, if the limbs of A are a[0]...a[15], then
//!
//! A = \sum_{i=0}^15 a[i] β^i,
//!
//! where β = 2^16 = 2^LIMB_BITS. To verify that A, S, B and C satisfy
//! the equations, we proceed similarly to MUL for SHL and to DIV for SHR.
use ethereum_types::U256;
use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField;
use plonky2::field::types::PrimeField64;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use super::{divmod, mul};
use crate::arithmetic::columns::*;
use crate::arithmetic::utils::*;
use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
/// Generates a shift operation (either SHL or SHR).
/// The inputs are stored in the form `(shift, input, 1 << shift)`.
/// NB: if `shift >= 256`, then the third register holds 0.
/// We leverage the functions in mul.rs and divmod.rs to carry out
/// the computation.
pub fn generate<F: PrimeField64>(
lv: &mut [F],
nv: &mut [F],
is_shl: bool,
shift: U256,
input: U256,
result: U256,
) {
// We use the multiplication logic to generate SHL
// TODO: It would probably be clearer/cleaner to read the U256
// into an [i64;N] and then copy that to the lv table.
// The first input is the shift we need to apply.
u256_to_array(&mut lv[INPUT_REGISTER_0], shift);
// The second register holds the input which needs shifting.
u256_to_array(&mut lv[INPUT_REGISTER_1], input);
u256_to_array(&mut lv[OUTPUT_REGISTER], result);
// If `shift >= 256`, the shifted displacement is set to 0.
// Compute 1 << shift and store it in the third input register.
let shifted_displacement = if shift > U256::from(255u64) {
U256::zero()
} else {
U256::one() << shift
};
u256_to_array(&mut lv[INPUT_REGISTER_2], shifted_displacement);
let input0 = read_value_i64_limbs(lv, INPUT_REGISTER_1); // input
let input1 = read_value_i64_limbs(lv, INPUT_REGISTER_2); // 1 << shift
if is_shl {
// We generate the multiplication input0 * input1 using mul.rs.
mul::generate_mul(lv, input0, input1);
} else {
// If the operation is SHR, we compute: `input / shifted_displacement` if `shifted_displacement == 0`
// otherwise, the output is 0. We use the logic in divmod.rs to achieve that.
divmod::generate_divmod(lv, nv, IS_SHR, INPUT_REGISTER_1, INPUT_REGISTER_2);
}
}
/// Evaluates the constraints for an SHL opcode.
/// The logic is the same as the one for MUL. The only difference is that
/// the inputs are in `INPUT_REGISTER_1` and `INPUT_REGISTER_2` instead of
/// `INPUT_REGISTER_0` and `INPUT_REGISTER_1`.
fn eval_packed_shl<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
) {
let is_shl = lv[IS_SHL];
let input0_limbs = read_value::<N_LIMBS, _>(lv, INPUT_REGISTER_1);
let shifted_limbs = read_value::<N_LIMBS, _>(lv, INPUT_REGISTER_2);
mul::eval_packed_generic_mul(lv, is_shl, input0_limbs, shifted_limbs, yield_constr);
}
/// Evaluates the constraints for an SHR opcode.
/// The logic is tha same as the one for DIV. The only difference is that
/// the inputs are in `INPUT_REGISTER_1` and `INPUT_REGISTER_2` instead of
/// `INPUT_REGISTER_0` and `INPUT_REGISTER_1`.
fn eval_packed_shr<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS],
nv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
) {
let quo_range = OUTPUT_REGISTER;
let rem_range = AUX_INPUT_REGISTER_0;
let filter = lv[IS_SHR];
divmod::eval_packed_divmod_helper(
lv,
nv,
yield_constr,
filter,
INPUT_REGISTER_1,
INPUT_REGISTER_2,
quo_range,
rem_range,
);
}
pub fn eval_packed_generic<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS],
nv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
) {
eval_packed_shl(lv, yield_constr);
eval_packed_shr(lv, nv, yield_constr);
}
fn eval_ext_circuit_shl<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let is_shl = lv[IS_SHL];
let input0_limbs = read_value::<N_LIMBS, _>(lv, INPUT_REGISTER_1);
let shifted_limbs = read_value::<N_LIMBS, _>(lv, INPUT_REGISTER_2);
mul::eval_ext_mul_circuit(
builder,
lv,
is_shl,
input0_limbs,
shifted_limbs,
yield_constr,
);
}
fn eval_ext_circuit_shr<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let filter = lv[IS_SHR];
let quo_range = OUTPUT_REGISTER;
let rem_range = AUX_INPUT_REGISTER_0;
divmod::eval_ext_circuit_divmod_helper(
builder,
lv,
nv,
yield_constr,
filter,
INPUT_REGISTER_1,
INPUT_REGISTER_2,
quo_range,
rem_range,
);
}
pub fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
lv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
nv: &[ExtensionTarget<D>; NUM_ARITH_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
eval_ext_circuit_shl(builder, lv, yield_constr);
eval_ext_circuit_shr(builder, lv, nv, yield_constr);
}
#[cfg(test)]
mod tests {
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::field::types::{Field, Sample};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use super::*;
use crate::arithmetic::columns::NUM_ARITH_COLUMNS;
use crate::constraint_consumer::ConstraintConsumer;
const N_RND_TESTS: usize = 1000;
// TODO: Should be able to refactor this test to apply to all operations.
#[test]
fn generate_eval_consistency_not_shift() {
type F = GoldilocksField;
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
let nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
// if `IS_SHL == 0` and `IS_SHR == 0`, then the constraints should be met even
// if all values are garbage.
lv[IS_SHL] = F::ZERO;
lv[IS_SHR] = F::ZERO;
let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
GoldilocksField::ONE,
GoldilocksField::ONE,
GoldilocksField::ONE,
);
eval_packed_generic(&lv, &nv, &mut constraint_consumer);
for &acc in &constraint_consumer.constraint_accs {
assert_eq!(acc, GoldilocksField::ZERO);
}
}
fn generate_eval_consistency_shift(is_shl: bool) {
type F = GoldilocksField;
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
let mut nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
// set `IS_SHL == 1` or `IS_SHR == 1` and ensure all constraints are satisfied.
if is_shl {
lv[IS_SHL] = F::ONE;
lv[IS_SHR] = F::ZERO;
} else {
// Set `IS_DIV` to 0 in this case, since we're using the logic of DIV for SHR.
lv[IS_DIV] = F::ZERO;
lv[IS_SHL] = F::ZERO;
lv[IS_SHR] = F::ONE;
}
for _i in 0..N_RND_TESTS {
let shift = U256::from(rng.gen::<u8>());
let mut full_input = U256::from(0);
// set inputs to random values
for ai in INPUT_REGISTER_1 {
lv[ai] = F::from_canonical_u16(rng.gen());
full_input =
U256::from(lv[ai].to_canonical_u64()) + full_input * U256::from(1 << 16);
}
let output = if is_shl {
full_input << shift
} else {
full_input >> shift
};
generate(&mut lv, &mut nv, is_shl, shift, full_input, output);
let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
GoldilocksField::ONE,
GoldilocksField::ONE,
GoldilocksField::ZERO,
);
eval_packed_generic(&lv, &nv, &mut constraint_consumer);
for &acc in &constraint_consumer.constraint_accs {
assert_eq!(acc, GoldilocksField::ZERO);
}
}
}
#[test]
fn generate_eval_consistency_shl() {
generate_eval_consistency_shift(true);
}
#[test]
fn generate_eval_consistency_shr() {
generate_eval_consistency_shift(false);
}
fn generate_eval_consistency_shift_over_256(is_shl: bool) {
type F = GoldilocksField;
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
let mut nv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::sample(&mut rng));
// set `IS_SHL == 1` or `IS_SHR == 1` and ensure all constraints are satisfied.
if is_shl {
lv[IS_SHL] = F::ONE;
lv[IS_SHR] = F::ZERO;
} else {
// Set `IS_DIV` to 0 in this case, since we're using the logic of DIV for SHR.
lv[IS_DIV] = F::ZERO;
lv[IS_SHL] = F::ZERO;
lv[IS_SHR] = F::ONE;
}
for _i in 0..N_RND_TESTS {
let mut shift = U256::from(rng.gen::<usize>());
while shift > U256::MAX - 256 {
shift = U256::from(rng.gen::<usize>());
}
shift += U256::from(256);
let mut full_input = U256::from(0);
// set inputs to random values
for ai in INPUT_REGISTER_1 {
lv[ai] = F::from_canonical_u16(rng.gen());
full_input =
U256::from(lv[ai].to_canonical_u64()) + full_input * U256::from(1 << 16);
}
let output = 0.into();
generate(&mut lv, &mut nv, is_shl, shift, full_input, output);
let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
GoldilocksField::ONE,
GoldilocksField::ONE,
GoldilocksField::ZERO,
);
eval_packed_generic(&lv, &nv, &mut constraint_consumer);
for &acc in &constraint_consumer.constraint_accs {
assert_eq!(acc, GoldilocksField::ZERO);
}
}
}
#[test]
fn generate_eval_consistency_shl_over_256() {
generate_eval_consistency_shift_over_256(true);
}
#[test]
fn generate_eval_consistency_shr_over_256() {
generate_eval_consistency_shift_over_256(false);
}
}

View File

@ -63,19 +63,10 @@ fn ctl_data_binops<F: Field>() -> Vec<Column<F>> {
/// one output of a ternary operation. By default, ternary operations use
/// the first three memory channels, and the last one for the result (binary
/// operations do not use the third inputs).
///
/// Shift operations are different, as they are simulated with `MUL` or `DIV`
/// on the arithmetic side. We first convert the shift into the multiplicand
/// (in case of `SHL`) or the divisor (in case of `SHR`), making the first memory
/// channel not directly usable. We overcome this by adding an offset of 1 in
/// case of shift operations, which will skip the first memory channel and use the
/// next three as ternary inputs. Because both `MUL` and `DIV` are binary operations,
/// the last memory channel used for the inputs will be safely ignored.
fn ctl_data_ternops<F: Field>(is_shift: bool) -> Vec<Column<F>> {
let offset = is_shift as usize;
let mut res = Column::singles(COL_MAP.mem_channels[offset].value).collect_vec();
res.extend(Column::singles(COL_MAP.mem_channels[offset + 1].value));
res.extend(Column::singles(COL_MAP.mem_channels[offset + 2].value));
fn ctl_data_ternops<F: Field>() -> Vec<Column<F>> {
let mut res = Column::singles(COL_MAP.mem_channels[0].value).collect_vec();
res.extend(Column::singles(COL_MAP.mem_channels[1].value));
res.extend(Column::singles(COL_MAP.mem_channels[2].value));
res.extend(Column::singles(
COL_MAP.mem_channels[NUM_GP_CHANNELS - 1].value,
));
@ -96,7 +87,7 @@ pub fn ctl_filter_logic<F: Field>() -> Column<F> {
pub fn ctl_arithmetic_base_rows<F: Field>() -> TableWithColumns<F> {
// Instead of taking single columns, we reconstruct the entire opcode value directly.
let mut columns = vec![Column::le_bits(COL_MAP.opcode_bits)];
columns.extend(ctl_data_ternops(false));
columns.extend(ctl_data_ternops());
// Create the CPU Table whose columns are those with the three
// inputs and one output of the ternary operations listed in `ops`
// (also `ops` is used as the operation filter). The list of
@ -109,22 +100,11 @@ pub fn ctl_arithmetic_base_rows<F: Field>() -> TableWithColumns<F> {
COL_MAP.op.binary_op,
COL_MAP.op.fp254_op,
COL_MAP.op.ternary_op,
COL_MAP.op.shift,
])),
)
}
pub fn ctl_arithmetic_shift_rows<F: Field>() -> TableWithColumns<F> {
// Instead of taking single columns, we reconstruct the entire opcode value directly.
let mut columns = vec![Column::le_bits(COL_MAP.opcode_bits)];
columns.extend(ctl_data_ternops(true));
// Create the CPU Table whose columns are those with the three
// inputs and one output of the ternary operations listed in `ops`
// (also `ops` is used as the operation filter). The list of
// operations includes binary operations which will simply ignore
// the third input.
TableWithColumns::new(Table::Cpu, columns, Some(Column::single(COL_MAP.op.shift)))
}
pub fn ctl_data_byte_packing<F: Field>() -> Vec<Column<F>> {
ctl_data_keccak_sponge()
}

View File

@ -499,18 +499,12 @@ fn append_shift<F: Field>(
channel.addr_virtual = F::from_canonical_usize(lookup_addr.virt);
}
// Convert the shift, and log the corresponding arithmetic operation.
let input0 = if input0 > U256::from(255u64) {
U256::zero()
} else {
U256::one() << input0
};
let operator = if is_shl {
BinaryOperator::Shl
} else {
BinaryOperator::Shr
};
let operation = arithmetic::Operation::binary(operator, input1, input0);
let operation = arithmetic::Operation::binary(operator, input0, input1);
state.traces.push_arithmetic(operation);
state.traces.push_memory(log_in0);