Implement DIV instruction (#790)

* Implement DIV instruction.

* cargo fmt, clippy, minor doc update.

* Add implementation of circuit version.
This commit is contained in:
Hamish Ivey-Law 2022-10-21 16:25:38 +11:00 committed by GitHub
parent f55e07659c
commit 4af2ede6e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 28 deletions

View File

@ -59,6 +59,8 @@ impl<F: RichField, const D: usize> ArithmeticStark<F, D> {
modular::generate(local_values, columns::IS_MULMOD); modular::generate(local_values, columns::IS_MULMOD);
} else if local_values[columns::IS_MOD].is_one() { } else if local_values[columns::IS_MOD].is_one() {
modular::generate(local_values, columns::IS_MOD); modular::generate(local_values, columns::IS_MOD);
} else if local_values[columns::IS_DIV].is_one() {
modular::generate(local_values, columns::IS_DIV);
} else { } else {
todo!("the requested operation has not yet been implemented"); todo!("the requested operation has not yet been implemented");
} }

View File

@ -85,4 +85,11 @@ pub(crate) const MODULAR_AUX_INPUT: Range<usize> = AUX_INPUT_1;
pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_INPUT_1.end - 1; pub(crate) const MODULAR_MOD_IS_ZERO: usize = AUX_INPUT_1.end - 1;
pub(crate) const MODULAR_OUT_AUX_RED: Range<usize> = AUX_INPUT_2; pub(crate) const MODULAR_OUT_AUX_RED: Range<usize> = AUX_INPUT_2;
#[allow(unused)] // TODO: Will be used when hooking into the CPU
pub(crate) const DIV_NUMERATOR: Range<usize> = MODULAR_INPUT_0;
#[allow(unused)] // TODO: Will be used when hooking into the CPU
pub(crate) const DIV_DENOMINATOR: Range<usize> = MODULAR_MODULUS;
#[allow(unused)] // TODO: Will be used when hooking into the CPU
pub(crate) const DIV_OUTPUT: Range<usize> = MODULAR_QUO_INPUT.start..MODULAR_QUO_INPUT.start + 16;
pub const NUM_ARITH_COLUMNS: usize = START_SHARED_COLS + NUM_SHARED_COLS; pub const NUM_ARITH_COLUMNS: usize = START_SHARED_COLS + NUM_SHARED_COLS;

View File

