From 6072fab0770eb2f9797bdc09997e72b85282e77f Mon Sep 17 00:00:00 2001 From: Daniel Lubarov Date: Mon, 21 Feb 2022 00:39:04 -0800 Subject: [PATCH] Implement a mul-add circuit in the ALU (#495) * Implement a mul-add circuit in the ALU The inputs are assumed to be `u32`s, while the output is encoded as four `u16 limbs`. Each output limb is range-checked. So, our basic mul-add constraint looks like out_0 + 2^16 out_1 + 2^32 out_2 + 2^48 out_3 = in_1 * in_2 + in_3 The right hand side will never overflow, since `u32::MAX * u32::MAX + u32::MAX < |F|`. However, the left hand side could overflow, even though we know each limb is less than `2^16`. For example, an operation like `0 * 0 + 0` could have two possible outputs, 0 and `|F|`, both of which would satisfy the constraint above. To prevent these non-canonical outputs, we need a comparison to enforce that `out < |F|`. Thankfully, `F::MAX` has all zeros in its low 32 bits, so `x <= F::MAX` is equivalent to `x_lo == 0 || x_hi != u32::MAX`. `x_hi != u32::MAX` can be checked by showing that `u32::MAX - x_hi` has an inverse. If `x_hi != u32::MAX`, the prover provides this (purported) inverse in an advice column. See @bobbinth's [post](https://hackmd.io/NC-yRmmtRQSvToTHb96e8Q#Checking-element-validity) for details. That post calls the purported inverse column `m`; I named it `canonical_inv` in this code. * fix * PR feedback * naming --- system_zero/Cargo.toml | 1 + system_zero/src/alu/addition.rs | 36 ++++----- system_zero/src/alu/canonical.rs | 109 ++++++++++++++++++++++++++ system_zero/src/alu/mod.rs | 13 ++- system_zero/src/alu/mul_add.rs | 88 +++++++++++++++++++++ system_zero/src/alu/multiplication.rs | 31 -------- system_zero/src/registers/alu.rs | 34 ++++++-- 7 files changed, 249 insertions(+), 63 deletions(-) create mode 100644 system_zero/src/alu/canonical.rs create mode 100644 system_zero/src/alu/mul_add.rs delete mode 100644 system_zero/src/alu/multiplication.rs diff --git a/system_zero/Cargo.toml b/system_zero/Cargo.toml index e5b617c9..032bfb53 100644 --- a/system_zero/Cargo.toml +++ b/system_zero/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] plonky2 = { path = "../plonky2" } +plonky2_util = { path = "../util" } starky = { path = "../starky" } anyhow = "1.0.40" env_logger = "0.9.0" diff --git a/system_zero/src/alu/addition.rs b/system_zero/src/alu/addition.rs index dc83ecb8..c2293b4a 100644 --- a/system_zero/src/alu/addition.rs +++ b/system_zero/src/alu/addition.rs @@ -11,14 +11,14 @@ use crate::registers::alu::*; use crate::registers::NUM_COLUMNS; pub(crate) fn generate_addition(values: &mut [F; NUM_COLUMNS]) { - let in_1 = values[COL_ADD_INPUT_1].to_canonical_u64(); - let in_2 = values[COL_ADD_INPUT_2].to_canonical_u64(); - let in_3 = values[COL_ADD_INPUT_3].to_canonical_u64(); + let in_1 = values[COL_ADD_INPUT_0].to_canonical_u64(); + let in_2 = values[COL_ADD_INPUT_1].to_canonical_u64(); + let in_3 = values[COL_ADD_INPUT_2].to_canonical_u64(); let output = in_1 + in_2 + in_3; - values[COL_ADD_OUTPUT_1] = F::from_canonical_u16(output as u16); - values[COL_ADD_OUTPUT_2] = F::from_canonical_u16((output >> 16) as u16); - values[COL_ADD_OUTPUT_3] = F::from_canonical_u16((output >> 32) as u16); + values[COL_ADD_OUTPUT_0] = F::from_canonical_u16(output as u16); + values[COL_ADD_OUTPUT_1] = F::from_canonical_u16((output >> 16) as u16); + values[COL_ADD_OUTPUT_2] = F::from_canonical_u16((output >> 32) as u16); } pub(crate) fn eval_addition>( @@ -26,12 +26,12 @@ pub(crate) fn eval_addition>( yield_constr: &mut ConstraintConsumer

, ) { let is_add = local_values[IS_ADD]; - let in_1 = local_values[COL_ADD_INPUT_1]; - let in_2 = local_values[COL_ADD_INPUT_2]; - let in_3 = local_values[COL_ADD_INPUT_3]; - let out_1 = local_values[COL_ADD_OUTPUT_1]; - let out_2 = local_values[COL_ADD_OUTPUT_2]; - let out_3 = local_values[COL_ADD_OUTPUT_3]; + let in_1 = local_values[COL_ADD_INPUT_0]; + let in_2 = local_values[COL_ADD_INPUT_1]; + let in_3 = local_values[COL_ADD_INPUT_2]; + let out_1 = local_values[COL_ADD_OUTPUT_0]; + let out_2 = local_values[COL_ADD_OUTPUT_1]; + let out_3 = local_values[COL_ADD_OUTPUT_2]; let weight_2 = F::from_canonical_u64(1 << 16); let weight_3 = F::from_canonical_u64(1 << 32); @@ -50,12 +50,12 @@ pub(crate) fn eval_addition_recursively, const D: u yield_constr: &mut RecursiveConstraintConsumer, ) { let is_add = local_values[IS_ADD]; - let in_1 = local_values[COL_ADD_INPUT_1]; - let in_2 = local_values[COL_ADD_INPUT_2]; - let in_3 = local_values[COL_ADD_INPUT_3]; - let out_1 = local_values[COL_ADD_OUTPUT_1]; - let out_2 = local_values[COL_ADD_OUTPUT_2]; - let out_3 = local_values[COL_ADD_OUTPUT_3]; + let in_1 = local_values[COL_ADD_INPUT_0]; + let in_2 = local_values[COL_ADD_INPUT_1]; + let in_3 = local_values[COL_ADD_INPUT_2]; + let out_1 = local_values[COL_ADD_OUTPUT_0]; + let out_2 = local_values[COL_ADD_OUTPUT_1]; + let out_3 = local_values[COL_ADD_OUTPUT_2]; let limb_base = builder.constant(F::from_canonical_u64(1 << 16)); // Note that this can't overflow. Since each output limb has been range checked as 16-bits, diff --git a/system_zero/src/alu/canonical.rs b/system_zero/src/alu/canonical.rs new file mode 100644 index 00000000..fb90eb0d --- /dev/null +++ b/system_zero/src/alu/canonical.rs @@ -0,0 +1,109 @@ +//! Helper methods for checking that a value is canonical, i.e. is less than `|F|`. +//! +//! See https://hackmd.io/NC-yRmmtRQSvToTHb96e8Q#Checking-element-validity + +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::Field; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; + +/// Computes the helper value used in the is-canonical check. +pub(crate) fn compute_canonical_inv(value_to_check: u64) -> F { + let value_hi_32 = (value_to_check >> 32) as u32; + + if value_hi_32 == u32::MAX { + debug_assert_eq!(value_to_check as u32, 0, "Value was not canonical."); + // In this case it doesn't matter what we put for the purported inverse value. The + // constraint containing this value will get multiplied by the low u32 limb, which will be + // zero, satisfying the constraint regardless of what we put here. + F::ZERO + } else { + F::from_canonical_u32(u32::MAX - value_hi_32).inverse() + } +} + +/// Adds constraints to require that a list of four `u16`s, in little-endian order, represent a +/// canonical field element, i.e. that their combined value is less than `|F|`. Returns their +/// combined value. +pub(crate) fn combine_u16s_check_canonical>( + limb_0_u16: P, + limb_1_u16: P, + limb_2_u16: P, + limb_3_u16: P, + inverse: P, + yield_constr: &mut ConstraintConsumer

, +) -> P { + let base = F::from_canonical_u32(1 << 16); + let limb_0_u32 = limb_0_u16 + limb_1_u16 * base; + let limb_1_u32 = limb_2_u16 + limb_3_u16 * base; + combine_u32s_check_canonical(limb_0_u32, limb_1_u32, inverse, yield_constr) +} + +/// Adds constraints to require that a list of four `u16`s, in little-endian order, represent a +/// canonical field element, i.e. that their combined value is less than `|F|`. Returns their +/// combined value. +pub(crate) fn combine_u16s_check_canonical_circuit, const D: usize>( + builder: &mut CircuitBuilder, + limb_0_u16: ExtensionTarget, + limb_1_u16: ExtensionTarget, + limb_2_u16: ExtensionTarget, + limb_3_u16: ExtensionTarget, + inverse: ExtensionTarget, + yield_constr: &mut RecursiveConstraintConsumer, +) -> ExtensionTarget { + let base = F::from_canonical_u32(1 << 16); + let limb_0_u32 = builder.mul_const_add_extension(base, limb_1_u16, limb_0_u16); + let limb_1_u32 = builder.mul_const_add_extension(base, limb_3_u16, limb_2_u16); + combine_u32s_check_canonical_circuit(builder, limb_0_u32, limb_1_u32, inverse, yield_constr) +} + +/// Adds constraints to require that a pair of `u32`s, in little-endian order, represent a canonical +/// field element, i.e. that their combined value is less than `|F|`. Returns their combined value. +pub(crate) fn combine_u32s_check_canonical>( + limb_0_u32: P, + limb_1_u32: P, + inverse: P, + yield_constr: &mut ConstraintConsumer

