diff --git a/evm/src/all_stark.rs b/evm/src/all_stark.rs index 1131d529..b687d48e 100644 --- a/evm/src/all_stark.rs +++ b/evm/src/all_stark.rs @@ -205,6 +205,19 @@ mod tests { (trace, num_ops) } + fn bits_from_opcode(opcode: u8) -> [F; 8] { + [ + F::from_bool(opcode & (1 << 0) != 0), + F::from_bool(opcode & (1 << 1) != 0), + F::from_bool(opcode & (1 << 2) != 0), + F::from_bool(opcode & (1 << 3) != 0), + F::from_bool(opcode & (1 << 4) != 0), + F::from_bool(opcode & (1 << 5) != 0), + F::from_bool(opcode & (1 << 6) != 0), + F::from_bool(opcode & (1 << 7) != 0), + ] + } + fn make_cpu_trace( num_keccak_perms: usize, num_logic_rows: usize, @@ -263,16 +276,21 @@ mod tests { [F::ZERO; CpuStark::::COLUMNS].into(); row.is_cpu_cycle = F::ONE; row.is_kernel_mode = F::ONE; + // Since these are the first cycle rows, we must start with PC=route_txn then increment. row.program_counter = F::from_canonical_usize(KERNEL.global_labels["route_txn"] + i); - row.opcode = [ - (logic::columns::IS_AND, 0x16), - (logic::columns::IS_OR, 0x17), - (logic::columns::IS_XOR, 0x18), - ] - .into_iter() - .map(|(col, opcode)| logic_trace[col].values[i] * F::from_canonical_u64(opcode)) - .sum(); + row.opcode_bits = bits_from_opcode( + if logic_trace[logic::columns::IS_AND].values[i] != F::ZERO { + 0x16 + } else if logic_trace[logic::columns::IS_OR].values[i] != F::ZERO { + 0x17 + } else if logic_trace[logic::columns::IS_XOR].values[i] != F::ZERO { + 0x18 + } else { + panic!() + }, + ); + let logic = row.general.logic_mut(); let input0_bit_cols = logic::columns::limb_bit_cols_for_input(logic::columns::INPUT0); @@ -330,7 +348,7 @@ mod tests { let last_row: cpu::columns::CpuColumnsView = cpu_trace_rows[cpu_trace_rows.len() - 1].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0x0a); // `EXP` is implemented in software + row.opcode_bits = bits_from_opcode(0x0a); // `EXP` is implemented in software row.is_kernel_mode = F::ONE; row.program_counter = last_row.program_counter + F::ONE; row.general.syscalls_mut().output = [ @@ -352,7 +370,7 @@ mod tests { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0xf9); + row.opcode_bits = bits_from_opcode(0xf9); row.is_kernel_mode = F::ONE; row.program_counter = F::from_canonical_usize(KERNEL.global_labels["sys_exp"]); row.general.jumps_mut().input0 = [ @@ -374,7 +392,7 @@ mod tests { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0x56); + row.opcode_bits = bits_from_opcode(0x56); row.is_kernel_mode = F::ONE; row.program_counter = F::from_canonical_u16(15682); row.general.jumps_mut().input0 = [ @@ -411,7 +429,7 @@ mod tests { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0xf9); + row.opcode_bits = bits_from_opcode(0xf9); row.is_kernel_mode = F::ONE; row.program_counter = F::from_canonical_u16(15106); row.general.jumps_mut().input0 = [ @@ -433,7 +451,7 @@ mod tests { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0x56); + row.opcode_bits = bits_from_opcode(0x56); row.is_kernel_mode = F::ZERO; row.program_counter = F::from_canonical_u16(63064); row.general.jumps_mut().input0 = [ @@ -471,7 +489,7 @@ mod tests { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0x57); + row.opcode_bits = bits_from_opcode(0x57); row.is_kernel_mode = F::ZERO; row.program_counter = F::from_canonical_u16(3754); row.general.jumps_mut().input0 = [ @@ -509,7 +527,7 @@ mod tests { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0x57); + row.opcode_bits = bits_from_opcode(0x57); row.is_kernel_mode = F::ZERO; row.program_counter = F::from_canonical_u16(37543); row.general.jumps_mut().input0 = [ @@ -538,7 +556,7 @@ mod tests { let last_row: cpu::columns::CpuColumnsView = cpu_trace_rows[cpu_trace_rows.len() - 1].into(); row.is_cpu_cycle = F::ONE; - row.opcode = F::from_canonical_u8(0x56); + row.opcode_bits = bits_from_opcode(0x56); row.is_kernel_mode = F::ZERO; row.program_counter = last_row.program_counter + F::ONE; row.general.jumps_mut().input0 = [ @@ -575,7 +593,7 @@ mod tests { for i in 0..cpu_trace_rows.len().next_power_of_two() - cpu_trace_rows.len() { let mut row: cpu::columns::CpuColumnsView = [F::ZERO; CpuStark::::COLUMNS].into(); - row.opcode = F::from_canonical_u8(0xff); + row.opcode_bits = bits_from_opcode(0xff); row.is_cpu_cycle = F::ONE; row.is_kernel_mode = F::ONE; row.program_counter = diff --git a/evm/src/cpu/columns/mod.rs b/evm/src/cpu/columns/mod.rs index 3016b2fd..8f641db9 100644 --- a/evm/src/cpu/columns/mod.rs +++ b/evm/src/cpu/columns/mod.rs @@ -27,9 +27,6 @@ pub struct CpuColumnsView { /// If CPU cycle: We're in kernel (privileged) mode. pub is_kernel_mode: T, - /// If CPU cycle: The opcode being decoded, in {0, ..., 255}. - pub opcode: T, - // If CPU cycle: flags for EVM instructions. PUSHn, DUPn, and SWAPn only get one flag each. // Invalid opcodes are split between a number of flags for practical reasons. Exactly one of // these flags must be 1. diff --git a/evm/src/cpu/decode.rs b/evm/src/cpu/decode.rs index 4faf7925..e58b474d 100644 --- a/evm/src/cpu/decode.rs +++ b/evm/src/cpu/decode.rs @@ -1,6 +1,5 @@ use plonky2::field::extension::Extendable; use plonky2::field::packed::PackedField; -use plonky2::field::types::Field; use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; @@ -158,13 +157,16 @@ pub fn generate(lv: &mut CpuColumnsView) { // This assert is not _strictly_ necessary, but I include it as a sanity check. assert_eq!(cycle_filter, F::ONE, "cycle_filter should be 0 or 1"); - let opcode = lv.opcode.to_canonical_u64(); - assert!(opcode < 256, "opcode should be in {{0, ..., 255}}"); - let opcode = opcode as u8; - - for (i, bit) in lv.opcode_bits.iter_mut().enumerate() { - *bit = F::from_bool(opcode & (1 << i) != 0); + // Validate all opcode bits. + for bit in lv.opcode_bits.into_iter() { + assert!(bit.to_canonical_u64() <= 1); } + let opcode = lv + .opcode_bits + .into_iter() + .enumerate() + .map(|(i, bit)| bit.to_canonical_u64() << i) + .sum::() as u8; let top_bits: [u8; 9] = [ 0, @@ -217,23 +219,10 @@ pub fn eval_packed_generic( let kernel_mode = lv.is_kernel_mode; yield_constr.constraint(cycle_filter * kernel_mode * (kernel_mode - P::ONES)); - // Ensure that the opcode bits are valid: each has to be either 0 or 1, and they must match - // the opcode. Note that this also implicitly range-checks the opcode. - let bits = lv.opcode_bits; - // First check that the bits are either 0 or 1. - for bit in bits { + // Ensure that the opcode bits are valid: each has to be either 0 or 1. + for bit in lv.opcode_bits { yield_constr.constraint(cycle_filter * bit * (bit - P::ONES)); } - // Now check that they match the opcode. - { - let opcode = lv.opcode; - let reconstructed_opcode: P = bits - .into_iter() - .enumerate() - .map(|(i, bit)| bit * P::Scalar::from_canonical_u64(1 << i)) - .sum(); - yield_constr.constraint(cycle_filter * (opcode - reconstructed_opcode)); - } // Check that the instruction flags are valid. // First, check that they are all either 0 or 1. @@ -258,7 +247,8 @@ pub fn eval_packed_generic( Kernel => P::ONES - kernel_mode, }; // 0 if all the opcode bits match, and something in {1, ..., 8}, otherwise. - let opcode_mismatch: P = bits + let opcode_mismatch: P = lv + .opcode_bits .into_iter() .zip(bits_from_opcode(oc)) .rev() @@ -294,28 +284,12 @@ pub fn eval_ext_circuit, const D: usize>( yield_constr.constraint(builder, constr); } - // Ensure that the opcode bits are valid: each has to be either 0 or 1, and they must match - // the opcode. Note that this also implicitly range-checks the opcode. - let bits = lv.opcode_bits; - // First check that the bits are either 0 or 1. - for bit in bits { + // Ensure that the opcode bits are valid: each has to be either 0 or 1. + for bit in lv.opcode_bits { let constr = builder.mul_sub_extension(bit, bit, bit); let constr = builder.mul_extension(cycle_filter, constr); yield_constr.constraint(builder, constr); } - // Now check that they match the opcode. - { - let opcode = lv.opcode; - let reconstructed_opcode = - bits.into_iter() - .enumerate() - .fold(builder.zero_extension(), |cumul, (i, bit)| { - builder.mul_const_add_extension(F::from_canonical_u64(1 << i), bit, cumul) - }); - let diff = builder.sub_extension(opcode, reconstructed_opcode); - let constr = builder.mul_extension(cycle_filter, diff); - yield_constr.constraint(builder, constr); - } // Check that the instruction flags are valid. // First, check that they are all either 0 or 1. @@ -346,7 +320,8 @@ pub fn eval_ext_circuit, const D: usize>( Kernel => builder.sub_extension(one, kernel_mode), }; // 0 if all the opcode bits match, and something in {1, ..., 8}, otherwise. - let opcode_mismatch = bits + let opcode_mismatch = lv + .opcode_bits .into_iter() .zip(bits_from_opcode(oc)) .rev()