@ -1,4 +1,5 @@
//! Support for the EVM modular instructions ADDMOD, MULMOD and MOD. //! Support for the EVM modular instructions ADDMOD, MULMOD and MOD,
//! as well as DIV.
//! //!
//! This crate verifies an EVM modular instruction, which takes three //! This crate verifies an EVM modular instruction, which takes three
//! 256-bit inputs A, B and M, and produces a 256-bit output C satisfying //! 256-bit inputs A, B and M, and produces a 256-bit output C satisfying
@ -82,8 +83,11 @@
//! - if modulus is non-zero, correct output is obtained //! - if modulus is non-zero, correct output is obtained
//! - if modulus is 0, then the test output < modulus, checking that //! - if modulus is 0, then the test output < modulus, checking that
//! the output is reduced, will fail, because output is non-negative. //! the output is reduced, will fail, because output is non-negative.
//!
//! In the case of DIV, we do something similar, except that we "replace"
//! the modulus with "2^256" to force the quotient to be zero.
use num::{bigint::Sign, BigInt, Zero}; use num::{bigint::Sign, BigInt, One, Zero};
use plonky2::field::extension::Extendable; use plonky2::field::extension::Extendable;
use plonky2::field::packed::PackedField; use plonky2::field::packed::PackedField;
use plonky2::field::types::Field; use plonky2::field::types::Field;
@ -166,6 +170,7 @@ fn bigint_to_columns<const N: usize>(num: &BigInt) -> [i64; N] {
/// zero if they are not used. /// zero if they are not used.
fn generate_modular_op<F: RichField>( fn generate_modular_op<F: RichField>(
lv: &mut [F; NUM_ARITH_COLUMNS], lv: &mut [F; NUM_ARITH_COLUMNS],
filter: usize,
operation: fn([i64; N_LIMBS], [i64; N_LIMBS]) -> [i64; 2 * N_LIMBS - 1], operation: fn([i64; N_LIMBS], [i64; N_LIMBS]) -> [i64; 2 * N_LIMBS - 1],
) { ) {
// Inputs are all range-checked in [0, 2^16), so the "as i64" // Inputs are all range-checked in [0, 2^16), so the "as i64"
@ -185,17 +190,31 @@ fn generate_modular_op<F: RichField>(
let mut constr_poly = [0i64; 2 * N_LIMBS]; let mut constr_poly = [0i64; 2 * N_LIMBS];
constr_poly[..2 * N_LIMBS - 1].copy_from_slice(&operation(input0_limbs, input1_limbs)); constr_poly[..2 * N_LIMBS - 1].copy_from_slice(&operation(input0_limbs, input1_limbs));
// two_exp_256 == 2^256
let two_exp_256 = {
let mut t = BigInt::zero();
t.set_bit(256, true);
t
};
let mut mod_is_zero = F::ZERO; let mut mod_is_zero = F::ZERO;
if modulus.is_zero() { if modulus.is_zero() {
modulus += 1u32; if filter == columns::IS_DIV {
modulus_limbs[0] += 1i64; // set modulus = 2^256
modulus = two_exp_256.clone();
// modulus_limbs don't play a role below
} else {
// set modulus = 1
modulus = BigInt::one();
modulus_limbs[0] = 1i64;
}
mod_is_zero = F::ONE; mod_is_zero = F::ONE;
} }
let input = columns_to_bigint(&constr_poly); let input = columns_to_bigint(&constr_poly);
// modulus != 0 here, because, if the given modulus was zero, then // modulus != 0 here, because, if the given modulus was zero, then
// we added 1 to it above. // it was set to 1 or 2^256 above
let mut output = &input % &modulus; let mut output = &input % &modulus;
// output will be -ve (but > -modulus) if input was -ve, so we can // output will be -ve (but > -modulus) if input was -ve, so we can
// add modulus to obtain a "canonical" +ve output. // add modulus to obtain a "canonical" +ve output.
@ -206,9 +225,6 @@ fn generate_modular_op<F: RichField>(
let quot = (&input - &output) / &modulus; // exact division; can be -ve let quot = (&input - &output) / &modulus; // exact division; can be -ve
let quot_limbs = bigint_to_columns::<{ 2 * N_LIMBS }>(&quot); let quot_limbs = bigint_to_columns::<{ 2 * N_LIMBS }>(&quot);
// two_exp_256 == 2^256
let mut two_exp_256 = BigInt::zero();
two_exp_256.set_bit(256, true);
// output < modulus here, so the proof requires (output - modulus) % 2^256: // output < modulus here, so the proof requires (output - modulus) % 2^256:
let out_aux_red = bigint_to_columns::<N_LIMBS>(&(two_exp_256 + output - modulus)); let out_aux_red = bigint_to_columns::<N_LIMBS>(&(two_exp_256 + output - modulus));
@ -240,10 +256,10 @@ fn generate_modular_op<F: RichField>(
/// `filter` must be one of `columns::IS_{ADDMOD,MULMOD,MOD}`. /// `filter` must be one of `columns::IS_{ADDMOD,MULMOD,MOD}`.
pub(crate) fn generate<F: RichField>(lv: &mut [F; NUM_ARITH_COLUMNS], filter: usize) { pub(crate) fn generate<F: RichField>(lv: &mut [F; NUM_ARITH_COLUMNS], filter: usize) {
match filter { match filter {
columns::IS_ADDMOD => generate_modular_op(lv, pol_add), columns::IS_ADDMOD => generate_modular_op(lv, filter, pol_add),
columns::IS_SUBMOD => generate_modular_op(lv, pol_sub), columns::IS_SUBMOD => generate_modular_op(lv, filter, pol_sub),
columns::IS_MULMOD => generate_modular_op(lv, pol_mul_wide), columns::IS_MULMOD => generate_modular_op(lv, filter, pol_mul_wide),
columns::IS_MOD => generate_modular_op(lv, |a, _| pol_extend(a)), columns::IS_MOD | columns::IS_DIV => generate_modular_op(lv, filter, |a, _| pol_extend(a)),
_ => panic!("generate modular operation called with unknown opcode"), _ => panic!("generate modular operation called with unknown opcode"),
} }
} }
@ -256,7 +272,6 @@ pub(crate) fn generate<F: RichField>(lv: &mut [F; NUM_ARITH_COLUMNS], filter: us
/// c(x) + q(x) * m(x) + (x - β) * s(x) /// c(x) + q(x) * m(x) + (x - β) * s(x)
/// ///
/// and check consistency when m = 0, and that c is reduced. /// and check consistency when m = 0, and that c is reduced.
#[allow(clippy::needless_range_loop)]
fn modular_constr_poly<P: PackedField>( fn modular_constr_poly<P: PackedField>(
lv: &[P; NUM_ARITH_COLUMNS], lv: &[P; NUM_ARITH_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>, yield_constr: &mut ConstraintConsumer<P>,
@ -284,19 +299,33 @@ fn modular_constr_poly<P: PackedField>(
// modulus = 0. // modulus = 0.
modulus[0] += mod_is_zero; modulus[0] += mod_is_zero;
let output = &lv[MODULAR_OUTPUT]; let mut output = read_value::<N_LIMBS, _>(lv, MODULAR_OUTPUT);
// Needed to compensate for adding mod_is_zero to modulus above,
// since the call eval_packed_generic_lt() below subtracts modulus
// verify in the case of a DIV.
output[0] += mod_is_zero * lv[IS_DIV];
// Verify that the output is reduced, i.e. output < modulus. // Verify that the output is reduced, i.e. output < modulus.
let out_aux_red = &lv[MODULAR_OUT_AUX_RED]; let out_aux_red = &lv[MODULAR_OUT_AUX_RED];
let is_less_than = P::ONES; // this sets is_less_than to 1 unless we get mod_is_zero when
// doing a DIV; in that case, we need is_less_than=0, since the
// function checks
//
// output - modulus == out_aux_red + is_less_than*2^256
//
// and we were given output = out_aux_red
let is_less_than = P::ONES - mod_is_zero * lv[IS_DIV];
eval_packed_generic_lt( eval_packed_generic_lt(
yield_constr, yield_constr,
filter, filter,
output, &output,
&modulus, &modulus,
out_aux_red, out_aux_red,
is_less_than, is_less_than,
); );
// restore output[0]
output[0] -= mod_is_zero * lv[IS_DIV];
// prod = q(x) * m(x) // prod = q(x) * m(x)
let quot = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT); let quot = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT);
@ -308,7 +337,7 @@ fn modular_constr_poly<P: PackedField>(
// constr_poly = c(x) + q(x) * m(x) // constr_poly = c(x) + q(x) * m(x)
let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap(); let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap();
pol_add_assign(&mut constr_poly, output); pol_add_assign(&mut constr_poly, &output);
// constr_poly = c(x) + q(x) * m(x) + (x - β) * s(x) // constr_poly = c(x) + q(x) * m(x) + (x - β) * s(x)
let mut aux = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_AUX_INPUT); let mut aux = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_AUX_INPUT);
@ -329,7 +358,8 @@ pub(crate) fn eval_packed_generic<P: PackedField>(
let filter = lv[columns::IS_ADDMOD] let filter = lv[columns::IS_ADDMOD]
+ lv[columns::IS_MULMOD] + lv[columns::IS_MULMOD]
+ lv[columns::IS_MOD] + lv[columns::IS_MOD]
+ lv[columns::IS_SUBMOD]; + lv[columns::IS_SUBMOD]
+ lv[columns::IS_DIV];
// constr_poly has 2*N_LIMBS limbs // constr_poly has 2*N_LIMBS limbs
let constr_poly = modular_constr_poly(lv, yield_constr, filter); let constr_poly = modular_constr_poly(lv, yield_constr, filter);
@ -346,7 +376,7 @@ pub(crate) fn eval_packed_generic<P: PackedField>(
(&add_input, &lv[columns::IS_ADDMOD]), (&add_input, &lv[columns::IS_ADDMOD]),
(&sub_input, &lv[columns::IS_SUBMOD]), (&sub_input, &lv[columns::IS_SUBMOD]),
(&mul_input, &lv[columns::IS_MULMOD]), (&mul_input, &lv[columns::IS_MULMOD]),
(&mod_input, &lv[columns::IS_MOD]), (&mod_input, &(lv[columns::IS_MOD] + lv[columns::IS_DIV])),
] { ] {
// Need constr_poly_copy to be the first argument to // Need constr_poly_copy to be the first argument to
// pol_sub_assign, since it is the longer of the two // pol_sub_assign, since it is the longer of the two
@ -388,18 +418,25 @@ fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, const D: usize>
modulus[0] = builder.add_extension(modulus[0], mod_is_zero); modulus[0] = builder.add_extension(modulus[0], mod_is_zero);
let output = &lv[MODULAR_OUTPUT]; let mut output = read_value::<N_LIMBS, _>(lv, MODULAR_OUTPUT);
output[0] = builder.mul_add_extension(mod_is_zero, lv[IS_DIV], output[0]);
let out_aux_red = &lv[MODULAR_OUT_AUX_RED]; let out_aux_red = &lv[MODULAR_OUT_AUX_RED];
let is_less_than = builder.one_extension(); let one = builder.one_extension();
let is_less_than =
builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], one);
eval_ext_circuit_lt( eval_ext_circuit_lt(
builder, builder,
yield_constr, yield_constr,
filter, filter,
output, &output,
&modulus, &modulus,
out_aux_red, out_aux_red,
is_less_than, is_less_than,
); );
output[0] =
builder.arithmetic_extension(F::NEG_ONE, F::ONE, mod_is_zero, lv[IS_DIV], output[0]);
let quot = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT); let quot = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_QUO_INPUT);
let prod = pol_mul_wide2_ext_circuit(builder, quot, modulus); let prod = pol_mul_wide2_ext_circuit(builder, quot, modulus);
@ -409,7 +446,7 @@ fn modular_constr_poly_ext_circuit<F: RichField + Extendable<D>, const D: usize>
} }
let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap(); let mut constr_poly: [_; 2 * N_LIMBS] = prod[0..2 * N_LIMBS].try_into().unwrap();
pol_add_assign_ext_circuit(builder, &mut constr_poly, output); pol_add_assign_ext_circuit(builder, &mut constr_poly, &output);
let mut aux = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_AUX_INPUT); let mut aux = read_value::<{ 2 * N_LIMBS }, _>(lv, MODULAR_AUX_INPUT);
aux[2 * N_LIMBS - 1] = builder.zero_extension(); aux[2 * N_LIMBS - 1] = builder.zero_extension();
@ -430,6 +467,7 @@ pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
lv[columns::IS_SUBMOD], lv[columns::IS_SUBMOD],
lv[columns::IS_MULMOD], lv[columns::IS_MULMOD],
lv[columns::IS_MOD], lv[columns::IS_MOD],
lv[columns::IS_DIV],
]); ]);
let constr_poly = modular_constr_poly_ext_circuit(lv, builder, yield_constr, filter); let constr_poly = modular_constr_poly_ext_circuit(lv, builder, yield_constr, filter);
@ -442,11 +480,12 @@ pub(crate) fn eval_ext_circuit<F: RichField + Extendable<D>, const D: usize>(
let mul_input = pol_mul_wide_ext_circuit(builder, input0, input1); let mul_input = pol_mul_wide_ext_circuit(builder, input0, input1);
let mod_input = pol_extend_ext_circuit(builder, input0); let mod_input = pol_extend_ext_circuit(builder, input0);
let mod_div_filter = builder.add_extension(lv[columns::IS_MOD], lv[columns::IS_DIV]);
for (input, &filter) in [ for (input, &filter) in [
(&add_input, &lv[columns::IS_ADDMOD]), (&add_input, &lv[columns::IS_ADDMOD]),
(&sub_input, &lv[columns::IS_SUBMOD]), (&sub_input, &lv[columns::IS_SUBMOD]),
(&mul_input, &lv[columns::IS_MULMOD]), (&mul_input, &lv[columns::IS_MULMOD]),
(&mod_input, &lv[columns::IS_MOD]), (&mod_input, &mod_div_filter),
] { ] {
let mut constr_poly_copy = constr_poly; let mut constr_poly_copy = constr_poly;
pol_sub_assign_ext_circuit(builder, &mut constr_poly_copy, input); pol_sub_assign_ext_circuit(builder, &mut constr_poly_copy, input);
@ -485,6 +524,7 @@ mod tests {
lv[IS_SUBMOD] = F::ZERO; lv[IS_SUBMOD] = F::ZERO;
lv[IS_MULMOD] = F::ZERO; lv[IS_MULMOD] = F::ZERO;
lv[IS_MOD] = F::ZERO; lv[IS_MOD] = F::ZERO;
lv[IS_DIV] = F::ZERO;
let mut constraint_consumer = ConstraintConsumer::new( let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
@ -505,12 +545,13 @@ mod tests {
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::rand_from_rng(&mut rng)); let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::rand_from_rng(&mut rng));
for op_filter in [IS_ADDMOD, IS_SUBMOD, IS_MOD, IS_MULMOD] { for op_filter in [IS_ADDMOD, IS_DIV, IS_SUBMOD, IS_MOD, IS_MULMOD] {
// Reset operation columns, then select one // Reset operation columns, then select one
lv[IS_ADDMOD] = F::ZERO; lv[IS_ADDMOD] = F::ZERO;
lv[IS_SUBMOD] = F::ZERO; lv[IS_SUBMOD] = F::ZERO;
lv[IS_MULMOD] = F::ZERO; lv[IS_MULMOD] = F::ZERO;
lv[IS_MOD] = F::ZERO; lv[IS_MOD] = F::ZERO;
lv[IS_DIV] = F::ZERO;
lv[op_filter] = F::ONE; lv[op_filter] = F::ONE;
for i in 0..N_RND_TESTS { for i in 0..N_RND_TESTS {
@ -555,12 +596,13 @@ mod tests {
let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25); let mut rng = ChaCha8Rng::seed_from_u64(0x6feb51b7ec230f25);
let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::rand_from_rng(&mut rng)); let mut lv = [F::default(); NUM_ARITH_COLUMNS].map(|_| F::rand_from_rng(&mut rng));
for op_filter in [IS_ADDMOD, IS_SUBMOD, IS_MOD, IS_MULMOD] { for op_filter in [IS_ADDMOD, IS_SUBMOD, IS_DIV, IS_MOD, IS_MULMOD] {
// Reset operation columns, then select one // Reset operation columns, then select one
lv[IS_ADDMOD] = F::ZERO; lv[IS_ADDMOD] = F::ZERO;
lv[IS_SUBMOD] = F::ZERO; lv[IS_SUBMOD] = F::ZERO;
lv[IS_MULMOD] = F::ZERO; lv[IS_MULMOD] = F::ZERO;
lv[IS_MOD] = F::ZERO; lv[IS_MOD] = F::ZERO;
lv[IS_DIV] = F::ZERO;
lv[op_filter] = F::ONE; lv[op_filter] = F::ONE;
for _i in 0..N_RND_TESTS { for _i in 0..N_RND_TESTS {
@ -575,7 +617,11 @@ mod tests {
generate(&mut lv, op_filter); generate(&mut lv, op_filter);
// check that the correct output was generated // check that the correct output was generated
assert!(lv[MODULAR_OUTPUT].iter().all(|&c| c == F::ZERO)); if op_filter == IS_DIV {
assert!(lv[DIV_OUTPUT].iter().all(|&c| c == F::ZERO));
} else {
assert!(lv[MODULAR_OUTPUT].iter().all(|&c| c == F::ZERO));
}
let mut constraint_consumer = ConstraintConsumer::new( let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)], vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
@ -590,7 +636,11 @@ mod tests {
.all(|&acc| acc == F::ZERO)); .all(|&acc| acc == F::ZERO));
// Corrupt one output limb by setting it to a non-zero value // Corrupt one output limb by setting it to a non-zero value
let random_oi = MODULAR_OUTPUT.start + rng.gen::<usize>() % N_LIMBS; let random_oi = if op_filter == IS_DIV {
DIV_OUTPUT.start + rng.gen::<usize>() % N_LIMBS
} else {
MODULAR_OUTPUT.start + rng.gen::<usize>() % N_LIMBS
};
lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX)); lv[random_oi] = F::from_canonical_u16(rng.gen_range(1..u16::MAX));
eval_packed_generic(&lv, &mut constraint_consumer); eval_packed_generic(&lv, &mut constraint_consumer);