diff --git a/ecdsa/src/gadgets/glv.rs b/ecdsa/src/gadgets/glv.rs index 4302023e..539b5de3 100644 --- a/ecdsa/src/gadgets/glv.rs +++ b/ecdsa/src/gadgets/glv.rs @@ -55,8 +55,8 @@ impl, const D: usize> CircuitBuilderGlv ) { let k1 = self.add_virtual_nonnative_target_sized::(4); let k2 = self.add_virtual_nonnative_target_sized::(4); - let k1_neg = self.add_virtual_bool_target(); - let k2_neg = self.add_virtual_bool_target(); + let k1_neg = self.add_virtual_bool_target_unsafe(); + let k2_neg = self.add_virtual_bool_target_unsafe(); self.add_simple_generator(GLVDecompositionGenerator:: { k: k.clone(), diff --git a/ecdsa/src/gadgets/nonnative.rs b/ecdsa/src/gadgets/nonnative.rs index c6ff4753..29520bed 100644 --- a/ecdsa/src/gadgets/nonnative.rs +++ b/ecdsa/src/gadgets/nonnative.rs @@ -183,7 +183,7 @@ impl, const D: usize> CircuitBuilderNonNative b: &NonNativeTarget, ) -> NonNativeTarget { let sum = self.add_virtual_nonnative_target::(); - let overflow = self.add_virtual_bool_target(); + let overflow = self.add_virtual_bool_target_unsafe(); self.add_simple_generator(NonNativeAdditionGenerator:: { a: a.clone(), @@ -282,7 +282,7 @@ impl, const D: usize> CircuitBuilderNonNative b: &NonNativeTarget, ) -> NonNativeTarget { let diff = self.add_virtual_nonnative_target::(); - let overflow = self.add_virtual_bool_target(); + let overflow = self.add_virtual_bool_target_unsafe(); self.add_simple_generator(NonNativeSubtractionGenerator:: { a: a.clone(), diff --git a/evm/src/cpu/kernel/aggregator.rs b/evm/src/cpu/kernel/aggregator.rs index 1d20fe0c..9509f484 100644 --- a/evm/src/cpu/kernel/aggregator.rs +++ b/evm/src/cpu/kernel/aggregator.rs @@ -32,6 +32,9 @@ pub(crate) fn combined_kernel() -> Kernel { include_str!("asm/curve/secp256k1/lift_x.asm"), include_str!("asm/curve/secp256k1/moddiv.asm"), include_str!("asm/exp.asm"), + include_str!("asm/fields/fp6_macros.asm"), + include_str!("asm/fields/fp6_mul.asm"), + include_str!("asm/fields/fp12_mul.asm"), include_str!("asm/halt.asm"), include_str!("asm/main.asm"), include_str!("asm/memory/core.asm"), diff --git a/evm/src/cpu/kernel/asm/fields/fp12_mul.asm b/evm/src/cpu/kernel/asm/fields/fp12_mul.asm new file mode 100644 index 00000000..2f4b9024 --- /dev/null +++ b/evm/src/cpu/kernel/asm/fields/fp12_mul.asm @@ -0,0 +1,166 @@ +/// Note: uncomment this to test + +/// global test_mul_Fp12: +/// // stack: f, in0 , f', g, in1 , g', in1, out, in0, out +/// DUP7 +/// // stack: in0, f, in0 , f', g, in1 , g', in1, out, in0, out +/// %store_fp6 +/// // stack: in0 , f', g, in1 , g', in1, out, in0, out +/// %add_const(6) +/// // stack: in0', f', g, in1 , g', in1, out, in0, out +/// %store_fp6 +/// // stack: g, in1 , g', in1, out, in0, out +/// DUP7 +/// // stack: in1, g, in1 , g', in1, out, in0, out +/// %store_fp6 +/// // stack: in1 , g', in1, out, in0, out +/// %add_const(6) +/// // stack: in1', g', in1, out, in0, out +/// %store_fp6 +/// // stack: in1, out, in0, out +/// PUSH ret_stack +/// // stack: ret_stack, in1, out, in0, out +/// SWAP3 +/// // stack: in0, in1, out, ret_stack, out +/// %jump(mul_Fp12) +/// ret_stack: +/// // stack: out +/// DUP1 %add_const(6) +/// // stack: out', out +/// %load_fp6 +/// // stack: h', out +/// DUP7 +/// // stack: out, h', out +/// %load_fp6 +/// // stack: h, h', out +/// %jump(0xdeadbeef) + + +/// fp6 functions: +/// fn | num | ops | cost +/// ------------------------- +/// load | 8 | 40 | 320 +/// store | 5 | 40 | 200 +/// dup | 5 | 6 | 30 +/// swap | 4 | 16 | 64 +/// add | 4 | 16 | 64 +/// subr | 1 | 17 | 17 +/// mul | 3 | 157 | 471 +/// i9 | 1 | 9 | 9 +/// +/// lone stack operations: +/// op | num +/// ------------ +/// ADD | 3 +/// SWAP | 2 +/// DUP | 6 +/// PUSH | 6 +/// POP | 2 +/// JUMP | 1 +/// +/// TOTAL: 1196 + +/// inputs: +/// F = f + f'z +/// G = g + g'z +/// +/// output: +/// H = h + h'z = FG +/// +/// h = fg + sh(f'g') +/// h' = (f+f')(g+g') - fg - f'g' +/// +/// memory pointers [ind' = ind+6] +/// {in0: f, in0: f', in1: g, in1':g', out: h, out': h'} +/// +/// f, f', g, g' consist of six elements on the stack + +global mul_Fp12: + // stack: in0, in1, out + DUP1 %add_const(6) + // stack: in0', in0, in1, out + %load_fp6 + // stack: f', in0, in1, out + DUP8 %add_const(6) + // stack: in1', f', in0, in1, out + %load_fp6 + // stack: g', f', in0, in1, out + PUSH ret_1 + // stack: ret_1, g', f', in0, in1, out + %dup_fp6_7 + // stack: f', ret_1, g', f', in0, in1, out + %dup_fp6_7 + // stack: g', f', ret_1, g', f', in0, in1, out + %jump(mul_fp6) +ret_1: + // stack: f'g', g' , f', in0, in1, out + %dup_fp6_0 + // stack: f'g', f'g', g' , f', in0, in1, out + %store_fp6_sh(100) + // stack: f'g', g' , f', in0, in1, out {100: sh(f'g')} + %store_fp6(106) + // stack: g' , f', in0, in1, out {100: sh(f'g'), 106: f'g'} + DUP13 + // stack: in0, g' , f', in0, in1, out {100: sh(f'g'), 106: f'g'} + DUP15 + // stack: in1, in0, g' , f', in0, in1, out {100: sh(f'g'), 106: f'g'} + %load_fp6 + // stack: g , in0, g' , f', in0, in1, out {100: sh(f'g'), 106: f'g'} + %swap_fp6_hole + // stack: g', in0, g , f', in0, in1, out {100: sh(f'g'), 106: f'g'} + %dup_fp6_7 + // stack: g,g', in0, g , f', in0, in1, out {100: sh(f'g'), 106: f'g'} + %add_fp6 + // stack: g+g', in0, g , f', in0, in1, out {100: sh(f'g'), 106: f'g'} + %swap_fp6_hole + // stack: g, in0, g+g', f', in0, in1, out {100: sh(f'g'), 106: f'g'} + PUSH ret_2 + // stack: ret_2, g, in0, g+g', f', in0, in1, out {100: sh(f'g'), 106: f'g'} + SWAP7 + // stack: in0, g, ret_2, g+g', f', in0, in1, out {100: sh(f'g'), 106: f'g'} + %load_fp6 + // stack: f, g, ret_2, g+g', f', in0, in1, out {100: sh(f'g'), 106: f'g'} + %jump(mul_fp6) +ret_2: + // stack: fg, g+g', f', in0, in1, out {100: sh(f'g'), 106: f'g'} + %store_fp6(112) + // stack: g+g', f', in0, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %swap_fp6 + // stack: f', g+g', in0, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + PUSH ret_3 + // stack: ret_3, f', g+g', in0, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + SWAP13 + // stack: in0, f', g+g', ret_3, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %load_fp6 + // stack: f,f', g+g', ret_3, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %add_fp6 + // stack: f+f', g+g', ret_3, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %jump(mul_fp6) +ret_3: + // stack: (f+f')(g+g'), in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %load_fp6(112) + // stack: fg, (f+f')(g+g'), in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %swap_fp6 + // stack: (f+f')(g+g'), fg, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %dup_fp6_6 + // stack: fg, (f+f')(g+g'), fg, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %load_fp6(106) + // stack: f'g',fg, (f+f')(g+g'), fg, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %add_fp6 + // stack: f'g'+fg, (f+f')(g+g'), fg, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %subr_fp6 + // stack: (f+f')(g+g') - (f'g'+fg), fg, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + DUP14 %add_const(6) + // stack: out', (f+f')(g+g') - (f'g'+fg), fg, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %store_fp6 + // stack: fg, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %load_fp6(100) + // stack: sh(f'g') , fg, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %add_fp6 + // stack: sh(f'g') + fg, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + DUP8 + // stack: out, sh(f'g') + fg, in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %store_fp6 + // stack: in1, out {100: sh(f'g'), 106: f'g', 112: fg} + %pop2 + JUMP diff --git a/evm/src/cpu/kernel/asm/fields/fp6_macros.asm b/evm/src/cpu/kernel/asm/fields/fp6_macros.asm new file mode 100644 index 00000000..b575c234 --- /dev/null +++ b/evm/src/cpu/kernel/asm/fields/fp6_macros.asm @@ -0,0 +1,314 @@ +// cost: 6 loads + 6 dup/swaps + 5 adds = 6*4 + 6*1 + 5*2 = 40 +%macro load_fp6 + // stack: ptr + DUP1 %add_const(4) + // stack: ind4, ptr + %mload_kernel_general + // stack: x4, ptr + DUP2 %add_const(3) + // stack: ind3, x4, ptr + %mload_kernel_general + // stack: x3, x4, ptr + DUP3 %add_const(2) + // stack: ind2, x3, x4, ptr + %mload_kernel_general + // stack: x2, x3, x4, ptr + DUP4 %add_const(1) + // stack: ind1, x2, x3, x4, ptr + %mload_kernel_general + // stack: x1, x2, x3, x4, ptr + DUP5 %add_const(5) + // stack: ind5, x1, x2, x3, x4, ptr + %mload_kernel_general + // stack: x5, x1, x2, x3, x4, ptr + SWAP5 + // stack: ind0, x1, x2, x3, x4, x5 + %mload_kernel_general + // stack: x0, x1, x2, x3, x4, x5 +%endmacro + +// cost: 6 loads + 6 pushes + 5 adds = 6*4 + 6*1 + 5*2 = 40 +%macro load_fp6(ptr) + // stack: + PUSH $ptr %add_const(5) + // stack: ind5 + %mload_kernel_general + // stack: x5 + PUSH $ptr %add_const(4) + // stack: ind4, x5 + %mload_kernel_general + // stack: x4, x5 + PUSH $ptr %add_const(3) + // stack: ind3, x4, x5 + %mload_kernel_general + // stack: x3, x4, x5 + PUSH $ptr %add_const(2) + // stack: ind2, x3, x4, x5 + %mload_kernel_general + // stack: x2, x3, x4, x5 + PUSH $ptr %add_const(1) + // stack: ind1, x2, x3, x4, x5 + %mload_kernel_general + // stack: x1, x2, x3, x4, x5 + PUSH $ptr + // stack: ind0, x1, x2, x3, x4, x5 + %mload_kernel_general + // stack: x0, x1, x2, x3, x4, x5 +%endmacro + +// cost: 6 stores + 6 swaps/dups + 5 adds = 6*4 + 6*1 + 5*2 = 40 +%macro store_fp6 + // stack: ptr, x0, x1, x2, x3, x4 , x5 + SWAP5 + // stack: x4, x0, x1, x2, x3, ptr, x5 + DUP6 %add_const(4) + // stack: ind4, x4, x0, x1, x2, x3, ptr, x5 + %mstore_kernel_general + // stack: x0, x1, x2, x3, ptr, x5 + DUP5 + // stack: ind0, x0, x1, x2, x3, ptr, x5 + %mstore_kernel_general + // stack: x1, x2, x3, ptr, x5 + DUP4 %add_const(1) + // stack: ind1, x1, x2, x3, ptr, x5 + %mstore_kernel_general + // stack: x2, x3, ptr, x5 + DUP3 %add_const(2) + // stack: ind2, x2, x3, ptr, x5 + %mstore_kernel_general + // stack: x3, ptr, x5 + DUP2 %add_const(3) + // stack: ind3, x3, ptr, x5 + %mstore_kernel_general + // stack: ptr, x5 + %add_const(5) + // stack: ind5, x5 + %mstore_kernel_general + // stack: +%endmacro + +// cost: 6 stores + 6 pushes + 5 adds = 6*4 + 6*1 + 5*2 = 40 +%macro store_fp6(ptr) + // stack: x0, x1, x2, x3, x4, x5 + PUSH $ptr + // stack: ind0, x0, x1, x2, x3, x4, x5 + %mstore_kernel_general + // stack: x1, x2, x3, x4, x5 + PUSH $ptr %add_const(1) + // stack: ind1, x1, x2, x3, x4, x5 + %mstore_kernel_general + // stack: x2, x3, x4, x5 + PUSH $ptr %add_const(2) + // stack: ind2, x2, x3, x4, x5 + %mstore_kernel_general + // stack: x3, x4, x5 + PUSH $ptr %add_const(3) + // stack: ind3, x3, x4, x5 + %mstore_kernel_general + // stack: x4, x5 + PUSH $ptr %add_const(4) + // stack: ind4, x4, x5 + %mstore_kernel_general + // stack: x5 + PUSH $ptr %add_const(5) + // stack: ind5, x5 + %mstore_kernel_general + // stack: +%endmacro + +// cost: store (40) + i9 (9) = 49 +%macro store_fp6_sh(ptr) + // stack: x0, x1, x2, x3, x4, x5 + PUSH $ptr %add_const(2) + // stack: ind2, x0, x1, x2, x3, x4, x5 + %mstore_kernel_general + // stack: x1, x2, x3, x4, x5 + PUSH $ptr %add_const(3) + // stack: ind3, x1, x2, x3, x4, x5 + %mstore_kernel_general + // stack: x2, x3, x4, x5 + PUSH $ptr %add_const(4) + // stack: ind4, x2, x3, x4, x5 + %mstore_kernel_general + // stack: x3, x4, x5 + PUSH $ptr %add_const(5) + // stack: ind5, x3, x4, x5 + %mstore_kernel_general + // stack: x4, x5 + %i9 + // stack: y5, y4 + PUSH $ptr %add_const(1) + // stack: ind1, y5, y4 + %mstore_kernel_general + // stack: y4 + PUSH $ptr + // stack: ind0, y4 + %mstore_kernel_general + // stack: +%endmacro + +// cost: 9; note this returns y, x for the output x + yi +%macro i9 + // stack: a , b + DUP2 + // stack: b, a, b + DUP2 + // stack: a , b, a , b + PUSH 9 MULFP254 + // stack: 9a , b, a , b + SUBFP254 + // stack: 9a - b, a , b + SWAP2 + // stack: b , a, 9a - b + PUSH 9 MULFP254 + // stack 9b , a, 9a - b + ADDFP254 + // stack: 9b + a, 9a - b +%endmacro + +// cost: 6 +%macro dup_fp6_0 + // stack: f: 6 + DUP6 + DUP6 + DUP6 + DUP6 + DUP6 + DUP6 + // stack: f: 6, g: 6 +%endmacro + +// cost: 6 +%macro dup_fp6_6 + // stack: f: 6, g: 6 + DUP12 + DUP12 + DUP12 + DUP12 + DUP12 + DUP12 + // stack: g: 6, f: 6, g: 6 +%endmacro + +// cost: 6 +%macro dup_fp6_7 + // stack: f: 6, g: 6 + DUP13 + DUP13 + DUP13 + DUP13 + DUP13 + DUP13 + // stack: g: 6, f: 6, g: 6 +%endmacro + +// cost: 16 +%macro swap_fp6 + // stack: f0, f1, f2, f3, f4, f5, g0, g1, g2, g3, g4, g5 + SWAP6 + // stack: g0, f1, f2, f3, f4, f5, f0, g1, g2, g3, g4, g5 + SWAP1 + SWAP7 + SWAP1 + // stack: g0, g1, f2, f3, f4, f5, f0, f1, g2, g3, g4, g5 + SWAP2 + SWAP8 + SWAP2 + // stack: g0, g1, g2, f3, f4, f5, f0, f1, f2, g3, g4, g5 + SWAP3 + SWAP9 + SWAP3 + // stack: g0, g1, g2, g3, f4, f5, f0, f1, f2, f3, g4, g5 + SWAP4 + SWAP10 + SWAP4 + // stack: g0, g1, g2, g3, g4, f5, f0, f1, f2, f3, f4, g5 + SWAP5 + SWAP11 + SWAP5 + // stack: g0, g1, g2, g3, g4, g5, f0, f1, f2, f3, f4, f5 +%endmacro + +// cost: 16 +// swap two fp6 elements with a stack term separating them +// (f: 6, x, g: 6) -> (g: 6, x, f: 6) +%macro swap_fp6_hole + // stack: f0, f1, f2, f3, f4, f5, X, g0, g1, g2, g3, g4, g5 + SWAP7 + // stack: g0, f1, f2, f3, f4, f5, X, f0, g1, g2, g3, g4, g5 + SWAP1 + SWAP8 + SWAP1 + // stack: g0, g1, f2, f3, f4, f5, X, f0, f1, g2, g3, g4, g5 + SWAP2 + SWAP9 + SWAP2 + // stack: g0, g1, g2, f3, f4, f5, X, f0, f1, f2, g3, g4, g5 + SWAP3 + SWAP10 + SWAP3 + // stack: g0, g1, g2, g3, f4, f5, X, f0, f1, f2, f3, g4, g5 + SWAP4 + SWAP11 + SWAP4 + // stack: g0, g1, g2, g3, g4, f5, X, f0, f1, f2, f3, f4, g5 + SWAP5 + SWAP12 + SWAP5 + // stack: g0, g1, g2, g3, g4, g5, X, f0, f1, f2, f3, f4, f5 +%endmacro + +// cost: 16 +%macro add_fp6 + // stack: f0, f1, f2, f3, f4, f5, g0, g1, g2, g3, g4, g5 + SWAP7 + ADDFP254 + SWAP6 + // stack: f0, f2, f3, f4, f5, g0, h1, g2, g3, g4, g5 + SWAP7 + ADDFP254 + SWAP6 + // stack: f0, f3, f4, f5, g0, h1, h2, g3, g4, g5 + SWAP7 + ADDFP254 + SWAP6 + // stack: f0, f4, f5, g0, h1, h2, h3, g4, g5 + SWAP7 + ADDFP254 + SWAP6 + // stack: f0, f5, g0, h1, h2, h3, h4, g5 + SWAP7 + ADDFP254 + SWAP6 + // stack: f0, g0, h1, h2, h3, h4, h5 + ADDFP254 + // stack: h0, h1, h2, h3, h4, h5 +%endmacro + +// *reversed argument subtraction* cost: 17 +%macro subr_fp6 + // stack: f0, f1, f2, f3, f4, f5, g0, g1, g2, g3, g4, g5 + SWAP7 + SUBFP254 + SWAP6 + // stack: f0, f2, f3, f4, f5, g0, h1, g2, g3, g4, g5 + SWAP7 + SUBFP254 + SWAP6 + // stack: f0, f3, f4, f5, g0, h1, h2, g3, g4, g5 + SWAP7 + SUBFP254 + SWAP6 + // stack: f0, f4, f5, g0, h1, h2, h3, g4, g5 + SWAP7 + SUBFP254 + SWAP6 + // stack: f0, f5, g0, h1, h2, h3, h4, g5 + SWAP7 + SUBFP254 + SWAP6 + // stack: f0, g0, h1, h2, h3, h4, h5 + SWAP1 + SUBFP254 + // stack: h0, h1, h2, h3, h4, h5 +%endmacro diff --git a/evm/src/cpu/kernel/asm/fields/fp6_mul.asm b/evm/src/cpu/kernel/asm/fields/fp6_mul.asm new file mode 100644 index 00000000..0fc6dbdf --- /dev/null +++ b/evm/src/cpu/kernel/asm/fields/fp6_mul.asm @@ -0,0 +1,258 @@ +/// inputs: +/// C = C0 + C1t + C2t^2 +/// = (c0 + c0_i) + (c1 + c1_i)t + (c2 + c2_i)t^2 +/// +/// D = D0 + D1t + D2t^2 +/// = (d0 + d0_i) + (d1 + d1_i)t + (d2 + d2_i)t^2 +/// +/// output: +/// E = E0 + E1t + E2t^2 = CD +/// = (e0 + e0_i) + (e1 + e1_i)t + (e2 + e2_i)t^2 +/// +/// initial stack: c0, c0_, c1, c1_, c2, c2_, d0, d0_, d1, d1_, d2, d2_, retdest +/// final stack: e0, e0_, e1, e1_, e2, e2_ + +/// computations: +/// +/// E0 = C0D0 + i9(C1D2 + C2D1) +/// +/// C0D0 = (c0d0 - c0_d0_) + (c0d0_ + c0_d0)i +/// +/// C1D2 = (c1d2 - c1_d2_) + (c1d2_ + c1_d2)i +/// C2D1 = (c2d1 - c2_d1_) + (c2d1_ + c2_d1)i +/// +/// CD12 = C1D2 + C2D1 +/// = (c1d2 + c2d1 - c1_d2_ - c2_d1_) + (c1d2_ + c1_d2 + c2d1_ + c2_d1)i +/// +/// i9(CD12) = (9CD12 - CD12_) + (CD12 + 9CD12_)i +/// +/// e0 = 9CD12 - CD12_ + C0D0 +/// e0_ = 9CD12_ + CD12 + C0D0_ +/// +/// +/// E1 = C0D1 + C1D0 + i9(C2D2) +/// +/// C0D1 = (c0d1 - c0_d1_) + (c0d1_ + c0_d1)i +/// C1D0 = (c1d0 - c1_d0_) + (c1d0_ + c1_d0)i +/// +/// CD01 = c0d1 + c1d0 - (c0_d1_ + c1_d0_) +/// CD01_ = c0d1_ + c0_d1 + c1d0_ + c1_d0 +/// +/// C2D2 = (c2d2 - c2_d2_) + (c2d2_ + c2_d2)i +/// i9(C2D2) = (9C2D2 - C2D2_) + (C2D2 + 9C2D2_)i +/// +/// e1 = 9C2D2 - C2D2_ + CD01 +/// e1_ = C2D2 + 9C2D2_ + CD01_ +/// +/// +/// E2 = C0D2 + C1D1 + C2D0 +/// +/// C0D2 = (c0d2 - c0_d2_) + (c0d2_ + c0_d2)i +/// C1D1 = (c1d1 - c1_d1_) + (c1d1_ + c1_d1)i +/// C2D0 = (c2d0 - c2_d0_) + (c2d0_ + c2_d0)i +/// +/// e2 = c0d2 + c1d1 + c2d0 - (c0_d2_ + c1_d1_ + c2_d0_) +/// e2_ = c0d2_ + c0_d2 + c1d1_ + c1_d1 + c2d0_ + c2_d0 + + +// cost: 157 +global mul_fp6: + // e2 + // make c0_d2_ + c1_d1_ + c2_d0_ + DUP8 + DUP7 + MULFP254 + DUP11 + DUP6 + MULFP254 + ADDFP254 + DUP13 + DUP4 + MULFP254 + ADDFP254 + // make c0d2 + c1d1 + c2d0 + DUP12 + DUP3 + MULFP254 + DUP11 + DUP6 + MULFP254 + ADDFP254 + DUP9 + DUP8 + MULFP254 + ADDFP254 + // stack: c0d2 + c1d1 + c2d0 , c0_d2_ + c1_d1_ + c2_d0_ + SUBFP254 + // stack: e2 = c0d2 + c1d1 + c2d0 - (c0_d2_ + c1_d1_ + c2_d0_) + SWAP12 + + // e0, e0_ + // make CD12_ = c1d2_ + c1_d2 + c2d1_ + c2_d1 + DUP1 + DUP5 + MULFP254 + DUP13 + DUP7 + MULFP254 + ADDFP254 + DUP12 + DUP8 + MULFP254 + ADDFP254 + DUP11 + DUP9 + MULFP254 + ADDFP254 + // make C0D0_ = c0d0_ + c0_d0 + DUP10 + DUP4 + MULFP254 + DUP10 + DUP6 + MULFP254 + ADDFP254 + // make CD12 = c1d2 + c2d1 - c1_d2_ - c2_d1_ + DUP13 + DUP10 + MULFP254 + DUP4 + DUP9 + MULFP254 + ADDFP254 + DUP15 + DUP8 + MULFP254 + DUP14 + DUP11 + MULFP254 + ADDFP254 + SUBFP254 + // make C0D0 = c0d0 - c0_d0_ + DUP12 + DUP7 + MULFP254 + DUP12 + DUP7 + MULFP254 + SUBFP254 + // stack: C0D0 , CD12 , C0D0_, CD12_ + DUP4 + DUP3 + // stack: CD12 , CD12_ , C0D0 , CD12 , C0D0_, CD12_ + PUSH 9 + MULFP254 + SUBFP254 + ADDFP254 + // stack: e0 = 9CD12 - CD12_ + C0D0 , CD12 , C0D0_, CD12_ + SWAP12 + SWAP3 + // stack: CD12_ , CD12 , C0D0_ + PUSH 9 + MULFP254 + ADDFP254 + ADDFP254 + // stack: e0_ = 9CD12_ + CD12 + C0D0_ + SWAP11 + + // e1, e1_ + // make C2D2_ = c2d2_ + c2_d2 + DUP14 + DUP10 + MULFP254 + DUP4 + DUP10 + MULFP254 + ADDFP254 + // make C2D2 = c2d2 - c2_d2_ + DUP4 + DUP11 + MULFP254 + DUP16 + DUP11 + MULFP254 + SUBFP254 + // make CD01 = c0d1 + c1d0 - (c0_d1_ + c1_d0_) + DUP4 + DUP10 + MULFP254 + DUP16 + DUP9 + MULFP254 + ADDFP254 + DUP13 + DUP10 + MULFP254 + DUP5 + DUP9 + MULFP254 + ADDFP254 + SUBFP254 + // stack: CD01, C2D2, C2D2_ + DUP3 + DUP3 + // stack: C2D2 , C2D2_ , CD01, C2D2, C2D2_ + PUSH 9 + MULFP254 + SUBFP254 + ADDFP254 + // stack: e1 = 9C2D2 - C2D2_ + CD01, C2D2, C2D2_ + SWAP15 + SWAP2 + // stack: C2D2_ , C2D2 + PUSH 9 + MULFP254 + ADDFP254 + // stack: 9C2D2_ + C2D2 + // make CD01_ = c0d1_ + c0_d1 + c1d0_ + c1_d0 + DUP12 + DUP10 + MULFP254 + DUP5 + DUP10 + MULFP254 + ADDFP254 + DUP4 + DUP9 + MULFP254 + ADDFP254 + DUP3 + DUP8 + MULFP254 + ADDFP254 + // stack: CD01_ , 9C2D2_ + C2D2 + ADDFP254 + // stack: e1_ = CD01_ + 9C2D2_ + C2D2 + SWAP15 + + // e2_ + // stack: d2, d1_, d1, d0_, d2_, c0, c0_, c1, c1_, c2, c2_, d0 + SWAP7 + MULFP254 + // stack: c1d1_, d1, d0_, d2_, c0, c0_, d2, c1_, c2, c2_, d0 + SWAP7 + MULFP254 + // stack: c1_d1, d0_, d2_, c0, c0_, d2, c1d1_, c2, c2_, d0 + SWAP7 + MULFP254 + // stack: c2d0_, d2_, c0, c0_, d2, c1d1_, c1_d1 , c2_, d0 + SWAP2 + MULFP254 + // stack: c0d2_ , c2d0_, c0_, d2, c1d1_, c1_d1 , c2_, d0 + ADDFP254 + // stack: c0d2_ + c2d0_, c0_, d2, c1d1_, c1_d1 , c2_, d0 + SWAP2 + MULFP254 + // stack: c0_d2 , c0d2_ + c2d0_ , c1d1_ , c1_d1 , c2_, d0 + ADDFP254 + ADDFP254 + ADDFP254 + // stack: c0_d2 + c0d2_ + c2d0_ + c1d1_ + c1_d1 , c2_, d0 + SWAP2 + MULFP254 + ADDFP254 + // stack: e2_ = c2_d0 + c0_d2 + c0d2_ + c2d0_ + c1d1_ + c1_d1 + SWAP6 + + // stack: retdest, e0, e0_, e1, e1_, e2, e2_ + JUMP diff --git a/evm/src/cpu/kernel/asm/mpt/accounts.asm b/evm/src/cpu/kernel/asm/mpt/accounts.asm index 0e49da98..08291048 100644 --- a/evm/src/cpu/kernel/asm/mpt/accounts.asm +++ b/evm/src/cpu/kernel/asm/mpt/accounts.asm @@ -1,6 +1,6 @@ // Return a pointer to the current account's data in the state trie. %macro current_account_data - ADDRESS %mpt_read_state_trie + %address %mpt_read_state_trie // stack: account_ptr // account_ptr should be non-null as long as the prover provided the proper // Merkle data. But a bad prover may not have, and we don't want return a diff --git a/evm/src/cpu/kernel/asm/mpt/storage/storage_write.asm b/evm/src/cpu/kernel/asm/mpt/storage/storage_write.asm index 057e4bb5..7695b0c2 100644 --- a/evm/src/cpu/kernel/asm/mpt/storage/storage_write.asm +++ b/evm/src/cpu/kernel/asm/mpt/storage/storage_write.asm @@ -39,6 +39,6 @@ after_storage_insert: // stack: new_account_ptr, retdest // Save this updated account to the state trie. - ADDRESS %addr_to_state_key + %address %addr_to_state_key // stack: state_key, new_account_ptr, retdest %jump(mpt_insert_state_trie) diff --git a/evm/src/cpu/kernel/interpreter.rs b/evm/src/cpu/kernel/interpreter.rs index 478d5413..82bd382b 100644 --- a/evm/src/cpu/kernel/interpreter.rs +++ b/evm/src/cpu/kernel/interpreter.rs @@ -237,9 +237,9 @@ impl<'a> Interpreter<'a> { 0x09 => self.run_mulmod(), // "MULMOD", 0x0a => self.run_exp(), // "EXP", 0x0b => todo!(), // "SIGNEXTEND", - 0x0c => todo!(), // "ADDFP254", - 0x0d => todo!(), // "MULFP254", - 0x0e => todo!(), // "SUBFP254", + 0x0c => self.run_addfp254(), // "ADDFP254", + 0x0d => self.run_mulfp254(), // "MULFP254", + 0x0e => self.run_subfp254(), // "SUBFP254", 0x10 => self.run_lt(), // "LT", 0x11 => self.run_gt(), // "GT", 0x12 => todo!(), // "SLT", @@ -370,6 +370,27 @@ impl<'a> Interpreter<'a> { self.push(x.overflowing_sub(y).0); } + // TODO: 107 is hardcoded as a dummy prime for testing + // should be changed to the proper implementation prime + + fn run_addfp254(&mut self) { + let x = self.pop(); + let y = self.pop(); + self.push((x + y) % 107); + } + + fn run_mulfp254(&mut self) { + let x = self.pop(); + let y = self.pop(); + self.push(U256::try_from(x.full_mul(y) % 107).unwrap()); + } + + fn run_subfp254(&mut self) { + let x = self.pop(); + let y = self.pop(); + self.push((U256::from(107) + x - y) % 107); + } + fn run_div(&mut self) { let x = self.pop(); let y = self.pop(); diff --git a/evm/src/cpu/kernel/tests/fields.rs b/evm/src/cpu/kernel/tests/fields.rs new file mode 100644 index 00000000..b8a38887 --- /dev/null +++ b/evm/src/cpu/kernel/tests/fields.rs @@ -0,0 +1,203 @@ +use anyhow::Result; +use ethereum_types::U256; +use rand::{thread_rng, Rng}; + +use crate::cpu::kernel::aggregator::combined_kernel; +use crate::cpu::kernel::interpreter::run_with_kernel; + +// TODO: 107 is hardcoded as a dummy prime for testing +// should be changed to the proper implementation prime +// once the run_{add, mul, sub}fp254 fns are implemented +const P254: u32 = 107; + +fn add_fp(x: u32, y: u32) -> u32 { + (x + y) % P254 +} + +fn add3_fp(x: u32, y: u32, z: u32) -> u32 { + (x + y + z) % P254 +} + +fn mul_fp(x: u32, y: u32) -> u32 { + (x * y) % P254 +} + +fn sub_fp(x: u32, y: u32) -> u32 { + (P254 + x - y) % P254 +} + +fn add_fp2(a: [u32; 2], b: [u32; 2]) -> [u32; 2] { + let [a, a_] = a; + let [b, b_] = b; + [add_fp(a, b), add_fp(a_, b_)] +} + +fn add3_fp2(a: [u32; 2], b: [u32; 2], c: [u32; 2]) -> [u32; 2] { + let [a, a_] = a; + let [b, b_] = b; + let [c, c_] = c; + [add3_fp(a, b, c), add3_fp(a_, b_, c_)] +} + +// fn sub_fp2(a: [u32; 2], b: [u32; 2]) -> [u32; 2] { +// let [a, a_] = a; +// let [b, b_] = b; +// [sub_fp(a, b), sub_fp(a_, b_)] +// } + +fn mul_fp2(a: [u32; 2], b: [u32; 2]) -> [u32; 2] { + let [a, a_] = a; + let [b, b_] = b; + [ + sub_fp(mul_fp(a, b), mul_fp(a_, b_)), + add_fp(mul_fp(a, b_), mul_fp(a_, b)), + ] +} + +fn i9(a: [u32; 2]) -> [u32; 2] { + let [a, a_] = a; + [sub_fp(mul_fp(9, a), a_), add_fp(a, mul_fp(9, a_))] +} + +// fn add_fp6(c: [[u32; 2]; 3], d: [[u32; 2]; 3]) -> [[u32; 2]; 3] { +// let [c0, c1, c2] = c; +// let [d0, d1, d2] = d; + +// let e0 = add_fp2(c0, d0); +// let e1 = add_fp2(c1, d1); +// let e2 = add_fp2(c2, d2); +// [e0, e1, e2] +// } + +// fn sub_fp6(c: [[u32; 2]; 3], d: [[u32; 2]; 3]) -> [[u32; 2]; 3] { +// let [c0, c1, c2] = c; +// let [d0, d1, d2] = d; + +// let e0 = sub_fp2(c0, d0); +// let e1 = sub_fp2(c1, d1); +// let e2 = sub_fp2(c2, d2); +// [e0, e1, e2] +// } + +fn mul_fp6(c: [[u32; 2]; 3], d: [[u32; 2]; 3]) -> [[u32; 2]; 3] { + let [c0, c1, c2] = c; + let [d0, d1, d2] = d; + + let c0d0 = mul_fp2(c0, d0); + let c0d1 = mul_fp2(c0, d1); + let c0d2 = mul_fp2(c0, d2); + let c1d0 = mul_fp2(c1, d0); + let c1d1 = mul_fp2(c1, d1); + let c1d2 = mul_fp2(c1, d2); + let c2d0 = mul_fp2(c2, d0); + let c2d1 = mul_fp2(c2, d1); + let c2d2 = mul_fp2(c2, d2); + let cd12 = add_fp2(c1d2, c2d1); + + [ + add_fp2(c0d0, i9(cd12)), + add3_fp2(c0d1, c1d0, i9(c2d2)), + add3_fp2(c0d2, c1d1, c2d0), + ] +} + +// fn sh(c: [[u32; 2]; 3]) -> [[u32; 2]; 3] { +// let [c0, c1, c2] = c; +// [i9(c2), c0, c1] +// } + +// fn mul_fp12(f: [[[u32; 2]; 3]; 2], g: [[[u32; 2]; 3]; 2]) -> [[[u32; 2]; 3]; 2] { +// let [f0, f1] = f; +// let [g0, g1] = g; + +// let h0 = mul_fp6(f0, g0); +// let h1 = mul_fp6(f1, g1); +// let h01 = mul_fp6(add_fp6(f0, f1), add_fp6(g0, g1)); +// [add_fp6(h0, sh(h1)), sub_fp6(h01, add_fp6(h0, h1))] +// } + +fn gen_fp6() -> [[u32; 2]; 3] { + let mut rng = thread_rng(); + [ + [rng.gen_range(0..P254), rng.gen_range(0..P254)], + [rng.gen_range(0..P254), rng.gen_range(0..P254)], + [rng.gen_range(0..P254), rng.gen_range(0..P254)], + ] +} + +fn as_stack(xs: Vec) -> Vec { + xs.iter().map(|&x| U256::from(x)).rev().collect() +} + +#[test] +fn test_fp6() -> Result<()> { + let c = gen_fp6(); + let d = gen_fp6(); + + let mut input: Vec = [c, d].into_iter().flatten().flatten().collect(); + input.push(0xdeadbeef); + + let kernel = combined_kernel(); + let initial_offset = kernel.global_labels["mul_fp6"]; + let initial_stack: Vec = as_stack(input); + let final_stack: Vec = run_with_kernel(&kernel, initial_offset, initial_stack)? + .stack() + .to_vec(); + + let output: Vec = mul_fp6(c, d).into_iter().flatten().collect(); + let expected = as_stack(output); + + assert_eq!(final_stack, expected); + + Ok(()) +} + +// fn make_initial_stack( +// f0: [[u32; 2]; 3], +// f1: [[u32; 2]; 3], +// g0: [[u32; 2]; 3], +// g1: [[u32; 2]; 3], +// ) -> Vec { +// // stack: in0, f, in0', f', in1, g, in1', g', in1, out, in0, out +// let f0: Vec = f0.into_iter().flatten().collect(); +// let f1: Vec = f1.into_iter().flatten().collect(); +// let g0: Vec = g0.into_iter().flatten().collect(); +// let g1: Vec = g1.into_iter().flatten().collect(); + +// let mut input = f0; +// input.extend(vec![0]); +// input.extend(f1); +// input.extend(g0); +// input.extend(vec![12]); +// input.extend(g1); +// input.extend(vec![12, 24, 0, 24]); + +// as_stack(input) +// } + +// #[test] +// fn test_fp12() -> Result<()> { +// let f0 = gen_fp6(); +// let f1 = gen_fp6(); +// let g0 = gen_fp6(); +// let g1 = gen_fp6(); + +// let kernel = combined_kernel(); +// let initial_offset = kernel.global_labels["test_mul_Fp12"]; +// let initial_stack: Vec = make_initial_stack(f0, f1, g0, g1); +// let final_stack: Vec = run_with_kernel(&kernel, initial_offset, initial_stack)? +// .stack() +// .to_vec(); + +// let mut output: Vec = mul_fp12([f0, f1], [g0, g1]) +// .into_iter() +// .flatten() +// .flatten() +// .collect(); +// output.extend(vec![24]); +// let expected = as_stack(output); + +// assert_eq!(final_stack, expected); + +// Ok(()) +// } diff --git a/evm/src/cpu/kernel/tests/mod.rs b/evm/src/cpu/kernel/tests/mod.rs index cfccb420..4b448af8 100644 --- a/evm/src/cpu/kernel/tests/mod.rs +++ b/evm/src/cpu/kernel/tests/mod.rs @@ -3,6 +3,7 @@ mod core; mod curve_ops; mod ecrecover; mod exp; +mod fields; mod hash; mod mpt; mod packing; diff --git a/plonky2/src/fri/mod.rs b/plonky2/src/fri/mod.rs index 9c44b53b..90f1c940 100644 --- a/plonky2/src/fri/mod.rs +++ b/plonky2/src/fri/mod.rs @@ -54,7 +54,7 @@ impl FriConfig { /// FRI parameters, including generated parameters which are specific to an instance size, in /// contrast to `FriConfig` which is user-specified and independent of instance size. -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq)] pub struct FriParams { /// User-specified FRI configuration. pub config: FriConfig, diff --git a/plonky2/src/gadgets/arithmetic.rs b/plonky2/src/gadgets/arithmetic.rs index f4722df4..33facd74 100644 --- a/plonky2/src/gadgets/arithmetic.rs +++ b/plonky2/src/gadgets/arithmetic.rs @@ -345,7 +345,7 @@ impl, const D: usize> CircuitBuilder { pub fn is_equal(&mut self, x: Target, y: Target) -> BoolTarget { let zero = self.zero(); - let equal = self.add_virtual_bool_target(); + let equal = self.add_virtual_bool_target_unsafe(); let not_equal = self.not(equal); let inv = self.add_virtual_target(); self.add_simple_generator(EqualityGenerator { x, y, equal, inv }); diff --git a/plonky2/src/lib.rs b/plonky2/src/lib.rs index 64acfe12..8a517a11 100644 --- a/plonky2/src/lib.rs +++ b/plonky2/src/lib.rs @@ -18,4 +18,5 @@ pub mod gates; pub mod hash; pub mod iop; pub mod plonk; +pub mod recursion; pub mod util; diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 83587f2e..dfd23426 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -34,7 +34,7 @@ use crate::iop::target::{BoolTarget, Target}; use crate::iop::wire::Wire; use crate::plonk::circuit_data::{ CircuitConfig, CircuitData, CommonCircuitData, ProverCircuitData, ProverOnlyCircuitData, - VerifierCircuitData, VerifierOnlyCircuitData, + VerifierCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData, }; use crate::plonk::config::{GenericConfig, Hasher}; use crate::plonk::copy_constraint::CopyConstraint; @@ -83,6 +83,15 @@ pub struct CircuitBuilder, const D: usize> { /// List of constant generators used to fill the constant wires. constant_generators: Vec>, + + /// Optional common data. When it is `Some(goal_data)`, the `build` function panics if the resulting + /// common data doesn't equal `goal_data`. + /// This is used in cyclic recursion. + pub(crate) goal_common_data: Option>, + + /// Optional verifier data that is registered as public inputs. + /// This is used in cyclic recursion to hold the circuit's own verifier key. + pub(crate) verifier_data_public_input: Option, } impl, const D: usize> CircuitBuilder { @@ -102,6 +111,8 @@ impl, const D: usize> CircuitBuilder { arithmetic_results: HashMap::new(), current_slots: HashMap::new(), constant_generators: Vec::new(), + goal_common_data: None, + verifier_data_public_input: None, }; builder.check_config(); builder @@ -144,6 +155,10 @@ impl, const D: usize> CircuitBuilder { targets.iter().for_each(|&t| self.register_public_input(t)); } + pub fn num_public_inputs(&self) -> usize { + self.public_inputs.len() + } + /// Adds a new "virtual" target. This is not an actual wire in the witness, but just a target /// that help facilitate witness generation. In particular, a generator can assign a values to a /// virtual target, which can then be copied to other (virtual or concrete) targets. When we @@ -198,8 +213,7 @@ impl, const D: usize> CircuitBuilder { PolynomialCoeffsExtTarget(coeffs) } - // TODO: Unsafe - pub fn add_virtual_bool_target(&mut self) -> BoolTarget { + pub fn add_virtual_bool_target_unsafe(&mut self) -> BoolTarget { BoolTarget::new_unsafe(self.add_virtual_target()) } @@ -215,6 +229,21 @@ impl, const D: usize> CircuitBuilder { self.register_public_input(t); t } + /// Add a virtual verifier data, register it as a public input and set it to `self.verifier_data_public_input`. + /// WARNING: Do not register any public input after calling this! TODO: relax this + pub(crate) fn add_verifier_data_public_input(&mut self) { + let verifier_data = VerifierCircuitTarget { + constants_sigmas_cap: self.add_virtual_cap(self.config.fri_config.cap_height), + circuit_digest: self.add_virtual_hash(), + }; + // The verifier data are public inputs. + self.register_public_inputs(&verifier_data.circuit_digest.elements); + for i in 0..self.config.fri_config.num_cap_elements() { + self.register_public_inputs(&verifier_data.constants_sigmas_cap.0[i].elements); + } + + self.verifier_data_public_input = Some(verifier_data); + } /// Adds a gate to the circuit, and returns its index. pub fn add_gate>(&mut self, gate_type: G, mut constants: Vec) -> usize { @@ -827,6 +856,9 @@ impl, const D: usize> CircuitBuilder { k_is, num_partial_products, }; + if let Some(goal_data) = self.goal_common_data { + assert_eq!(goal_data, common); + } let prover_only = ProverOnlyCircuitData { generators: self.generators, diff --git a/plonky2/src/plonk/circuit_data.rs b/plonky2/src/plonk/circuit_data.rs index b3747159..b5e411f1 100644 --- a/plonky2/src/plonk/circuit_data.rs +++ b/plonky2/src/plonk/circuit_data.rs @@ -276,7 +276,7 @@ pub struct ProverOnlyCircuitData< } /// Circuit data required by the verifier, but not the prover. -#[derive(Debug)] +#[derive(Debug, Eq, PartialEq)] pub struct VerifierOnlyCircuitData, const D: usize> { /// A commitment to each constant polynomial and each permutation polynomial. pub constants_sigmas_cap: MerkleCap, @@ -286,7 +286,7 @@ pub struct VerifierOnlyCircuitData, const D: usize> { } /// Circuit data required by both the prover and the verifier. -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq)] pub struct CommonCircuitData, const D: usize> { pub config: CircuitConfig, @@ -488,6 +488,7 @@ impl, const D: usize> CommonCircuitData { /// is intentionally missing certain fields, such as `CircuitConfig`, because we support only a /// limited form of dynamic inner circuits. We can't practically make things like the wire count /// dynamic, at least not without setting a maximum wire count and paying for the worst case. +#[derive(Clone)] pub struct VerifierCircuitTarget { /// A commitment to each constant polynomial and each permutation polynomial. pub constants_sigmas_cap: MerkleCapTarget, diff --git a/plonky2/src/plonk/mod.rs b/plonky2/src/plonk/mod.rs index 8cd7443f..604c1f79 100644 --- a/plonky2/src/plonk/mod.rs +++ b/plonky2/src/plonk/mod.rs @@ -1,6 +1,5 @@ pub mod circuit_builder; pub mod circuit_data; -pub mod conditional_recursive_verifier; pub mod config; pub(crate) mod copy_constraint; mod get_challenges; @@ -8,7 +7,6 @@ pub(crate) mod permutation_argument; pub mod plonk_common; pub mod proof; pub mod prover; -pub mod recursive_verifier; mod validate_shape; pub(crate) mod vanishing_poly; pub mod vars; diff --git a/plonky2/src/plonk/conditional_recursive_verifier.rs b/plonky2/src/recursion/conditional_recursive_verifier.rs similarity index 99% rename from plonky2/src/plonk/conditional_recursive_verifier.rs rename to plonky2/src/recursion/conditional_recursive_verifier.rs index 4f54ec3f..6bafc623 100644 --- a/plonky2/src/plonk/conditional_recursive_verifier.rs +++ b/plonky2/src/recursion/conditional_recursive_verifier.rs @@ -24,7 +24,6 @@ use crate::plonk::proof::{ use crate::with_context; /// Generate a proof having a given `CommonCircuitData`. -#[allow(unused)] // TODO: should be used soon. pub(crate) fn dummy_proof< F: RichField + Extendable, C: GenericConfig, @@ -183,7 +182,7 @@ impl, const D: usize> CircuitBuilder { .collect() } - fn select_hash( + pub(crate) fn select_hash( &mut self, b: BoolTarget, h0: HashOutTarget, diff --git a/plonky2/src/recursion/cyclic_recursion.rs b/plonky2/src/recursion/cyclic_recursion.rs new file mode 100644 index 00000000..f2ad7eb9 --- /dev/null +++ b/plonky2/src/recursion/cyclic_recursion.rs @@ -0,0 +1,437 @@ +#![allow(clippy::int_plus_one)] // Makes more sense for some inequalities below. +use anyhow::{ensure, Result}; +use itertools::Itertools; +use plonky2_field::extension::Extendable; + +use crate::gates::noop::NoopGate; +use crate::hash::hash_types::{HashOut, HashOutTarget, MerkleCapTarget, RichField}; +use crate::hash::merkle_tree::MerkleCap; +use crate::iop::target::{BoolTarget, Target}; +use crate::iop::witness::{PartialWitness, Witness}; +use crate::plonk::circuit_builder::CircuitBuilder; +use crate::plonk::circuit_data::{ + CommonCircuitData, VerifierCircuitTarget, VerifierOnlyCircuitData, +}; +use crate::plonk::config::Hasher; +use crate::plonk::config::{AlgebraicHasher, GenericConfig}; +use crate::plonk::proof::{ProofWithPublicInputs, ProofWithPublicInputsTarget}; +use crate::recursion::conditional_recursive_verifier::dummy_proof; + +pub struct CyclicRecursionData< + 'a, + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +> { + proof: &'a Option>, + verifier_data: &'a VerifierOnlyCircuitData, + common_data: &'a CommonCircuitData, +} + +pub struct CyclicRecursionTarget { + pub proof: ProofWithPublicInputsTarget, + pub verifier_data: VerifierCircuitTarget, + pub dummy_proof: ProofWithPublicInputsTarget, + pub dummy_verifier_data: VerifierCircuitTarget, + pub base_case: BoolTarget, +} + +impl, const D: usize> VerifierOnlyCircuitData { + fn from_slice(slice: &[C::F], common_data: &CommonCircuitData) -> Result + where + C::Hasher: AlgebraicHasher, + { + // The structure of the public inputs is `[..., circuit_digest, constants_sigmas_cap]`. + let cap_len = common_data.config.fri_config.num_cap_elements(); + let len = slice.len(); + ensure!(len >= 4 + 4 * cap_len, "Not enough public inputs"); + let constants_sigmas_cap = MerkleCap( + (0..cap_len) + .map(|i| HashOut { + elements: std::array::from_fn(|j| slice[len - 4 * (cap_len - i) + j]), + }) + .collect(), + ); + let circuit_digest = + HashOut::from_partial(&slice[len - 4 - 4 * cap_len..len - 4 * cap_len]); + + Ok(Self { + circuit_digest, + constants_sigmas_cap, + }) + } +} + +impl VerifierCircuitTarget { + fn from_slice, C: GenericConfig, const D: usize>( + slice: &[Target], + common_data: &CommonCircuitData, + ) -> Result { + let cap_len = common_data.config.fri_config.num_cap_elements(); + let len = slice.len(); + ensure!(len >= 4 + 4 * cap_len, "Not enough public inputs"); + let constants_sigmas_cap = MerkleCapTarget( + (0..cap_len) + .map(|i| HashOutTarget { + elements: std::array::from_fn(|j| slice[len - 4 * (cap_len - i) + j]), + }) + .collect(), + ); + let circuit_digest = HashOutTarget { + elements: std::array::from_fn(|i| slice[len - 4 - 4 * cap_len + i]), + }; + + Ok(Self { + circuit_digest, + constants_sigmas_cap, + }) + } +} + +impl, const D: usize> CircuitBuilder { + /// Cyclic recursion gadget. + /// WARNING: Do not register any public input after calling this! TODO: relax this + pub fn cyclic_recursion>( + &mut self, + // Flag set to true for the base case of the cycle where we verify a dummy proof to bootstrap the cycle. Set to false otherwise. + base_case: BoolTarget, + previous_virtual_public_inputs: &[Target], + common_data: &mut CommonCircuitData, + ) -> Result> + where + C::Hasher: AlgebraicHasher, + [(); C::Hasher::HASH_SIZE]:, + { + if self.verifier_data_public_input.is_none() { + self.add_verifier_data_public_input(); + } + let verifier_data = self.verifier_data_public_input.clone().unwrap(); + common_data.num_public_inputs = self.num_public_inputs(); + self.goal_common_data = Some(common_data.clone()); + + let dummy_verifier_data = VerifierCircuitTarget { + constants_sigmas_cap: self.add_virtual_cap(self.config.fri_config.cap_height), + circuit_digest: self.add_virtual_hash(), + }; + + let proof = self.add_virtual_proof_with_pis::(common_data); + let dummy_proof = self.add_virtual_proof_with_pis::(common_data); + + let pis = VerifierCircuitTarget::from_slice::(&proof.public_inputs, common_data)?; + // Connect previous verifier data to current one. This guarantees that every proof in the cycle uses the same verifier data. + self.connect_hashes(pis.circuit_digest, verifier_data.circuit_digest); + for (h0, h1) in pis + .constants_sigmas_cap + .0 + .iter() + .zip_eq(&verifier_data.constants_sigmas_cap.0) + { + self.connect_hashes(*h0, *h1); + } + + for (x, y) in previous_virtual_public_inputs + .iter() + .zip(&proof.public_inputs) + { + self.connect(*x, *y); + } + + // Verify the dummy proof if `base_case` is set to true, otherwise verify the "real" proof. + self.conditionally_verify_proof::( + base_case, + &dummy_proof, + &dummy_verifier_data, + &proof, + &verifier_data, + common_data, + ); + + // Make sure we have enough gates to match `common_data`. + while self.num_gates() < (common_data.degree() / 2) { + self.add_gate(NoopGate, vec![]); + } + // Make sure we have every gate to match `common_data`. + for g in &common_data.gates { + self.add_gate_to_gate_set(g.clone()); + } + + Ok(CyclicRecursionTarget { + proof, + verifier_data: verifier_data.clone(), + dummy_proof, + dummy_verifier_data, + base_case, + }) + } +} + +/// Set the targets in a `CyclicRecursionTarget` to their corresponding values in a `CyclicRecursionData`. +pub fn set_cyclic_recursion_data_target< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + pw: &mut PartialWitness, + cyclic_recursion_data_target: &CyclicRecursionTarget, + cyclic_recursion_data: &CyclicRecursionData, + // Public inputs to set in the base case to seed some initial data. + public_inputs: &[F], +) -> Result<()> +where + C::Hasher: AlgebraicHasher, + [(); C::Hasher::HASH_SIZE]:, +{ + if let Some(proof) = cyclic_recursion_data.proof { + pw.set_bool_target(cyclic_recursion_data_target.base_case, false); + pw.set_proof_with_pis_target(&cyclic_recursion_data_target.proof, proof); + pw.set_verifier_data_target( + &cyclic_recursion_data_target.verifier_data, + cyclic_recursion_data.verifier_data, + ); + pw.set_proof_with_pis_target(&cyclic_recursion_data_target.dummy_proof, proof); + pw.set_verifier_data_target( + &cyclic_recursion_data_target.dummy_verifier_data, + cyclic_recursion_data.verifier_data, + ); + } else { + let (dummy_proof, dummy_data) = dummy_proof::(cyclic_recursion_data.common_data)?; + pw.set_bool_target(cyclic_recursion_data_target.base_case, true); + let mut proof = dummy_proof.clone(); + proof.public_inputs[0..public_inputs.len()].copy_from_slice(public_inputs); + let pis_len = proof.public_inputs.len(); + // The circuit checks that the verifier data is the same throughout the cycle, so + // we set the verifier data to the "real" verifier data even though it's unused in the base case. + let num_cap = cyclic_recursion_data + .common_data + .config + .fri_config + .num_cap_elements(); + let s = pis_len - 4 - 4 * num_cap; + proof.public_inputs[s..s + 4] + .copy_from_slice(&cyclic_recursion_data.verifier_data.circuit_digest.elements); + for i in 0..num_cap { + proof.public_inputs[s + 4 * (1 + i)..s + 4 * (2 + i)].copy_from_slice( + &cyclic_recursion_data.verifier_data.constants_sigmas_cap.0[i].elements, + ); + } + + pw.set_proof_with_pis_target(&cyclic_recursion_data_target.proof, &proof); + pw.set_verifier_data_target( + &cyclic_recursion_data_target.verifier_data, + cyclic_recursion_data.verifier_data, + ); + pw.set_proof_with_pis_target(&cyclic_recursion_data_target.dummy_proof, &dummy_proof); + pw.set_verifier_data_target( + &cyclic_recursion_data_target.dummy_verifier_data, + &dummy_data, + ); + } + + Ok(()) +} + +/// Additional checks to be performed on a cyclic recursive proof in addition to verifying the proof. +/// Checks that the `base_case` flag is boolean and that the purported verifier data in the public inputs +/// match the real verifier data. +pub fn check_cyclic_proof_verifier_data< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + proof: &ProofWithPublicInputs, + verifier_data: &VerifierOnlyCircuitData, + common_data: &CommonCircuitData, +) -> Result<()> +where + C::Hasher: AlgebraicHasher, +{ + let pis = VerifierOnlyCircuitData::::from_slice(&proof.public_inputs, common_data)?; + ensure!(verifier_data.constants_sigmas_cap == pis.constants_sigmas_cap); + ensure!(verifier_data.circuit_digest == pis.circuit_digest); + + Ok(()) +} + +#[cfg(test)] +mod tests { + + use anyhow::Result; + use plonky2_field::extension::Extendable; + use plonky2_field::types::PrimeField64; + + use crate::field::types::Field; + use crate::gates::noop::NoopGate; + use crate::hash::hash_types::RichField; + use crate::hash::hashing::hash_n_to_hash_no_pad; + use crate::hash::poseidon::{PoseidonHash, PoseidonPermutation}; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierCircuitTarget}; + use crate::plonk::config::{AlgebraicHasher, GenericConfig, Hasher, PoseidonGoldilocksConfig}; + use crate::recursion::cyclic_recursion::{ + check_cyclic_proof_verifier_data, set_cyclic_recursion_data_target, CyclicRecursionData, + }; + + // Generates `CommonCircuitData` usable for recursion. + fn common_data_for_recursion< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, + >() -> CommonCircuitData + where + C::Hasher: AlgebraicHasher, + [(); C::Hasher::HASH_SIZE]:, + { + let config = CircuitConfig::standard_recursion_config(); + let builder = CircuitBuilder::::new(config); + let data = builder.build::(); + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let proof = builder.add_virtual_proof_with_pis::(&data.common); + let verifier_data = VerifierCircuitTarget { + constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), + circuit_digest: builder.add_virtual_hash(), + }; + builder.verify_proof::(proof, &verifier_data, &data.common); + let data = builder.build::(); + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + let proof = builder.add_virtual_proof_with_pis::(&data.common); + let verifier_data = VerifierCircuitTarget { + constants_sigmas_cap: builder.add_virtual_cap(data.common.config.fri_config.cap_height), + circuit_digest: builder.add_virtual_hash(), + }; + builder.verify_proof::(proof, &verifier_data, &data.common); + while builder.num_gates() < 1 << 12 { + builder.add_gate(NoopGate, vec![]); + } + builder.build::().common + } + + #[test] + fn test_cyclic_recursion() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_recursion_config(); + let mut pw = PartialWitness::new(); + let mut builder = CircuitBuilder::::new(config); + + // Circuit that computes a repeated hash. + let initial_hash = builder.add_virtual_hash(); + builder.register_public_inputs(&initial_hash.elements); + // Hash from the previous proof. + let old_hash = builder.add_virtual_hash(); + // The input hash is either the previous hash or the initial hash depending on whether + // the last proof was a base case. + let input_hash = builder.add_virtual_hash(); + let h = builder.hash_n_to_hash_no_pad::(input_hash.elements.to_vec()); + builder.register_public_inputs(&h.elements); + // Previous counter. + let old_counter = builder.add_virtual_target(); + let one = builder.one(); + let new_counter = builder.add_virtual_public_input(); + let old_pis = [ + initial_hash.elements.as_slice(), + old_hash.elements.as_slice(), + [old_counter].as_slice(), + ] + .concat(); + + let mut common_data = common_data_for_recursion::(); + + let base_case = builder.add_virtual_bool_target_safe(); + // Add cyclic recursion gadget. + let cyclic_data_target = + builder.cyclic_recursion::(base_case, &old_pis, &mut common_data)?; + let input_hash_bis = + builder.select_hash(cyclic_data_target.base_case, initial_hash, old_hash); + builder.connect_hashes(input_hash, input_hash_bis); + let not_base_case = builder.sub(one, cyclic_data_target.base_case.target); + // New counter is the previous counter +1 if the previous proof wasn't a base case. + let new_counter_bis = builder.add(old_counter, not_base_case); + builder.connect(new_counter, new_counter_bis); + + let cyclic_circuit_data = builder.build::(); + + let cyclic_recursion_data = CyclicRecursionData { + proof: &None, // Base case: We don't have a proof to put here yet. + verifier_data: &cyclic_circuit_data.verifier_only, + common_data: &cyclic_circuit_data.common, + }; + let initial_hash = [F::ZERO, F::ONE, F::TWO, F::from_canonical_usize(3)]; + set_cyclic_recursion_data_target( + &mut pw, + &cyclic_data_target, + &cyclic_recursion_data, + &initial_hash, + )?; + let proof = cyclic_circuit_data.prove(pw)?; + check_cyclic_proof_verifier_data( + &proof, + cyclic_recursion_data.verifier_data, + cyclic_recursion_data.common_data, + )?; + cyclic_circuit_data.verify(proof.clone())?; + + // 1st recursive layer. + let mut pw = PartialWitness::new(); + let cyclic_recursion_data = CyclicRecursionData { + proof: &Some(proof), // Input previous proof. + verifier_data: &cyclic_circuit_data.verifier_only, + common_data: &cyclic_circuit_data.common, + }; + set_cyclic_recursion_data_target( + &mut pw, + &cyclic_data_target, + &cyclic_recursion_data, + &[], + )?; + let proof = cyclic_circuit_data.prove(pw)?; + check_cyclic_proof_verifier_data( + &proof, + cyclic_recursion_data.verifier_data, + cyclic_recursion_data.common_data, + )?; + cyclic_circuit_data.verify(proof.clone())?; + + // 2nd recursive layer. + let mut pw = PartialWitness::new(); + let cyclic_recursion_data = CyclicRecursionData { + proof: &Some(proof), // Input previous proof. + verifier_data: &cyclic_circuit_data.verifier_only, + common_data: &cyclic_circuit_data.common, + }; + set_cyclic_recursion_data_target( + &mut pw, + &cyclic_data_target, + &cyclic_recursion_data, + &[], + )?; + let proof = cyclic_circuit_data.prove(pw)?; + check_cyclic_proof_verifier_data( + &proof, + cyclic_recursion_data.verifier_data, + cyclic_recursion_data.common_data, + )?; + + // Verify that the proof correctly computes a repeated hash. + let initial_hash = &proof.public_inputs[..4]; + let hash = &proof.public_inputs[4..8]; + let counter = proof.public_inputs[8]; + let mut h: [F; 4] = initial_hash.try_into().unwrap(); + assert_eq!( + hash, + std::iter::repeat_with(|| { + h = hash_n_to_hash_no_pad::(&h).elements; + h + }) + .nth(counter.to_canonical_u64() as usize) + .unwrap() + ); + + cyclic_circuit_data.verify(proof) + } +} diff --git a/plonky2/src/recursion/mod.rs b/plonky2/src/recursion/mod.rs new file mode 100644 index 00000000..33e8212e --- /dev/null +++ b/plonky2/src/recursion/mod.rs @@ -0,0 +1,3 @@ +pub mod conditional_recursive_verifier; +pub mod cyclic_recursion; +pub mod recursive_verifier; diff --git a/plonky2/src/plonk/recursive_verifier.rs b/plonky2/src/recursion/recursive_verifier.rs similarity index 100% rename from plonky2/src/plonk/recursive_verifier.rs rename to plonky2/src/recursion/recursive_verifier.rs