, +) -> P { + let u32_max = P::from(F::from_canonical_u32(u32::MAX)); + + // This is zero if and only if the high limb is `u32::MAX`. + let diff = u32_max - limb_1_u32; + // If this is zero, the diff is invertible, so the high limb is not `u32::MAX`. + let hi_not_max = inverse * diff - F::ONE; + // If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero. + let hi_not_max_or_lo_zero = hi_not_max * limb_0_u32; + + yield_constr.constraint(hi_not_max_or_lo_zero); + + // Return the combined value. + limb_0_u32 + limb_1_u32 * F::from_canonical_u64(1 << 32) +} + +/// Adds constraints to require that a pair of `u32`s, in little-endian order, represent a canonical +/// field element, i.e. that their combined value is less than `|F|`. Returns their combined value. +pub(crate) fn combine_u32s_check_canonical_circuit, const D: usize>( + builder: &mut CircuitBuilder, + limb_0_u32: ExtensionTarget, + limb_1_u32: ExtensionTarget, + inverse: ExtensionTarget, + yield_constr: &mut RecursiveConstraintConsumer, +) -> ExtensionTarget { + let one = builder.one_extension(); + let u32_max = builder.constant_extension(F::Extension::from_canonical_u32(u32::MAX)); + + // This is zero if and only if the high limb is `u32::MAX`. + let diff = builder.sub_extension(u32_max, limb_1_u32); + // If this is zero, the diff is invertible, so the high limb is not `u32::MAX`. + let hi_not_max = builder.mul_sub_extension(inverse, diff, one); + // If this is zero, either the high limb is not `u32::MAX`, or the low limb is zero. + let hi_not_max_or_lo_zero = builder.mul_extension(hi_not_max, limb_0_u32); + + yield_constr.constraint(builder, hi_not_max_or_lo_zero); + + // Return the combined value. + builder.mul_const_add_extension(F::from_canonical_u64(1 << 32), limb_1_u32, limb_0_u32) +} diff --git a/system_zero/src/alu/mod.rs b/system_zero/src/alu/mod.rs index 4e7e09fa..730ca302 100644 --- a/system_zero/src/alu/mod.rs +++ b/system_zero/src/alu/mod.rs @@ -9,9 +9,7 @@ use starky::vars::StarkEvaluationVars; use crate::alu::addition::{eval_addition, eval_addition_recursively, generate_addition}; use crate::alu::division::{eval_division, eval_division_recursively, generate_division}; -use crate::alu::multiplication::{ - eval_multiplication, eval_multiplication_recursively, generate_multiplication, -}; +use crate::alu::mul_add::{eval_mul_add, eval_mul_add_recursively, generate_mul_add}; use crate::alu::subtraction::{ eval_subtraction, eval_subtraction_recursively, generate_subtraction, }; @@ -20,8 +18,9 @@ use crate::registers::alu::*; use crate::registers::NUM_COLUMNS; mod addition; +mod canonical; mod division; -mod multiplication; +mod mul_add; mod subtraction; pub(crate) fn generate_alu(values: &mut [F; NUM_COLUMNS]) { @@ -30,7 +29,7 @@ pub(crate) fn generate_alu(values: &mut [F; NUM_COLUMNS]) { } else if values[IS_SUB].is_one() { generate_subtraction(values); } else if values[IS_MUL].is_one() { - generate_multiplication(values); + generate_mul_add(values); } else if values[IS_DIV].is_one() { generate_division(values); } @@ -50,7 +49,7 @@ pub(crate) fn eval_alu>( eval_addition(local_values, yield_constr); eval_subtraction(local_values, yield_constr); - eval_multiplication(local_values, yield_constr); + eval_mul_add(local_values, yield_constr); eval_division(local_values, yield_constr); } @@ -70,6 +69,6 @@ pub(crate) fn eval_alu_recursively, const D: usize> eval_addition_recursively(builder, local_values, yield_constr); eval_subtraction_recursively(builder, local_values, yield_constr); - eval_multiplication_recursively(builder, local_values, yield_constr); + eval_mul_add_recursively(builder, local_values, yield_constr); eval_division_recursively(builder, local_values, yield_constr); } diff --git a/system_zero/src/alu/mul_add.rs b/system_zero/src/alu/mul_add.rs new file mode 100644 index 00000000..53ba34a2 --- /dev/null +++ b/system_zero/src/alu/mul_add.rs @@ -0,0 +1,88 @@ +use plonky2::field::extension_field::Extendable; +use plonky2::field::field_types::{Field, PrimeField64}; +use plonky2::field::packed_field::PackedField; +use plonky2::hash::hash_types::RichField; +use plonky2::iop::ext_target::ExtensionTarget; +use plonky2::plonk::circuit_builder::CircuitBuilder; +use plonky2_util::assume; +use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; + +use crate::alu::canonical::*; +use crate::registers::alu::*; +use crate::registers::NUM_COLUMNS; + +pub(crate) fn generate_mul_add(values: &mut [F; NUM_COLUMNS]) { + let factor_0 = values[COL_MUL_ADD_FACTOR_0].to_canonical_u64(); + let factor_1 = values[COL_MUL_ADD_FACTOR_1].to_canonical_u64(); + let addend = values[COL_MUL_ADD_ADDEND].to_canonical_u64(); + + // Let the compiler know that each input must fit in 32 bits. + assume(factor_0 <= u32::MAX as u64); + assume(factor_1 <= u32::MAX as u64); + assume(addend <= u32::MAX as u64); + + let output = factor_0 * factor_1 + addend; + + // An advice value used to help verify that the limbs represent a canonical field element. + values[COL_MUL_ADD_RESULT_CANONICAL_INV] = compute_canonical_inv(output); + + values[COL_MUL_ADD_OUTPUT_0] = F::from_canonical_u16(output as u16); + values[COL_MUL_ADD_OUTPUT_1] = F::from_canonical_u16((output >> 16) as u16); + values[COL_MUL_ADD_OUTPUT_2] = F::from_canonical_u16((output >> 32) as u16); + values[COL_MUL_ADD_OUTPUT_3] = F::from_canonical_u16((output >> 48) as u16); +} + +pub(crate) fn eval_mul_add>( + local_values: &[P; NUM_COLUMNS], + yield_constr: &mut ConstraintConsumer

, +) { + let is_mul = local_values[IS_MUL]; + let factor_0 = local_values[COL_MUL_ADD_FACTOR_0]; + let factor_1 = local_values[COL_MUL_ADD_FACTOR_1]; + let addend = local_values[COL_MUL_ADD_ADDEND]; + let output_1 = local_values[COL_MUL_ADD_OUTPUT_0]; + let output_2 = local_values[COL_MUL_ADD_OUTPUT_1]; + let output_3 = local_values[COL_MUL_ADD_OUTPUT_2]; + let output_4 = local_values[COL_MUL_ADD_OUTPUT_3]; + let result_canonical_inv = local_values[COL_MUL_ADD_RESULT_CANONICAL_INV]; + + let computed_output = factor_0 * factor_1 + addend; + let output = combine_u16s_check_canonical( + output_1, + output_2, + output_3, + output_4, + result_canonical_inv, + yield_constr, + ); + yield_constr.constraint(computed_output - output); +} + +pub(crate) fn eval_mul_add_recursively, const D: usize>( + builder: &mut CircuitBuilder, + local_values: &[ExtensionTarget; NUM_COLUMNS], + yield_constr: &mut RecursiveConstraintConsumer, +) { + let is_mul = local_values[IS_MUL]; + let factor_0 = local_values[COL_MUL_ADD_FACTOR_0]; + let factor_1 = local_values[COL_MUL_ADD_FACTOR_1]; + let addend = local_values[COL_MUL_ADD_ADDEND]; + let output_1 = local_values[COL_MUL_ADD_OUTPUT_0]; + let output_2 = local_values[COL_MUL_ADD_OUTPUT_1]; + let output_3 = local_values[COL_MUL_ADD_OUTPUT_2]; + let output_4 = local_values[COL_MUL_ADD_OUTPUT_3]; + let result_canonical_inv = local_values[COL_MUL_ADD_RESULT_CANONICAL_INV]; + + let computed_output = builder.mul_add_extension(factor_0, factor_1, addend); + let output = combine_u16s_check_canonical_circuit( + builder, + output_1, + output_2, + output_3, + output_4, + result_canonical_inv, + yield_constr, + ); + let diff = builder.sub_extension(computed_output, output); + yield_constr.constraint(builder, diff); +} diff --git a/system_zero/src/alu/multiplication.rs b/system_zero/src/alu/multiplication.rs deleted file mode 100644 index a88b42f6..00000000 --- a/system_zero/src/alu/multiplication.rs +++ /dev/null @@ -1,31 +0,0 @@ -use plonky2::field::extension_field::Extendable; -use plonky2::field::field_types::{Field, PrimeField64}; -use plonky2::field::packed_field::PackedField; -use plonky2::hash::hash_types::RichField; -use plonky2::iop::ext_target::ExtensionTarget; -use plonky2::plonk::circuit_builder::CircuitBuilder; -use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; - -use crate::registers::alu::*; -use crate::registers::NUM_COLUMNS; - -pub(crate) fn generate_multiplication(values: &mut [F; NUM_COLUMNS]) { - // TODO -} - -pub(crate) fn eval_multiplication>( - local_values: &[P; NUM_COLUMNS], - yield_constr: &mut ConstraintConsumer

, -) { - let is_mul = local_values[IS_MUL]; - // TODO -} - -pub(crate) fn eval_multiplication_recursively, const D: usize>( - builder: &mut CircuitBuilder, - local_values: &[ExtensionTarget; NUM_COLUMNS], - yield_constr: &mut RecursiveConstraintConsumer, -) { - let is_mul = local_values[IS_MUL]; - // TODO -} diff --git a/system_zero/src/registers/alu.rs b/system_zero/src/registers/alu.rs index b4f82dff..e678d8e4 100644 --- a/system_zero/src/registers/alu.rs +++ b/system_zero/src/registers/alu.rs @@ -10,7 +10,7 @@ const START_SHARED_COLS: usize = IS_DIV + 1; /// Within the ALU, there are shared columns which can be used by any arithmetic/logic /// circuit, depending on which one is active this cycle. // Can be increased as needed as other operations are implemented. -const NUM_SHARED_COLS: usize = 3; +const NUM_SHARED_COLS: usize = 4; const fn shared_col(i: usize) -> usize { debug_assert!(i < NUM_SHARED_COLS); @@ -18,20 +18,40 @@ const fn shared_col(i: usize) -> usize { } /// The first value to be added; treated as an unsigned u32. -pub(crate) const COL_ADD_INPUT_1: usize = shared_col(0); +pub(crate) const COL_ADD_INPUT_0: usize = shared_col(0); /// The second value to be added; treated as an unsigned u32. -pub(crate) const COL_ADD_INPUT_2: usize = shared_col(1); +pub(crate) const COL_ADD_INPUT_1: usize = shared_col(1); /// The third value to be added; treated as an unsigned u32. -pub(crate) const COL_ADD_INPUT_3: usize = shared_col(2); +pub(crate) const COL_ADD_INPUT_2: usize = shared_col(2); // Note: Addition outputs three 16-bit chunks, and since these values need to be range-checked // anyway, we might as well use the range check unit's columns as our addition outputs. So the // three proceeding columns are basically aliases, not columns owned by the ALU. /// The first 16-bit chunk of the output, based on little-endian ordering. -pub(crate) const COL_ADD_OUTPUT_1: usize = super::range_check_16::col_rc_16_input(0); +pub(crate) const COL_ADD_OUTPUT_0: usize = super::range_check_16::col_rc_16_input(0); /// The second 16-bit chunk of the output, based on little-endian ordering. -pub(crate) const COL_ADD_OUTPUT_2: usize = super::range_check_16::col_rc_16_input(1); +pub(crate) const COL_ADD_OUTPUT_1: usize = super::range_check_16::col_rc_16_input(1); /// The third 16-bit chunk of the output, based on little-endian ordering. -pub(crate) const COL_ADD_OUTPUT_3: usize = super::range_check_16::col_rc_16_input(2); +pub(crate) const COL_ADD_OUTPUT_2: usize = super::range_check_16::col_rc_16_input(2); + +/// The first value to be multiplied; treated as an unsigned u32. +pub(crate) const COL_MUL_ADD_FACTOR_0: usize = shared_col(0); +/// The second value to be multiplied; treated as an unsigned u32. +pub(crate) const COL_MUL_ADD_FACTOR_1: usize = shared_col(1); +/// The value to be added to the product; treated as an unsigned u32. +pub(crate) const COL_MUL_ADD_ADDEND: usize = shared_col(2); + +/// The inverse of `u32::MAX - result_hi`, where `output_hi` is the high 32-bits of the result. +/// See https://hackmd.io/NC-yRmmtRQSvToTHb96e8Q#Checking-element-validity +pub(crate) const COL_MUL_ADD_RESULT_CANONICAL_INV: usize = shared_col(3); + +/// The first 16-bit chunk of the output, based on little-endian ordering. +pub(crate) const COL_MUL_ADD_OUTPUT_0: usize = super::range_check_16::col_rc_16_input(0); +/// The second 16-bit chunk of the output, based on little-endian ordering. +pub(crate) const COL_MUL_ADD_OUTPUT_1: usize = super::range_check_16::col_rc_16_input(1); +/// The third 16-bit chunk of the output, based on little-endian ordering. +pub(crate) const COL_MUL_ADD_OUTPUT_2: usize = super::range_check_16::col_rc_16_input(2); +/// The fourth 16-bit chunk of the output, based on little-endian ordering. +pub(crate) const COL_MUL_ADD_OUTPUT_3: usize = super::range_check_16::col_rc_16_input(3); pub(super) const END: usize = super::START_ALU + NUM_SHARED_COLS;