Implement a mul-add circuit in the ALU (#495)

* Implement a mul-add circuit in the ALU

The inputs are assumed to be `u32`s, while the output is encoded as four `u16 limbs`. Each output limb is range-checked.

So, our basic mul-add constraint looks like

    out_0 + 2^16 out_1 + 2^32 out_2 + 2^48 out_3 = in_1 * in_2 + in_3

The right hand side will never overflow, since `u32::MAX * u32::MAX + u32::MAX < |F|`. However, the left hand side could overflow, even though we know each limb is less than `2^16`.

For example, an operation like `0 * 0 + 0` could have two possible outputs, 0 and `|F|`, both of which would satisfy the constraint above. To prevent these non-canonical outputs, we need a comparison to enforce that `out < |F|`.

Thankfully, `F::MAX` has all zeros in its low 32 bits, so `x <= F::MAX` is equivalent to `x_lo == 0 || x_hi != u32::MAX`. `x_hi != u32::MAX` can be checked by showing that `u32::MAX - x_hi` has an inverse. If `x_hi != u32::MAX`, the prover provides this (purported) inverse in an advice column.

See @bobbinth's [post](https://hackmd.io/NC-yRmmtRQSvToTHb96e8Q#Checking-element-validity) for details. That post calls the purported inverse column `m`; I named it `canonical_inv` in this code.

* fix

* PR feedback

* naming
This commit is contained in:
Daniel Lubarov 2022-02-21 00:39:04 -08:00 committed by GitHub
parent bc3685587c
commit 6072fab077
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 249 additions and 63 deletions

View File

@ -6,6 +6,7 @@ edition = "2021"
[dependencies]
plonky2 = { path = "../plonky2" }
plonky2_util = { path = "../util" }
starky = { path = "../starky" }
anyhow = "1.0.40"
env_logger = "0.9.0"

View File

@ -11,14 +11,14 @@ use crate::registers::alu::*;
use crate::registers::NUM_COLUMNS;
pub(crate) fn generate_addition<F: PrimeField64>(values: &mut [F; NUM_COLUMNS]) {
let in_1 = values[COL_ADD_INPUT_1].to_canonical_u64();
let in_2 = values[COL_ADD_INPUT_2].to_canonical_u64();
let in_3 = values[COL_ADD_INPUT_3].to_canonical_u64();
let in_1 = values[COL_ADD_INPUT_0].to_canonical_u64();
let in_2 = values[COL_ADD_INPUT_1].to_canonical_u64();
let in_3 = values[COL_ADD_INPUT_2].to_canonical_u64();
let output = in_1 + in_2 + in_3;
values[COL_ADD_OUTPUT_1] = F::from_canonical_u16(output as u16);
values[COL_ADD_OUTPUT_2] = F::from_canonical_u16((output >> 16) as u16);
values[COL_ADD_OUTPUT_3] = F::from_canonical_u16((output >> 32) as u16);
values[COL_ADD_OUTPUT_0] = F::from_canonical_u16(output as u16);
values[COL_ADD_OUTPUT_1] = F::from_canonical_u16((output >> 16) as u16);
values[COL_ADD_OUTPUT_2] = F::from_canonical_u16((output >> 32) as u16);
}
pub(crate) fn eval_addition<F: Field, P: PackedField<Scalar = F>>(
@ -26,12 +26,12 @@ pub(crate) fn eval_addition<F: Field, P: PackedField<Scalar = F>>(
yield_constr: &mut ConstraintConsumer<P>,
) {
let is_add = local_values[IS_ADD];
let in_1 = local_values[COL_ADD_INPUT_1];
let in_2 = local_values[COL_ADD_INPUT_2];
let in_3 = local_values[COL_ADD_INPUT_3];
let out_1 = local_values[COL_ADD_OUTPUT_1];
let out_2 = local_values[COL_ADD_OUTPUT_2];
let out_3 = local_values[COL_ADD_OUTPUT_3];
let in_1 = local_values[COL_ADD_INPUT_0];
let in_2 = local_values[COL_ADD_INPUT_1];
let in_3 = local_values[COL_ADD_INPUT_2];
let out_1 = local_values[COL_ADD_OUTPUT_0];
let out_2 = local_values[COL_ADD_OUTPUT_1];
let out_3 = local_values[COL_ADD_OUTPUT_2];
let weight_2 = F::from_canonical_u64(1 << 16);
let weight_3 = F::from_canonical_u64(1 << 32);
@ -50,12 +50,12 @@ pub(crate) fn eval_addition_recursively<F: RichField + Extendable<D>, const D: u
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let is_add = local_values[IS_ADD];
let in_1 = local_values[COL_ADD_INPUT_1];
let in_2 = local_values[COL_ADD_INPUT_2];
let in_3 = local_values[COL_ADD_INPUT_3];
let out_1 = local_values[COL_ADD_OUTPUT_1];
let out_2 = local_values[COL_ADD_OUTPUT_2];
let out_3 = local_values[COL_ADD_OUTPUT_3];
let in_1 = local_values[COL_ADD_INPUT_0];
let in_2 = local_values[COL_ADD_INPUT_1];
let in_3 = local_values[COL_ADD_INPUT_2];
let out_1 = local_values[COL_ADD_OUTPUT_0];
let out_2 = local_values[COL_ADD_OUTPUT_1];
let out_3 = local_values[COL_ADD_OUTPUT_2];
let limb_base = builder.constant(F::from_canonical_u64(1 << 16));
// Note that this can't overflow. Since each output limb has been range checked as 16-bits,

View File

@ -0,0 +1,109 @@
//! Helper methods for checking that a value is canonical, i.e. is less than `|F|`.
//!
//! See https://hackmd.io/NC-yRmmtRQSvToTHb96e8Q#Checking-element-validity
use plonky2::field::extension_field::Extendable;
use plonky2::field::field_types::Field;
use plonky2::field::packed_field::PackedField;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
/// Computes the helper value used in the is-canonical check.
pub(crate) fn compute_canonical_inv<F: Field>(value_to_check: u64) -> F {
let value_hi_32 = (value_to_check >> 32) as u32;
if value_hi_32 == u32::MAX {
debug_assert_eq!(value_to_check as u32, 0, "Value was not canonical.");
// In this case it doesn't matter what we put for the purported inverse value. The
// constraint containing this value will get multiplied by the low u32 limb, which will be
// zero, satisfying the constraint regardless of what we put here.
F::ZERO
} else {
F::from_canonical_u32(u32::MAX - value_hi_32).inverse()
}
}
/// Adds constraints to require that a list of four `u16`s, in little-endian order, represent a
/// canonical field element, i.e. that their combined value is less than `|F|`. Returns their
/// combined value.
pub(crate) fn combine_u16s_check_canonical<F: Field, P: PackedField<Scalar = F>>(
limb_0_u16: P,
limb_1_u16: P,
limb_2_u16: P,
limb_3_u16: P,
inverse: P,
yield_constr: &mut ConstraintConsumer<P>,
) -> P {
let base = F::from_canonical_u32(1 << 16);
let limb_0_u32 = limb_0_u16 + limb_1_u16 * base;
let limb_1_u32 = limb_2_u16 + limb_3_u16 * base;
combine_u32s_check_canonical(limb_0_u32, limb_1_u32, inverse, yield_constr)
}
/// Adds constraints to require that a list of four `u16`s, in little-endian order, represent a
/// canonical field element, i.e. that their combined value is less than `|F|`. Returns their
/// combined value.
pub(crate) fn combine_u16s_check_canonical_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
limb_0_u16: ExtensionTarget<D>,
limb_1_u16: ExtensionTarget<D>,
limb_2_u16: ExtensionTarget<D>,
limb_3_u16: ExtensionTarget<D>,
inverse: ExtensionTarget<D>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) -> ExtensionTarget<D> {
let base = F::from_canonical_u32(1 << 16);
let limb_0_u32 = builder.mul_const_add_extension(base, limb_1_u16, limb_0_u16);
let limb_1_u32 = builder.mul_const_add_extension(base, limb_3_u16, limb_2_u16);
combine_u32s_check_canonical_circuit(builder, limb_0_u32, limb_1_u32, inverse, yield_constr)
}
/// Adds constraints to require that a pair of `u32`s, in little-endian order, represent a canonical
/// field element, i.e. that their combined value is less than `|F|`. Returns their combined value.
pub(crate) fn combine_u32s_check_canonical<F: Field, P: PackedField<Scalar = F>>(
limb_0_u32: P,
limb_1_u32: P,
inverse: P,
yield_constr: &mut ConstraintConsumer<P>,
) -> P {
let u32_max = P::from(F::from_canonical_u32(u32::MAX));
// This is zero if and only if the high limb is `u32::MAX`.
let diff = u32_max - limb_1_u32;
// If this is zero, the diff is invertible, so the high limb is not `u32::MAX`.
let hi_not_max = inverse * diff - F::ONE;
// If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero.
let hi_not_max_or_lo_zero = hi_not_max * limb_0_u32;
yield_constr.constraint(hi_not_max_or_lo_zero);
// Return the combined value.
limb_0_u32 + limb_1_u32 * F::from_canonical_u64(1 << 32)
}
/// Adds constraints to require that a pair of `u32`s, in little-endian order, represent a canonical
/// field element, i.e. that their combined value is less than `|F|`. Returns their combined value.
pub(crate) fn combine_u32s_check_canonical_circuit<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
limb_0_u32: ExtensionTarget<D>,
limb_1_u32: ExtensionTarget<D>,
inverse: ExtensionTarget<D>,
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) -> ExtensionTarget<D> {
let one = builder.one_extension();
let u32_max = builder.constant_extension(F::Extension::from_canonical_u32(u32::MAX));
// This is zero if and only if the high limb is `u32::MAX`.
let diff = builder.sub_extension(u32_max, limb_1_u32);
// If this is zero, the diff is invertible, so the high limb is not `u32::MAX`.
let hi_not_max = builder.mul_sub_extension(inverse, diff, one);
// If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero.
let hi_not_max_or_lo_zero = builder.mul_extension(hi_not_max, limb_0_u32);
yield_constr.constraint(builder, hi_not_max_or_lo_zero);
// Return the combined value.
builder.mul_const_add_extension(F::from_canonical_u64(1 << 32), limb_1_u32, limb_0_u32)
}

View File

@ -9,9 +9,7 @@ use starky::vars::StarkEvaluationVars;
use crate::alu::addition::{eval_addition, eval_addition_recursively, generate_addition};
use crate::alu::division::{eval_division, eval_division_recursively, generate_division};
use crate::alu::multiplication::{
eval_multiplication, eval_multiplication_recursively, generate_multiplication,
};
use crate::alu::mul_add::{eval_mul_add, eval_mul_add_recursively, generate_mul_add};
use crate::alu::subtraction::{
eval_subtraction, eval_subtraction_recursively, generate_subtraction,
};
@ -20,8 +18,9 @@ use crate::registers::alu::*;
use crate::registers::NUM_COLUMNS;
mod addition;
mod canonical;
mod division;
mod multiplication;
mod mul_add;
mod subtraction;
pub(crate) fn generate_alu<F: PrimeField64>(values: &mut [F; NUM_COLUMNS]) {
@ -30,7 +29,7 @@ pub(crate) fn generate_alu<F: PrimeField64>(values: &mut [F; NUM_COLUMNS]) {
} else if values[IS_SUB].is_one() {
generate_subtraction(values);
} else if values[IS_MUL].is_one() {
generate_multiplication(values);
generate_mul_add(values);
} else if values[IS_DIV].is_one() {
generate_division(values);
}
@ -50,7 +49,7 @@ pub(crate) fn eval_alu<F: Field, P: PackedField<Scalar = F>>(
eval_addition(local_values, yield_constr);
eval_subtraction(local_values, yield_constr);
eval_multiplication(local_values, yield_constr);
eval_mul_add(local_values, yield_constr);
eval_division(local_values, yield_constr);
}
@ -70,6 +69,6 @@ pub(crate) fn eval_alu_recursively<F: RichField + Extendable<D>, const D: usize>
eval_addition_recursively(builder, local_values, yield_constr);
eval_subtraction_recursively(builder, local_values, yield_constr);
eval_multiplication_recursively(builder, local_values, yield_constr);
eval_mul_add_recursively(builder, local_values, yield_constr);
eval_division_recursively(builder, local_values, yield_constr);
}

View File

@ -0,0 +1,88 @@
use plonky2::field::extension_field::Extendable;
use plonky2::field::field_types::{Field, PrimeField64};
use plonky2::field::packed_field::PackedField;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2_util::assume;
use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::alu::canonical::*;
use crate::registers::alu::*;
use crate::registers::NUM_COLUMNS;
pub(crate) fn generate_mul_add<F: PrimeField64>(values: &mut [F; NUM_COLUMNS]) {
let factor_0 = values[COL_MUL_ADD_FACTOR_0].to_canonical_u64();
let factor_1 = values[COL_MUL_ADD_FACTOR_1].to_canonical_u64();
let addend = values[COL_MUL_ADD_ADDEND].to_canonical_u64();
// Let the compiler know that each input must fit in 32 bits.
assume(factor_0 <= u32::MAX as u64);
assume(factor_1 <= u32::MAX as u64);
assume(addend <= u32::MAX as u64);
let output = factor_0 * factor_1 + addend;
// An advice value used to help verify that the limbs represent a canonical field element.
values[COL_MUL_ADD_RESULT_CANONICAL_INV] = compute_canonical_inv(output);
values[COL_MUL_ADD_OUTPUT_0] = F::from_canonical_u16(output as u16);
values[COL_MUL_ADD_OUTPUT_1] = F::from_canonical_u16((output >> 16) as u16);
values[COL_MUL_ADD_OUTPUT_2] = F::from_canonical_u16((output >> 32) as u16);
values[COL_MUL_ADD_OUTPUT_3] = F::from_canonical_u16((output >> 48) as u16);
}
pub(crate) fn eval_mul_add<F: Field, P: PackedField<Scalar = F>>(
local_values: &[P; NUM_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
) {
let is_mul = local_values[IS_MUL];
let factor_0 = local_values[COL_MUL_ADD_FACTOR_0];
let factor_1 = local_values[COL_MUL_ADD_FACTOR_1];
let addend = local_values[COL_MUL_ADD_ADDEND];
let output_1 = local_values[COL_MUL_ADD_OUTPUT_0];
let output_2 = local_values[COL_MUL_ADD_OUTPUT_1];
let output_3 = local_values[COL_MUL_ADD_OUTPUT_2];
let output_4 = local_values[COL_MUL_ADD_OUTPUT_3];
let result_canonical_inv = local_values[COL_MUL_ADD_RESULT_CANONICAL_INV];
let computed_output = factor_0 * factor_1 + addend;
let output = combine_u16s_check_canonical(
output_1,
output_2,
output_3,
output_4,
result_canonical_inv,
yield_constr,
);
yield_constr.constraint(computed_output - output);
}
pub(crate) fn eval_mul_add_recursively<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
local_values: &[ExtensionTarget<D>; NUM_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let is_mul = local_values[IS_MUL];
let factor_0 = local_values[COL_MUL_ADD_FACTOR_0];
let factor_1 = local_values[COL_MUL_ADD_FACTOR_1];
let addend = local_values[COL_MUL_ADD_ADDEND];
let output_1 = local_values[COL_MUL_ADD_OUTPUT_0];
let output_2 = local_values[COL_MUL_ADD_OUTPUT_1];
let output_3 = local_values[COL_MUL_ADD_OUTPUT_2];
let output_4 = local_values[COL_MUL_ADD_OUTPUT_3];
let result_canonical_inv = local_values[COL_MUL_ADD_RESULT_CANONICAL_INV];
let computed_output = builder.mul_add_extension(factor_0, factor_1, addend);
let output = combine_u16s_check_canonical_circuit(
builder,
output_1,
output_2,
output_3,
output_4,
result_canonical_inv,
yield_constr,
);
let diff = builder.sub_extension(computed_output, output);
yield_constr.constraint(builder, diff);
}

View File

@ -1,31 +0,0 @@
use plonky2::field::extension_field::Extendable;
use plonky2::field::field_types::{Field, PrimeField64};
use plonky2::field::packed_field::PackedField;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::registers::alu::*;
use crate::registers::NUM_COLUMNS;
pub(crate) fn generate_multiplication<F: PrimeField64>(values: &mut [F; NUM_COLUMNS]) {
// TODO
}
pub(crate) fn eval_multiplication<F: Field, P: PackedField<Scalar = F>>(
local_values: &[P; NUM_COLUMNS],
yield_constr: &mut ConstraintConsumer<P>,
) {
let is_mul = local_values[IS_MUL];
// TODO
}
pub(crate) fn eval_multiplication_recursively<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
local_values: &[ExtensionTarget<D>; NUM_COLUMNS],
yield_constr: &mut RecursiveConstraintConsumer<F, D>,
) {
let is_mul = local_values[IS_MUL];
// TODO
}

View File

@ -10,7 +10,7 @@ const START_SHARED_COLS: usize = IS_DIV + 1;
/// Within the ALU, there are shared columns which can be used by any arithmetic/logic
/// circuit, depending on which one is active this cycle.
// Can be increased as needed as other operations are implemented.
const NUM_SHARED_COLS: usize = 3;
const NUM_SHARED_COLS: usize = 4;
const fn shared_col(i: usize) -> usize {
debug_assert!(i < NUM_SHARED_COLS);
@ -18,20 +18,40 @@ const fn shared_col(i: usize) -> usize {
}
/// The first value to be added; treated as an unsigned u32.
pub(crate) const COL_ADD_INPUT_1: usize = shared_col(0);
pub(crate) const COL_ADD_INPUT_0: usize = shared_col(0);
/// The second value to be added; treated as an unsigned u32.
pub(crate) const COL_ADD_INPUT_2: usize = shared_col(1);
pub(crate) const COL_ADD_INPUT_1: usize = shared_col(1);
/// The third value to be added; treated as an unsigned u32.
pub(crate) const COL_ADD_INPUT_3: usize = shared_col(2);
pub(crate) const COL_ADD_INPUT_2: usize = shared_col(2);
// Note: Addition outputs three 16-bit chunks, and since these values need to be range-checked
// anyway, we might as well use the range check unit's columns as our addition outputs. So the
// three proceeding columns are basically aliases, not columns owned by the ALU.
/// The first 16-bit chunk of the output, based on little-endian ordering.
pub(crate) const COL_ADD_OUTPUT_1: usize = super::range_check_16::col_rc_16_input(0);
pub(crate) const COL_ADD_OUTPUT_0: usize = super::range_check_16::col_rc_16_input(0);
/// The second 16-bit chunk of the output, based on little-endian ordering.
pub(crate) const COL_ADD_OUTPUT_2: usize = super::range_check_16::col_rc_16_input(1);
pub(crate) const COL_ADD_OUTPUT_1: usize = super::range_check_16::col_rc_16_input(1);
/// The third 16-bit chunk of the output, based on little-endian ordering.
pub(crate) const COL_ADD_OUTPUT_3: usize = super::range_check_16::col_rc_16_input(2);
pub(crate) const COL_ADD_OUTPUT_2: usize = super::range_check_16::col_rc_16_input(2);
/// The first value to be multiplied; treated as an unsigned u32.
pub(crate) const COL_MUL_ADD_FACTOR_0: usize = shared_col(0);
/// The second value to be multiplied; treated as an unsigned u32.
pub(crate) const COL_MUL_ADD_FACTOR_1: usize = shared_col(1);
/// The value to be added to the product; treated as an unsigned u32.
pub(crate) const COL_MUL_ADD_ADDEND: usize = shared_col(2);
/// The inverse of `u32::MAX - result_hi`, where `output_hi` is the high 32-bits of the result.
/// See https://hackmd.io/NC-yRmmtRQSvToTHb96e8Q#Checking-element-validity
pub(crate) const COL_MUL_ADD_RESULT_CANONICAL_INV: usize = shared_col(3);
/// The first 16-bit chunk of the output, based on little-endian ordering.
pub(crate) const COL_MUL_ADD_OUTPUT_0: usize = super::range_check_16::col_rc_16_input(0);
/// The second 16-bit chunk of the output, based on little-endian ordering.
pub(crate) const COL_MUL_ADD_OUTPUT_1: usize = super::range_check_16::col_rc_16_input(1);
/// The third 16-bit chunk of the output, based on little-endian ordering.
pub(crate) const COL_MUL_ADD_OUTPUT_2: usize = super::range_check_16::col_rc_16_input(2);
/// The fourth 16-bit chunk of the output, based on little-endian ordering.
pub(crate) const COL_MUL_ADD_OUTPUT_3: usize = super::range_check_16::col_rc_16_input(3);
pub(super) const END: usize = super::START_ALU + NUM_SHARED_COLS;