From 788a57a6e74ceca4276c2739cc82f564c4788e35 Mon Sep 17 00:00:00 2001 From: thomaslavaur Date: Fri, 20 Sep 2024 10:11:08 +0200 Subject: [PATCH 01/60] Move experiments here --- Stwo_wrapper/Cargo.toml | 22 + Stwo_wrapper/LICENSE | 201 +++ Stwo_wrapper/README.md | 62 + Stwo_wrapper/Untitled1.ipynb | 429 +++++ Stwo_wrapper/WORKSPACE | 0 Stwo_wrapper/crates/prover/Cargo.toml | 111 ++ Stwo_wrapper/crates/prover/benches/README.md | 2 + Stwo_wrapper/crates/prover/benches/bit_rev.rs | 39 + .../crates/prover/benches/eval_at_point.rs | 35 + Stwo_wrapper/crates/prover/benches/fft.rs | 131 ++ Stwo_wrapper/crates/prover/benches/field.rs | 150 ++ Stwo_wrapper/crates/prover/benches/fri.rs | 35 + Stwo_wrapper/crates/prover/benches/lookups.rs | 104 ++ Stwo_wrapper/crates/prover/benches/matrix.rs | 63 + Stwo_wrapper/crates/prover/benches/merkle.rs | 38 + Stwo_wrapper/crates/prover/benches/pcs.rs | 81 + .../crates/prover/benches/poseidon.rs | 18 + .../crates/prover/benches/prefix_sum.rs | 19 + .../crates/prover/benches/quotients.rs | 55 + Stwo_wrapper/crates/prover/proof.json | 348 ++++ .../prover/src/constraint_framework/assert.rs | 84 + .../src/constraint_framework/component.rs | 210 +++ .../constraint_framework/constant_columns.rs | 37 + .../prover/src/constraint_framework/info.rs | 48 + .../prover/src/constraint_framework/logup.rs | 315 ++++ .../prover/src/constraint_framework/mod.rs | 97 ++ .../prover/src/constraint_framework/point.rs | 57 + .../src/constraint_framework/simd_domain.rs | 106 ++ .../prover/src/core/air/accumulation.rs | 297 ++++ .../crates/prover/src/core/air/components.rs | 80 + .../crates/prover/src/core/air/mask.rs | 91 ++ .../crates/prover/src/core/air/mod.rs | 76 + .../src/core/backend/cpu/accumulation.rs | 12 + .../prover/src/core/backend/cpu/blake2s.rs | 24 + .../prover/src/core/backend/cpu/circle.rs | 376 +++++ .../crates/prover/src/core/backend/cpu/fri.rs | 144 ++ .../prover/src/core/backend/cpu/grind.rs | 18 + .../src/core/backend/cpu/lookups/gkr.rs | 448 ++++++ .../src/core/backend/cpu/lookups/mle.rs | 66 + .../src/core/backend/cpu/lookups/mod.rs | 2 + .../crates/prover/src/core/backend/cpu/mod.rs | 105 ++ .../src/core/backend/cpu/poseidon252.rs | 24 + .../src/core/backend/cpu/poseidon_bls.rs | 24 + .../prover/src/core/backend/cpu/quotients.rs | 210 +++ .../crates/prover/src/core/backend/mod.rs | 66 + .../src/core/backend/simd/accumulation.rs | 12 + .../src/core/backend/simd/bit_reverse.rs | 203 +++ .../prover/src/core/backend/simd/blake2s.rs | 412 +++++ .../prover/src/core/backend/simd/circle.rs | 436 +++++ .../prover/src/core/backend/simd/cm31.rs | 230 +++ .../prover/src/core/backend/simd/column.rs | 656 ++++++++ .../prover/src/core/backend/simd/domain.rs | 86 + .../prover/src/core/backend/simd/fft/ifft.rs | 712 ++++++++ .../prover/src/core/backend/simd/fft/mod.rs | 120 ++ .../prover/src/core/backend/simd/fft/rfft.rs | 742 +++++++++ .../prover/src/core/backend/simd/fri.rs | 261 +++ .../prover/src/core/backend/simd/grind.rs | 95 ++ .../src/core/backend/simd/lookups/gkr.rs | 684 ++++++++ .../src/core/backend/simd/lookups/mle.rs | 132 ++ .../src/core/backend/simd/lookups/mod.rs | 2 + .../prover/src/core/backend/simd/m31.rs | 666 ++++++++ .../prover/src/core/backend/simd/mod.rs | 41 + .../src/core/backend/simd/poseidon252.rs | 36 + .../src/core/backend/simd/poseidon_bls.rs | 36 + .../src/core/backend/simd/prefix_sum.rs | 188 +++ .../prover/src/core/backend/simd/qm31.rs | 357 +++++ .../prover/src/core/backend/simd/quotients.rs | 314 ++++ .../prover/src/core/backend/simd/utils.rs | 52 + .../src/core/backend/simd/very_packed_m31.rs | 222 +++ .../crates/prover/src/core/channel/blake2s.rs | 186 +++ .../crates/prover/src/core/channel/mod.rs | 57 + .../prover/src/core/channel/poseidon252.rs | 190 +++ .../prover/src/core/channel/poseidon_bls.rs | 590 +++++++ Stwo_wrapper/crates/prover/src/core/circle.rs | 561 +++++++ .../crates/prover/src/core/constraints.rs | 251 +++ Stwo_wrapper/crates/prover/src/core/fft.rs | 21 + .../crates/prover/src/core/fields/cm31.rs | 137 ++ .../crates/prover/src/core/fields/m31.rs | 258 +++ .../crates/prover/src/core/fields/mod.rs | 489 ++++++ .../crates/prover/src/core/fields/qm31.rs | 195 +++ .../prover/src/core/fields/secure_column.rs | 111 ++ Stwo_wrapper/crates/prover/src/core/fri.rs | 1425 +++++++++++++++++ .../prover/src/core/lookups/gkr_prover.rs | 566 +++++++ .../prover/src/core/lookups/gkr_verifier.rs | 357 +++++ .../crates/prover/src/core/lookups/mle.rs | 106 ++ .../crates/prover/src/core/lookups/mod.rs | 5 + .../prover/src/core/lookups/sumcheck.rs | 292 ++++ .../crates/prover/src/core/lookups/utils.rs | 356 ++++ Stwo_wrapper/crates/prover/src/core/mod.rs | 59 + .../crates/prover/src/core/pcs/mod.rs | 40 + .../crates/prover/src/core/pcs/prover.rs | 256 +++ .../crates/prover/src/core/pcs/quotients.rs | 218 +++ .../crates/prover/src/core/pcs/utils.rs | 158 ++ .../crates/prover/src/core/pcs/verifier.rs | 132 ++ .../prover/src/core/poly/circle/canonic.rs | 77 + .../prover/src/core/poly/circle/domain.rs | 188 +++ .../prover/src/core/poly/circle/evaluation.rs | 218 +++ .../crates/prover/src/core/poly/circle/mod.rs | 56 + .../crates/prover/src/core/poly/circle/ops.rs | 48 + .../prover/src/core/poly/circle/poly.rs | 118 ++ .../src/core/poly/circle/secure_poly.rs | 118 ++ .../crates/prover/src/core/poly/line.rs | 408 +++++ .../crates/prover/src/core/poly/mod.rs | 14 + .../crates/prover/src/core/poly/twiddles.rs | 13 + .../crates/prover/src/core/poly/utils.rs | 115 ++ .../crates/prover/src/core/proof_of_work.rs | 7 + .../crates/prover/src/core/prover/mod.rs | 186 +++ .../crates/prover/src/core/queries.rs | 237 +++ .../crates/prover/src/core/test_utils.rs | 17 + Stwo_wrapper/crates/prover/src/core/utils.rs | 327 ++++ .../crates/prover/src/core/vcs/blake2_hash.rs | 139 ++ .../prover/src/core/vcs/blake2_merkle.rs | 148 ++ .../crates/prover/src/core/vcs/blake2s_ref.rs | 217 +++ .../crates/prover/src/core/vcs/blake3_hash.rs | 132 ++ .../crates/prover/src/core/vcs/hash.rs | 15 + .../crates/prover/src/core/vcs/mod.rs | 20 + .../crates/prover/src/core/vcs/ops.rs | 47 + .../prover/src/core/vcs/poseidon252_merkle.rs | 182 +++ .../src/core/vcs/poseidon_bls_merkle.rs | 581 +++++++ .../crates/prover/src/core/vcs/prover.rs | 223 +++ .../crates/prover/src/core/vcs/test_utils.rs | 60 + .../crates/prover/src/core/vcs/utils.rs | 20 + .../crates/prover/src/core/vcs/verifier.rs | 194 +++ .../crates/prover/src/examples/blake/air.rs | 483 ++++++ .../crates/prover/src/examples/blake/mod.rs | 126 ++ .../src/examples/blake/round/constraints.rs | 164 ++ .../prover/src/examples/blake/round/gen.rs | 281 ++++ .../prover/src/examples/blake/round/mod.rs | 110 ++ .../examples/blake/scheduler/constraints.rs | 64 + .../src/examples/blake/scheduler/gen.rs | 171 ++ .../src/examples/blake/scheduler/mod.rs | 106 ++ .../examples/blake/xor_table/constraints.rs | 52 + .../src/examples/blake/xor_table/gen.rs | 168 ++ .../src/examples/blake/xor_table/mod.rs | 158 ++ .../crates/prover/src/examples/mod.rs | 5 + .../crates/prover/src/examples/plonk/mod.rs | 300 ++++ .../prover/src/examples/poseidon/mod.rs | 508 ++++++ .../prover/src/examples/wide_fibonacci/mod.rs | 619 +++++++ .../examples/xor/gkr_lookups/accumulation.rs | 186 +++ .../src/examples/xor/gkr_lookups/mle_eval.rs | 571 +++++++ .../src/examples/xor/gkr_lookups/mod.rs | 2 + .../crates/prover/src/examples/xor/mod.rs | 1 + Stwo_wrapper/crates/prover/src/lib.rs | 23 + Stwo_wrapper/crates/prover/src/math/matrix.rs | 67 + Stwo_wrapper/crates/prover/src/math/mod.rs | 2 + Stwo_wrapper/crates/prover/src/math/utils.rs | 24 + Stwo_wrapper/poseidon_benchmark.sh | 3 + Stwo_wrapper/resources/img/logo.png | Bin 0 -> 19360 bytes Stwo_wrapper/rust-toolchain.toml | 2 + Stwo_wrapper/rustfmt.toml | 12 + Stwo_wrapper/scripts/bench.sh | 5 + Stwo_wrapper/scripts/clippy.sh | 3 + Stwo_wrapper/scripts/rust_fmt.sh | 3 + Stwo_wrapper/scripts/test_avx.sh | 4 + 154 files changed, 27019 insertions(+) create mode 100644 Stwo_wrapper/Cargo.toml create mode 100644 Stwo_wrapper/LICENSE create mode 100644 Stwo_wrapper/README.md create mode 100644 Stwo_wrapper/Untitled1.ipynb create mode 100644 Stwo_wrapper/WORKSPACE create mode 100644 Stwo_wrapper/crates/prover/Cargo.toml create mode 100644 Stwo_wrapper/crates/prover/benches/README.md create mode 100644 Stwo_wrapper/crates/prover/benches/bit_rev.rs create mode 100644 Stwo_wrapper/crates/prover/benches/eval_at_point.rs create mode 100644 Stwo_wrapper/crates/prover/benches/fft.rs create mode 100644 Stwo_wrapper/crates/prover/benches/field.rs create mode 100644 Stwo_wrapper/crates/prover/benches/fri.rs create mode 100644 Stwo_wrapper/crates/prover/benches/lookups.rs create mode 100644 Stwo_wrapper/crates/prover/benches/matrix.rs create mode 100644 Stwo_wrapper/crates/prover/benches/merkle.rs create mode 100644 Stwo_wrapper/crates/prover/benches/pcs.rs create mode 100644 Stwo_wrapper/crates/prover/benches/poseidon.rs create mode 100644 Stwo_wrapper/crates/prover/benches/prefix_sum.rs create mode 100644 Stwo_wrapper/crates/prover/benches/quotients.rs create mode 100644 Stwo_wrapper/crates/prover/proof.json create mode 100644 Stwo_wrapper/crates/prover/src/constraint_framework/assert.rs create mode 100644 Stwo_wrapper/crates/prover/src/constraint_framework/component.rs create mode 100644 Stwo_wrapper/crates/prover/src/constraint_framework/constant_columns.rs create mode 100644 Stwo_wrapper/crates/prover/src/constraint_framework/info.rs create mode 100644 Stwo_wrapper/crates/prover/src/constraint_framework/logup.rs create mode 100644 Stwo_wrapper/crates/prover/src/constraint_framework/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/constraint_framework/point.rs create mode 100644 Stwo_wrapper/crates/prover/src/constraint_framework/simd_domain.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/air/accumulation.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/air/components.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/air/mask.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/air/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/cpu/accumulation.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/cpu/blake2s.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/cpu/circle.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/cpu/fri.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/cpu/grind.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/gkr.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/mle.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/cpu/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/cpu/poseidon252.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/cpu/poseidon_bls.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/cpu/quotients.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/accumulation.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/bit_reverse.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/blake2s.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/circle.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/cm31.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/column.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/domain.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/fft/ifft.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/fft/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/fft/rfft.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/fri.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/grind.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/gkr.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mle.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/m31.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/poseidon252.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/poseidon_bls.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/prefix_sum.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/qm31.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/quotients.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/utils.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/backend/simd/very_packed_m31.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/channel/blake2s.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/channel/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/channel/poseidon252.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/channel/poseidon_bls.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/circle.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/constraints.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/fft.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/fields/cm31.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/fields/m31.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/fields/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/fields/qm31.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/fields/secure_column.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/fri.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/lookups/gkr_prover.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/lookups/gkr_verifier.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/lookups/mle.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/lookups/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/lookups/sumcheck.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/lookups/utils.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/pcs/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/pcs/prover.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/pcs/quotients.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/pcs/utils.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/pcs/verifier.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/poly/circle/canonic.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/poly/circle/domain.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/poly/circle/evaluation.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/poly/circle/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/poly/circle/ops.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/poly/circle/poly.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/poly/circle/secure_poly.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/poly/line.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/poly/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/poly/twiddles.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/poly/utils.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/proof_of_work.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/prover/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/queries.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/test_utils.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/utils.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/vcs/blake2_hash.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/vcs/blake2_merkle.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/vcs/blake2s_ref.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/vcs/blake3_hash.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/vcs/hash.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/vcs/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/vcs/ops.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/vcs/poseidon252_merkle.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/vcs/poseidon_bls_merkle.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/vcs/prover.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/vcs/test_utils.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/vcs/utils.rs create mode 100644 Stwo_wrapper/crates/prover/src/core/vcs/verifier.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/blake/air.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/blake/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/blake/round/constraints.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/blake/round/gen.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/blake/round/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/blake/scheduler/constraints.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/blake/scheduler/gen.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/blake/scheduler/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/blake/xor_table/constraints.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/blake/xor_table/gen.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/blake/xor_table/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/plonk/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/poseidon/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/examples/xor/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/lib.rs create mode 100644 Stwo_wrapper/crates/prover/src/math/matrix.rs create mode 100644 Stwo_wrapper/crates/prover/src/math/mod.rs create mode 100644 Stwo_wrapper/crates/prover/src/math/utils.rs create mode 100755 Stwo_wrapper/poseidon_benchmark.sh create mode 100644 Stwo_wrapper/resources/img/logo.png create mode 100644 Stwo_wrapper/rust-toolchain.toml create mode 100644 Stwo_wrapper/rustfmt.toml create mode 100755 Stwo_wrapper/scripts/bench.sh create mode 100755 Stwo_wrapper/scripts/clippy.sh create mode 100755 Stwo_wrapper/scripts/rust_fmt.sh create mode 100755 Stwo_wrapper/scripts/test_avx.sh diff --git a/Stwo_wrapper/Cargo.toml b/Stwo_wrapper/Cargo.toml new file mode 100644 index 0000000..0f314a4 --- /dev/null +++ b/Stwo_wrapper/Cargo.toml @@ -0,0 +1,22 @@ +[workspace] +members = ["crates/prover"] +resolver = "2" + +[workspace.package] +version = "0.1.1" +edition = "2021" + +[workspace.dependencies] +blake2 = "0.10.6" +blake3 = "1.5.0" +educe = "0.5.0" +hex = "0.4.3" +itertools = "0.12.0" +num-traits = "0.2.17" +thiserror = "1.0.56" +bytemuck = "1.14.3" +tracing = "0.1.40" + +[profile.bench] +codegen-units = 1 +lto = true diff --git a/Stwo_wrapper/LICENSE b/Stwo_wrapper/LICENSE new file mode 100644 index 0000000..2e0cecd --- /dev/null +++ b/Stwo_wrapper/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2024 StarkWare Industries Ltd. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/Stwo_wrapper/README.md b/Stwo_wrapper/README.md new file mode 100644 index 0000000..ece38cc --- /dev/null +++ b/Stwo_wrapper/README.md @@ -0,0 +1,62 @@ +
+ +![STWO](resources/img/logo.png) + +GitHub Workflow Status (with event) + + + +Project license +StarkWare +
+ +
+

+ + Paper + + | + + Documentation + + | + + Benchmarks + +

+
+ +# Stwo + +## 🌟 About + +Stwo is a next generation implementation of a [CSTARK](https://eprint.iacr.org/2024/278) prover and verifier, written in Rust 🦀. + +> **Stwo is a work in progress.** +> +> It is not recommended to use it in a production setting yet. + +## 🚀 Key Features + +- **Circle STARKs:** Based on the latest cryptographic research and innovations in the ZK field. +- **High performance:** Stwo is designed to be extremely fast and efficient. +- **Flexible:** Adaptable for various validity proof applications. + +## 📊 Benchmarks + +Run `poseidon_benchmark.sh` to run a single-threaded poseidon2 hash proof benchmark. + +Further benchmarks can be run using `cargo bench`. + +Visual representation of benchmarks can be found [here](https://starkware-libs.github.io/stwo/dev/bench/index.html). + +## 📜 License + +This project is licensed under the **Apache 2.0 license**. + +See [LICENSE](LICENSE) for more information. + + + + + diff --git a/Stwo_wrapper/Untitled1.ipynb b/Stwo_wrapper/Untitled1.ipynb new file mode 100644 index 0000000..e504f91 --- /dev/null +++ b/Stwo_wrapper/Untitled1.ipynb @@ -0,0 +1,429 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "43f7fdc0", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "with open(\"crates/prover/proof.json\", \"r\") as f:\n", + " proof = json.load(f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88f232b5", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "12ef3767", + "metadata": {}, + "outputs": [], + "source": [ + "F = FiniteField(52435875175126190479447740508185965837690552500527637822603658699938581184513)\n", + "\n", + "poseidon_comp_consts = [50207570499218320245539736680169582180207201335688461025883902752909290481781,\n", + "24448666467656506447555018649749346340705294023832615387641453784702583464707,\n", + "34092944507611308604157957266676007619644244199372265837364557849561670729974,\n", + "46954129210702959446093971191783182601726081775951103310666314834569091037713,\n", + "38612156878839717097806285947575477749087608521505464809942918879152074545066,\n", + "19752610610343814834081989345964253902282700341539483876504601969121084774539,\n", + "46567545048462867923299713424766325689670511126407629551256255807498976196546,\n", + "9520793415506326549109545537894287560752519598132096386048093015534488804808,\n", + "22814234098357034097599682726494820560934925862581927123816510593532324971186,\n", + "3277621627834606517208177071759088097855048183641615082769528872043050020787,\n", + "29230456498980145088774069819561206654397510279226264474986155631775387918911,\n", + "19087113294497892618475669593723876605785307026981218038380435259594863105240,\n", + "39932371919358015185769877859035474336011770016475087638554815294278664040916,\n", + "17645770319151120318035258350885823104235488352935695302274836429012504407725,\n", + "17990728141399065004015538797609951295983853332644474801890158217822768128628,\n", + "12607949331462269429981198199999740921418125994747028428126661151190418292729,\n", + "33067617079394435172767143524489677593390850035349407507374659268468278200906,\n", + "10025233623562179533044093426455032352895184661359005809314430689113735312874,\n", + "20398677688057466110325934731430812468657996794663167456321709689030080949228,\n", + "32085671199853825909918260218834827339732598508827083525700252644622592932757,\n", + "36451986593067827349794003109666944974266236856145879921902940325507228739480,\n", + "51835224419566813714481533481210630888564327175625175437244377303858990291964,\n", + "1944662263588038198375346521900053780907777056656211622999059135594196413076,\n", + "12995068374816903282074967132431954020410301768622808407703775963080983755183,\n", + "13278128079226679628648689279705910775020794457648431336050464485837924986341,\n", + "39207195481789228835625472428521288347432218258431761869689775532020546099642,\n", + "21081768833381902942114733002158882075348844281359283013642620389621494952015,\n", + "20751788049060260683191405008569080723662271828149227137187075968560831545739,\n", + "20820291785607398388900832350860967875629907105847554413318238165275470374689,\n", + "6971878585215744613467847324629115462668098071102846520957717612260531709386,\n", + "42421164250058173810994728364144776180689735894673627964404703973460802099146,\n", + "32890116643831560295329417521056875595733120141391587236744387068135440602102,\n", + "42670005614507618780436482775021159957307712089310941922452133588875084445464,\n", + "21120353743307986506720883740380468652053382764895882204680310593048134053982,\n", + "7853308243263055176258751393326645428041138029306706980470113526802326214700,\n", + "17545076036297840030021082424260289805456380863517895917265467158332801090765,\n", + "29526223376722400691172584788126610514669516909826971155598997488361793726636,\n", + "48421712782536172546302502401679048379568171245541707202282458591545347755349,\n", + "10740853637774754893036062076749871837371049036966225040269105665447180116170,\n", + "34042041521558704677804677569712674569738576001717295340556848855085089618161,\n", + "24290796201833228559129233924595614281891670608675107544294264860003803501509,\n", + "26722678647461522072509896114724736555938247563993442152746954157222882824350,\n", + "20252491387019425681551488261397157776479297799360691728406809731508542196845,\n", + "50322025264206689090790987370440439179141270613911973034521438238687587958097,\n", + "17070806525931584028449131949070191143344166668070820337429561524629464200550,\n", + "25856554324149146992239414502939942208580094928192925471532421030223074525051,\n", + "17714998974036855356530338446243137421735047395517260588250413348153258772076,\n", + "44833315250334176776685835079382312848180252180173884969157994737319426976437,\n", + "35603718839327251012037553292043899153393807438387129505923567878785822738162,\n", + "20515196301761603016197694845695272699608637099106794944737311528118558777570,\n", + "10100400556460905874275078234698187530913105549037797180493988678937053918124,\n", + "29943022708270799252522211109308629054849337552699067311814388215768905671554,\n", + "33400164627534996188947689774080657908147988421361870074239537729877153299092,\n", + "45574161704098228712016716221086232277248798839906622903502141601878895917316,\n", + "40623265267364613450776577487319920007897396936924051398790906883872334022964,\n", + "37929176440858430683261948300797278761072096845318183419284347376614069989808,\n", + "12242010394227909997626655999345208835040087302065045201635069094289920778463,\n", + "38947272924417356803622776795797899233194116520680026665045628837194239730633,\n", + "6838505804652359252670794375725267665530548946030641535297433541475260948424,\n", + "21345718918993308853491352363460625447157796362108157527364130872100101143328,\n", + "26397988737034501095129796920971941795766209722106383463197090306632188634870,\n", + "47092791129593573928369881528796435131623991381197863072979392492232678100884,\n", + "36850972241154890671857874025605504779963735054128436776319531005864791472123,\n", + "27893799443241349360688137159923920340185830261519093384488134540544971987330,\n", + "34031071010517479317003393843135868322188010660871691856659878788331169912272,\n", + "3102550735908358465878301372253437950829524988677083749179431098369388780259,\n", + "2963742902601529003553690631564645593518709846059084207036841793643477514707,\n", + "34538583661636382515652368664945657625216404085453317149263146639486246251503,\n", + "49179786922858759927440465310900376749726765337268308911471491527044937447403,\n", + "31668552784983283483593666924944066737680315058069542500069213700768949573692,\n", + "47303630019147536941220901582952982856517915740884282232588733470564849742080,\n", + "41561182787858915334837446901194440640033856888621022207410120224293681204923,\n", + "40208795410444394963490428737133513683110766973508056822474493355065333491217,\n", + "24620569969402072776192280888011017497854992833864712509770555543278833718751,\n", + "31418811028946653724823259636547682581071379929451162101915628592655152015310,\n", + "25964807298150242099204032696543021731332498792173212422070959505270506288817,\n", + "31766013031271106581980804902159064978010553325475976472264348555438361464655,\n", + "15107529391758643095716794813038523751713309080738989300826699946985294497278,\n", + "26149402682269665088314773514719203730233986608723938665192802061570851149320,\n", + "35053126320072620250684851851709987160095640397875384355477447570643983599564]\n", + "\n", + "def mix(state):\n", + " state[0] = state[0] + state[1] + state[2];\n", + " state[1] = state[0] + state[1];\n", + " state[2] = state[0] + state[2];\n", + " return state\n", + "\n", + "def round_comp(state, idx, full):\n", + " if full:\n", + " state[0] += poseidon_comp_consts[idx]\n", + " state[1] += poseidon_comp_consts[idx + 1]\n", + " state[2] += poseidon_comp_consts[idx + 2]\n", + " # Optimize multiplication\n", + " state[0] = state[0] * state[0] * state[0] * state[0] * state[0]\n", + " state[1] = state[1] * state[1] * state[1] * state[1] * state[1]\n", + " state[2] = state[2] * state[2] * state[2] * state[2] * state[2]\n", + " else:\n", + " state[0] += poseidon_comp_consts[idx]\n", + " state[2] = state[2] * state[2] * state[2] * state[2] * state[2]\n", + " state = mix(state)\n", + " return state\n", + "\n", + "def poseidon_permute_comp_bls(state):\n", + " idx = 0;\n", + " state = mix(state);\n", + "\n", + " # Full rounds\n", + " for i in range(4):\n", + " state = round_comp(state, idx, true);\n", + " idx += 3;\n", + "\n", + " # Partial rounds\n", + " for i in range(56):\n", + " state = round_comp(state, idx, false);\n", + " idx += 1;\n", + "\n", + " # Full rounds\n", + " for i in range(4):\n", + " state= round_comp(state, idx, true);\n", + " idx += 3;\n", + " return state\n", + "\n", + " \n", + "def poseidon_hash_bls(x, y):\n", + " state = [F(x), F(y), F(0)];\n", + " poseidon_permute_comp_bls(state);\n", + " return state[0] + F(x)\n", + "\n", + "\n", + "def poseidon_hash_many_bls(msgs):\n", + " state = [F(0), F(0), F(0)];\n", + " for i in range(len(msgs)//2):\n", + " state[0] += F(msgs[2*i])\n", + " state[1] += F(msgs[2*i + 1])\n", + " state = poseidon_permute_comp_bls(state)\n", + " if len(msgs) % 2 == 1:\n", + " state[0] += F(msgs[len(msgs)-1])\n", + " state[len(msgs)%2] += F(1);\n", + " state = poseidon_permute_comp_bls(state)\n", + " return state[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "67a09953", + "metadata": {}, + "outputs": [], + "source": [ + "def draw_felt252(digest, n_sent):\n", + " res = poseidon_hash_bls(digest, F(n_sent));\n", + " return res\n", + "\n", + "def draw_base_felts(digest, n_sent):\n", + " shift = 1 << 31\n", + " cur = int(draw_felt252(digest, n_sent))\n", + " n_sent += 1\n", + " u32s = []\n", + " quotient = 1\n", + " while(quotient != 0):\n", + " quotient = cur // shift\n", + " if quotient != 0:\n", + " remainder = cur % shift\n", + " cur = quotient\n", + " u32s.append(M31(remainder))\n", + " return u32s" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1bce81d8", + "metadata": {}, + "outputs": [], + "source": [ + "M31 = FiniteField(2**31 - 1)\n", + "Poly_1. = PolynomialRing(M31)\n", + "poly_1 = x**2+1\n", + "CM31. = M31.extension(poly_1)\n", + "CM31\n", + "\n", + "Poly_2. = PolynomialRing(CM31)\n", + "poly_2 = y**2 - 2 - i\n", + "QM31. = CM31.extension(poly_2)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "16aa579c", + "metadata": {}, + "outputs": [], + "source": [ + "# A class to represent points on a circle using fields.\n", + "class CirclePoint:\n", + " def __init__(self, x, y):\n", + " self.x = x\n", + " self.y = y\n", + "\n", + " @staticmethod\n", + " def zero():\n", + " \"\"\"Returns the identity element of the circle.\"\"\"\n", + " return CirclePoint(1, 0)\n", + "\n", + " def double(self):\n", + " \"\"\"Returns the point after doubling.\"\"\"\n", + " return self + self\n", + "\n", + " @staticmethod\n", + " def double_x(x):\n", + " \"\"\"Applies the x-coordinate doubling map.\"\"\"\n", + " sx = x**2\n", + " return sx + sx - 1\n", + "\n", + " def log_order(self):\n", + " \"\"\"Returns the log order of the point (as a power of 2).\"\"\"\n", + " res = 0\n", + " cur = self.x\n", + " while cur != 1:\n", + " cur = self.double_x(cur)\n", + " res += 1\n", + " return res\n", + "\n", + " def mul(self, scalar):\n", + " \"\"\"Multiplies the point by a scalar.\"\"\"\n", + " res = CirclePoint.zero()\n", + " cur = self\n", + " while scalar > 0:\n", + " if scalar & 1 == 1:\n", + " res = res + cur\n", + " cur = cur + cur\n", + " scalar = int(scalar / 2)\n", + " return res\n", + "\n", + " def repeated_double(self, n):\n", + " \"\"\"Returns the point after repeated doubling.\"\"\"\n", + " res = self\n", + " for _ in range(n):\n", + " res = res.double()\n", + " return res\n", + "\n", + " def conjugate(self):\n", + " \"\"\"Returns the conjugate of the point.\"\"\"\n", + " return CirclePoint(self.x, -self.y)\n", + "\n", + " def antipode(self):\n", + " \"\"\"Returns the antipode of the point.\"\"\"\n", + " return CirclePoint(-self.x, -self.y)\n", + "\n", + " def mul_signed(self, off):\n", + " \"\"\"Multiplies the point by a signed scalar.\"\"\"\n", + " if off > 0:\n", + " return self.mul(off)\n", + " else:\n", + " return self.conjugate().mul(-off)\n", + "\n", + " def __add__(self, other):\n", + " \"\"\"Adds two circle points.\"\"\"\n", + " return CirclePoint(self.x * other.x - self.y * other.y, self.x * other.y + self.y * other.x)\n", + "\n", + " def __repr__(self):\n", + " return f\"CirclePoint(x={self.x}, y={self.y})\"" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "d9371b02", + "metadata": {}, + "outputs": [], + "source": [ + "# First Fiat-Shamir\n", + "digest = poseidon_hash_bls(0,proof[\"commitments\"][0])\n", + "\n", + "# Draw first random coeff\n", + "random_coeff = QM31([draw_base_felts(digest,0)[:2],draw_base_felts(digest,0)[2:4]])\n", + "\n", + "# Second Fiat-Shamir\n", + "digest = poseidon_hash_bls(digest,proof[\"commitments\"][1])\n", + "\n", + "# Draw OODS sample point\n", + "t = QM31([draw_base_felts(digest,0)[:2],draw_base_felts(digest,0)[2:4]])\n", + "t_square = t**2\n", + "one_plus_tsquared_inv = (t_square + QM31(1)) ** (-1)\n", + "x = (1 - t_square) * one_plus_tsquared_inv\n", + "y = 2*t*one_plus_tsquared_inv\n", + "oods_point = CirclePoint(x,y)\n", + "\n", + "# Derive sample points for now offset is always 0 so sampled points are all equals oods_point\n", + "\n", + "# Load sampled_values_1 and verify that they are in QM31\n", + "\n", + "# Then we verify the composition polynomial evaluation:\n", + "point = oods_point\n", + "mask_values = proof[\"sampled_values_0\"] + proof[\"sampled_values_1\"]\n", + "\n", + "accumulator = random_coeff\n", + "\n", + "# Evaluate random point on the vanishing polynomial of the coset\n", + "evaluation = point.x\n", + "for i in range(5):\n", + " evaluation = CirclePoint.double_x(evaluation)\n", + "evaluation_inverse = evaluation ** (-1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3adb682e", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "59b44767", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1001989877*i + 1100649138)*u + 462165474*i + 673026348" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "evaluation_inverse" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "7790a0f9", + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'PolynomialQuotientRing_field_with_category.element_class' object has no attribute 'double'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/tmp/ipykernel_10017/940279797.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[0ma\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mQM31\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mQM31\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2129160320\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1109509513\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m787887008\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1676461964\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0ma\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdouble\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdouble\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[1;32m/usr/lib/python3/dist-packages/sage/structure/element.pyx\u001b[0m in \u001b[0;36msage.structure.element.Element.__getattr__ (build/cythonized/sage/structure/element.c:4827)\u001b[1;34m()\u001b[0m\n\u001b[0;32m 492\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;34m'LeftZeroSemigroup_with_category.element_class'\u001b[0m \u001b[0mobject\u001b[0m \u001b[0mhas\u001b[0m \u001b[0mno\u001b[0m \u001b[0mattribute\u001b[0m \u001b[1;34m'blah_blah'\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 493\u001b[0m \"\"\"\n\u001b[1;32m--> 494\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgetattr_from_category\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 495\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 496\u001b[0m \u001b[0mcdef\u001b[0m \u001b[0mgetattr_from_category\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m/usr/lib/python3/dist-packages/sage/structure/element.pyx\u001b[0m in \u001b[0;36msage.structure.element.Element.getattr_from_category (build/cythonized/sage/structure/element.c:4939)\u001b[1;34m()\u001b[0m\n\u001b[0;32m 505\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 506\u001b[0m \u001b[0mcls\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mP\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_abstract_element_class\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 507\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mgetattr_from_other_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcls\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 508\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 509\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__dir__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m/usr/lib/python3/dist-packages/sage/cpython/getattr.pyx\u001b[0m in \u001b[0;36msage.cpython.getattr.getattr_from_other_class (build/cythonized/sage/cpython/getattr.c:2636)\u001b[1;34m()\u001b[0m\n\u001b[0;32m 354\u001b[0m \u001b[0mdummy_error_message\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcls\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtype\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 355\u001b[0m \u001b[0mdummy_error_message\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mname\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 356\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdummy_error_message\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 357\u001b[0m \u001b[0mcdef\u001b[0m \u001b[0mPyObject\u001b[0m\u001b[1;33m*\u001b[0m \u001b[0mattr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minstance_getattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 358\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mattr\u001b[0m \u001b[1;32mis\u001b[0m \u001b[0mNULL\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mAttributeError\u001b[0m: 'PolynomialQuotientRing_field_with_category.element_class' object has no attribute 'double'" + ] + } + ], + "source": [ + "a = QM31([[1,0],[0,0]])\n", + "b = QM31([[2129160320,1109509513],[787887008,1676461964]])\n", + "a.double() + b.double()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f95a975", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "[[\"1\",\"0\",\"0\",\"0\"],\n", + "\t\t[\"2129160320\",\"1109509513\",\"787887008\",\"1676461964\"],\n", + "\t\t[\"262908602\",\"915488457\",\"1893945291\",\"1774327476\"]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "SageMath 9.5", + "language": "sage", + "name": "sagemath" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Stwo_wrapper/WORKSPACE b/Stwo_wrapper/WORKSPACE new file mode 100644 index 0000000..e69de29 diff --git a/Stwo_wrapper/crates/prover/Cargo.toml b/Stwo_wrapper/crates/prover/Cargo.toml new file mode 100644 index 0000000..f316cf7 --- /dev/null +++ b/Stwo_wrapper/crates/prover/Cargo.toml @@ -0,0 +1,111 @@ +[package] +name = "stwo-prover" +version.workspace = true +edition.workspace = true + +[features] +parallel = ["rayon"] +slow-tests = [] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +blake2.workspace = true +blake3.workspace = true +bytemuck = { workspace = true, features = ["derive", "extern_crate_alloc"] } +cfg-if = "1.0.0" +downcast-rs = "1.2" +educe.workspace = true +hex.workspace = true +itertools.workspace = true +num-traits.workspace = true +rand = { version = "0.8.5", default-features = false, features = ["small_rng"] } +starknet-crypto = "0.6.2" +starknet-ff = "0.3.7" +ark-bls12-381 = "0.4.0" +ark-ff = "0.4.0" +thiserror.workspace = true +tracing.workspace = true +rayon = { version = "1.10.0", optional = true } +serde = { version = "1.0", features = ["derive"] } +light-poseidon = {path = "../../../../../save/rust/snarks/light-poseidon/light-poseidon"} +crypto-bigint = "0.5.5" +ark-serialize = "0.4.0" +serde_json = "1.0.116" + +[dev-dependencies] +aligned = "0.4.2" +test-log = { version = "0.2.15", features = ["trace"] } +tracing-subscriber = "0.3.18" +[target.'cfg(all(target_family = "wasm", not(target_os = "wasi")))'.dev-dependencies] +wasm-bindgen-test = "0.3.43" + +[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies.criterion] +features = ["html_reports"] +version = "0.5.1" + +# Default features cause compile error: +# "Rayon cannot be used when targeting wasi32. Try disabling default features." +[target.'cfg(target_arch = "wasm32")'.dev-dependencies.criterion] +default-features = false +features = ["html_reports"] +version = "0.5.1" + +[lib] +bench = false +crate-type = ["cdylib", "lib"] + +[lints.rust] +warnings = "deny" +future-incompatible = "deny" +nonstandard-style = "deny" +rust-2018-idioms = "deny" +unused = "deny" + +[[bench]] +harness = false +name = "bit_rev" + +[[bench]] +harness = false +name = "eval_at_point" + +[[bench]] +harness = false +name = "fft" + +[[bench]] +harness = false +name = "field" + +[[bench]] +harness = false +name = "fri" + +[[bench]] +harness = false +name = "lookups" + +[[bench]] +harness = false +name = "matrix" + +[[bench]] +harness = false +name = "merkle" + +[[bench]] +harness = false +name = "poseidon" + +[[bench]] +harness = false +name = "prefix_sum" + +[[bench]] +harness = false +name = "quotients" + +[[bench]] +harness = false +name = "pcs" diff --git a/Stwo_wrapper/crates/prover/benches/README.md b/Stwo_wrapper/crates/prover/benches/README.md new file mode 100644 index 0000000..8e6d73f --- /dev/null +++ b/Stwo_wrapper/crates/prover/benches/README.md @@ -0,0 +1,2 @@ +dev benchmark results can be seen at +https://starkware-libs.github.io/stwo/dev/bench/index.html diff --git a/Stwo_wrapper/crates/prover/benches/bit_rev.rs b/Stwo_wrapper/crates/prover/benches/bit_rev.rs new file mode 100644 index 0000000..6e287e6 --- /dev/null +++ b/Stwo_wrapper/crates/prover/benches/bit_rev.rs @@ -0,0 +1,39 @@ +#![feature(iter_array_chunks)] + +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use itertools::Itertools; +use stwo_prover::core::fields::m31::BaseField; + +pub fn cpu_bit_rev(c: &mut Criterion) { + use stwo_prover::core::utils::bit_reverse; + // TODO(andrew): Consider using same size for all. + const SIZE: usize = 1 << 24; + let data = (0..SIZE).map(BaseField::from).collect_vec(); + c.bench_function("cpu bit_rev 24bit", |b| { + b.iter_batched( + || data.clone(), + |mut data| bit_reverse(&mut data), + BatchSize::LargeInput, + ); + }); +} + +pub fn simd_bit_rev(c: &mut Criterion) { + use stwo_prover::core::backend::simd::bit_reverse::bit_reverse_m31; + use stwo_prover::core::backend::simd::column::BaseColumn; + const SIZE: usize = 1 << 26; + let data = (0..SIZE).map(BaseField::from).collect::(); + c.bench_function("simd bit_rev 26bit", |b| { + b.iter_batched( + || data.data.clone(), + |mut data| bit_reverse_m31(&mut data), + BatchSize::LargeInput, + ); + }); +} + +criterion_group!( + name = bit_rev; + config = Criterion::default().sample_size(10); + targets = simd_bit_rev, cpu_bit_rev); +criterion_main!(bit_rev); diff --git a/Stwo_wrapper/crates/prover/benches/eval_at_point.rs b/Stwo_wrapper/crates/prover/benches/eval_at_point.rs new file mode 100644 index 0000000..64d1eec --- /dev/null +++ b/Stwo_wrapper/crates/prover/benches/eval_at_point.rs @@ -0,0 +1,35 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; +use stwo_prover::core::backend::cpu::CpuBackend; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::circle::CirclePoint; +use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::poly::circle::{CirclePoly, PolyOps}; + +const LOG_SIZE: u32 = 20; + +fn bench_eval_at_secure_point(c: &mut Criterion, id: &str) { + let poly = CirclePoly::new((0..1 << LOG_SIZE).map(BaseField::from).collect()); + let mut rng = SmallRng::seed_from_u64(0); + let x = rng.gen(); + let y = rng.gen(); + let point = CirclePoint { x, y }; + c.bench_function( + &format!("{id} eval_at_secure_field_point 2^{LOG_SIZE}"), + |b| { + b.iter(|| B::eval_at_point(black_box(&poly), black_box(point))); + }, + ); +} + +fn eval_at_secure_point_benches(c: &mut Criterion) { + bench_eval_at_secure_point::(c, "simd"); + bench_eval_at_secure_point::(c, "cpu"); +} + +criterion_group!( + name = benches; + config = Criterion::default().sample_size(10); + targets = eval_at_secure_point_benches); +criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/fft.rs b/Stwo_wrapper/crates/prover/benches/fft.rs new file mode 100644 index 0000000..35841d7 --- /dev/null +++ b/Stwo_wrapper/crates/prover/benches/fft.rs @@ -0,0 +1,131 @@ +#![feature(iter_array_chunks)] + +use std::hint::black_box; +use std::mem::{size_of_val, transmute}; + +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, Throughput}; +use itertools::Itertools; +use stwo_prover::core::backend::simd::column::BaseColumn; +use stwo_prover::core::backend::simd::fft::ifft::{ + get_itwiddle_dbls, ifft, ifft3_loop, ifft_vecwise_loop, +}; +use stwo_prover::core::backend::simd::fft::rfft::{fft, get_twiddle_dbls}; +use stwo_prover::core::backend::simd::fft::transpose_vecs; +use stwo_prover::core::backend::simd::m31::PackedBaseField; +use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::poly::circle::CanonicCoset; + +pub fn simd_ifft(c: &mut Criterion) { + let mut group = c.benchmark_group("iffts"); + + for log_size in 16..=28 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); + let twiddle_dbls_refs = twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(); + let values: BaseColumn = (0..domain.size()).map(BaseField::from).collect(); + group.throughput(Throughput::Bytes(size_of_val(&*values.data) as u64)); + group.bench_function(BenchmarkId::new("simd ifft", log_size), |b| { + b.iter_batched( + || values.clone().data, + |mut data| unsafe { + ifft( + transmute(data.as_mut_ptr()), + black_box(&twiddle_dbls_refs), + black_box(log_size as usize), + ); + }, + BatchSize::LargeInput, + ) + }); + } +} + +pub fn simd_ifft_parts(c: &mut Criterion) { + const LOG_SIZE: u32 = 14; + + let domain = CanonicCoset::new(LOG_SIZE).circle_domain(); + let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); + let twiddle_dbls_refs = twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(); + let values: BaseColumn = (0..domain.size()).map(BaseField::from).collect(); + + let mut group = c.benchmark_group("ifft parts"); + + // Note: These benchmarks run only on 2^LOG_SIZE elements because of their parameters. + // Increasing the figure above won't change the runtime of these benchmarks. + group.throughput(Throughput::Bytes(4 << LOG_SIZE)); + group.bench_function(format!("simd ifft_vecwise_loop 2^{LOG_SIZE}"), |b| { + b.iter_batched( + || values.clone().data, + |mut values| unsafe { + ifft_vecwise_loop( + transmute(values.as_mut_ptr()), + black_box(&twiddle_dbls_refs), + black_box(9), + black_box(0), + ) + }, + BatchSize::LargeInput, + ); + }); + group.bench_function(format!("simd ifft3_loop 2^{LOG_SIZE}"), |b| { + b.iter_batched( + || values.clone().data, + |mut values| unsafe { + ifft3_loop( + transmute(values.as_mut_ptr()), + black_box(&twiddle_dbls_refs[3..]), + black_box(7), + black_box(4), + black_box(0), + ) + }, + BatchSize::LargeInput, + ); + }); + + const TRANSPOSE_LOG_SIZE: u32 = 20; + let transpose_values: BaseColumn = (0..1 << TRANSPOSE_LOG_SIZE).map(BaseField::from).collect(); + group.throughput(Throughput::Bytes(4 << TRANSPOSE_LOG_SIZE)); + group.bench_function(format!("simd transpose_vecs 2^{TRANSPOSE_LOG_SIZE}"), |b| { + b.iter_batched( + || transpose_values.clone().data, + |mut values| unsafe { + transpose_vecs( + transmute(values.as_mut_ptr()), + black_box(TRANSPOSE_LOG_SIZE as usize - 4), + ) + }, + BatchSize::LargeInput, + ); + }); +} + +pub fn simd_rfft(c: &mut Criterion) { + const LOG_SIZE: u32 = 20; + + let domain = CanonicCoset::new(LOG_SIZE).circle_domain(); + let twiddle_dbls = get_twiddle_dbls(domain.half_coset); + let twiddle_dbls_refs = twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(); + let values: BaseColumn = (0..domain.size()).map(BaseField::from).collect(); + + c.bench_function("simd rfft 20bit", |b| { + b.iter_with_large_drop(|| unsafe { + let mut target = Vec::::with_capacity(values.data.len()); + #[allow(clippy::uninit_vec)] + target.set_len(values.data.len()); + + fft( + black_box(transmute(values.data.as_ptr())), + transmute(target.as_mut_ptr()), + black_box(&twiddle_dbls_refs), + black_box(LOG_SIZE as usize), + ) + }); + }); +} + +criterion_group!( + name = benches; + config = Criterion::default().sample_size(10); + targets = simd_ifft, simd_ifft_parts, simd_rfft); +criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/field.rs b/Stwo_wrapper/crates/prover/benches/field.rs new file mode 100644 index 0000000..acb318c --- /dev/null +++ b/Stwo_wrapper/crates/prover/benches/field.rs @@ -0,0 +1,150 @@ +use criterion::{criterion_group, criterion_main, Criterion}; +use num_traits::One; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; +use stwo_prover::core::backend::simd::m31::{PackedBaseField, N_LANES}; +use stwo_prover::core::fields::cm31::CM31; +use stwo_prover::core::fields::m31::{BaseField, M31}; +use stwo_prover::core::fields::qm31::SecureField; + +pub const N_ELEMENTS: usize = 1 << 16; +pub const N_STATE_ELEMENTS: usize = 8; + +pub fn m31_operations_bench(c: &mut Criterion) { + let mut rng = SmallRng::seed_from_u64(0); + let elements: Vec = (0..N_ELEMENTS).map(|_| rng.gen()).collect(); + let mut state: [M31; N_STATE_ELEMENTS] = rng.gen(); + + c.bench_function("M31 mul", |b| { + b.iter(|| { + for elem in &elements { + for _ in 0..128 { + for state_elem in &mut state { + *state_elem *= *elem; + } + } + } + }) + }); + + c.bench_function("M31 add", |b| { + b.iter(|| { + for elem in &elements { + for _ in 0..128 { + for state_elem in &mut state { + *state_elem += *elem; + } + } + } + }) + }); +} + +pub fn cm31_operations_bench(c: &mut Criterion) { + let mut rng = SmallRng::seed_from_u64(0); + let elements: Vec = (0..N_ELEMENTS).map(|_| rng.gen()).collect(); + let mut state: [CM31; N_STATE_ELEMENTS] = rng.gen(); + + c.bench_function("CM31 mul", |b| { + b.iter(|| { + for elem in &elements { + for _ in 0..128 { + for state_elem in &mut state { + *state_elem *= *elem; + } + } + } + }) + }); + + c.bench_function("CM31 add", |b| { + b.iter(|| { + for elem in &elements { + for _ in 0..128 { + for state_elem in &mut state { + *state_elem += *elem; + } + } + } + }) + }); +} + +pub fn qm31_operations_bench(c: &mut Criterion) { + let mut rng = SmallRng::seed_from_u64(0); + let elements: Vec = (0..N_ELEMENTS).map(|_| rng.gen()).collect(); + let mut state: [SecureField; N_STATE_ELEMENTS] = rng.gen(); + + c.bench_function("SecureField mul", |b| { + b.iter(|| { + for elem in &elements { + for _ in 0..128 { + for state_elem in &mut state { + *state_elem *= *elem; + } + } + } + }) + }); + + c.bench_function("SecureField add", |b| { + b.iter(|| { + for elem in &elements { + for _ in 0..128 { + for state_elem in &mut state { + *state_elem += *elem; + } + } + } + }) + }); +} + +pub fn simd_m31_operations_bench(c: &mut Criterion) { + let mut rng = SmallRng::seed_from_u64(0); + let elements: Vec = (0..N_ELEMENTS / N_LANES).map(|_| rng.gen()).collect(); + let mut states = vec![PackedBaseField::broadcast(BaseField::one()); N_STATE_ELEMENTS]; + + c.bench_function("mul_simd", |b| { + b.iter(|| { + for elem in elements.iter() { + for _ in 0..128 { + for state in states.iter_mut() { + *state *= *elem; + } + } + } + }) + }); + + c.bench_function("add_simd", |b| { + b.iter(|| { + for elem in elements.iter() { + for _ in 0..128 { + for state in states.iter_mut() { + *state += *elem; + } + } + } + }) + }); + + c.bench_function("sub_simd", |b| { + b.iter(|| { + for elem in elements.iter() { + for _ in 0..128 { + for state in states.iter_mut() { + *state -= *elem; + } + } + } + }) + }); +} + +criterion_group!( + name = benches; + config = Criterion::default().sample_size(10); + targets = m31_operations_bench, cm31_operations_bench, qm31_operations_bench, + simd_m31_operations_bench); +criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/fri.rs b/Stwo_wrapper/crates/prover/benches/fri.rs new file mode 100644 index 0000000..1c38a0e --- /dev/null +++ b/Stwo_wrapper/crates/prover/benches/fri.rs @@ -0,0 +1,35 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use stwo_prover::core::backend::CpuBackend; +use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::fields::secure_column::SecureColumnByCoords; +use stwo_prover::core::fri::FriOps; +use stwo_prover::core::poly::circle::{CanonicCoset, PolyOps}; +use stwo_prover::core::poly::line::{LineDomain, LineEvaluation}; + +fn folding_benchmark(c: &mut Criterion) { + const LOG_SIZE: u32 = 12; + let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset()); + let evals = LineEvaluation::new( + domain, + SecureColumnByCoords { + columns: std::array::from_fn(|i| { + vec![BaseField::from_u32_unchecked(i as u32); 1 << LOG_SIZE] + }), + }, + ); + let alpha = SecureField::from_u32_unchecked(2213980, 2213981, 2213982, 2213983); + let twiddles = CpuBackend::precompute_twiddles(domain.coset()); + c.bench_function("fold_line", |b| { + b.iter(|| { + black_box(CpuBackend::fold_line( + black_box(&evals), + black_box(alpha), + &twiddles, + )); + }) + }); +} + +criterion_group!(benches, folding_benchmark); +criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/lookups.rs b/Stwo_wrapper/crates/prover/benches/lookups.rs new file mode 100644 index 0000000..ac45a95 --- /dev/null +++ b/Stwo_wrapper/crates/prover/benches/lookups.rs @@ -0,0 +1,104 @@ +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use rand::distributions::{Distribution, Standard}; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::backend::CpuBackend; +use stwo_prover::core::channel::Blake2sChannel; +use stwo_prover::core::fields::Field; +use stwo_prover::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; +use stwo_prover::core::lookups::mle::{Mle, MleOps}; + +const LOG_N_ROWS: u32 = 16; + +fn bench_gkr_grand_product(c: &mut Criterion, id: &str) { + let mut rng = SmallRng::seed_from_u64(0); + let layer = Layer::::GrandProduct(gen_random_mle(&mut rng, LOG_N_ROWS)); + c.bench_function(&format!("{id} grand product lookup 2^{LOG_N_ROWS}"), |b| { + b.iter_batched( + || layer.clone(), + |layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]), + BatchSize::LargeInput, + ) + }); + c.bench_function( + &format!("{id} grand product lookup batch 4x 2^{LOG_N_ROWS}"), + |b| { + b.iter_batched( + || vec![layer.clone(), layer.clone(), layer.clone(), layer.clone()], + |layers| prove_batch(&mut Blake2sChannel::default(), layers), + BatchSize::LargeInput, + ) + }, + ); +} + +fn bench_gkr_logup_generic(c: &mut Criterion, id: &str) { + let mut rng = SmallRng::seed_from_u64(0); + let generic_layer = Layer::::LogUpGeneric { + numerators: gen_random_mle(&mut rng, LOG_N_ROWS), + denominators: gen_random_mle(&mut rng, LOG_N_ROWS), + }; + c.bench_function(&format!("{id} generic logup lookup 2^{LOG_N_ROWS}"), |b| { + b.iter_batched( + || generic_layer.clone(), + |layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]), + BatchSize::LargeInput, + ) + }); +} + +fn bench_gkr_logup_multiplicities(c: &mut Criterion, id: &str) { + let mut rng = SmallRng::seed_from_u64(0); + let multiplicities_layer = Layer::::LogUpMultiplicities { + numerators: gen_random_mle(&mut rng, LOG_N_ROWS), + denominators: gen_random_mle(&mut rng, LOG_N_ROWS), + }; + c.bench_function( + &format!("{id} multiplicities logup lookup 2^{LOG_N_ROWS}"), + |b| { + b.iter_batched( + || multiplicities_layer.clone(), + |layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]), + BatchSize::LargeInput, + ) + }, + ); +} + +fn bench_gkr_logup_singles(c: &mut Criterion, id: &str) { + let mut rng = SmallRng::seed_from_u64(0); + let singles_layer = Layer::::LogUpSingles { + denominators: gen_random_mle(&mut rng, LOG_N_ROWS), + }; + c.bench_function(&format!("{id} singles logup lookup 2^{LOG_N_ROWS}"), |b| { + b.iter_batched( + || singles_layer.clone(), + |layer| prove_batch(&mut Blake2sChannel::default(), vec![layer]), + BatchSize::LargeInput, + ) + }); +} + +/// Generates a random multilinear polynomial. +fn gen_random_mle, F: Field>(rng: &mut impl Rng, n_variables: u32) -> Mle +where + Standard: Distribution, +{ + Mle::new((0..1 << n_variables).map(|_| rng.gen()).collect()) +} + +fn gkr_lookup_benches(c: &mut Criterion) { + bench_gkr_grand_product::(c, "simd"); + bench_gkr_logup_generic::(c, "simd"); + bench_gkr_logup_multiplicities::(c, "simd"); + bench_gkr_logup_singles::(c, "simd"); + + bench_gkr_grand_product::(c, "cpu"); + bench_gkr_logup_generic::(c, "cpu"); + bench_gkr_logup_multiplicities::(c, "cpu"); + bench_gkr_logup_singles::(c, "cpu"); +} + +criterion_group!(benches, gkr_lookup_benches); +criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/matrix.rs b/Stwo_wrapper/crates/prover/benches/matrix.rs new file mode 100644 index 0000000..8e44a98 --- /dev/null +++ b/Stwo_wrapper/crates/prover/benches/matrix.rs @@ -0,0 +1,63 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; +use stwo_prover::core::fields::m31::{M31, P}; +use stwo_prover::core::fields::qm31::QM31; +use stwo_prover::math::matrix::{RowMajorMatrix, SquareMatrix}; + +const MATRIX_SIZE: usize = 24; +const QM31_MATRIX_SIZE: usize = 6; + +// TODO(ShaharS): Share code with other benchmarks. +fn row_major_matrix_multiplication_bench(c: &mut Criterion) { + let mut rng = SmallRng::seed_from_u64(0); + + let matrix_m31 = RowMajorMatrix::::new( + (0..MATRIX_SIZE.pow(2)) + .map(|_| rng.gen()) + .collect::>(), + ); + + let matrix_qm31 = RowMajorMatrix::::new( + (0..QM31_MATRIX_SIZE.pow(2)) + .map(|_| rng.gen()) + .collect::>(), + ); + + // Create vector M31. + let vec: [M31; MATRIX_SIZE] = rng.gen(); + + // Create vector QM31. + let vec_qm31: [QM31; QM31_MATRIX_SIZE] = [(); QM31_MATRIX_SIZE].map(|_| { + QM31::from_u32_unchecked( + rng.gen::() % P, + rng.gen::() % P, + rng.gen::() % P, + rng.gen::() % P, + ) + }); + + // bench matrix multiplication. + c.bench_function( + &format!("RowMajorMatrix M31 {size}x{size} mul", size = MATRIX_SIZE), + |b| { + b.iter(|| { + black_box(matrix_m31.mul(vec)); + }) + }, + ); + c.bench_function( + &format!( + "QM31 RowMajorMatrix {size}x{size} mul", + size = QM31_MATRIX_SIZE + ), + |b| { + b.iter(|| { + black_box(matrix_qm31.mul(vec_qm31)); + }) + }, + ); +} + +criterion_group!(benches, row_major_matrix_multiplication_bench); +criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/merkle.rs b/Stwo_wrapper/crates/prover/benches/merkle.rs new file mode 100644 index 0000000..c039be7 --- /dev/null +++ b/Stwo_wrapper/crates/prover/benches/merkle.rs @@ -0,0 +1,38 @@ +#![feature(iter_array_chunks)] + +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use itertools::Itertools; +use num_traits::Zero; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::backend::{Col, CpuBackend}; +use stwo_prover::core::fields::m31::{BaseField, N_BYTES_FELT}; +use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleHasher; +use stwo_prover::core::vcs::ops::MerkleOps; + +const LOG_N_ROWS: u32 = 16; + +const LOG_N_COLS: u32 = 8; + +fn bench_blake2s_merkle>(c: &mut Criterion, id: &str) { + let col: Col = (0..1 << LOG_N_ROWS).map(|_| BaseField::zero()).collect(); + let cols = (0..1 << LOG_N_COLS).map(|_| col.clone()).collect_vec(); + let col_refs = cols.iter().collect_vec(); + let mut group = c.benchmark_group("merkle throughput"); + let n_elements = 1 << (LOG_N_COLS + LOG_N_ROWS); + group.throughput(Throughput::Elements(n_elements)); + group.throughput(Throughput::Bytes(N_BYTES_FELT as u64 * n_elements)); + group.bench_function(&format!("{id} merkle"), |b| { + b.iter_with_large_drop(|| B::commit_on_layer(LOG_N_ROWS, None, &col_refs)) + }); +} + +fn blake2s_merkle_benches(c: &mut Criterion) { + bench_blake2s_merkle::(c, "simd"); + bench_blake2s_merkle::(c, "cpu"); +} + +criterion_group!( + name = benches; + config = Criterion::default().sample_size(10); + targets = blake2s_merkle_benches); +criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/pcs.rs b/Stwo_wrapper/crates/prover/benches/pcs.rs new file mode 100644 index 0000000..da185d7 --- /dev/null +++ b/Stwo_wrapper/crates/prover/benches/pcs.rs @@ -0,0 +1,81 @@ +use std::iter; + +use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::backend::{BackendForChannel, CpuBackend}; +use stwo_prover::core::channel::Blake2sChannel; +use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::pcs::CommitmentTreeProver; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use stwo_prover::core::poly::twiddles::TwiddleTree; +use stwo_prover::core::poly::BitReversedOrder; +use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; + +const LOG_COSET_SIZE: u32 = 20; +const LOG_BLOWUP_FACTOR: u32 = 1; +const N_POLYS: usize = 16; + +fn benched_fn>( + evals: Vec>, + channel: &mut Blake2sChannel, + twiddles: &TwiddleTree, +) { + let polys = evals + .into_iter() + .map(|eval| eval.interpolate_with_twiddles(twiddles)) + .collect(); + + CommitmentTreeProver::::new( + polys, + LOG_BLOWUP_FACTOR, + channel, + twiddles, + ); +} + +fn bench_pcs>(c: &mut Criterion, id: &str) { + let small_domain = CanonicCoset::new(LOG_COSET_SIZE); + let big_domain = CanonicCoset::new(LOG_COSET_SIZE + LOG_BLOWUP_FACTOR); + let twiddles = B::precompute_twiddles(big_domain.half_coset()); + let mut channel = Blake2sChannel::default(); + let mut rng = SmallRng::seed_from_u64(0); + + let evals: Vec> = iter::repeat_with(|| { + CircleEvaluation::new( + small_domain.circle_domain(), + (0..1 << LOG_COSET_SIZE).map(|_| rng.gen()).collect(), + ) + }) + .take(N_POLYS) + .collect(); + + c.bench_function( + &format!("{id} polynomial commitment 2^{LOG_COSET_SIZE}"), + |b| { + b.iter_batched( + || evals.clone(), + |evals| { + benched_fn::( + black_box(evals), + black_box(&mut channel), + black_box(&twiddles), + ) + }, + BatchSize::LargeInput, + ); + }, + ); +} + +fn pcs_benches(c: &mut Criterion) { + bench_pcs::(c, "simd"); + bench_pcs::(c, "cpu"); +} + +criterion_group!( + name = benches; + config = Criterion::default().sample_size(10); + targets = pcs_benches); +criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/poseidon.rs b/Stwo_wrapper/crates/prover/benches/poseidon.rs new file mode 100644 index 0000000..bc796c6 --- /dev/null +++ b/Stwo_wrapper/crates/prover/benches/poseidon.rs @@ -0,0 +1,18 @@ +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; +use stwo_prover::core::pcs::PcsConfig; +use stwo_prover::examples::poseidon::prove_poseidon; + +pub fn simd_poseidon(c: &mut Criterion) { + const LOG_N_INSTANCES: u32 = 18; + let mut group = c.benchmark_group("poseidon2"); + group.throughput(Throughput::Elements(1u64 << LOG_N_INSTANCES)); + group.bench_function(format!("poseidon2 2^{} instances", LOG_N_INSTANCES), |b| { + b.iter(|| prove_poseidon(LOG_N_INSTANCES, PcsConfig::default())); + }); +} + +criterion_group!( + name = bit_rev; + config = Criterion::default().sample_size(10); + targets = simd_poseidon); +criterion_main!(bit_rev); diff --git a/Stwo_wrapper/crates/prover/benches/prefix_sum.rs b/Stwo_wrapper/crates/prover/benches/prefix_sum.rs new file mode 100644 index 0000000..7faf4ac --- /dev/null +++ b/Stwo_wrapper/crates/prover/benches/prefix_sum.rs @@ -0,0 +1,19 @@ +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use stwo_prover::core::backend::simd::column::BaseColumn; +use stwo_prover::core::backend::simd::prefix_sum::inclusive_prefix_sum; +use stwo_prover::core::fields::m31::BaseField; + +pub fn simd_prefix_sum_bench(c: &mut Criterion) { + const LOG_SIZE: u32 = 24; + let evals: BaseColumn = (0..1 << LOG_SIZE).map(BaseField::from).collect(); + c.bench_function(&format!("simd prefix_sum 2^{LOG_SIZE}"), |b| { + b.iter_batched( + || evals.clone(), + inclusive_prefix_sum, + BatchSize::LargeInput, + ); + }); +} + +criterion_group!(benches, simd_prefix_sum_bench); +criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/benches/quotients.rs b/Stwo_wrapper/crates/prover/benches/quotients.rs new file mode 100644 index 0000000..fc2949a --- /dev/null +++ b/Stwo_wrapper/crates/prover/benches/quotients.rs @@ -0,0 +1,55 @@ +#![feature(iter_array_chunks)] + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use itertools::Itertools; +use stwo_prover::core::backend::cpu::CpuBackend; +use stwo_prover::core::backend::simd::SimdBackend; +use stwo_prover::core::circle::SECURE_FIELD_CIRCLE_GEN; +use stwo_prover::core::fields::m31::BaseField; +use stwo_prover::core::fields::qm31::SecureField; +use stwo_prover::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; +use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use stwo_prover::core::poly::BitReversedOrder; + +// TODO(andrew): Consider removing const generics and making all sizes the same. +fn bench_quotients( + c: &mut Criterion, + id: &str, +) { + let domain = CanonicCoset::new(LOG_N_ROWS).circle_domain(); + let values = (0..domain.size()).map(BaseField::from).collect(); + let col = CircleEvaluation::::new(domain, values); + let cols = (0..1 << LOG_N_COLS).map(|_| col.clone()).collect_vec(); + let col_refs = cols.iter().collect_vec(); + let random_coeff = SecureField::from_u32_unchecked(0, 1, 2, 3); + let a = SecureField::from_u32_unchecked(5, 6, 7, 8); + let samples = vec![ColumnSampleBatch { + point: SECURE_FIELD_CIRCLE_GEN, + columns_and_values: (0..1 << LOG_N_COLS).map(|i| (i, a)).collect(), + }]; + c.bench_function( + &format!("{id} quotients 2^{LOG_N_COLS} x 2^{LOG_N_ROWS}"), + |b| { + b.iter_with_large_drop(|| { + B::accumulate_quotients( + black_box(domain), + black_box(&col_refs), + black_box(random_coeff), + black_box(&samples), + 1, + ) + }) + }, + ); +} + +fn quotients_benches(c: &mut Criterion) { + bench_quotients::(c, "simd"); + bench_quotients::(c, "cpu"); +} + +criterion_group!( + name = benches; + config = Criterion::default().sample_size(10); + targets = quotients_benches); +criterion_main!(benches); diff --git a/Stwo_wrapper/crates/prover/proof.json b/Stwo_wrapper/crates/prover/proof.json new file mode 100644 index 0000000..aecce7d --- /dev/null +++ b/Stwo_wrapper/crates/prover/proof.json @@ -0,0 +1,348 @@ +{ + "commitments" : + ["34328580272026076035687604093297365442785733592720865218001799813393342152908", + "38388381845372648579572899115609862601821983406101214230086519922780265042634"], + + "sampled_values_0" : + [["1","0","0","0"], + ["2129160320","1109509513","787887008","1676461964"], + ["262908602","915488457","1893945291","1774327476"], + ["894719153","1570509766","1424186619","204092576"], + ["397490811","836398274","1615765624","2013800563"], + ["1022303904","276983775","1064742229","165204856"], + ["1200363525","170838026","524999776","156116441"], + ["850733526","448725560","1521962209","1318190714"], + ["1187866075","1705588092","924088348","490002418"], + ["2033565088","996780784","1820235518","2048788344"], + ["2061590372","1150986157","711772586","1511398564"], + ["1066623954","530384603","1890251380","1699008129"], + ["734047580","1685768538","505142109","787113212"], + ["2030904700","99932423","695391286","1736941035"], + ["1580330105","932031717","1705998668","146411959"], + ["1585732224","1556242253","941668238","1998570239"], + ["199481433","2123320403","1257464748","1663811899"], + ["2139019524","1547107722","728449250","1941851166"], + ["752079023","268472135","1465850435","16510773"], + ["1279312817","63252415","442230579","1560954631"], + ["1074859131","137997593","2118329011","652535723"], + ["297567647","1483381078","1941495981","599737348"], + ["1735543786","1420676479","1354982762","1114211268"], + ["1691705401","1143446295","1748115479","1666756627"], + ["955696743","2077778309","736065989","1319443838"], + ["1076874307","1001483910","1702287354","819727011"], + ["1134989244","1823710400","2067694105","1098263343"], + ["1793642608","961404475","1279773056","1815400043"], + ["739677274","1827877577","838562378","171296720"], + ["2036367121","1901888610","289723252","2014426907"], + ["330020507","436937516","2113056521","1828501207"], + ["1359068814","583899921","734628376","1223217137"], + ["1319501520","1242972089","1202216521","1285024997"], + ["681182370","1569622309","1574376904","1563950435"], + ["1204519566","483612224","1677731115","1667757584"], + ["330284364","917877098","57538161","179869993"], + ["2056561198","119768893","740294154","1454562198"], + ["79009084","545196641","13388962","1973400144"], + ["885977898","1973300145","37115619","957100699"], + ["1937449867","1777683674","1983002799","757662558"], + ["344927561","357845689","26887161","664585634"], + ["1462268220","615463524","209500386","44308852"], + ["570984705","2022111132","1404632615","2119081660"], + ["13183327","1584451280","1216116653","316345540"], + ["1497965915","705236857","1892466476","1068567492"], + ["1758694676","1408790161","1140545981","315723937"], + ["645308461","1125824784","1786470558","1240927727"], + ["1213464061","470930291","1718629724","1149088875"], + ["214577693","1578610321","2133720991","226291629"], + ["1357706729","2097875841","1767996253","1478111500"], + ["1154658683","752162439","2018723944","163997560"], + ["1051993583","703716977","379706674","487262860"], + ["1017692573","2060296775","2001023083","1064213951"], + ["1042587725","1701108370","204550428","904590130"], + ["1115340870","743420370","1927225111","1276396551"], + ["493638626","1874789377","47342513","209203758"], + ["1558586505","83459476","247638703","1975504267"], + ["2097068784","954319448","367516919","1545761518"], + ["1655645294","352838520","1307263981","1110198118"], + ["1169856046","1925368371","1362317240","1926032147"], + ["1940113709","885624001","1395047654","80053995"], + ["1778932990","25092730","201117282","1724571908"], + ["2096327738","233411984","1247443120","713989449"], + ["808532602","136577890","1015579288","38900716"], + ["1182257782","1186245376","1451332036","2080170103"], + ["1662610758","1505542080","1038243031","1889715771"], + ["440146119","942837214","1440484295","1593949278"], + ["46258268","1884246120","164930024","2050584510"], + ["1198954868","1079638495","1424072583","1028611344"], + ["2112984649","1531382496","1873151714","1818301795"], + ["1554382282","253920307","1641628530","1378998084"], + ["857898234","686236793","2091871553","184978860"], + ["2049153599","6111471","1579475775","32492894"], + ["1371356596","679072793","1547377985","354305233"], + ["1799882226","1201472049","1592617716","125534957"], + ["1277144880","253726080","1800145982","1125162267"], + ["1577717920","440984421","1377891036","846453148"], + ["1952731919","1710992214","673668053","1871913638"], + ["1559011028","2060945859","719954448","1356468891"], + ["1961642242","1693473944","1300152522","412222111"], + ["861208187","1242659514","977183954","38730935"], + ["1016984917","1368361439","2106430139","1225979890"], + ["1427754325","1482206106","1465316380","1096279813"], + ["566051043","2025874544","234976335","1482256978"], + ["1750543495","1494541462","374330732","411642241"], + ["230654343","55625728","136463431","1099606808"], + ["1172218793","1260458608","1314942990","75527287"], + ["1824515276","916178746","1300275105","370626746"], + ["915931367","987018043","56193044","617907884"], + ["1934695822","1112844637","609268252","1972086910"], + ["619631651","152029630","1979976905","292597437"], + ["62258350","1890115432","1373605674","1505619938"], + ["1770422019","1398189304","1773172351","1576001433"], + ["650940868","1756047014","1764798953","1146887875"], + ["1746945043","528205234","778346028","1797468521"], + ["760802416","1479409742","1556974632","1307498378"], + ["102511022","1787975482","968854748","1010240763"], + ["330722054","2046294448","14132125","1822414050"], + ["943548871","1770900623","1861740461","1290634078"], + ["1402661415","1361511065","1784889120","837615360"]], + + "sampled_values_1" : + [["712066144","1576368753","626134398","426337436"], + ["160634493","1096735733","992622982","964509862"], + ["208900621","1128739590","1423579079","1688318061"], + ["1029182234","1152361165","571476481","1593867154"]], + + "decommitment_0" : + ["24311567319749512546399129581715033328970605051392227451685196018312506896509", + "8450134967305372517473027560161707471995673370792264153422077885080332622841", + "6431507699794114682519586182713221908058047520896405293833270087934517909753", + "303109001984349840640377328716025252051982378448629744935456455431709129012", + "47167328465744900593371601186109726758160197572292632388959155138584359581158", + "50584492046778480438774038937088410409133167768957478525289857065775850658491", + "30584798499699103841624545814425958941934653399588880797257122471101102880636", + "33441256878213890325682161124370878299436204406591246133637659120215439522803", + "3288124068330032280185519028600654292250668929588668389702892483946668251740", + "29852774919556057664485671676242264613416836486089146650713214180894511265116", + "12482060975231949385592255321766253365687502822944549845564491620341379321204", + "46285234163162336949700608657672147469543559995399282843606812790099228411758", + "20807128972645591294726020136444795908525656782422245307591812614900798799914"], + + "decommitment_1" : + ["18063303111481257844109225560025890393366258018933166919604543575686388632162", + "1676364734386980395984608216327451243278421019544108756198322792517099196249", + "5278661052518480850653886996628582549184134231869598116690316714367933376948", + "21983822689977371558234298346357617674436224016274009820764238516520240403273", + "605332543427376153930374757063581881998320956602375739165671986207155079359", + "33771702906041565783498389127165108212044382608172583325407071671862086994048", + "6930451780154275491146135719028766497977496109537963233244808739657647563071", + "4564117668410212714684125903928600765456322272915099403587425756488534507713", + "337808767671877648828499299861821973796749820854854708379479049898835100991", + "2725840354457305623692571800192492803162041315546256970381708693201407812833", + "34716495111790106826563330917176360656701717867702196654990744866499300990003", + "49719870445464463616785535809529171382800153139923763422202182182379572350737", + "48664746464641275030915461677298150155193593333108431337941029583245720868695", + "32557886668297237033601675259512842580727821006475208499640136322794706303894", + "15014835402414421167586357788116276188694467622586221351644991310645286648480", + "12973705814659120511327850727547995427054971555827754777395732787633567627149", + "15379912850866472398958956306527914195058439699787840041152620034933267404138", + "19859070819439084101412868355121176941090844577507824922960697697668791429525", + "38273559034692361632775489704953448699371080776239846995670381153186834620044"], + + "queried_values_0" : + [["1","1","1","1","1","1"], + ["730457281","730490049","28918683","28885915","1656126010","1656093242"], + ["855614122","1238037465","1836504291","355791428","757095818","467806903"], + ["674179888","1530445315","1720543014","76190330","1475912409","1017215862"], + ["1142290008","1148671853","1619097781","938511401","904357795","257652679"], + ["1679234056","1355264641","2139729457","574756654","604307234","1146556949"], + ["1000500309","2008905806","1442759180","598876729","1786070690","1072293976"], + ["1119085545","2133345582","135683580","216214405","1049766224","943727969"], + ["206423262","2047139937","305085364","1422472664","1826554088","1032095092"], + ["501238882","1656305868","724710382","1949772461","1426787917","585368894"], + ["1005468045","1775577441","1042182076","415631363","1067013227","1635705270"], + ["1776076392","216798814","1525036520","1160666510","1212132211","1915058776"], + ["859923105","1633989410","182110635","2060185314","1084464822","1129902257"], + ["489437802","313401022","271315129","357612175","2050381179","647577687"], + ["1495302158","2052264981","1498165299","1164417520","1050104037","450244199"], + ["1084986392","398966983","808449145","1733554138","2068501028","659474347"], + ["399458768","1789245133","1698759035","188433436","1794535430","364419824"], + ["2013965647","722839714","928854328","124488895","1378959529","952886009"], + ["1334765706","193402268","471076108","640800921","1998121783","961582406"], + ["1067762968","381831281","560459357","1025929344","181659877","1922040224"], + ["1993303462","467991218","849673597","744722836","239634354","329631295"], + ["785794488","1649178388","672964420","1281255462","900602801","271501809"], + ["857859728","1325395820","985014020","1094321795","259553347","774587048"], + ["1214640090","1588569866","871717820","1131833706","1625896842","1635087550"], + ["796549205","931495223","2018253108","1395065060","158209751","1160478135"], + ["883143962","729115354","190207821","839273168","1668931939","2074584689"], + ["1490296658","1846956206","1610364850","56422972","160482417","681872093"], + ["1270585092","1910190167","464113273","613529242","1027101122","1014185686"], + ["1456043179","1999662961","193940913","678382864","39040067","1236859818"], + ["1626243617","901735777","1703169024","911300891","1640727682","1121874896"], + ["492192896","15672698","319327174","1727120334","1965889437","114404366"], + ["407079019","949462637","255390508","1753162095","501134776","1457122467"], + ["1478573872","1439193434","1053200675","1001140887","1553935777","1253681552"], + ["183135520","946237525","1802924023","1831496784","1893117930","1830486286"], + ["234902670","1169030504","196055115","1323151968","855748623","1328842866"], + ["1150999776","1338824346","2072101698","774206263","1967350016","1808817867"], + ["924341552","1430286424","511268814","825025920","1061850574","1954646566"], + ["302634890","314434153","1692670768","1822915313","1244352075","1953834230"], + ["1576167467","687837005","2116136752","144109400","1590157548","1634932462"], + ["396756275","1272134898","1207308240","818219166","1314182589","109494000"], + ["846425160","897737569","757312164","826009489","1019831588","1977463051"], + ["2065801114","1918982367","1548689186","2082631803","298112070","383438809"], + ["1034102289","461735180","2115581275","1343026598","1229979058","1021418523"], + ["1784173874","166635387","547550115","1094693960","573193735","451367040"], + ["119818313","659105018","1741377697","8940733","911200334","511474518"], + ["1511949880","1315119529","1267019200","2134944693","878254810","375758264"], + ["1203050254","156394547","1348568635","412863443","1068659960","1407913814"], + ["361779719","130417374","89109096","117994876","1151322919","863143484"], + ["1007476533","989566160","138644964","1672742874","540141118","1296408100"], + ["1824241144","2051199719","1863718547","2109877864","36689613","1055926854"], + ["791693003","1433717239","991140958","1565955371","1839976870","1163947838"], + ["1267320759","2102593211","1831360854","1691591439","1672201908","61327345"], + ["301343164","277158258","627925439","577508975","1896464649","907629062"], + ["1964268932","929590164","1529686876","68630644","1663063136","254082844"], + ["693529348","1815295486","1660565870","1226857377","156312343","1500907098"], + ["1723158753","252348225","253985470","52424437","1605949937","576572581"], + ["1781792048","1497492716","1951824572","1156925855","863650708","156447987"], + ["876432605","458503399","283092867","247883110","1227074181","966219235"], + ["1581118191","66527915","1577039825","1227961402","1738412997","1862462297"], + ["679458448","338624032","34185999","253532412","65409631","563033132"], + ["1011967612","1898273226","2124401156","105282260","1188226330","123913515"], + ["1432513005","1083162825","1299704150","1184276814","339749370","1064298821"], + ["145070927","1250457746","1977306722","1035124433","322154361","1782232869"], + ["1934256728","1269667423","248999825","267009036","132507662","474021315"], + ["1694871453","535187899","981724703","1180312550","74370795","277702656"], + ["1850841912","1037634121","1967377497","1755127193","1566449422","1939039785"], + ["1315699598","476111459","2058733537","332263289","1592057567","874912616"], + ["774536537","170060616","2086574090","47894465","778021586","2115296942"], + ["322468558","24934377","637275739","1596346002","1896623296","1814433409"], + ["766517428","1263076038","358941187","2070217919","2108397185","1587546402"], + ["345404490","1065320570","1231275245","1037359122","1286389839","2070140848"], + ["746521574","835067673","311114030","1586400488","1406022058","1284151326"], + ["1857315969","431410759","825259098","1717904860","503708539","2097758215"], + ["1879479734","1863555039","2108235515","1833922769","1562707156","49484002"], + ["1366768987","1050390036","1491845132","666041968","74368055","1254335623"], + ["188857287","1161878039","1771805176","1457666227","1157868840","486461459"], + ["261764705","1577846886","1332322961","10423372","640027252","1086814656"], + ["111907709","542625019","2021749229","2013008690","523703611","1328833940"], + ["1270684472","1675989474","394214608","538100201","1984625073","1560563159"], + ["666555709","852557426","651115051","1878827907","953346499","619017191"], + ["747972907","149382079","1393306586","1394823957","960994901","536632180"], + ["80535893","229602380","1817483938","1455260088","484432","1869486290"], + ["443556931","253108261","1609174393","1245931188","752691602","1668543792"], + ["745497042","854686466","1834097777","642389535","1284043061","896553209"], + ["532777064","1491985134","200157005","1378855967","1159213374","1797221037"], + ["933463176","813761538","1124049829","1988347055","2115297439","1836576920"], + ["194436043","1437728625","43998833","786326005","1130428925","1424571033"], + ["2108272636","1410841489","753065553","2020187193","1644376367","670324352"], + ["1362448669","752702510","1740531646","47989265","588634","1940480814"], + ["439960422","528245604","179496898","235775013","59000527","1903150726"], + ["103605138","249162711","1971219628","1958189530","423278905","1318354885"], + ["321504059","1595801356","596911575","1361967073","459661104","599048233"], + ["1610552125","73166668","444776743","1820306524","1180674369","1570356756"], + ["1283703846","1024562975","958477092","1329464736","1758672211","899108631"], + ["713626137","904634570","1902566483","1938333063","447549083","703262660"], + ["1417291696","1717451368","354584524","832751684","930128006","1037860604"], + ["1618745108","79533863","301008038","2091942909","1221962725","1524945081"], + ["1490224031","349040760","393684137","484089443","1912848485","1790207999"], + ["1930411520","642009628","1138820074","31855314","1177766391","1913457637"], + ["618628623","139430131","904498895","925273128","2111653256","1012250155"]], + + "queried_values_1" : [["316772341","1526280133","663010112","224983897","510598760","1109503351"], + ["754832207","435790299","883623752","553207508","154784232","199176676"], + ["689603315","1763523007","1720552945","1983603154","367841669","319325418"], + ["1290247052","1120744584","193500372","294491115","951360807","891034447"]], + + "proof of work" : "43", + + "coeffs" : ["329725079","667313404","2083859876","1645693780"], + + "inner_commitment_0" : "19048435553851756854966583714228494720706220237110109487981675465058006706934", + + "inner_decommitment_0" : + ["5462985033728555575703006689913665598917262836577853690370070265073174719979", + "1439463537898028163672031322449473184361454650351831891678420729829634422052", + "3457992889375232257419443221173459860069710025424105705342440558952480618733", + "1560236756242436900530359200127475252176167690738804367783218449678388748008", + "44704035195642947074853484428073358106137689624692539529545008114744235429869", + "1823722396825684657353300460130959150013626547749061669644787096209462088081", + "36173961213090661587166983586606381085726394354099487131775975739295244107044", + "44982414287882497869227873977178787463294164719732705041425254748922385395330", + "45309152665928987599895709328729600158686240484613667382186228529382344291065", + "6727671460449390719545303211873485738734289534911560838410979450236173250575", + "20261836592905556803606993653450529841562498576647326913889589677099044443740", + "5092741370842944151567575040255811031129982749889308756570945791769905531735", + "25669118913473215056777887930215946759071827607816131638248897691395809170466"], + + "inner_evals_subset_0" : + [["1779738283","1440487135","622229563","2027928845"], + ["3310770","2003547458","1663490902","2105455978"], + ["824310523","1757518542","231582441","427507918"]], + + "inner_commitment_1" : "18169576490546341141767046171814645645044172101340402424909425947818564934087", + + "inner_decommitment_1" : + ["45836684762941279488847688946861562185184567552524644394596887832754938056979", + "19073399120361742411025793373596546929954434607182120440274945820526786807031", + "10149408627738268983458395386544516732232569888429637086606165543488301475287", + "1134561121197915913438467336074096605141057500441242329860673894339623319838", + "4117331370612733116088255919125911273783755080775248279590234839282346655209", + "10513576479117646185507147947418059201301955276798193127377481002238952180861", + "23342730846670478933566004300392843380760787769231395287565422468403932432226", + "18740447275326464369699774196183138103532792867824812360537906002797944316369", + "14414140432147963417180883803716587948502362440108737731008978168633980398219", + "12250577863311462987061070846596950168057013378171600204130942328863618210853"], + + "inner_evals_subset_1" : + [["1993244979","993480712","1202910330","51538179"], + ["1550507975","1444313548","117070947","1740854590"], + ["342709844","601149328","1436490544","1384381104"]], + + "inner_commitment_2" : "29215190960441505077529177935027039424488733477246157993605964254003539615792", + + "inner_decommitment_2" : + ["3578061893060038632121836895066391994380789049478144565795630968916166958170", + "37159117331259094338569510749131708143302385613991325356583099983565421599794", + "37757292341028790385725896863222501721662517262004109998718347744172059822092", + "32039058093629437909283983486827138804118789892152561698628703527746673098335", + "11826129489683399268430621966121477905203281789801973078750629518404118320344", + "12209931412475259378582945309327673629248630080862144675946167350377528046777", + "23652788355891929123107872712785512551735081591713325931570432712643436668697"], + + "inner_evals_subset_2" : + [["1184374976","203117688","803515935","1781630737"], + ["716686867","138132852","2024080584","392488646"], + ["606958913","308986056","258114411","2075401741"]], + + "inner_commitment_3" : "33865913108527976063513611264849688092773721821534885464172308921150897546371", + + "inner_decommitment_3" : + ["46510393127320994984678433868779353751380750916123221099220042551249644407575", + "39205406309151711272632862117612882165913578213907083849769292791373274291092", + "1794620160371318211300608665903463515892826349097148444913826356018179256828", + "30764929521910551485060681298929856016753563990361023701791813032517030674324"], + + "inner_evals_subset_3" : + [["1682558787","129089003","784689440","491206249"], + ["2038210709","1600238918","655676259","1542271403"], + ["1014579052","1384080403","862591487","1941843578"]], + + "inner_commitment_4" : "233063347401903348432987619086774265933828676513621676247992038308030708388", + + "inner_decommitment_4" : + ["19324767949149751902195880760061491860991545124249052461044586503970003688610"], + + "inner_evals_subset_4" : + [["683258805","1002722262","1583421272","1748673499"], + ["2101847208","689925082","1602280602","1942656531"], + ["1952987285","1995490213","2082219584","1620868519"]], + + "inner_commitment_5" : "693572572477915449222328871061568563307509948116711167874020917953971074657", + + "inner_decommitment_5" : + [], + + "inner_evals_subset_5" : + [["1862940297","2014478284","1043383827","560191545"]] +} \ No newline at end of file diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/assert.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/assert.rs new file mode 100644 index 0000000..9e5530b --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/constraint_framework/assert.rs @@ -0,0 +1,84 @@ +use num_traits::{One, Zero}; + +use super::EvalAtRow; +use crate::core::backend::{Backend, Column}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::pcs::TreeVec; +use crate::core::poly::circle::{CanonicCoset, CirclePoly}; +use crate::core::utils::circle_domain_order_to_coset_order; + +/// Evaluates expressions at a trace domain row, and asserts constraints. Mainly used for testing. +pub struct AssertEvaluator<'a> { + pub trace: &'a TreeVec>>, + pub col_index: TreeVec, + pub row: usize, +} +impl<'a> AssertEvaluator<'a> { + pub fn new(trace: &'a TreeVec>>, row: usize) -> Self { + Self { + trace, + col_index: TreeVec::new(vec![0; trace.len()]), + row, + } + } +} +impl<'a> EvalAtRow for AssertEvaluator<'a> { + type F = BaseField; + type EF = SecureField; + + fn next_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::F; N] { + let col_index = self.col_index[interaction]; + self.col_index[interaction] += 1; + offsets.map(|off| { + // The mask row might wrap around the column size. + let col_size = self.trace[interaction][col_index].len() as isize; + self.trace[interaction][col_index] + [(self.row as isize + off).rem_euclid(col_size) as usize] + }) + } + + fn add_constraint(&mut self, constraint: G) + where + Self::EF: std::ops::Mul, + { + // Cast to SecureField. + let res = SecureField::one() * constraint; + // The constraint should be zero at the given row, since we are evaluating on the trace + // domain. + assert_eq!(res, SecureField::zero(), "row: {}", self.row); + } + + fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { + SecureField::from_m31_array(values) + } +} + +pub fn assert_constraints( + trace_polys: &TreeVec>>, + trace_domain: CanonicCoset, + assert_func: impl Fn(AssertEvaluator<'_>), +) { + let traces = trace_polys.as_ref().map(|tree| { + tree.iter() + .map(|poly| { + circle_domain_order_to_coset_order( + &poly + .evaluate(trace_domain.circle_domain()) + .bit_reverse() + .values + .to_cpu(), + ) + }) + .collect() + }); + for row in 0..trace_domain.size() { + let eval = AssertEvaluator::new(&traces, row); + assert_func(eval); + } +} diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/component.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/component.rs new file mode 100644 index 0000000..c0d8319 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/constraint_framework/component.rs @@ -0,0 +1,210 @@ +use std::borrow::Cow; +use std::iter::zip; +use std::ops::Deref; + +use itertools::Itertools; +use tracing::{span, Level}; + +use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator}; +use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; +use crate::core::air::{Component, ComponentProver, Trace}; +use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS}; +use crate::core::backend::simd::SimdBackend; +use crate::core::circle::CirclePoint; +use crate::core::constraints::coset_vanishing; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldExpOps; +use crate::core::pcs::{TreeSubspan, TreeVec}; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; +use crate::core::poly::BitReversedOrder; +use crate::core::{utils, ColumnVec}; + +// TODO(andrew): Docs. +// TODO(andrew): Consider better location for this. +#[derive(Debug, Default)] +pub struct TraceLocationAllocator { + /// Mapping of tree index to next available column offset. + next_tree_offsets: TreeVec, +} + +impl TraceLocationAllocator { + fn next_for_structure(&mut self, structure: &TreeVec>) -> TreeVec { + if structure.len() > self.next_tree_offsets.len() { + self.next_tree_offsets.resize(structure.len(), 0); + } + + TreeVec::new( + zip(&mut *self.next_tree_offsets, &**structure) + .enumerate() + .map(|(tree_index, (offset, cols))| { + let col_start = *offset; + let col_end = col_start + cols.len(); + *offset = col_end; + TreeSubspan { + tree_index, + col_start, + col_end, + } + }) + .collect(), + ) + } +} + +/// A component defined solely in means of the constraints framework. +/// Implementing this trait introduces implementations for [`Component`] and [`ComponentProver`] for +/// the SIMD backend. +/// Note that the constraint framework only support components with columns of the same size. +pub trait FrameworkEval { + fn log_size(&self) -> u32; + + fn max_constraint_log_degree_bound(&self) -> u32; + + fn evaluate(&self, eval: E) -> E; +} + +pub struct FrameworkComponent { + eval: C, + trace_locations: TreeVec, +} + +impl FrameworkComponent { + pub fn new(provider: &mut TraceLocationAllocator, eval: E) -> Self { + let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets; + let trace_locations = provider.next_for_structure(&eval_tree_structure); + Self { + eval, + trace_locations, + } + } +} + +impl Component for FrameworkComponent { + fn n_constraints(&self) -> usize { + self.eval.evaluate(InfoEvaluator::default()).n_constraints + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.eval.max_constraint_log_degree_bound() + } + + fn trace_log_degree_bounds(&self) -> TreeVec> { + TreeVec::new( + self.eval + .evaluate(InfoEvaluator::default()) + .mask_offsets + .iter() + .map(|tree_masks| vec![self.eval.log_size(); tree_masks.len()]) + .collect(), + ) + } + + fn mask_points( + &self, + point: CirclePoint, + ) -> TreeVec>>> { + let info = self.eval.evaluate(InfoEvaluator::default()); + let trace_step = CanonicCoset::new(self.eval.log_size()).step(); + info.mask_offsets.map_cols(|col_mask| { + col_mask + .iter() + .map(|off| point + trace_step.mul_signed(*off).into_ef()) + .collect() + }) + } + + fn evaluate_constraint_quotients_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + evaluation_accumulator: &mut PointEvaluationAccumulator, + ) { + self.eval.evaluate(PointEvaluator::new( + mask.sub_tree(&self.trace_locations), + evaluation_accumulator, + coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(), + )); + } +} + +impl ComponentProver for FrameworkComponent { + fn evaluate_constraint_quotients_on_domain( + &self, + trace: &Trace<'_, SimdBackend>, + evaluation_accumulator: &mut DomainEvaluationAccumulator, + ) { + let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); + let trace_domain = CanonicCoset::new(self.eval.log_size()); + + let component_polys = trace.polys.sub_tree(&self.trace_locations); + let component_evals = trace.evals.sub_tree(&self.trace_locations); + + // Extend trace if necessary. + // TODO(spapini): Don't extend when eval_size < committed_size. Instead, pick a good + // subdomain. + let need_to_extend = component_evals + .iter() + .flatten() + .any(|c| c.domain != eval_domain); + let trace: TreeVec< + Vec>>, + > = if need_to_extend { + let _span = span!(Level::INFO, "Extension").entered(); + let twiddles = SimdBackend::precompute_twiddles(eval_domain.half_coset); + component_polys + .as_cols_ref() + .map_cols(|col| Cow::Owned(col.evaluate_with_twiddles(eval_domain, &twiddles))) + } else { + component_evals.clone().map_cols(|c| Cow::Borrowed(*c)) + }; + + // Denom inverses. + let log_expand = eval_domain.log_size() - trace_domain.log_size(); + let mut denom_inv = (0..1 << log_expand) + .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) + .collect_vec(); + utils::bit_reverse(&mut denom_inv); + + // Accumulator. + let [mut accum] = + evaluation_accumulator.columns([(eval_domain.log_size(), self.n_constraints())]); + accum.random_coeff_powers.reverse(); + + let _span = span!(Level::INFO, "Constraint pointwise eval").entered(); + let col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(accum.col) }; + + for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)) { + let trace_cols = trace.as_cols_ref().map_cols(|c| c.as_ref()); + + // Evaluate constrains at row. + let eval = SimdDomainEvaluator::new( + &trace_cols, + vec_row, + &accum.random_coeff_powers, + trace_domain.log_size(), + eval_domain.log_size(), + ); + let row_res = self.eval.evaluate(eval).row_res; + + // Finalize row. + unsafe { + let denom_inv = VeryPackedBaseField::broadcast( + denom_inv[vec_row + >> (trace_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)], + ); + col.set_packed(vec_row, col.packed_at(vec_row) + row_res * denom_inv) + } + } + } +} + +impl Deref for FrameworkComponent { + type Target = E; + + fn deref(&self) -> &E { + &self.eval + } +} diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/constant_columns.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/constant_columns.rs new file mode 100644 index 0000000..e57df28 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/constraint_framework/constant_columns.rs @@ -0,0 +1,37 @@ +use num_traits::One; + +use crate::core::backend::{Backend, Col, Column}; +use crate::core::fields::m31::BaseField; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; + +/// Generates a column with a single one at the first position, and zeros elsewhere. +pub fn gen_is_first(log_size: u32) -> CircleEvaluation { + let mut col = Col::::zeros(1 << log_size); + col.set(0, BaseField::one()); + CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col) +} + +/// Generates a column with `1` at every `2^log_step` positions, `0` elsewhere, shifted by offset. +// TODO(andrew): Consider optimizing. Is a quotients of two coset_vanishing (use succinct rep for +// verifier). +pub fn gen_is_step_with_offset( + log_size: u32, + log_step: u32, + offset: usize, +) -> CircleEvaluation { + let mut col = Col::::zeros(1 << log_size); + + let size = 1 << log_size; + let step = 1 << log_step; + let step_offset = offset % step; + + for i in (step_offset..size).step_by(step) { + let circle_domain_index = coset_index_to_circle_domain_index(i, log_size); + let circle_domain_index_bit_rev = bit_reverse_index(circle_domain_index, log_size); + col.set(circle_domain_index_bit_rev, BaseField::one()); + } + + CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col) +} diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/info.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/info.rs new file mode 100644 index 0000000..05da93f --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/constraint_framework/info.rs @@ -0,0 +1,48 @@ +use std::ops::Mul; + +use num_traits::One; + +use super::EvalAtRow; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::pcs::TreeVec; + +/// Collects information about the constraints. +/// This includes mask offsets and columns at each interaction, and the number of constraints. +#[derive(Default)] +pub struct InfoEvaluator { + pub mask_offsets: TreeVec>>, + pub n_constraints: usize, +} +impl InfoEvaluator { + pub fn new() -> Self { + Self::default() + } +} +impl EvalAtRow for InfoEvaluator { + type F = BaseField; + type EF = SecureField; + fn next_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::F; N] { + // Check if requested a mask from a new interaction + if self.mask_offsets.len() <= interaction { + // Extend `mask_offsets` so that `interaction` is the last index. + self.mask_offsets.resize(interaction + 1, vec![]); + } + self.mask_offsets[interaction].push(offsets.into_iter().collect()); + [BaseField::one(); N] + } + fn add_constraint(&mut self, _constraint: G) + where + Self::EF: Mul, + { + self.n_constraints += 1; + } + + fn combine_ef(_values: [Self::F; 4]) -> Self::EF { + SecureField::one() + } +} diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/logup.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/logup.rs new file mode 100644 index 0000000..696a7b9 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/constraint_framework/logup.rs @@ -0,0 +1,315 @@ +use std::ops::{Mul, Sub}; + +use itertools::Itertools; +use num_traits::{One, Zero}; + +use super::EvalAtRow; +use crate::core::backend::simd::column::SecureColumn; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::channel::Channel; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE}; +use crate::core::fields::FieldExpOps; +use crate::core::lookups::utils::Fraction; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::ColumnVec; + +/// Evaluates constraints for batched logups. +/// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum. +/// BATCH_SIZE is the number of fractions to batch together. The degree of the resulting constraints +/// will be BATCH_SIZE + 1. +pub struct LogupAtRow { + /// The index of the interaction used for the cumulative sum columns. + pub interaction: usize, + /// Queue of fractions waiting to be batched together. + pub queue: [(E::EF, E::EF); BATCH_SIZE], + /// Number of fractions in the queue. + pub queue_size: usize, + /// A constant to subtract from each row, to make the totall sum of the last column zero. + /// In other words, claimed_sum / 2^log_size. + /// This is used to make the constraint uniform. + pub cumsum_shift: SecureField, + /// The evaluation of the last cumulative sum column. + pub prev_col_cumsum: E::EF, + is_finalized: bool, +} +impl LogupAtRow { + pub fn new(interaction: usize, claimed_sum: SecureField, log_size: u32) -> Self { + Self { + interaction, + queue: [(E::EF::zero(), E::EF::zero()); BATCH_SIZE], + queue_size: 0, + cumsum_shift: claimed_sum / BaseField::from_u32_unchecked(1 << log_size), + prev_col_cumsum: E::EF::zero(), + is_finalized: false, + } + } + pub fn push_lookup( + &mut self, + eval: &mut E, + numerator: E::EF, + values: &[E::F], + lookup_elements: &LookupElements, + ) { + let shifted_value = lookup_elements.combine(values); + self.push_frac(eval, numerator, shifted_value); + } + + pub fn push_frac(&mut self, eval: &mut E, numerator: E::EF, denominator: E::EF) { + if self.queue_size < BATCH_SIZE { + self.queue[self.queue_size] = (numerator, denominator); + self.queue_size += 1; + return; + } + + // Compute sum_i pi/qi over batch, as a fraction, num/denom. + let (num, denom) = self.fold_queue(); + + self.queue[0] = (numerator, denominator); + self.queue_size = 1; + + // Add a constraint that num / denom = diff. + let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0]; + let diff = cur_cumsum - self.prev_col_cumsum; + self.prev_col_cumsum = cur_cumsum; + eval.add_constraint(diff * denom - num); + } + + pub fn add_frac(&mut self, eval: &mut E, fraction: Fraction) { + // Add a constraint that num / denom = diff. + let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0]; + let diff = cur_cumsum - self.prev_col_cumsum; + self.prev_col_cumsum = cur_cumsum; + eval.add_constraint(diff * fraction.denominator - fraction.numerator); + } + + pub fn finalize(mut self, eval: &mut E) { + assert!(!self.is_finalized, "LogupAtRow was already finalized"); + let (num, denom) = self.fold_queue(); + + let [cur_cumsum, prev_row_cumsum] = + eval.next_extension_interaction_mask(self.interaction, [0, -1]); + + let diff = cur_cumsum - prev_row_cumsum - self.prev_col_cumsum; + // Instead of checking diff = num / denom, check diff = num / denom - cumsum_shift. + // This makes (num / denom - cumsum_shift) have sum zero, which makes the constraint + // uniform - apply on all rows. + let fixed_diff = diff + self.cumsum_shift; + + eval.add_constraint(fixed_diff * denom - num); + + self.is_finalized = true; + } + + fn fold_queue(&self) -> (E::EF, E::EF) { + self.queue[0..self.queue_size] + .iter() + .copied() + .fold((E::EF::zero(), E::EF::one()), |(p0, q0), (pi, qi)| { + (p0 * qi + pi * q0, qi * q0) + }) + } +} + +/// Ensures that the LogupAtRow is finalized. +/// LogupAtRow should be finalized exactly once. +impl Drop for LogupAtRow { + fn drop(&mut self) { + assert!(self.is_finalized, "LogupAtRow was not finalized"); + } +} + +/// Interaction elements for the logup protocol. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct LookupElements { + pub z: SecureField, + pub alpha: SecureField, + alpha_powers: [SecureField; N], +} +impl LookupElements { + pub fn draw(channel: &mut impl Channel) -> Self { + let [z, alpha] = channel.draw_felts(2).try_into().unwrap(); + let mut cur = SecureField::one(); + let alpha_powers = std::array::from_fn(|_| { + let res = cur; + cur *= alpha; + res + }); + Self { + z, + alpha, + alpha_powers, + } + } + pub fn combine(&self, values: &[F]) -> EF + where + EF: Copy + Zero + From + From + Mul + Sub, + { + EF::from(values[0]) + + values[1..] + .iter() + .zip(self.alpha_powers.iter()) + .fold(EF::zero(), |acc, (&value, &power)| { + acc + EF::from(power) * value + }) + - EF::from(self.z) + } + // TODO(spapini): Try to remove this. + pub fn dummy() -> Self { + Self { + z: SecureField::one(), + alpha: SecureField::one(), + alpha_powers: [SecureField::one(); N], + } + } +} + +// SIMD backend generator for logup interaction trace. +pub struct LogupTraceGenerator { + log_size: u32, + /// Current allocated interaction columns. + trace: Vec>, + /// Denominator expressions (z + sum_i alpha^i * x_i) being generated for the current lookup. + denom: SecureColumn, + /// Preallocated buffer for the Inverses of the denominators. + denom_inv: SecureColumn, +} +impl LogupTraceGenerator { + pub fn new(log_size: u32) -> Self { + let trace = vec![]; + let denom = SecureColumn::zeros(1 << log_size); + let denom_inv = SecureColumn::zeros(1 << log_size); + Self { + log_size, + trace, + denom, + denom_inv, + } + } + + /// Allocate a new lookup column. + pub fn new_col(&mut self) -> LogupColGenerator<'_> { + let log_size = self.log_size; + LogupColGenerator { + gen: self, + numerator: SecureColumnByCoords::::zeros(1 << log_size), + } + } + + /// Finalize the trace. Returns the trace and the claimed sum of the last column. + pub fn finalize( + mut self, + ) -> ( + ColumnVec>, + SecureField, + ) { + // Compute claimed sum. + let mut last_col_coords = self.trace.pop().unwrap().columns; + let packed_sums: [PackedBaseField; SECURE_EXTENSION_DEGREE] = last_col_coords + .each_ref() + .map(|c| c.data.iter().copied().sum()); + let base_sums = packed_sums.map(|s| s.pointwise_sum()); + let claimed_sum = SecureField::from_m31_array(base_sums); + + // Shift the last column to make the sum zero. + let cumsum_shift = claimed_sum / BaseField::from_u32_unchecked(1 << self.log_size); + last_col_coords.iter_mut().enumerate().for_each(|(i, c)| { + c.data + .iter_mut() + .for_each(|x| *x -= PackedBaseField::broadcast(cumsum_shift.to_m31_array()[i])) + }); + + // Prefix sum the last column. + let coord_prefix_sum = last_col_coords.map(inclusive_prefix_sum); + self.trace.push(SecureColumnByCoords { + columns: coord_prefix_sum, + }); + + let trace = self + .trace + .into_iter() + .flat_map(|eval| { + eval.columns.map(|c| { + CircleEvaluation::::new( + CanonicCoset::new(self.log_size).circle_domain(), + c, + ) + }) + }) + .collect_vec(); + (trace, claimed_sum) + } +} + +/// Trace generator for a single lookup column. +pub struct LogupColGenerator<'a> { + gen: &'a mut LogupTraceGenerator, + /// Numerator expressions (i.e. multiplicities) being generated for the current lookup. + numerator: SecureColumnByCoords, +} +impl<'a> LogupColGenerator<'a> { + /// Write a fraction to the column at a row. + pub fn write_frac( + &mut self, + vec_row: usize, + numerator: PackedSecureField, + denom: PackedSecureField, + ) { + debug_assert!( + denom.to_array().iter().all(|x| *x != SecureField::zero()), + "{:?}", + ("denom at vec_row {} is zero {}", denom, vec_row) + ); + unsafe { + self.numerator.set_packed(vec_row, numerator); + *self.gen.denom.data.get_unchecked_mut(vec_row) = denom; + } + } + + /// Finalizes generating the column. + pub fn finalize_col(mut self) { + FieldExpOps::batch_inverse(&self.gen.denom.data, &mut self.gen.denom_inv.data); + + for vec_row in 0..(1 << (self.gen.log_size - LOG_N_LANES)) { + unsafe { + let value = self.numerator.packed_at(vec_row) + * *self.gen.denom_inv.data.get_unchecked(vec_row); + let prev_value = self + .gen + .trace + .last() + .map(|col| col.packed_at(vec_row)) + .unwrap_or_else(PackedSecureField::zero); + self.numerator.set_packed(vec_row, value + prev_value) + }; + } + + self.gen.trace.push(self.numerator) + } +} + +#[cfg(test)] +mod tests { + use num_traits::One; + + use super::LogupAtRow; + use crate::constraint_framework::InfoEvaluator; + use crate::core::fields::qm31::SecureField; + + #[test] + #[should_panic] + fn test_logup_not_finalized_panic() { + let mut logup = LogupAtRow::<2, InfoEvaluator>::new(1, SecureField::one(), 7); + logup.push_frac( + &mut InfoEvaluator::default(), + SecureField::one(), + SecureField::one(), + ); + } +} diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/mod.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/mod.rs new file mode 100644 index 0000000..87069d3 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/constraint_framework/mod.rs @@ -0,0 +1,97 @@ +/// ! This module contains helpers to express and use constraints for components. +mod assert; +mod component; +pub mod constant_columns; +mod info; +pub mod logup; +mod point; +mod simd_domain; + +use std::array; +use std::fmt::Debug; +use std::ops::{Add, AddAssign, Mul, Neg, Sub}; + +pub use assert::{assert_constraints, AssertEvaluator}; +pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator}; +pub use info::InfoEvaluator; +use num_traits::{One, Zero}; +pub use point::PointEvaluator; +pub use simd_domain::SimdDomainEvaluator; + +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::fields::FieldExpOps; + +/// A trait for evaluating expressions at some point or row. +pub trait EvalAtRow { + // TODO(spapini): Use a better trait for these, like 'Algebra' or something. + /// The field type holding values of columns for the component. These are the inputs to the + /// constraints. It might be [BaseField] packed types, or even [SecureField], when evaluating + /// the columns out of domain. + type F: FieldExpOps + + Copy + + Debug + + Zero + + Neg + + AddAssign + + AddAssign + + Add + + Sub + + Mul + + Add + + Mul + + Neg + + From; + + /// A field type representing the closure of `F` with multiplying by [SecureField]. Constraints + /// usually get multiplied by [SecureField] values for security. + type EF: One + + Copy + + Debug + + Zero + + From + + Neg + + AddAssign + + Add + + Sub + + Mul + + Add + + Mul + + Sub + + Mul + + From + + From; + + /// Returns the next mask value for the first interaction at offset 0. + fn next_trace_mask(&mut self) -> Self::F { + let [mask_item] = self.next_interaction_mask(0, [0]); + mask_item + } + + /// Returns the mask values of the given offsets for the next column in the interaction. + fn next_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::F; N]; + + /// Returns the extension mask values of the given offsets for the next extension degree many + /// columns in the interaction. + fn next_extension_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::EF; N] { + let res_col_major = array::from_fn(|_| self.next_interaction_mask(interaction, offsets)); + array::from_fn(|i| Self::combine_ef(res_col_major.map(|c| c[i]))) + } + + /// Adds a constraint to the component. + fn add_constraint(&mut self, constraint: G) + where + Self::EF: Mul; + + /// Combines 4 base field values into a single extension field value. + fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF; +} diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/point.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/point.rs new file mode 100644 index 0000000..6c6f72f --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/constraint_framework/point.rs @@ -0,0 +1,57 @@ +use std::ops::Mul; + +use super::EvalAtRow; +use crate::core::air::accumulation::PointEvaluationAccumulator; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::pcs::TreeVec; +use crate::core::ColumnVec; + +/// Evaluates expressions at a point out of domain. +pub struct PointEvaluator<'a> { + pub mask: TreeVec>>, + pub evaluation_accumulator: &'a mut PointEvaluationAccumulator, + pub col_index: Vec, + pub denom_inverse: SecureField, +} +impl<'a> PointEvaluator<'a> { + pub fn new( + mask: TreeVec>>, + evaluation_accumulator: &'a mut PointEvaluationAccumulator, + denom_inverse: SecureField, + ) -> Self { + let col_index = vec![0; mask.len()]; + Self { + mask, + evaluation_accumulator, + col_index, + denom_inverse, + } + } +} +impl<'a> EvalAtRow for PointEvaluator<'a> { + type F = SecureField; + type EF = SecureField; + + fn next_interaction_mask( + &mut self, + interaction: usize, + _offsets: [isize; N], + ) -> [Self::F; N] { + let col_index = self.col_index[interaction]; + self.col_index[interaction] += 1; + let mask = self.mask[interaction][col_index].clone(); + assert_eq!(mask.len(), N); + mask.try_into().unwrap() + } + fn add_constraint(&mut self, constraint: G) + where + Self::EF: Mul, + { + self.evaluation_accumulator + .accumulate(self.denom_inverse * constraint); + } + fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { + SecureField::from_partial_evals(values) + } +} diff --git a/Stwo_wrapper/crates/prover/src/constraint_framework/simd_domain.rs b/Stwo_wrapper/crates/prover/src/constraint_framework/simd_domain.rs new file mode 100644 index 0000000..ef3662a --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/constraint_framework/simd_domain.rs @@ -0,0 +1,106 @@ +use std::ops::Mul; + +use num_traits::Zero; + +use super::EvalAtRow; +use crate::core::backend::simd::column::VeryPackedBaseColumn; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::very_packed_m31::{ + VeryPackedBaseField, VeryPackedSecureField, LOG_N_VERY_PACKED_ELEMS, +}; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::pcs::TreeVec; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::offset_bit_reversed_circle_domain_index; + +/// Evaluates constraints at an evaluation domain points. +pub struct SimdDomainEvaluator<'a> { + pub trace_eval: + &'a TreeVec>>, + pub column_index_per_interaction: Vec, + /// The row index of the simd-vector row to evaluate the constraints at. + pub vec_row: usize, + pub random_coeff_powers: &'a [SecureField], + pub row_res: VeryPackedSecureField, + pub constraint_index: usize, + pub domain_log_size: u32, + pub eval_domain_log_size: u32, +} +impl<'a> SimdDomainEvaluator<'a> { + pub fn new( + trace_eval: &'a TreeVec>>, + vec_row: usize, + random_coeff_powers: &'a [SecureField], + domain_log_size: u32, + eval_log_size: u32, + ) -> Self { + Self { + trace_eval, + column_index_per_interaction: vec![0; trace_eval.len()], + vec_row, + random_coeff_powers, + row_res: VeryPackedSecureField::zero(), + constraint_index: 0, + domain_log_size, + eval_domain_log_size: eval_log_size, + } + } +} +impl<'a> EvalAtRow for SimdDomainEvaluator<'a> { + type F = VeryPackedBaseField; + type EF = VeryPackedSecureField; + + // TODO(spapini): Remove all boundary checks. + fn next_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::F; N] { + let col_index = self.column_index_per_interaction[interaction]; + self.column_index_per_interaction[interaction] += 1; + offsets.map(|off| { + // If the offset is 0, we can just return the value directly from this row. + if off == 0 { + unsafe { + let col = &self + .trace_eval + .get_unchecked(interaction) + .get_unchecked(col_index) + .values; + let very_packed_col = VeryPackedBaseColumn::transform_under_ref(col); + return *very_packed_col.data.get_unchecked(self.vec_row); + }; + } + // Otherwise, we need to look up the value at the offset. + // Since the domain is bit-reversed circle domain ordered, we need to look up the value + // at the bit-reversed natural order index at an offset. + VeryPackedBaseField::from_array(std::array::from_fn(|i| { + let row_index = offset_bit_reversed_circle_domain_index( + (self.vec_row << (LOG_N_LANES + LOG_N_VERY_PACKED_ELEMS)) + i, + self.domain_log_size, + self.eval_domain_log_size, + off, + ); + self.trace_eval[interaction][col_index].at(row_index) + })) + }) + } + fn add_constraint(&mut self, constraint: G) + where + Self::EF: Mul, + { + self.row_res += + VeryPackedSecureField::broadcast(self.random_coeff_powers[self.constraint_index]) + * constraint; + self.constraint_index += 1; + } + + fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF { + VeryPackedSecureField::from_very_packed_m31s(values) + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/air/accumulation.rs b/Stwo_wrapper/crates/prover/src/core/air/accumulation.rs new file mode 100644 index 0000000..8fcf575 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/air/accumulation.rs @@ -0,0 +1,297 @@ +//! Accumulators for a random linear combination of circle polynomials. +//! Given N polynomials, u_0(P), ... u_{N-1}(P), and a random alpha, the combined polynomial is +//! defined as +//! f(p) = sum_i alpha^{N-1-i} u_i(P). + +use itertools::Itertools; +use tracing::{span, Level}; + +use crate::core::backend::{Backend, Col, Column, CpuBackend}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SecureColumnByCoords; +use crate::core::fields::FieldOps; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, SecureCirclePoly}; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::generate_secure_powers; + +/// Accumulates N evaluations of u_i(P0) at a single point. +/// Computes f(P0), the combined polynomial at that point. +/// For n accumulated evaluations, the i'th evaluation is multiplied by alpha^(N-1-i). +pub struct PointEvaluationAccumulator { + random_coeff: SecureField, + accumulation: SecureField, +} + +impl PointEvaluationAccumulator { + /// Creates a new accumulator. + /// `random_coeff` should be a secure random field element, drawn from the channel. + pub fn new(random_coeff: SecureField) -> Self { + Self { + random_coeff, + accumulation: SecureField::default(), + } + } + + /// Accumulates u_i(P0), a polynomial evaluation at a P0 in reverse order. + pub fn accumulate(&mut self, evaluation: SecureField) { + self.accumulation = self.accumulation * self.random_coeff + evaluation; + } + + pub fn finalize(self) -> SecureField { + self.accumulation + } +} + +// TODO(ShaharS), rename terminology to constraints instead of columns. +/// Accumulates evaluations of u_i(P), each at an evaluation domain of the size of that polynomial. +/// Computes the coefficients of f(P). +pub struct DomainEvaluationAccumulator { + random_coeff_powers: Vec, + /// Accumulated evaluations for each log_size. + /// Each `sub_accumulation` holds the sum over all columns i of that log_size, of + /// `evaluation_i * alpha^(N - 1 - i)` + /// where `N` is the total number of evaluations. + sub_accumulations: Vec>>, +} + +impl DomainEvaluationAccumulator { + /// Creates a new accumulator. + /// `random_coeff` should be a secure random field element, drawn from the channel. + /// `max_log_size` is the maximum log_size of the accumulated evaluations. + pub fn new(random_coeff: SecureField, max_log_size: u32, total_columns: usize) -> Self { + let max_log_size = max_log_size as usize; + Self { + random_coeff_powers: generate_secure_powers(random_coeff, total_columns), + sub_accumulations: (0..(max_log_size + 1)).map(|_| None).collect(), + } + } + + /// Gets accumulators for some sizes. + /// `n_cols_per_size` is an array of pairs (log_size, n_cols). + /// For each entry, a [ColumnAccumulator] is returned, expecting to accumulate `n_cols` + /// evaluations of size `log_size`. + /// The array size, `N`, is the number of different sizes. + pub fn columns( + &mut self, + n_cols_per_size: [(u32, usize); N], + ) -> [ColumnAccumulator<'_, B>; N] { + self.sub_accumulations + .get_many_mut(n_cols_per_size.map(|(log_size, _)| log_size as usize)) + .unwrap_or_else(|e| panic!("invalid log_sizes: {}", e)) + .into_iter() + .zip(n_cols_per_size) + .map(|(col, (log_size, n_cols))| { + let random_coeffs = self + .random_coeff_powers + .split_off(self.random_coeff_powers.len() - n_cols); + ColumnAccumulator { + random_coeff_powers: random_coeffs, + col: col.get_or_insert_with(|| SecureColumnByCoords::zeros(1 << log_size)), + } + }) + .collect_vec() + .try_into() + .unwrap_or_else(|_| unreachable!()) + } + + /// Returns the log size of the resulting polynomial. + pub fn log_size(&self) -> u32 { + (self.sub_accumulations.len() - 1) as u32 + } +} + +pub trait AccumulationOps: FieldOps + Sized { + /// Accumulates other into column: + /// column = column + other. + fn accumulate(column: &mut SecureColumnByCoords, other: &SecureColumnByCoords); +} + +impl DomainEvaluationAccumulator { + /// Computes f(P) as coefficients. + pub fn finalize(self) -> SecureCirclePoly { + assert_eq!( + self.random_coeff_powers.len(), + 0, + "not all random coefficients were used" + ); + let log_size = self.log_size(); + let _span = span!(Level::INFO, "Constraints interpolation").entered(); + let mut cur_poly: Option> = None; + let twiddles = B::precompute_twiddles( + CanonicCoset::new(self.log_size()) + .circle_domain() + .half_coset, + ); + + for (log_size, values) in self.sub_accumulations.into_iter().enumerate().skip(1) { + let Some(mut values) = values else { + continue; + }; + if let Some(prev_poly) = cur_poly { + let eval = SecureColumnByCoords { + columns: prev_poly.0.map(|c| { + c.evaluate_with_twiddles( + CanonicCoset::new(log_size as u32).circle_domain(), + &twiddles, + ) + .values + }), + }; + B::accumulate(&mut values, &eval); + } + cur_poly = Some(SecureCirclePoly(values.columns.map(|c| { + CircleEvaluation::::new( + CanonicCoset::new(log_size as u32).circle_domain(), + c, + ) + .interpolate_with_twiddles(&twiddles) + }))); + } + cur_poly.unwrap_or_else(|| { + SecureCirclePoly(std::array::from_fn(|_| { + CirclePoly::new(Col::::zeros(1 << log_size)) + })) + }) + } +} + +/// A domain accumulator for polynomials of a single size. +pub struct ColumnAccumulator<'a, B: Backend> { + pub random_coeff_powers: Vec, + pub col: &'a mut SecureColumnByCoords, +} +impl<'a> ColumnAccumulator<'a, CpuBackend> { + pub fn accumulate(&mut self, index: usize, evaluation: SecureField) { + let val = self.col.at(index) + evaluation; + self.col.set(index, val); + } +} + +#[cfg(test)] +mod tests { + use std::array; + + use num_traits::Zero; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::*; + use crate::core::backend::cpu::CpuCircleEvaluation; + use crate::core::circle::CirclePoint; + use crate::core::fields::m31::{M31, P}; + use crate::qm31; + + #[test] + fn test_point_evaluation_accumulator() { + // Generate a vector of random sizes with a constant seed. + let mut rng = SmallRng::seed_from_u64(0); + const MAX_LOG_SIZE: u32 = 10; + const MASK: u32 = P; + let log_sizes = (0..100) + .map(|_| rng.gen_range(4..MAX_LOG_SIZE)) + .collect::>(); + + // Generate random evaluations. + let evaluations = log_sizes + .iter() + .map(|_| M31::from_u32_unchecked(rng.gen::() & MASK)) + .collect::>(); + let alpha = qm31!(2, 3, 4, 5); + + // Use accumulator. + let mut accumulator = PointEvaluationAccumulator::new(alpha); + for (_, evaluation) in log_sizes.iter().zip(evaluations.iter()) { + accumulator.accumulate((*evaluation).into()); + } + let accumulator_res = accumulator.finalize(); + + // Use direct computation. + let mut res = SecureField::default(); + for evaluation in evaluations.iter() { + res = res * alpha + *evaluation; + } + + assert_eq!(accumulator_res, res); + } + + #[test] + fn test_domain_evaluation_accumulator() { + // Generate a vector of random sizes with a constant seed. + let mut rng = SmallRng::seed_from_u64(0); + const LOG_SIZE_MIN: u32 = 4; + const LOG_SIZE_BOUND: u32 = 10; + const MASK: u32 = P; + let mut log_sizes = (0..100) + .map(|_| rng.gen_range(LOG_SIZE_MIN..LOG_SIZE_BOUND)) + .collect::>(); + log_sizes.sort(); + + // Generate random evaluations. + let evaluations = log_sizes + .iter() + .map(|log_size| { + (0..(1 << *log_size)) + .map(|_| M31::from_u32_unchecked(rng.gen::() & MASK)) + .collect::>() + }) + .collect::>(); + let alpha = qm31!(2, 3, 4, 5); + + // Use accumulator. + let mut accumulator = DomainEvaluationAccumulator::::new( + alpha, + LOG_SIZE_BOUND, + evaluations.len(), + ); + let n_cols_per_size: [(u32, usize); (LOG_SIZE_BOUND - LOG_SIZE_MIN) as usize] = + array::from_fn(|i| { + let current_log_size = LOG_SIZE_MIN + i as u32; + let n_cols = log_sizes + .iter() + .copied() + .filter(|&log_size| log_size == current_log_size) + .count(); + (current_log_size, n_cols) + }); + let mut cols = accumulator.columns(n_cols_per_size); + let mut eval_chunk_offset = 0; + for (log_size, n_cols) in n_cols_per_size.iter() { + for index in 0..(1 << log_size) { + let mut val = SecureField::zero(); + for (eval_index, (col_log_size, evaluation)) in + log_sizes.iter().zip(evaluations.iter()).enumerate() + { + if *log_size != *col_log_size { + continue; + } + + // The random coefficient powers chunk is in regular order. + let random_coeff_chunk = + &cols[(log_size - LOG_SIZE_MIN) as usize].random_coeff_powers; + val += random_coeff_chunk + [random_coeff_chunk.len() - 1 - (eval_index - eval_chunk_offset)] + * evaluation[index]; + } + cols[(log_size - LOG_SIZE_MIN) as usize].accumulate(index, val); + } + eval_chunk_offset += n_cols; + } + let accumulator_poly = accumulator.finalize(); + + // Pick an arbitrary sample point. + let point = CirclePoint::::get_point(98989892); + let accumulator_res = accumulator_poly.eval_at_point(point); + + // Use direct computation. + let mut res = SecureField::default(); + for (log_size, values) in log_sizes.into_iter().zip(evaluations) { + res = res * alpha + + CpuCircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), values) + .interpolate() + .eval_at_point(point); + } + + assert_eq!(accumulator_res, res); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/air/components.rs b/Stwo_wrapper/crates/prover/src/core/air/components.rs new file mode 100644 index 0000000..a7e0129 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/air/components.rs @@ -0,0 +1,80 @@ +use itertools::Itertools; + +use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; +use super::{Component, ComponentProver, Trace}; +use crate::core::backend::Backend; +use crate::core::circle::CirclePoint; +use crate::core::fields::qm31::SecureField; +use crate::core::pcs::TreeVec; +use crate::core::poly::circle::SecureCirclePoly; +use crate::core::ColumnVec; + +pub struct Components<'a>(pub Vec<&'a dyn Component>); + +impl<'a> Components<'a> { + pub fn composition_log_degree_bound(&self) -> u32 { + self.0 + .iter() + .map(|component| component.max_constraint_log_degree_bound()) + .max() + .unwrap() + } + + pub fn mask_points( + &self, + point: CirclePoint, + ) -> TreeVec>>> { + TreeVec::concat_cols(self.0.iter().map(|component| component.mask_points(point))) + } + + pub fn eval_composition_polynomial_at_point( + &self, + point: CirclePoint, + mask_values: &TreeVec>>, + random_coeff: SecureField, + ) -> SecureField { + //accumulator for the random linear comination over powers of random_coeff + let mut evaluation_accumulator = PointEvaluationAccumulator::new(random_coeff); + + for component in &self.0 { + component.evaluate_constraint_quotients_at_point( + point, + mask_values, + &mut evaluation_accumulator, + ) + } + evaluation_accumulator.finalize() + } + + pub fn column_log_sizes(&self) -> TreeVec> { + TreeVec::concat_cols( + self.0 + .iter() + .map(|component| component.trace_log_degree_bounds()), + ) + } +} + +pub struct ComponentProvers<'a, B: Backend>(pub Vec<&'a dyn ComponentProver>); + +impl<'a, B: Backend> ComponentProvers<'a, B> { + pub fn components(&self) -> Components<'_> { + Components(self.0.iter().map(|c| *c as &dyn Component).collect_vec()) + } + pub fn compute_composition_polynomial( + &self, + random_coeff: SecureField, + trace: &Trace<'_, B>, + ) -> SecureCirclePoly { + let total_constraints: usize = self.0.iter().map(|c| c.n_constraints()).sum(); + let mut accumulator = DomainEvaluationAccumulator::new( + random_coeff, + self.components().composition_log_degree_bound(), + total_constraints, + ); + for component in &self.0 { + component.evaluate_constraint_quotients_on_domain(trace, &mut accumulator) + } + accumulator.finalize() + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/air/mask.rs b/Stwo_wrapper/crates/prover/src/core/air/mask.rs new file mode 100644 index 0000000..e2748a6 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/air/mask.rs @@ -0,0 +1,91 @@ +use std::collections::HashSet; +use std::vec; + +use itertools::Itertools; + +use crate::core::circle::CirclePoint; +use crate::core::fields::qm31::SecureField; +use crate::core::poly::circle::CanonicCoset; +use crate::core::ColumnVec; + +/// Mask holds a vector with an entry for each column. +/// Each entry holds a list of mask items, which are the offsets of the mask at that column. +type Mask = ColumnVec>; + +/// Returns the same point for each mask item. +/// Should be used where all the mask items has no shift from the constraint point. +pub fn fixed_mask_points( + mask: &Mask, + point: CirclePoint, +) -> ColumnVec>> { + assert_eq!( + mask.iter() + .flat_map(|mask_entry| mask_entry.iter().collect::>()) + .collect::>() + .into_iter() + .collect_vec(), + vec![&0] + ); + mask.iter() + .map(|mask_entry| mask_entry.iter().map(|_| point).collect()) + .collect() +} + +/// For each mask item returns the point shifted by the domain initial point of the column. +/// Should be used where the mask items are shifted from the constraint point. +pub fn shifted_mask_points( + mask: &Mask, + domains: &[CanonicCoset], + point: CirclePoint, +) -> ColumnVec>> { + mask.iter() + .zip(domains.iter()) + .map(|(mask_entry, domain)| { + mask_entry + .iter() + .map(|mask_item| point + domain.at(*mask_item).into_ef()) + .collect() + }) + .collect() +} + +#[cfg(test)] +mod tests { + use crate::core::air::mask::{fixed_mask_points, shifted_mask_points}; + use crate::core::circle::CirclePoint; + use crate::core::poly::circle::CanonicCoset; + + #[test] + fn test_mask_fixed_points() { + let mask = vec![vec![0], vec![0]]; + let constraint_point = CirclePoint::get_point(1234); + + let points = fixed_mask_points(&mask, constraint_point); + + assert_eq!(points.len(), 2); + assert_eq!(points[0].len(), 1); + assert_eq!(points[1].len(), 1); + assert_eq!(points[0][0], constraint_point); + assert_eq!(points[1][0], constraint_point); + } + + #[test] + fn test_mask_shifted_points() { + let mask = vec![vec![0, 1], vec![0, 1, 2]]; + let constraint_point = CirclePoint::get_point(1234); + let domains = (0..mask.len() as u32) + .map(|i| CanonicCoset::new(7 + i)) + .collect::>(); + + let points = shifted_mask_points(&mask, &domains, constraint_point); + + assert_eq!(points.len(), 2); + assert_eq!(points[0].len(), 2); + assert_eq!(points[1].len(), 3); + assert_eq!(points[0][0], constraint_point + domains[0].at(0).into_ef()); + assert_eq!(points[0][1], constraint_point + domains[0].at(1).into_ef()); + assert_eq!(points[1][0], constraint_point + domains[1].at(0).into_ef()); + assert_eq!(points[1][1], constraint_point + domains[1].at(1).into_ef()); + assert_eq!(points[1][2], constraint_point + domains[1].at(2).into_ef()); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/air/mod.rs b/Stwo_wrapper/crates/prover/src/core/air/mod.rs new file mode 100644 index 0000000..fcdd4d5 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/air/mod.rs @@ -0,0 +1,76 @@ +pub use components::{ComponentProvers, Components}; + +use self::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; +use super::backend::Backend; +use super::circle::CirclePoint; +use super::fields::m31::BaseField; +use super::fields::qm31::SecureField; +use super::pcs::TreeVec; +use super::poly::circle::{CircleEvaluation, CirclePoly}; +use super::poly::BitReversedOrder; +use super::ColumnVec; + +pub mod accumulation; +mod components; +pub mod mask; + +/// Arithmetic Intermediate Representation (AIR). +/// An Air instance is assumed to already contain all the information needed to +/// evaluate the constraints. +/// For instance, all interaction elements are assumed to be present in it. +/// Therefore, an AIR is generated only after the initial trace commitment phase. +// TODO(spapini): consider renaming this struct. +pub trait Air { + fn components(&self) -> Vec<&dyn Component>; +} + +pub trait AirProver: Air { + fn component_provers(&self) -> Vec<&dyn ComponentProver>; +} + +/// A component is a set of trace columns of various sizes along with a set of +/// constraints on them. +pub trait Component { + fn n_constraints(&self) -> usize; + + fn max_constraint_log_degree_bound(&self) -> u32; + + /// Returns the degree bounds of each trace column. The returned TreeVec should be of size + /// `n_interaction_phases`. + fn trace_log_degree_bounds(&self) -> TreeVec>; + + /// Returns the mask points for each trace column. The returned TreeVec should be of size + /// `n_interaction_phases`. + fn mask_points( + &self, + point: CirclePoint, + ) -> TreeVec>>>; + + /// Evaluates the constraint quotients combination of the component at a point. + fn evaluate_constraint_quotients_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + evaluation_accumulator: &mut PointEvaluationAccumulator, + ); +} + +pub trait ComponentProver: Component { + /// Evaluates the constraint quotients of the component on the evaluation domain. + /// Accumulates quotients in `evaluation_accumulator`. + fn evaluate_constraint_quotients_on_domain( + &self, + trace: &Trace<'_, B>, + evaluation_accumulator: &mut DomainEvaluationAccumulator, + ); +} + +/// The set of polynomials that make up the trace. +/// +/// Each polynomial is stored both in a coefficients, and evaluations form (for efficiency) +pub struct Trace<'a, B: Backend> { + /// Polynomials for each column. + pub polys: TreeVec>>, + /// Evaluations for each column (evaluated on their commitment domains). + pub evals: TreeVec>>, +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/accumulation.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/accumulation.rs new file mode 100644 index 0000000..63a49bf --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/cpu/accumulation.rs @@ -0,0 +1,12 @@ +use super::CpuBackend; +use crate::core::air::accumulation::AccumulationOps; +use crate::core::fields::secure_column::SecureColumnByCoords; + +impl AccumulationOps for CpuBackend { + fn accumulate(column: &mut SecureColumnByCoords, other: &SecureColumnByCoords) { + for i in 0..column.len() { + let res_coeff = column.at(i) + other.at(i); + column.set(i, res_coeff); + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/blake2s.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/blake2s.rs new file mode 100644 index 0000000..a87a5ae --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/cpu/blake2s.rs @@ -0,0 +1,24 @@ +use itertools::Itertools; + +use crate::core::backend::CpuBackend; +use crate::core::fields::m31::BaseField; +use crate::core::vcs::blake2_hash::Blake2sHash; +use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; +use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; + +impl MerkleOps for CpuBackend { + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Vec>, + columns: &[&Vec], + ) -> Vec { + (0..(1 << log_size)) + .map(|i| { + Blake2sMerkleHasher::hash_node( + prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), + &columns.iter().map(|column| column[i]).collect_vec(), + ) + }) + .collect() + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/circle.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/circle.rs new file mode 100644 index 0000000..c37ffe2 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/cpu/circle.rs @@ -0,0 +1,376 @@ +use num_traits::Zero; + +use super::CpuBackend; +use crate::core::backend::{Col, ColumnOps}; +use crate::core::circle::{CirclePoint, Coset}; +use crate::core::fft::{butterfly, ibutterfly}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::{ExtensionOf, FieldExpOps}; +use crate::core::poly::circle::{ + CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps, +}; +use crate::core::poly::twiddles::TwiddleTree; +use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold}; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; + +impl PolyOps for CpuBackend { + type Twiddles = Vec; + + fn new_canonical_ordered( + coset: CanonicCoset, + values: Col, + ) -> CircleEvaluation { + let domain = coset.circle_domain(); + assert_eq!(values.len(), domain.size()); + let mut new_values = coset_order_to_circle_domain_order(&values); + CpuBackend::bit_reverse_column(&mut new_values); + CircleEvaluation::new(domain, new_values) + } + + fn interpolate( + eval: CircleEvaluation, + twiddles: &TwiddleTree, + ) -> CirclePoly { + assert!(eval.domain.half_coset.is_doubling_of(twiddles.root_coset)); + + let mut values = eval.values; + + if eval.domain.log_size() == 1 { + let y = eval.domain.half_coset.initial.y; + let n = BaseField::from(2); + let yn_inv = (y * n).inverse(); + let y_inv = yn_inv * n; + let n_inv = yn_inv * y; + let (mut v0, mut v1) = (values[0], values[1]); + ibutterfly(&mut v0, &mut v1, y_inv); + return CirclePoly::new(vec![v0 * n_inv, v1 * n_inv]); + } + + if eval.domain.log_size() == 2 { + let CirclePoint { x, y } = eval.domain.half_coset.initial; + let n = BaseField::from(4); + let xyn_inv = (x * y * n).inverse(); + let x_inv = xyn_inv * y * n; + let y_inv = xyn_inv * x * n; + let n_inv = xyn_inv * x * y; + let (mut v0, mut v1, mut v2, mut v3) = (values[0], values[1], values[2], values[3]); + ibutterfly(&mut v0, &mut v1, y_inv); + ibutterfly(&mut v2, &mut v3, -y_inv); + ibutterfly(&mut v0, &mut v2, x_inv); + ibutterfly(&mut v1, &mut v3, x_inv); + return CirclePoly::new(vec![v0 * n_inv, v1 * n_inv, v2 * n_inv, v3 * n_inv]); + } + + let line_twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles); + let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]); + + for (h, t) in circle_twiddles.enumerate() { + fft_layer_loop(&mut values, 0, h, t, ibutterfly); + } + for (layer, layer_twiddles) in line_twiddles.into_iter().enumerate() { + for (h, &t) in layer_twiddles.iter().enumerate() { + fft_layer_loop(&mut values, layer + 1, h, t, ibutterfly); + } + } + + // Divide all values by 2^log_size. + let inv = BaseField::from_u32_unchecked(eval.domain.size() as u32).inverse(); + for val in &mut values { + *val *= inv; + } + + CirclePoly::new(values) + } + + fn eval_at_point(poly: &CirclePoly, point: CirclePoint) -> SecureField { + if poly.log_size() == 0 { + return poly.coeffs[0].into(); + } + + let mut mappings = vec![point.y]; + let mut x = point.x; + for _ in 1..poly.log_size() { + mappings.push(x); + x = CirclePoint::double_x(x); + } + mappings.reverse(); + + fold(&poly.coeffs, &mappings) + } + + fn extend(poly: &CirclePoly, log_size: u32) -> CirclePoly { + assert!(log_size >= poly.log_size()); + let mut coeffs = Vec::with_capacity(1 << log_size); + coeffs.extend_from_slice(&poly.coeffs); + coeffs.resize(1 << log_size, BaseField::zero()); + CirclePoly::new(coeffs) + } + + fn evaluate( + poly: &CirclePoly, + domain: CircleDomain, + twiddles: &TwiddleTree, + ) -> CircleEvaluation { + assert!(domain.half_coset.is_doubling_of(twiddles.root_coset)); + + let mut values = poly.extend(domain.log_size()).coeffs; + + if domain.log_size() == 1 { + let (mut v0, mut v1) = (values[0], values[1]); + butterfly(&mut v0, &mut v1, domain.half_coset.initial.y); + return CircleEvaluation::new(domain, vec![v0, v1]); + } + + if domain.log_size() == 2 { + let (mut v0, mut v1, mut v2, mut v3) = (values[0], values[1], values[2], values[3]); + let CirclePoint { x, y } = domain.half_coset.initial; + butterfly(&mut v0, &mut v2, x); + butterfly(&mut v1, &mut v3, x); + butterfly(&mut v0, &mut v1, y); + butterfly(&mut v2, &mut v3, -y); + return CircleEvaluation::new(domain, vec![v0, v1, v2, v3]); + } + + let line_twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles); + let circle_twiddles = circle_twiddles_from_line_twiddles(line_twiddles[0]); + + for (layer, layer_twiddles) in line_twiddles.iter().enumerate().rev() { + for (h, &t) in layer_twiddles.iter().enumerate() { + fft_layer_loop(&mut values, layer + 1, h, t, butterfly); + } + } + for (h, t) in circle_twiddles.enumerate() { + fft_layer_loop(&mut values, 0, h, t, butterfly); + } + + CircleEvaluation::new(domain, values) + } + + fn precompute_twiddles(mut coset: Coset) -> TwiddleTree { + const CHUNK_LOG_SIZE: usize = 12; + const CHUNK_SIZE: usize = 1 << CHUNK_LOG_SIZE; + + let root_coset = coset; + let mut twiddles = Vec::with_capacity(coset.size()); + for _ in 0..coset.log_size() { + let i0 = twiddles.len(); + twiddles.extend( + coset + .iter() + .take(coset.size() / 2) + .map(|p| p.x) + .collect::>(), + ); + bit_reverse(&mut twiddles[i0..]); + coset = coset.double(); + } + twiddles.push(1.into()); + + // Inverse twiddles. + // Fallback to the non-chunked version if the domain is not big enough. + if CHUNK_SIZE > coset.size() { + let itwiddles = twiddles.iter().map(|&t| t.inverse()).collect(); + return TwiddleTree { + root_coset, + twiddles, + itwiddles, + }; + } + + let mut itwiddles = vec![BaseField::zero(); twiddles.len()]; + twiddles + .array_chunks::() + .zip(itwiddles.array_chunks_mut::()) + .for_each(|(src, dst)| { + BaseField::batch_inverse(src, dst); + }); + + TwiddleTree { + root_coset, + twiddles, + itwiddles, + } + } +} + +fn fft_layer_loop( + values: &mut [BaseField], + i: usize, + h: usize, + t: BaseField, + butterfly_fn: impl Fn(&mut BaseField, &mut BaseField, BaseField), +) { + for l in 0..(1 << i) { + let idx0 = (h << (i + 1)) + l; + let idx1 = idx0 + (1 << i); + let (mut val0, mut val1) = (values[idx0], values[idx1]); + butterfly_fn(&mut val0, &mut val1, t); + (values[idx0], values[idx1]) = (val0, val1); + } +} + +/// Computes the circle twiddles layer (layer 0) from the first line twiddles layer (layer 1). +/// +/// Only works for line twiddles generated from a domain with size `>4`. +fn circle_twiddles_from_line_twiddles( + first_line_twiddles: &[BaseField], +) -> impl Iterator + '_ { + // The twiddles for layer 0 can be computed from the twiddles for layer 1. + // Since the twiddles are bit reversed, we consider the circle domain in bit reversed order. + // Each consecutive 4 points in the bit reversed order of a coset form a circle coset of size 4. + // A circle coset of size 4 in bit reversed order looks like this: + // [(x, y), (-x, -y), (y, -x), (-y, x)] + // Note: This relation is derived from the fact that `M31_CIRCLE_GEN`.repeated_double(ORDER / 4) + // == (-1,0), and not (0,1). (0,1) would yield another relation. + // The twiddles for layer 0 are the y coordinates: + // [y, -y, -x, x] + // The twiddles for layer 1 in bit reversed order are the x coordinates of the even indices + // points: + // [x, y] + // Works also for inverse of the twiddles. + first_line_twiddles + .iter() + .array_chunks() + .flat_map(|[&x, &y]| [y, -y, -x, x]) +} + +impl, EvalOrder> IntoIterator + for CircleEvaluation +{ + type Item = F; + type IntoIter = std::vec::IntoIter; + + /// Creates a consuming iterator over the evaluations. + /// + /// Evaluations are returned in the same order as elements of the domain. + fn into_iter(self) -> Self::IntoIter { + self.values.into_iter() + } +} + +#[cfg(test)] +mod tests { + use std::iter::zip; + + use num_traits::One; + + use crate::core::backend::cpu::CpuCirclePoly; + use crate::core::circle::CirclePoint; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::poly::circle::CanonicCoset; + + #[test] + fn test_eval_at_point_with_4_coeffs() { + // Represents the polynomial `1 + 2y + 3x + 4xy`. + // Note coefficients are passed in bit reversed order. + let poly = CpuCirclePoly::new([1, 3, 2, 4].map(BaseField::from).to_vec()); + let x = BaseField::from(5).into(); + let y = BaseField::from(8).into(); + + let eval = poly.eval_at_point(CirclePoint { x, y }); + + assert_eq!( + eval, + poly.coeffs[0] + poly.coeffs[1] * y + poly.coeffs[2] * x + poly.coeffs[3] * x * y + ); + } + + #[test] + fn test_eval_at_point_with_2_coeffs() { + // Represents the polynomial `1 + 2y`. + let poly = CpuCirclePoly::new(vec![BaseField::from(1), BaseField::from(2)]); + let x = BaseField::from(5).into(); + let y = BaseField::from(8).into(); + + let eval = poly.eval_at_point(CirclePoint { x, y }); + + assert_eq!(eval, poly.coeffs[0] + poly.coeffs[1] * y); + } + + #[test] + fn test_eval_at_point_with_1_coeff() { + // Represents the polynomial `1`. + let poly = CpuCirclePoly::new(vec![BaseField::one()]); + let x = BaseField::from(5).into(); + let y = BaseField::from(8).into(); + + let eval = poly.eval_at_point(CirclePoint { x, y }); + + assert_eq!(eval, SecureField::one()); + } + + #[test] + fn test_evaluate_2_coeffs() { + let domain = CanonicCoset::new(1).circle_domain(); + let poly = CpuCirclePoly::new((1..=2).map(BaseField::from).collect()); + + let evaluation = poly.clone().evaluate(domain).bit_reverse(); + + for (i, (p, eval)) in zip(domain, evaluation).enumerate() { + let eval: SecureField = eval.into(); + assert_eq!(eval, poly.eval_at_point(p.into_ef()), "mismatch at i={i}"); + } + } + + #[test] + fn test_evaluate_4_coeffs() { + let domain = CanonicCoset::new(2).circle_domain(); + let poly = CpuCirclePoly::new((1..=4).map(BaseField::from).collect()); + + let evaluation = poly.clone().evaluate(domain).bit_reverse(); + + for (i, (x, eval)) in zip(domain, evaluation).enumerate() { + let eval: SecureField = eval.into(); + assert_eq!(eval, poly.eval_at_point(x.into_ef()), "mismatch at i={i}"); + } + } + + #[test] + fn test_evaluate_8_coeffs() { + let domain = CanonicCoset::new(3).circle_domain(); + let poly = CpuCirclePoly::new((1..=8).map(BaseField::from).collect()); + + let evaluation = poly.clone().evaluate(domain).bit_reverse(); + + for (i, (x, eval)) in zip(domain, evaluation).enumerate() { + let eval: SecureField = eval.into(); + assert_eq!(eval, poly.eval_at_point(x.into_ef()), "mismatch at i={i}"); + } + } + + #[test] + fn test_interpolate_2_evals() { + let poly = CpuCirclePoly::new(vec![BaseField::one(), BaseField::from(2)]); + let domain = CanonicCoset::new(1).circle_domain(); + let evals = poly.clone().evaluate(domain); + + let interpolated_poly = evals.interpolate(); + + assert_eq!(interpolated_poly.coeffs, poly.coeffs); + } + + #[test] + fn test_interpolate_4_evals() { + let poly = CpuCirclePoly::new((1..=4).map(BaseField::from).collect()); + let domain = CanonicCoset::new(2).circle_domain(); + let evals = poly.clone().evaluate(domain); + + let interpolated_poly = evals.interpolate(); + + assert_eq!(interpolated_poly.coeffs, poly.coeffs); + } + + #[test] + fn test_interpolate_8_evals() { + let poly = CpuCirclePoly::new((1..=8).map(BaseField::from).collect()); + let domain = CanonicCoset::new(3).circle_domain(); + let evals = poly.clone().evaluate(domain); + + let interpolated_poly = evals.interpolate(); + + assert_eq!(interpolated_poly.coeffs, poly.coeffs); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/fri.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/fri.rs new file mode 100644 index 0000000..693fb99 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/cpu/fri.rs @@ -0,0 +1,144 @@ +use super::CpuBackend; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SecureColumnByCoords; +use crate::core::fri::{fold_circle_into_line, fold_line, FriOps}; +use crate::core::poly::circle::SecureEvaluation; +use crate::core::poly::line::LineEvaluation; +use crate::core::poly::twiddles::TwiddleTree; +use crate::core::poly::BitReversedOrder; + +// TODO(spapini): Optimized these functions as well. +impl FriOps for CpuBackend { + fn fold_line( + eval: &LineEvaluation, + alpha: SecureField, + _twiddles: &TwiddleTree, + ) -> LineEvaluation { + fold_line(eval, alpha) + } + fn fold_circle_into_line( + dst: &mut LineEvaluation, + src: &SecureEvaluation, + alpha: SecureField, + _twiddles: &TwiddleTree, + ) { + fold_circle_into_line(dst, src, alpha) + } + + fn decompose( + eval: &SecureEvaluation, + ) -> (SecureEvaluation, SecureField) { + let lambda = Self::decomposition_coefficient(eval); + let mut g_values = unsafe { SecureColumnByCoords::::uninitialized(eval.len()) }; + + let domain_size = eval.len(); + let half_domain_size = domain_size / 2; + + for i in 0..half_domain_size { + let x = eval.values.at(i); + let val = x - lambda; + g_values.set(i, val); + } + for i in half_domain_size..domain_size { + let x = eval.values.at(i); + let val = x + lambda; + g_values.set(i, val); + } + + let g = SecureEvaluation::new(eval.domain, g_values); + (g, lambda) + } +} + +impl CpuBackend { + /// Used to decompose a general polynomial to a polynomial inside the fft-space, and + /// the remainder terms. + /// A coset-diff on a [`CirclePoly`] that is in the FFT space will return zero. + /// + /// Let N be the domain size, Let h be a coset size N/2. Using lemma #7 from the CircleStark + /// paper, = lambda = lambda\*N => lambda = f(0)\*V_h(0) + f(1)*V_h(1) + .. + + /// f(N-1)\*V_h(N-1). The Vanishing polynomial of a cannonic coset sized half the circle + /// domain,evaluated on the circle domain, is [(1, -1, -1, 1)] repeating. This becomes + /// alternating [+-1] in our NaturalOrder, and [(+, +, +, ... , -, -)] in bit reverse. + /// Explicitly, lambda\*N = sum(+f(0..N/2)) + sum(-f(N/2..)). + /// + /// # Warning + /// This function assumes the blowupfactor is 2 + /// + /// [`CirclePoly`]: crate::core::poly::circle::CirclePoly + fn decomposition_coefficient(eval: &SecureEvaluation) -> SecureField { + let domain_size = 1 << eval.domain.log_size(); + let half_domain_size = domain_size / 2; + + // eval is in bit-reverse, hence all the positive factors are in the first half, opposite to + // the latter. + let a_sum = (0..half_domain_size) + .map(|i| eval.values.at(i)) + .sum::(); + let b_sum = (half_domain_size..domain_size) + .map(|i| eval.values.at(i)) + .sum::(); + + // lambda = sum(+-f(p)) / 2N. + (a_sum - b_sum) / BaseField::from_u32_unchecked(domain_size as u32) + } +} + +#[cfg(test)] +mod tests { + use num_traits::Zero; + + use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly}; + use crate::core::backend::CpuBackend; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::secure_column::SecureColumnByCoords; + use crate::core::fri::FriOps; + use crate::core::poly::circle::{CanonicCoset, SecureEvaluation}; + use crate::core::poly::BitReversedOrder; + use crate::m31; + + #[test] + fn decompose_coeff_out_fft_space_test() { + for domain_log_size in 5..12 { + let domain_log_half_size = domain_log_size - 1; + let s = CanonicCoset::new(domain_log_size); + let domain = s.circle_domain(); + + let mut coeffs = vec![BaseField::zero(); 1 << domain_log_size]; + + // Polynomial is out of FFT space. + coeffs[1 << domain_log_half_size] = m31!(1); + assert!(!CpuCirclePoly::new(coeffs.clone()).is_in_fft_space(domain_log_half_size)); + + let poly = CpuCirclePoly::new(coeffs); + let values = poly.evaluate(domain); + let secure_column = SecureColumnByCoords { + columns: [ + values.values.clone(), + values.values.clone(), + values.values.clone(), + values.values.clone(), + ], + }; + let secure_eval = SecureEvaluation::::new( + domain, + secure_column.clone(), + ); + + let (g, lambda) = CpuBackend::decompose(&secure_eval); + + // Sanity check. + assert_ne!(lambda, SecureField::zero()); + + // Assert the new polynomial is in the FFT space. + for i in 0..4 { + let basefield_column = g.columns[i].clone(); + let eval = CpuCircleEvaluation::new(domain, basefield_column); + let coeffs = eval.interpolate().coeffs; + assert!(CpuCirclePoly::new(coeffs).is_in_fft_space(domain_log_half_size)); + } + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/grind.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/grind.rs new file mode 100644 index 0000000..c5d27a1 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/cpu/grind.rs @@ -0,0 +1,18 @@ +use super::CpuBackend; +use crate::core::channel::Channel; +use crate::core::proof_of_work::GrindOps; + +impl GrindOps for CpuBackend { + fn grind(channel: &C, pow_bits: u32) -> u64 { + // TODO(spapini): This is a naive implementation. Optimize it. + let mut nonce = 0; + loop { + let mut channel = channel.clone(); + channel.mix_nonce(nonce); + if channel.trailing_zeros() >= pow_bits { + return nonce; + } + nonce += 1; + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/gkr.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/gkr.rs new file mode 100644 index 0000000..ae9ab6b --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/gkr.rs @@ -0,0 +1,448 @@ +use std::ops::Index; + +use num_traits::{One, Zero}; + +use crate::core::backend::CpuBackend; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::{ExtensionOf, Field}; +use crate::core::lookups::gkr_prover::{ + correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer, +}; +use crate::core::lookups::mle::{Mle, MleOps}; +use crate::core::lookups::sumcheck::MultivariatePolyOracle; +use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly}; + +impl GkrOps for CpuBackend { + fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle { + Mle::new(gen_eq_evals(y, v)) + } + + fn next_layer(layer: &Layer) -> Layer { + match layer { + Layer::GrandProduct(layer) => next_grand_product_layer(layer), + Layer::LogUpGeneric { + numerators, + denominators, + } => next_logup_layer(MleExpr::Mle(numerators), denominators), + Layer::LogUpMultiplicities { + numerators, + denominators, + } => next_logup_layer(MleExpr::Mle(numerators), denominators), + Layer::LogUpSingles { denominators } => { + next_logup_layer(MleExpr::Constant(BaseField::one()), denominators) + } + } + } + + fn sum_as_poly_in_first_variable( + h: &GkrMultivariatePolyOracle<'_, Self>, + claim: SecureField, + ) -> UnivariatePoly { + let n_variables = h.n_variables(); + assert!(!n_variables.is_zero()); + let n_terms = 1 << (n_variables - 1); + let eq_evals = h.eq_evals.as_ref(); + // Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube. + let y = eq_evals.y(); + let lambda = h.lambda; + + let (mut eval_at_0, mut eval_at_2) = match &h.input_layer { + Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_terms), + Layer::LogUpGeneric { + numerators, + denominators, + } => eval_logup_sum(eq_evals, numerators, denominators, n_terms, lambda), + Layer::LogUpMultiplicities { + numerators, + denominators, + } => eval_logup_sum(eq_evals, numerators, denominators, n_terms, lambda), + Layer::LogUpSingles { denominators } => { + eval_logup_singles_sum(eq_evals, denominators, n_terms, lambda) + } + }; + + eval_at_0 *= h.eq_fixed_var_correction; + eval_at_2 *= h.eq_fixed_var_correction; + correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables) + } +} + +/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`. +/// +/// Output of the form: `(eval_at_0, eval_at_2)`. +fn eval_grand_product_sum( + eq_evals: &EqEvals, + input_layer: &Mle, + n_terms: usize, +) -> (SecureField, SecureField) { + let mut eval_at_0 = SecureField::zero(); + let mut eval_at_2 = SecureField::zero(); + + for i in 0..n_terms { + // Input polynomial at points `(r, {0, 1, 2}, bits(i), {0, 1})`. + let inp_at_r0i0 = input_layer[i * 2]; + let inp_at_r0i1 = input_layer[i * 2 + 1]; + let inp_at_r1i0 = input_layer[(n_terms + i) * 2]; + let inp_at_r1i1 = input_layer[(n_terms + i) * 2 + 1]; + // Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)` + // => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)` + // TODO(andrew): Consider evaluation at `1/2` to save an addition operation since + // `inp(r, 1/2, x) = 1/2 * (inp(r, 1, x) + inp(r, 0, x))`. `1/2 * ...` can be factored + // outside the loop. + let inp_at_r2i0 = inp_at_r1i0.double() - inp_at_r0i0; + let inp_at_r2i1 = inp_at_r1i1.double() - inp_at_r0i1; + + // Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i))`. + let prod_at_r2i = inp_at_r2i0 * inp_at_r2i1; + let prod_at_r0i = inp_at_r0i0 * inp_at_r0i1; + + let eq_eval_at_0i = eq_evals[i]; + eval_at_0 += eq_eval_at_0i * prod_at_r0i; + eval_at_2 += eq_eval_at_0i * prod_at_r2i; + } + + (eval_at_0, eval_at_2) +} + +/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_numer(r, t, x, 0) * inp_denom(r, t, x, 1) + +/// inp_numer(r, t, x, 1) * inp_denom(r, t, x, 0) + lambda * inp_denom(r, t, x, 0) * inp_denom(r, t, +/// x, 1))` at `t=0` and `t=2`. +/// +/// Output of the form: `(eval_at_0, eval_at_2)`. +fn eval_logup_sum( + eq_evals: &EqEvals, + input_numerators: &Mle, + input_denominators: &Mle, + n_terms: usize, + lambda: SecureField, +) -> (SecureField, SecureField) +where + SecureField: ExtensionOf + Field, +{ + let mut eval_at_0 = SecureField::zero(); + let mut eval_at_2 = SecureField::zero(); + + for i in 0..n_terms { + // Input polynomials at points `(r, {0, 1, 2}, bits(i), {0, 1})`. + let inp_numer_at_r0i0 = input_numerators[i * 2]; + let inp_denom_at_r0i0 = input_denominators[i * 2]; + let inp_numer_at_r0i1 = input_numerators[i * 2 + 1]; + let inp_denom_at_r0i1 = input_denominators[i * 2 + 1]; + let inp_numer_at_r1i0 = input_numerators[(n_terms + i) * 2]; + let inp_denom_at_r1i0 = input_denominators[(n_terms + i) * 2]; + let inp_numer_at_r1i1 = input_numerators[(n_terms + i) * 2 + 1]; + let inp_denom_at_r1i1 = input_denominators[(n_terms + i) * 2 + 1]; + // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` + // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` + let inp_numer_at_r2i0 = inp_numer_at_r1i0.double() - inp_numer_at_r0i0; + let inp_denom_at_r2i0 = inp_denom_at_r1i0.double() - inp_denom_at_r0i0; + let inp_numer_at_r2i1 = inp_numer_at_r1i1.double() - inp_numer_at_r0i1; + let inp_denom_at_r2i1 = inp_denom_at_r1i1.double() - inp_denom_at_r0i1; + + // Fraction addition polynomials: + // - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)` + // - `denom(x) = inp_denom(x, 1) * inp_denom(x, 0)` + // at points `(r, {0, 2}, bits(i))`. + let Fraction { + numerator: numer_at_r0i, + denominator: denom_at_r0i, + } = Fraction::new(inp_numer_at_r0i0, inp_denom_at_r0i0) + + Fraction::new(inp_numer_at_r0i1, inp_denom_at_r0i1); + let Fraction { + numerator: numer_at_r2i, + denominator: denom_at_r2i, + } = Fraction::new(inp_numer_at_r2i0, inp_denom_at_r2i0) + + Fraction::new(inp_numer_at_r2i1, inp_denom_at_r2i1); + + let eq_eval_at_0i = eq_evals[i]; + eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i); + eval_at_2 += eq_eval_at_0i * (numer_at_r2i + lambda * denom_at_r2i); + } + + (eval_at_0, eval_at_2) +} + +/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_denom(r, t, x, 1) + inp_denom(r, t, x, 0) + +/// lambda * inp_denom(r, t, x, 0) * inp_denom(r, t, x, 1))` at `t=0` and `t=2`. +/// +/// Output of the form: `(eval_at_0, eval_at_2)`. +fn eval_logup_singles_sum( + eq_evals: &EqEvals, + input_denominators: &Mle, + n_terms: usize, + lambda: SecureField, +) -> (SecureField, SecureField) { + let mut eval_at_0 = SecureField::zero(); + let mut eval_at_2 = SecureField::zero(); + + for i in 0..n_terms { + // Input polynomial at points `(r, {0, 1, 2}, bits(i), {0, 1})`. + let inp_denom_at_r0i0 = input_denominators[i * 2]; + let inp_denom_at_r0i1 = input_denominators[i * 2 + 1]; + let inp_denom_at_r1i0 = input_denominators[(n_terms + i) * 2]; + let inp_denom_at_r1i1 = input_denominators[(n_terms + i) * 2 + 1]; + // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` + // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` + let inp_denom_at_r2i0 = inp_denom_at_r1i0.double() - inp_denom_at_r0i0; + let inp_denom_at_r2i1 = inp_denom_at_r1i1.double() - inp_denom_at_r0i1; + + // Fraction addition polynomials at points: + // - `numer(x) = inp_denom(x, 1) + inp_denom(x, 0)` + // - `denom(x) = inp_denom(x, 1) * inp_denom(x, 0)` + // at points `(r, {0, 2}, bits(i))`. + let Fraction { + numerator: numer_at_r0i, + denominator: denom_at_r0i, + } = Reciprocal::new(inp_denom_at_r0i0) + Reciprocal::new(inp_denom_at_r0i1); + let Fraction { + numerator: numer_at_r2i, + denominator: denom_at_r2i, + } = Reciprocal::new(inp_denom_at_r2i0) + Reciprocal::new(inp_denom_at_r2i1); + + let eq_eval_at_0i = eq_evals[i]; + eval_at_0 += eq_eval_at_0i * (numer_at_r0i + lambda * denom_at_r0i); + eval_at_2 += eq_eval_at_0i * (numer_at_r2i + lambda * denom_at_r2i); + } + + (eval_at_0, eval_at_2) +} + +/// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`. +/// +/// Evaluations are returned in bit-reversed order. +pub fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Vec { + let mut evals = Vec::with_capacity(1 << y.len()); + evals.push(v); + + for &y_i in y.iter().rev() { + for j in 0..evals.len() { + // `lhs[j] = eq(0, y_i) * c[i]` + // `rhs[j] = eq(1, y_i) * c[i]` + let tmp = evals[j] * y_i; + evals.push(tmp); + evals[j] -= tmp; + } + } + + evals +} + +fn next_grand_product_layer(layer: &Mle) -> Layer { + let res = layer.array_chunks().map(|&[a, b]| a * b).collect(); + Layer::GrandProduct(Mle::new(res)) +} + +fn next_logup_layer( + numerators: MleExpr<'_, F>, + denominators: &Mle, +) -> Layer +where + F: Field, + SecureField: ExtensionOf, + CpuBackend: MleOps, +{ + let half_n = 1 << (denominators.n_variables() - 1); + let mut next_numerators = Vec::with_capacity(half_n); + let mut next_denominators = Vec::with_capacity(half_n); + + for i in 0..half_n { + let a = Fraction::new(numerators[i * 2], denominators[i * 2]); + let b = Fraction::new(numerators[i * 2 + 1], denominators[i * 2 + 1]); + let res = a + b; + next_numerators.push(res.numerator); + next_denominators.push(res.denominator); + } + + Layer::LogUpGeneric { + numerators: Mle::new(next_numerators), + denominators: Mle::new(next_denominators), + } +} + +enum MleExpr<'a, F: Field> { + Constant(F), + Mle(&'a Mle), +} + +impl<'a, F: Field> Index for MleExpr<'a, F> { + type Output = F; + + fn index(&self, index: usize) -> &F { + match self { + Self::Constant(v) => v, + Self::Mle(mle) => &mle[index], + } + } +} + +#[cfg(test)] +mod tests { + use std::iter::zip; + + use num_traits::{One, Zero}; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use crate::core::backend::CpuBackend; + use crate::core::channel::Channel; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; + use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError}; + use crate::core::lookups::mle::Mle; + use crate::core::lookups::utils::{eq, Fraction}; + use crate::core::test_utils::test_channel; + + #[test] + fn gen_eq_evals() { + let zero = SecureField::zero(); + let one = SecureField::one(); + let two = BaseField::from(2).into(); + let y = [7, 3].map(|v| BaseField::from(v).into()); + + let eq_evals = CpuBackend::gen_eq_evals(&y, two); + + assert_eq!( + *eq_evals, + [ + eq(&[zero, zero], &y) * two, + eq(&[zero, one], &y) * two, + eq(&[one, zero], &y) * two, + eq(&[one, one], &y) * two, + ] + ); + } + + #[test] + fn grand_product_works() -> Result<(), GkrError> { + const N: usize = 1 << 5; + let values = test_channel().draw_felts(N); + let product = values.iter().product::(); + let col = Mle::::new(values); + let input_layer = Layer::GrandProduct(col.clone()); + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point: r, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; + + assert_eq!(proof.output_claims_by_instance, [vec![product]]); + assert_eq!(claims_to_verify_by_instance, [vec![col.eval_at_point(&r)]]); + Ok(()) + } + + #[test] + fn logup_with_generic_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 5; + let mut rng = SmallRng::seed_from_u64(0); + let numerator_values = (0..N).map(|_| rng.gen()).collect::>(); + let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); + let sum = zip(&numerator_values, &denominator_values) + .map(|(&n, &d)| Fraction::new(n, d)) + .sum::>(); + let numerators = Mle::::new(numerator_values); + let denominators = Mle::::new(denominator_values); + let top_layer = Layer::LogUpGeneric { + numerators: numerators.clone(), + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + numerators.eval_at_point(&ood_point), + denominators.eval_at_point(&ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } + + #[test] + fn logup_with_singles_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 5; + let mut rng = SmallRng::seed_from_u64(0); + let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); + let sum = denominator_values + .iter() + .map(|&d| Fraction::new(SecureField::one(), d)) + .sum::>(); + let denominators = Mle::::new(denominator_values); + let top_layer = Layer::LogUpSingles { + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [SecureField::one(), denominators.eval_at_point(&ood_point)] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } + + #[test] + fn logup_with_multiplicities_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 5; + let mut rng = SmallRng::seed_from_u64(0); + let numerator_values = (0..N).map(|_| rng.gen()).collect::>(); + let denominator_values = (0..N).map(|_| rng.gen()).collect::>(); + let sum = zip(&numerator_values, &denominator_values) + .map(|(&n, &d)| Fraction::new(n.into(), d)) + .sum::>(); + let numerators = Mle::::new(numerator_values); + let denominators = Mle::::new(denominator_values); + let top_layer = Layer::LogUpMultiplicities { + numerators: numerators.clone(), + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![top_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + numerators.eval_at_point(&ood_point), + denominators.eval_at_point(&ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/mle.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/mle.rs new file mode 100644 index 0000000..35d6632 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/mle.rs @@ -0,0 +1,66 @@ +use std::iter::zip; + +use num_traits::{One, Zero}; + +use crate::core::backend::CpuBackend; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::mle::{Mle, MleOps}; +use crate::core::lookups::sumcheck::MultivariatePolyOracle; +use crate::core::lookups::utils::{fold_mle_evals, UnivariatePoly}; + +impl MleOps for CpuBackend { + fn fix_first_variable( + mle: Mle, + assignment: SecureField, + ) -> Mle { + let midpoint = mle.len() / 2; + let (lhs_evals, rhs_evals) = mle.split_at(midpoint); + + let res = zip(lhs_evals, rhs_evals) + .map(|(&lhs_eval, &rhs_eval)| fold_mle_evals(assignment, lhs_eval, rhs_eval)) + .collect(); + + Mle::new(res) + } +} + +impl MleOps for CpuBackend { + fn fix_first_variable( + mle: Mle, + assignment: SecureField, + ) -> Mle { + let midpoint = mle.len() / 2; + let mut evals = mle.into_evals(); + + for i in 0..midpoint { + let lhs_eval = evals[i]; + let rhs_eval = evals[i + midpoint]; + evals[i] = fold_mle_evals(assignment, lhs_eval, rhs_eval); + } + + evals.truncate(midpoint); + + Mle::new(evals) + } +} + +impl MultivariatePolyOracle for Mle { + fn n_variables(&self) -> usize { + self.n_variables() + } + + fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly { + let x0 = SecureField::zero(); + let x1 = SecureField::one(); + + let y0 = self[0..self.len() / 2].iter().sum(); + let y1 = claim - y0; + + UnivariatePoly::interpolate_lagrange(&[x0, x1], &[y0, y1]) + } + + fn fix_first_variable(self, challenge: SecureField) -> Self { + self.fix_first_variable(challenge) + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/mod.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/mod.rs new file mode 100644 index 0000000..cd8dedf --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/cpu/lookups/mod.rs @@ -0,0 +1,2 @@ +pub mod gkr; +mod mle; diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/mod.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/mod.rs new file mode 100644 index 0000000..579b735 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/cpu/mod.rs @@ -0,0 +1,105 @@ +mod accumulation; +mod blake2s; +mod circle; +mod fri; +mod grind; +pub mod lookups; +#[cfg(not(target_arch = "wasm32"))] +mod poseidon252; +pub mod quotients; +#[cfg(not(target_arch = "wasm32"))] +mod poseidon_bls; + +use std::fmt::Debug; + +use serde::{Deserialize, Serialize}; + +use super::{Backend, BackendForChannel, Column, ColumnOps, FieldOps}; +use crate::core::fields::Field; +use crate::core::lookups::mle::Mle; +use crate::core::poly::circle::{CircleEvaluation, CirclePoly}; +use crate::core::utils::bit_reverse; +use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; +#[cfg(not(target_arch = "wasm32"))] +use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel; + +#[cfg(not(target_arch = "wasm32"))] +use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleChannel; + +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] +pub struct CpuBackend; + +impl Backend for CpuBackend {} +impl BackendForChannel for CpuBackend {} +#[cfg(not(target_arch = "wasm32"))] +impl BackendForChannel for CpuBackend {} + +#[cfg(not(target_arch = "wasm32"))] +impl BackendForChannel for CpuBackend {} + +impl ColumnOps for CpuBackend { + type Column = Vec; + + fn bit_reverse_column(column: &mut Self::Column) { + bit_reverse(column) + } +} + +impl FieldOps for CpuBackend { + /// Batch inversion using the Montgomery's trick. + // TODO(Ohad): Benchmark this function. + fn batch_inverse(column: &Self::Column, dst: &mut Self::Column) { + F::batch_inverse(column, &mut dst[..]); + } +} + +impl Column for Vec { + fn zeros(len: usize) -> Self { + vec![T::default(); len] + } + #[allow(clippy::uninit_vec)] + unsafe fn uninitialized(length: usize) -> Self { + let mut data = Vec::with_capacity(length); + data.set_len(length); + data + } + fn to_cpu(&self) -> Vec { + self.clone() + } + fn len(&self) -> usize { + self.len() + } + fn at(&self, index: usize) -> T { + self[index].clone() + } + fn set(&mut self, index: usize, value: T) { + self[index] = value; + } +} + +pub type CpuCirclePoly = CirclePoly; +pub type CpuCircleEvaluation = CircleEvaluation; +pub type CpuMle = Mle; + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use rand::prelude::*; + use rand::rngs::SmallRng; + + use crate::core::backend::{Column, CpuBackend, FieldOps}; + use crate::core::fields::qm31::QM31; + use crate::core::fields::FieldExpOps; + + #[test] + fn batch_inverse_test() { + let mut rng = SmallRng::seed_from_u64(0); + let column = rng.gen::<[QM31; 16]>().to_vec(); + let expected = column.iter().map(|e| e.inverse()).collect_vec(); + let mut dst = Column::zeros(column.len()); + + CpuBackend::batch_inverse(&column, &mut dst); + + assert_eq!(expected, dst); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/poseidon252.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/poseidon252.rs new file mode 100644 index 0000000..8cc5dd9 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/cpu/poseidon252.rs @@ -0,0 +1,24 @@ +use itertools::Itertools; +use starknet_ff::FieldElement as FieldElement252; + +use super::CpuBackend; +use crate::core::fields::m31::BaseField; +use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; +use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher; + +impl MerkleOps for CpuBackend { + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Vec>, + columns: &[&Vec], + ) -> Vec { + (0..(1 << log_size)) + .map(|i| { + Poseidon252MerkleHasher::hash_node( + prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), + &columns.iter().map(|column| column[i]).collect_vec(), + ) + }) + .collect() + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/poseidon_bls.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/poseidon_bls.rs new file mode 100644 index 0000000..a6cca33 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/cpu/poseidon_bls.rs @@ -0,0 +1,24 @@ +use itertools::Itertools; +use ark_bls12_381::Fr as BlsFr; + +use super::CpuBackend; +use crate::core::fields::m31::BaseField; +use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; +use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleHasher; + +impl MerkleOps for CpuBackend { + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Vec>, + columns: &[&Vec], + ) -> Vec { + (0..(1 << log_size)) + .map(|i| { + PoseidonBLSMerkleHasher::hash_node( + prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), + &columns.iter().map(|column| column[i]).collect_vec(), + ) + }) + .collect() + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/cpu/quotients.rs b/Stwo_wrapper/crates/prover/src/core/backend/cpu/quotients.rs new file mode 100644 index 0000000..17cc007 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/cpu/quotients.rs @@ -0,0 +1,210 @@ +use itertools::{izip, zip_eq}; +use num_traits::{One, Zero}; + +use super::CpuBackend; +use crate::core::circle::CirclePoint; +use crate::core::constraints::complex_conjugate_line_coeffs; +use crate::core::fields::cm31::CM31; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SecureColumnByCoords; +use crate::core::fields::FieldExpOps; +use crate::core::pcs::quotients::{ColumnSampleBatch, PointSample, QuotientOps}; +use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::{bit_reverse, bit_reverse_index}; + +impl QuotientOps for CpuBackend { + fn accumulate_quotients( + domain: CircleDomain, + columns: &[&CircleEvaluation], + random_coeff: SecureField, + sample_batches: &[ColumnSampleBatch], + _log_blowup_factor: u32, + ) -> SecureEvaluation { + let mut values = unsafe { SecureColumnByCoords::uninitialized(domain.size()) }; + let quotient_constants = quotient_constants(sample_batches, random_coeff, domain); + + for row in 0..domain.size() { + let domain_point = domain.at(bit_reverse_index(row, domain.log_size())); + let row_value = accumulate_row_quotients( + sample_batches, + columns, + "ient_constants, + row, + domain_point, + ); + values.set(row, row_value); + } + SecureEvaluation::new(domain, values) + } +} + +pub fn accumulate_row_quotients( + sample_batches: &[ColumnSampleBatch], + columns: &[&CircleEvaluation], + quotient_constants: &QuotientConstants, + row: usize, + domain_point: CirclePoint, +) -> SecureField { + let mut row_accumulator = SecureField::zero(); + for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!( + sample_batches, + "ient_constants.line_coeffs, + "ient_constants.batch_random_coeffs, + "ient_constants.denominator_inverses + ) { + let mut numerator = SecureField::zero(); + for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs) + { + let column = &columns[*column_index]; + let value = column[row] * *c; + // The numerator is a line equation passing through + // (sample_point.y, sample_value), (conj(sample_point), conj(sample_value)) + // evaluated at (domain_point.y, value). + // When substituting a polynomial in this line equation, we get a polynomial with a root + // at sample_point and conj(sample_point) if the original polynomial had the values + // sample_value and conj(sample_value) at these points. + let linear_term = *a * domain_point.y + *b; + numerator += value - linear_term; + } + + row_accumulator = + row_accumulator * *batch_coeff + numerator.mul_cm31(denominator_inverses[row]); + } + row_accumulator +} + +/// Precompute the complex conjugate line coefficients for each column in each sample batch. +/// Specifically, for the i-th (in a sample batch) column's numerator term +/// `alpha^i * (c * F(p) - (a * p.y + b))`, we precompute and return the constants: +/// (`alpha^i * a`, `alpha^i * b`, `alpha^i * c`). +pub fn column_line_coeffs( + sample_batches: &[ColumnSampleBatch], + random_coeff: SecureField, +) -> Vec> { + sample_batches + .iter() + .map(|sample_batch| { + let mut alpha = SecureField::one(); + sample_batch + .columns_and_values + .iter() + .map(|(_, sampled_value)| { + alpha *= random_coeff; + let sample = PointSample { + point: sample_batch.point, + value: *sampled_value, + }; + complex_conjugate_line_coeffs(&sample, alpha) + }) + .collect() + }) + .collect() +} + +/// Precompute the random coefficients used to linearly combine the batched quotients. +/// Specifically, for each sample batch we compute random_coeff^(number of columns in the batch), +/// which is used to linearly combine the batch with the next one. +pub fn batch_random_coeffs( + sample_batches: &[ColumnSampleBatch], + random_coeff: SecureField, +) -> Vec { + sample_batches + .iter() + .map(|sb| random_coeff.pow(sb.columns_and_values.len() as u128)) + .collect() +} + +fn denominator_inverses( + sample_batches: &[ColumnSampleBatch], + domain: CircleDomain, +) -> Vec> { + let mut flat_denominators = Vec::with_capacity(sample_batches.len() * domain.size()); + // We want a P to be on a line that passes through a point Pr + uPi in QM31^2, and its conjugate + // Pr - uPi. Thus, Pr - P is parallel to Pi. Or, (Pr - P).x * Pi.y - (Pr - P).y * Pi.x = 0. + for sample_batch in sample_batches { + // Extract Pr, Pi. + let prx = sample_batch.point.x.0; + let pry = sample_batch.point.y.0; + let pix = sample_batch.point.x.1; + let piy = sample_batch.point.y.1; + for row in 0..domain.size() { + let domain_point = domain.at(row); + flat_denominators.push((prx - domain_point.x) * piy - (pry - domain_point.y) * pix); + } + } + + let mut flat_denominator_inverses = vec![CM31::zero(); flat_denominators.len()]; + CM31::batch_inverse(&flat_denominators, &mut flat_denominator_inverses); + + flat_denominator_inverses + .chunks_mut(domain.size()) + .map(|denominator_inverses| { + bit_reverse(denominator_inverses); + denominator_inverses.to_vec() + }) + .collect() +} + +pub fn quotient_constants( + sample_batches: &[ColumnSampleBatch], + random_coeff: SecureField, + domain: CircleDomain, +) -> QuotientConstants { + let line_coeffs = column_line_coeffs(sample_batches, random_coeff); + let batch_random_coeffs = batch_random_coeffs(sample_batches, random_coeff); + let denominator_inverses = denominator_inverses(sample_batches, domain); + QuotientConstants { + line_coeffs, + batch_random_coeffs, + denominator_inverses, + } +} + +/// Holds the precomputed constant values used in each quotient evaluation. +pub struct QuotientConstants { + /// The line coefficients for each quotient numerator term. For more details see + /// [self::column_line_coeffs]. + pub line_coeffs: Vec>, + /// The random coefficients used to linearly combine the batched quotients For more details see + /// [self::batch_random_coeffs]. + pub batch_random_coeffs: Vec, + /// The inverses of the denominators of the quotients. + pub denominator_inverses: Vec>, +} + +#[cfg(test)] +mod tests { + use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly}; + use crate::core::backend::CpuBackend; + use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; + use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; + use crate::core::poly::circle::CanonicCoset; + use crate::{m31, qm31}; + + #[test] + fn test_quotients_are_low_degree() { + const LOG_SIZE: u32 = 7; + const LOG_BLOWUP_FACTOR: u32 = 1; + let polynomial = CpuCirclePoly::new((0..1 << LOG_SIZE).map(|i| m31!(i)).collect()); + let eval_domain = CanonicCoset::new(LOG_SIZE + 1).circle_domain(); + let eval = polynomial.evaluate(eval_domain); + let point = SECURE_FIELD_CIRCLE_GEN; + let value = polynomial.eval_at_point(point); + let coeff = qm31!(1, 2, 3, 4); + let quot_eval = CpuBackend::accumulate_quotients( + eval_domain, + &[&eval], + coeff, + &[ColumnSampleBatch { + point, + columns_and_values: vec![(0, value)], + }], + LOG_BLOWUP_FACTOR, + ); + let quot_poly_base_field = + CpuCircleEvaluation::new(eval_domain, quot_eval.columns[0].clone()).interpolate(); + assert!(quot_poly_base_field.is_in_fri_space(LOG_SIZE)); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/mod.rs b/Stwo_wrapper/crates/prover/src/core/backend/mod.rs new file mode 100644 index 0000000..f6eae91 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/mod.rs @@ -0,0 +1,66 @@ +use std::fmt::Debug; + +pub use cpu::CpuBackend; + +use super::air::accumulation::AccumulationOps; +use super::channel::MerkleChannel; +use super::fields::m31::BaseField; +use super::fields::qm31::SecureField; +use super::fields::FieldOps; +use super::fri::FriOps; +use super::lookups::gkr_prover::GkrOps; +use super::pcs::quotients::QuotientOps; +use super::poly::circle::PolyOps; +use super::proof_of_work::GrindOps; +use super::vcs::ops::MerkleOps; + +pub mod cpu; +pub mod simd; + +pub trait Backend: + Copy + + Clone + + Debug + + FieldOps + + FieldOps + + PolyOps + + QuotientOps + + FriOps + + AccumulationOps + + GkrOps +{ +} + +pub trait BackendForChannel: + Backend + MerkleOps + GrindOps +{ +} + +pub trait ColumnOps { + type Column: Column; + fn bit_reverse_column(column: &mut Self::Column); +} + +pub type Col = >::Column; + +// TODO(spapini): Consider removing the generic parameter and only support BaseField. +pub trait Column: Clone + Debug + FromIterator { + /// Creates a new column of zeros with the given length. + fn zeros(len: usize) -> Self; + /// Creates a new column of uninitialized values with the given length. + /// # Safety + /// The caller must ensure that the column is populated before being used. + unsafe fn uninitialized(len: usize) -> Self; + /// Returns a cpu vector of the column. + fn to_cpu(&self) -> Vec; + /// Returns the length of the column. + fn len(&self) -> usize; + /// Returns true if the column is empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Retrieves the element at the given index. + fn at(&self, index: usize) -> T; + /// Sets the element at the given index. + fn set(&mut self, index: usize, value: T); +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/accumulation.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/accumulation.rs new file mode 100644 index 0000000..c9705df --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/accumulation.rs @@ -0,0 +1,12 @@ +use super::SimdBackend; +use crate::core::air::accumulation::AccumulationOps; +use crate::core::fields::secure_column::SecureColumnByCoords; + +impl AccumulationOps for SimdBackend { + fn accumulate(column: &mut SecureColumnByCoords, other: &SecureColumnByCoords) { + for i in 0..column.packed_len() { + let res_coeff = unsafe { column.packed_at(i) + other.packed_at(i) }; + unsafe { column.set_packed(i, res_coeff) }; + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/bit_reverse.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/bit_reverse.rs new file mode 100644 index 0000000..13d6585 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/bit_reverse.rs @@ -0,0 +1,203 @@ +use std::array; + +use super::column::{BaseColumn, SecureColumn}; +use super::m31::PackedBaseField; +use super::SimdBackend; +use crate::core::backend::ColumnOps; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::utils::{bit_reverse as cpu_bit_reverse, bit_reverse_index}; + +const VEC_BITS: u32 = 4; + +const W_BITS: u32 = 3; + +pub const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS; + +impl ColumnOps for SimdBackend { + type Column = BaseColumn; + + fn bit_reverse_column(column: &mut Self::Column) { + // Fallback to cpu bit_reverse. + if column.data.len().ilog2() < MIN_LOG_SIZE { + cpu_bit_reverse(column.as_mut_slice()); + return; + } + + bit_reverse_m31(&mut column.data); + } +} + +impl ColumnOps for SimdBackend { + type Column = SecureColumn; + + fn bit_reverse_column(_column: &mut SecureColumn) { + todo!() + } +} + +/// Bit reverses M31 values. +/// +/// Given an array `A[0..2^n)`, computes `B[i] = A[bit_reverse(i)]`. +pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { + assert!(data.len().is_power_of_two()); + assert!(data.len().ilog2() >= MIN_LOG_SIZE); + + // Indices in the array are of the form v_h w_h a w_l v_l, with + // |v_h| = |v_l| = VEC_BITS, |w_h| = |w_l| = W_BITS, |a| = n - 2*W_BITS - VEC_BITS. + // The loops go over a, w_l, w_h, and then swaps the 16 by 16 values at: + // * w_h a w_l * <-> * rev(w_h a w_l) *. + // These are 1 or 2 chunks of 2^W_BITS contiguous `u32x16` vectors. + + let log_size = data.len().ilog2(); + let a_bits = log_size - 2 * W_BITS - VEC_BITS; + + // TODO(spapini): when doing multithreading, do it over a. + for a in 0u32..1 << a_bits { + for w_l in 0u32..1 << W_BITS { + let w_l_rev = w_l.reverse_bits() >> (u32::BITS - W_BITS); + for w_h in 0..w_l_rev + 1 { + let idx = ((((w_h << a_bits) | a) << W_BITS) | w_l) as usize; + let idx_rev = bit_reverse_index(idx, log_size - VEC_BITS); + + // In order to not swap twice, only swap if idx <= idx_rev. + if idx > idx_rev { + continue; + } + + // Read first chunk. + // TODO(spapini): Think about optimizing a_bits. + let chunk0 = array::from_fn(|i| unsafe { + *data.get_unchecked(idx + (i << (2 * W_BITS + a_bits))) + }); + let values0 = bit_reverse16(chunk0); + + if idx == idx_rev { + // Palindrome index. Write into the same chunk. + #[allow(clippy::needless_range_loop)] + for i in 0..16 { + unsafe { + *data.get_unchecked_mut(idx + (i << (2 * W_BITS + a_bits))) = + values0[i]; + } + } + continue; + } + + // Read bit reversed chunk. + let chunk1 = array::from_fn(|i| unsafe { + *data.get_unchecked(idx_rev + (i << (2 * W_BITS + a_bits))) + }); + let values1 = bit_reverse16(chunk1); + + for i in 0..16 { + unsafe { + *data.get_unchecked_mut(idx + (i << (2 * W_BITS + a_bits))) = values1[i]; + *data.get_unchecked_mut(idx_rev + (i << (2 * W_BITS + a_bits))) = + values0[i]; + } + } + } + } + } +} + +/// Bit reverses 256 M31 values, packed in 16 words of 16 elements each. +fn bit_reverse16(mut data: [PackedBaseField; 16]) -> [PackedBaseField; 16] { + // Denote the index of each element in the 16 packed M31 words as abcd:0123, + // where abcd is the index of the packed word and 0123 is the index of the element in the word. + // Bit reversal is achieved by applying the following permutation to the index for 4 times: + // abcd:0123 => 0abc:123d + // This is how it looks like at each iteration. + // abcd:0123 + // 0abc:123d + // 10ab:23dc + // 210a:3dcb + // 3210:dcba + for _ in 0..4 { + // Apply the abcd:0123 => 0abc:123d permutation. + // `interleave` allows us to interleave the first half of 2 words. + // For example, the second call interleaves 0010:0xyz (low half of register 2) with + // 0011:0xyz (low half of register 3), and stores the result in register 1 (0001). + // This results in + // 0001:xyz0 (even indices of register 1) <= 0010:0xyz (low half of register2), and + // 0001:xyz1 (odd indices of register 1) <= 0011:0xyz (low half of register 3) + // or 0001:xyzw <= 001w:0xyz. + let (d0, d8) = data[0].interleave(data[1]); + let (d1, d9) = data[2].interleave(data[3]); + let (d2, d10) = data[4].interleave(data[5]); + let (d3, d11) = data[6].interleave(data[7]); + let (d4, d12) = data[8].interleave(data[9]); + let (d5, d13) = data[10].interleave(data[11]); + let (d6, d14) = data[12].interleave(data[13]); + let (d7, d15) = data[14].interleave(data[15]); + data = [ + d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15, + ]; + } + + data +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use super::{bit_reverse16, bit_reverse_m31, MIN_LOG_SIZE}; + use crate::core::backend::simd::column::BaseColumn; + use crate::core::backend::simd::m31::{PackedM31, N_LANES}; + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::{Column, ColumnOps}; + use crate::core::fields::m31::BaseField; + use crate::core::utils::bit_reverse as cpu_bit_reverse; + + #[test] + fn test_bit_reverse16() { + let values: BaseColumn = (0..N_LANES * 16).map(BaseField::from).collect(); + let mut expected = values.to_cpu(); + cpu_bit_reverse(&mut expected); + + let res = bit_reverse16(values.data.try_into().unwrap()); + + assert_eq!(res.map(PackedM31::to_array).flatten(), expected); + } + + #[test] + fn bit_reverse_m31_works() { + const SIZE: usize = 1 << 15; + let data: Vec<_> = (0..SIZE).map(BaseField::from).collect(); + let mut expected = data.clone(); + cpu_bit_reverse(&mut expected); + + let mut res: BaseColumn = data.into_iter().collect(); + bit_reverse_m31(&mut res.data[..]); + + assert_eq!(res.to_cpu(), expected); + } + + #[test] + fn bit_reverse_small_column_works() { + const LOG_SIZE: u32 = MIN_LOG_SIZE - 1; + let column = (0..1 << LOG_SIZE).map(BaseField::from).collect_vec(); + let mut expected = column.clone(); + cpu_bit_reverse(&mut expected); + + let mut res = column.iter().copied().collect::(); + >::bit_reverse_column(&mut res); + + assert_eq!(res.to_cpu(), expected); + } + + #[test] + fn bit_reverse_large_column_works() { + const LOG_SIZE: u32 = MIN_LOG_SIZE; + let column = (0..1 << LOG_SIZE).map(BaseField::from).collect_vec(); + let mut expected = column.clone(); + cpu_bit_reverse(&mut expected); + + let mut res = column.iter().copied().collect::(); + >::bit_reverse_column(&mut res); + + assert_eq!(res.to_cpu(), expected); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/blake2s.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/blake2s.rs new file mode 100644 index 0000000..fbcfe89 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/blake2s.rs @@ -0,0 +1,412 @@ +//! A SIMD implementation of the BLAKE2s compression function. +//! Based on . + +use std::array; +use std::iter::repeat; +use std::mem::transmute; +use std::simd::u32x16; + +use bytemuck::cast_slice; +use itertools::Itertools; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +use super::m31::{LOG_N_LANES, N_LANES}; +use super::SimdBackend; +use crate::core::backend::{Col, Column, ColumnOps}; +use crate::core::fields::m31::BaseField; +use crate::core::vcs::blake2_hash::Blake2sHash; +use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher; +use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; + +const IV: [u32; 8] = [ + 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, +]; + +pub const SIGMA: [[u8; 16]; 10] = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3], + [11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4], + [7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8], + [9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13], + [2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9], + [12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11], + [13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10], + [6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5], + [10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0], +]; + +impl ColumnOps for SimdBackend { + type Column = Vec; + + fn bit_reverse_column(_column: &mut Self::Column) { + unimplemented!() + } +} + +impl MerkleOps for SimdBackend { + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Vec>, + columns: &[&Col], + ) -> Vec { + if log_size < LOG_N_LANES { + #[cfg(not(feature = "parallel"))] + let iter = 0..1 << log_size; + + #[cfg(feature = "parallel")] + let iter = (0..1 << log_size).into_par_iter(); + + return iter + .map(|i| { + Blake2sMerkleHasher::hash_node( + prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), + &columns.iter().map(|column| column.at(i)).collect_vec(), + ) + }) + .collect(); + } + + if let Some(prev_layer) = prev_layer { + assert_eq!(prev_layer.len(), 1 << (log_size + 1)); + } + + let zeros = u32x16::splat(0); + + // Commit to columns. + let mut res = vec![Blake2sHash::default(); 1 << log_size]; + #[cfg(not(feature = "parallel"))] + let iter = res.chunks_mut(1 << LOG_N_LANES); + + #[cfg(feature = "parallel")] + let iter = res.par_chunks_mut(1 << LOG_N_LANES); + + iter.enumerate().for_each(|(i, chunk)| { + let mut state: [u32x16; 8] = unsafe { std::mem::zeroed() }; + // Hash prev_layer, if exists. + if let Some(prev_layer) = prev_layer { + let prev_chunk_u32s = cast_slice::<_, u32>(&prev_layer[(i << 5)..((i + 1) << 5)]); + // Note: prev_layer might be unaligned. + let msgs: [u32x16; 16] = array::from_fn(|j| { + u32x16::from_array(std::array::from_fn(|k| prev_chunk_u32s[16 * j + k])) + }); + state = compress16(state, transpose_msgs(msgs), zeros, zeros, zeros, zeros); + } + + // Hash columns in chunks of 16. + let mut col_chunk_iter = columns.array_chunks(); + for col_chunk in &mut col_chunk_iter { + let msgs = col_chunk.map(|column| column.data[i].into_simd()); + state = compress16(state, msgs, zeros, zeros, zeros, zeros); + } + + // Hash remaining columns. + let remainder = col_chunk_iter.remainder(); + if !remainder.is_empty() { + let msgs = remainder + .iter() + .map(|column| column.data[i].into_simd()) + .chain(repeat(zeros)) + .take(N_LANES) + .collect_vec() + .try_into() + .unwrap(); + state = compress16(state, msgs, zeros, zeros, zeros, zeros); + } + let state: [Blake2sHash; 16] = unsafe { transmute(untranspose_states(state)) }; + chunk.copy_from_slice(&state); + }); + res + } +} + +/// Applies [`u32::rotate_right(N)`] to each element of the vector +/// +/// [`u32::rotate_right(N)`]: u32::rotate_right +#[inline(always)] +fn rotate(x: u32x16) -> u32x16 { + (x >> N) | (x << (u32::BITS - N)) +} + +// `inline(always)` can cause code parsing errors for wasm: "locals exceed maximum". +#[cfg_attr(not(target_arch = "wasm32"), inline(always))] +pub fn round(v: &mut [u32x16; 16], m: [u32x16; 16], r: usize) { + v[0] += m[SIGMA[r][0] as usize]; + v[1] += m[SIGMA[r][2] as usize]; + v[2] += m[SIGMA[r][4] as usize]; + v[3] += m[SIGMA[r][6] as usize]; + v[0] += v[4]; + v[1] += v[5]; + v[2] += v[6]; + v[3] += v[7]; + v[12] ^= v[0]; + v[13] ^= v[1]; + v[14] ^= v[2]; + v[15] ^= v[3]; + v[12] = rotate::<16>(v[12]); + v[13] = rotate::<16>(v[13]); + v[14] = rotate::<16>(v[14]); + v[15] = rotate::<16>(v[15]); + v[8] += v[12]; + v[9] += v[13]; + v[10] += v[14]; + v[11] += v[15]; + v[4] ^= v[8]; + v[5] ^= v[9]; + v[6] ^= v[10]; + v[7] ^= v[11]; + v[4] = rotate::<12>(v[4]); + v[5] = rotate::<12>(v[5]); + v[6] = rotate::<12>(v[6]); + v[7] = rotate::<12>(v[7]); + v[0] += m[SIGMA[r][1] as usize]; + v[1] += m[SIGMA[r][3] as usize]; + v[2] += m[SIGMA[r][5] as usize]; + v[3] += m[SIGMA[r][7] as usize]; + v[0] += v[4]; + v[1] += v[5]; + v[2] += v[6]; + v[3] += v[7]; + v[12] ^= v[0]; + v[13] ^= v[1]; + v[14] ^= v[2]; + v[15] ^= v[3]; + v[12] = rotate::<8>(v[12]); + v[13] = rotate::<8>(v[13]); + v[14] = rotate::<8>(v[14]); + v[15] = rotate::<8>(v[15]); + v[8] += v[12]; + v[9] += v[13]; + v[10] += v[14]; + v[11] += v[15]; + v[4] ^= v[8]; + v[5] ^= v[9]; + v[6] ^= v[10]; + v[7] ^= v[11]; + v[4] = rotate::<7>(v[4]); + v[5] = rotate::<7>(v[5]); + v[6] = rotate::<7>(v[6]); + v[7] = rotate::<7>(v[7]); + + v[0] += m[SIGMA[r][8] as usize]; + v[1] += m[SIGMA[r][10] as usize]; + v[2] += m[SIGMA[r][12] as usize]; + v[3] += m[SIGMA[r][14] as usize]; + v[0] += v[5]; + v[1] += v[6]; + v[2] += v[7]; + v[3] += v[4]; + v[15] ^= v[0]; + v[12] ^= v[1]; + v[13] ^= v[2]; + v[14] ^= v[3]; + v[15] = rotate::<16>(v[15]); + v[12] = rotate::<16>(v[12]); + v[13] = rotate::<16>(v[13]); + v[14] = rotate::<16>(v[14]); + v[10] += v[15]; + v[11] += v[12]; + v[8] += v[13]; + v[9] += v[14]; + v[5] ^= v[10]; + v[6] ^= v[11]; + v[7] ^= v[8]; + v[4] ^= v[9]; + v[5] = rotate::<12>(v[5]); + v[6] = rotate::<12>(v[6]); + v[7] = rotate::<12>(v[7]); + v[4] = rotate::<12>(v[4]); + v[0] += m[SIGMA[r][9] as usize]; + v[1] += m[SIGMA[r][11] as usize]; + v[2] += m[SIGMA[r][13] as usize]; + v[3] += m[SIGMA[r][15] as usize]; + v[0] += v[5]; + v[1] += v[6]; + v[2] += v[7]; + v[3] += v[4]; + v[15] ^= v[0]; + v[12] ^= v[1]; + v[13] ^= v[2]; + v[14] ^= v[3]; + v[15] = rotate::<8>(v[15]); + v[12] = rotate::<8>(v[12]); + v[13] = rotate::<8>(v[13]); + v[14] = rotate::<8>(v[14]); + v[10] += v[15]; + v[11] += v[12]; + v[8] += v[13]; + v[9] += v[14]; + v[5] ^= v[10]; + v[6] ^= v[11]; + v[7] ^= v[8]; + v[4] ^= v[9]; + v[5] = rotate::<7>(v[5]); + v[6] = rotate::<7>(v[6]); + v[7] = rotate::<7>(v[7]); + v[4] = rotate::<7>(v[4]); +} + +/// Transposes input chunks (16 chunks of 16 `u32`s each), to get 16 `u32x16`, each +/// representing 16 packed instances of a message word. +fn transpose_msgs(mut data: [u32x16; 16]) -> [u32x16; 16] { + // Index abcd:xyzw, refers to a specific word in data as follows: + // abcd - chunk index (in base 2) + // xyzw - word offset (in base 2) + // Transpose by applying 4 times the index permutation: + // abcd:xyzw => wabc:dxyz + // In other words, rotate the index to the right by 1. + for _ in 0..4 { + let (d0, d8) = data[0].deinterleave(data[1]); + let (d1, d9) = data[2].deinterleave(data[3]); + let (d2, d10) = data[4].deinterleave(data[5]); + let (d3, d11) = data[6].deinterleave(data[7]); + let (d4, d12) = data[8].deinterleave(data[9]); + let (d5, d13) = data[10].deinterleave(data[11]); + let (d6, d14) = data[12].deinterleave(data[13]); + let (d7, d15) = data[14].deinterleave(data[15]); + data = [ + d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15, + ]; + } + + data +} + +fn untranspose_states(mut states: [u32x16; 8]) -> [u32x16; 8] { + // Index abc:xyzw, refers to a specific word in data as follows: + // abc - chunk index (in base 2) + // xyzw - word offset (in base 2) + // Transpose by applying 3 times the index permutation: + // abc:xyzw => bcx:yzwa + // In other words, rotate the index to the left by 1. + for _ in 0..3 { + let (d0, d1) = states[0].interleave(states[4]); + let (d2, d3) = states[1].interleave(states[5]); + let (d4, d5) = states[2].interleave(states[6]); + let (d6, d7) = states[3].interleave(states[7]); + states = [d0, d1, d2, d3, d4, d5, d6, d7]; + } + states +} + +/// Compresses 16 blake2s instances. +pub fn compress16( + h_vecs: [u32x16; 8], + msg_vecs: [u32x16; 16], + count_low: u32x16, + count_high: u32x16, + lastblock: u32x16, + lastnode: u32x16, +) -> [u32x16; 8] { + let mut v = [ + h_vecs[0], + h_vecs[1], + h_vecs[2], + h_vecs[3], + h_vecs[4], + h_vecs[5], + h_vecs[6], + h_vecs[7], + u32x16::splat(IV[0]), + u32x16::splat(IV[1]), + u32x16::splat(IV[2]), + u32x16::splat(IV[3]), + u32x16::splat(IV[4]) ^ count_low, + u32x16::splat(IV[5]) ^ count_high, + u32x16::splat(IV[6]) ^ lastblock, + u32x16::splat(IV[7]) ^ lastnode, + ]; + + round(&mut v, msg_vecs, 0); + round(&mut v, msg_vecs, 1); + round(&mut v, msg_vecs, 2); + round(&mut v, msg_vecs, 3); + round(&mut v, msg_vecs, 4); + round(&mut v, msg_vecs, 5); + round(&mut v, msg_vecs, 6); + round(&mut v, msg_vecs, 7); + round(&mut v, msg_vecs, 8); + round(&mut v, msg_vecs, 9); + + [ + h_vecs[0] ^ v[0] ^ v[8], + h_vecs[1] ^ v[1] ^ v[9], + h_vecs[2] ^ v[2] ^ v[10], + h_vecs[3] ^ v[3] ^ v[11], + h_vecs[4] ^ v[4] ^ v[12], + h_vecs[5] ^ v[5] ^ v[13], + h_vecs[6] ^ v[6] ^ v[14], + h_vecs[7] ^ v[7] ^ v[15], + ] +} + +#[cfg(test)] +mod tests { + use std::array; + use std::mem::transmute; + use std::simd::u32x16; + + use aligned::{Aligned, A64}; + + use super::{compress16, transpose_msgs, untranspose_states}; + use crate::core::vcs::blake2s_ref::compress; + + #[test] + fn compress16_works() { + let states: Aligned = + Aligned(array::from_fn(|i| array::from_fn(|j| (i + j) as u32))); + let msgs: Aligned = + Aligned(array::from_fn(|i| array::from_fn(|j| (i + j + 20) as u32))); + let count_low = 1; + let count_high = 2; + let lastblock = 3; + let lastnode = 4; + let res_unvectorized = array::from_fn(|i| { + compress( + states[i], msgs[i], count_low, count_high, lastblock, lastnode, + ) + }); + + let res_vectorized: [[u32; 8]; 16] = unsafe { + transmute(untranspose_states(compress16( + transpose_states(transmute(states)), + transpose_msgs(transmute(msgs)), + u32x16::splat(count_low), + u32x16::splat(count_high), + u32x16::splat(lastblock), + u32x16::splat(lastnode), + ))) + }; + + assert_eq!(res_vectorized, res_unvectorized); + } + + #[test] + fn untranspose_states_is_transpose_states_inverse() { + let states = array::from_fn(|i| u32x16::from(array::from_fn(|j| (i + j) as u32))); + let transposed_states = transpose_states(states); + + let untrasponsed_transposed_states = untranspose_states(transposed_states); + + assert_eq!(untrasponsed_transposed_states, states) + } + + /// Transposes states, from 8 packed words, to get 16 results, each of size 32B. + fn transpose_states(mut states: [u32x16; 8]) -> [u32x16; 8] { + // Index abc:xyzw, refers to a specific word in data as follows: + // abc - chunk index (in base 2) + // xyzw - word offset (in base 2) + // Transpose by applying 3 times the index permutation: + // abc:xyzw => wab:cxyz + // In other words, rotate the index to the right by 1. + for _ in 0..3 { + let (s0, s4) = states[0].deinterleave(states[1]); + let (s1, s5) = states[2].deinterleave(states[3]); + let (s2, s6) = states[4].deinterleave(states[5]); + let (s3, s7) = states[6].deinterleave(states[7]); + states = [s0, s1, s2, s3, s4, s5, s6, s7]; + } + + states + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/circle.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/circle.rs new file mode 100644 index 0000000..e930f77 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/circle.rs @@ -0,0 +1,436 @@ +use std::iter::zip; +use std::mem::transmute; + +use bytemuck::{cast_slice, Zeroable}; +use num_traits::One; + +use super::fft::{ifft, rfft, CACHED_FFT_LOG_SIZE}; +use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; +use super::qm31::PackedSecureField; +use super::SimdBackend; +use crate::core::backend::simd::column::BaseColumn; +use crate::core::backend::{Col, CpuBackend}; +use crate::core::circle::{CirclePoint, Coset}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::{Field, FieldExpOps}; +use crate::core::poly::circle::{ + CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps, +}; +use crate::core::poly::twiddles::TwiddleTree; +use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold}; +use crate::core::poly::BitReversedOrder; + +impl SimdBackend { + // TODO(Ohad): optimize. + fn twiddle_at(mappings: &[F], mut index: usize) -> F { + debug_assert!( + (1 << mappings.len()) as usize >= index, + "Index out of bounds. mappings log len = {}, index = {index}", + mappings.len().ilog2() + ); + + let mut product = F::one(); + for &num in mappings.iter() { + if index & 1 == 1 { + product *= num; + } + index >>= 1; + if index == 0 { + break; + } + } + + product + } + + // TODO(Ohad): consider moving this to to a more general place. + // Note: CACHED_FFT_LOG_SIZE is specific to the backend. + fn generate_evaluation_mappings(point: CirclePoint, log_size: u32) -> Vec { + // Mappings are the factors used to compute the evaluation twiddle. + // Every twiddle (i) is of the form (m[0])^b_0 * (m[1])^b_1 * ... * (m[log_size - + // 1])^b_log_size. + // Where (m)_j are the mappings, and b_i is the j'th bit of i. + let mut mappings = vec![point.y, point.x]; + let mut x = point.x; + for _ in 2..log_size { + x = CirclePoint::double_x(x); + mappings.push(x); + } + + // The caller function expects the mapping in natural order. i.e. (y,x,h(x),h(h(x)),...). + // If the polynomial is large, the fft does a transpose in the middle in a granularity of 16 + // (avx512). The coefficients would then be in tranposed order of 16-sized chunks. + // i.e. (a_(n-15), a_(n-14), ..., a_(n-1), a_(n-31), ..., a_(n-16), a_(n-32), ...). + // To compute the twiddles in the correct order, we need to transpose the coprresponding + // 'transposed bits' in the mappings. The result order of the mappings would then be + // (y, x, h(x), h^2(x), h^(log_n-1)(x), h^(log_n-2)(x) ...). To avoid code + // complexity for now, we just reverse the mappings, transpose, then reverse back. + // TODO(Ohad): optimize. consider changing the caller to expect the mappings in + // reversed-tranposed order. + if log_size > CACHED_FFT_LOG_SIZE { + mappings.reverse(); + let n = mappings.len(); + let n0 = (n - LOG_N_LANES as usize) / 2; + let n1 = (n - LOG_N_LANES as usize + 1) / 2; + let (ab, c) = mappings.split_at_mut(n1); + let (a, _b) = ab.split_at_mut(n0); + // Swap content of a,c. + a.swap_with_slice(&mut c[0..n0]); + mappings.reverse(); + } + + mappings + } + + // Generates twiddle steps for efficiently computing the twiddles. + // steps[i] = t_i/(t_0*t_1*...*t_i-1). + fn twiddle_steps(mappings: &[F]) -> Vec + where + F: FieldExpOps, + { + let mut denominators: Vec = vec![mappings[0]]; + + for i in 1..mappings.len() { + denominators.push(denominators[i - 1] * mappings[i]); + } + + let mut denom_inverses = vec![F::zero(); denominators.len()]; + F::batch_inverse(&denominators, &mut denom_inverses); + + let mut steps = vec![mappings[0]]; + + mappings + .iter() + .skip(1) + .zip(denom_inverses.iter()) + .for_each(|(&m, &d)| { + steps.push(m * d); + }); + steps.push(F::one()); + steps + } + + // Advances the twiddle by multiplying it by the next step. e.g: + // If idx(t) = 0b100..1010 , then f(t) = t * step[0] + // If idx(t) = 0b100..0111 , then f(t) = t * step[3] + fn advance_twiddle(twiddle: F, steps: &[F], curr_idx: usize) -> F { + twiddle * steps[curr_idx.trailing_ones() as usize] + } +} + +// TODO(spapini): Everything is returned in redundant representation, where values can also be P. +// Decide if and when it's ok and what to do if it's not. +impl PolyOps for SimdBackend { + // The twiddles type is i32, and not BaseField. This is because the fast AVX mul implementation + // requries one of the numbers to be shifted left by 1 bit. This is not a reduced + // representation of the field. + type Twiddles = Vec; + + fn new_canonical_ordered( + coset: CanonicCoset, + values: Col, + ) -> CircleEvaluation { + // TODO(spapini): Optimize. + let eval = CpuBackend::new_canonical_ordered(coset, values.into_cpu_vec()); + CircleEvaluation::new( + eval.domain, + Col::::from_iter(eval.values), + ) + } + + fn interpolate( + eval: CircleEvaluation, + twiddles: &TwiddleTree, + ) -> CirclePoly { + let mut values = eval.values; + let log_size = values.length.ilog2(); + + let twiddles = domain_line_twiddles_from_tree(eval.domain, &twiddles.itwiddles); + + // Safe because [PackedBaseField] is aligned on 64 bytes. + unsafe { + ifft::ifft( + transmute(values.data.as_mut_ptr()), + &twiddles, + log_size as usize, + ); + } + + // TODO(spapini): Fuse this multiplication / rotation. + let inv = PackedBaseField::broadcast(BaseField::from(eval.domain.size()).inverse()); + values.data.iter_mut().for_each(|x| *x *= inv); + + CirclePoly::new(values) + } + + fn eval_at_point(poly: &CirclePoly, point: CirclePoint) -> SecureField { + // If the polynomial is small, fallback to evaluate directly. + // TODO(Ohad): it's possible to avoid falling back. Consider fixing. + if poly.log_size() <= 8 { + return slow_eval_at_point(poly, point); + } + + let mappings = Self::generate_evaluation_mappings(point, poly.log_size()); + + // 8 lowest mappings produce the first 2^8 twiddles. Separate to optimize each calculation. + let (map_low, map_high) = mappings.split_at(4); + let twiddle_lows = + PackedSecureField::from_array(std::array::from_fn(|i| Self::twiddle_at(map_low, i))); + let (map_mid, map_high) = map_high.split_at(4); + let twiddle_mids = + PackedSecureField::from_array(std::array::from_fn(|i| Self::twiddle_at(map_mid, i))); + + // Compute the high twiddle steps. + let twiddle_steps = Self::twiddle_steps(map_high); + + // Every twiddle is a product of mappings that correspond to '1's in the bit representation + // of the current index. For every 2^n alligned chunk of 2^n elements, the twiddle + // array is the same, denoted twiddle_low. Use this to compute sums of (coeff * + // twiddle_high) mod 2^n, then multiply by twiddle_low, and sum to get the final result. + let mut sum = PackedSecureField::zeroed(); + let mut twiddle_high = SecureField::one(); + for (i, coeff_chunk) in poly.coeffs.data.array_chunks::().enumerate() { + // For every chunk of 2 ^ 4 * 2 ^ 4 = 2 ^ 8 elements, the twiddle high is the same. + // Multiply it by every mid twiddle factor to get the factors for the current chunk. + let high_twiddle_factors = + (PackedSecureField::broadcast(twiddle_high) * twiddle_mids).to_array(); + + // Sum the coefficients multiplied by each corrseponsing twiddle. Result is effectivley + // an array[16] where the value at index 'i' is the sum of all coefficients at indices + // that are i mod 16. + for (&packed_coeffs, mid_twiddle) in zip(coeff_chunk, high_twiddle_factors) { + sum += PackedSecureField::broadcast(mid_twiddle) * packed_coeffs; + } + + // Advance twiddle high. + twiddle_high = Self::advance_twiddle(twiddle_high, &twiddle_steps, i); + } + + (sum * twiddle_lows).pointwise_sum() + } + + fn extend(poly: &CirclePoly, log_size: u32) -> CirclePoly { + // TODO(spapini): Optimize or get rid of extend. + poly.evaluate(CanonicCoset::new(log_size).circle_domain()) + .interpolate() + } + + fn evaluate( + poly: &CirclePoly, + domain: CircleDomain, + twiddles: &TwiddleTree, + ) -> CircleEvaluation { + // TODO(spapini): Precompute twiddles. + // TODO(spapini): Handle small cases. + let log_size = domain.log_size(); + let fft_log_size = poly.log_size(); + assert!( + log_size >= fft_log_size, + "Can only evaluate on larger domains" + ); + + let twiddles = domain_line_twiddles_from_tree(domain, &twiddles.twiddles); + + // Evaluate on a big domains by evaluating on several subdomains. + let log_subdomains = log_size - fft_log_size; + + // Allocate the destination buffer without initializing. + let mut values = Vec::with_capacity(domain.size() >> LOG_N_LANES); + #[allow(clippy::uninit_vec)] + unsafe { + values.set_len(domain.size() >> LOG_N_LANES) + }; + + for i in 0..(1 << log_subdomains) { + // The subdomain twiddles are a slice of the large domain twiddles. + let subdomain_twiddles = (0..(fft_log_size - 1)) + .map(|layer_i| { + &twiddles[layer_i as usize] + [i << (fft_log_size - 2 - layer_i)..(i + 1) << (fft_log_size - 2 - layer_i)] + }) + .collect::>(); + + // FFT from the coefficients buffer to the values chunk. + unsafe { + rfft::fft( + transmute(poly.coeffs.data.as_ptr()), + transmute( + values[i << (fft_log_size - LOG_N_LANES) + ..(i + 1) << (fft_log_size - LOG_N_LANES)] + .as_mut_ptr(), + ), + &subdomain_twiddles, + fft_log_size as usize, + ); + } + } + + CircleEvaluation::new( + domain, + BaseColumn { + data: values, + length: domain.size(), + }, + ) + } + + fn precompute_twiddles(coset: Coset) -> TwiddleTree { + let mut twiddles = Vec::with_capacity(coset.size()); + let mut itwiddles = Vec::with_capacity(coset.size()); + + // TODO(spapini): Optimize. + for layer in &rfft::get_twiddle_dbls(coset) { + twiddles.extend(layer); + } + // Pad by any value, to make the size a power of 2. + twiddles.push(1); + assert_eq!(twiddles.len(), coset.size()); + for layer in &ifft::get_itwiddle_dbls(coset) { + itwiddles.extend(layer); + } + // Pad by any value, to make the size a power of 2. + itwiddles.push(1); + assert_eq!(itwiddles.len(), coset.size()); + + TwiddleTree { + root_coset: coset, + twiddles, + itwiddles, + } + } +} + +fn slow_eval_at_point( + poly: &CirclePoly, + point: CirclePoint, +) -> SecureField { + let mut mappings = vec![point.y, point.x]; + let mut x = point.x; + for _ in 2..poly.log_size() { + x = CirclePoint::double_x(x); + mappings.push(x); + } + mappings.reverse(); + + // If the polynomial is large, the fft does a transpose in the middle. + if poly.log_size() > CACHED_FFT_LOG_SIZE { + let n = mappings.len(); + let n0 = (n - LOG_N_LANES as usize) / 2; + let n1 = (n - LOG_N_LANES as usize + 1) / 2; + let (ab, c) = mappings.split_at_mut(n1); + let (a, _b) = ab.split_at_mut(n0); + // Swap content of a,c. + a.swap_with_slice(&mut c[0..n0]); + } + fold(cast_slice::<_, BaseField>(&poly.coeffs.data), &mappings) +} + +#[cfg(test)] +mod tests { + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use crate::core::backend::simd::circle::slow_eval_at_point; + use crate::core::backend::simd::fft::{CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::Column; + use crate::core::circle::CirclePoint; + use crate::core::fields::m31::BaseField; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, CirclePoly, PolyOps}; + use crate::core::poly::{BitReversedOrder, NaturalOrder}; + + #[test] + fn test_interpolate_and_eval() { + for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 4 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let evaluation = CircleEvaluation::::new( + domain, + (0..1 << log_size).map(BaseField::from).collect(), + ); + + let poly = evaluation.clone().interpolate(); + let evaluation2 = poly.evaluate(domain); + + assert_eq!(evaluation.values.to_cpu(), evaluation2.values.to_cpu()); + } + } + + #[test] + fn test_eval_extension() { + for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let domain_ext = CanonicCoset::new(log_size + 2).circle_domain(); + let evaluation = CircleEvaluation::::new( + domain, + (0..1 << log_size).map(BaseField::from).collect(), + ); + let poly = evaluation.clone().interpolate(); + + let evaluation2 = poly.evaluate(domain_ext); + + assert_eq!( + poly.extend(log_size + 2).coeffs.to_cpu(), + evaluation2.interpolate().coeffs.to_cpu() + ); + } + } + + #[test] + fn test_eval_at_point() { + for log_size in MIN_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 4 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let evaluation = CircleEvaluation::::new( + domain, + (0..1 << log_size).map(BaseField::from).collect(), + ); + let poly = evaluation.bit_reverse().interpolate(); + for i in [0, 1, 3, 1 << (log_size - 1), 1 << (log_size - 2)] { + let p = domain.at(i); + + let eval = poly.eval_at_point(p.into_ef()); + + assert_eq!( + eval, + BaseField::from(i).into(), + "log_size={log_size}, i={i}" + ); + } + } + } + + #[test] + fn test_circle_poly_extend() { + for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 { + let poly = + CirclePoly::::new((0..1 << log_size).map(BaseField::from).collect()); + let eval0 = poly.evaluate(CanonicCoset::new(log_size + 2).circle_domain()); + + let eval1 = poly + .extend(log_size + 2) + .evaluate(CanonicCoset::new(log_size + 2).circle_domain()); + + assert_eq!(eval0.values.to_cpu(), eval1.values.to_cpu()); + } + } + + #[test] + fn test_eval_securefield() { + let mut rng = SmallRng::seed_from_u64(0); + for log_size in MIN_FFT_LOG_SIZE..CACHED_FFT_LOG_SIZE + 2 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let evaluation = CircleEvaluation::::new( + domain, + (0..1 << log_size).map(BaseField::from).collect(), + ); + let poly = evaluation.bit_reverse().interpolate(); + let x = rng.gen(); + let y = rng.gen(); + let p = CirclePoint { x, y }; + + let eval = PolyOps::eval_at_point(&poly, p); + + assert_eq!(eval, slow_eval_at_point(&poly, p), "log_size = {log_size}"); + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/cm31.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/cm31.rs new file mode 100644 index 0000000..31aba0a --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/cm31.rs @@ -0,0 +1,230 @@ +use std::array; +use std::ops::{Add, Mul, MulAssign, Neg, Sub}; + +use bytemuck::{Pod, Zeroable}; +use num_traits::{One, Zero}; + +use super::m31::{PackedM31, N_LANES}; +use crate::core::fields::cm31::CM31; +use crate::core::fields::FieldExpOps; + +/// SIMD implementation of [`CM31`]. +#[derive(Copy, Clone, Debug)] +pub struct PackedCM31(pub [PackedM31; 2]); + +impl PackedCM31 { + /// Constructs a new instance with all vector elements set to `value`. + pub fn broadcast(value: CM31) -> Self { + Self([PackedM31::broadcast(value.0), PackedM31::broadcast(value.1)]) + } + + /// Returns all `a` values such that each vector element is represented as `a + bi`. + pub fn a(&self) -> PackedM31 { + self.0[0] + } + + /// Returns all `b` values such that each vector element is represented as `a + bi`. + pub fn b(&self) -> PackedM31 { + self.0[1] + } + + pub fn to_array(&self) -> [CM31; N_LANES] { + let a = self.a().to_array(); + let b = self.b().to_array(); + array::from_fn(|i| CM31(a[i], b[i])) + } + + pub fn from_array(values: [CM31; N_LANES]) -> Self { + Self([ + PackedM31::from_array(values.map(|v| v.0)), + PackedM31::from_array(values.map(|v| v.1)), + ]) + } + + /// Interleaves two vectors. + pub fn interleave(self, other: Self) -> (Self, Self) { + let Self([a_evens, b_evens]) = self; + let Self([a_odds, b_odds]) = other; + let (a_lhs, a_rhs) = a_evens.interleave(a_odds); + let (b_lhs, b_rhs) = b_evens.interleave(b_odds); + (Self([a_lhs, b_lhs]), Self([a_rhs, b_rhs])) + } + + /// Deinterleaves two vectors. + pub fn deinterleave(self, other: Self) -> (Self, Self) { + let Self([a_self, b_self]) = self; + let Self([a_other, b_other]) = other; + let (a_evens, a_odds) = a_self.deinterleave(a_other); + let (b_evens, b_odds) = b_self.deinterleave(b_other); + (Self([a_evens, b_evens]), Self([a_odds, b_odds])) + } + + /// Doubles each element in the vector. + pub fn double(self) -> Self { + let Self([a, b]) = self; + Self([a.double(), b.double()]) + } +} + +impl Add for PackedCM31 { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self([self.a() + rhs.a(), self.b() + rhs.b()]) + } +} + +impl Sub for PackedCM31 { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self([self.a() - rhs.a(), self.b() - rhs.b()]) + } +} + +impl Mul for PackedCM31 { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + // Compute using Karatsuba. + let ac = self.a() * rhs.a(); + let bd = self.b() * rhs.b(); + // Computes (a + b) * (c + d). + let ab_t_cd = (self.a() + self.b()) * (rhs.a() + rhs.b()); + // (ac - bd) + (ad + bc)i. + Self([ac - bd, ab_t_cd - ac - bd]) + } +} + +impl Zero for PackedCM31 { + fn zero() -> Self { + Self([PackedM31::zero(), PackedM31::zero()]) + } + + fn is_zero(&self) -> bool { + self.a().is_zero() && self.b().is_zero() + } +} + +unsafe impl Pod for PackedCM31 {} + +unsafe impl Zeroable for PackedCM31 { + fn zeroed() -> Self { + unsafe { core::mem::zeroed() } + } +} + +impl One for PackedCM31 { + fn one() -> Self { + Self([PackedM31::one(), PackedM31::zero()]) + } +} + +impl MulAssign for PackedCM31 { + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl FieldExpOps for PackedCM31 { + fn inverse(&self) -> Self { + assert!(!self.is_zero(), "0 has no inverse"); + // 1 / (a + bi) = (a - bi) / (a^2 + b^2). + Self([self.a(), -self.b()]) * (self.a().square() + self.b().square()).inverse() + } +} + +impl Add for PackedCM31 { + type Output = Self; + + fn add(self, rhs: PackedM31) -> Self::Output { + Self([self.a() + rhs, self.b()]) + } +} + +impl Sub for PackedCM31 { + type Output = Self; + + fn sub(self, rhs: PackedM31) -> Self::Output { + let Self([a, b]) = self; + Self([a - rhs, b]) + } +} + +impl Mul for PackedCM31 { + type Output = Self; + + fn mul(self, rhs: PackedM31) -> Self::Output { + let Self([a, b]) = self; + Self([a * rhs, b * rhs]) + } +} + +impl Neg for PackedCM31 { + type Output = Self; + + fn neg(self) -> Self::Output { + let Self([a, b]) = self; + Self([-a, -b]) + } +} + +#[cfg(test)] +mod tests { + use std::array; + + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use crate::core::backend::simd::cm31::PackedCM31; + + #[test] + fn addition_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedCM31::from_array(lhs); + let packed_rhs = PackedCM31::from_array(rhs); + + let res = packed_lhs + packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i])); + } + + #[test] + fn subtraction_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedCM31::from_array(lhs); + let packed_rhs = PackedCM31::from_array(rhs); + + let res = packed_lhs - packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i])); + } + + #[test] + fn multiplication_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedCM31::from_array(lhs); + let packed_rhs = PackedCM31::from_array(rhs); + + let res = packed_lhs * packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i])); + } + + #[test] + fn negation_works() { + let mut rng = SmallRng::seed_from_u64(0); + let values = rng.gen(); + let packed_values = PackedCM31::from_array(values); + + let res = -packed_values; + + assert_eq!(res.to_array(), values.map(|v| -v)); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/column.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/column.rs new file mode 100644 index 0000000..6486940 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/column.rs @@ -0,0 +1,656 @@ +use std::iter::zip; +use std::{array, mem}; + +use bytemuck::allocation::cast_vec; +use bytemuck::{cast_slice, cast_slice_mut, Zeroable}; +use itertools::{izip, Itertools}; +use num_traits::Zero; + +use super::cm31::PackedCM31; +use super::m31::{PackedBaseField, N_LANES}; +use super::qm31::{PackedQM31, PackedSecureField}; +use super::very_packed_m31::{VeryPackedBaseField, VeryPackedSecureField, N_VERY_PACKED_ELEMS}; +use super::SimdBackend; +use crate::core::backend::{Column, CpuBackend}; +use crate::core::fields::cm31::CM31; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE}; +use crate::core::fields::{FieldExpOps, FieldOps}; + +impl FieldOps for SimdBackend { + fn batch_inverse(column: &BaseColumn, dst: &mut BaseColumn) { + PackedBaseField::batch_inverse(&column.data, &mut dst.data); + } +} + +impl FieldOps for SimdBackend { + fn batch_inverse(column: &SecureColumn, dst: &mut SecureColumn) { + PackedSecureField::batch_inverse(&column.data, &mut dst.data); + } +} + +/// An efficient structure for storing and operating on a arbitrary number of [`BaseField`] values. +#[derive(Clone, Debug)] +pub struct BaseColumn { + pub data: Vec, + /// The number of [`BaseField`]s in the vector. + pub length: usize, +} + +impl BaseColumn { + /// Extracts a slice containing the entire vector of [`BaseField`]s. + pub fn as_slice(&self) -> &[BaseField] { + &cast_slice(&self.data)[..self.length] + } + + /// Extracts a mutable slice containing the entire vector of [`BaseField`]s. + pub fn as_mut_slice(&mut self) -> &mut [BaseField] { + &mut cast_slice_mut(&mut self.data)[..self.length] + } + + pub fn into_cpu_vec(mut self) -> Vec { + let capacity = self.data.capacity() * N_LANES; + let length = self.length; + let ptr = self.data.as_mut_ptr() as *mut BaseField; + let res = unsafe { Vec::from_raw_parts(ptr, length, capacity) }; + mem::forget(self); + res + } + + /// Returns a vector of `BaseColumnMutSlice`s, each mutably owning + /// `chunk_size` `PackedBaseField`s (i.e, `chuck_size` * `N_LANES` elements). + pub fn chunks_mut(&mut self, chunk_size: usize) -> Vec> { + self.data + .chunks_mut(chunk_size) + .map(BaseColumnMutSlice) + .collect_vec() + } + + pub fn into_secure_column(self) -> SecureColumn { + let length = self.len(); + let data = self.data.into_iter().map(PackedSecureField::from).collect(); + SecureColumn { data, length } + } +} + +impl Column for BaseColumn { + fn zeros(length: usize) -> Self { + let data = vec![PackedBaseField::zeroed(); length.div_ceil(N_LANES)]; + Self { data, length } + } + + #[allow(clippy::uninit_vec)] + unsafe fn uninitialized(length: usize) -> Self { + let mut data = Vec::with_capacity(length.div_ceil(N_LANES)); + data.set_len(length.div_ceil(N_LANES)); + Self { data, length } + } + + fn to_cpu(&self) -> Vec { + self.as_slice().to_vec() + } + + fn len(&self) -> usize { + self.length + } + + fn at(&self, index: usize) -> BaseField { + self.data[index / N_LANES].to_array()[index % N_LANES] + } + + fn set(&mut self, index: usize, value: BaseField) { + let mut packed = self.data[index / N_LANES].to_array(); + packed[index % N_LANES] = value; + self.data[index / N_LANES] = PackedBaseField::from_array(packed) + } +} + +impl FromIterator for BaseColumn { + fn from_iter>(iter: I) -> Self { + let mut chunks = iter.into_iter().array_chunks(); + let mut data = (&mut chunks).map(PackedBaseField::from_array).collect_vec(); + let mut length = data.len() * N_LANES; + + if let Some(remainder) = chunks.into_remainder() { + if !remainder.is_empty() { + length += remainder.len(); + let mut last = [BaseField::zero(); N_LANES]; + last[..remainder.len()].copy_from_slice(remainder.as_slice()); + data.push(PackedBaseField::from_array(last)); + } + } + + Self { data, length } + } +} + +// A efficient structure for storing and operating on a arbitrary number of [`SecureField`] values. +#[derive(Clone, Debug)] +pub struct CM31Column { + pub data: Vec, + pub length: usize, +} + +impl Column for CM31Column { + fn zeros(length: usize) -> Self { + Self { + data: vec![PackedCM31::zeroed(); length.div_ceil(N_LANES)], + length, + } + } + + #[allow(clippy::uninit_vec)] + unsafe fn uninitialized(length: usize) -> Self { + let mut data = Vec::with_capacity(length.div_ceil(N_LANES)); + data.set_len(length.div_ceil(N_LANES)); + Self { data, length } + } + + fn to_cpu(&self) -> Vec { + self.data + .iter() + .flat_map(|x| x.to_array()) + .take(self.length) + .collect() + } + + fn len(&self) -> usize { + self.length + } + + fn at(&self, index: usize) -> CM31 { + self.data[index / N_LANES].to_array()[index % N_LANES] + } + + fn set(&mut self, index: usize, value: CM31) { + let mut packed = self.data[index / N_LANES].to_array(); + packed[index % N_LANES] = value; + self.data[index / N_LANES] = PackedCM31::from_array(packed) + } +} + +impl FromIterator for CM31Column { + fn from_iter>(iter: I) -> Self { + let mut chunks = iter.into_iter().array_chunks(); + let mut data = (&mut chunks).map(PackedCM31::from_array).collect_vec(); + let mut length = data.len() * N_LANES; + + if let Some(remainder) = chunks.into_remainder() { + if !remainder.is_empty() { + length += remainder.len(); + let mut last = [CM31::zero(); N_LANES]; + last[..remainder.len()].copy_from_slice(remainder.as_slice()); + data.push(PackedCM31::from_array(last)); + } + } + + Self { data, length } + } +} + +impl FromIterator for CM31Column { + fn from_iter>(iter: I) -> Self { + let data = (&mut iter.into_iter()).collect_vec(); + let length = data.len() * N_LANES; + + Self { data, length } + } +} + +/// A mutable slice of a BaseColumn. +pub struct BaseColumnMutSlice<'a>(pub &'a mut [PackedBaseField]); + +impl<'a> BaseColumnMutSlice<'a> { + pub fn at(&self, index: usize) -> BaseField { + self.0[index / N_LANES].to_array()[index % N_LANES] + } + + pub fn set(&mut self, index: usize, value: BaseField) { + let mut packed = self.0[index / N_LANES].to_array(); + packed[index % N_LANES] = value; + self.0[index / N_LANES] = PackedBaseField::from_array(packed) + } +} + +/// An efficient structure for storing and operating on a arbitrary number of [`SecureField`] +/// values. +#[derive(Clone, Debug)] +pub struct SecureColumn { + pub data: Vec, + /// The number of [`SecureField`]s in the vector. + pub length: usize, +} + +impl SecureColumn { + // Separates a single column of `PackedSecureField` elements into `SECURE_EXTENSION_DEGREE` many + // `PackedBaseField` coordinate columns. + pub fn into_secure_column_by_coords(self) -> SecureColumnByCoords { + if self.len() < N_LANES { + return self.to_cpu().into_iter().collect(); + } + + let length = self.length; + let packed_length = self.data.len(); + let mut columns = array::from_fn(|_| Vec::with_capacity(packed_length)); + + for v in self.data { + let packed_coords = v.into_packed_m31s(); + zip(&mut columns, packed_coords).for_each(|(col, packed_coord)| col.push(packed_coord)); + } + + SecureColumnByCoords { + columns: columns.map(|col| BaseColumn { data: col, length }), + } + } +} + +impl Column for SecureColumn { + fn zeros(length: usize) -> Self { + Self { + data: vec![PackedSecureField::zeroed(); length.div_ceil(N_LANES)], + length, + } + } + + #[allow(clippy::uninit_vec)] + unsafe fn uninitialized(length: usize) -> Self { + let mut data = Vec::with_capacity(length.div_ceil(N_LANES)); + data.set_len(length.div_ceil(N_LANES)); + Self { data, length } + } + + fn to_cpu(&self) -> Vec { + self.data + .iter() + .flat_map(|x| x.to_array()) + .take(self.length) + .collect() + } + + fn len(&self) -> usize { + self.length + } + + fn at(&self, index: usize) -> SecureField { + self.data[index / N_LANES].to_array()[index % N_LANES] + } + + fn set(&mut self, index: usize, value: SecureField) { + let mut packed = self.data[index / N_LANES].to_array(); + packed[index % N_LANES] = value; + self.data[index / N_LANES] = PackedSecureField::from_array(packed) + } +} + +impl FromIterator for SecureColumn { + fn from_iter>(iter: I) -> Self { + let mut chunks = iter.into_iter().array_chunks(); + let mut data = (&mut chunks) + .map(PackedSecureField::from_array) + .collect_vec(); + let mut length = data.len() * N_LANES; + + if let Some(remainder) = chunks.into_remainder() { + if !remainder.is_empty() { + length += remainder.len(); + let mut last = [SecureField::zero(); N_LANES]; + last[..remainder.len()].copy_from_slice(remainder.as_slice()); + data.push(PackedSecureField::from_array(last)); + } + } + + Self { data, length } + } +} + +impl FromIterator for SecureColumn { + fn from_iter>(iter: I) -> Self { + let data = iter.into_iter().collect_vec(); + let length = data.len() * N_LANES; + Self { data, length } + } +} + +/// A mutable slice of a SecureColumnByCoords. +pub struct SecureColumnByCoordsMutSlice<'a>(pub [BaseColumnMutSlice<'a>; SECURE_EXTENSION_DEGREE]); + +impl<'a> SecureColumnByCoordsMutSlice<'a> { + /// # Safety + /// + /// `vec_index` must be a valid index. + pub unsafe fn packed_at(&self, vec_index: usize) -> PackedSecureField { + PackedQM31([ + PackedCM31([ + *self.0[0].0.get_unchecked(vec_index), + *self.0[1].0.get_unchecked(vec_index), + ]), + PackedCM31([ + *self.0[2].0.get_unchecked(vec_index), + *self.0[3].0.get_unchecked(vec_index), + ]), + ]) + } + + /// # Safety + /// + /// `vec_index` must be a valid index. + pub unsafe fn set_packed(&mut self, vec_index: usize, value: PackedSecureField) { + let PackedQM31([PackedCM31([a, b]), PackedCM31([c, d])]) = value; + *self.0[0].0.get_unchecked_mut(vec_index) = a; + *self.0[1].0.get_unchecked_mut(vec_index) = b; + *self.0[2].0.get_unchecked_mut(vec_index) = c; + *self.0[3].0.get_unchecked_mut(vec_index) = d; + } +} + +impl SecureColumnByCoords { + pub fn packed_len(&self) -> usize { + self.columns[0].data.len() + } + + /// # Safety + /// + /// `vec_index` must be a valid index. + pub unsafe fn packed_at(&self, vec_index: usize) -> PackedSecureField { + PackedQM31([ + PackedCM31([ + *self.columns[0].data.get_unchecked(vec_index), + *self.columns[1].data.get_unchecked(vec_index), + ]), + PackedCM31([ + *self.columns[2].data.get_unchecked(vec_index), + *self.columns[3].data.get_unchecked(vec_index), + ]), + ]) + } + + /// # Safety + /// + /// `vec_index` must be a valid index. + pub unsafe fn set_packed(&mut self, vec_index: usize, value: PackedSecureField) { + let PackedQM31([PackedCM31([a, b]), PackedCM31([c, d])]) = value; + *self.columns[0].data.get_unchecked_mut(vec_index) = a; + *self.columns[1].data.get_unchecked_mut(vec_index) = b; + *self.columns[2].data.get_unchecked_mut(vec_index) = c; + *self.columns[3].data.get_unchecked_mut(vec_index) = d; + } + + pub fn to_vec(&self) -> Vec { + izip!( + self.columns[0].to_cpu(), + self.columns[1].to_cpu(), + self.columns[2].to_cpu(), + self.columns[3].to_cpu(), + ) + .map(|(a, b, c, d)| SecureField::from_m31_array([a, b, c, d])) + .collect() + } + + /// Returns a vector of `SecureColumnByCoordsMutSlice`s, each mutably owning + /// `SECURE_EXTENSION_DEGREE` slices of `chunk_size` `PackedBaseField`s + /// (i.e, `chuck_size` * `N_LANES` secure field elements, by coordinates). + pub fn chunks_mut(&mut self, chunk_size: usize) -> Vec> { + let [a, b, c, d] = self + .columns + .get_many_mut([0, 1, 2, 3]) + .unwrap() + .map(|x| x.chunks_mut(chunk_size)); + izip!(a, b, c, d) + .map(|(a, b, c, d)| SecureColumnByCoordsMutSlice([a, b, c, d])) + .collect_vec() + } +} + +impl FromIterator for SecureColumnByCoords { + fn from_iter>(iter: I) -> Self { + let cpu_col = SecureColumnByCoords::::from_iter(iter); + let columns = cpu_col.columns.map(|col| col.into_iter().collect()); + SecureColumnByCoords { columns } + } +} + +#[derive(Clone, Debug)] +pub struct VeryPackedBaseColumn { + pub data: Vec, + /// The number of [`BaseField`]s in the vector. + pub length: usize, +} + +impl VeryPackedBaseColumn { + /// Transforms a `&BaseColumn` to a `&VeryPackedBaseColumn`. + /// # Safety + /// + /// The resulting pointer does not update the underlying `data`'s length. + pub unsafe fn transform_under_ref(value: &BaseColumn) -> &Self { + &*(std::ptr::addr_of!(*value) as *const VeryPackedBaseColumn) + } +} + +impl From for VeryPackedBaseColumn { + fn from(value: BaseColumn) -> Self { + Self { + data: cast_vec(value.data), + length: value.length, + } + } +} + +impl From for BaseColumn { + fn from(value: VeryPackedBaseColumn) -> Self { + Self { + data: cast_vec(value.data), + length: value.length, + } + } +} + +impl FromIterator for VeryPackedBaseColumn { + fn from_iter>(iter: I) -> Self { + BaseColumn::from_iter(iter).into() + } +} + +impl Column for VeryPackedBaseColumn { + fn zeros(length: usize) -> Self { + BaseColumn::zeros(length).into() + } + + #[allow(clippy::uninit_vec)] + unsafe fn uninitialized(length: usize) -> Self { + BaseColumn::uninitialized(length).into() + } + + fn to_cpu(&self) -> Vec { + todo!() + } + + fn len(&self) -> usize { + self.length + } + + fn at(&self, index: usize) -> BaseField { + let chunk_size = N_LANES * N_VERY_PACKED_ELEMS; + self.data[index / chunk_size].to_array()[index % chunk_size] + } + + fn set(&mut self, index: usize, value: BaseField) { + let chunk_size = N_LANES * N_VERY_PACKED_ELEMS; + let mut packed = self.data[index / chunk_size].to_array(); + packed[index % chunk_size] = value; + self.data[index / chunk_size] = VeryPackedBaseField::from_array(packed) + } +} + +#[derive(Clone, Debug)] +pub struct VeryPackedSecureColumnByCoords { + pub columns: [VeryPackedBaseColumn; SECURE_EXTENSION_DEGREE], +} + +impl From> for VeryPackedSecureColumnByCoords { + fn from(value: SecureColumnByCoords) -> Self { + Self { + columns: value + .columns + .into_iter() + .map(VeryPackedBaseColumn::from) + .collect_vec() + .try_into() + .unwrap(), + } + } +} + +impl From for SecureColumnByCoords { + fn from(value: VeryPackedSecureColumnByCoords) -> Self { + Self { + columns: value + .columns + .into_iter() + .map(BaseColumn::from) + .collect_vec() + .try_into() + .unwrap(), + } + } +} + +impl VeryPackedSecureColumnByCoords { + pub fn packed_len(&self) -> usize { + self.columns[0].data.len() + } + + /// # Safety + /// + /// `vec_index` must be a valid index. + pub unsafe fn packed_at(&self, vec_index: usize) -> VeryPackedSecureField { + VeryPackedSecureField::from_fn(|i| { + PackedQM31([ + PackedCM31([ + self.columns[0].data.get_unchecked(vec_index).0[i], + self.columns[1].data.get_unchecked(vec_index).0[i], + ]), + PackedCM31([ + self.columns[2].data.get_unchecked(vec_index).0[i], + self.columns[3].data.get_unchecked(vec_index).0[i], + ]), + ]) + }) + } + + /// # Safety + /// + /// `vec_index` must be a valid index. + pub unsafe fn set_packed(&mut self, vec_index: usize, value: VeryPackedSecureField) { + for i in 0..N_VERY_PACKED_ELEMS { + let PackedQM31([PackedCM31([a, b]), PackedCM31([c, d])]) = value.0[i]; + self.columns[0].data.get_unchecked_mut(vec_index).0[i] = a; + self.columns[1].data.get_unchecked_mut(vec_index).0[i] = b; + self.columns[2].data.get_unchecked_mut(vec_index).0[i] = c; + self.columns[3].data.get_unchecked_mut(vec_index).0[i] = d; + } + } + + pub fn to_vec(&self) -> Vec { + izip!( + self.columns[0].to_cpu(), + self.columns[1].to_cpu(), + self.columns[2].to_cpu(), + self.columns[3].to_cpu(), + ) + .map(|(a, b, c, d)| SecureField::from_m31_array([a, b, c, d])) + .collect() + } + + /// Transforms a `&mut SecureColumnByCoords` to a + /// `&mut VeryPackedSecureColumnByCoords`. + /// + /// # Safety + /// + /// The resulting pointer does not update the underlying columns' `data`'s lengths. + pub unsafe fn transform_under_mut(value: &mut SecureColumnByCoords) -> &mut Self { + &mut *(std::ptr::addr_of!(*value) as *mut VeryPackedSecureColumnByCoords) + } +} + +#[cfg(test)] +mod tests { + use std::array; + + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::BaseColumn; + use crate::core::backend::simd::column::SecureColumn; + use crate::core::backend::simd::m31::N_LANES; + use crate::core::backend::simd::qm31::PackedQM31; + use crate::core::backend::Column; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::secure_column::SecureColumnByCoords; + + #[test] + fn base_field_vec_from_iter_works() { + let values: [BaseField; 30] = array::from_fn(BaseField::from); + + let res = values.into_iter().collect::(); + + assert_eq!(res.to_cpu(), values); + } + + #[test] + fn secure_field_vec_from_iter_works() { + let mut rng = SmallRng::seed_from_u64(0); + let values: [SecureField; 30] = rng.gen(); + + let res = values.into_iter().collect::(); + + assert_eq!(res.to_cpu(), values); + } + + #[test] + fn test_base_column_chunks_mut() { + let values: [BaseField; N_LANES * 7] = array::from_fn(BaseField::from); + let mut col = values.into_iter().collect::(); + + const CHUNK_SIZE: usize = 2; + let mut chunks = col.chunks_mut(CHUNK_SIZE); + chunks[2].set(19, BaseField::from(1234)); + chunks[3].set(1, BaseField::from(5678)); + + assert_eq!(col.at(2 * CHUNK_SIZE * N_LANES + 19), BaseField::from(1234)); + assert_eq!(col.at(3 * CHUNK_SIZE * N_LANES + 1), BaseField::from(5678)); + } + + #[test] + fn test_secure_column_by_coords_chunks_mut() { + const COL_PACKED_SIZE: usize = 16; + let a: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from); + let b: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from); + let c: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from); + let d: [BaseField; N_LANES * COL_PACKED_SIZE] = array::from_fn(BaseField::from); + let mut col = SecureColumnByCoords { + columns: [a, b, c, d].map(|values| values.into_iter().collect::()), + }; + + let mut rng = SmallRng::seed_from_u64(0); + let rand0 = PackedQM31::from_array(rng.gen()); + let rand1 = PackedQM31::from_array(rng.gen()); + + const CHUNK_SIZE: usize = 4; + let mut chunks = col.chunks_mut(CHUNK_SIZE); + unsafe { + chunks[2].set_packed(3, rand0); + chunks[3].set_packed(1, rand1); + + assert_eq!( + col.packed_at(2 * CHUNK_SIZE + 3).to_array(), + rand0.to_array() + ); + assert_eq!( + col.packed_at(3 * CHUNK_SIZE + 1).to_array(), + rand1.to_array() + ); + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/domain.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/domain.rs new file mode 100644 index 0000000..2093141 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/domain.rs @@ -0,0 +1,86 @@ +use std::simd::{simd_swizzle, u32x2, Simd}; + +use super::m31::{PackedM31, LOG_N_LANES}; +use crate::core::circle::{CirclePoint, M31_CIRCLE_LOG_ORDER}; +use crate::core::fields::m31::M31; +use crate::core::poly::circle::CircleDomain; +use crate::core::utils::bit_reverse_index; + +pub struct CircleDomainBitRevIterator { + domain: CircleDomain, + i: usize, + current: CirclePoint, + flips: [CirclePoint; (M31_CIRCLE_LOG_ORDER - LOG_N_LANES) as usize], +} +impl CircleDomainBitRevIterator { + pub fn new(domain: CircleDomain) -> Self { + let log_size = domain.log_size(); + assert!(log_size >= LOG_N_LANES); + + let initial_points = std::array::from_fn(|i| domain.at(bit_reverse_index(i, log_size))); + let current = CirclePoint { + x: PackedM31::from_array(initial_points.each_ref().map(|p| p.x)), + y: PackedM31::from_array(initial_points.each_ref().map(|p| p.y)), + }; + + let mut flips = [CirclePoint::zero(); (M31_CIRCLE_LOG_ORDER - LOG_N_LANES) as usize]; + for i in 0..(log_size - LOG_N_LANES) { + // L i + // 000111000000 -> + // 000000100000 + let prev_mul = bit_reverse_index((1 << i) - 1, log_size - LOG_N_LANES); + let new_mul = bit_reverse_index(1 << i, log_size - LOG_N_LANES); + let flip = domain.half_coset.step.mul(new_mul as u128) + - domain.half_coset.step.mul(prev_mul as u128); + flips[i as usize] = flip; + } + Self { + domain, + i: 0, + current, + flips, + } + } +} +impl Iterator for CircleDomainBitRevIterator { + type Item = CirclePoint; + + fn next(&mut self) -> Option { + if self.i << LOG_N_LANES >= self.domain.size() { + return None; + } + let res = self.current; + let flip = self.flips[self.i.trailing_ones() as usize]; + let flipx = Simd::splat(flip.x.0); + let flipy = u32x2::from_array([flip.y.0, (-flip.y).0]); + let flipy = simd_swizzle!(flipy, [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]); + let flip = unsafe { + CirclePoint { + x: PackedM31::from_simd_unchecked(flipx), + y: PackedM31::from_simd_unchecked(flipy), + } + }; + self.current = self.current + flip; + self.i += 1; + Some(res) + } +} + +#[test] +fn test_circle_domain_bit_rev_iterator() { + let domain = CircleDomain::new(crate::core::circle::Coset::new( + crate::core::circle::CirclePointIndex::generator(), + 5, + )); + let mut expected = domain.iter().collect::>(); + crate::core::utils::bit_reverse(&mut expected); + let actual = CircleDomainBitRevIterator::new(domain) + .flat_map(|c| -> [_; 16] { + std::array::from_fn(|i| CirclePoint { + x: c.x.to_array()[i], + y: c.y.to_array()[i], + }) + }) + .collect::>(); + assert_eq!(actual, expected); +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/ifft.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/ifft.rs new file mode 100644 index 0000000..eb34da4 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -0,0 +1,712 @@ +//! Inverse fft. + +use std::simd::{simd_swizzle, u32x16, u32x2, u32x4}; + +use itertools::Itertools; + +use super::{ + compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, +}; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::circle::Coset; +use crate::core::fields::FieldExpOps; +use crate::core::utils::bit_reverse; + +/// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values. +/// +/// # Arguments +/// +/// - `values`: A mutable pointer to the values on which the ICFFT is to be performed. +/// - `twiddle_dbl`: A reference to the doubles of the twiddle factors. +/// - `log_n_elements`: The log of the number of elements in the `values` array. +/// +/// # Panics +/// +/// Panic if `log_n_elements` is less than [`MIN_FFT_LOG_SIZE`]. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +pub unsafe fn ifft(values: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) { + assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize); + let log_n_vecs = log_n_elements - LOG_N_LANES as usize; + if log_n_elements <= CACHED_FFT_LOG_SIZE as usize { + ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements); + return; + } + + let fft_layers_pre_transpose = log_n_vecs.div_ceil(2); + let fft_layers_post_transpose = log_n_vecs / 2; + ifft_lower_with_vecwise( + values, + &twiddle_dbl[..3 + fft_layers_pre_transpose], + log_n_elements, + fft_layers_pre_transpose + LOG_N_LANES as usize, + ); + transpose_vecs(values, log_n_vecs); + ifft_lower_without_vecwise( + values, + &twiddle_dbl[3 + fft_layers_pre_transpose..], + log_n_elements, + fft_layers_post_transpose, + ); +} + +/// Computes partial ifft on `2^log_size` M31 elements. +/// +/// # Arguments +/// +/// - `values`: Pointer to the entire value array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the ifft. Layer i +/// holds `2^(log_size - 1 - i)` twiddles. +/// - `log_size`: The log of the number of number of M31 elements in the array. +/// - `fft_layers`: The number of ifft layers to apply, out of log_size. +/// +/// # Panics +/// +/// Panics if `log_size` is not at least 5. +/// +/// # Safety +/// +/// `values` must have the same alignment as [`PackedBaseField`]. +/// `fft_layers` must be at least 5. +pub unsafe fn ifft_lower_with_vecwise( + values: *mut u32, + twiddle_dbl: &[&[u32]], + log_size: usize, + fft_layers: usize, +) { + const VECWISE_FFT_BITS: usize = LOG_N_LANES as usize + 1; + assert!(log_size >= VECWISE_FFT_BITS); + + assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); + + for index_h in 0..1 << (log_size - fft_layers) { + ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h); + for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) { + match fft_layers - layer { + 1 => { + ifft1_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); + } + 2 => { + ifft2_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); + } + _ => { + ifft3_loop( + values, + &twiddle_dbl[(layer - 1)..], + fft_layers - layer - 3, + layer, + index_h, + ); + } + } + } + } +} + +/// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of +/// the index). +/// +/// # Arguments +/// +/// - `values`: Pointer to the entire value array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the ifft. +/// - `log_size`: The log of the number of number of M31 elements in the array. +/// - `fft_layers`: The number of ifft layers to apply, out of `log_size - LOG_N_LANES`. +/// +/// # Panics +/// +/// Panics if `log_size` is not at least 4. +/// +/// # Safety +/// +/// `values` must have the same alignment as [`PackedBaseField`]. +/// `fft_layers` must be at least 4. +pub unsafe fn ifft_lower_without_vecwise( + values: *mut u32, + twiddle_dbl: &[&[u32]], + log_size: usize, + fft_layers: usize, +) { + assert!(log_size >= LOG_N_LANES as usize); + + for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) { + for layer in (0..fft_layers).step_by(3) { + let fixed_layer = layer + LOG_N_LANES as usize; + match fft_layers - layer { + 1 => { + ifft1_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); + } + 2 => { + ifft2_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); + } + _ => { + ifft3_loop( + values, + &twiddle_dbl[layer..], + fft_layers - layer - 3, + fixed_layer, + index_h, + ); + } + } + } + } +} + +/// Runs the first 5 ifft layers across the entire array. +/// +/// # Arguments +/// +/// - `values`: Pointer to the entire value array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 5 ifft layers. +/// - `high_bits`: The number of bits this loops needs to run on. +/// - `index_h`: The higher part of the index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +pub unsafe fn ifft_vecwise_loop( + values: *mut u32, + twiddle_dbl: &[&[u32]], + loop_bits: usize, + index_h: usize, +) { + for index_l in 0..1 << loop_bits { + let index = (index_h << loop_bits) + index_l; + let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const()); + let mut val1 = PackedBaseField::load(values.add(index * 32 + 16).cast_const()); + (val0, val1) = vecwise_ibutterflies( + val0, + val1, + std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)), + std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), + std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), + ); + (val0, val1) = simd_ibutterfly( + val0, + val1, + u32x16::splat(*twiddle_dbl[3].get_unchecked(index)), + ); + val0.store(values.add(index * 32)); + val1.store(values.add(index * 32 + 16)); + } +} + +/// Runs 3 ifft layers across the entire array. +/// +/// # Arguments +/// +/// - `values`: Pointer to the entire value array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 3 ifft layers. +/// - `loop_bits`: The number of bits this loops needs to run on. +/// - `layer`: The layer number of the first ifft layer to apply. The layers `layer`, `layer + 1`, +/// `layer + 2` are applied. +/// - `index_h`: The higher part of the index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +pub unsafe fn ifft3_loop( + values: *mut u32, + twiddle_dbl: &[&[u32]], + loop_bits: usize, + layer: usize, + index_h: usize, +) { + for index_l in 0..1 << loop_bits { + let index = (index_h << loop_bits) + index_l; + let offset = index << (layer + 3); + for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { + ifft3( + values, + offset + l, + layer, + std::array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1)) + }), + ); + } + } +} + +/// Runs 2 ifft layers across the entire array. +/// +/// # Arguments +/// +/// - `values`: Pointer to the entire value array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 2 ifft layers. +/// - `loop_bits`: The number of bits this loops needs to run on. +/// - `layer`: The layer number of the first ifft layer to apply. The layers `layer`, `layer + 1` +/// are applied. +/// - `index`: The index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +unsafe fn ifft2_loop(values: *mut u32, twiddle_dbl: &[&[u32]], layer: usize, index: usize) { + let offset = index << (layer + 2); + for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { + ifft2( + values, + offset + l, + layer, + std::array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1)) + }), + ); + } +} + +/// Runs 1 ifft layer across the entire array. +/// +/// # Arguments +/// +/// - `values`: Pointer to the entire value array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for the ifft layer. +/// - `layer`: The layer number of the ifft layer to apply. +/// - `index_h`: The higher part of the index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +unsafe fn ifft1_loop(values: *mut u32, twiddle_dbl: &[&[u32]], layer: usize, index: usize) { + let offset = index << (layer + 1); + for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { + ifft1( + values, + offset + l, + layer, + std::array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1)) + }), + ); + } +} + +/// Computes the ibutterfly operation for packed M31 elements. +/// +/// Returns `val0 + val1, t (val0 - val1)`. `val0, val1` are packed M31 elements. 16 M31 words at +/// each. Each value is assumed to be in unreduced form, [0, P] including P. `twiddle_dbl` holds 16 +/// values, each is a *double* of a twiddle factor, in unreduced form. +pub fn simd_ibutterfly( + val0: PackedBaseField, + val1: PackedBaseField, + twiddle_dbl: u32x16, +) -> (PackedBaseField, PackedBaseField) { + let r0 = val0 + val1; + let r1 = val0 - val1; + let prod = mul_twiddle(r1, twiddle_dbl); + (r0, prod) +} + +/// Runs ifft on 2 vectors of 16 M31 elements. +/// +/// This amounts to 4 butterfly layers, each with 16 butterflies. +/// Each of the vectors represents a bit reversed evaluation. +/// Each value in a vectors is in unreduced form: [0, P] including P. +/// Takes 3 twiddle arrays, one for each layer after the first, holding the double of the +/// corresponding twiddle. +/// The first layer's twiddles (lower bit of the index) are computed from the second layer's +/// twiddles. The second layer takes 8 twiddles. +/// The third layer takes 4 twiddles. +/// The fourth layer takes 2 twiddles. +pub fn vecwise_ibutterflies( + mut val0: PackedBaseField, + mut val1: PackedBaseField, + twiddle1_dbl: [u32; 8], + twiddle2_dbl: [u32; 4], + twiddle3_dbl: [u32; 2], +) -> (PackedBaseField, PackedBaseField) { + // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. + + // Each `ibutterfly` take 2 512-bit registers, and does 16 butterflies element by element. + // We need to permute the 512-bit registers to get the right order for the butterflies. + // Denote the index of the 16 M31 elements in register i as i:abcd. + // At each layer we apply the following permutation to the index: + // i:abcd => d:iabc + // This is how it looks like at each iteration. + // i:abcd + // d:iabc + // ifft on d + // c:diab + // ifft on c + // b:cdia + // ifft on b + // a:bcid + // ifft on a + // i:abcd + + let (t0, t1) = compute_first_twiddles(twiddle1_dbl.into()); + + // Apply the permutation, resulting in indexing d:iabc. + (val0, val1) = val0.deinterleave(val1); + (val0, val1) = simd_ibutterfly(val0, val1, t0); + + // Apply the permutation, resulting in indexing c:diab. + (val0, val1) = val0.deinterleave(val1); + (val0, val1) = simd_ibutterfly(val0, val1, t1); + + let t = simd_swizzle!( + u32x4::from(twiddle2_dbl), + [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + ); + // Apply the permutation, resulting in indexing b:cdia. + (val0, val1) = val0.deinterleave(val1); + (val0, val1) = simd_ibutterfly(val0, val1, t); + + let t = simd_swizzle!( + u32x2::from(twiddle3_dbl), + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + ); + // Apply the permutation, resulting in indexing a:bcid. + (val0, val1) = val0.deinterleave(val1); + (val0, val1) = simd_ibutterfly(val0, val1, t); + + // Apply the permutation, resulting in indexing i:abcd. + val0.deinterleave(val1) +} + +/// Returns the line twiddles (x points) for an ifft on a coset. +pub fn get_itwiddle_dbls(mut coset: Coset) -> Vec> { + let mut res = vec![]; + for _ in 0..coset.log_size() { + res.push( + coset + .iter() + .take(coset.size() / 2) + .map(|p| p.x.inverse().0 * 2) + .collect_vec(), + ); + bit_reverse(res.last_mut().unwrap()); + coset = coset.double(); + } + + res +} + +/// Applies 3 ibutterfly layers on 8 vectors of 16 M31 elements. +/// +/// Vectorized over the 16 elements of the vectors. +/// Used for radix-8 ifft. +/// Each butterfly layer, has 3 SIMD butterflies. +/// Total of 12 SIMD butterflies. +/// +/// # Arguments +/// +/// - `values`: Pointer to the entire value array. +/// - `offset`: The offset of the first value in the array. +/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values +/// that need to be transformed. For layer i this is i - 4. +/// - `twiddles_dbl0/1/2`: The double of the twiddles for the 3 layers of ibutterflies. Each layer +/// has 4/2/1 twiddles. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +pub unsafe fn ifft3( + values: *mut u32, + offset: usize, + log_step: usize, + twiddles_dbl0: [u32; 4], + twiddles_dbl1: [u32; 2], + twiddles_dbl2: [u32; 1], +) { + // Load the 8 SIMD vectors from the array. + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); + let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); + let mut val4 = PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const()); + let mut val5 = PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const()); + let mut val6 = PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const()); + let mut val7 = PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const()); + + // Apply the first layer of ibutterflies. + (val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + (val2, val3) = simd_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); + (val4, val5) = simd_ibutterfly(val4, val5, u32x16::splat(twiddles_dbl0[2])); + (val6, val7) = simd_ibutterfly(val6, val7, u32x16::splat(twiddles_dbl0[3])); + + // Apply the second layer of ibutterflies. + (val0, val2) = simd_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); + (val1, val3) = simd_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); + (val4, val6) = simd_ibutterfly(val4, val6, u32x16::splat(twiddles_dbl1[1])); + (val5, val7) = simd_ibutterfly(val5, val7, u32x16::splat(twiddles_dbl1[1])); + + // Apply the third layer of ibutterflies. + (val0, val4) = simd_ibutterfly(val0, val4, u32x16::splat(twiddles_dbl2[0])); + (val1, val5) = simd_ibutterfly(val1, val5, u32x16::splat(twiddles_dbl2[0])); + (val2, val6) = simd_ibutterfly(val2, val6, u32x16::splat(twiddles_dbl2[0])); + (val3, val7) = simd_ibutterfly(val3, val7, u32x16::splat(twiddles_dbl2[0])); + + // Store the 8 SIMD vectors back to the array. + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); + val2.store(values.add(offset + (2 << log_step))); + val3.store(values.add(offset + (3 << log_step))); + val4.store(values.add(offset + (4 << log_step))); + val5.store(values.add(offset + (5 << log_step))); + val6.store(values.add(offset + (6 << log_step))); + val7.store(values.add(offset + (7 << log_step))); +} + +/// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements. +/// +/// Vectorized over the 16 elements of the vectors. +/// Used for radix-4 ifft. +/// Each ibutterfly layer, has 2 SIMD butterflies. +/// Total of 4 SIMD butterflies. +/// +/// # Arguments +/// +/// - `values`: Pointer to the entire value array. +/// - `offset`: The offset of the first value in the array. +/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values +/// that need to be transformed. For layer `i` this is `i - 4`. +/// - `twiddles_dbl0/1`: The double of the twiddles for the 2 layers of ibutterflies. Each layer has +/// 2/1 twiddles. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +pub unsafe fn ifft2( + values: *mut u32, + offset: usize, + log_step: usize, + twiddles_dbl0: [u32; 2], + twiddles_dbl1: [u32; 1], +) { + // Load the 4 SIMD vectors from the array. + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); + let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); + + // Apply the first layer of butterflies. + (val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + (val2, val3) = simd_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); + + // Apply the second layer of butterflies. + (val0, val2) = simd_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); + (val1, val3) = simd_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); + + // Store the 4 SIMD vectors back to the array. + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); + val2.store(values.add(offset + (2 << log_step))); + val3.store(values.add(offset + (3 << log_step))); +} + +/// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements. +/// +/// Vectorized over the 16 elements of the vectors. +/// +/// # Arguments +/// +/// - `values`: Pointer to the entire value array. +/// - `offset`: The offset of the first value in the array. +/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values +/// that need to be transformed. For layer `i` this is `i - 4`. +/// - `twiddles_dbl0`: The double of the twiddles for the ibutterfly layer. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +pub unsafe fn ifft1(values: *mut u32, offset: usize, log_step: usize, twiddles_dbl0: [u32; 1]) { + // Load the 2 SIMD vectors from the array. + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + + (val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + + // Store the 2 SIMD vectors back to the array. + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); +} + +#[cfg(test)] +mod tests { + use std::mem::transmute; + use std::simd::u32x16; + + use itertools::Itertools; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::{ + get_itwiddle_dbls, ifft, ifft3, ifft_lower_with_vecwise, simd_ibutterfly, + vecwise_ibutterflies, + }; + use crate::core::backend::cpu::CpuCircleEvaluation; + use crate::core::backend::simd::column::BaseColumn; + use crate::core::backend::simd::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE}; + use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; + use crate::core::backend::Column; + use crate::core::fft::ibutterfly as ground_truth_ibutterfly; + use crate::core::fields::m31::BaseField; + use crate::core::poly::circle::{CanonicCoset, CircleDomain}; + + #[test] + fn test_ibutterfly() { + let mut rng = SmallRng::seed_from_u64(0); + let mut v0: [BaseField; N_LANES] = rng.gen(); + let mut v1: [BaseField; N_LANES] = rng.gen(); + let twiddle: [BaseField; N_LANES] = rng.gen(); + let twiddle_dbl = twiddle.map(|v| v.0 * 2); + + let (r0, r1) = simd_ibutterfly(v0.into(), v1.into(), twiddle_dbl.into()); + + let r0 = r0.to_array(); + let r1 = r1.to_array(); + for i in 0..N_LANES { + ground_truth_ibutterfly(&mut v0[i], &mut v1[i], twiddle[i]); + assert_eq!((v0[i], v1[i]), (r0[i], r1[i]), "mismatch at i={i}"); + } + } + + #[test] + fn test_ifft3() { + let mut rng = SmallRng::seed_from_u64(0); + let values = rng.gen::<[BaseField; 8]>().map(PackedBaseField::broadcast); + let twiddles0: [BaseField; 4] = rng.gen(); + let twiddles1: [BaseField; 2] = rng.gen(); + let twiddles2: [BaseField; 1] = rng.gen(); + let twiddles0_dbl = twiddles0.map(|v| v.0 * 2); + let twiddles1_dbl = twiddles1.map(|v| v.0 * 2); + let twiddles2_dbl = twiddles2.map(|v| v.0 * 2); + + let mut res = values; + unsafe { + ifft3( + transmute(res.as_mut_ptr()), + 0, + LOG_N_LANES as usize, + twiddles0_dbl, + twiddles1_dbl, + twiddles2_dbl, + ) + }; + + let mut expected = values.map(|v| v.to_array()[0]); + for i in 0..8 { + let j = i ^ 1; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_ibutterfly(&mut v0, &mut v1, twiddles0[i / 2]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 2; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_ibutterfly(&mut v0, &mut v1, twiddles1[i / 4]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 4; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_ibutterfly(&mut v0, &mut v1, twiddles2[0]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + assert_eq!( + res[i].to_array(), + [expected[i]; N_LANES], + "mismatch at i={i}" + ); + } + } + + #[test] + fn test_vecwise_ibutterflies() { + let domain = CanonicCoset::new(5).circle_domain(); + let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); + assert_eq!(twiddle_dbls.len(), 4); + let mut rng = SmallRng::seed_from_u64(0); + let values: [[BaseField; 16]; 2] = rng.gen(); + + let res = { + let (val0, val1) = vecwise_ibutterflies( + values[0].into(), + values[1].into(), + twiddle_dbls[0].clone().try_into().unwrap(), + twiddle_dbls[1].clone().try_into().unwrap(), + twiddle_dbls[2].clone().try_into().unwrap(), + ); + let (val0, val1) = simd_ibutterfly(val0, val1, u32x16::splat(twiddle_dbls[3][0])); + [val0.to_array(), val1.to_array()].concat() + }; + + assert_eq!(res, ground_truth_ifft(domain, values.flatten())); + } + + #[test] + fn test_ifft_lower_with_vecwise() { + for log_size in 5..12 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); + let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); + + let mut res = values.iter().copied().collect::(); + unsafe { + ifft_lower_with_vecwise( + transmute(res.data.as_mut_ptr()), + &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), + log_size as usize, + log_size as usize, + ); + } + + assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values)); + } + } + + #[test] + fn test_ifft_full() { + for log_size in CACHED_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 3 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); + let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); + + let mut res = values.iter().copied().collect::(); + unsafe { + ifft( + transmute(res.data.as_mut_ptr()), + &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), + log_size as usize, + ); + transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4); + } + + assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values)); + } + } + + fn ground_truth_ifft(domain: CircleDomain, values: &[BaseField]) -> Vec { + let eval = CpuCircleEvaluation::new(domain, values.to_vec()); + let mut res = eval.interpolate().coeffs; + let denorm = BaseField::from(domain.size()); + res.iter_mut().for_each(|v| *v *= denorm); + res + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/mod.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/mod.rs new file mode 100644 index 0000000..ca44979 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/mod.rs @@ -0,0 +1,120 @@ +use std::simd::{simd_swizzle, u32x16, u32x8}; + +use super::m31::PackedBaseField; +use crate::core::fields::m31::P; + +pub mod ifft; +pub mod rfft; + +pub const CACHED_FFT_LOG_SIZE: u32 = 16; + +pub const MIN_FFT_LOG_SIZE: u32 = 5; + +// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce +// it somewhere. + +/// Transposes the SIMD vectors in the given array. +/// +/// Swaps the bit index abc <-> cba, where |a|=|c| and |b| = 0 or 1, according to the parity of +/// `log_n_vecs`. +/// When log_n_vecs is odd, transforms the index abc <-> cba, w +/// +/// # Arguments +/// +/// - `values`: A mutable pointer to the values that are to be transposed. +/// - `log_n_vecs`: The log of the number of SIMD vectors in the `values` array. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`u32x16`]. +pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) { + let half = log_n_vecs / 2; + for b in 0..1 << (log_n_vecs & 1) { + for a in 0..1 << half { + for c in 0..1 << half { + let i = (a << (log_n_vecs - half)) | (b << half) | c; + let j = (c << (log_n_vecs - half)) | (b << half) | a; + if i >= j { + continue; + } + let val0 = load(values.add(i << 4).cast_const()); + let val1 = load(values.add(j << 4).cast_const()); + store(values.add(i << 4), val1); + store(values.add(j << 4), val0); + } + } + } +} + +/// Computes the twiddles for the first fft layer from the second, and loads both to SIMD registers. +/// +/// Returns the twiddles for the first layer and the twiddles for the second layer. +pub fn compute_first_twiddles(twiddle1_dbl: u32x8) -> (u32x16, u32x16) { + // Start by loading the twiddles for the second layer (layer 1): + let t1 = simd_swizzle!( + twiddle1_dbl, + twiddle1_dbl, + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7] + ); + + // The twiddles for layer 0 can be computed from the twiddles for layer 1. + // Since the twiddles are bit reversed, we consider the circle domain in bit reversed order. + // Each consecutive 4 points in the bit reversed order of a coset form a circle coset of size 4. + // A circle coset of size 4 in bit reversed order looks like this: + // [(x, y), (-x, -y), (y, -x), (-y, x)] + // Note: This is related to the choice of M31_CIRCLE_GEN, and the fact the a quarter rotation + // is (0,-1) and not (0,1). (0,1) would yield another relation. + // The twiddles for layer 0 are the y coordinates: + // [y, -y, -x, x] + // The twiddles for layer 1 in bit reversed order are the x coordinates: + // [x, y] + // Works also for inverse of the twiddles. + + // The twiddles for layer 0 are computed like this: + // t0[4i:4i+3] = [t1[2i+1], -t1[2i+1], -t1[2i], t1[2i]] + // Xoring a double twiddle with P*2 transforms it to the double of it negation. + // Note that this keeps the values as a double of a value in the range [0, P]. + const P2: u32 = P * 2; + const NEGATION_MASK: u32x16 = + u32x16::from_array([0, P2, P2, 0, 0, P2, P2, 0, 0, P2, P2, 0, 0, P2, P2, 0]); + let t0 = simd_swizzle!( + t1, + [ + 0b0001, 0b0001, 0b0000, 0b0000, 0b0011, 0b0011, 0b0010, 0b0010, 0b0101, 0b0101, 0b0100, + 0b0100, 0b0111, 0b0111, 0b0110, 0b0110, + ] + ) ^ NEGATION_MASK; + (t0, t1) +} + +#[inline] +unsafe fn load(mem_addr: *const u32) -> u32x16 { + std::ptr::read(mem_addr as *const u32x16) +} + +#[inline] +unsafe fn store(mem_addr: *mut u32, a: u32x16) { + std::ptr::write(mem_addr as *mut u32x16, a); +} + +/// Computes `v * twiddle` +fn mul_twiddle(v: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { + // TODO: Come up with a better approach than `cfg`ing on target_feature. + // TODO: Ensure all these branches get tested in the CI. + cfg_if::cfg_if! { + if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { + // TODO: For architectures that when multiplying require doubling then the twiddles + // should be precomputed as double. For other architectures, the twiddle should be + // precomputed without doubling. + crate::core::backend::simd::m31::_mul_doubled_neon(v, twiddle_dbl) + } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { + crate::core::backend::simd::m31::_mul_doubled_wasm(v, twiddle_dbl) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { + crate::core::backend::simd::m31::_mul_doubled_avx512(v, twiddle_dbl) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { + crate::core::backend::simd::m31::_mul_doubled_avx2(v, twiddle_dbl) + } else { + crate::core::backend::simd::m31::_mul_doubled_simd(v, twiddle_dbl) + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/rfft.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/rfft.rs new file mode 100644 index 0000000..6d51fd0 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/fft/rfft.rs @@ -0,0 +1,742 @@ +//! Regular (forward) fft. + +use std::array; +use std::simd::{simd_swizzle, u32x16, u32x2, u32x4, u32x8}; + +use itertools::Itertools; + +use super::{ + compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, +}; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::circle::Coset; +use crate::core::utils::bit_reverse; + +/// Performs a Circle Fast Fourier Transform (CFFT) on the given values. +/// +/// # Arguments +/// +/// * `src`: A pointer to the values to transform. +/// * `dst`: A pointer to the destination array. +/// * `twiddle_dbl`: A reference to the doubles of the twiddle factors. +/// * `log_n_elements`: The log of the number of elements in the `values` array. +/// +/// # Panics +/// +/// This function will panic if `log_n_elements` is less than `MIN_FFT_LOG_SIZE`. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +pub unsafe fn fft(src: *const u32, dst: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) { + assert!(log_n_elements >= MIN_FFT_LOG_SIZE as usize); + let log_n_vecs = log_n_elements - LOG_N_LANES as usize; + if log_n_elements <= CACHED_FFT_LOG_SIZE as usize { + fft_lower_with_vecwise(src, dst, twiddle_dbl, log_n_elements, log_n_elements); + return; + } + + let fft_layers_pre_transpose = log_n_vecs.div_ceil(2); + let fft_layers_post_transpose = log_n_vecs / 2; + fft_lower_without_vecwise( + src, + dst, + &twiddle_dbl[(3 + fft_layers_pre_transpose)..], + log_n_elements, + fft_layers_post_transpose, + ); + transpose_vecs(dst, log_n_vecs); + fft_lower_with_vecwise( + dst, + dst, + &twiddle_dbl[..3 + fft_layers_pre_transpose], + log_n_elements, + fft_layers_pre_transpose + LOG_N_LANES as usize, + ); +} + +/// Computes partial fft on `2^log_size` M31 elements. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the fft. Layer `i` +/// holds `2^(log_size - 1 - i)` twiddles. +/// - `log_size`: The log of the number of number of M31 elements in the array. +/// - `fft_layers`: The number of fft layers to apply, out of log_size. +/// +/// # Panics +/// +/// Panics if `log_size` is not at least 5. +/// +/// # Safety +/// +/// `src` and `dst` must have same alignment as [`PackedBaseField`]. +/// `fft_layers` must be at least 5. +pub unsafe fn fft_lower_with_vecwise( + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], + log_size: usize, + fft_layers: usize, +) { + const VECWISE_FFT_BITS: usize = LOG_N_LANES as usize + 1; + assert!(log_size >= VECWISE_FFT_BITS); + + assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); + + for index_h in 0..1 << (log_size - fft_layers) { + let mut src = src; + for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() { + match fft_layers - layer { + 1 => { + fft1_loop(src, dst, &twiddle_dbl[(layer - 1)..], layer, index_h); + } + 2 => { + fft2_loop(src, dst, &twiddle_dbl[(layer - 1)..], layer, index_h); + } + _ => { + fft3_loop( + src, + dst, + &twiddle_dbl[(layer - 1)..], + fft_layers - layer - 3, + layer, + index_h, + ); + } + } + src = dst; + } + fft_vecwise_loop( + src, + dst, + twiddle_dbl, + fft_layers - VECWISE_FFT_BITS, + index_h, + ); + } +} + +/// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of +/// the index). +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the fft. +/// - `log_size`: The log of the number of number of M31 elements in the array. +/// - `fft_layers`: The number of fft layers to apply, out of log_size - VEC_LOG_SIZE. +/// +/// # Panics +/// +/// Panics if `log_size` is not at least 4. +/// +/// # Safety +/// +/// `src` and `dst` must have same alignment as [`PackedBaseField`]. +/// `fft_layers` must be at least 4. +pub unsafe fn fft_lower_without_vecwise( + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], + log_size: usize, + fft_layers: usize, +) { + assert!(log_size >= LOG_N_LANES as usize); + + for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) { + let mut src = src; + for layer in (0..fft_layers).step_by(3).rev() { + let fixed_layer = layer + LOG_N_LANES as usize; + match fft_layers - layer { + 1 => { + fft1_loop(src, dst, &twiddle_dbl[layer..], fixed_layer, index_h); + } + 2 => { + fft2_loop(src, dst, &twiddle_dbl[layer..], fixed_layer, index_h); + } + _ => { + fft3_loop( + src, + dst, + &twiddle_dbl[layer..], + fft_layers - layer - 3, + fixed_layer, + index_h, + ); + } + } + src = dst; + } + } +} + +/// Runs the last 5 fft layers across the entire array. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 5 fft layers. +/// - `high_bits`: The number of bits this loops needs to run on. +/// - `index_h`: The higher part of the index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +unsafe fn fft_vecwise_loop( + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], + loop_bits: usize, + index_h: usize, +) { + for index_l in 0..1 << loop_bits { + let index = (index_h << loop_bits) + index_l; + let mut val0 = PackedBaseField::load(src.add(index * 32)); + let mut val1 = PackedBaseField::load(src.add(index * 32 + 16)); + (val0, val1) = simd_butterfly( + val0, + val1, + u32x16::splat(*twiddle_dbl[3].get_unchecked(index)), + ); + (val0, val1) = vecwise_butterflies( + val0, + val1, + array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)), + array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), + array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), + ); + val0.store(dst.add(index * 32)); + val1.store(dst.add(index * 32 + 16)); + } +} + +/// Runs 3 fft layers across the entire array. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 3 fft layers. +/// - `loop_bits`: The number of bits this loops needs to run on. +/// - `layer`: The layer number of the first fft layer to apply. The layers `layer`, `layer + 1`, +/// `layer + 2` are applied. +/// - `index_h`: The higher part of the index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +unsafe fn fft3_loop( + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], + loop_bits: usize, + layer: usize, + index_h: usize, +) { + for index_l in 0..1 << loop_bits { + let index = (index_h << loop_bits) + index_l; + let offset = index << (layer + 3); + for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { + fft3( + src, + dst, + offset + l, + layer, + array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1)) + }), + array::from_fn(|i| { + *twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1)) + }), + array::from_fn(|i| { + *twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1)) + }), + ); + } + } +} + +/// Runs 2 fft layers across the entire array. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 2 fft layers. +/// - `loop_bits`: The number of bits this loops needs to run on. +/// - `layer`: The layer number of the first fft layer to apply. The layers `layer`, `layer + 1` are +/// applied. +/// - `index`: The index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +unsafe fn fft2_loop( + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], + layer: usize, + index: usize, +) { + let offset = index << (layer + 2); + for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { + fft2( + src, + dst, + offset + l, + layer, + array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1)) + }), + array::from_fn(|i| { + *twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1)) + }), + ); + } +} + +/// Runs 1 fft layer across the entire array. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for the fft layer. +/// - `layer`: The layer number of the fft layer to apply. +/// - `index_h`: The higher part of the index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +unsafe fn fft1_loop( + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], + layer: usize, + index: usize, +) { + let offset = index << (layer + 1); + for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { + fft1( + src, + dst, + offset + l, + layer, + array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1)) + }), + ); + } +} + +/// Computes the butterfly operation for packed M31 elements. +/// +/// Returns `val0 + t val1, val0 - t val1`. `val0, val1` are packed M31 elements. 16 M31 words at +/// each. Each value is assumed to be in unreduced form, [0, P] including P. Returned values are in +/// unreduced form, [0, P] including P. twiddle_dbl holds 16 values, each is a *double* of a twiddle +/// factor, in unreduced form, [0, 2*P]. +pub fn simd_butterfly( + val0: PackedBaseField, + val1: PackedBaseField, + twiddle_dbl: u32x16, +) -> (PackedBaseField, PackedBaseField) { + let prod = mul_twiddle(val1, twiddle_dbl); + (val0 + prod, val0 - prod) +} + +/// Runs fft on 2 vectors of 16 M31 elements. +/// +/// This amounts to 4 butterfly layers, each with 16 butterflies. +/// Each of the vectors represents natural ordered polynomial coefficeint. +/// Each value in a vectors is in unreduced form: [0, P] including P. +/// Takes 4 twiddle arrays, one for each layer, holding the double of the corresponding twiddle. +/// The first layer (higher bit of the index) takes 2 twiddles. +/// The second layer takes 4 twiddles. +/// etc. +pub fn vecwise_butterflies( + mut val0: PackedBaseField, + mut val1: PackedBaseField, + twiddle1_dbl: [u32; 8], + twiddle2_dbl: [u32; 4], + twiddle3_dbl: [u32; 2], +) -> (PackedBaseField, PackedBaseField) { + // TODO(spapini): Compute twiddle0 from twiddle1. + // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. + // The implementation is the exact reverse of vecwise_ibutterflies(). + // See the comments in its body for more info. + let t = simd_swizzle!( + u32x2::from(twiddle3_dbl), + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + ); + (val0, val1) = val0.interleave(val1); + (val0, val1) = simd_butterfly(val0, val1, t); + + let t = simd_swizzle!( + u32x4::from(twiddle2_dbl), + [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + ); + (val0, val1) = val0.interleave(val1); + (val0, val1) = simd_butterfly(val0, val1, t); + + let (t0, t1) = compute_first_twiddles(u32x8::from(twiddle1_dbl)); + (val0, val1) = val0.interleave(val1); + (val0, val1) = simd_butterfly(val0, val1, t1); + + (val0, val1) = val0.interleave(val1); + (val0, val1) = simd_butterfly(val0, val1, t0); + + val0.interleave(val1) +} + +/// Returns the line twiddles (x points) for an fft on a coset. +pub fn get_twiddle_dbls(mut coset: Coset) -> Vec> { + let mut res = vec![]; + for _ in 0..coset.log_size() { + res.push( + coset + .iter() + .take(coset.size() / 2) + .map(|p| p.x.0 * 2) + .collect_vec(), + ); + bit_reverse(res.last_mut().unwrap()); + coset = coset.double(); + } + + res +} + +/// Applies 3 butterfly layers on 8 vectors of 16 M31 elements. +/// +/// Vectorized over the 16 elements of the vectors. +/// Used for radix-8 ifft. +/// Each butterfly layer, has 3 SIMD butterflies. +/// Total of 12 SIMD butterflies. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `offset`: The offset of the first value in the array. +/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values +/// that need to be transformed. For layer i this is i - 4. +/// - `twiddles_dbl0/1/2`: The double of the twiddles for the 3 layers of butterflies. Each layer +/// has 4/2/1 twiddles. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +pub unsafe fn fft3( + src: *const u32, + dst: *mut u32, + offset: usize, + log_step: usize, + twiddles_dbl0: [u32; 4], + twiddles_dbl1: [u32; 2], + twiddles_dbl2: [u32; 1], +) { + // Load the 8 SIMD vectors from the array. + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); + let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step))); + let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step))); + let mut val4 = PackedBaseField::load(src.add(offset + (4 << log_step))); + let mut val5 = PackedBaseField::load(src.add(offset + (5 << log_step))); + let mut val6 = PackedBaseField::load(src.add(offset + (6 << log_step))); + let mut val7 = PackedBaseField::load(src.add(offset + (7 << log_step))); + + // Apply the third layer of butterflies. + (val0, val4) = simd_butterfly(val0, val4, u32x16::splat(twiddles_dbl2[0])); + (val1, val5) = simd_butterfly(val1, val5, u32x16::splat(twiddles_dbl2[0])); + (val2, val6) = simd_butterfly(val2, val6, u32x16::splat(twiddles_dbl2[0])); + (val3, val7) = simd_butterfly(val3, val7, u32x16::splat(twiddles_dbl2[0])); + + // Apply the second layer of butterflies. + (val0, val2) = simd_butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); + (val1, val3) = simd_butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); + (val4, val6) = simd_butterfly(val4, val6, u32x16::splat(twiddles_dbl1[1])); + (val5, val7) = simd_butterfly(val5, val7, u32x16::splat(twiddles_dbl1[1])); + + // Apply the first layer of butterflies. + (val0, val1) = simd_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + (val2, val3) = simd_butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); + (val4, val5) = simd_butterfly(val4, val5, u32x16::splat(twiddles_dbl0[2])); + (val6, val7) = simd_butterfly(val6, val7, u32x16::splat(twiddles_dbl0[3])); + + // Store the 8 SIMD vectors back to the array. + val0.store(dst.add(offset + (0 << log_step))); + val1.store(dst.add(offset + (1 << log_step))); + val2.store(dst.add(offset + (2 << log_step))); + val3.store(dst.add(offset + (3 << log_step))); + val4.store(dst.add(offset + (4 << log_step))); + val5.store(dst.add(offset + (5 << log_step))); + val6.store(dst.add(offset + (6 << log_step))); + val7.store(dst.add(offset + (7 << log_step))); +} + +/// Applies 2 butterfly layers on 4 vectors of 16 M31 elements. +/// +/// Vectorized over the 16 elements of the vectors. +/// Used for radix-4 fft. +/// Each butterfly layer, has 2 SIMD butterflies. +/// Total of 4 SIMD butterflies. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `offset`: The offset of the first value in the array. +/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values +/// that need to be transformed. For layer i this is i - 4. +/// - `twiddles_dbl0/1`: The double of the twiddles for the 2 layers of butterflies. Each layer has +/// 2/1 twiddles. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +pub unsafe fn fft2( + src: *const u32, + dst: *mut u32, + offset: usize, + log_step: usize, + twiddles_dbl0: [u32; 2], + twiddles_dbl1: [u32; 1], +) { + // Load the 4 SIMD vectors from the array. + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); + let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step))); + let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step))); + + // Apply the second layer of butterflies. + (val0, val2) = simd_butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); + (val1, val3) = simd_butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); + + // Apply the first layer of butterflies. + (val0, val1) = simd_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + (val2, val3) = simd_butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); + + // Store the 4 SIMD vectors back to the array. + val0.store(dst.add(offset + (0 << log_step))); + val1.store(dst.add(offset + (1 << log_step))); + val2.store(dst.add(offset + (2 << log_step))); + val3.store(dst.add(offset + (3 << log_step))); +} + +/// Applies 1 butterfly layers on 2 vectors of 16 M31 elements. +/// +/// Vectorized over the 16 elements of the vectors. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `offset`: The offset of the first value in the array. +/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values +/// that need to be transformed. For layer i this is i - 4. +/// - `twiddles_dbl0`: The double of the twiddles for the butterfly layer. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +pub unsafe fn fft1( + src: *const u32, + dst: *mut u32, + offset: usize, + log_step: usize, + twiddles_dbl0: [u32; 1], +) { + // Load the 2 SIMD vectors from the array. + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); + + (val0, val1) = simd_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + + // Store the 2 SIMD vectors back to the array. + val0.store(dst.add(offset + (0 << log_step))); + val1.store(dst.add(offset + (1 << log_step))); +} + +#[cfg(test)] +mod tests { + use std::mem::transmute; + use std::simd::u32x16; + + use itertools::Itertools; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::{ + fft, fft3, fft_lower_with_vecwise, get_twiddle_dbls, simd_butterfly, vecwise_butterflies, + }; + use crate::core::backend::cpu::CpuCirclePoly; + use crate::core::backend::simd::column::BaseColumn; + use crate::core::backend::simd::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE}; + use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; + use crate::core::backend::Column; + use crate::core::fft::butterfly as ground_truth_butterfly; + use crate::core::fields::m31::BaseField; + use crate::core::poly::circle::{CanonicCoset, CircleDomain}; + + #[test] + fn test_butterfly() { + let mut rng = SmallRng::seed_from_u64(0); + let mut v0: [BaseField; N_LANES] = rng.gen(); + let mut v1: [BaseField; N_LANES] = rng.gen(); + let twiddle: [BaseField; N_LANES] = rng.gen(); + let twiddle_dbl = twiddle.map(|v| v.0 * 2); + + let (r0, r1) = simd_butterfly(v0.into(), v1.into(), twiddle_dbl.into()); + + let r0 = r0.to_array(); + let r1 = r1.to_array(); + for i in 0..N_LANES { + ground_truth_butterfly(&mut v0[i], &mut v1[i], twiddle[i]); + assert_eq!((v0[i], v1[i]), (r0[i], r1[i]), "mismatch at i={i}"); + } + } + + #[test] + fn test_fft3() { + let mut rng = SmallRng::seed_from_u64(0); + let values = rng.gen::<[BaseField; 8]>().map(PackedBaseField::broadcast); + let twiddles0: [BaseField; 4] = rng.gen(); + let twiddles1: [BaseField; 2] = rng.gen(); + let twiddles2: [BaseField; 1] = rng.gen(); + let twiddles0_dbl = twiddles0.map(|v| v.0 * 2); + let twiddles1_dbl = twiddles1.map(|v| v.0 * 2); + let twiddles2_dbl = twiddles2.map(|v| v.0 * 2); + + let mut res = values; + unsafe { + fft3( + transmute(res.as_ptr()), + transmute(res.as_mut_ptr()), + 0, + LOG_N_LANES as usize, + twiddles0_dbl, + twiddles1_dbl, + twiddles2_dbl, + ) + }; + + let mut expected = values.map(|v| v.to_array()[0]); + for i in 0..8 { + let j = i ^ 4; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_butterfly(&mut v0, &mut v1, twiddles2[0]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 2; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_butterfly(&mut v0, &mut v1, twiddles1[i / 4]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 1; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_butterfly(&mut v0, &mut v1, twiddles0[i / 2]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + assert_eq!( + res[i].to_array(), + [expected[i]; N_LANES], + "mismatch at i={i}" + ); + } + } + + #[test] + fn test_vecwise_butterflies() { + let domain = CanonicCoset::new(5).circle_domain(); + let twiddle_dbls = get_twiddle_dbls(domain.half_coset); + assert_eq!(twiddle_dbls.len(), 4); + let mut rng = SmallRng::seed_from_u64(0); + let values: [[BaseField; 16]; 2] = rng.gen(); + + let res = { + let (val0, val1) = simd_butterfly( + values[0].into(), + values[1].into(), + u32x16::splat(twiddle_dbls[3][0]), + ); + let (val0, val1) = vecwise_butterflies( + val0, + val1, + twiddle_dbls[0].clone().try_into().unwrap(), + twiddle_dbls[1].clone().try_into().unwrap(), + twiddle_dbls[2].clone().try_into().unwrap(), + ); + [val0.to_array(), val1.to_array()].concat() + }; + + assert_eq!(res, ground_truth_fft(domain, values.flatten())); + } + + #[test] + fn test_fft_lower() { + for log_size in 5..12 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); + let twiddle_dbls = get_twiddle_dbls(domain.half_coset); + + let mut res = values.iter().copied().collect::(); + unsafe { + fft_lower_with_vecwise( + transmute(res.data.as_ptr()), + transmute(res.data.as_mut_ptr()), + &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), + log_size as usize, + log_size as usize, + ) + } + + assert_eq!(res.to_cpu(), ground_truth_fft(domain, &values)); + } + } + + #[test] + fn test_fft_full() { + for log_size in CACHED_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 3 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); + let twiddle_dbls = get_twiddle_dbls(domain.half_coset); + + let mut res = values.iter().copied().collect::(); + unsafe { + transpose_vecs(transmute(res.data.as_mut_ptr()), log_size as usize - 4); + fft( + transmute(res.data.as_ptr()), + transmute(res.data.as_mut_ptr()), + &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), + log_size as usize, + ); + } + + assert_eq!(res.to_cpu(), ground_truth_fft(domain, &values)); + } + } + + fn ground_truth_fft(domain: CircleDomain, values: &[BaseField]) -> Vec { + let poly = CpuCirclePoly::new(values.to_vec()); + poly.evaluate(domain).values + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/fri.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/fri.rs new file mode 100644 index 0000000..9721249 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/fri.rs @@ -0,0 +1,261 @@ +use std::array; +use std::simd::u32x8; + +use num_traits::Zero; + +use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; +use super::SimdBackend; +use crate::core::backend::simd::fft::compute_first_twiddles; +use crate::core::backend::simd::fft::ifft::simd_ibutterfly; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::Column; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SecureColumnByCoords; +use crate::core::fri::{self, FriOps}; +use crate::core::poly::circle::SecureEvaluation; +use crate::core::poly::line::LineEvaluation; +use crate::core::poly::twiddles::TwiddleTree; +use crate::core::poly::utils::domain_line_twiddles_from_tree; +use crate::core::poly::BitReversedOrder; + +impl FriOps for SimdBackend { + fn fold_line( + eval: &LineEvaluation, + alpha: SecureField, + twiddles: &TwiddleTree, + ) -> LineEvaluation { + let log_size = eval.len().ilog2(); + if log_size <= LOG_N_LANES { + let eval = fri::fold_line(&eval.to_cpu(), alpha); + return LineEvaluation::new(eval.domain(), eval.values.into_iter().collect()); + } + + let domain = eval.domain(); + let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0]; + + let mut folded_values = SecureColumnByCoords::::zeros(1 << (log_size - 1)); + + for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) { + let value = unsafe { + let twiddle_dbl: [u32; 16] = + array::from_fn(|i| *itwiddles.get_unchecked(vec_index * 16 + i)); + let val0 = eval.values.packed_at(vec_index * 2).into_packed_m31s(); + let val1 = eval.values.packed_at(vec_index * 2 + 1).into_packed_m31s(); + let pairs: [_; 4] = array::from_fn(|i| { + let (a, b) = val0[i].deinterleave(val1[i]); + simd_ibutterfly(a, b, std::mem::transmute(twiddle_dbl)) + }); + let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0)); + let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1)); + val0 + PackedSecureField::broadcast(alpha) * val1 + }; + unsafe { folded_values.set_packed(vec_index, value) }; + } + + LineEvaluation::new(domain.double(), folded_values) + } + + fn fold_circle_into_line( + dst: &mut LineEvaluation, + src: &SecureEvaluation, + alpha: SecureField, + twiddles: &TwiddleTree, + ) { + let log_size = src.len().ilog2(); + assert!(log_size > LOG_N_LANES, "Evaluation too small"); + + let domain = src.domain; + let alpha_sq = alpha * alpha; + let itwiddles = domain_line_twiddles_from_tree(domain, &twiddles.itwiddles)[0]; + + for vec_index in 0..(1 << (log_size - 1 - LOG_N_LANES)) { + let value = unsafe { + // The 16 twiddles of the circle domain can be derived from the 8 twiddles of the + // next line domain. See `compute_first_twiddles()`. + let twiddle_dbl = u32x8::from_array(array::from_fn(|i| { + *itwiddles.get_unchecked(vec_index * 8 + i) + })); + let (t0, _) = compute_first_twiddles(twiddle_dbl); + let val0 = src.values.packed_at(vec_index * 2).into_packed_m31s(); + let val1 = src.values.packed_at(vec_index * 2 + 1).into_packed_m31s(); + let pairs: [_; 4] = array::from_fn(|i| { + let (a, b) = val0[i].deinterleave(val1[i]); + simd_ibutterfly(a, b, t0) + }); + let val0 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].0)); + let val1 = PackedSecureField::from_packed_m31s(array::from_fn(|i| pairs[i].1)); + val0 + PackedSecureField::broadcast(alpha) * val1 + }; + unsafe { + dst.values.set_packed( + vec_index, + dst.values.packed_at(vec_index) * PackedSecureField::broadcast(alpha_sq) + + value, + ) + }; + } + } + + fn decompose( + eval: &SecureEvaluation, + ) -> (SecureEvaluation, SecureField) { + let lambda = decomposition_coefficient(eval); + let broadcasted_lambda = PackedSecureField::broadcast(lambda); + let mut g_values = SecureColumnByCoords::::zeros(eval.len()); + + let range = eval.len().div_ceil(N_LANES); + let half_range = range / 2; + for i in 0..half_range { + let val = unsafe { eval.packed_at(i) } - broadcasted_lambda; + unsafe { g_values.set_packed(i, val) } + } + for i in half_range..range { + let val = unsafe { eval.packed_at(i) } + broadcasted_lambda; + unsafe { g_values.set_packed(i, val) } + } + + let g = SecureEvaluation::new(eval.domain, g_values); + (g, lambda) + } +} + +/// See [`decomposition_coefficient`]. +/// +/// [`decomposition_coefficient`]: crate::core::backend::cpu::CpuBackend::decomposition_coefficient +fn decomposition_coefficient( + eval: &SecureEvaluation, +) -> SecureField { + let cols = &eval.values.columns; + let [mut x_sum, mut y_sum, mut z_sum, mut w_sum] = [PackedBaseField::zero(); 4]; + + let range = cols[0].len() / N_LANES; + let (half_a, half_b) = (range / 2, range); + + for i in 0..half_a { + x_sum += cols[0].data[i]; + y_sum += cols[1].data[i]; + z_sum += cols[2].data[i]; + w_sum += cols[3].data[i]; + } + for i in half_a..half_b { + x_sum -= cols[0].data[i]; + y_sum -= cols[1].data[i]; + z_sum -= cols[2].data[i]; + w_sum -= cols[3].data[i]; + } + + let x = x_sum.pointwise_sum(); + let y = y_sum.pointwise_sum(); + let z = z_sum.pointwise_sum(); + let w = w_sum.pointwise_sum(); + + SecureField::from_m31(x, y, z, w) / BaseField::from_u32_unchecked(1 << eval.domain.log_size()) +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use num_traits::One; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use crate::core::backend::simd::column::BaseColumn; + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::{Column, CpuBackend}; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::secure_column::SecureColumnByCoords; + use crate::core::fri::FriOps; + use crate::core::poly::circle::{CanonicCoset, CirclePoly, PolyOps, SecureEvaluation}; + use crate::core::poly::line::{LineDomain, LineEvaluation}; + use crate::core::poly::BitReversedOrder; + use crate::qm31; + + #[test] + fn test_fold_line() { + const LOG_SIZE: u32 = 7; + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..1 << LOG_SIZE).map(|_| rng.gen()).collect_vec(); + let alpha = qm31!(1, 3, 5, 7); + let domain = LineDomain::new(CanonicCoset::new(LOG_SIZE + 1).half_coset()); + let cpu_fold = CpuBackend::fold_line( + &LineEvaluation::new(domain, values.iter().copied().collect()), + alpha, + &CpuBackend::precompute_twiddles(domain.coset()), + ); + + let avx_fold = SimdBackend::fold_line( + &LineEvaluation::new(domain, values.iter().copied().collect()), + alpha, + &SimdBackend::precompute_twiddles(domain.coset()), + ); + + assert_eq!(cpu_fold.values.to_vec(), avx_fold.values.to_vec()); + } + + #[test] + fn test_fold_circle_into_line() { + const LOG_SIZE: u32 = 7; + let values: Vec = (0..(1 << LOG_SIZE)) + .map(|i| qm31!(4 * i, 4 * i + 1, 4 * i + 2, 4 * i + 3)) + .collect(); + let alpha = qm31!(1, 3, 5, 7); + let circle_domain = CanonicCoset::new(LOG_SIZE).circle_domain(); + let line_domain = LineDomain::new(circle_domain.half_coset); + let mut cpu_fold = LineEvaluation::new( + line_domain, + SecureColumnByCoords::zeros(1 << (LOG_SIZE - 1)), + ); + CpuBackend::fold_circle_into_line( + &mut cpu_fold, + &SecureEvaluation::new(circle_domain, values.iter().copied().collect()), + alpha, + &CpuBackend::precompute_twiddles(line_domain.coset()), + ); + + let mut simd_fold = LineEvaluation::new( + line_domain, + SecureColumnByCoords::zeros(1 << (LOG_SIZE - 1)), + ); + SimdBackend::fold_circle_into_line( + &mut simd_fold, + &SecureEvaluation::new(circle_domain, values.iter().copied().collect()), + alpha, + &SimdBackend::precompute_twiddles(line_domain.coset()), + ); + + assert_eq!(cpu_fold.values.to_vec(), simd_fold.values.to_vec()); + } + + #[test] + fn decomposition_test() { + const DOMAIN_LOG_SIZE: u32 = 5; + const DOMAIN_LOG_HALF_SIZE: u32 = DOMAIN_LOG_SIZE - 1; + let s = CanonicCoset::new(DOMAIN_LOG_SIZE); + let domain = s.circle_domain(); + let mut coeffs = BaseColumn::zeros(1 << DOMAIN_LOG_SIZE); + // Polynomial is out of FFT space. + coeffs.as_mut_slice()[1 << DOMAIN_LOG_HALF_SIZE] = BaseField::one(); + let poly = CirclePoly::::new(coeffs); + let values = poly.evaluate(domain); + let avx_column = SecureColumnByCoords:: { + columns: [ + values.values.clone(), + values.values.clone(), + values.values.clone(), + values.values.clone(), + ], + }; + let avx_eval = SecureEvaluation::new(domain, avx_column.clone()); + let cpu_eval = + SecureEvaluation::::new(domain, avx_eval.to_cpu()); + let (cpu_g, cpu_lambda) = CpuBackend::decompose(&cpu_eval); + let (avx_g, avx_lambda) = SimdBackend::decompose(&avx_eval); + + assert_eq!(avx_lambda, cpu_lambda); + for i in 0..1 << DOMAIN_LOG_SIZE { + assert_eq!(avx_g.values.at(i), cpu_g.values.at(i)); + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/grind.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/grind.rs new file mode 100644 index 0000000..36721dc --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/grind.rs @@ -0,0 +1,95 @@ +use std::simd::cmp::SimdPartialOrd; +use std::simd::num::SimdUint; +use std::simd::u32x16; + +use bytemuck::cast_slice; +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +use super::blake2s::compress16; +use super::SimdBackend; +use crate::core::backend::simd::m31::N_LANES; +use crate::core::channel::Blake2sChannel; +#[cfg(not(target_arch = "wasm32"))] +use crate::core::channel::{Channel, Poseidon252Channel, PoseidonBLSChannel}; +use crate::core::proof_of_work::GrindOps; + +// Note: GRIND_LOW_BITS is a cap on how much extra time we need to wait for all threads to finish. +const GRIND_LOW_BITS: u32 = 20; +const GRIND_HI_BITS: u32 = 64 - GRIND_LOW_BITS; + +impl GrindOps for SimdBackend { + fn grind(channel: &Blake2sChannel, pow_bits: u32) -> u64 { + // TODO(spapini): support more than 32 bits. + assert!(pow_bits <= 32, "pow_bits > 32 is not supported"); + let digest = channel.digest(); + let digest: &[u32] = cast_slice(&digest.0[..]); + + #[cfg(not(feature = "parallel"))] + let res = (0..=(1 << GRIND_HI_BITS)) + .find_map(|hi| grind_blake(digest, hi, pow_bits)) + .expect("Grind failed to find a solution."); + + #[cfg(feature = "parallel")] + let res = (0..=(1 << GRIND_HI_BITS)) + .into_par_iter() + .find_map_any(|hi| grind_blake(digest, hi, pow_bits)) + .expect("Grind failed to find a solution."); + + res + } +} + +fn grind_blake(digest: &[u32], hi: u64, pow_bits: u32) -> Option { + let zero: u32x16 = u32x16::default(); + let pow_bits = u32x16::splat(pow_bits); + + let state: [u32x16; 8] = std::array::from_fn(|i| u32x16::splat(digest[i])); + + let mut attempt = [zero; 16]; + attempt[0] = u32x16::splat((hi << GRIND_LOW_BITS) as u32); + attempt[0] += u32x16::from(std::array::from_fn(|i| i as u32)); + attempt[1] = u32x16::splat((hi >> (32 - GRIND_LOW_BITS)) as u32); + for low in (0..(1 << GRIND_LOW_BITS)).step_by(N_LANES) { + let res = compress16(state, attempt, zero, zero, zero, zero); + let success_mask = res[0].trailing_zeros().simd_ge(pow_bits); + if success_mask.any() { + let i = success_mask.to_array().iter().position(|&x| x).unwrap(); + return Some((hi << GRIND_LOW_BITS) + low as u64 + i as u64); + } + attempt[0] += u32x16::splat(N_LANES as u32); + } + None +} + +// TODO(spapini): This is a naive implementation. Optimize it. +#[cfg(not(target_arch = "wasm32"))] +impl GrindOps for SimdBackend { + fn grind(channel: &Poseidon252Channel, pow_bits: u32) -> u64 { + let mut nonce = 0; + loop { + let mut channel = channel.clone(); + channel.mix_nonce(nonce); + if channel.trailing_zeros() >= pow_bits { + return nonce; + } + nonce += 1; + } + } +} + +// TODO(spapini): This is a naive implementation. Optimize it. +#[cfg(not(target_arch = "wasm32"))] +impl GrindOps for SimdBackend { + fn grind(channel: &PoseidonBLSChannel, pow_bits: u32) -> u64 { + let mut nonce = 0; + loop { + let mut channel = channel.clone(); + channel.mix_nonce(nonce); + if channel.trailing_zeros() >= pow_bits { + return nonce; + } + nonce += 1; + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/gkr.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/gkr.rs new file mode 100644 index 0000000..017948d --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/gkr.rs @@ -0,0 +1,684 @@ +use std::iter::zip; + +use num_traits::Zero; + +use crate::core::backend::cpu::lookups::gkr::gen_eq_evals as cpu_gen_eq_evals; +use crate::core::backend::simd::column::SecureColumn; +use crate::core::backend::simd::m31::{LOG_N_LANES, N_LANES}; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{Column, CpuBackend}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::gkr_prover::{ + correct_sum_as_poly_in_first_variable, EqEvals, GkrMultivariatePolyOracle, GkrOps, Layer, +}; +use crate::core::lookups::mle::Mle; +use crate::core::lookups::sumcheck::MultivariatePolyOracle; +use crate::core::lookups::utils::{Fraction, Reciprocal, UnivariatePoly}; + +impl GkrOps for SimdBackend { + #[allow(clippy::uninit_vec)] + fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle { + if y.len() < LOG_N_LANES as usize { + return Mle::new(cpu_gen_eq_evals(y, v).into_iter().collect()); + } + + // Start DP with CPU backend to avoid dealing with instances smaller than a SIMD vector. + let (y_last_chunk, y_rem) = y.split_last_chunk::<{ LOG_N_LANES as usize }>().unwrap(); + let initial = SecureColumn::from_iter(cpu_gen_eq_evals(y_last_chunk, v)); + assert_eq!(initial.len(), N_LANES); + + let packed_len = 1 << y_rem.len(); + let mut data = initial.data; + + data.reserve(packed_len - data.len()); + unsafe { data.set_len(packed_len) }; + + for (i, &y_j) in y_rem.iter().rev().enumerate() { + let packed_y_j = PackedSecureField::broadcast(y_j); + + let (lhs_evals, rhs_evals) = data.split_at_mut(1 << i); + + for (lhs, rhs) in zip(lhs_evals, rhs_evals) { + // Equivalent to: + // `rhs = eq(1, y_j) * lhs`, + // `lhs = eq(0, y_j) * lhs` + *rhs = *lhs * packed_y_j; + *lhs -= *rhs; + } + } + + let length = packed_len * N_LANES; + Mle::new(SecureColumn { data, length }) + } + + fn next_layer(layer: &Layer) -> Layer { + // Offload to CPU backend to avoid dealing with instances smaller than a SIMD vector. + if layer.n_variables() as u32 <= LOG_N_LANES { + return into_simd_layer(layer.to_cpu().next_layer().unwrap()); + } + + match layer { + Layer::GrandProduct(col) => next_grand_product_layer(col), + Layer::LogUpGeneric { + numerators, + denominators, + } => next_logup_generic_layer(numerators, denominators), + Layer::LogUpMultiplicities { + numerators, + denominators, + } => next_logup_multiplicities_layer(numerators, denominators), + Layer::LogUpSingles { denominators } => next_logup_singles_layer(denominators), + } + } + + fn sum_as_poly_in_first_variable( + h: &GkrMultivariatePolyOracle<'_, Self>, + claim: SecureField, + ) -> UnivariatePoly { + let n_variables = h.n_variables(); + let n_terms = 1 << n_variables.saturating_sub(1); + let eq_evals = h.eq_evals.as_ref(); + // Vector used to generate evaluations of `eq(x, y)` for `x` in the boolean hypercube. + let y = eq_evals.y(); + + // Offload to CPU backend to avoid dealing with instances smaller than a SIMD vector. + if n_terms < N_LANES { + return h.to_cpu().sum_as_poly_in_first_variable(claim); + } + + let n_packed_terms = n_terms / N_LANES; + let packed_lambda = PackedSecureField::broadcast(h.lambda); + + let (mut eval_at_0, mut eval_at_2) = match &h.input_layer { + Layer::GrandProduct(col) => eval_grand_product_sum(eq_evals, col, n_packed_terms), + Layer::LogUpGeneric { + numerators, + denominators, + } => eval_logup_generic_sum( + eq_evals, + numerators, + denominators, + n_packed_terms, + packed_lambda, + ), + Layer::LogUpMultiplicities { + numerators, + denominators, + } => eval_logup_multiplicities_sum( + eq_evals, + numerators, + denominators, + n_packed_terms, + packed_lambda, + ), + Layer::LogUpSingles { denominators } => { + eval_logup_singles_sum(eq_evals, denominators, n_packed_terms, packed_lambda) + } + }; + + eval_at_0 *= h.eq_fixed_var_correction; + eval_at_2 *= h.eq_fixed_var_correction; + correct_sum_as_poly_in_first_variable(eval_at_0, eval_at_2, claim, y, n_variables) + } +} + +/// Generates the next GKR layer for Grand Product. +/// +/// Assumption: `len(layer) > N_LANES`. +fn next_grand_product_layer(layer: &Mle) -> Layer { + assert!(layer.len() > N_LANES); + let next_layer_len = layer.len() / 2; + + let data = layer + .data + .array_chunks() + .map(|&[a, b]| { + let (evens, odds) = a.deinterleave(b); + evens * odds + }) + .collect(); + + Layer::GrandProduct(Mle::new(SecureColumn { + data, + length: next_layer_len, + })) +} + +/// Generates the next GKR layer for LogUp. +/// +/// Assumption: `len(denominators) > N_LANES`. +fn next_logup_generic_layer( + numerators: &Mle, + denominators: &Mle, +) -> Layer { + assert!(denominators.len() > N_LANES); + assert_eq!(numerators.len(), denominators.len()); + + let next_layer_len = denominators.len() / 2; + let next_layer_packed_len = next_layer_len / N_LANES; + + let mut next_numerators = Vec::with_capacity(next_layer_packed_len); + let mut next_denominators = Vec::with_capacity(next_layer_packed_len); + + for i in 0..next_layer_packed_len { + let (n_even, n_odd) = numerators.data[i * 2].deinterleave(numerators.data[i * 2 + 1]); + let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]); + + let Fraction { + numerator, + denominator, + } = Fraction::new(n_even, d_even) + Fraction::new(n_odd, d_odd); + + next_numerators.push(numerator); + next_denominators.push(denominator); + } + + let next_numerators = SecureColumn { + data: next_numerators, + length: next_layer_len, + }; + + let next_denominators = SecureColumn { + data: next_denominators, + length: next_layer_len, + }; + + Layer::LogUpGeneric { + numerators: Mle::new(next_numerators), + denominators: Mle::new(next_denominators), + } +} + +/// Generates the next GKR layer for LogUp. +/// +/// Assumption: `len(denominators) > N_LANES`. +// TODO(andrew): Code duplication of `next_logup_generic_layer`. Consider unifying these. +fn next_logup_multiplicities_layer( + numerators: &Mle, + denominators: &Mle, +) -> Layer { + assert!(denominators.len() > N_LANES); + assert_eq!(numerators.len(), denominators.len()); + + let next_layer_len = denominators.len() / 2; + let next_layer_packed_len = next_layer_len / N_LANES; + + let mut next_numerators = Vec::with_capacity(next_layer_packed_len); + let mut next_denominators = Vec::with_capacity(next_layer_packed_len); + + for i in 0..next_layer_packed_len { + let (n_even, n_odd) = numerators.data[i * 2].deinterleave(numerators.data[i * 2 + 1]); + let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]); + + let Fraction { + numerator, + denominator, + } = Fraction::new(n_even, d_even) + Fraction::new(n_odd, d_odd); + + next_numerators.push(numerator); + next_denominators.push(denominator); + } + + let next_numerators = SecureColumn { + data: next_numerators, + length: next_layer_len, + }; + + let next_denominators = SecureColumn { + data: next_denominators, + length: next_layer_len, + }; + + Layer::LogUpGeneric { + numerators: Mle::new(next_numerators), + denominators: Mle::new(next_denominators), + } +} + +/// Generates the next GKR layer for LogUp. +/// +/// Assumption: `len(denominators) > N_LANES`. +fn next_logup_singles_layer(denominators: &Mle) -> Layer { + assert!(denominators.len() > N_LANES); + + let next_layer_len = denominators.len() / 2; + let next_layer_packed_len = next_layer_len / N_LANES; + + let mut next_numerators = Vec::with_capacity(next_layer_packed_len); + let mut next_denominators = Vec::with_capacity(next_layer_packed_len); + + for i in 0..next_layer_packed_len { + let (d_even, d_odd) = denominators.data[i * 2].deinterleave(denominators.data[i * 2 + 1]); + + let Fraction { + numerator, + denominator, + } = Reciprocal::new(d_even) + Reciprocal::new(d_odd); + + next_numerators.push(numerator); + next_denominators.push(denominator); + } + + let next_numerators = SecureColumn { + data: next_numerators, + length: next_layer_len, + }; + + let next_denominators = SecureColumn { + data: next_denominators, + length: next_layer_len, + }; + + Layer::LogUpGeneric { + numerators: Mle::new(next_numerators), + denominators: Mle::new(next_denominators), + } +} + +/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * inp(r, t, x, 0) * inp(r, t, x, 1)` at `t=0` and `t=2`. +/// +/// Output of the form: `(eval_at_0, eval_at_2)`. +fn eval_grand_product_sum( + eq_evals: &EqEvals, + col: &Mle, + n_packed_terms: usize, +) -> (SecureField, SecureField) { + let mut packed_eval_at_0 = PackedSecureField::zero(); + let mut packed_eval_at_2 = PackedSecureField::zero(); + + for i in 0..n_packed_terms { + // Input polynomial at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` + // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let (inp_at_r0iv0, inp_at_r0iv1) = col.data[i * 2].deinterleave(col.data[i * 2 + 1]); + let (inp_at_r1iv0, inp_at_r1iv1) = + col.data[(n_packed_terms + i) * 2].deinterleave(col.data[(n_packed_terms + i) * 2 + 1]); + // Note `inp(r, t, x) = eq(t, 0) * inp(r, 0, x) + eq(t, 1) * inp(r, 1, x)` + // => `inp(r, 2, x) = 2 * inp(r, 1, x) - inp(r, 0, x)` + let inp_at_r2iv0 = inp_at_r1iv0.double() - inp_at_r0iv0; + let inp_at_r2iv1 = inp_at_r1iv1.double() - inp_at_r0iv1; + + // Product polynomial `prod(x) = inp(x, 0) * inp(x, 1)` at points `(r, {0, 2}, bits(i), v)`. + // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let prod_at_r2iv = inp_at_r2iv0 * inp_at_r2iv1; + let prod_at_r0iv = inp_at_r0iv0 * inp_at_r0iv1; + + let eq_eval_at_0iv = eq_evals.data[i]; + packed_eval_at_0 += eq_eval_at_0iv * prod_at_r0iv; + packed_eval_at_2 += eq_eval_at_0iv * prod_at_r2iv; + } + + ( + packed_eval_at_0.pointwise_sum(), + packed_eval_at_2.pointwise_sum(), + ) +} + +fn eval_logup_generic_sum( + eq_evals: &EqEvals, + numerators: &Mle, + denominators: &Mle, + n_packed_terms: usize, + packed_lambda: PackedSecureField, +) -> (SecureField, SecureField) { + let mut packed_eval_at_0 = PackedSecureField::zero(); + let mut packed_eval_at_2 = PackedSecureField::zero(); + + let inp_numer = &numerators.data; + let inp_denom = &denominators.data; + + for i in 0..n_packed_terms { + // Input polynomials at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` + // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let (inp_numer_at_r0iv0, inp_numer_at_r0iv1) = + inp_numer[i * 2].deinterleave(inp_numer[i * 2 + 1]); + let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) = + inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]); + let (inp_numer_at_r1iv0, inp_numer_at_r1iv1) = inp_numer[(n_packed_terms + i) * 2] + .deinterleave(inp_numer[(n_packed_terms + i) * 2 + 1]); + let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2] + .deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]); + // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` + // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` + let inp_numer_at_r2iv0 = inp_numer_at_r1iv0.double() - inp_numer_at_r0iv0; + let inp_numer_at_r2iv1 = inp_numer_at_r1iv1.double() - inp_numer_at_r0iv1; + let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0; + let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1; + + // Fraction addition polynomials: + // - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)` + // - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`. + // at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let Fraction { + numerator: numer_at_r0iv, + denominator: denom_at_r0iv, + } = Fraction::new(inp_numer_at_r0iv0, inp_denom_at_r0iv0) + + Fraction::new(inp_numer_at_r0iv1, inp_denom_at_r0iv1); + let Fraction { + numerator: numer_at_r2iv, + denominator: denom_at_r2iv, + } = Fraction::new(inp_numer_at_r2iv0, inp_denom_at_r2iv0) + + Fraction::new(inp_numer_at_r2iv1, inp_denom_at_r2iv1); + + let eq_eval_at_0iv = eq_evals.data[i]; + packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv); + packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv); + } + + ( + packed_eval_at_0.pointwise_sum(), + packed_eval_at_2.pointwise_sum(), + ) +} + +// TODO(andrew): Code duplication of `eval_logup_generic_sum`. Consider unifying these. +fn eval_logup_multiplicities_sum( + eq_evals: &EqEvals, + numerators: &Mle, + denominators: &Mle, + n_packed_terms: usize, + packed_lambda: PackedSecureField, +) -> (SecureField, SecureField) { + let mut packed_eval_at_0 = PackedSecureField::zero(); + let mut packed_eval_at_2 = PackedSecureField::zero(); + + let inp_numer = &numerators.data; + let inp_denom = &denominators.data; + + for i in 0..n_packed_terms { + // Input polynomials at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` + // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let (inp_numer_at_r0iv0, inp_numer_at_r0iv1) = + inp_numer[i * 2].deinterleave(inp_numer[i * 2 + 1]); + let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) = + inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]); + let (inp_numer_at_r1iv0, inp_numer_at_r1iv1) = inp_numer[(n_packed_terms + i) * 2] + .deinterleave(inp_numer[(n_packed_terms + i) * 2 + 1]); + let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2] + .deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]); + // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` + // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` + let inp_numer_at_r2iv0 = inp_numer_at_r1iv0.double() - inp_numer_at_r0iv0; + let inp_numer_at_r2iv1 = inp_numer_at_r1iv1.double() - inp_numer_at_r0iv1; + let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0; + let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1; + + // Fraction addition polynomials: + // - `numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0)` + // - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`. + // at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let Fraction { + numerator: numer_at_r0iv, + denominator: denom_at_r0iv, + } = Fraction::new(inp_numer_at_r0iv0, inp_denom_at_r0iv0) + + Fraction::new(inp_numer_at_r0iv1, inp_denom_at_r0iv1); + let Fraction { + numerator: numer_at_r2iv, + denominator: denom_at_r2iv, + } = Fraction::new(inp_numer_at_r2iv0, inp_denom_at_r2iv0) + + Fraction::new(inp_numer_at_r2iv1, inp_denom_at_r2iv1); + + let eq_eval_at_0iv = eq_evals.data[i]; + packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv); + packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv); + } + + ( + packed_eval_at_0.pointwise_sum(), + packed_eval_at_2.pointwise_sum(), + ) +} + +/// Evaluates `sum_x eq(({0}^|r|, 0, x), y) * (inp_denom(r, t, x, 1) + inp_denom(r, t, x, 0) + +/// lambda * inp_denom(r, t, x, 0) * inp_denom(r, t, x, 1))` at `t=0` and `t=2`. +/// +/// Output of the form: `(eval_at_0, eval_at_2)`. +fn eval_logup_singles_sum( + eq_evals: &EqEvals, + denominators: &Mle, + n_packed_terms: usize, + packed_lambda: PackedSecureField, +) -> (SecureField, SecureField) { + let mut packed_eval_at_0 = PackedSecureField::zero(); + let mut packed_eval_at_2 = PackedSecureField::zero(); + + let inp_denom = &denominators.data; + + for i in 0..n_packed_terms { + // Input polynomial at points `(r, {0, 1, 2}, bits(i), v, {0, 1})` + // for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let (inp_denom_at_r0iv0, inp_denom_at_r0iv1) = + inp_denom[i * 2].deinterleave(inp_denom[i * 2 + 1]); + let (inp_denom_at_r1iv0, inp_denom_at_r1iv1) = inp_denom[(n_packed_terms + i) * 2] + .deinterleave(inp_denom[(n_packed_terms + i) * 2 + 1]); + // Note `inp_denom(r, t, x) = eq(t, 0) * inp_denom(r, 0, x) + eq(t, 1) * inp_denom(r, 1, x)` + // => `inp_denom(r, 2, x) = 2 * inp_denom(r, 1, x) - inp_denom(r, 0, x)` + let inp_denom_at_r2iv0 = inp_denom_at_r1iv0.double() - inp_denom_at_r0iv0; + let inp_denom_at_r2iv1 = inp_denom_at_r1iv1.double() - inp_denom_at_r0iv1; + + // Fraction addition polynomials: + // - `numer(x) = inp_denom(x, 1) + inp_denom(x, 0)` + // - `denom(x) = inp_denom(x, 0) * inp_denom(x, 1)`. + // at points `(r, {0, 2}, bits(i), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let Fraction { + numerator: numer_at_r0iv, + denominator: denom_at_r0iv, + } = Reciprocal::new(inp_denom_at_r0iv0) + Reciprocal::new(inp_denom_at_r0iv1); + let Fraction { + numerator: numer_at_r2iv, + denominator: denom_at_r2iv, + } = Reciprocal::new(inp_denom_at_r2iv0) + Reciprocal::new(inp_denom_at_r2iv1); + + let eq_eval_at_0iv = eq_evals.data[i]; + packed_eval_at_0 += eq_eval_at_0iv * (numer_at_r0iv + packed_lambda * denom_at_r0iv); + packed_eval_at_2 += eq_eval_at_0iv * (numer_at_r2iv + packed_lambda * denom_at_r2iv); + } + + ( + packed_eval_at_0.pointwise_sum(), + packed_eval_at_2.pointwise_sum(), + ) +} + +fn into_simd_layer(cpu_layer: Layer) -> Layer { + match cpu_layer { + Layer::GrandProduct(mle) => { + Layer::GrandProduct(Mle::new(mle.into_evals().into_iter().collect())) + } + Layer::LogUpGeneric { + numerators, + denominators, + } => Layer::LogUpGeneric { + numerators: Mle::new(numerators.into_evals().into_iter().collect()), + denominators: Mle::new(denominators.into_evals().into_iter().collect()), + }, + Layer::LogUpMultiplicities { + numerators, + denominators, + } => Layer::LogUpMultiplicities { + numerators: Mle::new(numerators.into_evals().into_iter().collect()), + denominators: Mle::new(denominators.into_evals().into_iter().collect()), + }, + Layer::LogUpSingles { denominators } => Layer::LogUpSingles { + denominators: Mle::new(denominators.into_evals().into_iter().collect()), + }, + } +} + +#[cfg(test)] +mod tests { + use std::iter::zip; + + use num_traits::One; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::{Column, CpuBackend}; + use crate::core::channel::Channel; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::lookups::gkr_prover::{prove_batch, GkrOps, Layer}; + use crate::core::lookups::gkr_verifier::{partially_verify_batch, Gate, GkrArtifact, GkrError}; + use crate::core::lookups::mle::Mle; + use crate::core::lookups::utils::Fraction; + use crate::core::test_utils::test_channel; + + #[test] + fn gen_eq_evals_matches_cpu() { + let two = BaseField::from(2).into(); + let y = [7, 3, 5, 6, 1, 1, 9].map(|v| BaseField::from(v).into()); + let eq_evals_cpu = CpuBackend::gen_eq_evals(&y, two); + + let eq_evals_simd = SimdBackend::gen_eq_evals(&y, two); + + assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu); + } + + #[test] + fn gen_eq_evals_with_small_assignment_matches_cpu() { + let two = BaseField::from(2).into(); + let y = [7, 3, 5].map(|v| BaseField::from(v).into()); + let eq_evals_cpu = CpuBackend::gen_eq_evals(&y, two); + + let eq_evals_simd = SimdBackend::gen_eq_evals(&y, two); + + assert_eq!(eq_evals_simd.to_cpu(), *eq_evals_cpu); + } + + #[test] + fn grand_product_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let values = test_channel().draw_felts(N); + let product = values.iter().product(); + let col = Mle::::new(values.into_iter().collect()); + let input_layer = Layer::GrandProduct(col.clone()); + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::GrandProduct], &proof, &mut test_channel())?; + + assert_eq!(proof.output_claims_by_instance, [vec![product]]); + assert_eq!( + claims_to_verify_by_instance, + [vec![col.eval_at_point(&ood_point)]] + ); + Ok(()) + } + + #[test] + fn logup_with_generic_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let mut rng = SmallRng::seed_from_u64(0); + let numerators = (0..N).map(|_| rng.gen()).collect::>(); + let denominators = (0..N).map(|_| rng.gen()).collect::>(); + let sum = zip(&numerators, &denominators) + .map(|(&n, &d)| Fraction::new(n, d)) + .sum::>(); + let numerators = Mle::::new(numerators.into_iter().collect()); + let denominators = Mle::::new(denominators.into_iter().collect()); + let input_layer = Layer::LogUpGeneric { + numerators: numerators.clone(), + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + numerators.eval_at_point(&ood_point), + denominators.eval_at_point(&ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } + + #[test] + fn logup_with_multiplicities_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let mut rng = SmallRng::seed_from_u64(0); + let numerators = (0..N).map(|_| rng.gen()).collect::>(); + let denominators = (0..N).map(|_| rng.gen()).collect::>(); + let sum = zip(&numerators, &denominators) + .map(|(&n, &d)| Fraction::new(n.into(), d)) + .sum::>(); + let numerators = Mle::::new(numerators.into_iter().collect()); + let denominators = Mle::::new(denominators.into_iter().collect()); + let input_layer = Layer::LogUpMultiplicities { + numerators: numerators.clone(), + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [ + numerators.eval_at_point(&ood_point), + denominators.eval_at_point(&ood_point) + ] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } + + #[test] + fn logup_with_singles_trace_works() -> Result<(), GkrError> { + const N: usize = 1 << 8; + let mut rng = SmallRng::seed_from_u64(0); + let denominators = (0..N).map(|_| rng.gen()).collect::>(); + let sum = denominators + .iter() + .map(|&d| Fraction::new(SecureField::one(), d)) + .sum::>(); + let denominators = Mle::::new(denominators.into_iter().collect()); + let input_layer = Layer::LogUpSingles { + denominators: denominators.clone(), + }; + let (proof, _) = prove_batch(&mut test_channel(), vec![input_layer]); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: _, + } = partially_verify_batch(vec![Gate::LogUp], &proof, &mut test_channel())?; + + assert_eq!(claims_to_verify_by_instance.len(), 1); + assert_eq!(proof.output_claims_by_instance.len(), 1); + assert_eq!( + claims_to_verify_by_instance[0], + [SecureField::one(), denominators.eval_at_point(&ood_point)] + ); + assert_eq!( + proof.output_claims_by_instance[0], + [sum.numerator, sum.denominator] + ); + Ok(()) + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mle.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mle.rs new file mode 100644 index 0000000..0e2fe73 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mle.rs @@ -0,0 +1,132 @@ +use core::ops::Sub; +use std::iter::zip; +use std::ops::{Add, Mul}; + +use crate::core::backend::simd::column::SecureColumn; +use crate::core::backend::simd::m31::N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{Column, CpuBackend}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::mle::{Mle, MleOps}; + +impl MleOps for SimdBackend { + fn fix_first_variable( + mle: Mle, + assignment: SecureField, + ) -> Mle { + let midpoint = mle.len() / 2; + + // Use CPU backend to avoid dealing with instances smaller than `PackedSecureField`. + if midpoint < N_LANES { + let cpu_mle = Mle::::new(mle.to_cpu()); + let cpu_res = cpu_mle.fix_first_variable(assignment); + return Mle::new(cpu_res.into_evals().into_iter().collect()); + } + + let packed_assignment = PackedSecureField::broadcast(assignment); + let packed_midpoint = midpoint / N_LANES; + let (evals_at_0x, evals_at_1x) = mle.data.split_at(packed_midpoint); + + let res = zip(evals_at_0x, evals_at_1x) + .enumerate() + // MLE at points `({0, 1}, rev(bits(i)), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + .map(|(_i, (&packed_eval_at_0iv, &packed_eval_at_1iv))| { + fold_packed_mle_evals(packed_assignment, packed_eval_at_0iv, packed_eval_at_1iv) + }) + .collect(); + + Mle::new(res) + } +} + +impl MleOps for SimdBackend { + fn fix_first_variable( + mle: Mle, + assignment: SecureField, + ) -> Mle { + let midpoint = mle.len() / 2; + + // Use CPU backend to avoid dealing with instances smaller than `PackedSecureField`. + if midpoint < N_LANES { + let cpu_mle = Mle::::new(mle.to_cpu()); + let cpu_res = cpu_mle.fix_first_variable(assignment); + return Mle::new(cpu_res.into_evals().into_iter().collect()); + } + + let packed_midpoint = midpoint / N_LANES; + let packed_assignment = PackedSecureField::broadcast(assignment); + let mut packed_evals = mle.into_evals().data; + + for i in 0..packed_midpoint { + // MLE at points `({0, 1}, rev(bits(i)), v)` for all `v` in `{0, 1}^LOG_N_SIMD_LANES`. + let packed_eval_at_0iv = packed_evals[i]; + let packed_eval_at_1iv = packed_evals[i + packed_midpoint]; + packed_evals[i] = + fold_packed_mle_evals(packed_assignment, packed_eval_at_0iv, packed_eval_at_1iv); + } + + packed_evals.truncate(packed_midpoint); + + let length = packed_evals.len() * N_LANES; + let data = packed_evals; + + Mle::new(SecureColumn { data, length }) + } +} + +/// Computes all `eq(0, assignment_i) * eval0_i + eq(1, assignment_i) * eval1_i`. +// TODO(andrew): Remove complex trait bounds once we have something like +// AbstractField/AbstractExtensionField traits. +fn fold_packed_mle_evals< + PackedF: Sub + Copy, + PackedEF: Mul + Add, +>( + assignment: PackedEF, + eval0: PackedF, + eval1: PackedF, +) -> PackedEF { + assignment * (eval1 - eval0) + eval0 +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::{Column, CpuBackend}; + use crate::core::channel::Channel; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::lookups::mle::Mle; + use crate::core::test_utils::test_channel; + + #[test] + fn fix_first_variable_with_secure_field_mle_matches_cpu() { + const N_VARIABLES: u32 = 8; + let values = test_channel().draw_felts(1 << N_VARIABLES); + let mle_simd = Mle::::new(values.iter().copied().collect()); + let mle_cpu = Mle::::new(values); + let random_assignment = SecureField::from_u32_unchecked(7, 12, 3, 2); + let mle_fixed_cpu = mle_cpu.fix_first_variable(random_assignment); + + let mle_fixed_simd = mle_simd.fix_first_variable(random_assignment); + + assert_eq!(mle_fixed_simd.into_evals().to_cpu(), *mle_fixed_cpu) + } + + #[test] + fn fix_first_variable_with_base_field_mle_matches_cpu() { + const N_VARIABLES: u32 = 8; + let values = (0..1 << N_VARIABLES).map(BaseField::from).collect_vec(); + let mle_simd = Mle::::new(values.iter().copied().collect()); + let mle_cpu = Mle::::new(values); + let random_assignment = SecureField::from_u32_unchecked(7, 12, 3, 2); + let mle_fixed_cpu = mle_cpu.fix_first_variable(random_assignment); + + let mle_fixed_simd = mle_simd.fix_first_variable(random_assignment); + + assert_eq!(mle_fixed_simd.into_evals().to_cpu(), *mle_fixed_cpu) + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mod.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mod.rs new file mode 100644 index 0000000..34395e9 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/lookups/mod.rs @@ -0,0 +1,2 @@ +mod gkr; +mod mle; diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/m31.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/m31.rs new file mode 100644 index 0000000..f629162 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/m31.rs @@ -0,0 +1,666 @@ +use std::iter::Sum; +use std::mem::transmute; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use std::ptr; +use std::simd::cmp::SimdOrd; +use std::simd::{u32x16, Simd, Swizzle}; + +use bytemuck::{Pod, Zeroable}; +use num_traits::{One, Zero}; +use rand::distributions::{Distribution, Standard}; + +use super::qm31::PackedQM31; +use crate::core::backend::simd::utils::{InterleaveEvens, InterleaveOdds}; +use crate::core::fields::m31::{pow2147483645, BaseField, M31, P}; +use crate::core::fields::qm31::QM31; +use crate::core::fields::FieldExpOps; + +pub const LOG_N_LANES: u32 = 4; + +pub const N_LANES: usize = 1 << LOG_N_LANES; + +pub const MODULUS: Simd = Simd::from_array([P; N_LANES]); + +pub type PackedBaseField = PackedM31; + +/// Holds a vector of unreduced [`M31`] elements in the range `[0, P]`. +/// +/// Implemented with [`std::simd`] to support multiple targets (avx512, neon, wasm etc.). +// TODO: Remove `pub` visibility +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct PackedM31(Simd); + +impl PackedM31 { + /// Constructs a new instance with all vector elements set to `value`. + pub fn broadcast(M31(value): M31) -> Self { + Self(Simd::splat(value)) + } + + pub fn from_array(values: [M31; N_LANES]) -> PackedM31 { + Self(Simd::from_array(values.map(|M31(v)| v))) + } + + pub fn to_array(self) -> [M31; N_LANES] { + self.reduce().0.to_array().map(M31) + } + + /// Reduces each element of the vector to the range `[0, P)`. + fn reduce(self) -> PackedM31 { + Self(Simd::simd_min(self.0, self.0 - MODULUS)) + } + + /// Interleaves two vectors. + pub fn interleave(self, other: Self) -> (Self, Self) { + let (a, b) = self.0.interleave(other.0); + (Self(a), Self(b)) + } + + /// Deinterleaves two vectors. + pub fn deinterleave(self, other: Self) -> (Self, Self) { + let (a, b) = self.0.deinterleave(other.0); + (Self(a), Self(b)) + } + + /// Reverses the order of the elements in the vector. + pub fn reverse(self) -> Self { + Self(self.0.reverse()) + } + + /// Sums all the elements in the vector. + pub fn pointwise_sum(self) -> M31 { + self.to_array().into_iter().sum() + } + + /// Doubles each element in the vector. + pub fn double(self) -> Self { + // TODO: Make more optimal. + self + self + } + + pub fn into_simd(self) -> Simd { + self.0 + } + + /// # Safety + /// + /// Vector elements must be in the range `[0, P]`. + pub unsafe fn from_simd_unchecked(v: Simd) -> Self { + Self(v) + } + + /// # Safety + /// + /// Behavior is undefined if the pointer does not have the same alignment as + /// [`PackedM31`]. The loaded `u32` values must be in the range `[0, P]`. + pub unsafe fn load(mem_addr: *const u32) -> Self { + Self(ptr::read(mem_addr as *const u32x16)) + } + + /// # Safety + /// + /// Behavior is undefined if the pointer does not have the same alignment as + /// [`PackedM31`]. + pub unsafe fn store(self, dst: *mut u32) { + ptr::write(dst as *mut u32x16, self.0) + } +} + +impl Add for PackedM31 { + type Output = Self; + + #[inline(always)] + fn add(self, rhs: Self) -> Self::Output { + // Add word by word. Each word is in the range [0, 2P]. + let c = self.0 + rhs.0; + // Apply min(c, c-P) to each word. + // When c in [P,2P], then c-P in [0,P] which is always less than [P,2P]. + // When c in [0,P-1], then c-P in [2^32-P,2^32-1] which is always greater than [0,P-1]. + Self(Simd::simd_min(c, c - MODULUS)) + } +} + +impl AddAssign for PackedM31 { + #[inline(always)] + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl AddAssign for PackedM31 { + #[inline(always)] + fn add_assign(&mut self, rhs: M31) { + *self = *self + PackedM31::broadcast(rhs); + } +} + +impl Mul for PackedM31 { + type Output = Self; + + #[inline(always)] + fn mul(self, rhs: Self) -> Self { + // TODO: Come up with a better approach than `cfg`ing on target_feature. + // TODO: Ensure all these branches get tested in the CI. + cfg_if::cfg_if! { + if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { + _mul_neon(self, rhs) + } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { + _mul_wasm(self, rhs) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { + _mul_avx512(self, rhs) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] { + _mul_avx2(self, rhs) + } else { + _mul_simd(self, rhs) + } + } + } +} + +impl Mul for PackedM31 { + type Output = Self; + + #[inline(always)] + fn mul(self, rhs: M31) -> Self::Output { + self * PackedM31::broadcast(rhs) + } +} + +impl Add for PackedM31 { + type Output = PackedM31; + + #[inline(always)] + fn add(self, rhs: M31) -> Self::Output { + PackedM31::broadcast(rhs) + self + } +} + +impl Add for PackedM31 { + type Output = PackedQM31; + + #[inline(always)] + fn add(self, rhs: QM31) -> Self::Output { + PackedQM31::broadcast(rhs) + self + } +} + +impl Mul for PackedM31 { + type Output = PackedQM31; + + #[inline(always)] + fn mul(self, rhs: QM31) -> Self::Output { + PackedQM31::broadcast(rhs) * self + } +} + +impl MulAssign for PackedM31 { + #[inline(always)] + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl Neg for PackedM31 { + type Output = Self; + + #[inline(always)] + fn neg(self) -> Self::Output { + Self(MODULUS - self.0) + } +} + +impl Sub for PackedM31 { + type Output = Self; + + #[inline(always)] + fn sub(self, rhs: Self) -> Self::Output { + // Subtract word by word. Each word is in the range [-P, P]. + let c = self.0 - rhs.0; + // Apply min(c, c+P) to each word. + // When c in [0,P], then c+P in [P,2P] which is always greater than [0,P]. + // When c in [2^32-P,2^32-1], then c+P in [0,P-1] which is always less than + // [2^32-P,2^32-1]. + Self(Simd::simd_min(c + MODULUS, c)) + } +} + +impl SubAssign for PackedM31 { + #[inline(always)] + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl Zero for PackedM31 { + fn zero() -> Self { + Self(Simd::from_array([0; N_LANES])) + } + + fn is_zero(&self) -> bool { + self.to_array().iter().all(M31::is_zero) + } +} + +impl One for PackedM31 { + fn one() -> Self { + Self(Simd::::from_array([1; N_LANES])) + } +} + +impl FieldExpOps for PackedM31 { + fn inverse(&self) -> Self { + assert!(!self.is_zero(), "0 has no inverse"); + pow2147483645(*self) + } +} + +unsafe impl Pod for PackedM31 {} + +unsafe impl Zeroable for PackedM31 { + fn zeroed() -> Self { + unsafe { core::mem::zeroed() } + } +} + +impl From<[BaseField; N_LANES]> for PackedM31 { + fn from(v: [BaseField; N_LANES]) -> Self { + Self::from_array(v) + } +} + +impl From for PackedM31 { + fn from(v: BaseField) -> Self { + Self::broadcast(v) + } +} + +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> PackedM31 { + PackedM31::from_array(rng.gen()) + } +} + +impl Sum for PackedM31 { + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), Add::add) + } +} + +/// Returns `a * b`. +#[cfg(target_arch = "aarch64")] +pub(crate) fn _mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 { + use core::arch::aarch64::{int32x2_t, vqdmull_s32}; + use std::simd::u32x4; + + let [a0, a1, a2, a3, a4, a5, a6, a7]: [int32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [int32x2_t; 8] = unsafe { transmute(b) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vqdmull_s32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vqdmull_s32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vqdmull_s32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vqdmull_s32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vqdmull_s32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vqdmull_s32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vqdmull_s32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vqdmull_s32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi +} + +/// Returns `a * b`. +/// +/// `b_double` should be in the range `[0, 2P]`. +#[cfg(target_arch = "aarch64")] +pub(crate) fn _mul_doubled_neon(a: PackedM31, b_double: u32x16) -> PackedM31 { + use core::arch::aarch64::{uint32x2_t, vmull_u32}; + use std::simd::u32x4; + + let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(b_double) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi +} + +/// Returns `a * b`. +#[cfg(target_arch = "wasm32")] +pub(crate) fn _mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 { + _mul_doubled_wasm(a, b.0 + b.0) +} + +/// Returns `a * b`. +/// +/// `b_double` should be in the range `[0, 2P]`. +#[cfg(target_arch = "wasm32")] +pub(crate) fn _mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 { + use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128}; + use std::simd::u32x4; + + let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) }; + let [b_double0, b_double1, b_double2, b_double3]: [v128; 4] = unsafe { transmute(b_double) }; + + let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_double0)) }; + let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_double0)) }; + let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_double1)) }; + let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_double1)) }; + let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_double2)) }; + let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_double2)) }; + let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_double3)) }; + let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_double3)) }; + + let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi); + let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi); + let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi); + let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi); + c0_even >>= 1; + c1_even >>= 1; + c2_even >>= 1; + c3_even >>= 1; + + let even: PackedM31 = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) }; + let odd: PackedM31 = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) }; + + even + odd +} + +/// Returns `a * b`. +#[cfg(target_arch = "x86_64")] +pub(crate) fn _mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 { + _mul_doubled_avx512(a, b.0 + b.0) +} + +/// Returns `a * b`. +/// +/// `b_double` should be in the range `[0, 2P]`. +#[cfg(target_arch = "x86_64")] +pub(crate) fn _mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 { + use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64}; + + let a: __m512i = unsafe { transmute(a) }; + let b_double: __m512i = unsafe { transmute(b_double) }; + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = a; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { _mm512_srli_epi64(a, 32) }; + + let b_dbl_e = b_double; + let b_dbl_o = unsafe { _mm512_srli_epi64(b_double, 32) }; + + // To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd. + let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) }; + let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + PackedM31(prod_lo) + PackedM31(prod_hi) +} + +/// Returns `a * b`. +#[cfg(target_arch = "x86_64")] +pub(crate) fn _mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 { + _mul_doubled_avx2(a, b.0 + b.0) +} + +/// Returns `a * b`. +/// +/// `b_double` should be in the range `[0, 2P]`. +#[cfg(target_arch = "x86_64")] +pub(crate) fn _mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 { + use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64}; + + let [a0, a1]: [__m256i; 2] = unsafe { transmute(a) }; + let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute(b_double) }; + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a0_e = a0; + let a1_e = a1; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a0_o = unsafe { _mm256_srli_epi64(a0, 32) }; + let a1_o = unsafe { _mm256_srli_epi64(a1, 32) }; + + let b0_dbl_e = b0_dbl; + let b1_dbl_e = b1_dbl; + let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) }; + let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) }; + + // To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd. + let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) }; + let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) }; + let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) }; + let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) }; + + let prod_dbl_e: u32x16 = unsafe { transmute([prod0_dbl_e, prod1_dbl_e]) }; + let prod_dbl_o: u32x16 = unsafe { transmute([prod0_dbl_o, prod1_dbl_o]) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = InterleaveEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = InterleaveOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + PackedM31(prod_lo) + PackedM31(prod_hi) +} + +/// Returns `a * b`. +/// +/// Should only be used in the absence of a platform specific implementation. +pub(crate) fn _mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 { + _mul_doubled_simd(a, b.0 + b.0) +} + +/// Returns `a * b`. +/// +/// Should only be used in the absence of a platform specific implementation. +/// +/// `b_double` should be in the range `[0, 2P]`. +pub(crate) fn _mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 { + const MASK_EVENS: Simd = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]); + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = unsafe { transmute::<_, Simd>(a.0) & MASK_EVENS }; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { transmute::<_, Simd>(a) >> 32 }; + + let b_dbl_e = unsafe { transmute::<_, Simd>(b_double) & MASK_EVENS }; + let b_dbl_o = unsafe { transmute::<_, Simd>(b_double) >> 32 }; + + // To compute prod = a * b start by multiplying + // a_e/o by b_dbl_e/o. + let prod_e_dbl = a_e * b_dbl_e; + let prod_o_dbl = a_o * b_dbl_o; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_e_dbl - |0|prod_e_h|prod_e_l|0| + // prod_o_dbl - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: + // let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, + // prod_o_dbl); + // prod_ls - |prod_o_l|0|prod_e_l|0| + let mut prod_lows = InterleaveEvens::concat_swizzle( + unsafe { transmute::<_, Simd>(prod_e_dbl) }, + unsafe { transmute::<_, Simd>(prod_o_dbl) }, + ); + // Divide by 2: + prod_lows >>= 1; + // prod_ls - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: + let prod_highs = InterleaveOdds::concat_swizzle( + unsafe { transmute::<_, Simd>(prod_e_dbl) }, + unsafe { transmute::<_, Simd>(prod_o_dbl) }, + ); + + // prod_hs - |0|prod_o_h|0|prod_e_h| + PackedM31(prod_lows) + PackedM31(prod_highs) +} + +#[cfg(test)] +mod tests { + use std::array; + + use aligned::{Aligned, A64}; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::PackedM31; + use crate::core::fields::m31::BaseField; + use crate::core::fields::FieldExpOps; + + #[test] + fn addition_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedM31::from_array(lhs); + let packed_rhs = PackedM31::from_array(rhs); + + let res = packed_lhs + packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i])); + } + + #[test] + fn subtraction_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedM31::from_array(lhs); + let packed_rhs = PackedM31::from_array(rhs); + + let res = packed_lhs - packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i])); + } + + #[test] + fn multiplication_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedM31::from_array(lhs); + let packed_rhs = PackedM31::from_array(rhs); + + let res = packed_lhs * packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i])); + } + + #[test] + fn negation_works() { + let mut rng = SmallRng::seed_from_u64(0); + let values = rng.gen(); + let packed_values = PackedM31::from_array(values); + + let res = -packed_values; + + assert_eq!(res.to_array(), array::from_fn(|i| -values[i])); + } + + #[test] + fn load_works() { + let v: Aligned = Aligned(array::from_fn(|i| i as u32)); + + let res = unsafe { PackedM31::load(v.as_ptr()) }; + + assert_eq!(res.to_array().map(|v| v.0), *v); + } + + #[test] + fn store_works() { + let v = PackedM31::from_array(array::from_fn(BaseField::from)); + + let mut res: Aligned = Aligned([0; 16]); + unsafe { v.store(res.as_mut_ptr()) }; + + assert_eq!(*res, v.to_array().map(|v| v.0)); + } + + #[test] + fn inverse_works() { + let mut rng = SmallRng::seed_from_u64(0); + let values = rng.gen(); + let packed_values = PackedM31::from_array(values); + + let res = packed_values.inverse(); + + assert_eq!(res.to_array(), array::from_fn(|i| values[i].inverse())); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/mod.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/mod.rs new file mode 100644 index 0000000..49c7f4a --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/mod.rs @@ -0,0 +1,41 @@ +use serde::{Deserialize, Serialize}; + +use super::{Backend, BackendForChannel}; +use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; +#[cfg(not(target_arch = "wasm32"))] +use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel; +#[cfg(not(target_arch = "wasm32"))] +use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleChannel; + +pub mod accumulation; +pub mod bit_reverse; +pub mod blake2s; +pub mod circle; +pub mod cm31; +pub mod column; +pub mod domain; +pub mod fft; +pub mod fri; +mod grind; +pub mod lookups; +pub mod m31; +#[cfg(not(target_arch = "wasm32"))] +pub mod poseidon252; +pub mod prefix_sum; +pub mod qm31; +pub mod quotients; +mod utils; +pub mod very_packed_m31; +#[cfg(not(target_arch = "wasm32"))] +pub mod poseidon_bls; + +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] +pub struct SimdBackend; + +impl Backend for SimdBackend {} +impl BackendForChannel for SimdBackend {} +#[cfg(not(target_arch = "wasm32"))] +impl BackendForChannel for SimdBackend {} + +#[cfg(not(target_arch = "wasm32"))] +impl BackendForChannel for SimdBackend {} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/poseidon252.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/poseidon252.rs new file mode 100644 index 0000000..b001481 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/poseidon252.rs @@ -0,0 +1,36 @@ +use itertools::Itertools; +use starknet_ff::FieldElement as FieldElement252; + +use super::SimdBackend; +use crate::core::backend::{Col, Column, ColumnOps}; +use crate::core::fields::m31::BaseField; +#[cfg(not(target_arch = "wasm32"))] +use crate::core::vcs::ops::MerkleHasher; +use crate::core::vcs::ops::MerkleOps; +use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher; + +impl ColumnOps for SimdBackend { + type Column = Vec; + + fn bit_reverse_column(_column: &mut Self::Column) { + unimplemented!() + } +} + +impl MerkleOps for SimdBackend { + // TODO(ShaharS): replace with SIMD implementation. + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Vec>, + columns: &[&Col], + ) -> Vec { + (0..(1 << log_size)) + .map(|i| { + Poseidon252MerkleHasher::hash_node( + prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), + &columns.iter().map(|column| column.at(i)).collect_vec(), + ) + }) + .collect() + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/poseidon_bls.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/poseidon_bls.rs new file mode 100644 index 0000000..10c5ec9 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/poseidon_bls.rs @@ -0,0 +1,36 @@ +use itertools::Itertools; +use ark_bls12_381::Fr as BlsFr; + +use super::SimdBackend; +use crate::core::backend::{Col, Column, ColumnOps}; +use crate::core::fields::m31::BaseField; +#[cfg(not(target_arch = "wasm32"))] +use crate::core::vcs::ops::MerkleHasher; +use crate::core::vcs::ops::MerkleOps; +use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleHasher; + +impl ColumnOps for SimdBackend { + type Column = Vec; + + fn bit_reverse_column(_column: &mut Self::Column) { + unimplemented!() + } +} + +impl MerkleOps for SimdBackend { + // TODO(ShaharS): replace with SIMD implementation. + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Vec>, + columns: &[&Col], + ) -> Vec { + (0..(1 << log_size)) + .map(|i| { + PoseidonBLSMerkleHasher::hash_node( + prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])), + &columns.iter().map(|column| column.at(i)).collect_vec(), + ) + }) + .collect() + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/prefix_sum.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/prefix_sum.rs new file mode 100644 index 0000000..652b484 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/prefix_sum.rs @@ -0,0 +1,188 @@ +use std::iter::zip; +use std::ops::{AddAssign, Sub}; + +use itertools::{izip, Itertools}; +use num_traits::Zero; + +use crate::core::backend::simd::m31::{PackedBaseField, N_LANES}; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{Col, Column}; +use crate::core::fields::m31::BaseField; +use crate::core::utils::{ + bit_reverse, circle_domain_order_to_coset_order, coset_order_to_circle_domain_order, +}; + +/// Performs a inclusive prefix sum on values in `Coset` order when provided +/// with evaluations in bit-reversed `CircleDomain` order. +/// +/// Based on parallel Blelloch prefix sum: +/// +pub fn inclusive_prefix_sum( + bit_rev_circle_domain_evals: Col, +) -> Col { + if bit_rev_circle_domain_evals.len() < N_LANES * 4 { + return inclusive_prefix_sum_slow(bit_rev_circle_domain_evals); + } + + let mut res = bit_rev_circle_domain_evals; + let packed_len = res.data.len(); + let (l_half, r_half) = res.data.split_at_mut(packed_len / 2); + + // Up Sweep + // ======== + // Handle the first two up sweep rounds manually. + // Required due different ordering of `CircleDomain` and `Coset`. + // Evaluations are provided in bit-reversed `CircleDomain` order. + for ([l0, l1], [r0, r1]) in izip!(l_half.array_chunks_mut(), r_half.array_chunks_mut().rev()) { + let (mut half_coset0_lo, half_coset1_hi_rev) = l0.deinterleave(*l1); + let half_coset1_hi = half_coset1_hi_rev.reverse(); + let (mut half_coset0_hi, half_coset1_lo_rev) = r0.deinterleave(*r1); + let half_coset1_lo = half_coset1_lo_rev.reverse(); + up_sweep_val(&mut half_coset0_lo, half_coset1_lo); + up_sweep_val(&mut half_coset0_hi, half_coset1_hi); + *l0 = half_coset0_lo; + *l1 = half_coset0_hi; + *r0 = half_coset1_lo; + *r1 = half_coset1_hi; + } + let half_coset0_sums: &mut [PackedBaseField] = l_half; + for i in 0..half_coset0_sums.len() / 2 { + let lo_index = i * 2; + let hi_index = half_coset0_sums.len() - 1 - i * 2; + let hi = half_coset0_sums[hi_index]; + up_sweep_val(&mut half_coset0_sums[lo_index], hi) + } + // Handle remaining up sweep rounds. + let mut chunk_size = half_coset0_sums.len() / 2; + while chunk_size > 1 { + let (lows, highs) = half_coset0_sums.split_at_mut(chunk_size); + zip(lows.array_chunks_mut(), highs.array_chunks()) + .for_each(|([lo, _], [hi, _])| up_sweep_val(lo, *hi)); + chunk_size /= 2; + } + // Up sweep the last SIMD vector. + let mut first_vec = half_coset0_sums.first().unwrap().to_array(); + let mut chunk_size = first_vec.len() / 2; + while chunk_size > 0 { + let (lows, highs) = first_vec.split_at_mut(chunk_size); + zip(lows, highs).for_each(|(lo, hi)| up_sweep_val(lo, *hi)); + chunk_size /= 2; + } + + // Down Sweep + // ========== + // Down sweep the last SIMD vector. + let mut chunk_size = 1; + while chunk_size < first_vec.len() { + let (lows, highs) = first_vec.split_at_mut(chunk_size); + zip(lows, highs).for_each(|(lo, hi)| down_sweep_val(lo, hi)); + chunk_size *= 2; + } + // Re-insert the SIMD vector. + *half_coset0_sums.first_mut().unwrap() = first_vec.into(); + // Handle remaining down sweep rounds (except first two). + let mut chunk_size = 2; + while chunk_size < half_coset0_sums.len() { + let (lows, highs) = half_coset0_sums.split_at_mut(chunk_size); + zip(lows.array_chunks_mut(), highs.array_chunks_mut()) + .for_each(|([lo, _], [hi, _])| down_sweep_val(lo, hi)); + chunk_size *= 2; + } + // Handle last two down sweep rounds manually. + // Required due different ordering of `CircleDomain` and `Coset`. + // Evaluations must be returned in bit-reversed `CircleDomain` order. + for i in 0..half_coset0_sums.len() / 2 { + let lo_index = i * 2; + let hi_index = half_coset0_sums.len() - 1 - i * 2; + let (mut lo, mut hi) = (half_coset0_sums[lo_index], half_coset0_sums[hi_index]); + down_sweep_val(&mut lo, &mut hi); + (half_coset0_sums[lo_index], half_coset0_sums[hi_index]) = (lo, hi); + } + for ([l0, l1], [r0, r1]) in izip!(l_half.array_chunks_mut(), r_half.array_chunks_mut().rev()) { + let mut half_coset0_lo = *l0; + let mut half_coset1_lo = *r0; + down_sweep_val(&mut half_coset0_lo, &mut half_coset1_lo); + let mut half_coset0_hi = *l1; + let mut half_coset1_hi = *r1; + down_sweep_val(&mut half_coset0_hi, &mut half_coset1_hi); + (*l0, *l1) = half_coset0_lo.interleave(half_coset1_hi.reverse()); + (*r0, *r1) = half_coset0_hi.interleave(half_coset1_lo.reverse()); + } + + res +} + +fn up_sweep_val(lo: &mut F, hi: F) { + *lo += hi; +} + +fn down_sweep_val + Copy>(lo: &mut F, hi: &mut F) { + (*lo, *hi) = (*lo - *hi, *lo) +} + +fn inclusive_prefix_sum_slow( + bit_rev_circle_domain_evals: Col, +) -> Col { + // Obtain values in coset order. + let mut coset_order_eval = bit_rev_circle_domain_evals.into_cpu_vec(); + bit_reverse(&mut coset_order_eval); + coset_order_eval = circle_domain_order_to_coset_order(&coset_order_eval); + let coset_order_prefix_sum = coset_order_eval + .into_iter() + .scan(BaseField::zero(), |acc, v| { + *acc += v; + Some(*acc) + }) + .collect_vec(); + let mut circle_domain_order_eval = coset_order_to_circle_domain_order(&coset_order_prefix_sum); + bit_reverse(&mut circle_domain_order_eval); + circle_domain_order_eval.into_iter().collect() +} + +#[cfg(test)] +mod tests { + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + use test_log::test; + + use super::inclusive_prefix_sum; + use crate::core::backend::simd::column::BaseColumn; + use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum_slow; + use crate::core::backend::Column; + + #[test] + fn exclusive_prefix_sum_simd_with_log_size_3_works() { + const LOG_N: u32 = 3; + let mut rng = SmallRng::seed_from_u64(0); + let evals: BaseColumn = (0..1 << LOG_N).map(|_| rng.gen()).collect(); + let expected = inclusive_prefix_sum_slow(evals.clone()); + + let res = inclusive_prefix_sum(evals); + + assert_eq!(res.to_cpu(), expected.to_cpu()); + } + + #[test] + fn exclusive_prefix_sum_simd_with_log_size_6_works() { + const LOG_N: u32 = 6; + let mut rng = SmallRng::seed_from_u64(0); + let evals: BaseColumn = (0..1 << LOG_N).map(|_| rng.gen()).collect(); + let expected = inclusive_prefix_sum_slow(evals.clone()); + + let res = inclusive_prefix_sum(evals); + + assert_eq!(res.to_cpu(), expected.to_cpu()); + } + + #[test] + fn exclusive_prefix_sum_simd_with_log_size_8_works() { + const LOG_N: u32 = 8; + let mut rng = SmallRng::seed_from_u64(0); + let evals: BaseColumn = (0..1 << LOG_N).map(|_| rng.gen()).collect(); + let expected = inclusive_prefix_sum_slow(evals.clone()); + + let res = inclusive_prefix_sum(evals); + + assert_eq!(res.to_cpu(), expected.to_cpu()); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/qm31.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/qm31.rs new file mode 100644 index 0000000..13d03ce --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/qm31.rs @@ -0,0 +1,357 @@ +use std::array; +use std::iter::Sum; +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +use bytemuck::{Pod, Zeroable}; +use num_traits::{One, Zero}; +use rand::distributions::{Distribution, Standard}; + +use super::cm31::PackedCM31; +use super::m31::{PackedM31, N_LANES}; +use crate::core::fields::qm31::QM31; +use crate::core::fields::FieldExpOps; + +pub type PackedSecureField = PackedQM31; + +/// SIMD implementation of [`QM31`]. +#[derive(Copy, Clone, Debug)] +pub struct PackedQM31(pub [PackedCM31; 2]); + +impl PackedQM31 { + /// Constructs a new instance with all vector elements set to `value`. + pub fn broadcast(value: QM31) -> Self { + Self([ + PackedCM31::broadcast(value.0), + PackedCM31::broadcast(value.1), + ]) + } + + /// Returns all `a` values such that each vector element is represented as `a + bu`. + pub fn a(&self) -> PackedCM31 { + self.0[0] + } + + /// Returns all `b` values such that each vector element is represented as `a + bu`. + pub fn b(&self) -> PackedCM31 { + self.0[1] + } + + pub fn to_array(&self) -> [QM31; N_LANES] { + let a = self.a().to_array(); + let b = self.b().to_array(); + array::from_fn(|i| QM31(a[i], b[i])) + } + + pub fn from_array(values: [QM31; N_LANES]) -> Self { + let a = values.map(|v| v.0); + let b = values.map(|v| v.1); + Self([PackedCM31::from_array(a), PackedCM31::from_array(b)]) + } + + /// Interleaves two vectors. + pub fn interleave(self, other: Self) -> (Self, Self) { + let Self([a_evens, b_evens]) = self; + let Self([a_odds, b_odds]) = other; + let (a_lhs, a_rhs) = a_evens.interleave(a_odds); + let (b_lhs, b_rhs) = b_evens.interleave(b_odds); + (Self([a_lhs, b_lhs]), Self([a_rhs, b_rhs])) + } + + /// Deinterleaves two vectors. + pub fn deinterleave(self, other: Self) -> (Self, Self) { + let Self([a_lhs, b_lhs]) = self; + let Self([a_rhs, b_rhs]) = other; + let (a_evens, a_odds) = a_lhs.deinterleave(a_rhs); + let (b_evens, b_odds) = b_lhs.deinterleave(b_rhs); + (Self([a_evens, b_evens]), Self([a_odds, b_odds])) + } + + /// Sums all the elements in the vector. + pub fn pointwise_sum(self) -> QM31 { + self.to_array().into_iter().sum() + } + + /// Doubles each element in the vector. + pub fn double(self) -> Self { + let Self([a, b]) = self; + Self([a.double(), b.double()]) + } + + /// Returns vectors `a, b, c, d` such that element `i` is represented as + /// `QM31(a_i, b_i, c_i, d_i)`. + pub fn into_packed_m31s(self) -> [PackedM31; 4] { + let Self([PackedCM31([a, b]), PackedCM31([c, d])]) = self; + [a, b, c, d] + } + + /// Creates an instance from vectors `a, b, c, d` such that element `i` + /// is represented as `QM31(a_i, b_i, c_i, d_i)`. + pub fn from_packed_m31s([a, b, c, d]: [PackedM31; 4]) -> Self { + Self([PackedCM31([a, b]), PackedCM31([c, d])]) + } +} + +impl Add for PackedQM31 { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self([self.a() + rhs.a(), self.b() + rhs.b()]) + } +} + +impl Sub for PackedQM31 { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self([self.a() - rhs.a(), self.b() - rhs.b()]) + } +} + +impl Mul for PackedQM31 { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + // Compute using Karatsuba. + // (a + ub) * (c + ud) = + // (ac + (2+i)bd) + (ad + bc)u = + // ac + 2bd + ibd + (ad + bc)u. + let ac = self.a() * rhs.a(); + let bd = self.b() * rhs.b(); + let bd_times_1_plus_i = PackedCM31([bd.a() - bd.b(), bd.a() + bd.b()]); + // Computes ac + bd. + let ac_p_bd = ac + bd; + // Computes ad + bc. + let ad_p_bc = (self.a() + self.b()) * (rhs.a() + rhs.b()) - ac_p_bd; + // ac + 2bd + ibd = + // ac + bd + bd + ibd + let l = PackedCM31([ + ac_p_bd.a() + bd_times_1_plus_i.a(), + ac_p_bd.b() + bd_times_1_plus_i.b(), + ]); + Self([l, ad_p_bc]) + } +} + +impl Zero for PackedQM31 { + fn zero() -> Self { + Self([PackedCM31::zero(), PackedCM31::zero()]) + } + + fn is_zero(&self) -> bool { + self.a().is_zero() && self.b().is_zero() + } +} + +impl One for PackedQM31 { + fn one() -> Self { + Self([PackedCM31::one(), PackedCM31::zero()]) + } +} + +impl AddAssign for PackedQM31 { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } +} + +impl MulAssign for PackedQM31 { + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } +} + +impl FieldExpOps for PackedQM31 { + fn inverse(&self) -> Self { + assert!(!self.is_zero(), "0 has no inverse"); + // (a + bu)^-1 = (a - bu) / (a^2 - (2+i)b^2). + let b2 = self.b().square(); + let ib2 = PackedCM31([-b2.b(), b2.a()]); + let denom = self.a().square() - (b2 + b2 + ib2); + let denom_inverse = denom.inverse(); + Self([self.a() * denom_inverse, -self.b() * denom_inverse]) + } +} + +impl Add for PackedQM31 { + type Output = Self; + + fn add(self, rhs: PackedM31) -> Self::Output { + Self([self.a() + rhs, self.b()]) + } +} + +impl Mul for PackedQM31 { + type Output = Self; + + fn mul(self, rhs: PackedM31) -> Self::Output { + let Self([a, b]) = self; + Self([a * rhs, b * rhs]) + } +} + +impl Mul for PackedQM31 { + type Output = Self; + + fn mul(self, rhs: PackedCM31) -> Self::Output { + let Self([a, b]) = self; + Self([a * rhs, b * rhs]) + } +} + +impl Sub for PackedQM31 { + type Output = Self; + + fn sub(self, rhs: PackedM31) -> Self::Output { + let Self([a, b]) = self; + Self([a - rhs, b]) + } +} + +impl Add for PackedQM31 { + type Output = Self; + + fn add(self, rhs: QM31) -> Self::Output { + self + PackedQM31::broadcast(rhs) + } +} + +impl Sub for PackedQM31 { + type Output = Self; + + fn sub(self, rhs: QM31) -> Self::Output { + self - PackedQM31::broadcast(rhs) + } +} + +impl Mul for PackedQM31 { + type Output = Self; + + fn mul(self, rhs: QM31) -> Self::Output { + self * PackedQM31::broadcast(rhs) + } +} + +impl SubAssign for PackedQM31 { + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +unsafe impl Pod for PackedQM31 {} + +unsafe impl Zeroable for PackedQM31 { + fn zeroed() -> Self { + unsafe { core::mem::zeroed() } + } +} + +impl Sum for PackedQM31 { + fn sum(mut iter: I) -> Self + where + I: Iterator, + { + let first = iter.next().unwrap_or_else(Self::zero); + iter.fold(first, |a, b| a + b) + } +} + +impl<'a> Sum<&'a Self> for PackedQM31 { + fn sum(iter: I) -> Self + where + I: Iterator, + { + iter.copied().sum() + } +} + +impl Neg for PackedQM31 { + type Output = Self; + + fn neg(self) -> Self::Output { + let Self([a, b]) = self; + Self([-a, -b]) + } +} + +impl Distribution for Standard { + fn sample(&self, rng: &mut R) -> PackedQM31 { + PackedQM31::from_array(rng.gen()) + } +} + +impl From for PackedQM31 { + fn from(value: PackedM31) -> Self { + PackedQM31::from_packed_m31s([ + value, + PackedM31::zero(), + PackedM31::zero(), + PackedM31::zero(), + ]) + } +} + +impl From for PackedQM31 { + fn from(value: QM31) -> Self { + PackedQM31::broadcast(value) + } +} + +#[cfg(test)] +mod tests { + use std::array; + + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use crate::core::backend::simd::qm31::PackedQM31; + + #[test] + fn addition_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedQM31::from_array(lhs); + let packed_rhs = PackedQM31::from_array(rhs); + + let res = packed_lhs + packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] + rhs[i])); + } + + #[test] + fn subtraction_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedQM31::from_array(lhs); + let packed_rhs = PackedQM31::from_array(rhs); + + let res = packed_lhs - packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] - rhs[i])); + } + + #[test] + fn multiplication_works() { + let mut rng = SmallRng::seed_from_u64(0); + let lhs = rng.gen(); + let rhs = rng.gen(); + let packed_lhs = PackedQM31::from_array(lhs); + let packed_rhs = PackedQM31::from_array(rhs); + + let res = packed_lhs * packed_rhs; + + assert_eq!(res.to_array(), array::from_fn(|i| lhs[i] * rhs[i])); + } + + #[test] + fn negation_works() { + let mut rng = SmallRng::seed_from_u64(0); + let values = rng.gen(); + let packed_values = PackedQM31::from_array(values); + + let res = -packed_values; + + assert_eq!(res.to_array(), values.map(|v| -v)); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/quotients.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/quotients.rs new file mode 100644 index 0000000..3cb664a --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/quotients.rs @@ -0,0 +1,314 @@ +use itertools::{izip, zip_eq, Itertools}; +use num_traits::Zero; +use tracing::{span, Level}; + +use super::cm31::PackedCM31; +use super::column::CM31Column; +use super::domain::CircleDomainBitRevIterator; +use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; +use super::qm31::PackedSecureField; +use super::SimdBackend; +use crate::core::backend::cpu::quotients::{batch_random_coeffs, column_line_coeffs}; +use crate::core::backend::Column; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE}; +use crate::core::fields::FieldExpOps; +use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; +use crate::core::poly::circle::{CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::bit_reverse; + +pub struct QuotientConstants { + pub line_coeffs: Vec>, + pub batch_random_coeffs: Vec, + pub denominator_inverses: Vec, +} + +impl QuotientOps for SimdBackend { + fn accumulate_quotients( + domain: CircleDomain, + columns: &[&CircleEvaluation], + random_coeff: SecureField, + sample_batches: &[ColumnSampleBatch], + log_blowup_factor: u32, + ) -> SecureEvaluation { + // Split the domain into a subdomain and a shift coset. + // TODO(spapini): Move to the caller when Columns support slices. + let (subdomain, mut subdomain_shifts) = domain.split(log_blowup_factor); + + // Bit reverse the shifts. + // Since we traverse the domain in bit-reversed order, we need bit-reverse the shifts. + // To see why, consider the index of a point in the natural order of the domain + // (least to most): + // b0 b1 b2 b3 b4 b5 + // b0 adds G, b1 adds 2G, etc.. (b5 is special and flips the sign of the point). + // Splitting the domain to 4 parts yields: + // subdomain: b2 b3 b4 b5, shifts: b0 b1. + // b2 b3 b4 b5 is indeed a circle domain, with a bigger jump. + // Traversing the domain in bit-reversed order, after we finish with b5, b4, b3, b2, + // we need to change b1 and then b0. This is the bit reverse of the shift b0 b1. + bit_reverse(&mut subdomain_shifts); + + let (span, mut extended_eval, subeval_polys) = accumulate_quotients_on_subdomain( + subdomain, + sample_batches, + random_coeff, + columns, + domain, + ); + + // Extend the evaluation to the full domain. + // TODO(spapini): Try to optimize out all these copies. + for (ci, &c) in subdomain_shifts.iter().enumerate() { + let subdomain = subdomain.shift(c); + + let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset); + #[allow(clippy::needless_range_loop)] + for i in 0..SECURE_EXTENSION_DEGREE { + // Sanity check. + let eval = subeval_polys[i].evaluate_with_twiddles(subdomain, &twiddles); + extended_eval.columns[i].data[(ci * eval.data.len())..((ci + 1) * eval.data.len())] + .copy_from_slice(&eval.data); + } + } + span.exit(); + + SecureEvaluation::new(domain, extended_eval) + } +} + +fn accumulate_quotients_on_subdomain( + subdomain: CircleDomain, + sample_batches: &[ColumnSampleBatch], + random_coeff: SecureField, + columns: &[&CircleEvaluation], + domain: CircleDomain, +) -> ( + span::EnteredSpan, + SecureColumnByCoords, + [crate::core::poly::circle::CirclePoly; 4], +) { + assert!(subdomain.log_size() >= LOG_N_LANES + 2); + let mut values = + unsafe { SecureColumnByCoords::::uninitialized(subdomain.size()) }; + let quotient_constants = quotient_constants(sample_batches, random_coeff, subdomain); + + let span = span!(Level::INFO, "Quotient accumulation").entered(); + for (quad_row, points) in CircleDomainBitRevIterator::new(subdomain) + .array_chunks::<4>() + .enumerate() + { + // TODO(spapini): Use optimized domain iteration. + let (y01, _) = points[0].y.deinterleave(points[1].y); + let (y23, _) = points[2].y.deinterleave(points[3].y); + let (spaced_ys, _) = y01.deinterleave(y23); + let row_accumulator = accumulate_row_quotients( + sample_batches, + columns, + "ient_constants, + quad_row, + spaced_ys, + ); + #[allow(clippy::needless_range_loop)] + for i in 0..4 { + unsafe { values.set_packed((quad_row << 2) + i, row_accumulator[i]) }; + } + } + span.exit(); + let span = span!(Level::INFO, "Quotient extension").entered(); + + // Extend the evaluation to the full domain. + let extended_eval = + unsafe { SecureColumnByCoords::::uninitialized(domain.size()) }; + + let mut i = 0; + let values = values.columns; + let twiddles = SimdBackend::precompute_twiddles(subdomain.half_coset); + let subeval_polys = values.map(|c| { + i += 1; + CircleEvaluation::::new(subdomain, c) + .interpolate_with_twiddles(&twiddles) + }); + (span, extended_eval, subeval_polys) +} + +/// Accumulates the quotients for 4 * N_LANES rows at a time. +/// spaced_ys - y values for N_LANES points in the domain, in jumps of 4. +pub fn accumulate_row_quotients( + sample_batches: &[ColumnSampleBatch], + columns: &[&CircleEvaluation], + quotient_constants: &QuotientConstants, + quad_row: usize, + spaced_ys: PackedBaseField, +) -> [PackedSecureField; 4] { + let mut row_accumulator = [PackedSecureField::zero(); 4]; + for (sample_batch, line_coeffs, batch_coeff, denominator_inverses) in izip!( + sample_batches, + "ient_constants.line_coeffs, + "ient_constants.batch_random_coeffs, + "ient_constants.denominator_inverses + ) { + let mut numerator = [PackedSecureField::zero(); 4]; + for ((column_index, _), (a, b, c)) in zip_eq(&sample_batch.columns_and_values, line_coeffs) + { + let column = &columns[*column_index]; + let cvalues: [_; 4] = std::array::from_fn(|i| { + PackedSecureField::broadcast(*c) * column.data[(quad_row << 2) + i] + }); + + // The numerator is the line equation: + // c * value - a * point.y - b; + // Note that a, b, c were already multilpied by random_coeff^i. + // See [column_line_coeffs()] for more details. + // This is why we only add here. + // 4 consecutive point in the domain in bit reversed order are: + // P, -P, P + H, -P + H. + // H being the half point (-1,0). The y values for these are + // P.y, -P.y, -P.y, P.y. + // We use this fact to save multiplications. + // spaced_ys are the y value in jumps of 4: + // P0.y, P1.y, P2.y, ... + let spaced_ay = PackedSecureField::broadcast(*a) * spaced_ys; + // t0:t1 = a*P0.y, -a*P0.y, a*P1.y, -a*P1.y, ... + let (t0, t1) = spaced_ay.interleave(-spaced_ay); + // t2:t3:t4:t5 = a*P0.y, -a*P0.y, -a*P0.y, a*P0.y, a*P1.y, -a*P1.y, ... + let (t2, t3) = t0.interleave(-t0); + let (t4, t5) = t1.interleave(-t1); + let ay = [t2, t3, t4, t5]; + for i in 0..4 { + numerator[i] += cvalues[i] - ay[i] - PackedSecureField::broadcast(*b); + } + } + + for i in 0..4 { + row_accumulator[i] = row_accumulator[i] * PackedSecureField::broadcast(*batch_coeff) + + numerator[i] * denominator_inverses.data[(quad_row << 2) + i]; + } + } + row_accumulator +} + +fn denominator_inverses( + sample_batches: &[ColumnSampleBatch], + domain: CircleDomain, +) -> Vec { + // We want a P to be on a line that passes through a point Pr + uPi in QM31^2, and its conjugate + // Pr - uPi. Thus, Pr - P is parallel to Pi. Or, (Pr - P).x * Pi.y - (Pr - P).y * Pi.x = 0. + let flat_denominators: CM31Column = sample_batches + .iter() + .flat_map(|sample_batch| { + // Extract Pr, Pi. + let prx = PackedCM31::broadcast(sample_batch.point.x.0); + let pry = PackedCM31::broadcast(sample_batch.point.y.0); + let pix = PackedCM31::broadcast(sample_batch.point.x.1); + let piy = PackedCM31::broadcast(sample_batch.point.y.1); + + // Line equation through pr +-u pi. + // (p-pr)* + CircleDomainBitRevIterator::new(domain) + .map(|points| (prx - points.x) * piy - (pry - points.y) * pix) + .collect_vec() + }) + .collect(); + + let mut flat_denominator_inverses = + unsafe { CM31Column::uninitialized(flat_denominators.len()) }; + FieldExpOps::batch_inverse( + &flat_denominators.data, + &mut flat_denominator_inverses.data[..], + ); + + flat_denominator_inverses + .data + .chunks(domain.size() / N_LANES) + .map(|denominator_inverses| denominator_inverses.iter().copied().collect()) + .collect() +} + +fn quotient_constants( + sample_batches: &[ColumnSampleBatch], + random_coeff: SecureField, + domain: CircleDomain, +) -> QuotientConstants { + let _span = span!(Level::INFO, "Quotient constants").entered(); + let line_coeffs = column_line_coeffs(sample_batches, random_coeff); + let batch_random_coeffs = batch_random_coeffs(sample_batches, random_coeff); + let denominator_inverses = denominator_inverses(sample_batches, domain); + QuotientConstants { + line_coeffs, + batch_random_coeffs, + denominator_inverses, + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use crate::core::backend::simd::column::BaseColumn; + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::{Column, CpuBackend}; + use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; + use crate::core::fields::m31::BaseField; + use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; + use crate::core::poly::BitReversedOrder; + use crate::qm31; + + #[test] + fn test_accumulate_quotients() { + const LOG_SIZE: u32 = 8; + const LOG_BLOWUP_FACTOR: u32 = 1; + let small_domain = CanonicCoset::new(LOG_SIZE).circle_domain(); + let domain = CanonicCoset::new(LOG_SIZE + LOG_BLOWUP_FACTOR).circle_domain(); + let e0: BaseColumn = (0..small_domain.size()).map(BaseField::from).collect(); + let e1: BaseColumn = (0..small_domain.size()) + .map(|i| BaseField::from(2 * i)) + .collect(); + let polys = vec![ + CircleEvaluation::::new(small_domain, e0) + .interpolate(), + CircleEvaluation::::new(small_domain, e1) + .interpolate(), + ]; + let columns = vec![polys[0].evaluate(domain), polys[1].evaluate(domain)]; + let random_coeff = qm31!(1, 2, 3, 4); + let a = polys[0].eval_at_point(SECURE_FIELD_CIRCLE_GEN); + let b = polys[1].eval_at_point(SECURE_FIELD_CIRCLE_GEN); + let samples = vec![ColumnSampleBatch { + point: SECURE_FIELD_CIRCLE_GEN, + columns_and_values: vec![(0, a), (1, b)], + }]; + let cpu_columns = columns + .iter() + .map(|c| { + CircleEvaluation::::new( + c.domain, + c.values.to_cpu(), + ) + }) + .collect::>(); + let cpu_result = CpuBackend::accumulate_quotients( + domain, + &cpu_columns.iter().collect_vec(), + random_coeff, + &samples, + LOG_BLOWUP_FACTOR, + ) + .values + .to_vec(); + + let res = SimdBackend::accumulate_quotients( + domain, + &columns.iter().collect_vec(), + random_coeff, + &samples, + LOG_BLOWUP_FACTOR, + ) + .values + .to_vec(); + + assert_eq!(res, cpu_result); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/utils.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/utils.rs new file mode 100644 index 0000000..87dfd22 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/utils.rs @@ -0,0 +1,52 @@ +use std::simd::Swizzle; + +/// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. +pub struct InterleaveEvens; + +impl Swizzle for InterleaveEvens { + const INDEX: [usize; N] = parity_interleave(false); +} + +/// Used with [`Swizzle::concat_swizzle`] to interleave the odd values of two vectors. +pub struct InterleaveOdds; + +impl Swizzle for InterleaveOdds { + const INDEX: [usize; N] = parity_interleave(true); +} + +const fn parity_interleave(odd: bool) -> [usize; N] { + let mut res = [0; N]; + let mut i = 0; + while i < N { + res[i] = (i % 2) * N + (i / 2) * 2 + if odd { 1 } else { 0 }; + i += 1; + } + res +} + +#[cfg(test)] +mod tests { + use std::simd::{u32x4, Swizzle}; + + use super::{InterleaveEvens, InterleaveOdds}; + + #[test] + fn interleave_evens() { + let lo = u32x4::from_array([0, 1, 2, 3]); + let hi = u32x4::from_array([4, 5, 6, 7]); + + let res = InterleaveEvens::concat_swizzle(lo, hi); + + assert_eq!(res, u32x4::from_array([0, 4, 2, 6])); + } + + #[test] + fn interleave_odds() { + let lo = u32x4::from_array([0, 1, 2, 3]); + let hi = u32x4::from_array([4, 5, 6, 7]); + + let res = InterleaveOdds::concat_swizzle(lo, hi); + + assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/backend/simd/very_packed_m31.rs b/Stwo_wrapper/crates/prover/src/core/backend/simd/very_packed_m31.rs new file mode 100644 index 0000000..2e344b8 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/backend/simd/very_packed_m31.rs @@ -0,0 +1,222 @@ +use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; + +use bytemuck::{Pod, Zeroable}; +use num_traits::{One, Zero}; + +use super::cm31::PackedCM31; +use super::m31::{PackedM31, N_LANES}; +use super::qm31::PackedQM31; +use crate::core::fields::cm31::CM31; +use crate::core::fields::m31::{pow2147483645, M31}; +use crate::core::fields::qm31::QM31; +use crate::core::fields::FieldExpOps; + +pub const LOG_N_VERY_PACKED_ELEMS: u32 = 1; +pub const N_VERY_PACKED_ELEMS: usize = 1 << LOG_N_VERY_PACKED_ELEMS; + +#[derive(Copy, Clone, Debug)] +#[repr(transparent)] +pub struct Vectorized(pub [A; N]); + +impl Vectorized { + pub fn from_fn(cb: F) -> Self + where + F: FnMut(usize) -> A, + { + Vectorized(std::array::from_fn(cb)) + } +} + +unsafe impl Zeroable for Vectorized { + fn zeroed() -> Self { + unsafe { core::mem::zeroed() } + } +} +unsafe impl Pod for Vectorized {} + +pub type VeryPackedM31 = Vectorized; +pub type VeryPackedCM31 = Vectorized; +pub type VeryPackedQM31 = Vectorized; +pub type VeryPackedBaseField = VeryPackedM31; +pub type VeryPackedSecureField = VeryPackedQM31; + +impl VeryPackedM31 { + pub fn broadcast(value: M31) -> Self { + Self::from_fn(|_| PackedM31::broadcast(value)) + } + + pub fn from_array(values: [M31; N_LANES * N_VERY_PACKED_ELEMS]) -> VeryPackedM31 { + Self::from_fn(|i| { + let start = i * N_LANES; + let end = start + N_LANES; + PackedM31::from_array(values[start..end].try_into().unwrap()) + }) + } + + pub fn to_array(&self) -> [M31; N_LANES * N_VERY_PACKED_ELEMS] { + // Safety: We are transmuting &[A; N_VERY_PACKED_ELEMS] into &[i32; N_LANES * + // N_VERY_PACKED_ELEMS] because we know that A contains [i32; N_LANES] and the + // memory layout is contiguous. + unsafe { + std::slice::from_raw_parts(self.0.as_ptr() as *const M31, N_LANES * N_VERY_PACKED_ELEMS) + .try_into() + .unwrap() + } + } +} + +impl VeryPackedCM31 { + pub fn broadcast(value: CM31) -> Self { + Self::from_fn(|_| PackedCM31::broadcast(value)) + } +} + +impl VeryPackedQM31 { + pub fn broadcast(value: QM31) -> Self { + Self::from_fn(|_| PackedQM31::broadcast(value)) + } + + pub fn from_very_packed_m31s([a, b, c, d]: [VeryPackedM31; 4]) -> Self { + Self::from_fn(|i| PackedQM31::from_packed_m31s([a.0[i], b.0[i], c.0[i], d.0[i]])) + } +} +impl From for VeryPackedM31 { + fn from(v: M31) -> Self { + Self::broadcast(v) + } +} + +impl From for VeryPackedQM31 { + fn from(value: VeryPackedM31) -> Self { + VeryPackedQM31::from_very_packed_m31s([ + value, + VeryPackedM31::zero(), + VeryPackedM31::zero(), + VeryPackedM31::zero(), + ]) + } +} + +impl From for VeryPackedQM31 { + fn from(value: QM31) -> Self { + VeryPackedQM31::broadcast(value) + } +} + +trait Scalar {} +impl Scalar for M31 {} +impl Scalar for CM31 {} +impl Scalar for QM31 {} +impl Scalar for PackedM31 {} +impl Scalar for PackedCM31 {} +impl Scalar for PackedQM31 {} + +impl + Copy, B: Copy, const N: usize> Add> for Vectorized { + type Output = Vectorized; + + fn add(self, other: Vectorized) -> Self::Output { + Vectorized::from_fn(|i| self.0[i] + other.0[i]) + } +} + +impl + Copy, B: Scalar + Copy, const N: usize> Add for Vectorized { + type Output = Vectorized; + + fn add(self, other: B) -> Self::Output { + Vectorized::from_fn(|i| self.0[i] + other) + } +} + +impl + Copy, B: Copy, const N: usize> Sub> for Vectorized { + type Output = Vectorized; + + fn sub(self, other: Vectorized) -> Self::Output { + Vectorized::from_fn(|i| self.0[i] - other.0[i]) + } +} + +impl + Copy, B: Scalar + Copy, const N: usize> Sub for Vectorized { + type Output = Vectorized; + + fn sub(self, other: B) -> Self::Output { + Vectorized::from_fn(|i| self.0[i] - other) + } +} + +impl + Copy, B: Copy, const N: usize> Mul> for Vectorized { + type Output = Vectorized; + + fn mul(self, other: Vectorized) -> Self::Output { + Vectorized::from_fn(|i| self.0[i] * other.0[i]) + } +} + +impl + Copy, B: Scalar + Copy, const N: usize> Mul for Vectorized { + type Output = Vectorized; + + fn mul(self, other: B) -> Self::Output { + Vectorized::from_fn(|i| self.0[i] * other) + } +} + +impl + Copy, B: Copy, const N: usize> AddAssign> + for Vectorized +{ + fn add_assign(&mut self, other: Vectorized) { + for i in 0..N { + self.0[i] += other.0[i]; + } + } +} + +impl + Copy, B: Scalar + Copy, const N: usize> AddAssign for Vectorized { + fn add_assign(&mut self, other: B) { + for i in 0..N { + self.0[i] += other; + } + } +} + +impl + Copy, B: Copy, const N: usize> MulAssign> + for Vectorized +{ + fn mul_assign(&mut self, other: Vectorized) { + for i in 0..N { + self.0[i] *= other.0[i]; + } + } +} + +impl Neg for Vectorized { + type Output = Vectorized; + + #[inline(always)] + fn neg(self) -> Self::Output { + Vectorized::from_fn(|i| self.0[i].neg()) + } +} + +impl Zero for Vectorized { + fn zero() -> Self { + Vectorized::from_fn(|_| A::zero()) + } + + fn is_zero(&self) -> bool { + self.0.iter().all(A::is_zero) + } +} + +impl One for Vectorized { + fn one() -> Self { + Vectorized::from_fn(|_| A::one()) + } +} + +impl FieldExpOps for Vectorized { + fn inverse(&self) -> Self { + Vectorized::from_fn(|i| { + assert!(!self.0[i].is_zero(), "0 has no inverse"); + pow2147483645(self.0[i]) + }) + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/channel/blake2s.rs b/Stwo_wrapper/crates/prover/src/core/channel/blake2s.rs new file mode 100644 index 0000000..9861862 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/channel/blake2s.rs @@ -0,0 +1,186 @@ +use std::iter; + +use super::{Channel, ChannelTime}; +use crate::core::fields::m31::{BaseField, N_BYTES_FELT, P}; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::fields::IntoSlice; +use crate::core::vcs::blake2_hash::{Blake2sHash, Blake2sHasher}; +use crate::core::vcs::blake2s_ref::compress; + +pub const BLAKE_BYTES_PER_HASH: usize = 32; +pub const FELTS_PER_HASH: usize = 8; + +/// A channel that can be used to draw random elements from a [Blake2sHash] digest. +#[derive(Default, Clone)] +pub struct Blake2sChannel { + digest: Blake2sHash, + pub channel_time: ChannelTime, +} + +impl Blake2sChannel { + pub fn digest(&self) -> Blake2sHash { + self.digest + } + pub fn update_digest(&mut self, new_digest: Blake2sHash) { + self.digest = new_digest; + self.channel_time.inc_challenges(); + } + /// Generates a uniform random vector of BaseField elements. + fn draw_base_felts(&mut self) -> [BaseField; FELTS_PER_HASH] { + // Repeats hashing with an increasing counter until getting a good result. + // Retry probability for each round is ~ 2^(-28). + loop { + let u32s: [u32; FELTS_PER_HASH] = self + .draw_random_bytes() + .chunks_exact(N_BYTES_FELT) // 4 bytes per u32. + .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) + .collect::>() + .try_into() + .unwrap(); + + // Retry if not all the u32 are in the range [0, 2P). + if u32s.iter().all(|x| *x < 2 * P) { + return u32s + .into_iter() + .map(|x| BaseField::reduce(x as u64)) + .collect::>() + .try_into() + .unwrap(); + } + } + } +} + +impl Channel for Blake2sChannel { + const BYTES_PER_HASH: usize = BLAKE_BYTES_PER_HASH; + + fn trailing_zeros(&self) -> u32 { + u128::from_le_bytes(std::array::from_fn(|i| self.digest.0[i])).trailing_zeros() + } + + fn mix_felts(&mut self, felts: &[SecureField]) { + let mut hasher = Blake2sHasher::new(); + hasher.update(self.digest.as_ref()); + hasher.update(IntoSlice::::into_slice(felts)); + + self.update_digest(hasher.finalize()); + } + + fn mix_nonce(&mut self, nonce: u64) { + let digest: [u32; 8] = unsafe { std::mem::transmute(self.digest) }; + let mut msg = [0; 16]; + msg[0] = nonce as u32; + msg[1] = (nonce >> 32) as u32; + let res = compress(std::array::from_fn(|i| digest[i]), msg, 0, 0, 0, 0); + + self.update_digest(unsafe { std::mem::transmute(res) }); + } + + fn draw_felt(&mut self) -> SecureField { + let felts: [BaseField; FELTS_PER_HASH] = self.draw_base_felts(); + SecureField::from_m31_array(felts[..SECURE_EXTENSION_DEGREE].try_into().unwrap()) + } + + fn draw_felts(&mut self, n_felts: usize) -> Vec { + let mut felts = iter::from_fn(|| Some(self.draw_base_felts())).flatten(); + let secure_felts = iter::from_fn(|| { + Some(SecureField::from_m31_array([ + felts.next()?, + felts.next()?, + felts.next()?, + felts.next()?, + ])) + }); + secure_felts.take(n_felts).collect() + } + + fn draw_random_bytes(&mut self) -> Vec { + let mut hash_input = self.digest.as_ref().to_vec(); + + // Pad the counter to 32 bytes. + let mut padded_counter = [0; BLAKE_BYTES_PER_HASH]; + let counter_bytes = self.channel_time.n_sent.to_le_bytes(); + padded_counter[0..counter_bytes.len()].copy_from_slice(&counter_bytes); + + hash_input.extend_from_slice(&padded_counter); + + // TODO(spapini): Are we worried about this drawing hash colliding with mix_digest? + + self.channel_time.inc_sent(); + Blake2sHasher::hash(&hash_input).into() + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use crate::core::channel::blake2s::Blake2sChannel; + use crate::core::channel::Channel; + use crate::core::fields::qm31::SecureField; + use crate::m31; + + #[test] + fn test_channel_time() { + let mut channel = Blake2sChannel::default(); + + assert_eq!(channel.channel_time.n_challenges, 0); + assert_eq!(channel.channel_time.n_sent, 0); + + channel.draw_random_bytes(); + assert_eq!(channel.channel_time.n_challenges, 0); + assert_eq!(channel.channel_time.n_sent, 1); + + channel.draw_felts(9); + assert_eq!(channel.channel_time.n_challenges, 0); + assert_eq!(channel.channel_time.n_sent, 6); + } + + #[test] + fn test_draw_random_bytes() { + let mut channel = Blake2sChannel::default(); + + let first_random_bytes = channel.draw_random_bytes(); + + // Assert that next random bytes are different. + assert_ne!(first_random_bytes, channel.draw_random_bytes()); + } + + #[test] + pub fn test_draw_felt() { + let mut channel = Blake2sChannel::default(); + + let first_random_felt = channel.draw_felt(); + + // Assert that next random felt is different. + assert_ne!(first_random_felt, channel.draw_felt()); + } + + #[test] + pub fn test_draw_felts() { + let mut channel = Blake2sChannel::default(); + + let mut random_felts = channel.draw_felts(5); + random_felts.extend(channel.draw_felts(4)); + + // Assert that all the random felts are unique. + assert_eq!( + random_felts.len(), + random_felts.iter().collect::>().len() + ); + } + + #[test] + pub fn test_mix_felts() { + let mut channel = Blake2sChannel::default(); + let initial_digest = channel.digest; + let felts: Vec = (0..2) + .map(|i| SecureField::from(m31!(i + 1923782))) + .collect(); + + channel.mix_felts(felts.as_slice()); + + assert_ne!(initial_digest, channel.digest); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/channel/mod.rs b/Stwo_wrapper/crates/prover/src/core/channel/mod.rs new file mode 100644 index 0000000..3d85b8d --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/channel/mod.rs @@ -0,0 +1,57 @@ +use super::fields::qm31::SecureField; +use super::vcs::ops::MerkleHasher; + +#[cfg(not(target_arch = "wasm32"))] +mod poseidon252; +#[cfg(not(target_arch = "wasm32"))] +pub use poseidon252::Poseidon252Channel; + +mod blake2s; +pub use blake2s::Blake2sChannel; + +#[cfg(not(target_arch = "wasm32"))] +mod poseidon_bls; +#[cfg(not(target_arch = "wasm32"))] +pub use poseidon_bls::PoseidonBLSChannel; + +pub const EXTENSION_FELTS_PER_HASH: usize = 2; + +#[derive(Clone, Default)] +pub struct ChannelTime { + pub n_challenges: usize, + n_sent: usize, +} + +impl ChannelTime { + fn inc_sent(&mut self) { + self.n_sent += 1; + } + + fn inc_challenges(&mut self) { + self.n_challenges += 1; + self.n_sent = 0; + } +} + +pub trait Channel: Default + Clone { + const BYTES_PER_HASH: usize; + + fn trailing_zeros(&self) -> u32; + + // Mix functions. + fn mix_felts(&mut self, felts: &[SecureField]); + fn mix_nonce(&mut self, nonce: u64); + + // Draw functions. + fn draw_felt(&mut self) -> SecureField; + /// Generates a uniform random vector of SecureField elements. + fn draw_felts(&mut self, n_felts: usize) -> Vec; + /// Returns a vector of random bytes of length `BYTES_PER_HASH`. + fn draw_random_bytes(&mut self) -> Vec; +} + +pub trait MerkleChannel: Default { + type C: Channel; + type H: MerkleHasher; + fn mix_root(channel: &mut Self::C, root: ::Hash); +} diff --git a/Stwo_wrapper/crates/prover/src/core/channel/poseidon252.rs b/Stwo_wrapper/crates/prover/src/core/channel/poseidon252.rs new file mode 100644 index 0000000..195d1fc --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/channel/poseidon252.rs @@ -0,0 +1,190 @@ +use std::iter; + +use starknet_crypto::{poseidon_hash, poseidon_hash_many}; +use starknet_ff::FieldElement as FieldElement252; + +use super::{Channel, ChannelTime}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; + +pub const BYTES_PER_FELT252: usize = 31; +pub const FELTS_PER_HASH: usize = 8; + +/// A channel that can be used to draw random elements from a Poseidon252 hash. +#[derive(Clone, Default)] +pub struct Poseidon252Channel { + digest: FieldElement252, + pub channel_time: ChannelTime, +} + +impl Poseidon252Channel { + pub fn digest(&self) -> FieldElement252 { + self.digest + } + pub fn update_digest(&mut self, new_digest: FieldElement252) { + self.digest = new_digest; + self.channel_time.inc_challenges(); + } + fn draw_felt252(&mut self) -> FieldElement252 { + let res = poseidon_hash(self.digest, self.channel_time.n_sent.into()); + self.channel_time.inc_sent(); + res + } + + // TODO(spapini): Understand if we really need uniformity here. + /// Generates a close-to uniform random vector of BaseField elements. + fn draw_base_felts(&mut self) -> [BaseField; 8] { + let shift = (1u64 << 31).into(); + + let mut cur = self.draw_felt252(); + let u32s: [u32; 8] = std::array::from_fn(|_| { + let next = cur.floor_div(shift); + let res = cur - next * shift; + cur = next; + res.try_into().unwrap() + }); + + u32s.into_iter() + .map(|x| BaseField::reduce(x as u64)) + .collect::>() + .try_into() + .unwrap() + } +} + +impl Channel for Poseidon252Channel { + const BYTES_PER_HASH: usize = BYTES_PER_FELT252; + + fn trailing_zeros(&self) -> u32 { + let bytes = self.digest.to_bytes_be(); + u128::from_le_bytes(std::array::from_fn(|i| bytes[i])).trailing_zeros() + } + + // TODO(spapini): Optimize. + fn mix_felts(&mut self, felts: &[SecureField]) { + let shift = (1u64 << 31).into(); + let mut res = Vec::with_capacity(felts.len() / 2 + 2); + res.push(self.digest); + for chunk in felts.chunks(2) { + res.push( + chunk + .iter() + .flat_map(|x| x.to_m31_array()) + .fold(FieldElement252::default(), |cur, y| { + cur * shift + y.0.into() + }), + ); + } + + // TODO(spapini): do we need length padding? + self.update_digest(poseidon_hash_many(&res)); + } + + fn mix_nonce(&mut self, nonce: u64) { + self.update_digest(poseidon_hash(self.digest, nonce.into())); + } + + fn draw_felt(&mut self) -> SecureField { + let felts: [BaseField; FELTS_PER_HASH] = self.draw_base_felts(); + SecureField::from_m31_array(felts[..SECURE_EXTENSION_DEGREE].try_into().unwrap()) + } + + fn draw_felts(&mut self, n_felts: usize) -> Vec { + let mut felts = iter::from_fn(|| Some(self.draw_base_felts())).flatten(); + let secure_felts = iter::from_fn(|| { + Some(SecureField::from_m31_array([ + felts.next()?, + felts.next()?, + felts.next()?, + felts.next()?, + ])) + }); + secure_felts.take(n_felts).collect() + } + + fn draw_random_bytes(&mut self) -> Vec { + let shift = (1u64 << 8).into(); + let mut cur = self.draw_felt252(); + let bytes: [u8; 31] = std::array::from_fn(|_| { + let next = cur.floor_div(shift); + let res = cur - next * shift; + cur = next; + res.try_into().unwrap() + }); + bytes.to_vec() + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use crate::core::channel::poseidon252::Poseidon252Channel; + use crate::core::channel::Channel; + use crate::core::fields::qm31::SecureField; + use crate::m31; + + #[test] + fn test_channel_time() { + let mut channel = Poseidon252Channel::default(); + + assert_eq!(channel.channel_time.n_challenges, 0); + assert_eq!(channel.channel_time.n_sent, 0); + + channel.draw_random_bytes(); + assert_eq!(channel.channel_time.n_challenges, 0); + assert_eq!(channel.channel_time.n_sent, 1); + + channel.draw_felts(9); + assert_eq!(channel.channel_time.n_challenges, 0); + assert_eq!(channel.channel_time.n_sent, 6); + } + + #[test] + fn test_draw_random_bytes() { + let mut channel = Poseidon252Channel::default(); + + let first_random_bytes = channel.draw_random_bytes(); + + // Assert that next random bytes are different. + assert_ne!(first_random_bytes, channel.draw_random_bytes()); + } + + #[test] + pub fn test_draw_felt() { + let mut channel = Poseidon252Channel::default(); + + let first_random_felt = channel.draw_felt(); + + // Assert that next random felt is different. + assert_ne!(first_random_felt, channel.draw_felt()); + } + + #[test] + pub fn test_draw_felts() { + let mut channel = Poseidon252Channel::default(); + + let mut random_felts = channel.draw_felts(5); + random_felts.extend(channel.draw_felts(4)); + + // Assert that all the random felts are unique. + assert_eq!( + random_felts.len(), + random_felts.iter().collect::>().len() + ); + } + + #[test] + pub fn test_mix_felts() { + let mut channel = Poseidon252Channel::default(); + let initial_digest = channel.digest; + let felts: Vec = (0..2) + .map(|i| SecureField::from(m31!(i + 1923782))) + .collect(); + + channel.mix_felts(felts.as_slice()); + + assert_ne!(initial_digest, channel.digest); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/channel/poseidon_bls.rs b/Stwo_wrapper/crates/prover/src/core/channel/poseidon_bls.rs new file mode 100644 index 0000000..65fdce5 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/channel/poseidon_bls.rs @@ -0,0 +1,590 @@ +use std::iter; + +use ark_bls12_381::Fr as BlsFr; +use ark_ff::{BigInteger, Field, PrimeField}; +use crypto_bigint::{Encoding, NonZero, U256}; + +use super::{Channel, ChannelTime}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; + +pub const BYTES_PER_FELT252: usize = 32; +pub const FELTS_PER_HASH: usize = 8; + +//Optimize constant to be real constants (no conversion) and merge duplicated code in VCS poseidono +fn poseidon_comp_consts(idx: usize) -> BlsFr { + match idx { + 0 => BlsFr::from_be_bytes_mod_order(&[ + 111, 0, 122, 85, 17, 86, 179, 164, 73, 228, 73, 54, 183, 192, 147, 100, 74, 14, 211, + 63, 51, 234, 204, 198, 40, 233, 66, 232, 54, 193, 168, 117, + ]), + 1 => BlsFr::from_be_bytes_mod_order(&[ + 54, 13, 116, 112, 97, 30, 71, 61, 53, 63, 98, 143, 118, 209, 16, 243, 78, 113, 22, 47, + 49, 0, 59, 112, 87, 83, 140, 37, 150, 66, 99, 3, + ]), + 2 => BlsFr::from_be_bytes_mod_order(&[ + 75, 95, 236, 58, 160, 115, 223, 68, 1, 144, 145, 240, 7, 164, 76, 169, 150, 72, 73, + 101, 247, 3, 109, 206, 62, 157, 9, 119, 237, 205, 192, 246, + ]), + 3 => BlsFr::from_be_bytes_mod_order(&[ + 103, 207, 24, 104, 175, 99, 150, 192, 184, 76, 206, 113, 94, 83, 159, 132, 158, 6, 205, + 28, 56, 58, 197, 176, 97, 0, 199, 107, 204, 151, 58, 17, + ]), + 4 => BlsFr::from_be_bytes_mod_order(&[ + 85, 93, 180, 209, 220, 237, 129, 159, 93, 61, 231, 15, 222, 131, 241, 199, 211, 232, + 201, 137, 104, 229, 22, 162, 58, 119, 26, 92, 156, 130, 87, 170, + ]), + 5 => BlsFr::from_be_bytes_mod_order(&[ + 43, 171, 148, 215, 174, 34, 45, 19, 93, 195, 198, 197, 254, 191, 170, 49, 73, 8, 172, + 47, 18, 235, 224, 111, 189, 183, 66, 19, 191, 99, 24, 139, + ]), + 6 => BlsFr::from_be_bytes_mod_order(&[ + 102, 244, 75, 229, 41, 102, 130, 196, 250, 120, 130, 121, 157, 109, 208, 73, 182, 215, + 210, 201, 80, 204, 249, 140, 242, 229, 13, 109, 30, 187, 119, 194, + ]), + 7 => BlsFr::from_be_bytes_mod_order(&[ + 21, 12, 147, 254, 246, 82, 251, 28, 43, 240, 62, 26, 41, 170, 135, 31, 239, 119, 231, + 215, 54, 118, 108, 93, 9, 57, 217, 39, 83, 204, 93, 200, + ]), + 8 => BlsFr::from_be_bytes_mod_order(&[ + 50, 112, 102, 30, 104, 146, 139, 58, 149, 93, 85, 219, 86, 220, 87, 193, 3, 204, 10, + 96, 20, 30, 137, 78, 20, 37, 157, 206, 83, 119, 130, 178, + ]), + 9 => BlsFr::from_be_bytes_mod_order(&[ + 7, 63, 17, 111, 4, 18, 46, 37, 160, 183, 175, 228, 226, 5, 114, 153, 180, 7, 195, 112, + 242, 181, 161, 204, 206, 159, 185, 255, 195, 69, 175, 179, + ]), + 10 => BlsFr::from_be_bytes_mod_order(&[ + 64, 159, 218, 34, 85, 140, 254, 77, 61, 216, 220, 226, 79, 105, 231, 111, 140, 42, 174, + 177, 221, 15, 9, 214, 94, 101, 76, 113, 243, 42, 162, 63, + ]), + 11 => BlsFr::from_be_bytes_mod_order(&[ + 42, 50, 236, 92, 78, 229, 177, 131, 122, 255, 208, 156, 31, 83, 245, 253, 85, 201, 205, + 32, 97, 174, 147, 202, 142, 186, 215, 111, 199, 21, 84, 216, + ]), + 12 => BlsFr::from_be_bytes_mod_order(&[ + 88, 72, 235, 235, 89, 35, 233, 37, 85, 183, 18, 79, 255, 186, 93, 107, 213, 113, 198, + 249, 132, 25, 94, 185, 207, 211, 163, 232, 235, 85, 177, 212, + ]), + 13 => BlsFr::from_be_bytes_mod_order(&[ + 39, 3, 38, 238, 3, 157, 241, 158, 101, 30, 44, 252, 116, 6, 40, 202, 99, 77, 36, 252, + 110, 37, 89, 242, 45, 140, 203, 226, 146, 239, 238, 173, + ]), + 14 => BlsFr::from_be_bytes_mod_order(&[ + 39, 198, 100, 42, 198, 51, 188, 102, 220, 16, 15, 231, 252, 250, 84, 145, 138, 248, + 149, 188, 224, 18, 241, 130, 160, 104, 252, 55, 193, 130, 226, 116, + ]), + 15 => BlsFr::from_be_bytes_mod_order(&[ + 27, 223, 216, 176, 20, 1, 199, 10, 210, 127, 87, 57, 105, 137, 18, 157, 113, 14, 31, + 182, 171, 151, 106, 69, 156, 161, 134, 130, 226, 109, 127, 249, + ]), + 16 => BlsFr::from_be_bytes_mod_order(&[ + 73, 27, 155, 166, 152, 59, 207, 159, 5, 254, 71, 148, 173, 180, 74, 48, 135, 155, 248, + 40, 150, 98, 225, 245, 125, 144, 246, 114, 65, 78, 138, 74, + ]), + 17 => BlsFr::from_be_bytes_mod_order(&[ + 22, 42, 20, 198, 47, 154, 137, 184, 20, 185, 214, 169, 200, 77, 214, 120, 244, 246, + 251, 63, 144, 84, 211, 115, 200, 50, 216, 36, 38, 26, 53, 234, + ]), + 18 => BlsFr::from_be_bytes_mod_order(&[ + 45, 25, 62, 15, 118, 222, 88, 107, 42, 246, 247, 158, 49, 39, 254, 234, 172, 10, 31, + 199, 30, 44, 240, 192, 247, 152, 36, 102, 123, 91, 107, 236, + ]), + 19 => BlsFr::from_be_bytes_mod_order(&[ + 70, 239, 216, 169, 162, 98, 214, 216, 253, 201, 202, 92, 4, 176, 152, 47, 36, 221, 204, + 110, 152, 99, 136, 90, 106, 115, 42, 57, 6, 160, 123, 149, + ]), + 20 => BlsFr::from_be_bytes_mod_order(&[ + 80, 151, 23, 224, 194, 0, 227, 201, 45, 141, 202, 41, 115, 179, 219, 69, 240, 120, 130, + 148, 53, 26, 208, 122, 231, 92, 187, 120, 6, 147, 167, 152, + ]), + 21 => BlsFr::from_be_bytes_mod_order(&[ + 114, 153, 178, 132, 100, 168, 201, 79, 185, 212, 223, 97, 56, 15, 57, 192, 220, 169, + 194, 192, 20, 17, 135, 137, 226, 39, 37, 40, 32, 240, 27, 252, + ]), + 22 => BlsFr::from_be_bytes_mod_order(&[ + 4, 76, 163, 204, 74, 133, 215, 59, 129, 105, 110, 241, 16, 78, 103, 79, 79, 239, 248, + 41, 132, 153, 15, 248, 93, 11, 245, 141, 200, 164, 170, 148, + ]), + 23 => BlsFr::from_be_bytes_mod_order(&[ + 28, 186, 242, 179, 113, 218, 198, 168, 29, 4, 83, 65, 109, 62, 35, 92, 184, 217, 226, + 212, 243, 20, 244, 111, 97, 152, 120, 95, 12, 214, 185, 175, + ]), + 24 => BlsFr::from_be_bytes_mod_order(&[ + 29, 91, 39, 119, 105, 44, 32, 91, 14, 108, 73, 208, 97, 182, 181, 244, 41, 60, 74, 176, + 56, 253, 187, 220, 52, 62, 7, 97, 15, 63, 237, 229, + ]), + 25 => BlsFr::from_be_bytes_mod_order(&[ + 86, 174, 124, 122, 82, 147, 189, 194, 62, 133, 225, 105, 140, 129, 199, 127, 138, 216, + 140, 75, 51, 165, 120, 4, 55, 173, 4, 124, 110, 219, 89, 186, + ]), + 26 => BlsFr::from_be_bytes_mod_order(&[ + 46, 155, 219, 186, 61, 211, 75, 255, 170, 48, 83, 91, 221, 116, 154, 126, 6, 169, 173, + 176, 193, 230, 249, 98, 246, 14, 151, 27, 141, 115, 176, 79, + ]), + 27 => BlsFr::from_be_bytes_mod_order(&[ + 45, 225, 24, 134, 177, 128, 17, 202, 139, 213, 186, 227, 105, 105, 41, 159, 222, 64, + 251, 226, 109, 4, 123, 5, 3, 90, 19, 102, 31, 34, 65, 139, + ]), + 28 => BlsFr::from_be_bytes_mod_order(&[ + 46, 7, 222, 23, 128, 184, 167, 13, 13, 91, 74, 63, 24, 65, 220, 216, 42, 185, 57, 92, + 68, 155, 233, 71, 188, 153, 136, 132, 186, 150, 167, 33, + ]), + 29 => BlsFr::from_be_bytes_mod_order(&[ + 15, 105, 241, 133, 77, 32, 202, 12, 187, 219, 99, 219, 213, 45, 173, 22, 37, 4, 64, + 169, 157, 107, 138, 243, 130, 94, 76, 43, 183, 73, 37, 202, + ]), + 30 => BlsFr::from_be_bytes_mod_order(&[ + 93, 201, 135, 49, 142, 110, 89, 193, 175, 184, 123, 101, 93, 213, 140, 193, 210, 46, + 81, 58, 5, 131, 140, 212, 88, 93, 4, 177, 53, 185, 87, 202, + ]), + 31 => BlsFr::from_be_bytes_mod_order(&[ + 72, 183, 37, 117, 133, 113, 201, 223, 108, 1, 220, 99, 154, 133, 240, 114, 151, 105, + 107, 27, 182, 120, 99, 58, 41, 220, 145, 222, 149, 239, 83, 246, + ]), + 32 => BlsFr::from_be_bytes_mod_order(&[ + 94, 86, 94, 8, 192, 130, 16, 153, 37, 107, 86, 73, 14, 174, 225, 213, 115, 175, 209, + 11, 182, 209, 125, 19, 202, 78, 92, 97, 27, 42, 55, 24, + ]), + 33 => BlsFr::from_be_bytes_mod_order(&[ + 46, 177, 178, 84, 23, 254, 23, 103, 13, 19, 93, 198, 57, 251, 9, 164, 108, 229, 17, 53, + 7, 249, 109, 233, 129, 108, 5, 148, 34, 220, 112, 94, + ]), + 34 => BlsFr::from_be_bytes_mod_order(&[ + 17, 92, 208, 160, 100, 60, 251, 152, 140, 36, 203, 68, 195, 250, 180, 138, 255, 54, + 198, 97, 210, 108, 196, 45, 184, 177, 189, 244, 149, 59, 216, 44, + ]), + 35 => BlsFr::from_be_bytes_mod_order(&[ + 38, 202, 41, 63, 123, 44, 70, 45, 6, 109, 115, 120, 185, 153, 134, 139, 187, 87, 221, + 241, 78, 15, 149, 138, 222, 128, 22, 18, 49, 29, 4, 205, + ]), + 36 => BlsFr::from_be_bytes_mod_order(&[ + 65, 71, 64, 13, 142, 26, 172, 207, 49, 26, 107, 91, 118, 32, 17, 171, 62, 69, 50, 110, + 77, 75, 157, 226, 105, 146, 129, 107, 153, 197, 40, 172, + ]), + 37 => BlsFr::from_be_bytes_mod_order(&[ + 107, 13, 183, 220, 204, 75, 161, 178, 104, 246, 189, 204, 77, 55, 40, 72, 212, 167, 41, + 118, 194, 104, 234, 48, 81, 154, 47, 115, 230, 219, 77, 85, + ]), + 38 => BlsFr::from_be_bytes_mod_order(&[ + 23, 191, 27, 147, 196, 199, 224, 26, 42, 131, 10, 161, 98, 65, 44, 217, 15, 22, 11, + 249, 247, 30, 150, 127, 245, 32, 157, 20, 178, 72, 32, 202, + ]), + 39 => BlsFr::from_be_bytes_mod_order(&[ + 75, 67, 28, 217, 239, 237, 188, 148, 207, 30, 202, 111, 158, 156, 24, 57, 208, 230, + 106, 139, 255, 168, 200, 70, 76, 172, 129, 163, 157, 60, 248, 241, + ]), + 40 => BlsFr::from_be_bytes_mod_order(&[ + 53, 180, 26, 122, 196, 243, 197, 113, 162, 79, 132, 86, 54, 156, 133, 223, 224, 60, 3, + 84, 189, 140, 253, 56, 5, 200, 111, 46, 125, 194, 147, 197, + ]), + 41 => BlsFr::from_be_bytes_mod_order(&[ + 59, 20, 128, 8, 5, 35, 196, 57, 67, 89, 39, 153, 72, 73, 190, 169, 100, 225, 77, 59, + 235, 45, 221, 222, 114, 172, 21, 106, 244, 53, 208, 158, + ]), + 42 => BlsFr::from_be_bytes_mod_order(&[ + 44, 198, 129, 0, 49, 220, 27, 13, 73, 80, 133, 109, 201, 7, 213, 117, 8, 226, 134, 68, + 42, 45, 62, 178, 39, 22, 24, 216, 116, 177, 76, 109, + ]), + 43 => BlsFr::from_be_bytes_mod_order(&[ + 111, 65, 65, 200, 64, 28, 90, 57, 91, 166, 121, 14, 253, 113, 199, 12, 4, 175, 234, 6, + 195, 201, 40, 38, 188, 171, 221, 92, 181, 71, 125, 81, + ]), + 44 => BlsFr::from_be_bytes_mod_order(&[ + 37, 189, 187, 237, 161, 189, 232, 193, 5, 150, 24, 226, 175, 210, 239, 153, 158, 81, + 122, 169, 59, 120, 52, 29, 145, 243, 24, 192, 159, 12, 181, 102, + ]), + 45 => BlsFr::from_be_bytes_mod_order(&[ + 57, 42, 74, 135, 88, 224, 110, 232, 185, 95, 51, 194, 93, 222, 138, 192, 42, 94, 208, + 162, 123, 97, 146, 108, 198, 49, 52, 135, 7, 63, 127, 123, + ]), + 46 => BlsFr::from_be_bytes_mod_order(&[ + 39, 42, 85, 135, 138, 8, 68, 43, 154, 166, 17, 31, 77, 224, 9, 72, 94, 106, 111, 209, + 93, 184, 147, 101, 231, 187, 206, 240, 46, 181, 134, 108, + ]), + 47 => BlsFr::from_be_bytes_mod_order(&[ + 99, 30, 193, 214, 210, 141, 217, 232, 36, 238, 137, 163, 7, 48, 174, 247, 171, 70, 58, + 207, 201, 209, 132, 179, 85, 170, 5, 253, 105, 56, 234, 181, + ]), + 48 => BlsFr::from_be_bytes_mod_order(&[ + 78, 182, 253, 161, 15, 208, 251, 222, 2, 199, 68, 155, 251, 221, 195, 91, 205, 130, 37, + 231, 229, 195, 131, 58, 8, 24, 161, 0, 64, 157, 198, 242, + ]), + 49 => BlsFr::from_be_bytes_mod_order(&[ + 45, 91, 48, 139, 12, 240, 44, 223, 239, 161, 60, 78, 96, 226, 98, 57, 166, 235, 186, 1, + 22, 148, 221, 18, 155, 146, 91, 60, 91, 33, 224, 226, + ]), + 50 => BlsFr::from_be_bytes_mod_order(&[ + 22, 84, 159, 198, 175, 47, 59, 114, 221, 93, 41, 61, 114, 226, 229, 242, 68, 223, 244, + 47, 24, 180, 108, 86, 239, 56, 197, 124, 49, 22, 115, 172, + ]), + 51 => BlsFr::from_be_bytes_mod_order(&[ + 66, 51, 38, 119, 255, 53, 156, 94, 141, 184, 54, 217, 245, 251, 84, 130, 46, 57, 189, + 94, 34, 52, 11, 185, 186, 151, 91, 161, 169, 43, 227, 130, + ]), + 52 => BlsFr::from_be_bytes_mod_order(&[ + 73, 215, 210, 192, 180, 73, 229, 23, 155, 197, 204, 195, 180, 76, 96, 117, 217, 132, + 155, 86, 16, 70, 95, 9, 234, 114, 93, 220, 151, 114, 58, 148, + ]), + 53 => BlsFr::from_be_bytes_mod_order(&[ + 100, 194, 15, 185, 13, 122, 0, 56, 49, 117, 124, 196, 198, 34, 111, 110, 73, 133, 252, + 158, 203, 65, 107, 159, 104, 76, 160, 53, 29, 150, 121, 4, + ]), + 54 => BlsFr::from_be_bytes_mod_order(&[ + 89, 207, 244, 13, 232, 59, 82, 180, 27, 196, 67, 215, 151, 149, 16, 215, 113, 201, 64, + 185, 117, 140, 168, 32, 254, 115, 181, 200, 213, 88, 9, 52, + ]), + 55 => BlsFr::from_be_bytes_mod_order(&[ + 83, 219, 39, 49, 115, 12, 57, 176, 78, 221, 135, 95, 227, 183, 200, 130, 128, 130, 133, + 205, 188, 98, 29, 122, 244, 248, 13, 213, 62, 187, 113, 176, + ]), + 56 => BlsFr::from_be_bytes_mod_order(&[ + 27, 16, 187, 122, 130, 175, 206, 57, 250, 105, 195, 162, 173, 82, 247, 109, 118, 57, + 130, 101, 52, 66, 3, 17, 155, 113, 38, 217, 180, 104, 96, 223, + ]), + 57 => BlsFr::from_be_bytes_mod_order(&[ + 86, 27, 96, 18, 214, 102, 191, 225, 121, 196, 221, 127, 132, 205, 209, 83, 21, 150, + 211, 170, 199, 197, 112, 12, 235, 49, 159, 145, 4, 106, 99, 201, + ]), + 58 => BlsFr::from_be_bytes_mod_order(&[ + 15, 30, 117, 5, 235, 217, 29, 47, 199, 156, 45, 247, 220, 152, 163, 190, 209, 179, 105, + 104, 186, 4, 5, 192, 144, 210, 127, 106, 0, 183, 223, 200, + ]), + 59 => BlsFr::from_be_bytes_mod_order(&[ + 47, 49, 63, 175, 13, 63, 97, 135, 83, 122, 116, 151, 163, 180, 63, 70, 121, 127, 214, + 227, 241, 142, 177, 202, 255, 69, 119, 86, 184, 25, 187, 32, + ]), + 60 => BlsFr::from_be_bytes_mod_order(&[ + 58, 92, 187, 109, 228, 80, 180, 129, 250, 60, 166, 28, 14, 209, 91, 197, 92, 173, 17, + 235, 240, 247, 206, 184, 240, 188, 62, 115, 46, 203, 38, 246, + ]), + 61 => BlsFr::from_be_bytes_mod_order(&[ + 104, 29, 147, 65, 27, 248, 206, 99, 246, 113, 106, 239, 189, 14, 36, 80, 100, 84, 192, + 52, 142, 227, 143, 171, 235, 38, 71, 2, 113, 76, 207, 148, + ]), + 62 => BlsFr::from_be_bytes_mod_order(&[ + 81, 120, 233, 64, 245, 0, 4, 49, 38, 70, 180, 54, 114, 127, 14, 128, 167, 184, 242, + 233, 238, 31, 220, 103, 124, 72, 49, 167, 103, 39, 119, 251, + ]), + 63 => BlsFr::from_be_bytes_mod_order(&[ + 61, 171, 84, 188, 155, 239, 104, 141, 217, 32, 134, 226, 83, 180, 57, 214, 81, 186, + 166, 226, 15, 137, 43, 98, 134, 85, 39, 203, 202, 145, 89, 130, + ]), + 64 => BlsFr::from_be_bytes_mod_order(&[ + 75, 60, 231, 83, 17, 33, 143, 154, 233, 5, 248, 78, 170, 91, 43, 56, 24, 68, 139, 191, + 57, 114, 225, 170, 214, 157, 227, 33, 0, 144, 21, 208, + ]), + 65 => BlsFr::from_be_bytes_mod_order(&[ + 6, 219, 251, 66, 185, 121, 136, 77, 226, 128, 211, 22, 112, 18, 63, 116, 76, 36, 179, + 59, 65, 15, 239, 212, 54, 128, 69, 172, 242, 183, 26, 227, + ]), + 66 => BlsFr::from_be_bytes_mod_order(&[ + 6, 141, 107, 70, 8, 170, 232, 16, 198, 240, 57, 234, 25, 115, 166, 62, 184, 210, 222, + 114, 227, 210, 201, 236, 167, 252, 50, 210, 47, 24, 185, 211, + ]), + 67 => BlsFr::from_be_bytes_mod_order(&[ + 76, 92, 37, 69, 137, 169, 42, 54, 8, 74, 87, 211, 177, 217, 100, 39, 138, 204, 126, 79, + 232, 246, 159, 41, 85, 149, 79, 39, 167, 156, 235, 239, + ]), + 68 => BlsFr::from_be_bytes_mod_order(&[ + 108, 186, 197, 225, 112, 9, 132, 235, 195, 45, 161, 91, 75, 185, 104, 63, 170, 186, + 181, 95, 103, 204, 196, 247, 29, 149, 96, 179, 71, 90, 119, 235, + ]), + 69 => BlsFr::from_be_bytes_mod_order(&[ + 70, 3, 196, 3, 187, 250, 154, 23, 115, 138, 92, 98, 120, 234, 171, 28, 55, 236, 48, + 176, 115, 122, 162, 64, 159, 196, 137, 128, 105, 235, 152, 60, + ]), + 70 => BlsFr::from_be_bytes_mod_order(&[ + 104, 148, 231, 226, 43, 44, 29, 92, 112, 167, 18, 166, 52, 90, 230, 177, 146, 169, 200, + 51, 169, 35, 76, 49, 197, 106, 172, 209, 107, 194, 241, 0, + ]), + 71 => BlsFr::from_be_bytes_mod_order(&[ + 91, 226, 203, 188, 68, 5, 58, 208, 138, 250, 77, 30, 171, 199, 243, 210, 49, 238, 167, + 153, 185, 63, 34, 110, 144, 91, 125, 77, 101, 197, 142, 187, + ]), + 72 => BlsFr::from_be_bytes_mod_order(&[ + 88, 229, 95, 40, 123, 69, 58, 152, 8, 98, 74, 140, 42, 53, 61, 82, 141, 160, 247, 231, + 19, 165, 198, 208, 215, 113, 30, 71, 6, 63, 166, 17, + ]), + 73 => BlsFr::from_be_bytes_mod_order(&[ + 54, 110, 191, 175, 163, 173, 56, 28, 14, 226, 88, 201, 184, 253, 252, 205, 184, 104, + 167, 215, 225, 241, 246, 154, 43, 93, 252, 197, 87, 37, 85, 223, + ]), + 74 => BlsFr::from_be_bytes_mod_order(&[ + 69, 118, 106, 183, 40, 150, 140, 100, 47, 144, 217, 124, 207, 85, 4, 221, 193, 5, 24, + 168, 25, 235, 188, 196, 208, 156, 63, 93, 120, 77, 103, 206, + ]), + 75 => BlsFr::from_be_bytes_mod_order(&[ + 57, 103, 143, 101, 81, 47, 30, 228, 4, 219, 48, 36, 244, 29, 63, 86, 126, 246, 109, + 137, 208, 68, 208, 34, 230, 188, 34, 158, 149, 188, 118, 177, + ]), + 76 => BlsFr::from_be_bytes_mod_order(&[ + 70, 58, 237, 29, 47, 31, 149, 94, 48, 120, 190, 91, 247, 191, 196, 111, 192, 235, 140, + 81, 85, 25, 6, 168, 134, 143, 24, 255, 174, 48, 207, 79, + ]), + 77 => BlsFr::from_be_bytes_mod_order(&[ + 33, 102, 143, 1, 106, 128, 99, 192, 213, 139, 119, 80, 163, 188, 47, 225, 207, 130, + 194, 95, 153, 220, 1, 164, 229, 52, 200, 143, 229, 61, 133, 254, + ]), + 78 => BlsFr::from_be_bytes_mod_order(&[ + 57, 208, 9, 148, 168, 165, 4, 106, 27, 199, 73, 54, 62, 152, 167, 104, 227, 77, 234, + 86, 67, 159, 225, 149, 75, 239, 66, 155, 197, 51, 22, 8, + ]), + 79 => BlsFr::from_be_bytes_mod_order(&[ + 77, 127, 93, 205, 120, 236, 233, 169, 51, 152, 77, 227, 44, 11, 72, 250, 194, 187, 169, + 31, 38, 25, 150, 184, 233, 209, 2, 23, 115, 189, 7, 204, + ]), + _ => BlsFr::ZERO, + } +} + +/// A channel that can be used to draw random elements from a PoseidonBLS hash. +#[derive(Clone, Default)] +pub struct PoseidonBLSChannel { + digest: BlsFr, + pub channel_time: ChannelTime, +} + +pub fn poseidon_hash_bls(x: BlsFr, y: BlsFr) -> BlsFr { + let mut state = [x, y, BlsFr::ZERO]; + poseidon_permute_comp_bls(&mut state); + state[0] + x +} + +pub fn poseidon_permute_comp_bls(state: &mut [BlsFr; 3]) { + let mut idx = 0; + mix(state); + + // Full rounds + for _ in 0..4 { + round_comp(state, idx, true); + idx += 3; + } + + // Partial rounds + for _ in 0..56 { + round_comp(state, idx, false); + idx += 1; + } + + // Full rounds + for _ in 0..4 { + round_comp(state, idx, true); + idx += 3; + } +} + +#[inline] +fn round_comp(state: &mut [BlsFr; 3], idx: usize, full: bool) { + if full { + state[0] += poseidon_comp_consts(idx); + state[1] += poseidon_comp_consts(idx + 1); + state[2] += poseidon_comp_consts(idx + 2); + // Optimize multiplication + state[0] = state[0] * state[0] * state[0] * state[0] * state[0]; + state[1] = state[1] * state[1] * state[1] * state[1] * state[1]; + state[2] = state[2] * state[2] * state[2] * state[2] * state[2]; + } else { + state[0] += poseidon_comp_consts(idx); + state[2] = state[2] * state[2] * state[2] * state[2] * state[2]; + } + mix(state); +} + +#[inline(always)] +fn mix(state: &mut [BlsFr; 3]) { + state[0] = state[0] + state[1] + state[2]; + state[1] = state[0] + state[1]; + state[2] = state[0] + state[2]; +} + +pub fn poseidon_hash_many_bls(msgs: &[BlsFr]) -> BlsFr { + let mut state = [BlsFr::ZERO, BlsFr::ZERO, BlsFr::ZERO]; + let mut iter = msgs.chunks_exact(2); + + for msg in iter.by_ref() { + state[0] += msg[0]; + state[1] += msg[1]; + poseidon_permute_comp_bls(&mut state); + } + let r = iter.remainder(); + if r.len() == 1 { + state[0] += r[0]; + } + state[r.len()] += BlsFr::ONE; + poseidon_permute_comp_bls(&mut state); + + state[0] +} + +impl PoseidonBLSChannel { + pub fn digest(&self) -> BlsFr { + self.digest + } + pub fn update_digest(&mut self, new_digest: BlsFr) { + self.digest = new_digest; + self.channel_time.inc_challenges(); + } + fn draw_felt252(&mut self) -> BlsFr { + let res = poseidon_hash_bls(self.digest, BlsFr::from(self.channel_time.n_sent as u64)); + self.channel_time.inc_sent(); + res + } + + // TODO(spapini): Understand if we really need uniformity here. + /// Generates a close-to uniform random vector of BaseField elements. + fn draw_base_felts(&mut self) -> [BaseField; 8] { + let shift = NonZero::new(U256::from_u64(1u64 << 31)).unwrap(); + + let mut cur = self.draw_felt252(); + let u32s: [u32; 8usize] = std::array::from_fn(|_| { + let (quotient, reminder) = + U256::from_be_slice(&cur.into_bigint().to_bytes_be()).div_rem(&shift); + cur = BlsFr::from_be_bytes_mod_order("ient.to_be_bytes()); + u32::from_str_radix(&reminder.to_string(),16).unwrap() + }); + + u32s.into_iter() + .map(|x| BaseField::reduce(x as u64)) + .collect::>() + .try_into() + .unwrap() + } +} + +impl Channel for PoseidonBLSChannel { + const BYTES_PER_HASH: usize = BYTES_PER_FELT252; + + fn trailing_zeros(&self) -> u32 { + let bytes = self.digest.into_bigint().to_bytes_be(); + u128::from_le_bytes(std::array::from_fn(|i| bytes[i])).trailing_zeros() + } + + // TODO(spapini): Optimize. + fn mix_felts(&mut self, felts: &[SecureField]) { + let shift = BlsFr::from(1u64 << 31); + let mut res = Vec::with_capacity(felts.len() / 2 + 2); + res.push(self.digest); + for chunk in felts.chunks(2) { + res.push( + chunk + .iter() + .flat_map(|x| x.to_m31_array()) + .fold(BlsFr::default(), |cur, y| { + cur * shift + BlsFr::from_be_bytes_mod_order(&y.0.to_be_bytes()) + }), + ); + } + + // TODO(spapini): do we need length padding? + self.update_digest(poseidon_hash_many_bls(&res)); + } + + fn mix_nonce(&mut self, nonce: u64) { + self.update_digest(poseidon_hash_bls(self.digest, nonce.into())); + } + + fn draw_felt(&mut self) -> SecureField { + let felts: [BaseField; FELTS_PER_HASH] = self.draw_base_felts(); + SecureField::from_m31_array(felts[..SECURE_EXTENSION_DEGREE].try_into().unwrap()) + } + + fn draw_felts(&mut self, n_felts: usize) -> Vec { + let mut felts = iter::from_fn(|| Some(self.draw_base_felts())).flatten(); + let secure_felts = iter::from_fn(|| { + Some(SecureField::from_m31_array([ + felts.next()?, + felts.next()?, + felts.next()?, + felts.next()?, + ])) + }); + secure_felts.take(n_felts).collect() + } + + fn draw_random_bytes(&mut self) -> Vec { + let shift = NonZero::new(U256::from_u64(1u64 << 8)).unwrap(); + let mut cur = self.draw_felt252(); + let bytes: [u8; 31] = std::array::from_fn(|_| { + let (quotient, reminder) = + U256::from_be_slice(&cur.into_bigint().to_bytes_be()).div_rem(&shift); + cur = BlsFr::from_be_bytes_mod_order("ient.to_be_bytes()); + u8::from_str_radix(&reminder.to_string(),16).unwrap() + }); + bytes.to_vec() + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use crate::core::channel::poseidon_bls::PoseidonBLSChannel; + use crate::core::channel::Channel; + use crate::core::fields::qm31::SecureField; + use crate::m31; + + #[test] + fn test_channel_time() { + let mut channel = PoseidonBLSChannel::default(); + + assert_eq!(channel.channel_time.n_challenges, 0); + assert_eq!(channel.channel_time.n_sent, 0); + + channel.draw_random_bytes(); + assert_eq!(channel.channel_time.n_challenges, 0); + assert_eq!(channel.channel_time.n_sent, 1); + + channel.draw_felts(9); + assert_eq!(channel.channel_time.n_challenges, 0); + assert_eq!(channel.channel_time.n_sent, 6); + } + + #[test] + fn test_draw_random_bytes() { + let mut channel = PoseidonBLSChannel::default(); + + let first_random_bytes = channel.draw_random_bytes(); + + // Assert that next random bytes are different. + assert_ne!(first_random_bytes, channel.draw_random_bytes()); + } + + #[test] + pub fn test_draw_felt() { + let mut channel = PoseidonBLSChannel::default(); + + let first_random_felt = channel.draw_felt(); + + // Assert that next random felt is different. + assert_ne!(first_random_felt, channel.draw_felt()); + } + + #[test] + pub fn test_draw_felts() { + let mut channel = PoseidonBLSChannel::default(); + + let mut random_felts = channel.draw_felts(5); + random_felts.extend(channel.draw_felts(4)); + + // Assert that all the random felts are unique. + assert_eq!( + random_felts.len(), + random_felts.iter().collect::>().len() + ); + } + + #[test] + pub fn test_mix_felts() { + let mut channel = PoseidonBLSChannel::default(); + let initial_digest = channel.digest; + let felts: Vec = (0..2) + .map(|i| SecureField::from(m31!(i + 1923782))) + .collect(); + + channel.mix_felts(felts.as_slice()); + + assert_ne!(initial_digest, channel.digest); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/circle.rs b/Stwo_wrapper/crates/prover/src/core/circle.rs new file mode 100644 index 0000000..8804840 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/circle.rs @@ -0,0 +1,561 @@ +use std::ops::{Add, Div, Mul, Neg, Sub}; + +use num_traits::{One, Zero}; + +use super::fields::m31::{BaseField, M31}; +use super::fields::qm31::SecureField; +use super::fields::{ComplexConjugate, Field, FieldExpOps}; +use crate::core::channel::Channel; +use crate::core::fields::qm31::P4; +use crate::math::utils::egcd; + +/// A point on the complex circle. Treated as an additive group. +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct CirclePoint { + pub x: F, + pub y: F, +} + +impl + FieldExpOps + Sub + Neg> CirclePoint { + pub fn zero() -> Self { + Self { + x: F::one(), + y: F::zero(), + } + } + + pub fn double(&self) -> Self { + *self + *self + } + + /// Applies the circle's x-coordinate doubling map. + /// + /// # Examples + /// + /// ``` + /// use stwo_prover::core::circle::{CirclePoint, M31_CIRCLE_GEN}; + /// use stwo_prover::core::fields::m31::M31; + /// let p = M31_CIRCLE_GEN.mul(17); + /// assert_eq!(CirclePoint::double_x(p.x), (p + p).x); + /// ``` + pub fn double_x(x: F) -> F { + let sx = x.square(); + sx + sx - F::one() + } + + /// Returns the log order of a point. + /// + /// All points have an order of the form `2^k`. + /// + /// # Examples + /// + /// ``` + /// use stwo_prover::core::circle::{CirclePoint, M31_CIRCLE_GEN, M31_CIRCLE_LOG_ORDER}; + /// use stwo_prover::core::fields::m31::M31; + /// assert_eq!(M31_CIRCLE_GEN.log_order(), M31_CIRCLE_LOG_ORDER); + /// ``` + pub fn log_order(&self) -> u32 + where + F: PartialEq + Eq, + { + // we only need the x-coordinate to check order since the only point + // with x=1 is the circle's identity + let mut res = 0; + let mut cur = self.x; + while cur != F::one() { + cur = Self::double_x(cur); + res += 1; + } + res + } + + pub fn mul(&self, mut scalar: u128) -> CirclePoint { + let mut res = Self::zero(); + let mut cur = *self; + while scalar > 0 { + if scalar & 1 == 1 { + res = res + cur; + } + cur = cur.double(); + scalar >>= 1; + } + res + } + + pub fn repeated_double(&self, n: u32) -> Self { + let mut res = *self; + for _ in 0..n { + res = res.double(); + } + res + } + + pub fn conjugate(&self) -> CirclePoint { + Self { + x: self.x, + y: -self.y, + } + } + + pub fn antipode(&self) -> CirclePoint { + Self { + x: -self.x, + y: -self.y, + } + } + + pub fn into_ef>(&self) -> CirclePoint { + CirclePoint { + x: self.x.into(), + y: self.y.into(), + } + } + + pub fn mul_signed(&self, off: isize) -> CirclePoint { + if off > 0 { + self.mul(off as u128) + } else { + self.conjugate().mul(-off as u128) + } + } +} + +impl + FieldExpOps + Sub + Neg> Add + for CirclePoint +{ + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + let x = self.x * rhs.x - self.y * rhs.y; + let y = self.x * rhs.y + self.y * rhs.x; + Self { x, y } + } +} + +impl + FieldExpOps + Sub + Neg> Neg + for CirclePoint +{ + type Output = Self; + + fn neg(self) -> Self::Output { + self.conjugate() + } +} + +impl + FieldExpOps + Sub + Neg> Sub + for CirclePoint +{ + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + self + (-rhs) + } +} + +impl ComplexConjugate for CirclePoint { + fn complex_conjugate(&self) -> Self { + Self { + x: self.x.complex_conjugate(), + y: self.y.complex_conjugate(), + } + } +} + +impl CirclePoint { + pub fn get_point(index: u128) -> Self { + assert!(index < SECURE_FIELD_CIRCLE_ORDER); + SECURE_FIELD_CIRCLE_GEN.mul(index) + } + + pub fn get_random_point(channel: &mut C) -> Self { + let t = channel.draw_felt(); + let t_square = t.square(); + + let one_plus_tsquared_inv = t_square.add(SecureField::one()).inverse(); + + let x = SecureField::one() + .add(t_square.neg()) + .mul(one_plus_tsquared_inv); + let y = t.double().mul(one_plus_tsquared_inv); + + Self { x, y } + } +} + +/// A generator for the circle group over [M31]. +/// +/// # Examples +/// +/// ``` +/// use stwo_prover::core::circle::{CirclePoint, M31_CIRCLE_GEN}; +/// use stwo_prover::core::fields::m31::M31; +/// +/// // Adding a generator to itself (2^30) times should NOT yield the identity. +/// let circle_point = M31_CIRCLE_GEN.repeated_double(30); +/// assert_ne!(circle_point, CirclePoint::zero()); +/// +/// // Shown above ord(M31_CIRCLE_GEN) > 2^30 . Group order is 2^31. +/// // Ord(M31_CIRCLE_GEN) must be a divisor of it, Hence ord(M31_CIRCLE_GEN) = 2^31. +/// // Adding the generator to itself (2^31) times should yield the identity. +/// let circle_point = M31_CIRCLE_GEN.repeated_double(31); +/// assert_eq!(circle_point, CirclePoint::zero()); +/// ``` +pub const M31_CIRCLE_GEN: CirclePoint = CirclePoint { + x: M31::from_u32_unchecked(2), + y: M31::from_u32_unchecked(1268011823), +}; + +/// Order of [M31_CIRCLE_GEN]. +pub const M31_CIRCLE_LOG_ORDER: u32 = 31; + +/// A generator for the circle group over [SecureField]. +pub const SECURE_FIELD_CIRCLE_GEN: CirclePoint = CirclePoint { + x: SecureField::from_u32_unchecked(1, 0, 478637715, 513582971), + y: SecureField::from_u32_unchecked(992285211, 649143431, 740191619, 1186584352), +}; + +/// Order of [SECURE_FIELD_CIRCLE_GEN]. +pub const SECURE_FIELD_CIRCLE_ORDER: u128 = P4 - 1; + +/// Integer i that represent the circle point i * CIRCLE_GEN. Treated as an +/// additive ring modulo `1 << M31_CIRCLE_LOG_ORDER`. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Ord, PartialOrd)] +pub struct CirclePointIndex(pub usize); + +impl CirclePointIndex { + pub fn zero() -> Self { + Self(0) + } + + pub fn generator() -> Self { + Self(1) + } + + pub fn reduce(self) -> Self { + Self(self.0 & ((1 << M31_CIRCLE_LOG_ORDER) - 1)) + } + + pub fn subgroup_gen(log_size: u32) -> Self { + assert!(log_size <= M31_CIRCLE_LOG_ORDER); + Self(1 << (M31_CIRCLE_LOG_ORDER - log_size)) + } + + pub fn to_point(self) -> CirclePoint { + M31_CIRCLE_GEN.mul(self.0 as u128) + } + + pub fn half(self) -> Self { + assert!(self.0 & 1 == 0); + Self(self.0 >> 1) + } + + pub fn try_div(&self, rhs: CirclePointIndex) -> Option { + // Find x s.t. x * rhs.0 = self.0 (mod CIRCLE_ORDER). + let (s, _t, g) = egcd(rhs.0 as isize, 1 << M31_CIRCLE_LOG_ORDER); + if self.0 as isize % g != 0 { + return None; + } + let res = s * self.0 as isize / g; + Some(res) + } +} + +impl Add for CirclePointIndex { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self(self.0 + rhs.0).reduce() + } +} + +impl Sub for CirclePointIndex { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self(self.0 + (1 << M31_CIRCLE_LOG_ORDER) - rhs.0).reduce() + } +} + +impl Mul for CirclePointIndex { + type Output = Self; + + fn mul(self, rhs: usize) -> Self::Output { + Self(self.0.wrapping_mul(rhs)).reduce() + } +} + +impl Div for CirclePointIndex { + type Output = isize; + + fn div(self, rhs: Self) -> Self::Output { + self.try_div(rhs).unwrap() + } +} + +impl Neg for CirclePointIndex { + type Output = Self; + + fn neg(self) -> Self::Output { + Self((1 << M31_CIRCLE_LOG_ORDER) - self.0).reduce() + } +} + +/// Represents the coset initial + \. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct Coset { + pub initial_index: CirclePointIndex, + pub initial: CirclePoint, + pub step_size: CirclePointIndex, + pub step: CirclePoint, + pub log_size: u32, +} + +impl Coset { + pub fn new(initial_index: CirclePointIndex, log_size: u32) -> Self { + assert!(log_size <= M31_CIRCLE_LOG_ORDER); + let step_size = CirclePointIndex::subgroup_gen(log_size); + Self { + initial_index, + initial: initial_index.to_point(), + step: step_size.to_point(), + step_size, + log_size, + } + } + + /// Creates a coset of the form . + /// For example, for n=8, we get the point indices \[0,1,2,3,4,5,6,7\]. + pub fn subgroup(log_size: u32) -> Self { + Self::new(CirclePointIndex::zero(), log_size) + } + + /// Creates a coset of the form G_2n + \. + /// For example, for n=8, we get the point indices \[1,3,5,7,9,11,13,15\]. + pub fn odds(log_size: u32) -> Self { + Self::new(CirclePointIndex::subgroup_gen(log_size + 1), log_size) + } + + /// Creates a coset of the form G_4n + . + /// For example, for n=8, we get the point indices \[1,5,9,13,17,21,25,29\]. + /// Its conjugate will be \[3,7,11,15,19,23,27,31\]. + pub fn half_odds(log_size: u32) -> Self { + Self::new(CirclePointIndex::subgroup_gen(log_size + 2), log_size) + } + + /// Returns the size of the coset. + pub fn size(&self) -> usize { + 1 << self.log_size() + } + + /// Returns the log size of the coset. + pub fn log_size(&self) -> u32 { + self.log_size + } + + pub fn iter(&self) -> CosetIterator> { + CosetIterator { + cur: self.initial, + step: self.step, + remaining: self.size(), + } + } + + pub fn iter_indices(&self) -> CosetIterator { + CosetIterator { + cur: self.initial_index, + step: self.step_size, + remaining: self.size(), + } + } + + /// Returns a new coset comprising of all points in current coset doubled. + pub fn double(&self) -> Self { + assert!(self.log_size > 0); + Self { + initial_index: self.initial_index * 2, + initial: self.initial.double(), + step: self.step.double(), + step_size: self.step_size * 2, + log_size: self.log_size.saturating_sub(1), + } + } + + pub fn repeated_double(&self, n_doubles: u32) -> Self { + (0..n_doubles).fold(*self, |coset, _| coset.double()) + } + + pub fn is_doubling_of(&self, other: Self) -> bool { + self.log_size <= other.log_size + && *self == other.repeated_double(other.log_size - self.log_size) + } + + pub fn initial(&self) -> CirclePoint { + self.initial + } + + pub fn index_at(&self, index: usize) -> CirclePointIndex { + self.initial_index + self.step_size.mul(index) + } + + pub fn at(&self, index: usize) -> CirclePoint { + self.index_at(index).to_point() + } + + pub fn shift(&self, shift_size: CirclePointIndex) -> Self { + let initial_index = self.initial_index + shift_size; + Self { + initial_index, + initial: initial_index.to_point(), + ..*self + } + } + + /// Creates the conjugate coset: -initial -\. + pub fn conjugate(&self) -> Self { + let initial_index = -self.initial_index; + let step_size = -self.step_size; + Self { + initial_index, + initial: initial_index.to_point(), + step_size, + step: step_size.to_point(), + log_size: self.log_size, + } + } + + pub fn find(&self, i: CirclePointIndex) -> Option { + let res = (i - self.initial_index).try_div(self.step_size)?; + Some(res.rem_euclid(self.size() as isize) as usize) + } +} + +impl IntoIterator for Coset { + type Item = CirclePoint; + type IntoIter = CosetIterator>; + + /// Iterates over the points in the coset. + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +#[derive(Clone)] +pub struct CosetIterator { + pub cur: T, + pub step: T, + pub remaining: usize, +} + +impl + Copy> Iterator for CosetIterator { + type Item = T; + + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + self.remaining -= 1; + let res = self.cur; + self.cur = self.cur + self.step; + Some(res) + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeSet; + + use num_traits::{One, Pow}; + + use super::{CirclePointIndex, Coset}; + use crate::core::channel::Blake2sChannel; + use crate::core::circle::{CirclePoint, SECURE_FIELD_CIRCLE_GEN}; + use crate::core::fields::qm31::{SecureField, P4}; + use crate::core::fields::FieldExpOps; + use crate::core::poly::circle::CanonicCoset; + + #[test] + fn test_iterator() { + let coset = Coset::new(CirclePointIndex(1), 3); + let actual_indices: Vec<_> = coset.iter_indices().collect(); + let expected_indices = vec![ + CirclePointIndex(1), + CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 1, + CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 2, + CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 3, + CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 4, + CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 5, + CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 6, + CirclePointIndex(1) + CirclePointIndex::subgroup_gen(3) * 7, + ]; + assert_eq!(actual_indices, expected_indices); + + let actual_points = coset.iter().collect::>(); + let expected_points: Vec<_> = expected_indices.iter().map(|i| i.to_point()).collect(); + assert_eq!(actual_points, expected_points); + } + + #[test] + fn test_coset_is_half_coset_with_conjugate() { + let canonic_coset = CanonicCoset::new(8); + let coset_points = BTreeSet::from_iter(canonic_coset.coset().iter()); + + let half_coset_points = BTreeSet::from_iter(canonic_coset.half_coset().iter()); + let half_coset_conjugate_points = + BTreeSet::from_iter(canonic_coset.half_coset().conjugate().iter()); + + assert!((&half_coset_points & &half_coset_conjugate_points).is_empty()); + assert_eq!( + coset_points, + &half_coset_points | &half_coset_conjugate_points + ) + } + + #[test] + pub fn test_get_random_circle_point() { + let mut channel = Blake2sChannel::default(); + + let first_random_circle_point = CirclePoint::get_random_point(&mut channel); + + // Assert that the next random circle point is different. + assert_ne!( + first_random_circle_point, + CirclePoint::get_random_point(&mut channel) + ); + } + + #[test] + pub fn test_secure_field_circle_gen() { + let prime_factors = [ + (2, 33), + (3, 2), + (5, 1), + (7, 1), + (11, 1), + (31, 1), + (151, 1), + (331, 1), + (733, 1), + (1709, 1), + (368140581013, 1), + ]; + + assert_eq!( + prime_factors + .iter() + .map(|(p, e)| p.pow(*e as u32)) + .product::(), + P4 - 1 + ); + assert_eq!( + SECURE_FIELD_CIRCLE_GEN.x.square() + SECURE_FIELD_CIRCLE_GEN.y.square(), + SecureField::one() + ); + assert_eq!(SECURE_FIELD_CIRCLE_GEN.mul(P4 - 1), CirclePoint::zero()); + for (p, _) in prime_factors.iter() { + assert_ne!( + SECURE_FIELD_CIRCLE_GEN.mul((P4 - 1) / *p), + CirclePoint::zero() + ); + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/constraints.rs b/Stwo_wrapper/crates/prover/src/core/constraints.rs new file mode 100644 index 0000000..776951f --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/constraints.rs @@ -0,0 +1,251 @@ +use num_traits::One; + +use super::circle::{CirclePoint, Coset}; +use super::fields::m31::BaseField; +use super::fields::qm31::SecureField; +use super::fields::ExtensionOf; +use super::pcs::quotients::PointSample; +use crate::core::fields::ComplexConjugate; + +/// Evaluates a vanishing polynomial of the coset at a point. +pub fn coset_vanishing>(coset: Coset, mut p: CirclePoint) -> F { + // Doubling a point `log_order - 1` times and taking the x coordinate is + // essentially evaluating a polynomial in x of degree `2^(log_order - 1)`. If + // the entire `2^log_order` points of the coset are roots (i.e. yield 0), then + // this is a vanishing polynomial of these points. + + // Rotating the coset -coset.initial + step / 2 yields a canonic coset: + // `step/2 + .` + // Doubling this coset log_order - 1 times yields the coset +-G_4. + // The polynomial x vanishes on these points. + // ```text + // X + // . . + // X + // ``` + p = p - coset.initial.into_ef() + coset.step_size.half().to_point().into_ef(); + //println!("p - coset.initial.into_ef() + coset.step_size.half().to_point().into_ef() = {:?} - {:?} + {:?}",p,coset.initial,coset.step_size.half().0); + let mut x = p.x; + + // The formula for the x coordinate of the double of a point. + for _ in 1..coset.log_size { + x = CirclePoint::double_x(x); + } + x +} + +/// Evaluates the polynomial that is used to exclude the excluded point at point +/// p. Note that this polynomial has a zero of multiplicity 2 at the excluded +/// point. +pub fn point_excluder>( + excluded: CirclePoint, + p: CirclePoint, +) -> F { + (p - excluded.into_ef()).x - BaseField::one() +} + +// A vanishing polynomial on 2 circle points. +pub fn pair_vanishing>( + excluded0: CirclePoint, + excluded1: CirclePoint, + p: CirclePoint, +) -> F { + // The algorithm check computes the area of the triangle formed by the + // 3 points. This is done using the determinant of: + // | p.x p.y 1 | + // | e0.x e0.y 1 | + // | e1.x e1.y 1 | + // This is a polynomial of degree 1 in p.x and p.y, and thus it is a line. + // It vanishes at e0 and e1. + (excluded0.y - excluded1.y) * p.x + + (excluded1.x - excluded0.x) * p.y + + (excluded0.x * excluded1.y - excluded0.y * excluded1.x) +} + +/// Evaluates a vanishing polynomial of the vanish_point at a point. +/// Note that this function has a pole on the antipode of the vanish_point. +pub fn point_vanishing, EF: ExtensionOf>( + vanish_point: CirclePoint, + p: CirclePoint, +) -> EF { + let h = p - vanish_point.into_ef(); + h.y / (EF::one() + h.x) +} + +/// Evaluates a point on a line between a point and its complex conjugate. +/// Relies on the fact that every polynomial F over the base field holds: +/// F(p*) == F(p)* (* being the complex conjugate). +pub fn complex_conjugate_line( + point: CirclePoint, + value: SecureField, + p: CirclePoint, +) -> SecureField { + // TODO(AlonH): This assertion will fail at a probability of 1 to 2^62. Use a better solution. + assert_ne!( + point.y, + point.y.complex_conjugate(), + "Cannot evaluate a line with a single point ({point:?})." + ); + value + + (value.complex_conjugate() - value) * (-point.y + p.y) + / (point.complex_conjugate().y - point.y) +} + +/// Evaluates the coefficients of a line between a point and its complex conjugate. Specifically, +/// `a, b, and c, s.t. a*x + b -c*y = 0` for (x,y) being (sample.y, sample.value) and +/// (conj(sample.y), conj(sample.value)). +/// Relies on the fact that every polynomial F over the base +/// field holds: F(p*) == F(p)* (* being the complex conjugate). +pub fn complex_conjugate_line_coeffs( + sample: &PointSample, + alpha: SecureField, +) -> (SecureField, SecureField, SecureField) { + // TODO(AlonH): This assertion will fail at a probability of 1 to 2^62. Use a better solution. + assert_ne!( + sample.point.y, + sample.point.y.complex_conjugate(), + "Cannot evaluate a line with a single point ({:?}).", + sample.point + ); + let a = sample.value.complex_conjugate() - sample.value; + let c = sample.point.complex_conjugate().y - sample.point.y; + let b = sample.value * c - a * sample.point.y; + (alpha * a, alpha * b, alpha * c) +} + +#[cfg(test)] +mod tests { + use num_traits::Zero; + + use super::{coset_vanishing, point_excluder, point_vanishing}; + use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly}; + use crate::core::circle::{CirclePoint, CirclePointIndex, Coset}; + use crate::core::constraints::{complex_conjugate_line, pair_vanishing}; + use crate::core::fields::m31::{BaseField, M31}; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::{ComplexConjugate, FieldExpOps}; + use crate::core::poly::circle::CanonicCoset; + use crate::core::poly::NaturalOrder; + use crate::core::test_utils::secure_eval_to_base_eval; + use crate::m31; + + #[test] + fn test_coset_vanishing() { + let cosets = [ + Coset::half_odds(5), + Coset::odds(5), + Coset::new(CirclePointIndex::zero(), 5), + Coset::half_odds(5).conjugate(), + ]; + for c0 in cosets.iter() { + for el in c0.iter() { + assert_eq!(coset_vanishing(*c0, el), BaseField::zero()); + for c1 in cosets.iter() { + if c0 == c1 { + continue; + } + assert_ne!(coset_vanishing(*c1, el), BaseField::zero()); + } + } + } + } + + #[test] + fn test_point_excluder() { + let excluded = Coset::half_odds(5).at(10); + let point = (CirclePointIndex::generator() * 4).to_point(); + + let num = point_excluder(excluded, point) * point_excluder(excluded.conjugate(), point); + let denom = (point.x - excluded.x).pow(2); + + assert_eq!(num, denom); + } + + #[test] + fn test_pair_excluder() { + let excluded0 = Coset::half_odds(5).at(10); + let excluded1 = Coset::half_odds(5).at(13); + let point = (CirclePointIndex::generator() * 4).to_point(); + + assert_ne!(pair_vanishing(excluded0, excluded1, point), M31::zero()); + assert_eq!(pair_vanishing(excluded0, excluded1, excluded0), M31::zero()); + assert_eq!(pair_vanishing(excluded0, excluded1, excluded1), M31::zero()); + } + + #[test] + fn test_point_vanishing_success() { + let coset = Coset::odds(5); + let vanish_point = coset.at(2); + for el in coset.iter() { + if el == vanish_point { + assert_eq!(point_vanishing(vanish_point, el), BaseField::zero()); + continue; + } + if el == vanish_point.antipode() { + continue; + } + assert_ne!(point_vanishing(vanish_point, el), BaseField::zero()); + } + } + + #[test] + #[should_panic(expected = "0 has no inverse")] + fn test_point_vanishing_failure() { + let coset = Coset::half_odds(6); + let point = coset.at(4); + point_vanishing(point, point.antipode()); + } + + #[test] + fn test_complex_conjugate_symmetry() { + // Create a polynomial over a base circle domain. + let polynomial = CpuCirclePoly::new((0..1 << 7).map(|i| m31!(i)).collect()); + let oods_point = CirclePoint::get_point(9834759221); + + // Assert that the base field polynomial is complex conjugate symmetric. + assert_eq!( + polynomial.eval_at_point(oods_point.complex_conjugate()), + polynomial.eval_at_point(oods_point).complex_conjugate() + ); + } + + #[test] + fn test_point_vanishing_degree() { + // Create a polynomial over a circle domain. + let log_domain_size = 7; + let domain_size = 1 << log_domain_size; + let polynomial = CpuCirclePoly::new((0..domain_size).map(|i| m31!(i)).collect()); + + // Create a larger domain. + let log_large_domain_size = log_domain_size + 1; + let large_domain_size = 1 << log_large_domain_size; + let large_domain = CanonicCoset::new(log_large_domain_size).circle_domain(); + + // Create a vanish point that is not in the large domain. + let vanish_point = CirclePoint::get_point(97); + let vanish_point_value = polynomial.eval_at_point(vanish_point); + + // Compute the quotient polynomial. + let mut quotient_polynomial_values = Vec::with_capacity(large_domain_size as usize); + for point in large_domain.iter() { + let line = complex_conjugate_line(vanish_point, vanish_point_value, point); + let mut value = polynomial.eval_at_point(point.into_ef()) - line; + value /= pair_vanishing( + vanish_point, + vanish_point.complex_conjugate(), + point.into_ef(), + ); + quotient_polynomial_values.push(value); + } + let quotient_evaluation = CpuCircleEvaluation::::new( + large_domain, + quotient_polynomial_values, + ); + let quotient_polynomial = secure_eval_to_base_eval("ient_evaluation) + .bit_reverse() + .interpolate(); + + // Check that the quotient polynomial is indeed in the wanted fft space. + assert!(quotient_polynomial.is_in_fft_space(log_domain_size)); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/fft.rs b/Stwo_wrapper/crates/prover/src/core/fft.rs new file mode 100644 index 0000000..630fbe7 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/fft.rs @@ -0,0 +1,21 @@ +use std::ops::{Add, AddAssign, Mul, Sub}; + +use super::fields::m31::BaseField; + +pub fn butterfly(v0: &mut F, v1: &mut F, twid: BaseField) +where + F: Copy + AddAssign + Sub + Mul, +{ + let tmp = *v1 * twid; + *v1 = *v0 - tmp; + *v0 += tmp; +} + +pub fn ibutterfly(v0: &mut F, v1: &mut F, itwid: BaseField) +where + F: Copy + AddAssign + Add + Sub + Mul, +{ + let tmp = *v0; + *v0 = tmp + *v1; + *v1 = (tmp - *v1) * itwid; +} diff --git a/Stwo_wrapper/crates/prover/src/core/fields/cm31.rs b/Stwo_wrapper/crates/prover/src/core/fields/cm31.rs new file mode 100644 index 0000000..6f1b6c2 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/fields/cm31.rs @@ -0,0 +1,137 @@ +use std::fmt::{Debug, Display}; +use std::ops::{ + Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, +}; + +use serde::{Deserialize, Serialize}; + +use super::{ComplexConjugate, FieldExpOps}; +use crate::core::fields::m31::M31; +use crate::{impl_extension_field, impl_field}; +pub const P2: u64 = 4611686014132420609; // (2 ** 31 - 1) ** 2 + +/// Complex extension field of M31. +/// Equivalent to M31\[x\] over (x^2 + 1) as the irreducible polynomial. +/// Represented as (a, b) of a + bi. +#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] +pub struct CM31(pub M31, pub M31); + +impl_field!(CM31, P2); +impl_extension_field!(CM31, M31); + +impl CM31 { + pub const fn from_u32_unchecked(a: u32, b: u32) -> CM31 { + Self(M31::from_u32_unchecked(a), M31::from_u32_unchecked(b)) + } + + pub fn from_m31(a: M31, b: M31) -> CM31 { + Self(a, b) + } +} + +impl Display for CM31 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} + {}i", self.0, self.1) + } +} + +impl Debug for CM31 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} + {}i", self.0, self.1) + } +} + +impl Mul for CM31 { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + // (a + bi) * (c + di) = (ac - bd) + (ad + bc)i. + Self( + self.0 * rhs.0 - self.1 * rhs.1, + self.0 * rhs.1 + self.1 * rhs.0, + ) + } +} + +impl TryInto for CM31 { + type Error = (); + + fn try_into(self) -> Result { + if self.1 != M31::zero() { + return Err(()); + } + Ok(self.0) + } +} + +impl FieldExpOps for CM31 { + fn inverse(&self) -> Self { + assert!(!self.is_zero(), "0 has no inverse"); + // 1 / (a + bi) = (a - bi) / (a^2 + b^2). + Self(self.0, -self.1) * (self.0.square() + self.1.square()).inverse() + } +} + +#[cfg(test)] +#[macro_export] +macro_rules! cm31 { + ($m0:expr, $m1:expr) => { + CM31::from_u32_unchecked($m0, $m1) + }; +} + +#[cfg(test)] +mod tests { + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::CM31; + use crate::core::fields::m31::P; + use crate::core::fields::{FieldExpOps, IntoSlice}; + use crate::m31; + + #[test] + fn test_inverse() { + let cm = cm31!(1, 2); + let cm_inv = cm.inverse(); + assert_eq!(cm * cm_inv, cm31!(1, 0)); + } + + #[test] + fn test_ops() { + let cm0 = cm31!(1, 2); + let cm1 = cm31!(4, 5); + let m = m31!(8); + let cm = CM31::from(m); + let cm0_x_cm1 = cm31!(P - 6, 13); + + assert_eq!(cm0 + cm1, cm31!(5, 7)); + assert_eq!(cm1 + m, cm1 + cm); + assert_eq!(cm0 * cm1, cm0_x_cm1); + assert_eq!(cm1 * m, cm1 * cm); + assert_eq!(-cm0, cm31!(P - 1, P - 2)); + assert_eq!(cm0 - cm1, cm31!(P - 3, P - 3)); + assert_eq!(cm1 - m, cm1 - cm); + assert_eq!(cm0_x_cm1 / cm1, cm31!(1, 2)); + assert_eq!(cm1 / m, cm1 / cm); + } + + #[test] + fn test_into_slice() { + let mut rng = SmallRng::seed_from_u64(0); + let x = (0..100).map(|_| rng.gen()).collect::>(); + + let slice = CM31::into_slice(&x); + + for i in 0..100 { + let corresponding_sub_slice = &slice[i * 8..(i + 1) * 8]; + assert_eq!( + x[i], + cm31!( + u32::from_le_bytes(corresponding_sub_slice[..4].try_into().unwrap()), + u32::from_le_bytes(corresponding_sub_slice[4..].try_into().unwrap()) + ) + ) + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/fields/m31.rs b/Stwo_wrapper/crates/prover/src/core/fields/m31.rs new file mode 100644 index 0000000..852f959 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/fields/m31.rs @@ -0,0 +1,258 @@ +use std::fmt::Display; +use std::ops::{ + Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, +}; + +use bytemuck::{Pod, Zeroable}; +use rand::distributions::{Distribution, Standard}; +use serde::{Deserialize, Serialize}; + +use super::{ComplexConjugate, FieldExpOps}; +use crate::impl_field; +pub const MODULUS_BITS: u32 = 31; +pub const N_BYTES_FELT: usize = 4; +pub const P: u32 = 2147483647; // 2 ** 31 - 1 + +#[repr(transparent)] +#[derive( + Copy, + Clone, + Debug, + Default, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + Pod, + Zeroable, + Serialize, + Deserialize, +)] +pub struct M31(pub u32); +pub type BaseField = M31; + +impl_field!(M31, P); + +impl M31 { + /// Returns `val % P` when `val` is in the range `[0, 2P)`. + /// + /// ``` + /// use stwo_prover::core::fields::m31::{M31, P}; + /// + /// let val = 2 * P - 19; + /// assert_eq!(M31::partial_reduce(val), M31::from(P - 19)); + /// ``` + pub fn partial_reduce(val: u32) -> Self { + Self(val.checked_sub(P).unwrap_or(val)) + } + + /// Returns `val % P` when `val` is in the range `[0, P^2)`. + /// + /// ``` + /// use stwo_prover::core::fields::m31::{M31, P}; + /// + /// let val = (P as u64).pow(2) - 19; + /// assert_eq!(M31::reduce(val), M31::from(P - 19)); + /// ``` + pub fn reduce(val: u64) -> Self { + Self((((((val >> MODULUS_BITS) + val + 1) >> MODULUS_BITS) + val) & (P as u64)) as u32) + } + + pub const fn from_u32_unchecked(arg: u32) -> Self { + Self(arg) + } +} + +impl Display for M31 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl Add for M31 { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self::partial_reduce(self.0 + rhs.0) + } +} + +impl Neg for M31 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self::partial_reduce(P - self.0) + } +} + +impl Sub for M31 { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self::partial_reduce(self.0 + P - rhs.0) + } +} + +impl Mul for M31 { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + Self::reduce((self.0 as u64) * (rhs.0 as u64)) + } +} + +impl FieldExpOps for M31 { + /// ``` + /// use num_traits::One; + /// use stwo_prover::core::fields::m31::BaseField; + /// use stwo_prover::core::fields::FieldExpOps; + /// + /// let v = BaseField::from(19); + /// assert_eq!(v.inverse() * v, BaseField::one()); + /// ``` + fn inverse(&self) -> Self { + assert!(!self.is_zero(), "0 has no inverse"); + pow2147483645(*self) + } +} + +impl ComplexConjugate for M31 { + fn complex_conjugate(&self) -> Self { + *self + } +} + +impl One for M31 { + fn one() -> Self { + Self(1) + } +} + +impl Zero for M31 { + fn zero() -> Self { + Self(0) + } + + fn is_zero(&self) -> bool { + *self == Self::zero() + } +} + +impl From for M31 { + fn from(value: usize) -> Self { + M31::reduce(value.try_into().unwrap()) + } +} + +impl From for M31 { + fn from(value: u32) -> Self { + M31::reduce(value.into()) + } +} + +impl From for M31 { + fn from(value: i32) -> Self { + M31::reduce(value.try_into().unwrap()) + } +} + +impl Distribution for Standard { + // Not intended for cryptographic use. Should only be used in tests and benchmarks. + fn sample(&self, rng: &mut R) -> M31 { + M31(rng.gen_range(0..P)) + } +} + +#[cfg(test)] +#[macro_export] +macro_rules! m31 { + ($m:expr) => { + $crate::core::fields::m31::M31::from_u32_unchecked($m) + }; +} + +/// Computes `v^((2^31-1)-2)`. +/// +/// Computes the multiplicative inverse of [`M31`] elements with 37 multiplications vs naive 60 +/// multiplications. Made generic to support both vectorized and non-vectorized implementations. +/// Multiplication tree found with [addchain](https://github.com/mmcloughlin/addchain). +/// +/// ``` +/// use stwo_prover::core::fields::m31::{pow2147483645, BaseField}; +/// use stwo_prover::core::fields::FieldExpOps; +/// +/// let v = BaseField::from(19); +/// assert_eq!(pow2147483645(v), v.pow(2147483645)); +/// ``` +pub fn pow2147483645(v: T) -> T { + let t0 = sqn::<2, T>(v) * v; + let t1 = sqn::<1, T>(t0) * t0; + let t2 = sqn::<3, T>(t1) * t0; + let t3 = sqn::<1, T>(t2) * t0; + let t4 = sqn::<8, T>(t3) * t3; + let t5 = sqn::<8, T>(t4) * t3; + sqn::<7, T>(t5) * t2 +} + +/// Computes `v^(2*n)`. +fn sqn(mut v: T) -> T { + for _ in 0..N { + v = v.square(); + } + v +} + +#[cfg(test)] +mod tests { + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::{M31, P}; + use crate::core::fields::IntoSlice; + + fn mul_p(a: u32, b: u32) -> u32 { + ((a as u64 * b as u64) % P as u64) as u32 + } + + fn add_p(a: u32, b: u32) -> u32 { + (a + b) % P + } + + fn neg_p(a: u32) -> u32 { + if a == 0 { + 0 + } else { + P - a + } + } + + #[test] + fn test_basic_ops() { + let mut rng = SmallRng::seed_from_u64(0); + for _ in 0..10000 { + let x: u32 = rng.gen::() % P; + let y: u32 = rng.gen::() % P; + assert_eq!(m31!(add_p(x, y)), m31!(x) + m31!(y)); + assert_eq!(m31!(mul_p(x, y)), m31!(x) * m31!(y)); + assert_eq!(m31!(neg_p(x)), -m31!(x)); + } + } + + #[test] + fn test_into_slice() { + let mut rng = SmallRng::seed_from_u64(0); + let x = (0..100).map(|_| rng.gen()).collect::>(); + + let slice = M31::into_slice(&x); + + for i in 0..100 { + assert_eq!( + x[i], + m31!(u32::from_le_bytes( + slice[i * 4..(i + 1) * 4].try_into().unwrap() + )) + ); + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/fields/mod.rs b/Stwo_wrapper/crates/prover/src/core/fields/mod.rs new file mode 100644 index 0000000..fbeefbb --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/fields/mod.rs @@ -0,0 +1,489 @@ +use std::fmt::{Debug, Display}; +use std::iter::{Product, Sum}; +use std::ops::{Mul, MulAssign, Neg}; + +use num_traits::{NumAssign, NumAssignOps, NumOps, One}; + +use super::backend::ColumnOps; + +pub mod cm31; +pub mod m31; +pub mod qm31; +pub mod secure_column; + +pub trait FieldOps: ColumnOps { + // TODO(Ohad): change to use a mutable slice. + fn batch_inverse(column: &Self::Column, dst: &mut Self::Column); +} + +pub trait FieldExpOps: Mul + MulAssign + Sized + One + Copy { + fn square(&self) -> Self { + (*self) * (*self) + } + + fn pow(&self, exp: u128) -> Self { + let mut res = Self::one(); + let mut base = *self; + let mut exp = exp; + while exp > 0 { + if exp & 1 == 1 { + res *= base; + } + base = base.square(); + exp >>= 1; + } + res + } + + fn inverse(&self) -> Self; + + /// Inverts a batch of elements using Montgomery's trick. + fn batch_inverse(column: &[Self], dst: &mut [Self]) { + const WIDTH: usize = 4; + let n = column.len(); + debug_assert!(dst.len() >= n); + + if n <= WIDTH || n % WIDTH != 0 { + batch_inverse_classic(column, dst); + return; + } + + // First pass. Compute 'WIDTH' cumulative products in an interleaving fashion, reducing + // instruction dependency and allowing better pipelining. + let mut cum_prod: [Self; WIDTH] = [Self::one(); WIDTH]; + dst[..WIDTH].copy_from_slice(&cum_prod); + for i in 0..n { + cum_prod[i % WIDTH] *= column[i]; + dst[i] = cum_prod[i % WIDTH]; + } + + // Inverse cumulative products. + // Use classic batch inversion. + let mut tail_inverses = [Self::one(); WIDTH]; + batch_inverse_classic(&dst[n - WIDTH..], &mut tail_inverses); + + // Second pass. + for i in (WIDTH..n).rev() { + dst[i] = dst[i - WIDTH] * tail_inverses[i % WIDTH]; + tail_inverses[i % WIDTH] *= column[i]; + } + dst[0..WIDTH].copy_from_slice(&tail_inverses); + } +} + +/// Assumes dst is initialized and of the same length as column. +fn batch_inverse_classic(column: &[T], dst: &mut [T]) { + let n = column.len(); + debug_assert!(dst.len() >= n); + + dst[0] = column[0]; + // First pass. + for i in 1..n { + dst[i] = dst[i - 1] * column[i]; + } + + // Inverse cumulative product. + let mut curr_inverse = dst[n - 1].inverse(); + + // Second pass. + for i in (1..n).rev() { + dst[i] = dst[i - 1] * curr_inverse; + curr_inverse *= column[i]; + } + dst[0] = curr_inverse; +} + +pub trait Field: + NumAssign + + Neg + + ComplexConjugate + + Copy + + Default + + Debug + + Display + + PartialOrd + + Ord + + Send + + Sync + + Sized + + FieldExpOps + + Product + + for<'a> Product<&'a Self> + + Sum + + for<'a> Sum<&'a Self> +{ + fn double(&self) -> Self { + (*self) + (*self) + } +} + +/// # Safety +/// +/// Do not use unless you are aware of the endianess in the platform you are compiling for, and the +/// Field element's representation in memory. +// TODO(Ohad): Do not compile on non-le targets. +pub unsafe trait IntoSlice: Sized { + fn into_slice(sl: &[Self]) -> &[T] { + unsafe { + std::slice::from_raw_parts( + sl.as_ptr() as *const T, + std::mem::size_of_val(sl) / std::mem::size_of::(), + ) + } + } +} + +unsafe impl IntoSlice for F {} + +pub trait ComplexConjugate { + /// # Example + /// + /// ``` + /// use stwo_prover::core::fields::m31::P; + /// use stwo_prover::core::fields::qm31::QM31; + /// use stwo_prover::core::fields::ComplexConjugate; + /// + /// let x = QM31::from_u32_unchecked(1, 2, 3, 4); + /// assert_eq!( + /// x.complex_conjugate(), + /// QM31::from_u32_unchecked(1, 2, P - 3, P - 4) + /// ); + /// ``` + fn complex_conjugate(&self) -> Self; +} + +pub trait ExtensionOf: Field + From + NumOps + NumAssignOps { + const EXTENSION_DEGREE: usize; +} + +impl ExtensionOf for F { + const EXTENSION_DEGREE: usize = 1; +} + +#[macro_export] +macro_rules! impl_field { + ($field_name: ty, $field_size: ident) => { + use std::iter::{Product, Sum}; + + use num_traits::{Num, One, Zero}; + use $crate::core::fields::Field; + + impl Num for $field_name { + type FromStrRadixErr = Box; + + fn from_str_radix(_str: &str, _radix: u32) -> Result { + unimplemented!( + "Num::from_str_radix is not implemented for {}", + stringify!($field_name) + ); + } + } + + impl Field for $field_name {} + + impl AddAssign for $field_name { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } + } + + impl SubAssign for $field_name { + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } + } + + impl MulAssign for $field_name { + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } + } + + impl Div for $field_name { + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: Self) -> Self::Output { + self * rhs.inverse() + } + } + + impl DivAssign for $field_name { + fn div_assign(&mut self, rhs: Self) { + *self = *self / rhs; + } + } + + impl Rem for $field_name { + type Output = Self; + + fn rem(self, _rhs: Self) -> Self::Output { + unimplemented!("Rem is not implemented for {}", stringify!($field_name)); + } + } + + impl RemAssign for $field_name { + fn rem_assign(&mut self, _rhs: Self) { + unimplemented!( + "RemAssign is not implemented for {}", + stringify!($field_name) + ); + } + } + + impl Product for $field_name { + fn product(mut iter: I) -> Self + where + I: Iterator, + { + let first = iter.next().unwrap_or_else(Self::one); + iter.fold(first, |a, b| a * b) + } + } + + impl<'a> Product<&'a Self> for $field_name { + fn product(iter: I) -> Self + where + I: Iterator, + { + iter.map(|&v| v).product() + } + } + + impl Sum for $field_name { + fn sum(mut iter: I) -> Self + where + I: Iterator, + { + let first = iter.next().unwrap_or_else(Self::zero); + iter.fold(first, |a, b| a + b) + } + } + + impl<'a> Sum<&'a Self> for $field_name { + fn sum(iter: I) -> Self + where + I: Iterator, + { + iter.map(|&v| v).sum() + } + } + }; +} + +/// Used to extend a field (with characteristic M31) by 2. +#[macro_export] +macro_rules! impl_extension_field { + ($field_name: ident, $extended_field_name: ty) => { + use rand::distributions::{Distribution, Standard}; + use $crate::core::fields::ExtensionOf; + + impl ExtensionOf for $field_name { + const EXTENSION_DEGREE: usize = + <$extended_field_name as ExtensionOf>::EXTENSION_DEGREE * 2; + } + + impl Add for $field_name { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self(self.0 + rhs.0, self.1 + rhs.1) + } + } + + impl Neg for $field_name { + type Output = Self; + + fn neg(self) -> Self::Output { + Self(-self.0, -self.1) + } + } + + impl Sub for $field_name { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self(self.0 - rhs.0, self.1 - rhs.1) + } + } + + impl One for $field_name { + fn one() -> Self { + Self( + <$extended_field_name>::one(), + <$extended_field_name>::zero(), + ) + } + } + + impl Zero for $field_name { + fn zero() -> Self { + Self( + <$extended_field_name>::zero(), + <$extended_field_name>::zero(), + ) + } + + fn is_zero(&self) -> bool { + *self == Self::zero() + } + } + + impl Add for $field_name { + type Output = Self; + + fn add(self, rhs: M31) -> Self::Output { + Self(self.0 + rhs, self.1) + } + } + + impl Add<$field_name> for M31 { + type Output = $field_name; + + fn add(self, rhs: $field_name) -> Self::Output { + rhs + self + } + } + + impl Sub for $field_name { + type Output = Self; + + fn sub(self, rhs: M31) -> Self::Output { + Self(self.0 - rhs, self.1) + } + } + + impl Sub<$field_name> for M31 { + type Output = $field_name; + + fn sub(self, rhs: $field_name) -> Self::Output { + -rhs + self + } + } + + impl Mul for $field_name { + type Output = Self; + + fn mul(self, rhs: M31) -> Self::Output { + Self(self.0 * rhs, self.1 * rhs) + } + } + + impl Mul<$field_name> for M31 { + type Output = $field_name; + + fn mul(self, rhs: $field_name) -> Self::Output { + rhs * self + } + } + + impl Div for $field_name { + type Output = Self; + + fn div(self, rhs: M31) -> Self::Output { + Self(self.0 / rhs, self.1 / rhs) + } + } + + impl Div<$field_name> for M31 { + type Output = $field_name; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: $field_name) -> Self::Output { + rhs.inverse() * self + } + } + + impl ComplexConjugate for $field_name { + fn complex_conjugate(&self) -> Self { + Self(self.0, -self.1) + } + } + + impl From for $field_name { + fn from(x: M31) -> Self { + Self(x.into(), <$extended_field_name>::zero()) + } + } + + impl AddAssign for $field_name { + fn add_assign(&mut self, rhs: M31) { + *self = *self + rhs; + } + } + + impl SubAssign for $field_name { + fn sub_assign(&mut self, rhs: M31) { + *self = *self - rhs; + } + } + + impl MulAssign for $field_name { + fn mul_assign(&mut self, rhs: M31) { + *self = *self * rhs; + } + } + + impl DivAssign for $field_name { + fn div_assign(&mut self, rhs: M31) { + *self = *self / rhs; + } + } + + impl Rem for $field_name { + type Output = Self; + + fn rem(self, _rhs: M31) -> Self::Output { + unimplemented!("Rem is not implemented for {}", stringify!($field_name)); + } + } + + impl RemAssign for $field_name { + fn rem_assign(&mut self, _rhs: M31) { + unimplemented!( + "RemAssign is not implemented for {}", + stringify!($field_name) + ); + } + } + + impl Distribution<$field_name> for Standard { + // Not intended for cryptographic use. Should only be used in tests and benchmarks. + fn sample(&self, rng: &mut R) -> $field_name { + $field_name(rng.gen(), rng.gen()) + } + } + }; +} + +#[cfg(test)] +mod tests { + use num_traits::Zero; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use crate::core::fields::m31::M31; + use crate::core::fields::FieldExpOps; + + #[test] + fn test_slice_batch_inverse() { + let mut rng = SmallRng::seed_from_u64(0); + let elements: [M31; 16] = rng.gen(); + let expected = elements.iter().map(|e| e.inverse()).collect::>(); + let mut dst = [M31::zero(); 16]; + + M31::batch_inverse(&elements, &mut dst); + + assert_eq!(expected, dst); + } + + #[test] + #[should_panic] + fn test_slice_batch_inverse_wrong_dst_size() { + let mut rng = SmallRng::seed_from_u64(0); + let elements: [M31; 16] = rng.gen(); + let mut dst = [M31::zero(); 15]; + + M31::batch_inverse(&elements, &mut dst); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/fields/qm31.rs b/Stwo_wrapper/crates/prover/src/core/fields/qm31.rs new file mode 100644 index 0000000..6da19a3 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/fields/qm31.rs @@ -0,0 +1,195 @@ +use std::fmt::{Debug, Display}; +use std::ops::{ + Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, +}; + +use serde::{Deserialize, Serialize}; + +use super::secure_column::SECURE_EXTENSION_DEGREE; +use super::{ComplexConjugate, FieldExpOps}; +use crate::core::fields::cm31::CM31; +use crate::core::fields::m31::M31; +use crate::{impl_extension_field, impl_field}; + +pub const P4: u128 = 21267647892944572736998860269687930881; // (2 ** 31 - 1) ** 4 +pub const R: CM31 = CM31::from_u32_unchecked(2, 1); + +/// Extension field of CM31. +/// Equivalent to CM31\[x\] over (x^2 - 2 - i) as the irreducible polynomial. +/// Represented as ((a, b), (c, d)) of (a + bi) + (c + di)u. +#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] +pub struct QM31(pub CM31, pub CM31); +pub type SecureField = QM31; + +impl_field!(QM31, P4); +impl_extension_field!(QM31, CM31); + +impl QM31 { + pub const fn from_u32_unchecked(a: u32, b: u32, c: u32, d: u32) -> Self { + Self( + CM31::from_u32_unchecked(a, b), + CM31::from_u32_unchecked(c, d), + ) + } + + pub fn from_m31(a: M31, b: M31, c: M31, d: M31) -> Self { + Self(CM31::from_m31(a, b), CM31::from_m31(c, d)) + } + + pub fn from_m31_array(array: [M31; SECURE_EXTENSION_DEGREE]) -> Self { + Self::from_m31(array[0], array[1], array[2], array[3]) + } + + pub fn to_m31_array(self) -> [M31; SECURE_EXTENSION_DEGREE] { + [self.0 .0, self.0 .1, self.1 .0, self.1 .1] + } + + /// Returns the combined value, given the values of its composing base field polynomials at that + /// point. + pub fn from_partial_evals(evals: [Self; SECURE_EXTENSION_DEGREE]) -> Self { + let mut res = evals[0]; + res += evals[1] * Self::from_u32_unchecked(0, 1, 0, 0); + res += evals[2] * Self::from_u32_unchecked(0, 0, 1, 0); + res += evals[3] * Self::from_u32_unchecked(0, 0, 0, 1); + res + } + + // Note: Adding this as a Mul impl drives rust insane, and it tries to infer Qm31*Qm31 as + // QM31*CM31. + pub fn mul_cm31(self, rhs: CM31) -> Self { + Self(self.0 * rhs, self.1 * rhs) + } +} + +impl Display for QM31 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({}) + ({})u", self.0, self.1) + } +} + +impl Debug for QM31 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({}) + ({})u", self.0, self.1) + } +} + +impl Mul for QM31 { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + // (a + bu) * (c + du) = (ac + rbd) + (ad + bc)u. + Self( + self.0 * rhs.0 + R * self.1 * rhs.1, + self.0 * rhs.1 + self.1 * rhs.0, + ) + } +} + +impl From for QM31 { + fn from(value: usize) -> Self { + M31::from(value).into() + } +} + +impl From for QM31 { + fn from(value: u32) -> Self { + M31::from(value).into() + } +} + +impl From for QM31 { + fn from(value: i32) -> Self { + M31::from(value).into() + } +} + +impl TryInto for QM31 { + type Error = (); + + fn try_into(self) -> Result { + if self.1 != CM31::zero() { + return Err(()); + } + self.0.try_into() + } +} + +impl FieldExpOps for QM31 { + fn inverse(&self) -> Self { + assert!(!self.is_zero(), "0 has no inverse"); + // (a + bu)^-1 = (a - bu) / (a^2 - (2+i)b^2). + let b2 = self.1.square(); + let ib2 = CM31(-b2.1, b2.0); + let denom = self.0.square() - (b2 + b2 + ib2); + let denom_inverse = denom.inverse(); + Self(self.0 * denom_inverse, -self.1 * denom_inverse) + } +} + +#[cfg(test)] +#[macro_export] +macro_rules! qm31 { + ($m0:expr, $m1:expr, $m2:expr, $m3:expr) => {{ + use $crate::core::fields::qm31::QM31; + QM31::from_u32_unchecked($m0, $m1, $m2, $m3) + }}; +} + +#[cfg(test)] +mod tests { + use num_traits::One; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::QM31; + use crate::core::fields::m31::P; + use crate::core::fields::{FieldExpOps, IntoSlice}; + use crate::m31; + + #[test] + fn test_inverse() { + let qm = qm31!(1, 2, 3, 4); + let qm_inv = qm.inverse(); + assert_eq!(qm * qm_inv, QM31::one()); + } + + #[test] + fn test_ops() { + let qm0 = qm31!(1, 2, 3, 4); + let qm1 = qm31!(4, 5, 6, 7); + let m = m31!(8); + let qm = QM31::from(m); + let qm0_x_qm1 = qm31!(P - 71, 93, P - 16, 50); + + assert_eq!(qm0 + qm1, qm31!(5, 7, 9, 11)); + assert_eq!(qm1 + m, qm1 + qm); + assert_eq!(qm0 * qm1, qm0_x_qm1); + assert_eq!(qm1 * m, qm1 * qm); + assert_eq!(-qm0, qm31!(P - 1, P - 2, P - 3, P - 4)); + assert_eq!(qm0 - qm1, qm31!(P - 3, P - 3, P - 3, P - 3)); + assert_eq!(qm1 - m, qm1 - qm); + assert_eq!(qm0_x_qm1 / qm1, qm31!(1, 2, 3, 4)); + assert_eq!(qm1 / m, qm1 / qm); + } + + #[test] + fn test_into_slice() { + let mut rng = SmallRng::seed_from_u64(0); + let x = (0..100).map(|_| rng.gen()).collect::>(); + + let slice = QM31::into_slice(&x); + + for i in 0..100 { + let corresponding_sub_slice = &slice[i * 16..(i + 1) * 16]; + assert_eq!( + x[i], + qm31!( + u32::from_le_bytes(corresponding_sub_slice[..4].try_into().unwrap()), + u32::from_le_bytes(corresponding_sub_slice[4..8].try_into().unwrap()), + u32::from_le_bytes(corresponding_sub_slice[8..12].try_into().unwrap()), + u32::from_le_bytes(corresponding_sub_slice[12..16].try_into().unwrap()) + ) + ) + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/fields/secure_column.rs b/Stwo_wrapper/crates/prover/src/core/fields/secure_column.rs new file mode 100644 index 0000000..073d21b --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/fields/secure_column.rs @@ -0,0 +1,111 @@ +use std::array; +use std::iter::zip; + +use super::m31::BaseField; +use super::qm31::SecureField; +use super::{ExtensionOf, FieldOps}; +use crate::core::backend::{Col, Column, CpuBackend}; + +pub const SECURE_EXTENSION_DEGREE: usize = + >::EXTENSION_DEGREE; + +/// A column major array of `SECURE_EXTENSION_DEGREE` base field columns, that represents a column +/// of secure field element coordinates. +#[derive(Clone, Debug)] +pub struct SecureColumnByCoords> { + pub columns: [Col; SECURE_EXTENSION_DEGREE], +} +impl SecureColumnByCoords { + // TODO(spapini): Remove when we no longer use CircleEvaluation. + pub fn to_vec(&self) -> Vec { + (0..self.len()).map(|i| self.at(i)).collect() + } +} +impl> SecureColumnByCoords { + pub fn at(&self, index: usize) -> SecureField { + SecureField::from_m31_array(std::array::from_fn(|i| self.columns[i].at(index))) + } + + pub fn zeros(len: usize) -> Self { + Self { + columns: std::array::from_fn(|_| Col::::zeros(len)), + } + } + + /// # Safety + pub unsafe fn uninitialized(len: usize) -> Self { + Self { + columns: std::array::from_fn(|_| Col::::uninitialized(len)), + } + } + + pub fn len(&self) -> usize { + self.columns[0].len() + } + + pub fn is_empty(&self) -> bool { + self.columns[0].is_empty() + } + + pub fn to_cpu(&self) -> SecureColumnByCoords { + SecureColumnByCoords { + columns: self.columns.clone().map(|c| c.to_cpu()), + } + } + + pub fn set(&mut self, index: usize, value: SecureField) { + let values = value.to_m31_array(); + #[allow(clippy::needless_range_loop)] + for i in 0..SECURE_EXTENSION_DEGREE { + self.columns[i].set(index, values[i]); + } + } +} + +pub struct SecureColumnByCoordsIter<'a> { + column: &'a SecureColumnByCoords, + index: usize, +} +impl Iterator for SecureColumnByCoordsIter<'_> { + type Item = SecureField; + + fn next(&mut self) -> Option { + if self.index < self.column.len() { + let value = self.column.at(self.index); + self.index += 1; + Some(value) + } else { + None + } + } +} +impl<'a> IntoIterator for &'a SecureColumnByCoords { + type Item = SecureField; + type IntoIter = SecureColumnByCoordsIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + SecureColumnByCoordsIter { + column: self, + index: 0, + } + } +} +impl FromIterator for SecureColumnByCoords { + fn from_iter>(iter: I) -> Self { + let values = iter.into_iter(); + let (lower_bound, _) = values.size_hint(); + let mut columns = array::from_fn(|_| Vec::with_capacity(lower_bound)); + + for value in values { + let coords = value.to_m31_array(); + zip(&mut columns, coords).for_each(|(col, coord)| col.push(coord)); + } + + SecureColumnByCoords { columns } + } +} +impl From> for Vec { + fn from(column: SecureColumnByCoords) -> Self { + column.into_iter().collect() + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/fri.rs b/Stwo_wrapper/crates/prover/src/core/fri.rs new file mode 100644 index 0000000..db3e37d --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/fri.rs @@ -0,0 +1,1425 @@ +use std::cmp::Reverse; +use std::collections::BTreeMap; +use std::fmt::Debug; +use std::iter::zip; +use std::ops::RangeInclusive; + +use itertools::Itertools; +use num_traits::Zero; +use thiserror::Error; +use tracing::{span, Level}; + +use super::backend::CpuBackend; +use super::channel::{Channel, MerkleChannel}; +use super::fields::m31::BaseField; +use super::fields::qm31::SecureField; +use super::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE}; +use super::fields::FieldOps; +use super::poly::circle::{CircleEvaluation, PolyOps, SecureEvaluation}; +use super::poly::line::{LineEvaluation, LinePoly}; +use super::poly::twiddles::TwiddleTree; +use super::poly::BitReversedOrder; +// TODO(andrew): Create fri/ directory, move queries.rs there and split this file up. +use super::queries::{Queries, SparseSubCircleDomain}; +use crate::core::circle::Coset; +use crate::core::fft::ibutterfly; +use crate::core::fields::FieldExpOps; +use crate::core::poly::line::LineDomain; +use crate::core::utils::bit_reverse_index; +use crate::core::vcs::ops::{MerkleHasher, MerkleOps}; +use crate::core::vcs::prover::{MerkleDecommitment, MerkleProver}; +use crate::core::vcs::verifier::{MerkleVerificationError, MerkleVerifier}; + +/// FRI proof config +// TODO(andrew): Support different step sizes. +#[derive(Debug, Clone, Copy)] +pub struct FriConfig { + pub log_blowup_factor: u32, + pub log_last_layer_degree_bound: u32, + pub n_queries: usize, + // TODO(andrew): fold_steps. +} + +impl FriConfig { + const LOG_MIN_LAST_LAYER_DEGREE_BOUND: u32 = 0; + const LOG_MAX_LAST_LAYER_DEGREE_BOUND: u32 = 10; + const LOG_LAST_LAYER_DEGREE_BOUND_RANGE: RangeInclusive = + Self::LOG_MIN_LAST_LAYER_DEGREE_BOUND..=Self::LOG_MAX_LAST_LAYER_DEGREE_BOUND; + + const LOG_MIN_BLOWUP_FACTOR: u32 = 1; + const LOG_MAX_BLOWUP_FACTOR: u32 = 16; + const LOG_BLOWUP_FACTOR_RANGE: RangeInclusive = + Self::LOG_MIN_BLOWUP_FACTOR..=Self::LOG_MAX_BLOWUP_FACTOR; + + /// Creates a new FRI configuration. + /// + /// # Panics + /// + /// Panics if: + /// * `log_last_layer_degree_bound` is greater than 10. + /// * `log_blowup_factor` is equal to zero or greater than 16. + pub fn new(log_last_layer_degree_bound: u32, log_blowup_factor: u32, n_queries: usize) -> Self { + assert!(Self::LOG_LAST_LAYER_DEGREE_BOUND_RANGE.contains(&log_last_layer_degree_bound)); + assert!(Self::LOG_BLOWUP_FACTOR_RANGE.contains(&log_blowup_factor)); + Self { + log_blowup_factor, + log_last_layer_degree_bound, + n_queries, + } + } + + fn last_layer_domain_size(&self) -> usize { + 1 << (self.log_last_layer_degree_bound + self.log_blowup_factor) + } +} + +pub trait FriOps: FieldOps + PolyOps + Sized + FieldOps { + /// Folds a degree `d` polynomial into a degree `d/2` polynomial. + /// + /// Let `eval` be a polynomial evaluated on a [LineDomain] `E`, `alpha` be a random field + /// element and `pi(x) = 2x^2 - 1` be the circle's x-coordinate doubling map. This function + /// returns `f' = f0 + alpha * f1` evaluated on `pi(E)` such that `2f(x) = f0(pi(x)) + x * + /// f1(pi(x))`. + /// + /// # Panics + /// + /// Panics if there are less than two evaluations. + fn fold_line( + eval: &LineEvaluation, + alpha: SecureField, + twiddles: &TwiddleTree, + ) -> LineEvaluation; + + /// Folds and accumulates a degree `d` circle polynomial into a degree `d/2` univariate + /// polynomial. + /// + /// Let `src` be the evaluation of a circle polynomial `f` on a + /// [`CircleDomain`] `E`. This function computes evaluations of `f' = f0 + /// + alpha * f1` on the x-coordinates of `E` such that `2f(p) = f0(px) + py * f1(px)`. The + /// evaluations of `f'` are accumulated into `dst` by the formula `dst = dst * alpha^2 + + /// f'`. + /// + /// # Panics + /// + /// Panics if `src` is not double the length of `dst`. + /// + /// [`CircleDomain`]: super::poly::circle::CircleDomain + // TODO(andrew): Make folding factor generic. + // TODO(andrew): Fold directly into FRI layer to prevent allocation. + fn fold_circle_into_line( + dst: &mut LineEvaluation, + src: &SecureEvaluation, + alpha: SecureField, + twiddles: &TwiddleTree, + ); + + /// Decomposes a FRI-space polynomial into a polynomial inside the fft-space and the + /// remainder term. + /// FRI-space: polynomials of total degree n/2. + /// Based on lemma #12 from the CircleStark paper: f(P) = g(P)+ lambda * alternating(P), + /// where lambda is the cosset diff of eval, and g is a polynomial in the fft-space. + fn decompose( + eval: &SecureEvaluation, + ) -> (SecureEvaluation, SecureField); +} +/// A FRI prover that applies the FRI protocol to prove a set of polynomials are of low degree. +pub struct FriProver, MC: MerkleChannel> { + config: FriConfig, + inner_layers: Vec>, + last_layer_poly: LinePoly, + /// Unique sizes of committed columns sorted in descending order. + column_log_sizes: Vec, +} + +impl, MC: MerkleChannel> FriProver { + /// Commits to multiple [CircleEvaluation]s. + /// + /// `columns` must be provided in descending order by size. + /// + /// Mixed degree STARKs involve polynomials evaluated on multiple domains of different size. + /// Combining evaluations on different sized domains into an evaluation of a single polynomial + /// on a single domain for the purpose of commitment is inefficient. Instead, commit to multiple + /// polynomials so combining of evaluations can be taken care of efficiently at the appropriate + /// FRI layer. All evaluations must be taken over canonic [`CircleDomain`]s. + /// + /// # Panics + /// + /// Panics if: + /// * `columns` is empty or not sorted in ascending order by domain size. + /// * An evaluation is not from a sufficiently low degree circle polynomial. + /// * An evaluation's domain is smaller than the last layer. + /// * An evaluation's domain is not a canonic circle domain. + /// + /// [`CircleDomain`]: super::poly::circle::CircleDomain + // TODO(andrew): Add docs for all evaluations needing to be from canonic domains. + pub fn commit( + channel: &mut MC::C, + config: FriConfig, + columns: &[SecureEvaluation], + twiddles: &TwiddleTree, + ) -> Self { + let _span = span!(Level::INFO, "FRI commitment").entered(); + assert!(!columns.is_empty(), "no columns"); + assert!(columns.is_sorted_by_key(|e| Reverse(e.len())), "not sorted"); + assert!(columns.iter().all(|e| e.domain.is_canonic()), "not canonic"); + let (inner_layers, last_layer_evaluation) = + Self::commit_inner_layers(channel, config, columns, twiddles); + let last_layer_poly = Self::commit_last_layer(channel, config, last_layer_evaluation); + + let column_log_sizes = columns + .iter() + .map(|e| e.domain.log_size()) + .dedup() + .collect(); + Self { + config, + inner_layers, + last_layer_poly, + column_log_sizes, + } + } + + /// Builds and commits to the inner FRI layers (all layers except the last layer). + /// + /// All `columns` must be provided in descending order by size. + /// + /// Returns all inner layers and the evaluation of the last layer. + fn commit_inner_layers( + channel: &mut MC::C, + config: FriConfig, + columns: &[SecureEvaluation], + twiddles: &TwiddleTree, + ) -> (Vec>, LineEvaluation) { + // Returns the length of the [LineEvaluation] a [CircleEvaluation] gets folded into. + let folded_len = + |e: &SecureEvaluation| e.len() >> CIRCLE_TO_LINE_FOLD_STEP; + + let first_layer_size = folded_len(&columns[0]); + let first_layer_domain = LineDomain::new(Coset::half_odds(first_layer_size.ilog2())); + let mut layer_evaluation = LineEvaluation::new_zero(first_layer_domain); + + let mut columns = columns.iter().peekable(); + + let mut layers = Vec::new(); + + // Circle polynomials can all be folded with the same alpha. + let circle_poly_alpha = channel.draw_felt(); + + while layer_evaluation.len() > config.last_layer_domain_size() { + // Check for any columns (circle poly evaluations) that should be combined. + while let Some(column) = columns.next_if(|c| folded_len(c) == layer_evaluation.len()) { + B::fold_circle_into_line( + &mut layer_evaluation, + column, + circle_poly_alpha, + twiddles, + ); + } + + let layer = FriLayerProver::new(layer_evaluation); + MC::mix_root(channel, layer.merkle_tree.root()); + let folding_alpha = channel.draw_felt(); + let folded_layer_evaluation = B::fold_line(&layer.evaluation, folding_alpha, twiddles); + + layer_evaluation = folded_layer_evaluation; + layers.push(layer); + } + + // Check all columns have been consumed. + assert!(columns.is_empty()); + + (layers, layer_evaluation) + } + + /// Builds and commits to the last layer. + /// + /// The layer is committed to by sending the verifier all the coefficients of the remaining + /// polynomial. + /// + /// # Panics + /// + /// Panics if: + /// * The evaluation domain size exceeds the maximum last layer domain size. + /// * The evaluation is not of sufficiently low degree. + fn commit_last_layer( + channel: &mut MC::C, + config: FriConfig, + evaluation: LineEvaluation, + ) -> LinePoly { + assert_eq!(evaluation.len(), config.last_layer_domain_size()); + + let evaluation = evaluation.to_cpu(); + let mut coeffs = evaluation.interpolate().into_ordered_coefficients(); + + let last_layer_degree_bound = 1 << config.log_last_layer_degree_bound; + let zeros = coeffs.split_off(last_layer_degree_bound); + assert!(zeros.iter().all(SecureField::is_zero), "invalid degree"); + + let last_layer_poly = LinePoly::from_ordered_coefficients(coeffs); + channel.mix_felts(&last_layer_poly); + + last_layer_poly + } + + /// Generates a FRI proof and returns it with the opening positions for the committed columns. + /// + /// Returned column opening positions are mapped by their log size. + pub fn decommit( + self, + channel: &mut MC::C, + ) -> (FriProof, BTreeMap) { + let max_column_log_size = self.column_log_sizes[0]; + let queries = Queries::generate(channel, max_column_log_size, self.config.n_queries); + let positions = get_opening_positions(&queries, &self.column_log_sizes); + let proof = self.decommit_on_queries(&queries); + (proof, positions) + } + + /// # Panics + /// + /// Panics if the queries were sampled on the wrong domain size. + fn decommit_on_queries(self, queries: &Queries) -> FriProof { + let max_column_log_size = self.column_log_sizes[0]; + assert_eq!(queries.log_domain_size, max_column_log_size); + let first_layer_queries = queries.fold(CIRCLE_TO_LINE_FOLD_STEP); + let inner_layers = self + .inner_layers + .into_iter() + .scan(first_layer_queries, |layer_queries, layer| { + let layer_proof = layer.decommit(layer_queries); + *layer_queries = layer_queries.fold(FOLD_STEP); + Some(layer_proof) + }) + .collect(); + + let last_layer_poly = self.last_layer_poly; + + FriProof { + inner_layers, + last_layer_poly, + } + } +} + +pub struct FriVerifier { + config: FriConfig, + /// Alpha used to fold all circle polynomials to univariate polynomials. + circle_poly_alpha: SecureField, + /// Domain size queries should be sampled from. + expected_query_log_domain_size: u32, + /// The list of degree bounds of all committed circle polynomials. + column_bounds: Vec, + inner_layers: Vec>, + last_layer_domain: LineDomain, + last_layer_poly: LinePoly, + /// The queries used for decommitment. Initialized when calling + /// [`FriVerifier::column_opening_positions`]. + queries: Option, +} + +impl FriVerifier { + /// Verifies the commitment stage of FRI. + /// + /// `column_bounds` should be the committed circle polynomial degree bounds in descending order. + /// + /// # Errors + /// + /// An `Err` will be returned if: + /// * The proof contains an invalid number of FRI layers. + /// * The degree of the last layer polynomial is too high. + /// + /// # Panics + /// + /// Panics if: + /// * There are no degree bounds. + /// * The degree bounds are not sorted in descending order. + /// * A degree bound is less than or equal to the last layer's degree bound. + pub fn commit( + channel: &mut MC::C, + config: FriConfig, + proof: FriProof, + column_bounds: Vec, + ) -> Result { + assert!(column_bounds.is_sorted_by_key(|b| Reverse(*b))); + + let max_column_bound = column_bounds[0]; + let expected_query_log_domain_size = + max_column_bound.log_degree_bound + config.log_blowup_factor; + + // Circle polynomials can all be folded with the same alpha. + let circle_poly_alpha = channel.draw_felt(); + + let mut inner_layers = Vec::new(); + let mut layer_bound = max_column_bound.fold_to_line(); + let mut layer_domain = LineDomain::new(Coset::half_odds( + layer_bound.log_degree_bound + config.log_blowup_factor, + )); + + for (layer_index, proof) in proof.inner_layers.into_iter().enumerate() { + MC::mix_root(channel, proof.commitment); + + let folding_alpha = channel.draw_felt(); + + inner_layers.push(FriLayerVerifier { + degree_bound: layer_bound, + domain: layer_domain, + folding_alpha, + layer_index, + proof, + }); + + layer_bound = layer_bound + .fold(FOLD_STEP) + .ok_or(FriVerificationError::InvalidNumFriLayers)?; + layer_domain = layer_domain.double(); + } + + if layer_bound.log_degree_bound != config.log_last_layer_degree_bound { + return Err(FriVerificationError::InvalidNumFriLayers); + } + + let last_layer_domain = layer_domain; + let last_layer_poly = proof.last_layer_poly; + + if last_layer_poly.len() > (1 << config.log_last_layer_degree_bound) { + return Err(FriVerificationError::LastLayerDegreeInvalid); + } + + channel.mix_felts(&last_layer_poly); + + Ok(Self { + config, + circle_poly_alpha, + column_bounds, + expected_query_log_domain_size, + inner_layers, + last_layer_domain, + last_layer_poly, + queries: None, + }) + } + + /// Verifies the decommitment stage of FRI. + /// + /// The decommitment values need to be provided in the same order as their commitment. + /// + /// # Panics + /// + /// Panics if: + /// * The queries were not yet sampled. + /// * The queries were sampled on the wrong domain size. + /// * There aren't the same number of decommitted values as degree bounds. + // TODO(andrew): Finish docs. + pub fn decommit( + mut self, + decommitted_values: Vec, + ) -> Result<(), FriVerificationError> { + let queries = self.queries.take().expect("queries not sampled"); + self.decommit_on_queries(&queries, decommitted_values) + } + + fn decommit_on_queries( + self, + queries: &Queries, + decommitted_values: Vec, + ) -> Result<(), FriVerificationError> { + assert_eq!(queries.log_domain_size, self.expected_query_log_domain_size); + assert_eq!(decommitted_values.len(), self.column_bounds.len()); + + let (last_layer_queries, last_layer_query_evals) = + self.decommit_inner_layers(queries, decommitted_values)?; + + self.decommit_last_layer(last_layer_queries, last_layer_query_evals) + } + + /// Verifies all inner layer decommitments. + /// + /// Returns the queries and query evaluations needed for verifying the last FRI layer. + fn decommit_inner_layers( + &self, + queries: &Queries, + decommitted_values: Vec, + ) -> Result<(Queries, Vec), FriVerificationError> { + let circle_poly_alpha = self.circle_poly_alpha; + let circle_poly_alpha_sq = circle_poly_alpha * circle_poly_alpha; + + let mut decommitted_values = decommitted_values.into_iter(); + let mut column_bounds = self.column_bounds.iter().copied().peekable(); + let mut layer_queries = queries.fold(CIRCLE_TO_LINE_FOLD_STEP); + let mut layer_query_evals = vec![SecureField::zero(); layer_queries.len()]; + + for layer in self.inner_layers.iter() { + // Check for column evals that need to folded into this layer. + while column_bounds + .next_if(|b| b.fold_to_line() == layer.degree_bound) + .is_some() + { + let sparse_evaluation = decommitted_values.next().unwrap(); + let folded_evals = sparse_evaluation.fold(circle_poly_alpha); + assert_eq!(folded_evals.len(), layer_query_evals.len()); + + for (layer_eval, folded_eval) in zip(&mut layer_query_evals, folded_evals) { + *layer_eval = *layer_eval * circle_poly_alpha_sq + folded_eval; + } + } + + (layer_queries, layer_query_evals) = + layer.verify_and_fold(layer_queries, layer_query_evals)?; + } + + // Check all values have been consumed. + assert!(column_bounds.is_empty()); + assert!(decommitted_values.is_empty()); + + Ok((layer_queries, layer_query_evals)) + } + + /// Verifies the last layer. + fn decommit_last_layer( + self, + queries: Queries, + query_evals: Vec, + ) -> Result<(), FriVerificationError> { + let Self { + last_layer_domain: domain, + last_layer_poly, + .. + } = self; + + for (&query, query_eval) in zip(&*queries, query_evals) { + let x = domain.at(bit_reverse_index(query, domain.log_size())); + + if query_eval != last_layer_poly.eval_at_point(x.into()) { + return Err(FriVerificationError::LastLayerEvaluationsInvalid); + } + } + + Ok(()) + } + + /// Samples queries and returns the opening positions for each unique column size. + /// + /// The order of the opening positions corresponds to the order of the column commitment. + pub fn column_query_positions( + &mut self, + channel: &mut MC::C, + ) -> BTreeMap { + let column_log_sizes = self + .column_bounds + .iter() + .dedup() + .map(|b| b.log_degree_bound + self.config.log_blowup_factor) + .collect_vec(); + let queries = Queries::generate(channel, column_log_sizes[0], self.config.n_queries); + let positions = get_opening_positions(&queries, &column_log_sizes); + self.queries = Some(queries); + positions + } +} + +/// Returns the column opening positions needed for verification. +/// +/// The column log sizes must be unique and in descending order. Returned +/// column opening positions are mapped by their log size. +fn get_opening_positions( + queries: &Queries, + column_log_sizes: &[u32], +) -> BTreeMap { + let mut prev_log_size = column_log_sizes[0]; + assert!(prev_log_size == queries.log_domain_size); + let mut prev_queries = queries.clone(); + let mut positions = BTreeMap::new(); + positions.insert(prev_log_size, prev_queries.opening_positions(FOLD_STEP)); + for log_size in column_log_sizes.iter().skip(1) { + let n_folds = prev_log_size - log_size; + let queries = prev_queries.fold(n_folds); + positions.insert(*log_size, queries.opening_positions(FOLD_STEP)); + prev_log_size = *log_size; + prev_queries = queries; + } + positions +} + +#[derive(Clone, Copy, Debug, Error)] +pub enum FriVerificationError { + #[error("proof contains an invalid number of FRI layers")] + InvalidNumFriLayers, + #[error("queries do not resolve to their commitment in layer {layer}")] + InnerLayerCommitmentInvalid { + layer: usize, + error: MerkleVerificationError, + }, + #[error("evaluations are invalid in layer {layer}")] + InnerLayerEvaluationsInvalid { layer: usize }, + #[error("degree of last layer is invalid")] + LastLayerDegreeInvalid, + #[error("evaluations in the last layer are invalid")] + LastLayerEvaluationsInvalid, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct CirclePolyDegreeBound { + log_degree_bound: u32, +} + +impl CirclePolyDegreeBound { + pub fn new(log_degree_bound: u32) -> Self { + Self { log_degree_bound } + } + + /// Maps a circle polynomial's degree bound to the degree bound of the univariate (line) + /// polynomial it gets folded into. + fn fold_to_line(&self) -> LinePolyDegreeBound { + LinePolyDegreeBound { + log_degree_bound: self.log_degree_bound - CIRCLE_TO_LINE_FOLD_STEP, + } + } +} + +impl PartialOrd for CirclePolyDegreeBound { + fn partial_cmp(&self, other: &LinePolyDegreeBound) -> Option { + Some(self.log_degree_bound.cmp(&other.log_degree_bound)) + } +} + +impl PartialEq for CirclePolyDegreeBound { + fn eq(&self, other: &LinePolyDegreeBound) -> bool { + self.log_degree_bound == other.log_degree_bound + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +struct LinePolyDegreeBound { + log_degree_bound: u32, +} + +impl LinePolyDegreeBound { + /// Returns [None] if the unfolded degree bound is smaller than the folding factor. + fn fold(self, n_folds: u32) -> Option { + if self.log_degree_bound < n_folds { + return None; + } + + let log_degree_bound = self.log_degree_bound - n_folds; + Some(Self { log_degree_bound }) + } +} + +/// A FRI proof. +#[derive(Debug)] +pub struct FriProof { + pub inner_layers: Vec>, + pub last_layer_poly: LinePoly, +} + +/// Number of folds for univariate polynomials. +// TODO(andrew): Support different step sizes. +pub const FOLD_STEP: u32 = 1; + +/// Number of folds when folding a circle polynomial to univariate polynomial. +pub const CIRCLE_TO_LINE_FOLD_STEP: u32 = 1; + +/// Stores a subset of evaluations in a fri layer with their corresponding merkle decommitments. +/// +/// The subset corresponds to the set of evaluations needed by a FRI verifier. +#[derive(Debug)] +pub struct FriLayerProof { + /// The subset stored corresponds to the set of evaluations the verifier doesn't have but needs + /// to fold and verify the merkle decommitment. + pub evals_subset: Vec, + pub decommitment: MerkleDecommitment, + pub commitment: H::Hash, +} + +struct FriLayerVerifier { + degree_bound: LinePolyDegreeBound, + domain: LineDomain, + folding_alpha: SecureField, + layer_index: usize, + proof: FriLayerProof, +} + +impl FriLayerVerifier { + /// Verifies the layer's merkle decommitment and returns the the folded queries and query evals. + /// + /// # Errors + /// + /// An `Err` will be returned if: + /// * The proof doesn't store enough evaluations. + /// * The merkle decommitment is invalid. + /// + /// # Panics + /// + /// Panics if the number of queries doesn't match the number of evals. + fn verify_and_fold( + &self, + queries: Queries, + evals_at_queries: Vec, + ) -> Result<(Queries, Vec), FriVerificationError> { + let decommitment = self.proof.decommitment.clone(); + let commitment = self.proof.commitment; + + // Extract the evals needed for decommitment and folding. + let sparse_evaluation = self.extract_evaluation(&queries, &evals_at_queries)?; + + // TODO: When leaf values are removed from the decommitment, also remove this block. + let actual_decommitment_evals: SecureColumnByCoords = sparse_evaluation + .subline_evals + .iter() + .flat_map(|e| e.values.into_iter()) + .collect(); + + let folded_queries = queries.fold(FOLD_STEP); + + // Positions of all the decommitment evals. + let decommitment_positions = folded_queries + .iter() + .flat_map(|folded_query| { + let start = folded_query << FOLD_STEP; + let end = start + (1 << FOLD_STEP); + start..end + }) + .collect::>(); + + let merkle_verifier = MerkleVerifier::new( + commitment, + vec![self.domain.log_size(); SECURE_EXTENSION_DEGREE], + ); + // TODO(spapini): Propagate error. + merkle_verifier + .verify( + [(self.domain.log_size(), decommitment_positions)] + .into_iter() + .collect(), + actual_decommitment_evals.columns.to_vec(), + decommitment, + ) + .map_err(|e| FriVerificationError::InnerLayerCommitmentInvalid { + layer: self.layer_index, + error: e, + })?; + + let evals_at_folded_queries = sparse_evaluation.fold(self.folding_alpha); + + Ok((folded_queries, evals_at_folded_queries)) + } + + /// Returns the evaluations needed for decommitment. + /// + /// # Errors + /// + /// Returns an `Err` if the proof doesn't store enough evaluations. + /// + /// # Panics + /// + /// Panics if the number of queries doesn't match the number of evals. + fn extract_evaluation( + &self, + queries: &Queries, + evals_at_queries: &[SecureField], + ) -> Result { + // Evals provided by the verifier. + let mut evals_at_queries = evals_at_queries.iter().copied(); + + // Evals stored in the proof. + let mut proof_evals = self.proof.evals_subset.iter().copied(); + + let mut all_subline_evals = Vec::new(); + + // Group queries by the subline they reside in. + for subline_queries in queries.group_by(|a, b| a >> FOLD_STEP == b >> FOLD_STEP) { + let subline_start = (subline_queries[0] >> FOLD_STEP) << FOLD_STEP; + let subline_end = subline_start + (1 << FOLD_STEP); + + let mut subline_evals = Vec::new(); + let mut subline_queries = subline_queries.iter().peekable(); + + // Insert the evals. + for eval_position in subline_start..subline_end { + let eval = match subline_queries.next_if_eq(&&eval_position) { + Some(_) => evals_at_queries.next().unwrap(), + None => proof_evals.next().ok_or( + FriVerificationError::InnerLayerEvaluationsInvalid { + layer: self.layer_index, + }, + )?, + }; + + subline_evals.push(eval); + } + + // Construct the domain. + // TODO(andrew): Create a constructor for LineDomain. + let subline_initial_index = bit_reverse_index(subline_start, self.domain.log_size()); + let subline_initial = self.domain.coset().index_at(subline_initial_index); + let subline_domain = LineDomain::new(Coset::new(subline_initial, FOLD_STEP)); + + all_subline_evals.push(LineEvaluation::new( + subline_domain, + subline_evals.into_iter().collect(), + )); + } + + // Check all proof evals have been consumed. + if !proof_evals.is_empty() { + return Err(FriVerificationError::InnerLayerEvaluationsInvalid { + layer: self.layer_index, + }); + } + + Ok(SparseLineEvaluation::new(all_subline_evals)) + } +} + +/// A FRI layer comprises of a merkle tree that commits to evaluations of a polynomial. +/// +/// The polynomial evaluations are viewed as evaluation of a polynomial on multiple distinct cosets +/// of size two. Each leaf of the merkle tree commits to a single coset evaluation. +// TODO(andrew): Support different step sizes. +struct FriLayerProver, H: MerkleHasher> { + evaluation: LineEvaluation, + merkle_tree: MerkleProver, +} + +impl, H: MerkleHasher> FriLayerProver { + fn new(evaluation: LineEvaluation) -> Self { + // TODO(spapini): Commit on slice. + // TODO(spapini): Merkle tree in backend. + let merkle_tree = MerkleProver::commit(evaluation.values.columns.iter().collect_vec()); + #[allow(unreachable_code)] + FriLayerProver { + evaluation, + merkle_tree, + } + } + + /// Generates a decommitment of the subline evaluations at the specified positions. + fn decommit(self, queries: &Queries) -> FriLayerProof { + let mut decommit_positions = Vec::new(); + let mut evals_subset = Vec::new(); + + // Group queries by the subline they reside in. + // TODO(andrew): Explain what a "subline" is at the top of the module. + for query_group in queries.group_by(|a, b| a >> FOLD_STEP == b >> FOLD_STEP) { + let subline_start = (query_group[0] >> FOLD_STEP) << FOLD_STEP; + let subline_end = subline_start + (1 << FOLD_STEP); + + let mut subline_queries = query_group.iter().peekable(); + + for eval_position in subline_start..subline_end { + // Add decommitment position. + decommit_positions.push(eval_position); + + // Skip evals the verifier can calculate. + if subline_queries.next_if_eq(&&eval_position).is_some() { + continue; + } + + let eval = self.evaluation.values.at(eval_position); + evals_subset.push(eval); + } + } + + let commitment = self.merkle_tree.root(); + // TODO(spapini): Use _evals. + let (_evals, decommitment) = self.merkle_tree.decommit( + [(self.evaluation.len().ilog2(), decommit_positions)] + .into_iter() + .collect(), + self.evaluation.values.columns.iter().collect_vec(), + ); + + FriLayerProof { + evals_subset, + decommitment, + commitment, + } + } +} + +/// Holds a foldable subset of circle polynomial evaluations. +#[derive(Debug, Clone)] +pub struct SparseCircleEvaluation { + subcircle_evals: Vec>, +} + +impl SparseCircleEvaluation { + /// # Panics + /// + /// Panics if the evaluation domain sizes don't equal the folding factor. + pub fn new( + subcircle_evals: Vec>, + ) -> Self { + let folding_factor = 1 << CIRCLE_TO_LINE_FOLD_STEP; + assert!(subcircle_evals.iter().all(|e| e.len() == folding_factor)); + Self { subcircle_evals } + } + + fn fold(self, alpha: SecureField) -> Vec { + self.subcircle_evals + .into_iter() + .map(|e| { + let buffer_domain = LineDomain::new(e.domain.half_coset); + let mut buffer = LineEvaluation::new_zero(buffer_domain); + fold_circle_into_line( + &mut buffer, + &SecureEvaluation::new(e.domain, e.values.into_iter().collect()), + alpha, + ); + buffer.values.at(0) + }) + .collect() + } +} + +impl<'a> IntoIterator for &'a mut SparseCircleEvaluation { + type Item = &'a mut CircleEvaluation; + type IntoIter = + std::slice::IterMut<'a, CircleEvaluation>; + + fn into_iter(self) -> Self::IntoIter { + self.subcircle_evals.iter_mut() + } +} + +/// Holds a small foldable subset of univariate SecureField polynomial evaluations. +/// Evaluation is held at the CPU backend. +#[derive(Debug, Clone)] +struct SparseLineEvaluation { + subline_evals: Vec>, +} + +impl SparseLineEvaluation { + /// # Panics + /// + /// Panics if the evaluation domain sizes don't equal the folding factor. + fn new(subline_evals: Vec>) -> Self { + let folding_factor = 1 << FOLD_STEP; + assert!(subline_evals.iter().all(|e| e.len() == folding_factor)); + Self { subline_evals } + } + + fn fold(self, alpha: SecureField) -> Vec { + self.subline_evals + .into_iter() + .map(|e| fold_line(&e, alpha).values.at(0)) + .collect() + } +} + +/// Folds a degree `d` polynomial into a degree `d/2` polynomial. +/// See [`FriOps::fold_line`]. +pub fn fold_line( + eval: &LineEvaluation, + alpha: SecureField, +) -> LineEvaluation { + let n = eval.len(); + assert!(n >= 2, "Evaluation too small"); + + let domain = eval.domain(); + + let folded_values = eval + .values + .into_iter() + .array_chunks() + .enumerate() + .map(|(i, [f_x, f_neg_x])| { + // TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer. + let x = domain.at(bit_reverse_index(i << FOLD_STEP, domain.log_size())); + + let (mut f0, mut f1) = (f_x, f_neg_x); + ibutterfly(&mut f0, &mut f1, x.inverse()); + f0 + alpha * f1 + }) + .collect(); + + LineEvaluation::new(domain.double(), folded_values) +} + +/// Folds and accumulates a degree `d` circle polynomial into a degree `d/2` univariate +/// polynomial. +/// See [`FriOps::fold_circle_into_line`]. +pub fn fold_circle_into_line( + dst: &mut LineEvaluation, + src: &SecureEvaluation, + alpha: SecureField, +) { + assert_eq!(src.len() >> CIRCLE_TO_LINE_FOLD_STEP, dst.len()); + + let domain = src.domain; + let alpha_sq = alpha * alpha; + + src.into_iter() + .array_chunks() + .enumerate() + .for_each(|(i, [f_p, f_neg_p])| { + // TODO(andrew): Inefficient. Update when domain twiddles get stored in a buffer. + let p = domain.at(bit_reverse_index( + i << CIRCLE_TO_LINE_FOLD_STEP, + domain.log_size(), + )); + + // Calculate `f0(px)` and `f1(px)` such that `2f(p) = f0(px) + py * f1(px)`. + let (mut f0_px, mut f1_px) = (f_p, f_neg_p); + ibutterfly(&mut f0_px, &mut f1_px, p.y.inverse()); + let f_prime = alpha * f1_px + f0_px; + + dst.values.set(i, dst.values.at(i) * alpha_sq + f_prime); + }); +} + +#[cfg(test)] +mod tests { + use std::iter::zip; + + use itertools::Itertools; + use num_traits::{One, Zero}; + + use super::{get_opening_positions, FriVerificationError, SparseCircleEvaluation}; + use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly}; + use crate::core::backend::{ColumnOps, CpuBackend}; + use crate::core::circle::{CirclePointIndex, Coset}; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::Field; + use crate::core::fri::{ + fold_circle_into_line, fold_line, CirclePolyDegreeBound, FriConfig, + CIRCLE_TO_LINE_FOLD_STEP, + }; + use crate::core::poly::circle::{CircleDomain, PolyOps, SecureEvaluation}; + use crate::core::poly::line::{LineDomain, LineEvaluation, LinePoly}; + use crate::core::poly::{BitReversedOrder, NaturalOrder}; + use crate::core::queries::{Queries, SparseSubCircleDomain}; + use crate::core::test_utils::test_channel; + use crate::core::utils::bit_reverse_index; + use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; + + /// Default blowup factor used for tests. + const LOG_BLOWUP_FACTOR: u32 = 2; + + type FriProver = super::FriProver; + type FriVerifier = super::FriVerifier; + + #[test] + fn fold_line_works() { + const DEGREE: usize = 8; + // Coefficients are bit-reversed. + let even_coeffs: [SecureField; DEGREE / 2] = [1, 2, 1, 3] + .map(BaseField::from_u32_unchecked) + .map(SecureField::from); + let odd_coeffs: [SecureField; DEGREE / 2] = [3, 5, 4, 1] + .map(BaseField::from_u32_unchecked) + .map(SecureField::from); + let poly = LinePoly::new([even_coeffs, odd_coeffs].concat()); + let even_poly = LinePoly::new(even_coeffs.to_vec()); + let odd_poly = LinePoly::new(odd_coeffs.to_vec()); + let alpha = BaseField::from_u32_unchecked(19283).into(); + let domain = LineDomain::new(Coset::half_odds(DEGREE.ilog2())); + let drp_domain = domain.double(); + let mut values = domain + .iter() + .map(|p| poly.eval_at_point(p.into())) + .collect(); + CpuBackend::bit_reverse_column(&mut values); + let evals = LineEvaluation::new(domain, values.into_iter().collect()); + + let drp_evals = fold_line(&evals, alpha); + let mut drp_evals = drp_evals.values.into_iter().collect_vec(); + CpuBackend::bit_reverse_column(&mut drp_evals); + + assert_eq!(drp_evals.len(), DEGREE / 2); + for (i, (&drp_eval, x)) in zip(&drp_evals, drp_domain).enumerate() { + let f_e: SecureField = even_poly.eval_at_point(x.into()); + let f_o: SecureField = odd_poly.eval_at_point(x.into()); + assert_eq!(drp_eval, (f_e + alpha * f_o).double(), "mismatch at {i}"); + } + } + + #[test] + fn fold_circle_to_line_works() { + const LOG_DEGREE: u32 = 4; + let circle_evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); + let alpha = SecureField::one(); + let folded_domain = LineDomain::new(circle_evaluation.domain.half_coset); + + let mut folded_evaluation = LineEvaluation::new_zero(folded_domain); + fold_circle_into_line(&mut folded_evaluation, &circle_evaluation, alpha); + + assert_eq!( + log_degree_bound(folded_evaluation), + LOG_DEGREE - CIRCLE_TO_LINE_FOLD_STEP + ); + } + + #[test] + #[should_panic = "invalid degree"] + fn committing_high_degree_polynomial_fails() { + const LOG_EXPECTED_BLOWUP_FACTOR: u32 = LOG_BLOWUP_FACTOR; + const LOG_INVALID_BLOWUP_FACTOR: u32 = LOG_BLOWUP_FACTOR - 1; + let config = FriConfig::new(2, LOG_EXPECTED_BLOWUP_FACTOR, 3); + let evaluation = polynomial_evaluation(6, LOG_INVALID_BLOWUP_FACTOR); + + FriProver::commit( + &mut test_channel(), + config, + &[evaluation.clone()], + &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), + ); + } + + #[test] + #[should_panic = "not canonic"] + fn committing_evaluation_from_invalid_domain_fails() { + let invalid_domain = CircleDomain::new(Coset::new(CirclePointIndex::generator(), 3)); + assert!(!invalid_domain.is_canonic(), "must be an invalid domain"); + let evaluation = SecureEvaluation::new( + invalid_domain, + vec![SecureField::one(); 1 << 4].into_iter().collect(), + ); + + FriProver::commit( + &mut test_channel(), + FriConfig::new(2, 2, 3), + &[evaluation.clone()], + &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), + ); + } + + #[test] + fn valid_proof_passes_verification() -> Result<(), FriVerificationError> { + const LOG_DEGREE: u32 = 3; + let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); + let log_domain_size = evaluation.domain.log_size(); + let queries = Queries::from_positions(vec![5], log_domain_size); + let config = FriConfig::new(1, LOG_BLOWUP_FACTOR, queries.len()); + let decommitment_value = query_polynomial(&evaluation, &queries); + let prover = FriProver::commit( + &mut test_channel(), + config, + &[evaluation.clone()], + &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), + ); + let proof = prover.decommit_on_queries(&queries); + let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; + let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bound).unwrap(); + + verifier.decommit_on_queries(&queries, vec![decommitment_value]) + } + + #[test] + fn valid_proof_with_constant_last_layer_passes_verification() -> Result<(), FriVerificationError> + { + const LOG_DEGREE: u32 = 3; + const LAST_LAYER_LOG_BOUND: u32 = 0; + let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); + let log_domain_size = evaluation.domain.log_size(); + let queries = Queries::from_positions(vec![5], log_domain_size); + let config = FriConfig::new(LAST_LAYER_LOG_BOUND, LOG_BLOWUP_FACTOR, queries.len()); + let decommitment_value = query_polynomial(&evaluation, &queries); + let prover = FriProver::commit( + &mut test_channel(), + config, + &[evaluation.clone()], + &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), + ); + let proof = prover.decommit_on_queries(&queries); + let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; + let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bound).unwrap(); + + verifier.decommit_on_queries(&queries, vec![decommitment_value]) + } + + #[test] + fn valid_mixed_degree_proof_passes_verification() -> Result<(), FriVerificationError> { + const LOG_DEGREES: [u32; 3] = [6, 5, 4]; + let evaluations = LOG_DEGREES.map(|log_d| polynomial_evaluation(log_d, LOG_BLOWUP_FACTOR)); + let log_domain_size = evaluations[0].domain.log_size(); + let queries = Queries::from_positions(vec![7, 70], log_domain_size); + let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, queries.len()); + let prover = FriProver::commit( + &mut test_channel(), + config, + &evaluations, + &CpuBackend::precompute_twiddles(evaluations[0].domain.half_coset), + ); + let decommitment_values = evaluations.map(|p| query_polynomial(&p, &queries)).to_vec(); + let proof = prover.decommit_on_queries(&queries); + let bounds = LOG_DEGREES.map(CirclePolyDegreeBound::new).to_vec(); + let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bounds).unwrap(); + + verifier.decommit_on_queries(&queries, decommitment_values) + } + + #[test] + fn valid_mixed_degree_end_to_end_proof_passes_verification() -> Result<(), FriVerificationError> + { + const LOG_DEGREES: [u32; 3] = [6, 5, 4]; + let evaluations = LOG_DEGREES.map(|log_d| polynomial_evaluation(log_d, LOG_BLOWUP_FACTOR)); + let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, 3); + let prover = FriProver::commit( + &mut test_channel(), + config, + &evaluations, + &CpuBackend::precompute_twiddles(evaluations[0].domain.half_coset), + ); + let (proof, prover_opening_positions) = prover.decommit(&mut test_channel()); + let decommitment_values = zip(&evaluations, prover_opening_positions.values().rev()) + .map(|(poly, positions)| open_polynomial(poly, positions)) + .collect(); + let bounds = LOG_DEGREES.map(CirclePolyDegreeBound::new).to_vec(); + + let mut verifier = FriVerifier::commit(&mut test_channel(), config, proof, bounds).unwrap(); + let verifier_opening_positions = verifier.column_query_positions(&mut test_channel()); + + assert_eq!(prover_opening_positions, verifier_opening_positions); + verifier.decommit(decommitment_values) + } + + #[test] + fn proof_with_removed_layer_fails_verification() { + const LOG_DEGREE: u32 = 6; + let evaluation = polynomial_evaluation(6, LOG_BLOWUP_FACTOR); + let log_domain_size = evaluation.domain.log_size(); + let queries = Queries::from_positions(vec![1], log_domain_size); + let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, queries.len()); + let prover = FriProver::commit( + &mut test_channel(), + config, + &[evaluation.clone()], + &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), + ); + let proof = prover.decommit_on_queries(&queries); + let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; + // Set verifier's config to expect one extra layer than prover config. + let mut invalid_config = config; + invalid_config.log_last_layer_degree_bound -= 1; + + let verifier = FriVerifier::commit(&mut test_channel(), invalid_config, proof, bound); + + assert!(matches!( + verifier, + Err(FriVerificationError::InvalidNumFriLayers) + )); + } + + #[test] + fn proof_with_added_layer_fails_verification() { + const LOG_DEGREE: u32 = 6; + let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); + let log_domain_size = evaluation.domain.log_size(); + let queries = Queries::from_positions(vec![1], log_domain_size); + let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, queries.len()); + let prover = FriProver::commit( + &mut test_channel(), + config, + &[evaluation.clone()], + &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), + ); + let proof = prover.decommit_on_queries(&queries); + let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; + // Set verifier's config to expect one less layer than prover config. + let mut invalid_config = config; + invalid_config.log_last_layer_degree_bound += 1; + + let verifier = FriVerifier::commit(&mut test_channel(), invalid_config, proof, bound); + + assert!(matches!( + verifier, + Err(FriVerificationError::InvalidNumFriLayers) + )); + } + + #[test] + fn proof_with_invalid_inner_layer_evaluation_fails_verification() { + const LOG_DEGREE: u32 = 6; + let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); + let log_domain_size = evaluation.domain.log_size(); + let queries = Queries::from_positions(vec![5], log_domain_size); + let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, queries.len()); + let decommitment_value = query_polynomial(&evaluation, &queries); + let prover = FriProver::commit( + &mut test_channel(), + config, + &[evaluation.clone()], + &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), + ); + let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; + let mut proof = prover.decommit_on_queries(&queries); + // Remove an evaluation from the second layer's proof. + proof.inner_layers[1].evals_subset.pop(); + let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bound).unwrap(); + + let verification_result = verifier.decommit_on_queries(&queries, vec![decommitment_value]); + + assert!(matches!( + verification_result, + Err(FriVerificationError::InnerLayerEvaluationsInvalid { layer: 1 }) + )); + } + + #[test] + fn proof_with_invalid_inner_layer_decommitment_fails_verification() { + const LOG_DEGREE: u32 = 6; + let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); + let log_domain_size = evaluation.domain.log_size(); + let queries = Queries::from_positions(vec![5], log_domain_size); + let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, queries.len()); + let decommitment_value = query_polynomial(&evaluation, &queries); + let prover = FriProver::commit( + &mut test_channel(), + config, + &[evaluation.clone()], + &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), + ); + let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; + let mut proof = prover.decommit_on_queries(&queries); + // Modify the committed values in the second layer. + proof.inner_layers[1].evals_subset[0] += BaseField::one(); + let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bound).unwrap(); + + let verification_result = verifier.decommit_on_queries(&queries, vec![decommitment_value]); + + assert!(matches!( + verification_result, + Err(FriVerificationError::InnerLayerCommitmentInvalid { layer: 1, .. }) + )); + } + + #[test] + fn proof_with_invalid_last_layer_degree_fails_verification() { + const LOG_DEGREE: u32 = 6; + const LOG_MAX_LAST_LAYER_DEGREE: u32 = 2; + let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); + let log_domain_size = evaluation.domain.log_size(); + let queries = Queries::from_positions(vec![1, 7, 8], log_domain_size); + let config = FriConfig::new(LOG_MAX_LAST_LAYER_DEGREE, LOG_BLOWUP_FACTOR, queries.len()); + let prover = FriProver::commit( + &mut test_channel(), + config, + &[evaluation.clone()], + &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), + ); + let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; + let mut proof = prover.decommit_on_queries(&queries); + let bad_last_layer_coeffs = vec![One::one(); 1 << (LOG_MAX_LAST_LAYER_DEGREE + 1)]; + proof.last_layer_poly = LinePoly::new(bad_last_layer_coeffs); + + let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bound); + + assert!(matches!( + verifier, + Err(FriVerificationError::LastLayerDegreeInvalid) + )); + } + + #[test] + fn proof_with_invalid_last_layer_fails_verification() { + const LOG_DEGREE: u32 = 6; + let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); + let log_domain_size = evaluation.domain.log_size(); + let queries = Queries::from_positions(vec![1, 7, 8], log_domain_size); + let config = FriConfig::new(2, LOG_BLOWUP_FACTOR, queries.len()); + let decommitment_value = query_polynomial(&evaluation, &queries); + let prover = FriProver::commit( + &mut test_channel(), + config, + &[evaluation.clone()], + &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), + ); + let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; + let mut proof = prover.decommit_on_queries(&queries); + // Compromise the last layer polynomial's first coefficient. + proof.last_layer_poly[0] += BaseField::one(); + let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bound).unwrap(); + + let verification_result = verifier.decommit_on_queries(&queries, vec![decommitment_value]); + + assert!(matches!( + verification_result, + Err(FriVerificationError::LastLayerEvaluationsInvalid) + )); + } + + #[test] + #[should_panic] + fn decommit_queries_on_invalid_domain_fails_verification() { + const LOG_DEGREE: u32 = 3; + let evaluation = polynomial_evaluation(LOG_DEGREE, LOG_BLOWUP_FACTOR); + let log_domain_size = evaluation.domain.log_size(); + let queries = Queries::from_positions(vec![5], log_domain_size); + let config = FriConfig::new(1, LOG_BLOWUP_FACTOR, queries.len()); + let decommitment_value = query_polynomial(&evaluation, &queries); + let prover = FriProver::commit( + &mut test_channel(), + config, + &[evaluation.clone()], + &CpuBackend::precompute_twiddles(evaluation.domain.half_coset), + ); + let proof = prover.decommit_on_queries(&queries); + let bound = vec![CirclePolyDegreeBound::new(LOG_DEGREE)]; + let verifier = FriVerifier::commit(&mut test_channel(), config, proof, bound).unwrap(); + // Simulate the verifier sampling queries on a smaller domain. + let mut invalid_queries = queries.clone(); + invalid_queries.log_domain_size -= 1; + + let _ = verifier.decommit_on_queries(&invalid_queries, vec![decommitment_value]); + } + + /// Returns an evaluation of a random polynomial with degree `2^log_degree`. + /// + /// The evaluation domain size is `2^(log_degree + log_blowup_factor)`. + fn polynomial_evaluation( + log_degree: u32, + log_blowup_factor: u32, + ) -> SecureEvaluation { + let poly = CpuCirclePoly::new(vec![BaseField::one(); 1 << log_degree]); + let coset = Coset::half_odds(log_degree + log_blowup_factor - 1); + let domain = CircleDomain::new(coset); + let values = poly.evaluate(domain); + SecureEvaluation::new(domain, values.into_iter().map(SecureField::from).collect()) + } + + /// Returns the log degree bound of a polynomial. + fn log_degree_bound(polynomial: LineEvaluation) -> u32 { + let coeffs = polynomial.interpolate().into_ordered_coefficients(); + let degree = coeffs.into_iter().rposition(|c| !c.is_zero()).unwrap_or(0); + (degree + 1).ilog2() + } + + // TODO: Remove after SubcircleDomain integration. + fn query_polynomial( + polynomial: &SecureEvaluation, + queries: &Queries, + ) -> SparseCircleEvaluation { + let polynomial_log_size = polynomial.domain.log_size(); + let positions = + get_opening_positions(queries, &[queries.log_domain_size, polynomial_log_size]); + open_polynomial(polynomial, &positions[&polynomial_log_size]) + } + + fn open_polynomial( + polynomial: &SecureEvaluation, + positions: &SparseSubCircleDomain, + ) -> SparseCircleEvaluation { + let coset_evals = positions + .iter() + .map(|position| { + let coset_domain = position.to_circle_domain(&polynomial.domain); + let evals = coset_domain + .iter_indices() + .map(|p| { + polynomial.at(bit_reverse_index( + polynomial.domain.find(p).unwrap(), + polynomial.domain.log_size(), + )) + }) + .collect(); + let coset_eval = + CpuCircleEvaluation::::new(coset_domain, evals); + coset_eval.bit_reverse() + }) + .collect(); + + SparseCircleEvaluation::new(coset_evals) + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/lookups/gkr_prover.rs b/Stwo_wrapper/crates/prover/src/core/lookups/gkr_prover.rs new file mode 100644 index 0000000..6e6ed25 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/lookups/gkr_prover.rs @@ -0,0 +1,566 @@ +//! GKR batch prover for Grand Product and LogUp lookup arguments. +use std::borrow::Cow; +use std::iter::{successors, zip}; +use std::ops::Deref; + +use educe::Educe; +use itertools::Itertools; +use num_traits::{One, Zero}; +use thiserror::Error; + +use super::gkr_verifier::{GkrArtifact, GkrBatchProof, GkrMask}; +use super::mle::{Mle, MleOps}; +use super::sumcheck::MultivariatePolyOracle; +use super::utils::{eq, random_linear_combination, UnivariatePoly}; +use crate::core::backend::{Col, Column, ColumnOps, CpuBackend}; +use crate::core::channel::Channel; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::{Field, FieldExpOps}; +use crate::core::lookups::sumcheck; + +pub trait GkrOps: MleOps + MleOps { + /// Returns evaluations `eq(x, y) * v` for all `x` in `{0, 1}^n`. + /// + /// Note [`Mle`] stores values in bit-reversed order. + /// + /// [`eq(x, y)`]: crate::core::lookups::utils::eq + fn gen_eq_evals(y: &[SecureField], v: SecureField) -> Mle; + + /// Generates the next GKR layer from the current one. + fn next_layer(layer: &Layer) -> Layer; + + /// Returns univariate polynomial `f(t) = sum_x h(t, x)` for all `x` in the boolean hypercube. + /// + /// `claim` equals `f(0) + f(1)`. + /// + /// For more context see docs of [`MultivariatePolyOracle::sum_as_poly_in_first_variable()`]. + fn sum_as_poly_in_first_variable( + h: &GkrMultivariatePolyOracle<'_, Self>, + claim: SecureField, + ) -> UnivariatePoly; +} + +/// Stores evaluations of [`eq(x, y)`] on all boolean hypercube points of the form +/// `x = (0, x_1, ..., x_{n-1})`. +/// +/// Evaluations are stored in bit-reversed order i.e. `evals[0] = eq((0, ..., 0, 0), y)`, +/// `evals[1] = eq((0, ..., 0, 1), y)`, etc. +/// +/// [`eq(x, y)`]: crate::core::lookups::utils::eq +#[derive(Educe)] +#[educe(Debug, Clone)] +pub struct EqEvals> { + y: Vec, + evals: Mle, +} + +impl EqEvals { + pub fn generate(y: &[SecureField]) -> Self { + let y = y.to_vec(); + + if y.is_empty() { + let evals = Mle::new([SecureField::one()].into_iter().collect()); + return Self { evals, y }; + } + + let evals = B::gen_eq_evals(&y[1..], eq(&[SecureField::zero()], &[y[0]])); + assert_eq!(evals.len(), 1 << (y.len() - 1)); + Self { evals, y } + } + + /// Returns the fixed vector `y` used to generate the evaluations. + pub fn y(&self) -> &[SecureField] { + &self.y + } +} + +impl> Deref for EqEvals { + type Target = Col; + + fn deref(&self) -> &Col { + &self.evals + } +} + +/// Represents a layer in a binary tree structured GKR circuit. +/// +/// Layers can contain multiple columns, for example [LogUp] which has separate columns for +/// numerators and denominators. +/// +/// [LogUp]: https://eprint.iacr.org/2023/1284.pdf +#[derive(Educe)] +#[educe(Debug, Clone)] +pub enum Layer { + GrandProduct(Mle), + LogUpGeneric { + numerators: Mle, + denominators: Mle, + }, + LogUpMultiplicities { + numerators: Mle, + denominators: Mle, + }, + /// All numerators implicitly equal "1". + LogUpSingles { + denominators: Mle, + }, +} + +impl Layer { + /// Returns the number of variables used to interpolate the layer's gate values. + pub fn n_variables(&self) -> usize { + match self { + Self::GrandProduct(mle) + | Self::LogUpSingles { denominators: mle } + | Self::LogUpMultiplicities { + denominators: mle, .. + } + | Self::LogUpGeneric { + denominators: mle, .. + } => mle.n_variables(), + } + } + + fn is_output_layer(&self) -> bool { + self.n_variables() == 0 + } + + /// Produces the next layer from the current layer. + /// + /// The next layer is strictly half the size of the current layer. + /// Returns [`None`] if called on an output layer. + pub fn next_layer(&self) -> Option { + if self.is_output_layer() { + return None; + } + + Some(B::next_layer(self)) + } + + /// Returns each column output if the layer is an output layer, otherwise returns an `Err`. + fn try_into_output_layer_values(self) -> Result, NotOutputLayerError> { + if !self.is_output_layer() { + return Err(NotOutputLayerError); + } + + Ok(match self { + Layer::LogUpSingles { denominators } => { + let numerator = SecureField::one(); + let denominator = denominators.at(0); + vec![numerator, denominator] + } + Layer::LogUpMultiplicities { + numerators, + denominators, + } => { + let numerator = numerators.at(0).into(); + let denominator = denominators.at(0); + vec![numerator, denominator] + } + Layer::LogUpGeneric { + numerators, + denominators, + } => { + let numerator = numerators.at(0); + let denominator = denominators.at(0); + vec![numerator, denominator] + } + Layer::GrandProduct(col) => { + vec![col.at(0)] + } + }) + } + + /// Returns a transformed layer with the first variable of each column fixed to `assignment`. + fn fix_first_variable(self, x0: SecureField) -> Self { + if self.n_variables() == 0 { + return self; + } + + match self { + Self::GrandProduct(mle) => Self::GrandProduct(mle.fix_first_variable(x0)), + Self::LogUpGeneric { + numerators, + denominators, + } => Self::LogUpGeneric { + numerators: numerators.fix_first_variable(x0), + denominators: denominators.fix_first_variable(x0), + }, + Self::LogUpMultiplicities { + numerators, + denominators, + } => Self::LogUpGeneric { + numerators: numerators.fix_first_variable(x0), + denominators: denominators.fix_first_variable(x0), + }, + Self::LogUpSingles { denominators } => Self::LogUpSingles { + denominators: denominators.fix_first_variable(x0), + }, + } + } + + /// Represents the next GKR layer evaluation as a multivariate polynomial which uses this GKR + /// layer as input. + /// + /// Layers can contain multiple columns `c_0, ..., c_{n-1}` with multivariate polynomial `g_i` + /// representing[^note] `c_i` in the next layer. These polynomials must be combined with + /// `lambda` into a single polynomial `h = g_0 + lambda * g_1 + ... + lambda^(n-1) * + /// g_{n-1}`. The oracle for `h` should be returned. + /// + /// # Optimization: precomputed [`eq(x, y)`] evals + /// + /// Let `y` be a fixed vector of length `m` and let `z` be a subvector comprising of the + /// last `k` elements of `y`. `h(x)` **must** equal some multivariate polynomial of the form + /// `eq(x, z) * p(x)`. A common operation will be computing the univariate polynomial `f(t) = + /// sum_x h(t, x)` for `x` in the boolean hypercube `{0, 1}^(k-1)`. + /// + /// `eq_evals` stores evaluations of `eq((0, x), y)` for `x` in a potentially extended boolean + /// hypercube `{0, 1}^{m-1}`. These evaluations, on the extended hypercube, can be used directly + /// in computing the sums of `h(x)`, however a correction factor must be applied to the final + /// sum which is handled by [`correct_sum_as_poly_in_first_variable()`] in `O(m)`. + /// + /// Being able to compute sums of `h(x)` using `eq_evals` in this way leads to a more efficient + /// implementation because the prover only has to generate `eq_evals` once for an entire batch + /// of multiple GKR layer instances. + /// + /// [`eq(x, y)`]: crate::core::lookups::utils::eq + /// [^note]: By "representing" we mean `g_i` agrees with the next layer's `c_i` on the boolean + /// hypercube that interpolates `c_i`. + fn into_multivariate_poly( + self, + lambda: SecureField, + eq_evals: &EqEvals, + ) -> GkrMultivariatePolyOracle<'_, B> { + GkrMultivariatePolyOracle { + eq_evals: Cow::Borrowed(eq_evals), + input_layer: self, + eq_fixed_var_correction: SecureField::one(), + lambda, + } + } + + /// Returns a copy of this layer with the [`CpuBackend`]. + /// + /// This operation is expensive but can be useful for small traces that are difficult to handle + /// depending on the backend. For example, the SIMD backend offloads to the CPU backend when + /// trace length becomes smaller than the SIMD lane count. + pub fn to_cpu(&self) -> Layer { + match self { + Layer::GrandProduct(mle) => Layer::GrandProduct(Mle::new(mle.to_cpu())), + Layer::LogUpGeneric { + numerators, + denominators, + } => Layer::LogUpGeneric { + numerators: Mle::new(numerators.to_cpu()), + denominators: Mle::new(denominators.to_cpu()), + }, + Layer::LogUpMultiplicities { + numerators, + denominators, + } => Layer::LogUpMultiplicities { + numerators: Mle::new(numerators.to_cpu()), + denominators: Mle::new(denominators.to_cpu()), + }, + Layer::LogUpSingles { denominators } => Layer::LogUpSingles { + denominators: Mle::new(denominators.to_cpu()), + }, + } + } +} + +#[derive(Debug)] +struct NotOutputLayerError; + +/// Multivariate polynomial `P` that expresses the relation between two consecutive GKR layers. +/// +/// When the input layer is [`Layer::GrandProduct`] (represented by multilinear column `inp`) +/// the polynomial represents: +/// +/// ```text +/// P(x) = eq(x, y) * inp(x, 0) * inp(x, 1) +/// ``` +/// +/// When the input layer is LogUp (represented by multilinear columns `inp_numer` and +/// `inp_denom`) the polynomial represents: +/// +/// ```text +/// numer(x) = inp_numer(x, 0) * inp_denom(x, 1) + inp_numer(x, 1) * inp_denom(x, 0) +/// denom(x) = inp_denom(x, 0) * inp_denom(x, 1) +/// +/// P(x) = eq(x, y) * (numer(x) + lambda * denom(x)) +/// ``` +pub struct GkrMultivariatePolyOracle<'a, B: GkrOps> { + /// `eq_evals` passed by `Layer::into_multivariate_poly()`. + pub eq_evals: Cow<'a, EqEvals>, + pub input_layer: Layer, + pub eq_fixed_var_correction: SecureField, + /// Used by LogUp to perform a random linear combination of the numerators and denominators. + pub lambda: SecureField, +} + +impl<'a, B: GkrOps> MultivariatePolyOracle for GkrMultivariatePolyOracle<'a, B> { + fn n_variables(&self) -> usize { + self.input_layer.n_variables() - 1 + } + + fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly { + B::sum_as_poly_in_first_variable(self, claim) + } + + fn fix_first_variable(self, challenge: SecureField) -> Self { + if self.is_constant() { + return self; + } + + let z0 = self.eq_evals.y()[self.eq_evals.y().len() - self.n_variables()]; + let eq_fixed_var_correction = self.eq_fixed_var_correction * eq(&[challenge], &[z0]); + + Self { + eq_evals: self.eq_evals, + eq_fixed_var_correction, + input_layer: self.input_layer.fix_first_variable(challenge), + lambda: self.lambda, + } + } +} + +impl<'a, B: GkrOps> GkrMultivariatePolyOracle<'a, B> { + fn is_constant(&self) -> bool { + self.n_variables() == 0 + } + + /// Returns all input layer columns restricted to a line. + /// + /// Let `l` be the line satisfying `l(0) = b*` and `l(1) = c*`. Oracles that represent constants + /// are expressed by values `c_i(b*)` and `c_i(c*)` where `c_i` represents the input GKR layer's + /// `i`th column (for binary tree GKR `b* = (r, 0)`, `c* = (r, 1)`). + /// + /// If this oracle represents a constant, then each `c_i` restricted to `l` is returned. + /// Otherwise, an [`Err`] is returned. + /// + /// For more context see page 64. + fn try_into_mask(self) -> Result { + if !self.is_constant() { + return Err(NotConstantPolyError); + } + + let columns = match self.input_layer { + Layer::GrandProduct(mle) => vec![mle.to_cpu().try_into().unwrap()], + Layer::LogUpGeneric { + numerators, + denominators, + } => { + let numerators = numerators.to_cpu().try_into().unwrap(); + let denominators = denominators.to_cpu().try_into().unwrap(); + vec![numerators, denominators] + } + // Should never get called. + Layer::LogUpMultiplicities { .. } => unimplemented!(), + Layer::LogUpSingles { denominators } => { + let numerators = [SecureField::one(); 2]; + let denominators = denominators.to_cpu().try_into().unwrap(); + vec![numerators, denominators] + } + }; + + Ok(GkrMask::new(columns)) + } + + /// Returns a copy of this oracle with the [`CpuBackend`]. + /// + /// This operation is expensive but can be useful for small oracles that are difficult to handle + /// depending on the backend. For example, the SIMD backend offloads to the CPU backend when + /// trace length becomes smaller than the SIMD lane count. + pub fn to_cpu(&self) -> GkrMultivariatePolyOracle<'a, CpuBackend> { + // TODO(andrew): This block is not ideal. + let n_eq_evals = 1 << (self.n_variables() - 1); + let eq_evals = Cow::Owned(EqEvals { + evals: Mle::new((0..n_eq_evals).map(|i| self.eq_evals.at(i)).collect()), + y: self.eq_evals.y.to_vec(), + }); + + GkrMultivariatePolyOracle { + eq_evals, + eq_fixed_var_correction: self.eq_fixed_var_correction, + input_layer: self.input_layer.to_cpu(), + lambda: self.lambda, + } + } +} + +/// Error returned when a polynomial is expected to be constant but it is not. +#[derive(Debug, Error)] +#[error("polynomial is not constant")] +pub struct NotConstantPolyError; + +/// Batch proves lookup circuits with GKR. +/// +/// The input layers should be committed to the channel before calling this function. +// GKR algorithm: (page 64) +pub fn prove_batch( + channel: &mut impl Channel, + input_layer_by_instance: Vec>, +) -> (GkrBatchProof, GkrArtifact) { + let n_instances = input_layer_by_instance.len(); + let n_layers_by_instance = input_layer_by_instance + .iter() + .map(|l| l.n_variables()) + .collect_vec(); + let n_layers = *n_layers_by_instance.iter().max().unwrap(); + + // Evaluate all instance circuits and collect the layer values. + let mut layers_by_instance = input_layer_by_instance + .into_iter() + .map(|input_layer| gen_layers(input_layer).into_iter().rev()) + .collect_vec(); + + let mut output_claims_by_instance = vec![None; n_instances]; + let mut layer_masks_by_instance = (0..n_instances).map(|_| Vec::new()).collect_vec(); + let mut sumcheck_proofs = Vec::new(); + + let mut ood_point = Vec::new(); + let mut claims_to_verify_by_instance = vec![None; n_instances]; + + for layer in 0..n_layers { + let n_remaining_layers = n_layers - layer; + + // Check all the instances for output layers. + for (instance, layers) in layers_by_instance.iter_mut().enumerate() { + if n_layers_by_instance[instance] == n_remaining_layers { + let output_layer = layers.next().unwrap(); + let output_layer_values = output_layer.try_into_output_layer_values().unwrap(); + claims_to_verify_by_instance[instance] = Some(output_layer_values.clone()); + output_claims_by_instance[instance] = Some(output_layer_values); + } + } + + // Seed the channel with layer claims. + for claims_to_verify in claims_to_verify_by_instance.iter().flatten() { + channel.mix_felts(claims_to_verify); + } + + let eq_evals = EqEvals::generate(&ood_point); + let sumcheck_alpha = channel.draw_felt(); + let instance_lambda = channel.draw_felt(); + + let mut sumcheck_oracles = Vec::new(); + let mut sumcheck_claims = Vec::new(); + let mut sumcheck_instances = Vec::new(); + + // Create the multivariate polynomial oracles used with sumcheck. + for (instance, claims_to_verify) in claims_to_verify_by_instance.iter().enumerate() { + if let Some(claims_to_verify) = claims_to_verify { + let layer = layers_by_instance[instance].next().unwrap(); + sumcheck_oracles.push(layer.into_multivariate_poly(instance_lambda, &eq_evals)); + sumcheck_claims.push(random_linear_combination(claims_to_verify, instance_lambda)); + sumcheck_instances.push(instance); + } + } + + let (sumcheck_proof, sumcheck_ood_point, constant_poly_oracles, _) = + sumcheck::prove_batch(sumcheck_claims, sumcheck_oracles, sumcheck_alpha, channel); + + sumcheck_proofs.push(sumcheck_proof); + + let masks = constant_poly_oracles + .into_iter() + .map(|oracle| oracle.try_into_mask().unwrap()) + .collect_vec(); + + // Seed the channel with the layer masks. + for (&instance, mask) in zip(&sumcheck_instances, &masks) { + channel.mix_felts(mask.columns().flatten()); + layer_masks_by_instance[instance].push(mask.clone()); + } + + let challenge = channel.draw_felt(); + ood_point = sumcheck_ood_point; + ood_point.push(challenge); + + // Set the claims to prove in the layer above. + for (instance, mask) in zip(sumcheck_instances, masks) { + claims_to_verify_by_instance[instance] = Some(mask.reduce_at_point(challenge)); + } + } + + let output_claims_by_instance = output_claims_by_instance + .into_iter() + .map(Option::unwrap) + .collect(); + + let claims_to_verify_by_instance = claims_to_verify_by_instance + .into_iter() + .map(Option::unwrap) + .collect(); + + let proof = GkrBatchProof { + sumcheck_proofs, + layer_masks_by_instance, + output_claims_by_instance, + }; + + let artifact = GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: n_layers_by_instance, + }; + + (proof, artifact) +} + +/// Executes the GKR circuit on the input layer and returns all the circuit's layers. +fn gen_layers(input_layer: Layer) -> Vec> { + let n_variables = input_layer.n_variables(); + let layers = successors(Some(input_layer), |layer| layer.next_layer()).collect_vec(); + assert_eq!(layers.len(), n_variables + 1); + layers +} + +/// Computes `r(t) = sum_x eq((t, x), y[-k:]) * p(t, x)` from evaluations of +/// `f(t) = sum_x eq(({0}^(n - k), 0, x), y) * p(t, x)`. +/// +/// Note `claim` must equal `r(0) + r(1)` and `r` must have degree <= 3. +/// +/// For more context see `Layer::into_multivariate_poly()` docs. +/// See also (section 3.2). +pub fn correct_sum_as_poly_in_first_variable( + f_at_0: SecureField, + f_at_2: SecureField, + claim: SecureField, + y: &[SecureField], + k: usize, +) -> UnivariatePoly { + assert_ne!(k, 0); + let n = y.len(); + assert!(k <= n); + + // We evaluated `f(0)` and `f(2)` - the inputs. + // We want to compute `r(t) = f(t) * eq(t, y[n - k]) / eq(0, y[:n - k + 1])`. + let a_const = eq(&vec![SecureField::zero(); n - k + 1], &y[..n - k + 1]).inverse(); + + // Find the additional root of `r(t)`, by finding the root of `eq(t, y[n - k])`: + // 0 = eq(t, y[n - k]) + // = t * y[n - k] + (1 - t)(1 - y[n - k]) + // = 1 - y[n - k] - t(1 - 2 * y[n - k]) + // => t = (1 - y[n - k]) / (1 - 2 * y[n - k]) + // = b + let b_const = (SecureField::one() - y[n - k]) / (SecureField::one() - y[n - k].double()); + + // We get that `r(t) = f(t) * eq(t, y[n - k]) * a`. + let r_at_0 = f_at_0 * eq(&[SecureField::zero()], &[y[n - k]]) * a_const; + let r_at_1 = claim - r_at_0; + let r_at_2 = f_at_2 * eq(&[BaseField::from(2).into()], &[y[n - k]]) * a_const; + let r_at_b = SecureField::zero(); + + // Interpolate. + UnivariatePoly::interpolate_lagrange( + &[ + SecureField::zero(), + SecureField::one(), + SecureField::from(BaseField::from(2)), + b_const, + ], + &[r_at_0, r_at_1, r_at_2, r_at_b], + ) +} diff --git a/Stwo_wrapper/crates/prover/src/core/lookups/gkr_verifier.rs b/Stwo_wrapper/crates/prover/src/core/lookups/gkr_verifier.rs new file mode 100644 index 0000000..b65ceb1 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/lookups/gkr_verifier.rs @@ -0,0 +1,357 @@ +//! GKR batch verifier for Grand Product and LogUp lookup arguments. +use thiserror::Error; + +use super::sumcheck::{SumcheckError, SumcheckProof}; +use super::utils::{eq, fold_mle_evals, random_linear_combination}; +use crate::core::channel::Channel; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::sumcheck; +use crate::core::lookups::utils::Fraction; + +/// Partially verifies a batch GKR proof. +/// +/// On successful verification the function returns a [`GkrArtifact`] which stores the out-of-domain +/// point and claimed evaluations in the input layer columns for each instance at the OOD point. +/// These claimed evaluations are not checked in this function - hence partial verification. +pub fn partially_verify_batch( + gate_by_instance: Vec, + proof: &GkrBatchProof, + channel: &mut impl Channel, +) -> Result { + let GkrBatchProof { + sumcheck_proofs, + layer_masks_by_instance, + output_claims_by_instance, + } = proof; + + if layer_masks_by_instance.len() != output_claims_by_instance.len() { + return Err(GkrError::MalformedProof); + } + + let n_instances = layer_masks_by_instance.len(); + let instance_n_layers = |instance: usize| layer_masks_by_instance[instance].len(); + let n_layers = (0..n_instances).map(instance_n_layers).max().unwrap(); + + if n_layers != sumcheck_proofs.len() { + return Err(GkrError::MalformedProof); + } + + if gate_by_instance.len() != n_instances { + return Err(GkrError::NumInstancesMismatch { + given: gate_by_instance.len(), + proof: n_instances, + }); + } + + let mut ood_point = vec![]; + let mut claims_to_verify_by_instance = vec![None; n_instances]; + + for (layer, sumcheck_proof) in sumcheck_proofs.iter().enumerate() { + let n_remaining_layers = n_layers - layer; + + // Check for output layers. + for instance in 0..n_instances { + if instance_n_layers(instance) == n_remaining_layers { + let output_claims = output_claims_by_instance[instance].clone(); + claims_to_verify_by_instance[instance] = Some(output_claims); + } + } + + // Seed the channel with layer claims. + for claims_to_verify in claims_to_verify_by_instance.iter().flatten() { + channel.mix_felts(claims_to_verify); + } + + let sumcheck_alpha = channel.draw_felt(); + let instance_lambda = channel.draw_felt(); + + let mut sumcheck_claims = Vec::new(); + let mut sumcheck_instances = Vec::new(); + + // Prepare the sumcheck claim. + for (instance, claims_to_verify) in claims_to_verify_by_instance.iter().enumerate() { + if let Some(claims_to_verify) = claims_to_verify { + let n_unused_variables = n_layers - instance_n_layers(instance); + let doubling_factor = BaseField::from(1 << n_unused_variables); + let claim = + random_linear_combination(claims_to_verify, instance_lambda) * doubling_factor; + sumcheck_claims.push(claim); + sumcheck_instances.push(instance); + } + } + + let sumcheck_claim = random_linear_combination(&sumcheck_claims, sumcheck_alpha); + let (sumcheck_ood_point, sumcheck_eval) = + sumcheck::partially_verify(sumcheck_claim, sumcheck_proof, channel) + .map_err(|source| GkrError::InvalidSumcheck { layer, source })?; + + let mut layer_evals = Vec::new(); + + // Evaluate the circuit locally at sumcheck OOD point. + for &instance in &sumcheck_instances { + let n_unused = n_layers - instance_n_layers(instance); + let mask = &layer_masks_by_instance[instance][layer - n_unused]; + let gate = &gate_by_instance[instance]; + let gate_output = gate.eval(mask).map_err(|InvalidNumMaskColumnsError| { + let instance_layer = instance_n_layers(layer) - n_remaining_layers; + GkrError::InvalidMask { + instance, + instance_layer, + } + })?; + // TODO: Consider simplifying the code by just using the same eq eval for all instances + // regardless of size. + let eq_eval = eq(&ood_point[n_unused..], &sumcheck_ood_point[n_unused..]); + layer_evals.push(eq_eval * random_linear_combination(&gate_output, instance_lambda)); + } + + let layer_eval = random_linear_combination(&layer_evals, sumcheck_alpha); + + if sumcheck_eval != layer_eval { + return Err(GkrError::CircuitCheckFailure { + claim: sumcheck_eval, + output: layer_eval, + layer, + }); + } + + // Seed the channel with the layer masks. + for &instance in &sumcheck_instances { + let n_unused = n_layers - instance_n_layers(instance); + let mask = &layer_masks_by_instance[instance][layer - n_unused]; + channel.mix_felts(mask.columns().flatten()); + } + + // Set the OOD evaluation point for layer above. + let challenge = channel.draw_felt(); + ood_point = sumcheck_ood_point; + ood_point.push(challenge); + + // Set the claims to verify in the layer above. + for instance in sumcheck_instances { + let n_unused = n_layers - instance_n_layers(instance); + let mask = &layer_masks_by_instance[instance][layer - n_unused]; + claims_to_verify_by_instance[instance] = Some(mask.reduce_at_point(challenge)); + } + } + + let claims_to_verify_by_instance = claims_to_verify_by_instance + .into_iter() + .map(Option::unwrap) + .collect(); + + Ok(GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance: (0..n_instances).map(instance_n_layers).collect(), + }) +} + +/// Batch GKR proof. +pub struct GkrBatchProof { + /// Sum-check proof for each layer. + pub sumcheck_proofs: Vec, + /// Mask for each layer for each instance. + pub layer_masks_by_instance: Vec>, + /// Column circuit outputs for each instance. + pub output_claims_by_instance: Vec>, +} + +/// Values of interest obtained from the execution of the GKR protocol. +pub struct GkrArtifact { + /// Out-of-domain (OOD) point for evaluating columns in the input layer. + pub ood_point: Vec, + /// The claimed evaluation at `ood_point` for each column in the input layer of each instance. + pub claims_to_verify_by_instance: Vec>, + /// The number of variables that interpolate the input layer of each instance. + pub n_variables_by_instance: Vec, +} + +/// Defines how a circuit operates locally on two input rows to produce a single output row. +/// This local 2-to-1 constraint is what gives the whole circuit its "binary tree" structure. +/// +/// Binary tree structured circuits have a highly regular wiring pattern that fit the structure of +/// the circuits defined in [Thaler13] which allow for efficient linear time (linear in size of the +/// circuit) GKR prover implementations. +/// +/// [Thaler13]: https://eprint.iacr.org/2013/351.pdf +#[derive(Debug, Clone, Copy)] +pub enum Gate { + LogUp, + GrandProduct, +} + +impl Gate { + /// Returns the output after applying the gate to the mask. + fn eval(&self, mask: &GkrMask) -> Result, InvalidNumMaskColumnsError> { + Ok(match self { + Self::LogUp => { + if mask.columns().len() != 2 { + return Err(InvalidNumMaskColumnsError); + } + + let [numerator_a, numerator_b] = mask.columns()[0]; + let [denominator_a, denominator_b] = mask.columns()[1]; + + let a = Fraction::new(numerator_a, denominator_a); + let b = Fraction::new(numerator_b, denominator_b); + let res = a + b; + + vec![res.numerator, res.denominator] + } + Self::GrandProduct => { + if mask.columns().len() != 1 { + return Err(InvalidNumMaskColumnsError); + } + + let [a, b] = mask.columns()[0]; + vec![a * b] + } + }) + } +} + +/// Mask has an invalid number of columns +#[derive(Debug)] +struct InvalidNumMaskColumnsError; + +/// Stores two evaluations of each column in a GKR layer. +#[derive(Debug, Clone)] +pub struct GkrMask { + columns: Vec<[SecureField; 2]>, +} + +impl GkrMask { + pub fn new(columns: Vec<[SecureField; 2]>) -> Self { + Self { columns } + } + + pub fn to_rows(&self) -> [Vec; 2] { + self.columns.iter().map(|[a, b]| (a, b)).unzip().into() + } + + pub fn columns(&self) -> &[[SecureField; 2]] { + &self.columns + } + + /// Returns all `p_i(x)` where `p_i` interpolates column `i` of the mask on `{0, 1}`. + pub fn reduce_at_point(&self, x: SecureField) -> Vec { + self.columns + .iter() + .map(|&[v0, v1]| fold_mle_evals(x, v0, v1)) + .collect() + } +} + +/// Error encountered during GKR protocol verification. +#[derive(Error, Debug)] +pub enum GkrError { + /// The proof is malformed. + #[error("proof data is invalid")] + MalformedProof, + /// Mask has an invalid number of columns. + #[error("mask in layer {instance_layer} of instance {instance} is invalid")] + InvalidMask { + instance: usize, + /// Layer of the instance (but not necessarily the batch). + instance_layer: LayerIndex, + }, + /// There is a mismatch between the number of instances in the proof and the number of + /// instances passed for verification. + #[error("provided an invalid number of instances (given {given}, proof expects {proof})")] + NumInstancesMismatch { given: usize, proof: usize }, + /// There was an error with one of the sumcheck proofs. + #[error("sum-check invalid in layer {layer}: {source}")] + InvalidSumcheck { + layer: LayerIndex, + source: SumcheckError, + }, + /// The circuit polynomial the verifier evaluated doesn't match claim from sumcheck. + #[error("circuit check failed in layer {layer} (calculated {output}, claim {claim})")] + CircuitCheckFailure { + claim: SecureField, + output: SecureField, + layer: LayerIndex, + }, +} + +/// GKR layer index where 0 corresponds to the output layer. +pub type LayerIndex = usize; + +#[cfg(test)] +mod tests { + use super::{partially_verify_batch, Gate, GkrArtifact, GkrError}; + use crate::core::backend::CpuBackend; + use crate::core::channel::Channel; + use crate::core::fields::qm31::SecureField; + use crate::core::lookups::gkr_prover::{prove_batch, Layer}; + use crate::core::lookups::mle::Mle; + use crate::core::test_utils::test_channel; + + #[test] + fn prove_batch_works() -> Result<(), GkrError> { + const LOG_N: usize = 5; + let mut channel = test_channel(); + let col0 = Mle::::new(channel.draw_felts(1 << LOG_N)); + let col1 = Mle::::new(channel.draw_felts(1 << LOG_N)); + let product0 = col0.iter().product::(); + let product1 = col1.iter().product::(); + let input_layers = vec![ + Layer::GrandProduct(col0.clone()), + Layer::GrandProduct(col1.clone()), + ]; + let (proof, _) = prove_batch(&mut test_channel(), input_layers); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance, + } = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?; + + assert_eq!(n_variables_by_instance, [LOG_N, LOG_N]); + assert_eq!(proof.output_claims_by_instance.len(), 2); + assert_eq!(claims_to_verify_by_instance.len(), 2); + assert_eq!(proof.output_claims_by_instance[0], &[product0]); + assert_eq!(proof.output_claims_by_instance[1], &[product1]); + let claim0 = &claims_to_verify_by_instance[0]; + let claim1 = &claims_to_verify_by_instance[1]; + assert_eq!(claim0, &[col0.eval_at_point(&ood_point)]); + assert_eq!(claim1, &[col1.eval_at_point(&ood_point)]); + Ok(()) + } + + #[test] + fn prove_batch_with_different_sizes_works() -> Result<(), GkrError> { + const LOG_N0: usize = 5; + const LOG_N1: usize = 7; + let mut channel = test_channel(); + let col0 = Mle::::new(channel.draw_felts(1 << LOG_N0)); + let col1 = Mle::::new(channel.draw_felts(1 << LOG_N1)); + let product0 = col0.iter().product::(); + let product1 = col1.iter().product::(); + let input_layers = vec![ + Layer::GrandProduct(col0.clone()), + Layer::GrandProduct(col1.clone()), + ]; + let (proof, _) = prove_batch(&mut test_channel(), input_layers); + + let GkrArtifact { + ood_point, + claims_to_verify_by_instance, + n_variables_by_instance, + } = partially_verify_batch(vec![Gate::GrandProduct; 2], &proof, &mut test_channel())?; + + assert_eq!(n_variables_by_instance, [LOG_N0, LOG_N1]); + assert_eq!(proof.output_claims_by_instance.len(), 2); + assert_eq!(claims_to_verify_by_instance.len(), 2); + assert_eq!(proof.output_claims_by_instance[0], &[product0]); + assert_eq!(proof.output_claims_by_instance[1], &[product1]); + let claim0 = &claims_to_verify_by_instance[0]; + let claim1 = &claims_to_verify_by_instance[1]; + let n_vars = ood_point.len(); + assert_eq!(claim0, &[col0.eval_at_point(&ood_point[n_vars - LOG_N0..])]); + assert_eq!(claim1, &[col1.eval_at_point(&ood_point[n_vars - LOG_N1..])]); + Ok(()) + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/lookups/mle.rs b/Stwo_wrapper/crates/prover/src/core/lookups/mle.rs new file mode 100644 index 0000000..7449f40 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/lookups/mle.rs @@ -0,0 +1,106 @@ +use std::ops::{Deref, DerefMut}; + +use educe::Educe; + +use crate::core::backend::{Col, Column, ColumnOps}; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::Field; + +pub trait MleOps: ColumnOps + Sized { + /// Returns a transformed [`Mle`] where the first variable is fixed to `assignment`. + fn fix_first_variable(mle: Mle, assignment: SecureField) -> Mle + where + Self: MleOps; +} + +/// Multilinear Extension stored as evaluations of a multilinear polynomial over the boolean +/// hypercube in bit-reversed order. +#[derive(Educe)] +#[educe(Debug, Clone)] +pub struct Mle, F: Field> { + evals: Col, +} + +impl, F: Field> Mle { + /// Creates a [`Mle`] from evaluations of a multilinear polynomial on the boolean hypercube. + /// + /// # Panics + /// + /// Panics if the number of evaluations is not a power of two. + pub fn new(evals: Col) -> Self { + assert!(evals.len().is_power_of_two()); + Self { evals } + } + + pub fn into_evals(self) -> Col { + self.evals + } + + /// Returns a transformed polynomial where the first variable is fixed to `assignment`. + pub fn fix_first_variable(self, assignment: SecureField) -> Mle + where + B: MleOps, + { + B::fix_first_variable(self, assignment) + } + + /// Returns the number of variables in the polynomial. + pub fn n_variables(&self) -> usize { + self.evals.len().ilog2() as usize + } +} + +impl, F: Field> Deref for Mle { + type Target = Col; + + fn deref(&self) -> &Col { + &self.evals + } +} + +impl, F: Field> DerefMut for Mle { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.evals + } +} + +#[cfg(test)] +mod test { + use super::{Mle, MleOps}; + use crate::core::backend::Column; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::{ExtensionOf, Field}; + + impl Mle + where + F: Field, + SecureField: ExtensionOf, + B: MleOps, + { + /// Evaluates the multilinear polynomial at `point`. + pub(crate) fn eval_at_point(&self, point: &[SecureField]) -> SecureField { + pub fn eval(mle_evals: &[SecureField], p: &[SecureField]) -> SecureField { + match p { + [] => mle_evals[0], + &[p_i, ref p @ ..] => { + let (lhs, rhs) = mle_evals.split_at(mle_evals.len() / 2); + let lhs_eval = eval(lhs, p); + let rhs_eval = eval(rhs, p); + // Equivalent to `eq(0, p_i) * lhs_eval + eq(1, p_i) * rhs_eval`. + p_i * (rhs_eval - lhs_eval) + lhs_eval + } + } + } + + let mle_evals = self + .clone() + .into_evals() + .to_cpu() + .into_iter() + .map(|v| v.into()) + .collect::>(); + + eval(&mle_evals, point) + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/lookups/mod.rs b/Stwo_wrapper/crates/prover/src/core/lookups/mod.rs new file mode 100644 index 0000000..8f7351a --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/lookups/mod.rs @@ -0,0 +1,5 @@ +pub mod gkr_prover; +pub mod gkr_verifier; +pub mod mle; +pub mod sumcheck; +pub mod utils; diff --git a/Stwo_wrapper/crates/prover/src/core/lookups/sumcheck.rs b/Stwo_wrapper/crates/prover/src/core/lookups/sumcheck.rs new file mode 100644 index 0000000..1df2451 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/lookups/sumcheck.rs @@ -0,0 +1,292 @@ +//! Sum-check protocol that proves and verifies claims about `sum_x g(x)` for all x in `{0, 1}^n`. +//! +//! [`MultivariatePolyOracle`] provides methods for evaluating sums and making transformations on +//! `g` in the context of the protocol. It is intended to be used in conjunction with +//! [`prove_batch()`] to generate proofs. + +use std::iter::zip; + +use itertools::Itertools; +use num_traits::{One, Zero}; +use thiserror::Error; + +use super::utils::UnivariatePoly; +use crate::core::channel::Channel; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; + +/// Something that can be seen as a multivariate polynomial `g(x_0, ..., x_{n-1})`. +pub trait MultivariatePolyOracle: Sized { + /// Returns the number of variables in `g`. + fn n_variables(&self) -> usize; + + /// Computes the sum of `g(x_0, x_1, ..., x_{n-1})` over all `(x_1, ..., x_{n-1})` in + /// `{0, 1}^(n-1)`, effectively reducing the sum over `g` to a univariate polynomial in `x_0`. + /// + /// `claim` equals the claimed sum of `g(x_0, x_2, ..., x_{n-1})` over all `(x_0, ..., x_{n-1})` + /// in `{0, 1}^n`. Knowing the claim can help optimize the implementation: Let `f` denote the + /// univariate polynomial we want to return. Note that `claim = f(0) + f(1)` so knowing `claim` + /// and either `f(0)` or `f(1)` allows determining the other. + fn sum_as_poly_in_first_variable(&self, claim: SecureField) -> UnivariatePoly; + + /// Returns a transformed oracle where the first variable of `g` is fixed to `challenge`. + /// + /// The returned oracle represents the multivariate polynomial `g'`, defined as + /// `g'(x_1, ..., x_{n-1}) = g(challenge, x_1, ..., x_{n-1})`. + fn fix_first_variable(self, challenge: SecureField) -> Self; +} + +/// Performs sum-check on a random linear combinations of multiple multivariate polynomials. +/// +/// Let the multivariate polynomials be `g_0, ..., g_{n-1}`. A single sum-check is performed on +/// multivariate polynomial `h = g_0 + lambda * g_1 + ... + lambda^(n-1) * g_{n-1}`. The `g_i`s do +/// not need to have the same number of variables. `g_i`s with less variables are folded in the +/// latest possible round of the protocol. For instance with `g_0(x, y, z)` and `g_1(x, y)` +/// sum-check is performed on `h(x, y, z) = g_0(x, y, z) + lambda * g_1(y, z)`. Claim `c_i` should +/// equal the claimed sum of `g_i(x_0, ..., x_{j-1})` over all `(x_0, ..., x_{j-1})` in `{0, 1}^j`. +/// +/// The degree of each `g_i` should not exceed [`MAX_DEGREE`] in any variable. The sum-check proof +/// of `h`, list of challenges (variable assignment) and the constant oracles (i.e. the `g_i` with +/// all variables fixed to the their corresponding challenges) are returned. +/// +/// Output is of the form: `(proof, variable_assignment, constant_poly_oracles, claimed_evals)` +/// +/// # Panics +/// +/// Panics if: +/// - No multivariate polynomials are provided. +/// - There aren't the same number of multivariate polynomials and claims. +/// - The degree of any multivariate polynomial exceeds [`MAX_DEGREE`] in any variable. +/// - The round polynomials are inconsistent with their corresponding claimed sum on `0` and `1`. +// TODO: Consider returning constant oracles as separate type. +pub fn prove_batch( + mut claims: Vec, + mut multivariate_polys: Vec, + lambda: SecureField, + channel: &mut impl Channel, +) -> (SumcheckProof, Vec, Vec, Vec) { + let n_variables = multivariate_polys.iter().map(O::n_variables).max().unwrap(); + assert_eq!(claims.len(), multivariate_polys.len()); + + let mut round_polys = Vec::new(); + let mut assignment = Vec::new(); + + // Update the claims for the sum over `h`'s hypercube. + for (claim, multivariate_poly) in zip(&mut claims, &multivariate_polys) { + let n_unused_variables = n_variables - multivariate_poly.n_variables(); + *claim *= BaseField::from(1 << n_unused_variables); + } + + // Prove sum-check rounds + for round in 0..n_variables { + let n_remaining_rounds = n_variables - round; + + let this_round_polys = zip(&multivariate_polys, &claims) + .enumerate() + .map(|(i, (multivariate_poly, &claim))| { + let round_poly = if n_remaining_rounds == multivariate_poly.n_variables() { + multivariate_poly.sum_as_poly_in_first_variable(claim) + } else { + (claim / BaseField::from(2)).into() + }; + + let eval_at_0 = round_poly.eval_at_point(SecureField::zero()); + let eval_at_1 = round_poly.eval_at_point(SecureField::one()); + assert_eq!(eval_at_0 + eval_at_1, claim, "i={i}, round={round}"); + assert!(round_poly.degree() <= MAX_DEGREE, "i={i}, round={round}"); + + round_poly + }) + .collect_vec(); + + let round_poly = random_linear_combination(&this_round_polys, lambda); + + channel.mix_felts(&round_poly); + + let challenge = channel.draw_felt(); + + claims = this_round_polys + .iter() + .map(|round_poly| round_poly.eval_at_point(challenge)) + .collect(); + + multivariate_polys = multivariate_polys + .into_iter() + .map(|multivariate_poly| { + if n_remaining_rounds != multivariate_poly.n_variables() { + return multivariate_poly; + } + + multivariate_poly.fix_first_variable(challenge) + }) + .collect(); + + round_polys.push(round_poly); + assignment.push(challenge); + } + + let proof = SumcheckProof { round_polys }; + + (proof, assignment, multivariate_polys, claims) +} + +/// Returns `p_0 + alpha * p_1 + ... + alpha^(n-1) * p_{n-1}`. +fn random_linear_combination( + polys: &[UnivariatePoly], + alpha: SecureField, +) -> UnivariatePoly { + polys + .iter() + .rfold(Zero::zero(), |acc, poly| acc * alpha + poly.clone()) +} + +/// Partially verifies a sum-check proof. +/// +/// Only "partial" since it does not fully verify the prover's claimed evaluation on the variable +/// assignment but checks if the sum of the round polynomials evaluated on `0` and `1` matches the +/// claim for each round. If the proof passes these checks, the variable assignment and the prover's +/// claimed evaluation are returned for the caller to validate otherwise an [`Err`] is returned. +/// +/// Output is of the form `(variable_assignment, claimed_eval)`. +pub fn partially_verify( + mut claim: SecureField, + proof: &SumcheckProof, + channel: &mut impl Channel, +) -> Result<(Vec, SecureField), SumcheckError> { + let mut assignment = Vec::new(); + + for (round, round_poly) in proof.round_polys.iter().enumerate() { + if round_poly.degree() > MAX_DEGREE { + return Err(SumcheckError::DegreeInvalid { round }); + } + + // TODO: optimize this by sending one less coefficient, and computing it from the + // claim, instead of checking the claim. (Can also be done by quotienting). + let sum = round_poly.eval_at_point(Zero::zero()) + round_poly.eval_at_point(One::one()); + + if claim != sum { + return Err(SumcheckError::SumInvalid { claim, sum, round }); + } + + channel.mix_felts(round_poly); + let challenge = channel.draw_felt(); + claim = round_poly.eval_at_point(challenge); + assignment.push(challenge); + } + + Ok((assignment, claim)) +} + +#[derive(Debug, Clone)] +pub struct SumcheckProof { + pub round_polys: Vec>, +} + +/// Max degree of polynomials the verifier accepts in each round of the protocol. +pub const MAX_DEGREE: usize = 3; + +/// Sum-check protocol verification error. +#[derive(Error, Debug)] +pub enum SumcheckError { + #[error("degree of the polynomial in round {round} is too high")] + DegreeInvalid { round: RoundIndex }, + #[error("sum does not match the claim in round {round} (sum {sum}, claim {claim})")] + SumInvalid { + claim: SecureField, + sum: SecureField, + round: RoundIndex, + }, +} + +/// Sum-check round index where 0 corresponds to the first round. +pub type RoundIndex = usize; + +#[cfg(test)] +mod tests { + + use num_traits::One; + + use crate::core::backend::CpuBackend; + use crate::core::channel::{Blake2sChannel, Channel}; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::Field; + use crate::core::lookups::mle::Mle; + use crate::core::lookups::sumcheck::{partially_verify, prove_batch}; + + #[test] + fn sumcheck_works() { + let values = test_channel().draw_felts(32); + let claim = values.iter().sum(); + let mle = Mle::::new(values); + let lambda = SecureField::one(); + let (proof, ..) = prove_batch(vec![claim], vec![mle.clone()], lambda, &mut test_channel()); + + let (assignment, eval) = partially_verify(claim, &proof, &mut test_channel()).unwrap(); + + assert_eq!(eval, mle.eval_at_point(&assignment)); + } + + #[test] + fn batch_sumcheck_works() { + let mut channel = test_channel(); + let values0 = channel.draw_felts(32); + let values1 = channel.draw_felts(32); + let claim0 = values0.iter().sum(); + let claim1 = values1.iter().sum(); + let mle0 = Mle::::new(values0.clone()); + let mle1 = Mle::::new(values1.clone()); + let lambda = channel.draw_felt(); + let claims = vec![claim0, claim1]; + let mles = vec![mle0.clone(), mle1.clone()]; + let (proof, ..) = prove_batch(claims, mles, lambda, &mut test_channel()); + + let claim = claim0 + lambda * claim1; + let (assignment, eval) = partially_verify(claim, &proof, &mut test_channel()).unwrap(); + + let eval0 = mle0.eval_at_point(&assignment); + let eval1 = mle1.eval_at_point(&assignment); + assert_eq!(eval, eval0 + lambda * eval1); + } + + #[test] + fn batch_sumcheck_with_different_n_variables() { + let mut channel = test_channel(); + let values0 = channel.draw_felts(64); + let values1 = channel.draw_felts(32); + let claim0 = values0.iter().sum(); + let claim1 = values1.iter().sum(); + let mle0 = Mle::::new(values0.clone()); + let mle1 = Mle::::new(values1.clone()); + let lambda = channel.draw_felt(); + let claims = vec![claim0, claim1]; + let mles = vec![mle0.clone(), mle1.clone()]; + let (proof, ..) = prove_batch(claims, mles, lambda, &mut test_channel()); + + let claim = claim0 + lambda * claim1.double(); + let (assignment, eval) = partially_verify(claim, &proof, &mut test_channel()).unwrap(); + + let eval0 = mle0.eval_at_point(&assignment); + let eval1 = mle1.eval_at_point(&assignment[1..]); + assert_eq!(eval, eval0 + lambda * eval1); + } + + #[test] + fn invalid_sumcheck_proof_fails() { + let values = test_channel().draw_felts(8); + let claim = values.iter().sum::(); + let lambda = SecureField::one(); + // Compromise the first value. + let mut invalid_values = values; + invalid_values[0] += SecureField::one(); + let invalid_claim = vec![invalid_values.iter().sum::()]; + let invalid_mle = vec![Mle::::new(invalid_values.clone())]; + let (invalid_proof, ..) = + prove_batch(invalid_claim, invalid_mle, lambda, &mut test_channel()); + + assert!(partially_verify(claim, &invalid_proof, &mut test_channel()).is_err()); + } + + fn test_channel() -> Blake2sChannel { + Blake2sChannel::default() + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/lookups/utils.rs b/Stwo_wrapper/crates/prover/src/core/lookups/utils.rs new file mode 100644 index 0000000..85ea4c3 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/lookups/utils.rs @@ -0,0 +1,356 @@ +use std::iter::{zip, Sum}; +use std::ops::{Add, Deref, Mul, Neg, Sub}; + +use num_traits::{One, Zero}; + +use crate::core::fields::qm31::SecureField; +use crate::core::fields::{ExtensionOf, Field}; + +/// Univariate polynomial stored as coefficients in the monomial basis. +#[derive(Debug, Clone)] +pub struct UnivariatePoly(Vec); + +impl UnivariatePoly { + pub fn new(coeffs: Vec) -> Self { + let mut polynomial = Self(coeffs); + polynomial.truncate_leading_zeros(); + polynomial + } + + pub fn eval_at_point(&self, x: F) -> F { + horner_eval(&self.0, x) + } + + // + pub fn interpolate_lagrange(xs: &[F], ys: &[F]) -> Self { + assert_eq!(xs.len(), ys.len()); + + let mut coeffs = Self::zero(); + + for (i, (&xi, &yi)) in zip(xs, ys).enumerate() { + let mut prod = yi; + + for (j, &xj) in xs.iter().enumerate() { + if i != j { + prod /= xi - xj; + } + } + + let mut term = Self::new(vec![prod]); + + for (j, &xj) in xs.iter().enumerate() { + if i != j { + term = term * (Self::x() - Self::new(vec![xj])); + } + } + + coeffs = coeffs + term; + } + + coeffs.truncate_leading_zeros(); + + coeffs + } + + pub fn degree(&self) -> usize { + let mut coeffs = self.0.iter().rev(); + let _ = (&mut coeffs).take_while(|v| v.is_zero()); + coeffs.len().saturating_sub(1) + } + + fn x() -> Self { + Self(vec![F::zero(), F::one()]) + } + + fn truncate_leading_zeros(&mut self) { + while self.0.last() == Some(&F::zero()) { + self.0.pop(); + } + } +} + +impl From for UnivariatePoly { + fn from(value: F) -> Self { + Self::new(vec![value]) + } +} + +impl Mul for UnivariatePoly { + type Output = Self; + + fn mul(mut self, rhs: F) -> Self { + self.0.iter_mut().for_each(|coeff| *coeff *= rhs); + self + } +} + +impl Mul for UnivariatePoly { + type Output = Self; + + fn mul(mut self, mut rhs: Self) -> Self { + if self.is_zero() || rhs.is_zero() { + return Self::zero(); + } + + self.truncate_leading_zeros(); + rhs.truncate_leading_zeros(); + + let mut res = vec![F::zero(); self.0.len() + rhs.0.len() - 1]; + + for (i, coeff_a) in self.0.into_iter().enumerate() { + for (j, &coeff_b) in rhs.0.iter().enumerate() { + res[i + j] += coeff_a * coeff_b; + } + } + + Self::new(res) + } +} + +impl Add for UnivariatePoly { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + let n = self.0.len().max(rhs.0.len()); + let mut res = Vec::new(); + + for i in 0..n { + res.push(match (self.0.get(i), rhs.0.get(i)) { + (Some(&a), Some(&b)) => a + b, + (Some(&a), None) | (None, Some(&a)) => a, + _ => unreachable!(), + }) + } + + Self(res) + } +} + +impl Sub for UnivariatePoly { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + self + (-rhs) + } +} + +impl Neg for UnivariatePoly { + type Output = Self; + + fn neg(self) -> Self { + Self(self.0.into_iter().map(|v| -v).collect()) + } +} + +impl Zero for UnivariatePoly { + fn zero() -> Self { + Self(vec![]) + } + + fn is_zero(&self) -> bool { + self.0.iter().all(F::is_zero) + } +} + +impl Deref for UnivariatePoly { + type Target = [F]; + + fn deref(&self) -> &[F] { + &self.0 + } +} + +/// Evaluates univariate polynomial using [Horner's method]. +/// +/// [Horner's method]: https://en.wikipedia.org/wiki/Horner%27s_method +pub fn horner_eval(coeffs: &[F], x: F) -> F { + coeffs + .iter() + .rfold(F::zero(), |acc, &coeff| acc * x + coeff) +} + +/// Returns `v_0 + alpha * v_1 + ... + alpha^(n-1) * v_{n-1}`. +pub fn random_linear_combination(v: &[SecureField], alpha: SecureField) -> SecureField { + horner_eval(v, alpha) +} + +/// Evaluates the lagrange kernel of the boolean hypercube. +/// +/// The lagrange kernel of the boolean hypercube is a multilinear extension of the function that +/// when given `x, y` in `{0, 1}^n` evaluates to 1 if `x = y`, and evaluates to 0 otherwise. +pub fn eq(x: &[F], y: &[F]) -> F { + assert_eq!(x.len(), y.len()); + zip(x, y) + .map(|(&xi, &yi)| xi * yi + (F::one() - xi) * (F::one() - yi)) + .product() +} + +/// Computes `eq(0, assignment) * eval0 + eq(1, assignment) * eval1`. +pub fn fold_mle_evals(assignment: SecureField, eval0: F, eval1: F) -> SecureField +where + F: Field, + SecureField: ExtensionOf, +{ + assignment * (eval1 - eval0) + eval0 +} + +/// Projective fraction. +#[derive(Debug, Clone, Copy)] +pub struct Fraction { + pub numerator: N, + pub denominator: D, +} + +impl Fraction { + pub fn new(numerator: N, denominator: D) -> Self { + Self { + numerator, + denominator, + } + } +} + +impl + Add + Mul + Mul + Copy> Add + for Fraction +{ + type Output = Fraction; + + fn add(self, rhs: Self) -> Fraction { + Fraction { + numerator: rhs.denominator * self.numerator + self.denominator * rhs.numerator, + denominator: self.denominator * rhs.denominator, + } + } +} + +impl Zero for Fraction +where + Self: Add, +{ + fn zero() -> Self { + Self { + numerator: N::zero(), + denominator: D::one(), + } + } + + fn is_zero(&self) -> bool { + self.numerator.is_zero() && !self.denominator.is_zero() + } +} + +impl Sum for Fraction +where + Self: Zero, +{ + fn sum>(mut iter: I) -> Self { + let first = iter.next().unwrap_or_else(Self::zero); + iter.fold(first, |a, b| a + b) + } +} + +/// Represents the fraction `1 / x` +pub struct Reciprocal { + x: T, +} + +impl Reciprocal { + pub fn new(x: T) -> Self { + Self { x } + } +} + +impl + Mul + Copy> Add for Reciprocal { + type Output = Fraction; + + fn add(self, rhs: Self) -> Fraction { + // `1/a + 1/b = (a + b)/(a * b)` + Fraction { + numerator: self.x + rhs.x, + denominator: self.x * rhs.x, + } + } +} + +#[cfg(test)] +mod tests { + use std::iter::zip; + + use num_traits::{One, Zero}; + + use super::{horner_eval, UnivariatePoly}; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::FieldExpOps; + use crate::core::lookups::utils::{eq, Fraction}; + + #[test] + fn lagrange_interpolation_works() { + let xs = [5, 1, 3, 9].map(BaseField::from); + let ys = [1, 2, 3, 4].map(BaseField::from); + + let poly = UnivariatePoly::interpolate_lagrange(&xs, &ys); + + for (x, y) in zip(xs, ys) { + assert_eq!(poly.eval_at_point(x), y, "mismatch for x={x}"); + } + } + + #[test] + fn horner_eval_works() { + let coeffs = [BaseField::from(9), BaseField::from(2), BaseField::from(3)]; + let x = BaseField::from(7); + + let eval = horner_eval(&coeffs, x); + + assert_eq!(eval, coeffs[0] + coeffs[1] * x + coeffs[2] * x.square()); + } + + #[test] + fn eq_identical_hypercube_points_returns_one() { + let zero = SecureField::zero(); + let one = SecureField::one(); + let a = &[one, zero, one]; + + let eq_eval = eq(a, a); + + assert_eq!(eq_eval, one); + } + + #[test] + fn eq_different_hypercube_points_returns_zero() { + let zero = SecureField::zero(); + let one = SecureField::one(); + let a = &[one, zero, one]; + let b = &[one, zero, zero]; + + let eq_eval = eq(a, b); + + assert_eq!(eq_eval, zero); + } + + #[test] + #[should_panic] + fn eq_different_size_points() { + let zero = SecureField::zero(); + let one = SecureField::one(); + + eq(&[zero, one], &[zero]); + } + + #[test] + fn fraction_addition_works() { + let a = Fraction::new(BaseField::from(1), BaseField::from(3)); + let b = Fraction::new(BaseField::from(2), BaseField::from(6)); + + let Fraction { + numerator, + denominator, + } = a + b; + + assert_eq!( + numerator / denominator, + BaseField::from(2) / BaseField::from(3) + ); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/mod.rs b/Stwo_wrapper/crates/prover/src/core/mod.rs new file mode 100644 index 0000000..a00aad6 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/mod.rs @@ -0,0 +1,59 @@ +use std::ops::{Deref, DerefMut}; + +pub mod air; +pub mod backend; +pub mod channel; +pub mod circle; +pub mod constraints; +pub mod fft; +pub mod fields; +pub mod fri; +pub mod lookups; +pub mod pcs; +pub mod poly; +pub mod proof_of_work; +pub mod prover; +pub mod queries; +#[cfg(test)] +pub mod test_utils; +pub mod utils; +pub mod vcs; + +/// A vector in which each element relates (by index) to a column in the trace. +pub type ColumnVec = Vec; + +/// A vector of [ColumnVec]s. Each [ColumnVec] relates (by index) to a component in the air. +#[derive(Debug, Clone)] +pub struct ComponentVec(pub Vec>); + +impl ComponentVec { + pub fn flatten(self) -> ColumnVec { + self.0.into_iter().flatten().collect() + } +} + +impl ComponentVec> { + pub fn flatten_cols(self) -> Vec { + self.0.into_iter().flatten().flatten().collect() + } +} + +impl Default for ComponentVec { + fn default() -> Self { + Self(Vec::new()) + } +} + +impl Deref for ComponentVec { + type Target = Vec>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for ComponentVec { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/pcs/mod.rs b/Stwo_wrapper/crates/prover/src/core/pcs/mod.rs new file mode 100644 index 0000000..d9acf52 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/pcs/mod.rs @@ -0,0 +1,40 @@ +//! Implements a FRI polynomial commitment scheme. +//! This is a protocol where the prover can commit on a set of polynomials and then prove their +//! opening on a set of points. +//! Note: This implementation is not really a polynomial commitment scheme, because we are not in +//! the unique decoding regime. This is enough for a STARK proof though, where we only want to imply +//! the existence of such polynomials, and are ok with having a small decoding list. +//! Note: Opened points cannot come from the commitment domain. + +mod prover; +pub mod quotients; +mod utils; +mod verifier; + +pub use self::prover::{ + CommitmentSchemeProof, CommitmentSchemeProver, CommitmentTreeProver, TreeBuilder, +}; +pub use self::utils::TreeVec; +pub use self::verifier::CommitmentSchemeVerifier; +use super::fri::FriConfig; + +#[derive(Copy, Debug, Clone, PartialEq, Eq)] +pub struct TreeSubspan { + pub tree_index: usize, + pub col_start: usize, + pub col_end: usize, +} + +#[derive(Debug, Clone, Copy)] +pub struct PcsConfig { + pub pow_bits: u32, + pub fri_config: FriConfig, +} +impl Default for PcsConfig { + fn default() -> Self { + Self { + pow_bits: 5, + fri_config: FriConfig::new(0, 1, 3), + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/pcs/prover.rs b/Stwo_wrapper/crates/prover/src/core/pcs/prover.rs new file mode 100644 index 0000000..ed45ffc --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/pcs/prover.rs @@ -0,0 +1,256 @@ +use std::collections::BTreeMap; + +use itertools::Itertools; +use tracing::{span, Level}; + +use super::super::circle::CirclePoint; +use super::super::fields::m31::BaseField; +use super::super::fields::qm31::SecureField; +use super::super::fri::{FriProof, FriProver}; +use super::super::poly::circle::CanonicCoset; +use super::super::poly::BitReversedOrder; +use super::super::ColumnVec; +use super::quotients::{compute_fri_quotients, PointSample}; +use super::utils::TreeVec; +use super::{PcsConfig, TreeSubspan}; +use crate::core::air::Trace; +use crate::core::backend::BackendForChannel; +use crate::core::channel::{Channel, MerkleChannel}; +use crate::core::poly::circle::{CircleEvaluation, CirclePoly}; +use crate::core::poly::twiddles::TwiddleTree; +use crate::core::vcs::ops::MerkleHasher; +use crate::core::vcs::prover::{MerkleDecommitment, MerkleProver}; + +/// The prover side of a FRI polynomial commitment scheme. See [super]. +pub struct CommitmentSchemeProver<'a, B: BackendForChannel, MC: MerkleChannel> { + pub trees: TreeVec>, + pub config: PcsConfig, + twiddles: &'a TwiddleTree, +} + +impl<'a, B: BackendForChannel, MC: MerkleChannel> CommitmentSchemeProver<'a, B, MC> { + pub fn new(config: PcsConfig, twiddles: &'a TwiddleTree) -> Self { + CommitmentSchemeProver { + trees: TreeVec::default(), + config, + twiddles, + } + } + + fn commit(&mut self, polynomials: ColumnVec>, channel: &mut MC::C) { + let _span = span!(Level::INFO, "Commitment").entered(); + let tree = CommitmentTreeProver::new( + polynomials, + self.config.fri_config.log_blowup_factor, + channel, + self.twiddles, + ); + self.trees.push(tree); + } + + pub fn tree_builder(&mut self) -> TreeBuilder<'_, 'a, B, MC> { + TreeBuilder { + tree_index: self.trees.len(), + commitment_scheme: self, + polys: Vec::default(), + } + } + + pub fn roots(&self) -> TreeVec<::Hash> { + self.trees.as_ref().map(|tree| tree.commitment.root()) + } + + pub fn polynomials(&self) -> TreeVec>> { + self.trees + .as_ref() + .map(|tree| tree.polynomials.iter().collect()) + } + + pub fn evaluations( + &self, + ) -> TreeVec>> { + self.trees + .as_ref() + .map(|tree| tree.evaluations.iter().collect()) + } + + pub fn trace(&self) -> Trace<'_, B> { + let polys = self.polynomials(); + let evals = self.evaluations(); + Trace { polys, evals } + } + + pub fn prove_values( + &self, + sampled_points: TreeVec>>>, + channel: &mut MC::C, + ) -> CommitmentSchemeProof { + // Evaluate polynomials on open points. + let span = span!(Level::INFO, "Evaluate columns out of domain").entered(); + let samples = self + .polynomials() + .zip_cols(&sampled_points) + .map_cols(|(poly, points)| { + points + .iter() + .map(|&point| PointSample { + point, + value: poly.eval_at_point(point), + }) + .collect_vec() + }); + span.exit(); + let sampled_values = samples + .as_cols_ref() + .map_cols(|x| x.iter().map(|o| o.value).collect()); + channel.mix_felts(&sampled_values.clone().flatten_cols()); + + // Compute oods quotients for boundary constraints on the sampled points. + let columns = self.evaluations().flatten(); + let quotients = compute_fri_quotients( + &columns, + &samples.flatten(), + channel.draw_felt(), + self.config.fri_config.log_blowup_factor, + ); + + // Run FRI commitment phase on the oods quotients. + let fri_prover = + FriProver::::commit(channel, self.config.fri_config, "ients, self.twiddles); + + // Proof of work. + let span1 = span!(Level::INFO, "Grind").entered(); + let proof_of_work = B::grind(channel, self.config.pow_bits); + span1.exit(); + channel.mix_nonce(proof_of_work); + + // FRI decommitment phase. + let (fri_proof, fri_query_domains) = fri_prover.decommit(channel); + + // Decommit the FRI queries on the merkle trees. + let decommitment_results = self.trees.as_ref().map(|tree| { + let queries = fri_query_domains + .iter() + .map(|(&log_size, domain)| (log_size, domain.flatten())) + .collect(); + tree.decommit(queries) + }); + + let queried_values = decommitment_results.as_ref().map(|(v, _)| v.clone()); + let decommitments = decommitment_results.map(|(_, d)| d); + + CommitmentSchemeProof { + sampled_values, + decommitments, + queried_values, + proof_of_work, + fri_proof, + } + } +} + +#[derive(Debug)] +pub struct CommitmentSchemeProof { + pub sampled_values: TreeVec>>, + pub decommitments: TreeVec>, + pub queried_values: TreeVec>>, + pub proof_of_work: u64, + pub fri_proof: FriProof, +} + +pub struct TreeBuilder<'a, 'b, B: BackendForChannel, MC: MerkleChannel> { + tree_index: usize, + commitment_scheme: &'a mut CommitmentSchemeProver<'b, B, MC>, + polys: ColumnVec>, +} +impl<'a, 'b, B: BackendForChannel, MC: MerkleChannel> TreeBuilder<'a, 'b, B, MC> { + pub fn extend_evals( + &mut self, + columns: ColumnVec>, + ) -> TreeSubspan { + let span = span!(Level::INFO, "Interpolation for commitment").entered(); + let col_start = self.polys.len(); + let polys = columns + .into_iter() + .map(|eval| eval.interpolate_with_twiddles(self.commitment_scheme.twiddles)) + .collect_vec(); + span.exit(); + self.polys.extend(polys); + TreeSubspan { + tree_index: self.tree_index, + col_start, + col_end: self.polys.len(), + } + } + + pub fn extend_polys(&mut self, polys: ColumnVec>) -> TreeSubspan { + let col_start = self.polys.len(); + self.polys.extend(polys); + TreeSubspan { + tree_index: self.tree_index, + col_start, + col_end: self.polys.len(), + } + } + + pub fn commit(self, channel: &mut MC::C) { + let _span = span!(Level::INFO, "Commitment").entered(); + self.commitment_scheme.commit(self.polys, channel); + } +} + +/// Prover data for a single commitment tree in a commitment scheme. The commitment scheme allows to +/// commit on a set of polynomials at a time. This corresponds to such a set. +pub struct CommitmentTreeProver, MC: MerkleChannel> { + pub polynomials: ColumnVec>, + pub evaluations: ColumnVec>, + pub commitment: MerkleProver, +} + +impl, MC: MerkleChannel> CommitmentTreeProver { + pub fn new( + polynomials: ColumnVec>, + log_blowup_factor: u32, + channel: &mut MC::C, + twiddles: &TwiddleTree, + ) -> Self { + let span = span!(Level::INFO, "Extension").entered(); + let evaluations = polynomials + .iter() + .map(|poly| { + poly.evaluate_with_twiddles( + CanonicCoset::new(poly.log_size() + log_blowup_factor).circle_domain(), + twiddles, + ) + }) + .collect_vec(); + + span.exit(); + + let _span = span!(Level::INFO, "Merkle").entered(); + let tree = MerkleProver::commit(evaluations.iter().map(|eval| &eval.values).collect()); + MC::mix_root(channel, tree.root()); + + CommitmentTreeProver { + polynomials, + evaluations, + commitment: tree, + } + } + + /// Decommits the merkle tree on the given query positions. + /// Returns the values at the queried positions and the decommitment. + /// The queries are given as a mapping from the log size of the layer size to the queried + /// positions on each column of that size. + fn decommit( + &self, + queries: BTreeMap>, + ) -> (ColumnVec>, MerkleDecommitment) { + let eval_vec = self + .evaluations + .iter() + .map(|eval| &eval.values) + .collect_vec(); + self.commitment.decommit(queries, eval_vec) + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/pcs/quotients.rs b/Stwo_wrapper/crates/prover/src/core/pcs/quotients.rs new file mode 100644 index 0000000..1a41e83 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/pcs/quotients.rs @@ -0,0 +1,218 @@ +use std::cmp::Reverse; +use std::collections::BTreeMap; +use std::iter::zip; + +use itertools::{izip, multiunzip, Itertools}; +use tracing::{span, Level}; + +use crate::core::backend::cpu::quotients::{accumulate_row_quotients, quotient_constants}; +use crate::core::circle::CirclePoint; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fri::SparseCircleEvaluation; +use crate::core::poly::circle::{ + CanonicCoset, CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation, +}; +use crate::core::poly::BitReversedOrder; +use crate::core::prover::VerificationError; +use crate::core::queries::SparseSubCircleDomain; +use crate::core::utils::bit_reverse_index; + +pub trait QuotientOps: PolyOps { + /// Accumulates the quotients of the columns at the given domain. + /// For a column f(x), and a point sample (p,v), the quotient is + /// (f(x) - V0(x))/V1(x) + /// where V0(p)=v, V0(conj(p))=conj(v), and V1 is a vanishing polynomial for p,conj(p). + /// This ensures that if f(p)=v, then the quotient is a polynomial. + /// The result is a linear combination of the quotients using powers of random_coeff. + fn accumulate_quotients( + domain: CircleDomain, + columns: &[&CircleEvaluation], + random_coeff: SecureField, + sample_batches: &[ColumnSampleBatch], + log_blowup_factor: u32, + ) -> SecureEvaluation; +} + +/// A batch of column samplings at a point. +pub struct ColumnSampleBatch { + /// The point at which the columns are sampled. + pub point: CirclePoint, + /// The sampled column indices and their values at the point. + pub columns_and_values: Vec<(usize, SecureField)>, +} + +impl ColumnSampleBatch { + /// Groups column samples by sampled point. + /// # Arguments + /// samples: For each column, a vector of samples. + pub fn new_vec(samples: &[&Vec]) -> Vec { + // Group samples by point, and create a ColumnSampleBatch for each point. + // This should keep a stable ordering. + let mut grouped_samples = BTreeMap::new(); + for (column_index, samples) in samples.iter().enumerate() { + for sample in samples.iter() { + grouped_samples + .entry(sample.point) + .or_insert_with(Vec::new) + .push((column_index, sample.value)); + } + } + grouped_samples + .into_iter() + .map(|(point, columns_and_values)| ColumnSampleBatch { + point, + columns_and_values, + }) + .collect() + } +} + +pub struct PointSample { + pub point: CirclePoint, + pub value: SecureField, +} + +pub fn compute_fri_quotients( + columns: &[&CircleEvaluation], + samples: &[Vec], + random_coeff: SecureField, + log_blowup_factor: u32, +) -> Vec> { + let _span = span!(Level::INFO, "Compute FRI quotients").entered(); + zip(columns, samples) + .sorted_by_key(|(c, _)| Reverse(c.domain.log_size())) + .group_by(|(c, _)| c.domain.log_size()) + .into_iter() + .map(|(log_size, tuples)| { + let (columns, samples): (Vec<_>, Vec<_>) = tuples.unzip(); + let domain = CanonicCoset::new(log_size).circle_domain(); + // TODO: slice. + let sample_batches = ColumnSampleBatch::new_vec(&samples); + B::accumulate_quotients( + domain, + &columns, + random_coeff, + &sample_batches, + log_blowup_factor, + ) + }) + .collect() +} + +pub fn fri_answers( + column_log_sizes: Vec, + samples: &[Vec], + random_coeff: SecureField, + query_domain_per_log_size: BTreeMap, + queried_values_per_column: &[Vec], +) -> Result, VerificationError> { + izip!(column_log_sizes, samples, queried_values_per_column) + .sorted_by_key(|(log_size, ..)| Reverse(*log_size)) + .group_by(|(log_size, ..)| *log_size) + .into_iter() + .map(|(log_size, tuples)| { + let (_, samples, queried_valued_per_column): (Vec<_>, Vec<_>, Vec<_>) = + multiunzip(tuples); + fri_answers_for_log_size( + log_size, + &samples, + random_coeff, + &query_domain_per_log_size[&log_size], + &queried_valued_per_column, + ) + }) + .collect() +} + +pub fn fri_answers_for_log_size( + log_size: u32, + samples: &[&Vec], + random_coeff: SecureField, + query_domain: &SparseSubCircleDomain, + queried_values_per_column: &[&Vec], +) -> Result { + let commitment_domain = CanonicCoset::new(log_size).circle_domain(); + let sample_batches = ColumnSampleBatch::new_vec(samples); + for queried_values in queried_values_per_column { + if queried_values.len() != query_domain.flatten().len() { + return Err(VerificationError::InvalidStructure( + "Insufficient number of queried values".to_string(), + )); + } + } + let mut queried_values_per_column = queried_values_per_column + .iter() + .map(|q| q.iter()) + .collect_vec(); + + let mut evals = Vec::new(); + for subdomain in query_domain.iter() { + let domain = subdomain.to_circle_domain(&commitment_domain); + let quotient_constants = quotient_constants(&sample_batches, random_coeff, domain); + let mut column_evals = Vec::new(); + for queried_values in queried_values_per_column.iter_mut() { + let eval = CircleEvaluation::new( + domain, + queried_values.take(domain.size()).copied().collect_vec(), + ); + column_evals.push(eval); + } + + let mut values = Vec::new(); + for row in 0..domain.size() { + let domain_point = domain.at(bit_reverse_index(row, log_size)); + let value = accumulate_row_quotients( + &sample_batches, + &column_evals.iter().collect_vec(), + "ient_constants, + row, + domain_point, + ); + values.push(value); + } + let eval = CircleEvaluation::new(domain, values); + evals.push(eval); + } + + let res = SparseCircleEvaluation::new(evals); + if !queried_values_per_column.iter().all(|x| x.is_empty()) { + return Err(VerificationError::InvalidStructure( + "Too many queried values".to_string(), + )); + } + Ok(res) +} + +#[cfg(test)] +mod tests { + use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly}; + use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; + use crate::core::pcs::quotients::{compute_fri_quotients, PointSample}; + use crate::core::poly::circle::CanonicCoset; + use crate::{m31, qm31}; + + #[test] + fn test_quotients_are_low_degree() { + const LOG_SIZE: u32 = 7; + const LOG_BLOWUP_FACTOR: u32 = 1; + let polynomial = CpuCirclePoly::new((0..1 << LOG_SIZE).map(|i| m31!(i)).collect()); + let eval_domain = CanonicCoset::new(LOG_SIZE + 1).circle_domain(); + let eval = polynomial.evaluate(eval_domain); + let point = SECURE_FIELD_CIRCLE_GEN; + let value = polynomial.eval_at_point(point); + let coeff = qm31!(1, 2, 3, 4); + let quot_eval = compute_fri_quotients( + &[&eval], + &[vec![PointSample { point, value }]], + coeff, + LOG_BLOWUP_FACTOR, + ) + .pop() + .unwrap(); + let quot_poly_base_field = + CpuCircleEvaluation::new(eval_domain, quot_eval.values.columns[0].clone()) + .interpolate(); + assert!(quot_poly_base_field.is_in_fri_space(LOG_SIZE)); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/pcs/utils.rs b/Stwo_wrapper/crates/prover/src/core/pcs/utils.rs new file mode 100644 index 0000000..bfdbdb5 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/pcs/utils.rs @@ -0,0 +1,158 @@ +use std::collections::BTreeSet; +use std::ops::{Deref, DerefMut}; + +use itertools::zip_eq; +use serde::{Deserialize, Serialize}; + +use super::TreeSubspan; +use crate::core::ColumnVec; + +/// A container that holds an element for each commitment tree. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TreeVec(pub Vec); + +impl TreeVec { + pub fn new(vec: Vec) -> TreeVec { + TreeVec(vec) + } + pub fn map U>(self, f: F) -> TreeVec { + TreeVec(self.0.into_iter().map(f).collect()) + } + pub fn zip(self, other: impl Into>) -> TreeVec<(T, U)> { + let other = other.into(); + TreeVec(self.0.into_iter().zip(other.0).collect()) + } + pub fn zip_eq(self, other: impl Into>) -> TreeVec<(T, U)> { + let other = other.into(); + TreeVec(zip_eq(self.0, other.0).collect()) + } + pub fn as_ref(&self) -> TreeVec<&T> { + TreeVec(self.iter().collect()) + } + pub fn as_mut(&mut self) -> TreeVec<&mut T> { + TreeVec(self.iter_mut().collect()) + } +} + +/// Converts `&TreeVec` to `TreeVec<&T>`. +impl<'a, T> From<&'a TreeVec> for TreeVec<&'a T> { + fn from(val: &'a TreeVec) -> Self { + val.as_ref() + } +} + +impl Deref for TreeVec { + type Target = Vec; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for TreeVec { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Default for TreeVec { + fn default() -> Self { + TreeVec(Vec::new()) + } +} + +impl TreeVec> { + pub fn map_cols U>(self, mut f: F) -> TreeVec> { + TreeVec( + self.0 + .into_iter() + .map(|column| column.into_iter().map(&mut f).collect()) + .collect(), + ) + } + + /// Zips two [`TreeVec>`] with the same structure (number of columns in each tree). + /// The resulting [`TreeVec>`] has the same structure, with each value being a tuple + /// of the corresponding values from the input [`TreeVec>`]. + pub fn zip_cols( + self, + other: impl Into>>, + ) -> TreeVec> { + let other = other.into(); + TreeVec( + zip_eq(self.0, other.0) + .map(|(column1, column2)| zip_eq(column1, column2).collect()) + .collect(), + ) + } + + pub fn as_cols_ref(&self) -> TreeVec> { + TreeVec(self.iter().map(|column| column.iter().collect()).collect()) + } + + /// Flattens the [`TreeVec>`] into a single [`ColumnVec`] with all the columns + /// combined. + pub fn flatten(self) -> ColumnVec { + self.0.into_iter().flatten().collect() + } + + /// Appends the columns of another [`TreeVec>`] to this one. + pub fn append_cols(&mut self, mut other: TreeVec>) { + let n_trees = self.0.len().max(other.0.len()); + self.0.resize_with(n_trees, Default::default); + for (self_col, other_col) in self.0.iter_mut().zip(other.0.iter_mut()) { + self_col.append(other_col); + } + } + + /// Concatenates the columns of multiple [`TreeVec>`] into a single + /// [`TreeVec>`]. + pub fn concat_cols( + trees: impl Iterator>>, + ) -> TreeVec> { + let mut result = TreeVec::default(); + for tree in trees { + result.append_cols(tree); + } + result + } + + /// Extracts a sub-tree based on the specified locations. + /// + /// # Panics + /// + /// If two or more locations have the same tree index. + pub fn sub_tree(&self, locations: &[TreeSubspan]) -> TreeVec> { + let tree_indicies: BTreeSet = locations.iter().map(|l| l.tree_index).collect(); + assert_eq!(tree_indicies.len(), locations.len()); + let max_tree_index = tree_indicies.iter().max().unwrap_or(&0); + let mut res = TreeVec(vec![Vec::new(); max_tree_index + 1]); + + for &location in locations { + // TODO(andrew): Throwing error here might be better instead. + let chunk = self.get_chunk(location).unwrap(); + res[location.tree_index] = chunk; + } + + res + } + + fn get_chunk(&self, location: TreeSubspan) -> Option> { + let tree = self.0.get(location.tree_index)?; + let chunk = tree.get(location.col_start..location.col_end)?; + Some(chunk.iter().collect()) + } +} + +impl<'a, T> From<&'a TreeVec>> for TreeVec> { + fn from(val: &'a TreeVec>) -> Self { + val.as_cols_ref() + } +} + +impl TreeVec>> { + /// Flattens a [`TreeVec>`] of [Vec]s into a single [Vec] with all the elements + /// combined. + pub fn flatten_cols(self) -> Vec { + self.0.into_iter().flatten().flatten().collect() + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/pcs/verifier.rs b/Stwo_wrapper/crates/prover/src/core/pcs/verifier.rs new file mode 100644 index 0000000..812137f --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/pcs/verifier.rs @@ -0,0 +1,132 @@ +use std::iter::zip; + +use itertools::Itertools; + +use super::super::circle::CirclePoint; +use super::super::fields::qm31::SecureField; +use super::super::fri::{CirclePolyDegreeBound, FriVerifier}; +use super::quotients::{fri_answers, PointSample}; +use super::utils::TreeVec; +use super::{CommitmentSchemeProof, PcsConfig}; +use crate::core::channel::{Channel, MerkleChannel}; +use crate::core::prover::VerificationError; +use crate::core::vcs::ops::MerkleHasher; +use crate::core::vcs::verifier::MerkleVerifier; +use crate::core::ColumnVec; + +/// The verifier side of a FRI polynomial commitment scheme. See [super]. +#[derive(Default)] +pub struct CommitmentSchemeVerifier { + pub trees: TreeVec>, + pub config: PcsConfig, +} + +impl CommitmentSchemeVerifier { + pub fn new(config: PcsConfig) -> Self { + Self { + trees: TreeVec::default(), + config, + } + } + + /// A [TreeVec] of the log sizes of each column in each commitment tree. + fn column_log_sizes(&self) -> TreeVec> { + self.trees + .as_ref() + .map(|tree| tree.column_log_sizes.clone()) + } + + /// Reads a commitment from the prover. + pub fn commit( + &mut self, + commitment: ::Hash, + log_sizes: &[u32], + channel: &mut MC::C, + ) { + MC::mix_root(channel, commitment); + let extended_log_sizes = log_sizes + .iter() + .map(|&log_size| log_size + self.config.fri_config.log_blowup_factor) + .collect(); + let verifier = MerkleVerifier::new(commitment, extended_log_sizes); + self.trees.push(verifier); + } + + pub fn verify_values( + &self, + sampled_points: TreeVec>>>, + proof: CommitmentSchemeProof, + channel: &mut MC::C, + ) -> Result<(), VerificationError> { + channel.mix_felts(&proof.sampled_values.clone().flatten_cols()); + let random_coeff = channel.draw_felt(); + + let bounds = self + .column_log_sizes() + .zip_cols(&sampled_points) + .map_cols(|(log_size, sampled_points)| { + vec![ + CirclePolyDegreeBound::new(log_size - self.config.fri_config.log_blowup_factor); + sampled_points.len() + ] + }) + .flatten_cols() + .into_iter() + .sorted() + .rev() + .dedup() + .collect_vec(); + + // FRI commitment phase on OODS quotients. + let mut fri_verifier = + FriVerifier::::commit(channel, self.config.fri_config, proof.fri_proof, bounds)?; + + // Verify proof of work. + channel.mix_nonce(proof.proof_of_work); + if channel.trailing_zeros() < self.config.pow_bits { + return Err(VerificationError::ProofOfWork); + } + + // Get FRI query domains. + let fri_query_domains = fri_verifier.column_query_positions(channel); + + // Verify merkle decommitments. + self.trees + .as_ref() + .zip_eq(proof.decommitments) + .zip_eq(proof.queried_values.clone()) + .map(|((tree, decommitment), queried_values)| { + let queries = fri_query_domains + .iter() + .map(|(&log_size, domain)| (log_size, domain.flatten())) + .collect(); + tree.verify(queries, queried_values, decommitment) + }) + .0 + .into_iter() + .collect::>()?; + + // Answer FRI queries. + let samples = sampled_points + .zip_cols(proof.sampled_values) + .map_cols(|(sampled_points, sampled_values)| { + zip(sampled_points, sampled_values) + .map(|(point, value)| PointSample { point, value }) + .collect_vec() + }) + .flatten(); + + // TODO(spapini): Properly defined column log size and dinstinguish between poly and + // commitment. + let fri_answers = fri_answers( + self.column_log_sizes().flatten().into_iter().collect(), + &samples, + random_coeff, + fri_query_domains, + &proof.queried_values.flatten(), + )?; + + fri_verifier.decommit(fri_answers)?; + Ok(()) + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/circle/canonic.rs b/Stwo_wrapper/crates/prover/src/core/poly/circle/canonic.rs new file mode 100644 index 0000000..837e648 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/poly/circle/canonic.rs @@ -0,0 +1,77 @@ +use super::CircleDomain; +use crate::core::circle::{CirclePoint, CirclePointIndex, Coset}; +use crate::core::fields::m31::BaseField; + +/// A coset of the form G_{2n} + , where G_n is the generator of the +/// subgroup of order n. The ordering on this coset is G_2n + i * G_n. +/// These cosets can be used as a [CircleDomain], and be interpolated on. +/// Note that this changes the ordering on the coset to be like [CircleDomain], +/// which is G_2n + i * G_n/2 and then -G_2n -i * G_n/2. +/// For example, the Xs below are a canonic coset with n=8. +/// ```text +/// X O X +/// O O +/// X X +/// O O +/// X X +/// O O +/// X O X +/// ``` +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct CanonicCoset { + pub coset: Coset, +} + +impl CanonicCoset { + pub fn new(log_size: u32) -> Self { + assert!(log_size > 0); + Self { + coset: Coset::odds(log_size), + } + } + + /// Gets the full coset represented G_{2n} + . + pub fn coset(&self) -> Coset { + self.coset + } + + /// Gets half of the coset (its conjugate complements to the whole coset), G_{2n} + + pub fn half_coset(&self) -> Coset { + Coset::half_odds(self.log_size() - 1) + } + + /// Gets the [CircleDomain] representing the same point set (in another order). + pub fn circle_domain(&self) -> CircleDomain { + CircleDomain::new(self.half_coset()) + } + + /// Returns the log size of the coset. + pub fn log_size(&self) -> u32 { + self.coset.log_size + } + + /// Returns the size of the coset. + pub fn size(&self) -> usize { + self.coset.size() + } + + pub fn initial_index(&self) -> CirclePointIndex { + self.coset.initial_index + } + + pub fn step_size(&self) -> CirclePointIndex { + self.coset.step_size + } + + pub fn step(&self) -> CirclePoint { + self.coset.step + } + + pub fn index_at(&self, index: usize) -> CirclePointIndex { + self.coset.index_at(index) + } + + pub fn at(&self, i: usize) -> CirclePoint { + self.coset.at(i) + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/circle/domain.rs b/Stwo_wrapper/crates/prover/src/core/poly/circle/domain.rs new file mode 100644 index 0000000..fba2bc3 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/poly/circle/domain.rs @@ -0,0 +1,188 @@ +use std::iter::Chain; + +use itertools::Itertools; + +use crate::core::circle::{ + CirclePoint, CirclePointIndex, Coset, CosetIterator, M31_CIRCLE_LOG_ORDER, +}; +use crate::core::fields::m31::BaseField; + +pub const MAX_CIRCLE_DOMAIN_LOG_SIZE: u32 = M31_CIRCLE_LOG_ORDER - 1; + +/// A valid domain for circle polynomial interpolation and evaluation. +/// Valid domains are a disjoint union of two conjugate cosets: +-C + . +/// The ordering defined on this domain is C + iG_n, and then -C - iG_n. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct CircleDomain { + pub half_coset: Coset, +} + +impl CircleDomain { + /// Given a coset C + , constructs the circle domain +-C + (i.e., + /// this coset and its conjugate). + pub fn new(half_coset: Coset) -> Self { + Self { half_coset } + } + + pub fn iter(&self) -> CircleDomainIterator { + self.half_coset + .iter() + .chain(self.half_coset.conjugate().iter()) + } + + /// Iterates over point indices. + pub fn iter_indices(&self) -> CircleDomainIndexIterator { + self.half_coset + .iter_indices() + .chain(self.half_coset.conjugate().iter_indices()) + } + + /// Returns the size of the domain. + pub fn size(&self) -> usize { + 1 << self.log_size() + } + + /// Returns the log size of the domain. + pub fn log_size(&self) -> u32 { + self.half_coset.log_size + 1 + } + + /// Returns the `i` th domain element. + pub fn at(&self, i: usize) -> CirclePoint { + self.index_at(i).to_point() + } + + /// Returns the [CirclePointIndex] of the `i`th domain element. + pub fn index_at(&self, i: usize) -> CirclePointIndex { + if i < self.half_coset.size() { + self.half_coset.index_at(i) + } else { + -self.half_coset.index_at(i - self.half_coset.size()) + } + } + + pub fn find(&self, i: CirclePointIndex) -> Option { + if let Some(d) = self.half_coset.find(i) { + return Some(d); + } + if let Some(d) = self.half_coset.conjugate().find(i) { + return Some(self.half_coset.size() + d); + } + None + } + + /// Returns true if the domain is canonic. + /// + /// Canonic domains are domains with elements that are the entire set of points defined by + /// `G_2n + ` where `G_n` and `G_2n` are obtained by repeatedly doubling + /// [crate::core::circle::M31_CIRCLE_GEN]. + pub fn is_canonic(&self) -> bool { + self.half_coset.initial_index * 4 == self.half_coset.step_size + } + + /// Splits a circle domain into a smaller [CircleDomain]s, shifted by offsets. + pub fn split(&self, log_parts: u32) -> (CircleDomain, Vec) { + assert!(log_parts <= self.half_coset.log_size); + let subdomain = CircleDomain::new(Coset::new( + self.half_coset.initial_index, + self.half_coset.log_size - log_parts, + )); + let shifts = (0..1 << log_parts) + .map(|i| self.half_coset.step_size * i) + .collect_vec(); + (subdomain, shifts) + } + + pub fn shift(&self, shift: CirclePointIndex) -> CircleDomain { + CircleDomain::new(self.half_coset.shift(shift)) + } +} + +impl IntoIterator for CircleDomain { + type Item = CirclePoint; + type IntoIter = CircleDomainIterator; + + /// Iterates over the points in the domain. + fn into_iter(self) -> CircleDomainIterator { + self.iter() + } +} + +/// An iterator over points in a circle domain. +/// +/// Let the domain be `+-c + `. The first iterated points are `c + `, then `-c + <-G>`. +pub type CircleDomainIterator = + Chain>, CosetIterator>>; + +/// Like [CircleDomainIterator] but returns corresponding [CirclePointIndex]s. +type CircleDomainIndexIterator = + Chain, CosetIterator>; + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use super::CircleDomain; + use crate::core::circle::{CirclePointIndex, Coset}; + use crate::core::poly::circle::CanonicCoset; + + #[test] + fn test_circle_domain_iterator() { + let domain = CircleDomain::new(Coset::new(CirclePointIndex::generator(), 2)); + for (i, point) in domain.iter().enumerate() { + if i < 4 { + assert_eq!( + point, + (CirclePointIndex::generator() + CirclePointIndex::subgroup_gen(2) * i) + .to_point() + ); + } else { + assert_eq!( + point, + (-(CirclePointIndex::generator() + CirclePointIndex::subgroup_gen(2) * i)) + .to_point() + ); + } + } + } + + #[test] + fn is_canonic_invalid_domain() { + let half_coset = Coset::new(CirclePointIndex::generator(), 4); + let not_canonic_domain = CircleDomain::new(half_coset); + + assert!(!not_canonic_domain.is_canonic()); + } + + #[test] + fn test_at_circle_domain() { + let domain = CanonicCoset::new(7).circle_domain(); + let half_domain_size = domain.size() / 2; + + for i in 0..half_domain_size { + assert_eq!(domain.index_at(i), -domain.index_at(i + half_domain_size)); + assert_eq!(domain.at(i), domain.at(i + half_domain_size).conjugate()); + } + } + + #[test] + fn test_domain_split() { + let domain = CanonicCoset::new(5).circle_domain(); + let (subdomain, shifts) = domain.split(2); + + let domain_points = domain.iter().collect::>(); + let points_for_each_domain = shifts + .iter() + .map(|&shift| (subdomain.shift(shift)).iter().collect_vec()) + .collect::>(); + // Interleave the points from each subdomain. + let extended_points = (0..(1 << 3)) + .flat_map(|point_ind| { + (0..(1 << 2)) + .map(|shift_ind| points_for_each_domain[shift_ind][point_ind]) + .collect_vec() + }) + .collect_vec(); + assert_eq!(domain_points, extended_points); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/circle/evaluation.rs b/Stwo_wrapper/crates/prover/src/core/poly/circle/evaluation.rs new file mode 100644 index 0000000..4cf23b9 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/poly/circle/evaluation.rs @@ -0,0 +1,218 @@ +use std::marker::PhantomData; +use std::ops::{Deref, Index}; + +use educe::Educe; + +use super::{CanonicCoset, CircleDomain, CirclePoly, PolyOps}; +use crate::core::backend::cpu::CpuCircleEvaluation; +use crate::core::backend::{Col, Column}; +use crate::core::circle::{CirclePointIndex, Coset}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::{ExtensionOf, FieldOps}; +use crate::core::poly::twiddles::TwiddleTree; +use crate::core::poly::{BitReversedOrder, NaturalOrder}; +use crate::core::utils::bit_reverse_index; + +/// An evaluation defined on a [CircleDomain]. +/// The values are ordered according to the [CircleDomain] ordering. +#[derive(Educe)] +#[educe(Clone, Debug)] +pub struct CircleEvaluation, F: ExtensionOf, EvalOrder = NaturalOrder> { + pub domain: CircleDomain, + pub values: Col, + _eval_order: PhantomData, +} + +impl, F: ExtensionOf, EvalOrder> CircleEvaluation { + pub fn new(domain: CircleDomain, values: Col) -> Self { + assert_eq!(domain.size(), values.len()); + Self { + domain, + values, + _eval_order: PhantomData, + } + } +} + +// Note: The concrete implementation of the poly operations is in the specific backend used. +// For example, the CPU backend implementation is in `src/core/backend/cpu/poly.rs`. +impl, B: FieldOps> CircleEvaluation { + // TODO(spapini): Remove. Is this even used. + pub fn get_at(&self, point_index: CirclePointIndex) -> F { + self.values + .at(self.domain.find(point_index).expect("Not in domain")) + } + + pub fn bit_reverse(mut self) -> CircleEvaluation { + B::bit_reverse_column(&mut self.values); + CircleEvaluation::new(self.domain, self.values) + } +} + +impl> CpuCircleEvaluation { + pub fn fetch_eval_on_coset(&self, coset: Coset) -> CosetSubEvaluation<'_, F> { + assert!(coset.log_size() <= self.domain.half_coset.log_size()); + if let Some(offset) = self.domain.half_coset.find(coset.initial_index) { + return CosetSubEvaluation::new( + &self.values[..self.domain.half_coset.size()], + offset, + coset.step_size / self.domain.half_coset.step_size, + ); + } + if let Some(offset) = self.domain.half_coset.conjugate().find(coset.initial_index) { + return CosetSubEvaluation::new( + &self.values[self.domain.half_coset.size()..], + offset, + (-coset.step_size) / self.domain.half_coset.step_size, + ); + } + panic!("Coset not found in domain"); + } +} + +impl CircleEvaluation { + /// Creates a [CircleEvaluation] from values ordered according to + /// [CanonicCoset]. For example, the canonic coset might look like this: + /// G_8, G_8 + G_4, G_8 + 2G_4, G_8 + 3G_4. + /// The circle domain will be ordered like this: + /// G_8, G_8 + 2G_4, -G_8, -G_8 - 2G_4. + pub fn new_canonical_ordered(coset: CanonicCoset, values: Col) -> Self { + B::new_canonical_ordered(coset, values) + } + + /// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation. + pub fn interpolate(self) -> CirclePoly { + let coset = self.domain.half_coset; + B::interpolate(self, &B::precompute_twiddles(coset)) + } + + /// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation, using + /// precomputed twiddles. + pub fn interpolate_with_twiddles(self, twiddles: &TwiddleTree) -> CirclePoly { + B::interpolate(self, twiddles) + } +} + +impl, F: ExtensionOf> CircleEvaluation { + pub fn bit_reverse(mut self) -> CircleEvaluation { + B::bit_reverse_column(&mut self.values); + CircleEvaluation::new(self.domain, self.values) + } + + pub fn get_at(&self, point_index: CirclePointIndex) -> F { + self.values.at(bit_reverse_index( + self.domain.find(point_index).expect("Not in domain"), + self.domain.log_size(), + )) + } +} + +impl, F: ExtensionOf, EvalOrder> Deref + for CircleEvaluation +{ + type Target = Col; + + fn deref(&self) -> &Self::Target { + &self.values + } +} + +/// A part of a [CircleEvaluation], for a specific coset that is a subset of the circle domain. +pub struct CosetSubEvaluation<'a, F: ExtensionOf> { + evaluation: &'a [F], + offset: usize, + step: isize, +} + +impl<'a, F: ExtensionOf> CosetSubEvaluation<'a, F> { + fn new(evaluation: &'a [F], offset: usize, step: isize) -> Self { + assert!(evaluation.len().is_power_of_two()); + Self { + evaluation, + offset, + step, + } + } +} + +impl<'a, F: ExtensionOf> Index for CosetSubEvaluation<'a, F> { + type Output = F; + + fn index(&self, index: isize) -> &Self::Output { + let index = + ((self.offset as isize) + index * self.step) & ((self.evaluation.len() - 1) as isize); + &self.evaluation[index as usize] + } +} + +impl<'a, F: ExtensionOf> Index for CosetSubEvaluation<'a, F> { + type Output = F; + + fn index(&self, index: usize) -> &Self::Output { + &self[index as isize] + } +} + +#[cfg(test)] +mod tests { + use crate::core::backend::cpu::CpuCircleEvaluation; + use crate::core::circle::Coset; + use crate::core::fields::m31::BaseField; + use crate::core::poly::circle::CanonicCoset; + use crate::core::poly::NaturalOrder; + use crate::m31; + + #[test] + fn test_interpolate_non_canonic() { + let domain = CanonicCoset::new(3).circle_domain(); + assert_eq!(domain.log_size(), 3); + let evaluation = CpuCircleEvaluation::<_, NaturalOrder>::new( + domain, + (0..8).map(BaseField::from_u32_unchecked).collect(), + ) + .bit_reverse(); + let poly = evaluation.interpolate(); + for (i, point) in domain.iter().enumerate() { + assert_eq!(poly.eval_at_point(point.into_ef()), m31!(i as u32).into()); + } + } + + #[test] + fn test_interpolate_canonic() { + let coset = CanonicCoset::new(3); + let evaluation = CpuCircleEvaluation::new_canonical_ordered( + coset, + (0..8).map(BaseField::from_u32_unchecked).collect(), + ); + let poly = evaluation.interpolate(); + for (i, point) in Coset::odds(3).iter().enumerate() { + assert_eq!(poly.eval_at_point(point.into_ef()), m31!(i as u32).into()); + } + } + + #[test] + pub fn test_get_at_circle_evaluation() { + let domain = CanonicCoset::new(7).circle_domain(); + let values = (0..domain.size()).map(|i| m31!(i as u32)).collect(); + let circle_evaluation = CpuCircleEvaluation::<_, NaturalOrder>::new(domain, values); + let bit_reversed_circle_evaluation = circle_evaluation.clone().bit_reverse(); + for index in domain.iter_indices() { + assert_eq!( + circle_evaluation.get_at(index), + bit_reversed_circle_evaluation.get_at(index) + ); + } + } + + #[test] + fn test_sub_evaluation() { + let domain = CanonicCoset::new(7).circle_domain(); + let values = (0..domain.size()).map(|i| m31!(i as u32)).collect(); + let circle_evaluation = CpuCircleEvaluation::new(domain, values); + let coset = Coset::new(domain.index_at(17), 3); + let sub_eval = circle_evaluation.fetch_eval_on_coset(coset); + for i in 0..coset.size() { + assert_eq!(sub_eval[i], circle_evaluation.get_at(coset.index_at(i))); + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/circle/mod.rs b/Stwo_wrapper/crates/prover/src/core/poly/circle/mod.rs new file mode 100644 index 0000000..f2532d5 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/poly/circle/mod.rs @@ -0,0 +1,56 @@ +mod canonic; +mod domain; +mod evaluation; +mod ops; +mod poly; +mod secure_poly; + +pub use canonic::CanonicCoset; +pub use domain::{CircleDomain, MAX_CIRCLE_DOMAIN_LOG_SIZE}; +pub use evaluation::{CircleEvaluation, CosetSubEvaluation}; +pub use ops::PolyOps; +pub use poly::CirclePoly; +pub use secure_poly::{SecureCirclePoly, SecureEvaluation}; + +#[cfg(test)] +mod tests { + use super::CanonicCoset; + use crate::core::backend::cpu::CpuCircleEvaluation; + use crate::core::fields::m31::BaseField; + use crate::core::utils::bit_reverse_index; + + #[test] + fn test_interpolate_and_eval() { + let domain = CanonicCoset::new(3).circle_domain(); + assert_eq!(domain.log_size(), 3); + let evaluation = + CpuCircleEvaluation::new(domain, (0..8).map(BaseField::from_u32_unchecked).collect()); + let poly = evaluation.clone().interpolate(); + let evaluation2 = poly.evaluate(domain); + assert_eq!(evaluation.values, evaluation2.values); + } + + #[test] + fn is_canonic_valid_domain() { + let canonic_domain = CanonicCoset::new(4).circle_domain(); + + assert!(canonic_domain.is_canonic()); + } + + #[test] + pub fn test_bit_reverse_indices() { + let log_domain_size = 7; + let log_small_domain_size = 5; + let domain = CanonicCoset::new(log_domain_size); + let small_domain = CanonicCoset::new(log_small_domain_size); + let n_folds = log_domain_size - log_small_domain_size; + for i in 0..2usize.pow(log_domain_size) { + let point = domain.at(bit_reverse_index(i, log_domain_size)); + let small_point = small_domain.at(bit_reverse_index( + i / 2usize.pow(n_folds), + log_small_domain_size, + )); + assert_eq!(point.repeated_double(n_folds), small_point); + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/circle/ops.rs b/Stwo_wrapper/crates/prover/src/core/poly/circle/ops.rs new file mode 100644 index 0000000..40b86cb --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/poly/circle/ops.rs @@ -0,0 +1,48 @@ +use super::{CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly}; +use crate::core::backend::Col; +use crate::core::circle::{CirclePoint, Coset}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldOps; +use crate::core::poly::twiddles::TwiddleTree; +use crate::core::poly::BitReversedOrder; + +/// Operations on BaseField polynomials. +pub trait PolyOps: FieldOps + Sized { + // TODO(spapini): Use a column instead of this type. + /// The type for precomputed twiddles. + type Twiddles; + + /// Creates a [CircleEvaluation] from values ordered according to [CanonicCoset]. + /// Used by the [`CircleEvaluation::new_canonical_ordered()`] function. + fn new_canonical_ordered( + coset: CanonicCoset, + values: Col, + ) -> CircleEvaluation; + + /// Computes a minimal [CirclePoly] that evaluates to the same values as this evaluation. + /// Used by the [`CircleEvaluation::interpolate()`] function. + fn interpolate( + eval: CircleEvaluation, + itwiddles: &TwiddleTree, + ) -> CirclePoly; + + /// Evaluates the polynomial at a single point. + /// Used by the [`CirclePoly::eval_at_point()`] function. + fn eval_at_point(poly: &CirclePoly, point: CirclePoint) -> SecureField; + + /// Extends the polynomial to a larger degree bound. + /// Used by the [`CirclePoly::extend()`] function. + fn extend(poly: &CirclePoly, log_size: u32) -> CirclePoly; + + /// Evaluates the polynomial at all points in the domain. + /// Used by the [`CirclePoly::evaluate()`] function. + fn evaluate( + poly: &CirclePoly, + domain: CircleDomain, + twiddles: &TwiddleTree, + ) -> CircleEvaluation; + + /// Precomputes twiddles for a given coset. + fn precompute_twiddles(coset: Coset) -> TwiddleTree; +} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/circle/poly.rs b/Stwo_wrapper/crates/prover/src/core/poly/circle/poly.rs new file mode 100644 index 0000000..c10fc5e --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/poly/circle/poly.rs @@ -0,0 +1,118 @@ +use super::{CircleDomain, CircleEvaluation, PolyOps}; +use crate::core::backend::{Col, Column}; +use crate::core::circle::CirclePoint; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldOps; +use crate::core::poly::twiddles::TwiddleTree; +use crate::core::poly::BitReversedOrder; + +/// A polynomial defined on a [CircleDomain]. +#[derive(Clone, Debug)] +pub struct CirclePoly> { + /// Coefficients of the polynomial in the FFT basis. + /// Note: These are not the coefficients of the polynomial in the standard + /// monomial basis. The FFT basis is a tensor product of the twiddles: + /// y, x, pi(x), pi^2(x), ..., pi^{log_size-2}(x). + /// pi(x) := 2x^2 - 1. + pub coeffs: Col, + /// The number of coefficients stored as `log2(len(coeffs))`. + log_size: u32, +} + +impl CirclePoly { + /// Creates a new circle polynomial. + /// + /// Coefficients must be in the circle IFFT algorithm's basis stored in bit-reversed order. + /// + /// # Panics + /// + /// Panics if the number of coefficients isn't a power of two. + pub fn new(coeffs: Col) -> Self { + assert!(coeffs.len().is_power_of_two()); + let log_size = coeffs.len().ilog2(); + Self { log_size, coeffs } + } + + pub fn log_size(&self) -> u32 { + self.log_size + } + + /// Evaluates the polynomial at a single point. + pub fn eval_at_point(&self, point: CirclePoint) -> SecureField { + B::eval_at_point(self, point) + } + + /// Extends the polynomial to a larger degree bound. + pub fn extend(&self, log_size: u32) -> Self { + B::extend(self, log_size) + } + + /// Evaluates the polynomial at all points in the domain. + pub fn evaluate( + &self, + domain: CircleDomain, + ) -> CircleEvaluation { + B::evaluate(self, domain, &B::precompute_twiddles(domain.half_coset)) + } + + /// Evaluates the polynomial at all points in the domain, using precomputed twiddles. + pub fn evaluate_with_twiddles( + &self, + domain: CircleDomain, + twiddles: &TwiddleTree, + ) -> CircleEvaluation { + B::evaluate(self, domain, twiddles) + } +} + +#[cfg(test)] +impl crate::core::backend::cpu::CpuCirclePoly { + pub fn is_in_fft_space(&self, log_fft_size: u32) -> bool { + use num_traits::Zero; + + let mut coeffs = self.coeffs.clone(); + while coeffs.last() == Some(&BaseField::zero()) { + coeffs.pop(); + } + + // The highest degree monomial in a fft-space polynomial is x^{(n/2) - 1}y. + // And it is at offset (n-1). x^{(n/2)} is at offset `n`, and is not allowed. + let highest_degree_allowed_monomial_offset = 1 << log_fft_size; + coeffs.len() <= highest_degree_allowed_monomial_offset + } + + /// Fri space is the space of polynomials of total degree n/2. + /// Highest degree monomials are x^{n/2} and x^{(n/2)-1}y. + pub fn is_in_fri_space(&self, log_fft_size: u32) -> bool { + use num_traits::Zero; + + let mut coeffs = self.coeffs.clone(); + while coeffs.last() == Some(&BaseField::zero()) { + coeffs.pop(); + } + + // x^{n/2} is at offset `n`, and is the last offset allowed to be non-zero. + let highest_degree_monomial_offset = (1 << log_fft_size) + 1; + coeffs.len() <= highest_degree_monomial_offset + } +} + +#[cfg(test)] +mod tests { + use crate::core::backend::cpu::CpuCirclePoly; + use crate::core::circle::CirclePoint; + use crate::core::fields::m31::BaseField; + + #[test] + fn test_circle_poly_extend() { + let poly = CpuCirclePoly::new((0..16).map(BaseField::from_u32_unchecked).collect()); + let extended = poly.clone().extend(8); + let random_point = CirclePoint::get_point(21903); + + assert_eq!( + poly.eval_at_point(random_point), + extended.eval_at_point(random_point) + ); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/circle/secure_poly.rs b/Stwo_wrapper/crates/prover/src/core/poly/circle/secure_poly.rs new file mode 100644 index 0000000..a503bd2 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/poly/circle/secure_poly.rs @@ -0,0 +1,118 @@ +use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; + +use super::{CircleDomain, CircleEvaluation, CirclePoly, PolyOps}; +use crate::core::backend::CpuBackend; +use crate::core::circle::CirclePoint; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE}; +use crate::core::fields::FieldOps; +use crate::core::poly::twiddles::TwiddleTree; +use crate::core::poly::BitReversedOrder; + +pub struct SecureCirclePoly>(pub [CirclePoly; SECURE_EXTENSION_DEGREE]); + +impl SecureCirclePoly { + pub fn eval_at_point(&self, point: CirclePoint) -> SecureField { + SecureField::from_partial_evals(self.eval_columns_at_point(point)) + } + + pub fn eval_columns_at_point( + &self, + point: CirclePoint, + ) -> [SecureField; SECURE_EXTENSION_DEGREE] { + [ + self[0].eval_at_point(point), + self[1].eval_at_point(point), + self[2].eval_at_point(point), + self[3].eval_at_point(point), + ] + } + + pub fn log_size(&self) -> u32 { + self[0].log_size() + } + + pub fn evaluate_with_twiddles( + &self, + domain: CircleDomain, + twiddles: &TwiddleTree, + ) -> SecureEvaluation { + let polys = self.0.each_ref(); + let columns = polys.map(|poly| poly.evaluate_with_twiddles(domain, twiddles).values); + SecureEvaluation::new(domain, SecureColumnByCoords { columns }) + } +} + +impl> Deref for SecureCirclePoly { + type Target = [CirclePoly; SECURE_EXTENSION_DEGREE]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +/// A [`SecureField`] evaluation defined on a [CircleDomain]. +/// +/// The evaluation is stored as a column major array of [`SECURE_EXTENSION_DEGREE`] many base field +/// evaluations. The evaluations are ordered according to the [CircleDomain] ordering. +#[derive(Clone)] +pub struct SecureEvaluation, EvalOrder> { + pub domain: CircleDomain, + pub values: SecureColumnByCoords, + _eval_order: PhantomData, +} + +impl, EvalOrder> SecureEvaluation { + pub fn new(domain: CircleDomain, values: SecureColumnByCoords) -> Self { + assert_eq!(domain.size(), values.len()); + Self { + domain, + values, + _eval_order: PhantomData, + } + } + + pub fn into_coordinate_evals( + self, + ) -> [CircleEvaluation; SECURE_EXTENSION_DEGREE] { + let Self { domain, values, .. } = self; + values.columns.map(|c| CircleEvaluation::new(domain, c)) + } +} + +impl, EvalOrder> Deref for SecureEvaluation { + type Target = SecureColumnByCoords; + + fn deref(&self) -> &Self::Target { + &self.values + } +} + +impl, EvalOrder> DerefMut for SecureEvaluation { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.values + } +} + +impl SecureEvaluation { + /// Computes a minimal [`SecureCirclePoly`] that evaluates to the same values as this + /// evaluation, using precomputed twiddles. + pub fn interpolate_with_twiddles(self, twiddles: &TwiddleTree) -> SecureCirclePoly { + let domain = self.domain; + let cols = self.values.columns; + SecureCirclePoly(cols.map(|c| { + CircleEvaluation::::new(domain, c) + .interpolate_with_twiddles(twiddles) + })) + } +} + +impl From> + for SecureEvaluation +{ + fn from(evaluation: CircleEvaluation) -> Self { + Self::new(evaluation.domain, evaluation.values.into_iter().collect()) + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/line.rs b/Stwo_wrapper/crates/prover/src/core/poly/line.rs new file mode 100644 index 0000000..4dac73e --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/poly/line.rs @@ -0,0 +1,408 @@ +use std::cmp::Ordering; +use std::fmt::Debug; +use std::iter::Map; +use std::ops::{Deref, DerefMut}; + +use itertools::Itertools; +use num_traits::Zero; +use serde::{Deserialize, Serialize}; + +use super::circle::CircleDomain; +use super::utils::fold; +use crate::core::backend::{ColumnOps, CpuBackend}; +use crate::core::circle::{CirclePoint, Coset, CosetIterator}; +use crate::core::fft::ibutterfly; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SecureColumnByCoords; +use crate::core::fields::{ExtensionOf, FieldExpOps, FieldOps}; +use crate::core::utils::bit_reverse; + +/// Domain comprising of the x-coordinates of points in a [Coset]. +/// +/// For use with univariate polynomials. +#[derive(Copy, Clone, Debug)] +pub struct LineDomain { + coset: Coset, +} + +impl LineDomain { + /// Returns a domain comprising of the x-coordinates of points in a coset. + /// + /// # Panics + /// + /// Panics if the coset items don't have unique x-coordinates. + pub fn new(coset: Coset) -> Self { + match coset.size().cmp(&2) { + Ordering::Less => {} + Ordering::Equal => { + // If the coset with two points contains (0, y) then the coset is {(0, y), (0, -y)}. + assert!(!coset.initial.x.is_zero(), "coset x-coordinates not unique"); + } + Ordering::Greater => { + // Let our coset be `E = c + ` with `|E| > 2` then: + // 1. if `ord(c) <= ord(G)` the coset contains two points at x=0 + // 2. if `ord(c) = 2 * ord(G)` then `c` and `-c` are in our coset + assert!( + coset.initial.log_order() >= coset.step.log_order() + 2, + "coset x-coordinates not unique" + ); + } + } + Self { coset } + } + + /// Returns the `i`th domain element. + pub fn at(&self, i: usize) -> BaseField { + self.coset.at(i).x + } + + /// Returns the size of the domain. + pub fn size(&self) -> usize { + self.coset.size() + } + + /// Returns the log size of the domain. + pub fn log_size(&self) -> u32 { + self.coset.log_size() + } + + /// Returns an iterator over elements in the domain. + pub fn iter(&self) -> LineDomainIterator { + self.coset.iter().map(|p| p.x) + } + + /// Returns a new domain comprising of all points in current domain doubled. + pub fn double(&self) -> Self { + Self { + coset: self.coset.double(), + } + } + + /// Returns the domain's underlying coset. + pub fn coset(&self) -> Coset { + self.coset + } +} + +impl IntoIterator for LineDomain { + type Item = BaseField; + type IntoIter = LineDomainIterator; + + /// Returns an iterator over elements in the domain. + fn into_iter(self) -> LineDomainIterator { + self.iter() + } +} + +impl From for LineDomain { + fn from(domain: CircleDomain) -> Self { + Self { + coset: domain.half_coset, + } + } +} + +/// An iterator over the x-coordinates of points in a coset. +type LineDomainIterator = + Map>, fn(CirclePoint) -> BaseField>; + +/// A univariate polynomial defined on a [LineDomain]. +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize)] +pub struct LinePoly { + /// Coefficients of the polynomial in [line_ifft] algorithm's basis. + /// + /// The coefficients are stored in bit-reversed order. + pub coeffs: Vec, + /// The number of coefficients stored as `log2(len(coeffs))`. + log_size: u32, +} + +impl LinePoly { + /// Creates a new line polynomial from bit reversed coefficients. + /// + /// # Panics + /// + /// Panics if the number of coefficients is not a power of two. + pub fn new(coeffs: Vec) -> Self { + assert!(coeffs.len().is_power_of_two()); + let log_size = coeffs.len().ilog2(); + Self { coeffs, log_size } + } + + /// Evaluates the polynomial at a single point. + pub fn eval_at_point(&self, mut x: SecureField) -> SecureField { + let mut doublings = Vec::new(); + for _ in 0..self.log_size { + doublings.push(x); + x = CirclePoint::double_x(x); + } + fold(&self.coeffs, &doublings) + } + + /// Returns the number of coefficients. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + // `.len().ilog2()` is a common operation. By returning the length like so the compiler + // optimizes `.len().ilog2()` to a load of `log_size` instead of a branch and a bit count. + debug_assert_eq!(self.coeffs.len(), 1 << self.log_size); + 1 << self.log_size + } + + /// Returns the polynomial's coefficients in their natural order. + pub fn into_ordered_coefficients(mut self) -> Vec { + bit_reverse(&mut self.coeffs); + self.coeffs + } + + /// Creates a new line polynomial from coefficients in their natural order. + /// + /// # Panics + /// + /// Panics if the number of coefficients is not a power of two. + pub fn from_ordered_coefficients(mut coeffs: Vec) -> Self { + bit_reverse(&mut coeffs); + Self::new(coeffs) + } +} + +impl Deref for LinePoly { + type Target = [SecureField]; + + fn deref(&self) -> &[SecureField] { + &self.coeffs + } +} + +impl DerefMut for LinePoly { + fn deref_mut(&mut self) -> &mut [SecureField] { + &mut self.coeffs + } +} + +/// Evaluations of a univariate polynomial on a [LineDomain]. +// TODO(andrew): Remove EvalOrder. Bit-reversed evals are only necessary since LineEvaluation is +// only used by FRI where evaluations are in bit-reversed order. +// TODO(spapini): Remove pub. +#[derive(Clone, Debug)] +pub struct LineEvaluation> { + /// Evaluations of a univariate polynomial on `domain`. + pub values: SecureColumnByCoords, + domain: LineDomain, +} + +impl> LineEvaluation { + /// Creates new [LineEvaluation] from a set of polynomial evaluations over a [LineDomain]. + /// + /// # Panics + /// + /// Panics if the number of evaluations does not match the size of the domain. + pub fn new(domain: LineDomain, values: SecureColumnByCoords) -> Self { + assert_eq!(values.len(), domain.size()); + Self { values, domain } + } + + pub fn new_zero(domain: LineDomain) -> Self { + Self::new(domain, SecureColumnByCoords::zeros(domain.size())) + } + + /// Returns the number of evaluations. + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + 1 << self.domain.log_size() + } + + pub fn domain(&self) -> LineDomain { + self.domain + } + + /// Clones the values into a new line evaluation in the CPU. + pub fn to_cpu(&self) -> LineEvaluation { + LineEvaluation::new(self.domain, self.values.to_cpu()) + } +} + +impl LineEvaluation { + /// Interpolates the polynomial as evaluations on `domain`. + pub fn interpolate(self) -> LinePoly { + let mut values = self.values.into_iter().collect_vec(); + CpuBackend::bit_reverse_column(&mut values); + line_ifft(&mut values, self.domain); + // Normalize the coefficients. + let len_inv = BaseField::from(values.len()).inverse(); + values.iter_mut().for_each(|v| *v *= len_inv); + LinePoly::new(values) + } +} + +/// Performs a univariate IFFT on a polynomial's evaluation over a [LineDomain]. +/// +/// This is not the standard univariate IFFT, because [LineDomain] is not a cyclic group. +/// +/// The transform happens in-place. `values` should be the evaluations of a polynomial over `domain` +/// in their natural order. After the transformation `values` becomes the coefficients of the +/// polynomial stored in bit-reversed order. +/// +/// For performance reasons and flexibility the normalization of the coefficients is omitted. The +/// normalized coefficients can be obtained by scaling all coefficients by `1 / len(values)`. +/// +/// This algorithm does not return coefficients in the standard monomial basis but rather returns +/// coefficients in a basis relating to the circle's x-coordinate doubling map `pi(x) = 2x^2 - 1` +/// i.e. +/// +/// ```text +/// B = { 1 } * { x } * { pi(x) } * { pi(pi(x)) } * ... +/// = { 1, x, pi(x), pi(x) * x, pi(pi(x)), pi(pi(x)) * x, pi(pi(x)) * pi(x), ... } +/// ``` +/// +/// # Panics +/// +/// Panics if the number of values doesn't match the size of the domain. +fn line_ifft>(values: &mut [F], mut domain: LineDomain) { + assert_eq!(values.len(), domain.size()); + while domain.size() > 1 { + for chunk in values.chunks_exact_mut(domain.size()) { + let (l, r) = chunk.split_at_mut(domain.size() / 2); + for (i, x) in domain.iter().take(domain.size() / 2).enumerate() { + ibutterfly(&mut l[i], &mut r[i], x.inverse()); + } + } + domain = domain.double(); + } +} + +#[cfg(test)] +mod tests { + type B = CpuBackend; + + use itertools::Itertools; + + use super::LineDomain; + use crate::core::backend::{ColumnOps, CpuBackend}; + use crate::core::circle::{CirclePoint, Coset}; + use crate::core::fields::m31::BaseField; + use crate::core::poly::line::{LineEvaluation, LinePoly}; + use crate::core::utils::bit_reverse_index; + + #[test] + #[should_panic] + fn bad_line_domain() { + // This coset doesn't have points with unique x-coordinates. + let coset = Coset::odds(2); + + LineDomain::new(coset); + } + + #[test] + fn line_domain_of_size_two_works() { + const LOG_SIZE: u32 = 1; + let coset = Coset::subgroup(LOG_SIZE); + + LineDomain::new(coset); + } + + #[test] + fn line_domain_of_size_one_works() { + const LOG_SIZE: u32 = 0; + let coset = Coset::subgroup(LOG_SIZE); + + LineDomain::new(coset); + } + + #[test] + fn line_domain_size_is_correct() { + const LOG_SIZE: u32 = 8; + let coset = Coset::half_odds(LOG_SIZE); + let domain = LineDomain::new(coset); + + let size = domain.size(); + + assert_eq!(size, 1 << LOG_SIZE); + } + + #[test] + fn line_domain_coset_returns_the_coset() { + let coset = Coset::half_odds(5); + let domain = LineDomain::new(coset); + + assert_eq!(domain.coset(), coset); + } + + #[test] + fn line_domain_double_works() { + const LOG_SIZE: u32 = 8; + let coset = Coset::half_odds(LOG_SIZE); + let domain = LineDomain::new(coset); + + let doubled_domain = domain.double(); + + assert_eq!(doubled_domain.size(), 1 << (LOG_SIZE - 1)); + assert_eq!(doubled_domain.at(0), CirclePoint::double_x(domain.at(0))); + assert_eq!(doubled_domain.at(1), CirclePoint::double_x(domain.at(1))); + } + + #[test] + fn line_domain_iter_works() { + const LOG_SIZE: u32 = 8; + let coset = Coset::half_odds(LOG_SIZE); + let domain = LineDomain::new(coset); + + let elements = domain.iter().collect::>(); + + assert_eq!(elements.len(), domain.size()); + for (i, element) in elements.into_iter().enumerate() { + assert_eq!(element, domain.at(i), "mismatch at {i}"); + } + } + + #[test] + fn line_evaluation_interpolation() { + let poly = LinePoly::new(vec![ + BaseField::from(7).into(), // 7 * 1 + BaseField::from(9).into(), // 9 * pi(x) + BaseField::from(5).into(), // 5 * x + BaseField::from(3).into(), // 3 * pi(x)*x + ]); + let coset = Coset::half_odds(poly.len().ilog2()); + let domain = LineDomain::new(coset); + let mut values = domain + .iter() + .map(|x| { + let pi_x = CirclePoint::double_x(x); + poly.coeffs[0] + + poly.coeffs[1] * pi_x + + poly.coeffs[2] * x + + poly.coeffs[3] * pi_x * x + }) + .collect_vec(); + CpuBackend::bit_reverse_column(&mut values); + let evals = LineEvaluation::::new(domain, values.into_iter().collect()); + + let interpolated_poly = evals.interpolate(); + + assert_eq!(interpolated_poly.coeffs, poly.coeffs); + } + + #[test] + fn line_polynomial_eval_at_point() { + const LOG_SIZE: u32 = 2; + let coset = Coset::half_odds(LOG_SIZE); + let domain = LineDomain::new(coset); + let evals = LineEvaluation::::new( + domain, + (0..1 << LOG_SIZE) + .map(BaseField::from) + .map(|x| x.into()) + .collect(), + ); + let poly = evals.clone().interpolate(); + + for (i, x) in domain.iter().enumerate() { + assert_eq!( + poly.eval_at_point(x.into()), + evals.values.at(bit_reverse_index(i, domain.log_size())), + "mismatch at {i}" + ); + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/mod.rs b/Stwo_wrapper/crates/prover/src/core/poly/mod.rs new file mode 100644 index 0000000..301c698 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/poly/mod.rs @@ -0,0 +1,14 @@ +pub mod circle; +pub mod line; +// TODO(spapini): Remove pub, when LinePoly moved to the backend as well, and we can move the fold +// function there. +pub mod twiddles; +pub mod utils; + +/// Bit-reversed evaluation ordering. +#[derive(Copy, Clone, Debug)] +pub struct BitReversedOrder; + +/// Natural evaluation ordering (same order as domain). +#[derive(Copy, Clone, Debug)] +pub struct NaturalOrder; diff --git a/Stwo_wrapper/crates/prover/src/core/poly/twiddles.rs b/Stwo_wrapper/crates/prover/src/core/poly/twiddles.rs new file mode 100644 index 0000000..53ea476 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/poly/twiddles.rs @@ -0,0 +1,13 @@ +use super::circle::PolyOps; +use crate::core::circle::Coset; + +/// Precomputed twiddles for a specific coset tower. +/// A coset tower is every repeated doubling of a `root_coset`. +/// The largest CircleDomain that can be ffted using these twiddles is one with `root_coset` as +/// its `half_coset`. +pub struct TwiddleTree { + pub root_coset: Coset, + // TODO(spapini): Represent a slice, and grabbing, in a generic way + pub twiddles: B::Twiddles, + pub itwiddles: B::Twiddles, +} diff --git a/Stwo_wrapper/crates/prover/src/core/poly/utils.rs b/Stwo_wrapper/crates/prover/src/core/poly/utils.rs new file mode 100644 index 0000000..bc0dece --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/poly/utils.rs @@ -0,0 +1,115 @@ +use super::line::LineDomain; +use crate::core::fields::{ExtensionOf, Field}; + +/// Folds values recursively in `O(n)` by a hierarchical application of folding factors. +/// +/// i.e. folding `n = 8` values with `folding_factors = [x, y, z]`: +/// +/// ```text +/// n2=n1+x*n2 +/// / \ +/// n1=n3+y*n4 n2=n5+y*n6 +/// / \ / \ +/// n3=a+z*b n4=c+z*d n5=e+z*f n6=g+z*h +/// / \ / \ / \ / \ +/// a b c d e f g h +/// ``` +/// +/// # Panics +/// +/// Panics if the number of values is not a power of two or if an incorrect number of of folding +/// factors is provided. +// TODO(Andrew): Can be made to run >10x faster by unrolling lower layers of recursion +pub fn fold>(values: &[F], folding_factors: &[E]) -> E { + let n = values.len(); + assert_eq!(n, 1 << folding_factors.len()); + if n == 1 { + return values[0].into(); + } + let (lhs_values, rhs_values) = values.split_at(n / 2); + let (&folding_factor, folding_factors) = folding_factors.split_first().unwrap(); + let lhs_val = fold(lhs_values, folding_factors); + let rhs_val = fold(rhs_values, folding_factors); + lhs_val + rhs_val * folding_factor +} + +/// Repeats each value sequentially `duplicity` many times. +/// +/// # Examples +/// +/// ```rust +/// # use stwo_prover::core::poly::utils::repeat_value; +/// assert_eq!(repeat_value(&[1, 2, 3], 2), vec![1, 1, 2, 2, 3, 3]); +/// ``` +pub fn repeat_value(values: &[T], duplicity: usize) -> Vec { + let n = values.len(); + let mut res: Vec = Vec::with_capacity(n * duplicity); + + // Fill each chunk with its corresponding value. + for &v in values { + for _ in 0..duplicity { + res.push(v) + } + } + + res +} + +/// Computes the line twiddles for a [`CircleDomain`] or a [`LineDomain`] from the precomputed +/// twiddles tree. +/// +/// [`CircleDomain`]: super::circle::CircleDomain +pub fn domain_line_twiddles_from_tree( + domain: impl Into, + twiddle_buffer: &[T], +) -> Vec<&[T]> { + let domain = domain.into(); + debug_assert!(domain.coset().size() <= twiddle_buffer.len()); + (0..domain.coset().log_size()) + .map(|i| { + let len = 1 << i; + &twiddle_buffer[twiddle_buffer.len() - len * 2..twiddle_buffer.len() - len] + }) + .rev() + .collect() +} + +#[cfg(test)] +mod tests { + use super::repeat_value; + use crate::core::poly::circle::CanonicCoset; + use crate::core::poly::line::LineDomain; + use crate::core::poly::utils::domain_line_twiddles_from_tree; + + #[test] + fn repeat_value_0_times_works() { + assert!(repeat_value(&[1, 2, 3], 0).is_empty()); + } + + #[test] + fn repeat_value_2_times_works() { + assert_eq!(repeat_value(&[1, 2, 3], 2), vec![1, 1, 2, 2, 3, 3]); + } + + #[test] + fn repeat_value_3_times_works() { + assert_eq!(repeat_value(&[1, 2], 3), vec![1, 1, 1, 2, 2, 2]); + } + + #[test] + fn test_domain_line_twiddles_works() { + let domain: LineDomain = CanonicCoset::new(4).circle_domain().into(); + let twiddles = domain_line_twiddles_from_tree(domain, &[0, 1, 2, 3, 4, 5, 6, 7]); + assert_eq!(twiddles.len(), 3); + assert_eq!(twiddles[0], &[0, 1, 2, 3]); + assert_eq!(twiddles[1], &[4, 5]); + assert_eq!(twiddles[2], &[6]); + } + + #[test] + #[should_panic] + fn test_domain_line_twiddles_fails() { + let domain: LineDomain = CanonicCoset::new(5).circle_domain().into(); + domain_line_twiddles_from_tree(domain, &[0, 1, 2, 3, 4, 5, 6, 7]); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/proof_of_work.rs b/Stwo_wrapper/crates/prover/src/core/proof_of_work.rs new file mode 100644 index 0000000..1c61ad8 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/proof_of_work.rs @@ -0,0 +1,7 @@ +use crate::core::channel::Channel; + +pub trait GrindOps { + /// Searches for a nonce s.t. mixing it to the channel makes the digest have `pow_bits` leading + /// zero bits. + fn grind(channel: &C, pow_bits: u32) -> u64; +} diff --git a/Stwo_wrapper/crates/prover/src/core/prover/mod.rs b/Stwo_wrapper/crates/prover/src/core/prover/mod.rs new file mode 100644 index 0000000..30493fc --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/prover/mod.rs @@ -0,0 +1,186 @@ +use std::array; + +use thiserror::Error; +use tracing::{span, Level}; + +use super::air::{Component, ComponentProver, ComponentProvers, Components}; +use super::backend::BackendForChannel; +use super::channel::MerkleChannel; +use super::fields::secure_column::SECURE_EXTENSION_DEGREE; +use super::fri::FriVerificationError; +use super::pcs::{CommitmentSchemeProof, TreeVec}; +use super::vcs::ops::MerkleHasher; +use crate::core::backend::CpuBackend; +use crate::core::channel::Channel; +use crate::core::circle::CirclePoint; +use crate::core::fields::qm31::SecureField; +use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier}; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; +use crate::core::vcs::verifier::MerkleVerificationError; + +#[derive(Debug)] +pub struct StarkProof { + pub commitments: TreeVec, + pub commitment_scheme_proof: CommitmentSchemeProof, +} + +#[derive(Debug)] +pub struct AdditionalProofData { + pub composition_polynomial_oods_value: SecureField, + pub composition_polynomial_random_coeff: SecureField, + pub oods_point: CirclePoint, + pub oods_quotients: Vec>, +} + +pub fn prove, MC: MerkleChannel>( + components: &[&dyn ComponentProver], + channel: &mut MC::C, + commitment_scheme: &mut CommitmentSchemeProver<'_, B, MC>, +) -> Result, ProvingError> { + let component_provers = ComponentProvers(components.to_vec()); + let trace = commitment_scheme.trace(); + + // Evaluate and commit on composition polynomial. (Hash(commitment[0]) + let random_coeff = channel.draw_felt(); + + let span = span!(Level::INFO, "Composition").entered(); + let span1 = span!(Level::INFO, "Generation").entered(); + // Linear random combination of the trace polynomials + let composition_polynomial_poly = + component_provers.compute_composition_polynomial(random_coeff, &trace); + span1.exit(); + + // Before tree_builder is only column polynomial + let mut tree_builder = commitment_scheme.tree_builder(); + // Extend it with composition polynomial + tree_builder.extend_polys(composition_polynomial_poly.to_vec()); + // Reevaluate every polynomial (including composition) and absorb the root (Transcript <-- root) + tree_builder.commit(channel); + span.exit(); + + // Draw OODS point. + let oods_point = CirclePoint::::get_random_point(channel); + + // Get mask sample points relative to oods point. + let mut sample_points = component_provers.components().mask_points(oods_point); + // Add the composition polynomial mask points. + sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]); + + // Prove the trace and composition OODS values, and retrieve them. + let commitment_scheme_proof = commitment_scheme.prove_values(sample_points, channel); + + let sampled_oods_values = &commitment_scheme_proof.sampled_values; + let composition_oods_eval = extract_composition_eval(sampled_oods_values).unwrap(); + + // Evaluate composition polynomial at OODS point and check that it matches the trace OODS + // values. This is a sanity check. + if composition_oods_eval + != component_provers + .components() + .eval_composition_polynomial_at_point(oods_point, sampled_oods_values, random_coeff) + { + return Err(ProvingError::ConstraintsNotSatisfied); + } + + Ok(StarkProof { + commitments: commitment_scheme.roots(), + commitment_scheme_proof, + }) +} + +pub fn verify( + components: &[&dyn Component], + channel: &mut MC::C, + commitment_scheme: &mut CommitmentSchemeVerifier, + proof: StarkProof, +) -> Result<(), VerificationError> { + let components = Components(components.to_vec()); + let random_coeff = channel.draw_felt(); + + // Read composition polynomial commitment. + commitment_scheme.commit( + *proof.commitments.last().unwrap(), + &[components.composition_log_degree_bound(); SECURE_EXTENSION_DEGREE], + channel, + ); + + // Draw OODS point z. + let oods_point = CirclePoint::::get_random_point(channel); + + // Get mask sample points relative to oods point. + let mut sample_points = components.mask_points(oods_point); + // Add the composition polynomial mask points. + sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]); + + let sampled_oods_values = &proof.commitment_scheme_proof.sampled_values; + // Compute h0(z) + i * h1(z) + u * h2(z) + u*i* h3(z) + let composition_oods_eval = extract_composition_eval(sampled_oods_values).map_err(|_| { + VerificationError::InvalidStructure("Unexpected sampled_values structure".to_string()) + })?; + + if composition_oods_eval + // Compute + != components.eval_composition_polynomial_at_point( + oods_point, + sampled_oods_values, + random_coeff, + ) + { + return Err(VerificationError::OodsNotMatching); + } + + commitment_scheme.verify_values(sample_points, proof.commitment_scheme_proof, channel) +} + +/// Extracts the composition trace evaluation from the mask. +fn extract_composition_eval( + mask: &TreeVec>>, +) -> Result { + // Last part of the proof + let mut composition_cols = mask.last().into_iter().flatten(); + + let coordinate_evals = array::try_from_fn(|_| { + //Each Secure element of the sampled values + let col = &**composition_cols.next().ok_or(InvalidOodsSampleStructure)?; + let [eval] = col.try_into().map_err(|_| InvalidOodsSampleStructure)?; + Ok(eval) + })?; + + // Too many columns. + if composition_cols.next().is_some() { + return Err(InvalidOodsSampleStructure); + } + // Computes [0] + [1] * i + [2] * u + [3] * i * u + + Ok(SecureField::from_partial_evals(coordinate_evals)) +} + +/// Error when the sampled values have an invalid structure. +#[derive(Clone, Copy, Debug)] +pub struct InvalidOodsSampleStructure; + +#[derive(Clone, Copy, Debug, Error)] +pub enum ProvingError { + #[error("Constraints not satisfied.")] + ConstraintsNotSatisfied, +} + +#[derive(Clone, Debug, Error)] +pub enum VerificationError { + #[error("Proof has invalid structure: {0}.")] + InvalidStructure(String), + #[error("{0} lookup values do not match.")] + InvalidLookup(String), + #[error(transparent)] + Merkle(#[from] MerkleVerificationError), + #[error( + "The composition polynomial OODS value does not match the trace OODS values + (DEEP-ALI failure)." + )] + OodsNotMatching, + #[error(transparent)] + Fri(#[from] FriVerificationError), + #[error("Proof of work verification failed.")] + ProofOfWork, +} diff --git a/Stwo_wrapper/crates/prover/src/core/queries.rs b/Stwo_wrapper/crates/prover/src/core/queries.rs new file mode 100644 index 0000000..934edfd --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/queries.rs @@ -0,0 +1,237 @@ +use std::collections::BTreeSet; +use std::ops::Deref; + +use itertools::Itertools; + +use super::channel::Channel; +use super::circle::Coset; +use super::poly::circle::CircleDomain; +use super::utils::bit_reverse_index; + +pub const UPPER_BOUND_QUERY_BYTES: usize = 4; + +/// An ordered set of query indices over a bit reversed [CircleDomain]. +#[derive(Debug, Clone)] +pub struct Queries { + pub positions: Vec, + pub log_domain_size: u32, +} + +impl Queries { + /// Randomizes a set of query indices uniformly over the range [0, 2^`log_query_size`). + pub fn generate(channel: &mut impl Channel, log_domain_size: u32, n_queries: usize) -> Self { + let mut queries = BTreeSet::new(); + let mut query_cnt = 0; + let max_query = (1 << log_domain_size) - 1; + loop { + let random_bytes = channel.draw_random_bytes(); + for chunk in random_bytes.chunks_exact(UPPER_BOUND_QUERY_BYTES) { + let query_bits = u32::from_le_bytes(chunk.try_into().unwrap()); + let quotient_query = query_bits & max_query; + queries.insert(quotient_query as usize); + query_cnt += 1; + if query_cnt == n_queries { + return Self { + positions: queries.into_iter().collect(), + log_domain_size, + }; + } + } + } + } + + // TODO docs + #[allow(clippy::missing_safety_doc)] + pub fn from_positions(positions: Vec, log_domain_size: u32) -> Self { + assert!(positions.is_sorted()); + assert!(positions.iter().all(|p| *p < (1 << log_domain_size))); + Self { + positions, + log_domain_size, + } + } + + /// Calculates the matching query indices in a folded domain (i.e each domain point is doubled) + /// given `self` (the queries of the original domain) and the number of folds between domains. + pub fn fold(&self, n_folds: u32) -> Self { + assert!(n_folds <= self.log_domain_size); + Self { + positions: self.iter().map(|q| q >> n_folds).dedup().collect(), + log_domain_size: self.log_domain_size - n_folds, + } + } + + pub fn opening_positions(&self, fri_step_size: u32) -> SparseSubCircleDomain { + assert!(fri_step_size > 0); + SparseSubCircleDomain { + domains: self + .iter() + .map(|q| SubCircleDomain { + coset_index: q >> fri_step_size, + log_size: fri_step_size, + }) + .dedup() + .collect(), + large_domain_log_size: self.log_domain_size, + } + } +} + +impl Deref for Queries { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.positions + } +} + +#[derive(Debug, Eq, PartialEq)] +pub struct SparseSubCircleDomain { + pub domains: Vec, + pub large_domain_log_size: u32, +} + +impl SparseSubCircleDomain { + pub fn flatten(&self) -> Vec { + self.iter() + .flat_map(|sub_circle_domain| sub_circle_domain.to_decommitment_positions()) + .collect() + } +} + +impl Deref for SparseSubCircleDomain { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.domains + } +} + +/// Represents a circle domain relative to a larger circle domain. The `initial_index` is the bit +/// reversed query index in the larger domain. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct SubCircleDomain { + pub coset_index: usize, + pub log_size: u32, +} + +impl SubCircleDomain { + /// Calculates the decommitment positions needed for each query given the fri step size. + pub fn to_decommitment_positions(&self) -> Vec { + (self.coset_index << self.log_size..(self.coset_index + 1) << self.log_size).collect() + } + + /// Returns the represented [CircleDomain]. + pub fn to_circle_domain(&self, query_domain: &CircleDomain) -> CircleDomain { + let query = bit_reverse_index(self.coset_index << self.log_size, query_domain.log_size()); + let initial_index = query_domain.index_at(query); + let half_coset = Coset::new(initial_index, self.log_size - 1); + CircleDomain::new(half_coset) + } +} + +#[cfg(test)] +mod tests { + use crate::core::channel::Blake2sChannel; + use crate::core::poly::circle::CanonicCoset; + use crate::core::queries::Queries; + use crate::core::utils::bit_reverse; + + #[test] + fn test_generate_queries() { + let channel = &mut Blake2sChannel::default(); + let log_query_size = 31; + let n_queries = 100; + + let queries = Queries::generate(channel, log_query_size, n_queries); + + assert!(queries.len() == n_queries); + for query in queries.iter() { + assert!(*query < 1 << log_query_size); + } + } + + #[test] + pub fn test_folded_queries() { + let log_domain_size = 7; + let domain = CanonicCoset::new(log_domain_size).circle_domain(); + let mut values = domain.iter().collect::>(); + bit_reverse(&mut values); + + let log_folded_domain_size = 5; + let folded_domain = CanonicCoset::new(log_folded_domain_size).circle_domain(); + let mut folded_values = folded_domain.iter().collect::>(); + bit_reverse(&mut folded_values); + + // Generate all possible queries. + let queries = Queries { + positions: (0..1 << log_domain_size).collect(), + log_domain_size, + }; + let n_folds = log_domain_size - log_folded_domain_size; + let ratio = 1 << n_folds; + + let folded_queries = queries.fold(n_folds); + let repeated_folded_queries = folded_queries + .iter() + .flat_map(|q| std::iter::repeat(q).take(ratio)); + for (query, folded_query) in queries.iter().zip(repeated_folded_queries) { + // Check only the x coordinate since folding might give you the conjugate point. + assert_eq!( + values[*query].repeated_double(n_folds).x, + folded_values[*folded_query].x + ); + } + } + + #[test] + pub fn test_conjugate_queries() { + let channel = &mut Blake2sChannel::default(); + let log_domain_size = 7; + let domain = CanonicCoset::new(log_domain_size).circle_domain(); + let mut values = domain.iter().collect::>(); + bit_reverse(&mut values); + + // Test random queries one by one because the conjugate queries are sorted. + for _ in 0..100 { + let query = Queries::generate(channel, log_domain_size, 1); + let conjugate_query = query[0] ^ 1; + let query_and_conjugate = query.opening_positions(1).flatten(); + let mut expected_query_and_conjugate = vec![query[0], conjugate_query]; + expected_query_and_conjugate.sort(); + assert_eq!(query_and_conjugate, expected_query_and_conjugate); + assert_eq!(values[query[0]], values[conjugate_query].conjugate()); + } + } + + #[test] + pub fn test_decommitment_positions() { + let channel = &mut Blake2sChannel::default(); + let log_domain_size = 31; + let n_queries = 100; + let fri_step_size = 3; + + let queries = Queries::generate(channel, log_domain_size, n_queries); + let queries_with_added_positions = queries.opening_positions(fri_step_size).flatten(); + + assert!(queries_with_added_positions.is_sorted()); + assert_eq!( + queries_with_added_positions.len(), + n_queries * (1 << fri_step_size) + ); + } + + #[test] + pub fn test_dedup_decommitment_positions() { + let log_domain_size = 7; + + // Generate all possible queries. + let queries = Queries { + positions: (0..1 << log_domain_size).collect(), + log_domain_size, + }; + let queries_with_conjugates = queries.opening_positions(log_domain_size - 2).flatten(); + + assert_eq!(*queries, *queries_with_conjugates); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/test_utils.rs b/Stwo_wrapper/crates/prover/src/core/test_utils.rs new file mode 100644 index 0000000..5ebaeaf --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/test_utils.rs @@ -0,0 +1,17 @@ +use super::backend::cpu::CpuCircleEvaluation; +use super::channel::Blake2sChannel; +use super::fields::m31::BaseField; +use super::fields::qm31::SecureField; + +pub fn secure_eval_to_base_eval( + eval: &CpuCircleEvaluation, +) -> CpuCircleEvaluation { + CpuCircleEvaluation::new( + eval.domain, + eval.values.iter().map(|x| x.to_m31_array()[0]).collect(), + ) +} + +pub fn test_channel() -> Blake2sChannel { + Blake2sChannel::default() +} diff --git a/Stwo_wrapper/crates/prover/src/core/utils.rs b/Stwo_wrapper/crates/prover/src/core/utils.rs new file mode 100644 index 0000000..334edb7 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/utils.rs @@ -0,0 +1,327 @@ +use std::iter::Peekable; +use std::ops::{Add, Mul, Sub}; + +use num_traits::{One, Zero}; + +use super::circle::CirclePoint; +use super::constraints::point_vanishing; +use super::fields::m31::BaseField; +use super::fields::qm31::SecureField; +use super::fields::{Field, FieldExpOps}; +use super::poly::circle::CircleDomain; + +pub trait IteratorMutExt<'a, T: 'a>: Iterator { + fn assign(self, other: impl IntoIterator) + where + Self: Sized, + { + self.zip(other).for_each(|(a, b)| *a = b); + } +} + +impl<'a, T: 'a, I: Iterator> IteratorMutExt<'a, T> for I {} + +/// An iterator that takes elements from the underlying [Peekable] while the predicate is true. +/// Used to implement [PeekableExt::peek_take_while]. +pub struct PeekTakeWhile<'a, I: Iterator, P: FnMut(&I::Item) -> bool> { + iter: &'a mut Peekable, + predicate: P, +} +impl<'a, I: Iterator, P: FnMut(&I::Item) -> bool> Iterator for PeekTakeWhile<'a, I, P> { + type Item = I::Item; + + fn next(&mut self) -> Option { + self.iter.next_if(&mut self.predicate) + } +} +pub trait PeekableExt<'a, I: Iterator> { + /// Returns an iterator that takes elements from the underlying [Peekable] while the predicate + /// is true. + /// Unlike [Iterator::take_while], this iterator does not consume the first element that does + /// not satisfy the predicate. + fn peek_take_while bool>( + &'a mut self, + predicate: P, + ) -> PeekTakeWhile<'a, I, P>; +} +impl<'a, I: Iterator> PeekableExt<'a, I> for Peekable { + fn peek_take_while bool>( + &'a mut self, + predicate: P, + ) -> PeekTakeWhile<'a, I, P> { + PeekTakeWhile { + iter: self, + predicate, + } + } +} + +/// Returns the bit reversed index of `i` which is represented by `log_size` bits. +pub fn bit_reverse_index(i: usize, log_size: u32) -> usize { + if log_size == 0 { + return i; + } + i.reverse_bits() >> (usize::BITS - log_size) +} + +/// Returns the index of the previous element in a bit reversed +/// [super::poly::circle::CircleEvaluation] of log size `eval_log_size` relative to a smaller domain +/// of size `domain_log_size`. +pub fn previous_bit_reversed_circle_domain_index( + i: usize, + domain_log_size: u32, + eval_log_size: u32, +) -> usize { + offset_bit_reversed_circle_domain_index(i, domain_log_size, eval_log_size, -1) +} + +/// Returns the index of the offset element in a bit reversed +/// [super::poly::circle::CircleEvaluation] of log size `eval_log_size` relative to a smaller domain +/// of size `domain_log_size`. +pub fn offset_bit_reversed_circle_domain_index( + i: usize, + domain_log_size: u32, + eval_log_size: u32, + offset: isize, +) -> usize { + let mut prev_index = bit_reverse_index(i, eval_log_size); + let half_size = 1 << (eval_log_size - 1); + let step_size = offset * (1 << (eval_log_size - domain_log_size - 1)) as isize; + if prev_index < half_size { + prev_index = (prev_index as isize + step_size).rem_euclid(half_size as isize) as usize; + } else { + prev_index = + ((prev_index as isize - step_size).rem_euclid(half_size as isize) as usize) + half_size; + } + bit_reverse_index(prev_index, eval_log_size) +} + +// TODO(AlonH): Pair both functions below with bit reverse. Consider removing both and calculating +// the indices instead. +pub(crate) fn circle_domain_order_to_coset_order(values: &[BaseField]) -> Vec { + let n = values.len(); + let mut coset_order = vec![]; + for i in 0..(n / 2) { + coset_order.push(values[i]); + coset_order.push(values[n - 1 - i]); + } + coset_order +} + +pub(crate) fn coset_order_to_circle_domain_order(values: &[F]) -> Vec { + let mut circle_domain_order = Vec::with_capacity(values.len()); + let n = values.len(); + let half_len = n / 2; + for i in 0..half_len { + circle_domain_order.push(values[i << 1]); + } + for i in 0..half_len { + circle_domain_order.push(values[n - 1 - (i << 1)]); + } + circle_domain_order +} + +/// Converts an index within a [`Coset`] to the corresponding index in a [`CircleDomain`]. +/// +/// [`CircleDomain`]: crate::core::poly::circle::CircleDomain +/// [`Coset`]: crate::core::circle::Coset +pub fn coset_index_to_circle_domain_index(coset_index: usize, log_domain_size: u32) -> usize { + if coset_index % 2 == 0 { + coset_index / 2 + } else { + ((2 << log_domain_size) - coset_index) / 2 + } +} + +/// Performs a naive bit-reversal permutation inplace. +/// +/// # Panics +/// +/// Panics if the length of the slice is not a power of two. +// TODO: Implement cache friendly implementation. +// TODO(spapini): Move this to the cpu backend. +pub fn bit_reverse(v: &mut [T]) { + let n = v.len(); + assert!(n.is_power_of_two()); + let log_n = n.ilog2(); + for i in 0..n { + let j = bit_reverse_index(i, log_n); + if j > i { + v.swap(i, j); + } + } +} + +pub fn generate_secure_powers(felt: SecureField, n_powers: usize) -> Vec { + (0..n_powers) + .scan(SecureField::one(), |acc, _| { + let res = *acc; + *acc *= felt; + Some(res) + }) + .collect() +} + +/// Securely combines the given values using the given random alpha and z. +/// Alpha and z should be secure field elements for soundness. +pub fn shifted_secure_combination(values: &[F], alpha: EF, z: EF) -> EF +where + EF: Copy + Zero + Mul + Add + Sub, +{ + let res = values + .iter() + .fold(EF::zero(), |acc, &value| acc * alpha + value); + res - z +} + +pub fn point_vanish_denominator_inverses( + domain: CircleDomain, + vanish_point: CirclePoint, +) -> Vec { + let mut denoms = vec![]; + for point in domain.iter() { + // TODO(AlonH): Use `point_vanishing_fraction` instead of `point_vanishing` everywhere. + denoms.push(point_vanishing(vanish_point, point)); + } + bit_reverse(&mut denoms); + let mut denom_inverses = vec![BaseField::zero(); 1 << (domain.log_size())]; + BaseField::batch_inverse(&denoms, &mut denom_inverses); + denom_inverses +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use num_traits::One; + + use super::{ + offset_bit_reversed_circle_domain_index, previous_bit_reversed_circle_domain_index, + }; + use crate::core::backend::cpu::CpuCircleEvaluation; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::FieldExpOps; + use crate::core::poly::circle::CanonicCoset; + use crate::core::poly::NaturalOrder; + use crate::core::utils::bit_reverse; + use crate::{m31, qm31}; + + #[test] + fn bit_reverse_works() { + let mut data = [0, 1, 2, 3, 4, 5, 6, 7]; + bit_reverse(&mut data); + assert_eq!(data, [0, 4, 2, 6, 1, 5, 3, 7]); + } + + #[test] + #[should_panic] + fn bit_reverse_non_power_of_two_size_fails() { + let mut data = [0, 1, 2, 3, 4, 5]; + bit_reverse(&mut data); + } + + #[test] + fn generate_secure_powers_works() { + let felt = qm31!(1, 2, 3, 4); + let n_powers = 10; + + let powers = super::generate_secure_powers(felt, n_powers); + + assert_eq!(powers.len(), n_powers); + assert_eq!(powers[0], SecureField::one()); + assert_eq!(powers[1], felt); + assert_eq!(powers[7], felt.pow(7)); + } + + #[test] + fn generate_empty_secure_powers_works() { + let felt = qm31!(1, 2, 3, 4); + let max_log_size = 0; + + let powers = super::generate_secure_powers(felt, max_log_size); + + assert_eq!(powers, vec![]); + } + + #[test] + fn test_offset_bit_reversed_circle_domain_index() { + let domain_log_size = 3; + let eval_log_size = 6; + let initial_index = 5; + + let actual = offset_bit_reversed_circle_domain_index( + initial_index, + domain_log_size, + eval_log_size, + -2, + ); + let expected_prev = previous_bit_reversed_circle_domain_index( + initial_index, + domain_log_size, + eval_log_size, + ); + let expected_prev2 = previous_bit_reversed_circle_domain_index( + expected_prev, + domain_log_size, + eval_log_size, + ); + assert_eq!(actual, expected_prev2); + } + + #[test] + fn test_previous_bit_reversed_circle_domain_index() { + let log_size = 4; + let n = 1 << log_size; + let domain = CanonicCoset::new(log_size).circle_domain(); + let values = (0..n).map(|i| m31!(i as u32)).collect_vec(); + let evaluation = CpuCircleEvaluation::<_, NaturalOrder>::new(domain, values.clone()); + let bit_reversed_evaluation = evaluation.clone().bit_reverse(); + + // 2 · 14 + // · | · + // 13 | 1 + // · | · + // 3 | 15 + // · | · + // 12 | 0 + // ·--------------|---------------· + // 4 | 8 + // · | · + // 11 | 7 + // · | · + // 5 | 9 + // · | · + // 10 · 6 + let neighbor_pairs = (0..n) + .map(|index| { + let prev_index = + previous_bit_reversed_circle_domain_index(index, log_size - 3, log_size); + ( + bit_reversed_evaluation[index], + bit_reversed_evaluation[prev_index], + ) + }) + .sorted() + .collect_vec(); + let mut expected_neighbor_pairs = vec![ + (m31!(0), m31!(4)), + (m31!(15), m31!(11)), + (m31!(1), m31!(5)), + (m31!(14), m31!(10)), + (m31!(2), m31!(6)), + (m31!(13), m31!(9)), + (m31!(3), m31!(7)), + (m31!(12), m31!(8)), + (m31!(4), m31!(0)), + (m31!(11), m31!(15)), + (m31!(5), m31!(1)), + (m31!(10), m31!(14)), + (m31!(6), m31!(2)), + (m31!(9), m31!(13)), + (m31!(7), m31!(3)), + (m31!(8), m31!(12)), + ]; + expected_neighbor_pairs.sort(); + + assert_eq!(neighbor_pairs, expected_neighbor_pairs); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/blake2_hash.rs b/Stwo_wrapper/crates/prover/src/core/vcs/blake2_hash.rs new file mode 100644 index 0000000..b702fcd --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/vcs/blake2_hash.rs @@ -0,0 +1,139 @@ +use std::fmt; + +use blake2::{Blake2s256, Digest}; +use bytemuck::{Pod, Zeroable}; +use serde::{Deserialize, Serialize}; + +// Wrapper for the blake2s hash type. +#[repr(C, align(32))] +#[derive(Clone, Copy, PartialEq, Default, Eq, Pod, Zeroable, Deserialize, Serialize)] +pub struct Blake2sHash(pub [u8; 32]); + +impl From for Vec { + fn from(value: Blake2sHash) -> Self { + Vec::from(value.0) + } +} + +impl From> for Blake2sHash { + fn from(value: Vec) -> Self { + Self( + value + .try_into() + .expect("Failed converting Vec to Blake2Hash type"), + ) + } +} + +impl From<&[u8]> for Blake2sHash { + fn from(value: &[u8]) -> Self { + Self( + value + .try_into() + .expect("Failed converting &[u8] to Blake2sHash Type!"), + ) + } +} + +impl AsRef<[u8]> for Blake2sHash { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl From for [u8; 32] { + fn from(val: Blake2sHash) -> Self { + val.0 + } +} + +impl fmt::Display for Blake2sHash { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&hex::encode(self.0)) + } +} + +impl fmt::Debug for Blake2sHash { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} + +impl super::hash::Hash for Blake2sHash { + fn to_bytes(&self) -> Vec { + self.0.to_vec() + } + + fn from_bytes(bytes: &[u8]) -> Self { + bytes.into() + } +} + +// Wrapper for the blake2s Hashing functionalities. +#[derive(Clone, Debug, Default)] +pub struct Blake2sHasher { + state: Blake2s256, +} + +impl Blake2sHasher { + pub fn new() -> Self { + Self { + state: Blake2s256::new(), + } + } + + pub fn update(&mut self, data: &[u8]) { + blake2::Digest::update(&mut self.state, data); + } + + pub fn finalize(self) -> Blake2sHash { + Blake2sHash(self.state.finalize().into()) + } + + pub fn concat_and_hash(v1: &Blake2sHash, v2: &Blake2sHash) -> Blake2sHash { + let mut hasher = Self::new(); + hasher.update(v1.as_ref()); + hasher.update(v2.as_ref()); + hasher.finalize() + } + + pub fn hash(data: &[u8]) -> Blake2sHash { + let mut hasher = Self::new(); + hasher.update(data); + hasher.finalize() + } +} + +#[cfg(test)] +mod tests { + use blake2::Digest; + + use super::{Blake2sHash, Blake2sHasher}; + + impl Blake2sHasher { + fn finalize_reset(&mut self) -> Blake2sHash { + Blake2sHash(self.state.finalize_reset().into()) + } + } + + #[test] + fn single_hash_test() { + let hash_a = Blake2sHasher::hash(b"a"); + assert_eq!( + hash_a.to_string(), + "4a0d129873403037c2cd9b9048203687f6233fb6738956e0349bd4320fec3e90" + ); + } + + #[test] + fn hash_state_test() { + let mut state = Blake2sHasher::new(); + state.update(b"a"); + state.update(b"b"); + let hash = state.finalize_reset(); + let hash_empty = state.finalize(); + + assert_eq!(hash.to_string(), Blake2sHasher::hash(b"ab").to_string()); + assert_eq!(hash_empty.to_string(), Blake2sHasher::hash(b"").to_string()); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/blake2_merkle.rs b/Stwo_wrapper/crates/prover/src/core/vcs/blake2_merkle.rs new file mode 100644 index 0000000..293ed4a --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/vcs/blake2_merkle.rs @@ -0,0 +1,148 @@ +use num_traits::Zero; +use serde::{Deserialize, Serialize}; + +use super::blake2_hash::Blake2sHash; +use super::blake2s_ref::compress; +use super::ops::MerkleHasher; +use crate::core::channel::{Blake2sChannel, MerkleChannel}; +use crate::core::fields::m31::BaseField; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Deserialize, Serialize)] +pub struct Blake2sMerkleHasher; +impl MerkleHasher for Blake2sMerkleHasher { + type Hash = Blake2sHash; + + fn hash_node( + children_hashes: Option<(Self::Hash, Self::Hash)>, + column_values: &[BaseField], + ) -> Self::Hash { + let mut state = [0; 8]; + if let Some((left, right)) = children_hashes { + state = compress( + state, + unsafe { std::mem::transmute([left, right]) }, + 0, + 0, + 0, + 0, + ); + } + let rem = 15 - ((column_values.len() + 15) % 16); + let padded_values = column_values + .iter() + .copied() + .chain(std::iter::repeat(BaseField::zero()).take(rem)); + for chunk in padded_values.array_chunks::<16>() { + state = compress(state, unsafe { std::mem::transmute(chunk) }, 0, 0, 0, 0); + } + state.map(|x| x.to_le_bytes()).flatten().into() + } +} + +#[derive(Default)] +pub struct Blake2sMerkleChannel; + +impl MerkleChannel for Blake2sMerkleChannel { + type C = Blake2sChannel; + type H = Blake2sMerkleHasher; + + fn mix_root(channel: &mut Self::C, root: ::Hash) { + channel.update_digest(super::blake2_hash::Blake2sHasher::concat_and_hash( + &channel.digest(), + &root, + )); + } +} + +#[cfg(test)] +mod tests { + use num_traits::Zero; + + use super::Blake2sMerkleChannel; + use crate::core::channel::{Blake2sChannel, MerkleChannel}; + use crate::core::fields::m31::BaseField; + use crate::core::vcs::blake2_merkle::{Blake2sHash, Blake2sMerkleHasher}; + use crate::core::vcs::test_utils::prepare_merkle; + use crate::core::vcs::verifier::MerkleVerificationError; + + #[test] + fn test_merkle_success() { + let (queries, decommitment, values, verifier) = prepare_merkle::(); + + verifier.verify(queries, values, decommitment).unwrap(); + } + + #[test] + fn test_merkle_invalid_witness() { + let (queries, mut decommitment, values, verifier) = prepare_merkle::(); + decommitment.hash_witness[4] = Blake2sHash::default(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::RootMismatch + ); + } + + #[test] + fn test_merkle_invalid_value() { + let (queries, decommitment, mut values, verifier) = prepare_merkle::(); + values[3][2] = BaseField::zero(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::RootMismatch + ); + } + + #[test] + fn test_merkle_witness_too_short() { + let (queries, mut decommitment, values, verifier) = prepare_merkle::(); + decommitment.hash_witness.pop(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::WitnessTooShort + ); + } + + #[test] + fn test_merkle_witness_too_long() { + let (queries, mut decommitment, values, verifier) = prepare_merkle::(); + decommitment.hash_witness.push(Blake2sHash::default()); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::WitnessTooLong + ); + } + + #[test] + fn test_merkle_column_values_too_long() { + let (queries, decommitment, mut values, verifier) = prepare_merkle::(); + values[3].push(BaseField::zero()); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::ColumnValuesTooLong + ); + } + + #[test] + fn test_merkle_column_values_too_short() { + let (queries, decommitment, mut values, verifier) = prepare_merkle::(); + values[3].pop(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::ColumnValuesTooShort + ); + } + + #[test] + fn test_merkle_channel() { + let mut channel = Blake2sChannel::default(); + let (_queries, _decommitment, _values, verifier) = prepare_merkle::(); + Blake2sMerkleChannel::mix_root(&mut channel, verifier.root); + assert_eq!(channel.channel_time.n_challenges, 1); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/blake2s_ref.rs b/Stwo_wrapper/crates/prover/src/core/vcs/blake2s_ref.rs new file mode 100644 index 0000000..ab32ea6 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/vcs/blake2s_ref.rs @@ -0,0 +1,217 @@ +//! An AVX512 implementation of the BLAKE2s compression function. +//! Based on . + +pub const IV: [u32; 8] = [ + 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, +]; + +pub const SIGMA: [[u8; 16]; 10] = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3], + [11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4], + [7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8], + [9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13], + [2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9], + [12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11], + [13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10], + [6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5], + [10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0], +]; + +#[inline(always)] +fn add(a: u32, b: u32) -> u32 { + a.wrapping_add(b) +} + +#[inline(always)] +fn xor(a: u32, b: u32) -> u32 { + a ^ b +} + +#[inline(always)] +fn rot16(x: u32) -> u32 { + (x >> 16) | (x << (32 - 16)) +} + +#[inline(always)] +fn rot12(x: u32) -> u32 { + (x >> 12) | (x << (32 - 12)) +} + +#[inline(always)] +fn rot8(x: u32) -> u32 { + (x >> 8) | (x << (32 - 8)) +} + +#[inline(always)] +fn rot7(x: u32) -> u32 { + (x >> 7) | (x << (32 - 7)) +} + +#[inline(always)] +fn round(v: &mut [u32; 16], m: [u32; 16], r: usize) { + v[0] = add(v[0], m[SIGMA[r][0] as usize]); + v[1] = add(v[1], m[SIGMA[r][2] as usize]); + v[2] = add(v[2], m[SIGMA[r][4] as usize]); + v[3] = add(v[3], m[SIGMA[r][6] as usize]); + v[0] = add(v[0], v[4]); + v[1] = add(v[1], v[5]); + v[2] = add(v[2], v[6]); + v[3] = add(v[3], v[7]); + v[12] = xor(v[12], v[0]); + v[13] = xor(v[13], v[1]); + v[14] = xor(v[14], v[2]); + v[15] = xor(v[15], v[3]); + v[12] = rot16(v[12]); + v[13] = rot16(v[13]); + v[14] = rot16(v[14]); + v[15] = rot16(v[15]); + v[8] = add(v[8], v[12]); + v[9] = add(v[9], v[13]); + v[10] = add(v[10], v[14]); + v[11] = add(v[11], v[15]); + v[4] = xor(v[4], v[8]); + v[5] = xor(v[5], v[9]); + v[6] = xor(v[6], v[10]); + v[7] = xor(v[7], v[11]); + v[4] = rot12(v[4]); + v[5] = rot12(v[5]); + v[6] = rot12(v[6]); + v[7] = rot12(v[7]); + v[0] = add(v[0], m[SIGMA[r][1] as usize]); + v[1] = add(v[1], m[SIGMA[r][3] as usize]); + v[2] = add(v[2], m[SIGMA[r][5] as usize]); + v[3] = add(v[3], m[SIGMA[r][7] as usize]); + v[0] = add(v[0], v[4]); + v[1] = add(v[1], v[5]); + v[2] = add(v[2], v[6]); + v[3] = add(v[3], v[7]); + v[12] = xor(v[12], v[0]); + v[13] = xor(v[13], v[1]); + v[14] = xor(v[14], v[2]); + v[15] = xor(v[15], v[3]); + v[12] = rot8(v[12]); + v[13] = rot8(v[13]); + v[14] = rot8(v[14]); + v[15] = rot8(v[15]); + v[8] = add(v[8], v[12]); + v[9] = add(v[9], v[13]); + v[10] = add(v[10], v[14]); + v[11] = add(v[11], v[15]); + v[4] = xor(v[4], v[8]); + v[5] = xor(v[5], v[9]); + v[6] = xor(v[6], v[10]); + v[7] = xor(v[7], v[11]); + v[4] = rot7(v[4]); + v[5] = rot7(v[5]); + v[6] = rot7(v[6]); + v[7] = rot7(v[7]); + + v[0] = add(v[0], m[SIGMA[r][8] as usize]); + v[1] = add(v[1], m[SIGMA[r][10] as usize]); + v[2] = add(v[2], m[SIGMA[r][12] as usize]); + v[3] = add(v[3], m[SIGMA[r][14] as usize]); + v[0] = add(v[0], v[5]); + v[1] = add(v[1], v[6]); + v[2] = add(v[2], v[7]); + v[3] = add(v[3], v[4]); + v[15] = xor(v[15], v[0]); + v[12] = xor(v[12], v[1]); + v[13] = xor(v[13], v[2]); + v[14] = xor(v[14], v[3]); + v[15] = rot16(v[15]); + v[12] = rot16(v[12]); + v[13] = rot16(v[13]); + v[14] = rot16(v[14]); + v[10] = add(v[10], v[15]); + v[11] = add(v[11], v[12]); + v[8] = add(v[8], v[13]); + v[9] = add(v[9], v[14]); + v[5] = xor(v[5], v[10]); + v[6] = xor(v[6], v[11]); + v[7] = xor(v[7], v[8]); + v[4] = xor(v[4], v[9]); + v[5] = rot12(v[5]); + v[6] = rot12(v[6]); + v[7] = rot12(v[7]); + v[4] = rot12(v[4]); + v[0] = add(v[0], m[SIGMA[r][9] as usize]); + v[1] = add(v[1], m[SIGMA[r][11] as usize]); + v[2] = add(v[2], m[SIGMA[r][13] as usize]); + v[3] = add(v[3], m[SIGMA[r][15] as usize]); + v[0] = add(v[0], v[5]); + v[1] = add(v[1], v[6]); + v[2] = add(v[2], v[7]); + v[3] = add(v[3], v[4]); + v[15] = xor(v[15], v[0]); + v[12] = xor(v[12], v[1]); + v[13] = xor(v[13], v[2]); + v[14] = xor(v[14], v[3]); + v[15] = rot8(v[15]); + v[12] = rot8(v[12]); + v[13] = rot8(v[13]); + v[14] = rot8(v[14]); + v[10] = add(v[10], v[15]); + v[11] = add(v[11], v[12]); + v[8] = add(v[8], v[13]); + v[9] = add(v[9], v[14]); + v[5] = xor(v[5], v[10]); + v[6] = xor(v[6], v[11]); + v[7] = xor(v[7], v[8]); + v[4] = xor(v[4], v[9]); + v[5] = rot7(v[5]); + v[6] = rot7(v[6]); + v[7] = rot7(v[7]); + v[4] = rot7(v[4]); +} + +/// Performs a Blake2s compression. +pub fn compress( + h_vecs: [u32; 8], + msg_vecs: [u32; 16], + count_low: u32, + count_high: u32, + lastblock: u32, + lastnode: u32, +) -> [u32; 8] { + let mut v = [ + h_vecs[0], + h_vecs[1], + h_vecs[2], + h_vecs[3], + h_vecs[4], + h_vecs[5], + h_vecs[6], + h_vecs[7], + IV[0], + IV[1], + IV[2], + IV[3], + xor(IV[4], count_low), + xor(IV[5], count_high), + xor(IV[6], lastblock), + xor(IV[7], lastnode), + ]; + + round(&mut v, msg_vecs, 0); + round(&mut v, msg_vecs, 1); + round(&mut v, msg_vecs, 2); + round(&mut v, msg_vecs, 3); + round(&mut v, msg_vecs, 4); + round(&mut v, msg_vecs, 5); + round(&mut v, msg_vecs, 6); + round(&mut v, msg_vecs, 7); + round(&mut v, msg_vecs, 8); + round(&mut v, msg_vecs, 9); + + [ + xor(xor(h_vecs[0], v[0]), v[8]), + xor(xor(h_vecs[1], v[1]), v[9]), + xor(xor(h_vecs[2], v[2]), v[10]), + xor(xor(h_vecs[3], v[3]), v[11]), + xor(xor(h_vecs[4], v[4]), v[12]), + xor(xor(h_vecs[5], v[5]), v[13]), + xor(xor(h_vecs[6], v[6]), v[14]), + xor(xor(h_vecs[7], v[7]), v[15]), + ] +} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/blake3_hash.rs b/Stwo_wrapper/crates/prover/src/core/vcs/blake3_hash.rs new file mode 100644 index 0000000..e9b9d0b --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/vcs/blake3_hash.rs @@ -0,0 +1,132 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +use crate::core::vcs::hash::Hash; + +// Wrapper for the blake3 hash type. +#[derive(Clone, Copy, PartialEq, Default, Eq, Serialize, Deserialize)] +pub struct Blake3Hash([u8; 32]); + +impl From for Vec { + fn from(value: Blake3Hash) -> Self { + Vec::from(value.0) + } +} + +impl From> for Blake3Hash { + fn from(value: Vec) -> Self { + Self( + value + .try_into() + .expect("Failed converting Vec to Blake3Hash Type!"), + ) + } +} + +impl From<&[u8]> for Blake3Hash { + fn from(value: &[u8]) -> Self { + Self( + value + .try_into() + .expect("Failed converting &[u8] to Blake3Hash Type!"), + ) + } +} + +impl AsRef<[u8]> for Blake3Hash { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +impl fmt::Display for Blake3Hash { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&hex::encode(self.0)) + } +} + +impl fmt::Debug for Blake3Hash { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} + +impl Hash for Blake3Hash { + fn to_bytes(&self) -> Vec { + self.0.to_vec() + } + + fn from_bytes(bytes: &[u8]) -> Self { + bytes.into() + } +} + +// Wrapper for the blake3 Hashing functionalities. +#[derive(Clone, Default)] +pub struct Blake3Hasher { + state: blake3::Hasher, +} + +impl Blake3Hasher { + pub fn new() -> Self { + Self { + state: blake3::Hasher::new(), + } + } + pub fn update(&mut self, data: &[u8]) { + self.state.update(data); + } + + pub fn finalize(self) -> Blake3Hash { + Blake3Hash(self.state.finalize().into()) + } + + pub fn concat_and_hash(v1: &Blake3Hash, v2: &Blake3Hash) -> Blake3Hash { + let mut hasher = Self::new(); + hasher.update(v1.as_ref()); + hasher.update(v2.as_ref()); + hasher.finalize() + } + + pub fn hash(data: &[u8]) -> Blake3Hash { + let mut hasher = Self::new(); + hasher.update(data); + hasher.finalize() + } +} + +#[cfg(test)] +impl Blake3Hasher { + fn finalize_reset(&mut self) -> Blake3Hash { + let res = Blake3Hash(self.state.finalize().into()); + self.state.reset(); + res + } +} + +#[cfg(test)] +mod tests { + use crate::core::vcs::blake3_hash::Blake3Hasher; + + #[test] + fn single_hash_test() { + let hash_a = Blake3Hasher::hash(b"a"); + assert_eq!( + hash_a.to_string(), + "17762fddd969a453925d65717ac3eea21320b66b54342fde15128d6caf21215f" + ); + } + + #[test] + fn hash_state_test() { + let mut state = Blake3Hasher::new(); + state.update(b"a"); + state.update(b"b"); + let hash = state.finalize_reset(); + let hash_empty = state.finalize(); + + assert_eq!(hash.to_string(), Blake3Hasher::hash(b"ab").to_string()); + assert_eq!(hash_empty.to_string(), Blake3Hasher::hash(b"").to_string()) + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/hash.rs b/Stwo_wrapper/crates/prover/src/core/vcs/hash.rs new file mode 100644 index 0000000..066a5d1 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/vcs/hash.rs @@ -0,0 +1,15 @@ +use std::fmt::{Debug, Display}; + +pub trait Hash: + Copy + + Default + + Display + + Debug + + Eq + + Send + + Sync + + 'static +{ + fn to_bytes(&self) -> Vec; + fn from_bytes(bytes: &[u8]) -> Self; +} \ No newline at end of file diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/mod.rs b/Stwo_wrapper/crates/prover/src/core/vcs/mod.rs new file mode 100644 index 0000000..7e19129 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/vcs/mod.rs @@ -0,0 +1,20 @@ +//! Vector commitment scheme (VCS) module. + +pub mod blake2_hash; +pub mod blake2_merkle; +pub mod blake2s_ref; +pub mod blake3_hash; +pub mod hash; +pub mod ops; +#[cfg(not(target_arch = "wasm32"))] +pub mod poseidon252_merkle; + +#[cfg(not(target_arch = "wasm32"))] +pub mod poseidon_bls_merkle; +pub mod prover; +mod utils; +pub mod verifier; + +#[cfg(test)] +mod test_utils; + diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/ops.rs b/Stwo_wrapper/crates/prover/src/core/vcs/ops.rs new file mode 100644 index 0000000..14093e5 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/vcs/ops.rs @@ -0,0 +1,47 @@ +use std::fmt::Debug; + +use serde::{Deserialize, Serialize}; + +use crate::core::backend::{Col, ColumnOps}; +use crate::core::fields::m31::BaseField; +use crate::core::vcs::hash::Hash; + +/// A Merkle node hash is a hash of: +/// [left_child_hash, right_child_hash], column0_value, column1_value, ... +/// "[]" denotes optional values. +/// The largest Merkle layer has no left and right child hashes. The rest of the layers have +/// children hashes. +/// At each layer, the tree may have multiple columns of the same length as the layer. +/// Each node in that layer contains one value from each column. +pub trait MerkleHasher: Debug + Default + Clone { + type Hash: Hash; + /// Hashes a single Merkle node. See [MerkleHasher] for more details. + fn hash_node( + children_hashes: Option<(Self::Hash, Self::Hash)>, + column_values: &[BaseField], + ) -> Self::Hash; +} + +/// Trait for performing Merkle operations on a commitment scheme. +pub trait MerkleOps: + ColumnOps + ColumnOps + for<'de> Deserialize<'de> + Serialize +{ + /// Commits on an entire layer of the Merkle tree. + /// See [MerkleHasher] for more details. + /// + /// The layer has 2^`log_size` nodes that need to be hashed. The topmost layer has 1 node, + /// which is a hash of 2 children and some columns. + /// + /// `prev_layer` is the previous layer of the Merkle tree, if this is not the leaf layer. + /// That layer is assumed to have 2^(`log_size`+1) nodes. + /// + /// `columns` are the extra columns that need to be hashed in each node. + /// They are assumed to be of size 2^`log_size`. + /// + /// Returns the next Merkle layer hashes. + fn commit_on_layer( + log_size: u32, + prev_layer: Option<&Col>, + columns: &[&Col], + ) -> Col; +} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/poseidon252_merkle.rs b/Stwo_wrapper/crates/prover/src/core/vcs/poseidon252_merkle.rs new file mode 100644 index 0000000..6441b71 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/vcs/poseidon252_merkle.rs @@ -0,0 +1,182 @@ +use num_traits::Zero; +use serde::{Deserialize, Serialize}; +use starknet_crypto::{poseidon_hash, poseidon_hash_many}; +use starknet_ff::FieldElement as FieldElement252; + +use super::ops::MerkleHasher; +use crate::core::channel::{MerkleChannel, Poseidon252Channel}; +use crate::core::fields::m31::BaseField; +use crate::core::vcs::hash::Hash; + +const ELEMENTS_IN_BLOCK: usize = 8; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Deserialize, Serialize)] +pub struct Poseidon252MerkleHasher; +impl MerkleHasher for Poseidon252MerkleHasher { + type Hash = FieldElement252; + + fn hash_node( + children_hashes: Option<(Self::Hash, Self::Hash)>, + column_values: &[BaseField], + ) -> Self::Hash { + let n_column_blocks = column_values.len().div_ceil(ELEMENTS_IN_BLOCK); + let values_len = 2 + n_column_blocks; + let mut values = Vec::with_capacity(values_len); + + if let Some((left, right)) = children_hashes { + values.push(left); + values.push(right); + } + + let padding_length = ELEMENTS_IN_BLOCK * n_column_blocks - column_values.len(); + let padded_values = column_values + .iter() + .copied() + .chain(std::iter::repeat(BaseField::zero()).take(padding_length)); + for chunk in padded_values.array_chunks::() { + let mut word = FieldElement252::default(); + for x in chunk { + word = word * FieldElement252::from(2u64.pow(31)) + FieldElement252::from(x.0); + } + values.push(word); + } + poseidon_hash_many(&values) + } +} + +impl Hash for FieldElement252 { + fn to_bytes(&self) -> Vec { + self.to_bytes_be().to_vec() + } + + fn from_bytes(bytes: &[u8]) -> Self { + let mut bytes_array = [0u8; 32]; + bytes_array.copy_from_slice(bytes); + FieldElement252::from_bytes_be(&bytes_array).unwrap() + } +} + +#[derive(Default)] +pub struct Poseidon252MerkleChannel; + +impl MerkleChannel for Poseidon252MerkleChannel { + type C = Poseidon252Channel; + type H = Poseidon252MerkleHasher; + + fn mix_root(channel: &mut Self::C, root: ::Hash) { + channel.update_digest(poseidon_hash(channel.digest(), root)); + } +} + +#[cfg(test)] +mod tests { + use num_traits::Zero; + use starknet_ff::FieldElement as FieldElement252; + + use crate::core::fields::m31::BaseField; + use crate::core::vcs::ops::MerkleHasher; + use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleHasher; + use crate::core::vcs::test_utils::prepare_merkle; + use crate::core::vcs::verifier::MerkleVerificationError; + use crate::m31; + + #[test] + fn test_vector() { + assert_eq!( + Poseidon252MerkleHasher::hash_node(None, &[m31!(0), m31!(1)]), + FieldElement252::from_dec_str( + "2552053700073128806553921687214114320458351061521275103654266875084493044716" + ) + .unwrap() + ); + + assert_eq!( + Poseidon252MerkleHasher::hash_node( + Some((FieldElement252::from(1u32), FieldElement252::from(2u32))), + &[m31!(3)] + ), + FieldElement252::from_dec_str( + "159358216886023795422515519110998391754567506678525778721401012606792642769" + ) + .unwrap() + ); + } + + #[test] + fn test_merkle_success() { + let (queries, decommitment, values, verifier) = prepare_merkle::(); + verifier.verify(queries, values, decommitment).unwrap(); + } + + #[test] + fn test_merkle_invalid_witness() { + let (queries, mut decommitment, values, verifier) = + prepare_merkle::(); + decommitment.hash_witness[4] = FieldElement252::default(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::RootMismatch + ); + } + + #[test] + fn test_merkle_invalid_value() { + let (queries, decommitment, mut values, verifier) = + prepare_merkle::(); + values[3][2] = BaseField::zero(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::RootMismatch + ); + } + + #[test] + fn test_merkle_witness_too_short() { + let (queries, mut decommitment, values, verifier) = + prepare_merkle::(); + decommitment.hash_witness.pop(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::WitnessTooShort + ); + } + + #[test] + fn test_merkle_witness_too_long() { + let (queries, mut decommitment, values, verifier) = + prepare_merkle::(); + decommitment.hash_witness.push(FieldElement252::default()); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::WitnessTooLong + ); + } + + #[test] + fn test_merkle_column_values_too_long() { + let (queries, decommitment, mut values, verifier) = + prepare_merkle::(); + values[3].push(BaseField::zero()); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::ColumnValuesTooLong + ); + } + + #[test] + fn test_merkle_column_values_too_short() { + let (queries, decommitment, mut values, verifier) = + prepare_merkle::(); + values[3].pop(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::ColumnValuesTooShort + ); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/poseidon_bls_merkle.rs b/Stwo_wrapper/crates/prover/src/core/vcs/poseidon_bls_merkle.rs new file mode 100644 index 0000000..196e734 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/vcs/poseidon_bls_merkle.rs @@ -0,0 +1,581 @@ +use num_traits::Zero; +use serde::{Deserialize, Serialize}; +use ark_bls12_381::Fr as BlsFr; +use ark_ff::{BigInteger, Field, PrimeField}; + + +use super::ops::MerkleHasher; +use crate::core::channel::{MerkleChannel, PoseidonBLSChannel}; +use crate::core::fields::m31::BaseField; +use crate::core::vcs::hash::Hash; + +const ELEMENTS_IN_BLOCK: usize = 8; + +//Optimize constant to be real constants (no conversion) and merge duplicated code in VCS poseidono +fn poseidon_comp_consts(idx: usize) -> BlsFr { + match idx { + 0 => BlsFr::from_be_bytes_mod_order(&[ + 111, 0, 122, 85, 17, 86, 179, 164, 73, 228, 73, 54, 183, 192, 147, 100, 74, 14, 211, + 63, 51, 234, 204, 198, 40, 233, 66, 232, 54, 193, 168, 117, + ]), + 1 => BlsFr::from_be_bytes_mod_order(&[ + 54, 13, 116, 112, 97, 30, 71, 61, 53, 63, 98, 143, 118, 209, 16, 243, 78, 113, 22, 47, + 49, 0, 59, 112, 87, 83, 140, 37, 150, 66, 99, 3, + ]), + 2 => BlsFr::from_be_bytes_mod_order(&[ + 75, 95, 236, 58, 160, 115, 223, 68, 1, 144, 145, 240, 7, 164, 76, 169, 150, 72, 73, + 101, 247, 3, 109, 206, 62, 157, 9, 119, 237, 205, 192, 246, + ]), + 3 => BlsFr::from_be_bytes_mod_order(&[ + 103, 207, 24, 104, 175, 99, 150, 192, 184, 76, 206, 113, 94, 83, 159, 132, 158, 6, 205, + 28, 56, 58, 197, 176, 97, 0, 199, 107, 204, 151, 58, 17, + ]), + 4 => BlsFr::from_be_bytes_mod_order(&[ + 85, 93, 180, 209, 220, 237, 129, 159, 93, 61, 231, 15, 222, 131, 241, 199, 211, 232, + 201, 137, 104, 229, 22, 162, 58, 119, 26, 92, 156, 130, 87, 170, + ]), + 5 => BlsFr::from_be_bytes_mod_order(&[ + 43, 171, 148, 215, 174, 34, 45, 19, 93, 195, 198, 197, 254, 191, 170, 49, 73, 8, 172, + 47, 18, 235, 224, 111, 189, 183, 66, 19, 191, 99, 24, 139, + ]), + 6 => BlsFr::from_be_bytes_mod_order(&[ + 102, 244, 75, 229, 41, 102, 130, 196, 250, 120, 130, 121, 157, 109, 208, 73, 182, 215, + 210, 201, 80, 204, 249, 140, 242, 229, 13, 109, 30, 187, 119, 194, + ]), + 7 => BlsFr::from_be_bytes_mod_order(&[ + 21, 12, 147, 254, 246, 82, 251, 28, 43, 240, 62, 26, 41, 170, 135, 31, 239, 119, 231, + 215, 54, 118, 108, 93, 9, 57, 217, 39, 83, 204, 93, 200, + ]), + 8 => BlsFr::from_be_bytes_mod_order(&[ + 50, 112, 102, 30, 104, 146, 139, 58, 149, 93, 85, 219, 86, 220, 87, 193, 3, 204, 10, + 96, 20, 30, 137, 78, 20, 37, 157, 206, 83, 119, 130, 178, + ]), + 9 => BlsFr::from_be_bytes_mod_order(&[ + 7, 63, 17, 111, 4, 18, 46, 37, 160, 183, 175, 228, 226, 5, 114, 153, 180, 7, 195, 112, + 242, 181, 161, 204, 206, 159, 185, 255, 195, 69, 175, 179, + ]), + 10 => BlsFr::from_be_bytes_mod_order(&[ + 64, 159, 218, 34, 85, 140, 254, 77, 61, 216, 220, 226, 79, 105, 231, 111, 140, 42, 174, + 177, 221, 15, 9, 214, 94, 101, 76, 113, 243, 42, 162, 63, + ]), + 11 => BlsFr::from_be_bytes_mod_order(&[ + 42, 50, 236, 92, 78, 229, 177, 131, 122, 255, 208, 156, 31, 83, 245, 253, 85, 201, 205, + 32, 97, 174, 147, 202, 142, 186, 215, 111, 199, 21, 84, 216, + ]), + 12 => BlsFr::from_be_bytes_mod_order(&[ + 88, 72, 235, 235, 89, 35, 233, 37, 85, 183, 18, 79, 255, 186, 93, 107, 213, 113, 198, + 249, 132, 25, 94, 185, 207, 211, 163, 232, 235, 85, 177, 212, + ]), + 13 => BlsFr::from_be_bytes_mod_order(&[ + 39, 3, 38, 238, 3, 157, 241, 158, 101, 30, 44, 252, 116, 6, 40, 202, 99, 77, 36, 252, + 110, 37, 89, 242, 45, 140, 203, 226, 146, 239, 238, 173, + ]), + 14 => BlsFr::from_be_bytes_mod_order(&[ + 39, 198, 100, 42, 198, 51, 188, 102, 220, 16, 15, 231, 252, 250, 84, 145, 138, 248, + 149, 188, 224, 18, 241, 130, 160, 104, 252, 55, 193, 130, 226, 116, + ]), + 15 => BlsFr::from_be_bytes_mod_order(&[ + 27, 223, 216, 176, 20, 1, 199, 10, 210, 127, 87, 57, 105, 137, 18, 157, 113, 14, 31, + 182, 171, 151, 106, 69, 156, 161, 134, 130, 226, 109, 127, 249, + ]), + 16 => BlsFr::from_be_bytes_mod_order(&[ + 73, 27, 155, 166, 152, 59, 207, 159, 5, 254, 71, 148, 173, 180, 74, 48, 135, 155, 248, + 40, 150, 98, 225, 245, 125, 144, 246, 114, 65, 78, 138, 74, + ]), + 17 => BlsFr::from_be_bytes_mod_order(&[ + 22, 42, 20, 198, 47, 154, 137, 184, 20, 185, 214, 169, 200, 77, 214, 120, 244, 246, + 251, 63, 144, 84, 211, 115, 200, 50, 216, 36, 38, 26, 53, 234, + ]), + 18 => BlsFr::from_be_bytes_mod_order(&[ + 45, 25, 62, 15, 118, 222, 88, 107, 42, 246, 247, 158, 49, 39, 254, 234, 172, 10, 31, + 199, 30, 44, 240, 192, 247, 152, 36, 102, 123, 91, 107, 236, + ]), + 19 => BlsFr::from_be_bytes_mod_order(&[ + 70, 239, 216, 169, 162, 98, 214, 216, 253, 201, 202, 92, 4, 176, 152, 47, 36, 221, 204, + 110, 152, 99, 136, 90, 106, 115, 42, 57, 6, 160, 123, 149, + ]), + 20 => BlsFr::from_be_bytes_mod_order(&[ + 80, 151, 23, 224, 194, 0, 227, 201, 45, 141, 202, 41, 115, 179, 219, 69, 240, 120, 130, + 148, 53, 26, 208, 122, 231, 92, 187, 120, 6, 147, 167, 152, + ]), + 21 => BlsFr::from_be_bytes_mod_order(&[ + 114, 153, 178, 132, 100, 168, 201, 79, 185, 212, 223, 97, 56, 15, 57, 192, 220, 169, + 194, 192, 20, 17, 135, 137, 226, 39, 37, 40, 32, 240, 27, 252, + ]), + 22 => BlsFr::from_be_bytes_mod_order(&[ + 4, 76, 163, 204, 74, 133, 215, 59, 129, 105, 110, 241, 16, 78, 103, 79, 79, 239, 248, + 41, 132, 153, 15, 248, 93, 11, 245, 141, 200, 164, 170, 148, + ]), + 23 => BlsFr::from_be_bytes_mod_order(&[ + 28, 186, 242, 179, 113, 218, 198, 168, 29, 4, 83, 65, 109, 62, 35, 92, 184, 217, 226, + 212, 243, 20, 244, 111, 97, 152, 120, 95, 12, 214, 185, 175, + ]), + 24 => BlsFr::from_be_bytes_mod_order(&[ + 29, 91, 39, 119, 105, 44, 32, 91, 14, 108, 73, 208, 97, 182, 181, 244, 41, 60, 74, 176, + 56, 253, 187, 220, 52, 62, 7, 97, 15, 63, 237, 229, + ]), + 25 => BlsFr::from_be_bytes_mod_order(&[ + 86, 174, 124, 122, 82, 147, 189, 194, 62, 133, 225, 105, 140, 129, 199, 127, 138, 216, + 140, 75, 51, 165, 120, 4, 55, 173, 4, 124, 110, 219, 89, 186, + ]), + 26 => BlsFr::from_be_bytes_mod_order(&[ + 46, 155, 219, 186, 61, 211, 75, 255, 170, 48, 83, 91, 221, 116, 154, 126, 6, 169, 173, + 176, 193, 230, 249, 98, 246, 14, 151, 27, 141, 115, 176, 79, + ]), + 27 => BlsFr::from_be_bytes_mod_order(&[ + 45, 225, 24, 134, 177, 128, 17, 202, 139, 213, 186, 227, 105, 105, 41, 159, 222, 64, + 251, 226, 109, 4, 123, 5, 3, 90, 19, 102, 31, 34, 65, 139, + ]), + 28 => BlsFr::from_be_bytes_mod_order(&[ + 46, 7, 222, 23, 128, 184, 167, 13, 13, 91, 74, 63, 24, 65, 220, 216, 42, 185, 57, 92, + 68, 155, 233, 71, 188, 153, 136, 132, 186, 150, 167, 33, + ]), + 29 => BlsFr::from_be_bytes_mod_order(&[ + 15, 105, 241, 133, 77, 32, 202, 12, 187, 219, 99, 219, 213, 45, 173, 22, 37, 4, 64, + 169, 157, 107, 138, 243, 130, 94, 76, 43, 183, 73, 37, 202, + ]), + 30 => BlsFr::from_be_bytes_mod_order(&[ + 93, 201, 135, 49, 142, 110, 89, 193, 175, 184, 123, 101, 93, 213, 140, 193, 210, 46, + 81, 58, 5, 131, 140, 212, 88, 93, 4, 177, 53, 185, 87, 202, + ]), + 31 => BlsFr::from_be_bytes_mod_order(&[ + 72, 183, 37, 117, 133, 113, 201, 223, 108, 1, 220, 99, 154, 133, 240, 114, 151, 105, + 107, 27, 182, 120, 99, 58, 41, 220, 145, 222, 149, 239, 83, 246, + ]), + 32 => BlsFr::from_be_bytes_mod_order(&[ + 94, 86, 94, 8, 192, 130, 16, 153, 37, 107, 86, 73, 14, 174, 225, 213, 115, 175, 209, + 11, 182, 209, 125, 19, 202, 78, 92, 97, 27, 42, 55, 24, + ]), + 33 => BlsFr::from_be_bytes_mod_order(&[ + 46, 177, 178, 84, 23, 254, 23, 103, 13, 19, 93, 198, 57, 251, 9, 164, 108, 229, 17, 53, + 7, 249, 109, 233, 129, 108, 5, 148, 34, 220, 112, 94, + ]), + 34 => BlsFr::from_be_bytes_mod_order(&[ + 17, 92, 208, 160, 100, 60, 251, 152, 140, 36, 203, 68, 195, 250, 180, 138, 255, 54, + 198, 97, 210, 108, 196, 45, 184, 177, 189, 244, 149, 59, 216, 44, + ]), + 35 => BlsFr::from_be_bytes_mod_order(&[ + 38, 202, 41, 63, 123, 44, 70, 45, 6, 109, 115, 120, 185, 153, 134, 139, 187, 87, 221, + 241, 78, 15, 149, 138, 222, 128, 22, 18, 49, 29, 4, 205, + ]), + 36 => BlsFr::from_be_bytes_mod_order(&[ + 65, 71, 64, 13, 142, 26, 172, 207, 49, 26, 107, 91, 118, 32, 17, 171, 62, 69, 50, 110, + 77, 75, 157, 226, 105, 146, 129, 107, 153, 197, 40, 172, + ]), + 37 => BlsFr::from_be_bytes_mod_order(&[ + 107, 13, 183, 220, 204, 75, 161, 178, 104, 246, 189, 204, 77, 55, 40, 72, 212, 167, 41, + 118, 194, 104, 234, 48, 81, 154, 47, 115, 230, 219, 77, 85, + ]), + 38 => BlsFr::from_be_bytes_mod_order(&[ + 23, 191, 27, 147, 196, 199, 224, 26, 42, 131, 10, 161, 98, 65, 44, 217, 15, 22, 11, + 249, 247, 30, 150, 127, 245, 32, 157, 20, 178, 72, 32, 202, + ]), + 39 => BlsFr::from_be_bytes_mod_order(&[ + 75, 67, 28, 217, 239, 237, 188, 148, 207, 30, 202, 111, 158, 156, 24, 57, 208, 230, + 106, 139, 255, 168, 200, 70, 76, 172, 129, 163, 157, 60, 248, 241, + ]), + 40 => BlsFr::from_be_bytes_mod_order(&[ + 53, 180, 26, 122, 196, 243, 197, 113, 162, 79, 132, 86, 54, 156, 133, 223, 224, 60, 3, + 84, 189, 140, 253, 56, 5, 200, 111, 46, 125, 194, 147, 197, + ]), + 41 => BlsFr::from_be_bytes_mod_order(&[ + 59, 20, 128, 8, 5, 35, 196, 57, 67, 89, 39, 153, 72, 73, 190, 169, 100, 225, 77, 59, + 235, 45, 221, 222, 114, 172, 21, 106, 244, 53, 208, 158, + ]), + 42 => BlsFr::from_be_bytes_mod_order(&[ + 44, 198, 129, 0, 49, 220, 27, 13, 73, 80, 133, 109, 201, 7, 213, 117, 8, 226, 134, 68, + 42, 45, 62, 178, 39, 22, 24, 216, 116, 177, 76, 109, + ]), + 43 => BlsFr::from_be_bytes_mod_order(&[ + 111, 65, 65, 200, 64, 28, 90, 57, 91, 166, 121, 14, 253, 113, 199, 12, 4, 175, 234, 6, + 195, 201, 40, 38, 188, 171, 221, 92, 181, 71, 125, 81, + ]), + 44 => BlsFr::from_be_bytes_mod_order(&[ + 37, 189, 187, 237, 161, 189, 232, 193, 5, 150, 24, 226, 175, 210, 239, 153, 158, 81, + 122, 169, 59, 120, 52, 29, 145, 243, 24, 192, 159, 12, 181, 102, + ]), + 45 => BlsFr::from_be_bytes_mod_order(&[ + 57, 42, 74, 135, 88, 224, 110, 232, 185, 95, 51, 194, 93, 222, 138, 192, 42, 94, 208, + 162, 123, 97, 146, 108, 198, 49, 52, 135, 7, 63, 127, 123, + ]), + 46 => BlsFr::from_be_bytes_mod_order(&[ + 39, 42, 85, 135, 138, 8, 68, 43, 154, 166, 17, 31, 77, 224, 9, 72, 94, 106, 111, 209, + 93, 184, 147, 101, 231, 187, 206, 240, 46, 181, 134, 108, + ]), + 47 => BlsFr::from_be_bytes_mod_order(&[ + 99, 30, 193, 214, 210, 141, 217, 232, 36, 238, 137, 163, 7, 48, 174, 247, 171, 70, 58, + 207, 201, 209, 132, 179, 85, 170, 5, 253, 105, 56, 234, 181, + ]), + 48 => BlsFr::from_be_bytes_mod_order(&[ + 78, 182, 253, 161, 15, 208, 251, 222, 2, 199, 68, 155, 251, 221, 195, 91, 205, 130, 37, + 231, 229, 195, 131, 58, 8, 24, 161, 0, 64, 157, 198, 242, + ]), + 49 => BlsFr::from_be_bytes_mod_order(&[ + 45, 91, 48, 139, 12, 240, 44, 223, 239, 161, 60, 78, 96, 226, 98, 57, 166, 235, 186, 1, + 22, 148, 221, 18, 155, 146, 91, 60, 91, 33, 224, 226, + ]), + 50 => BlsFr::from_be_bytes_mod_order(&[ + 22, 84, 159, 198, 175, 47, 59, 114, 221, 93, 41, 61, 114, 226, 229, 242, 68, 223, 244, + 47, 24, 180, 108, 86, 239, 56, 197, 124, 49, 22, 115, 172, + ]), + 51 => BlsFr::from_be_bytes_mod_order(&[ + 66, 51, 38, 119, 255, 53, 156, 94, 141, 184, 54, 217, 245, 251, 84, 130, 46, 57, 189, + 94, 34, 52, 11, 185, 186, 151, 91, 161, 169, 43, 227, 130, + ]), + 52 => BlsFr::from_be_bytes_mod_order(&[ + 73, 215, 210, 192, 180, 73, 229, 23, 155, 197, 204, 195, 180, 76, 96, 117, 217, 132, + 155, 86, 16, 70, 95, 9, 234, 114, 93, 220, 151, 114, 58, 148, + ]), + 53 => BlsFr::from_be_bytes_mod_order(&[ + 100, 194, 15, 185, 13, 122, 0, 56, 49, 117, 124, 196, 198, 34, 111, 110, 73, 133, 252, + 158, 203, 65, 107, 159, 104, 76, 160, 53, 29, 150, 121, 4, + ]), + 54 => BlsFr::from_be_bytes_mod_order(&[ + 89, 207, 244, 13, 232, 59, 82, 180, 27, 196, 67, 215, 151, 149, 16, 215, 113, 201, 64, + 185, 117, 140, 168, 32, 254, 115, 181, 200, 213, 88, 9, 52, + ]), + 55 => BlsFr::from_be_bytes_mod_order(&[ + 83, 219, 39, 49, 115, 12, 57, 176, 78, 221, 135, 95, 227, 183, 200, 130, 128, 130, 133, + 205, 188, 98, 29, 122, 244, 248, 13, 213, 62, 187, 113, 176, + ]), + 56 => BlsFr::from_be_bytes_mod_order(&[ + 27, 16, 187, 122, 130, 175, 206, 57, 250, 105, 195, 162, 173, 82, 247, 109, 118, 57, + 130, 101, 52, 66, 3, 17, 155, 113, 38, 217, 180, 104, 96, 223, + ]), + 57 => BlsFr::from_be_bytes_mod_order(&[ + 86, 27, 96, 18, 214, 102, 191, 225, 121, 196, 221, 127, 132, 205, 209, 83, 21, 150, + 211, 170, 199, 197, 112, 12, 235, 49, 159, 145, 4, 106, 99, 201, + ]), + 58 => BlsFr::from_be_bytes_mod_order(&[ + 15, 30, 117, 5, 235, 217, 29, 47, 199, 156, 45, 247, 220, 152, 163, 190, 209, 179, 105, + 104, 186, 4, 5, 192, 144, 210, 127, 106, 0, 183, 223, 200, + ]), + 59 => BlsFr::from_be_bytes_mod_order(&[ + 47, 49, 63, 175, 13, 63, 97, 135, 83, 122, 116, 151, 163, 180, 63, 70, 121, 127, 214, + 227, 241, 142, 177, 202, 255, 69, 119, 86, 184, 25, 187, 32, + ]), + 60 => BlsFr::from_be_bytes_mod_order(&[ + 58, 92, 187, 109, 228, 80, 180, 129, 250, 60, 166, 28, 14, 209, 91, 197, 92, 173, 17, + 235, 240, 247, 206, 184, 240, 188, 62, 115, 46, 203, 38, 246, + ]), + 61 => BlsFr::from_be_bytes_mod_order(&[ + 104, 29, 147, 65, 27, 248, 206, 99, 246, 113, 106, 239, 189, 14, 36, 80, 100, 84, 192, + 52, 142, 227, 143, 171, 235, 38, 71, 2, 113, 76, 207, 148, + ]), + 62 => BlsFr::from_be_bytes_mod_order(&[ + 81, 120, 233, 64, 245, 0, 4, 49, 38, 70, 180, 54, 114, 127, 14, 128, 167, 184, 242, + 233, 238, 31, 220, 103, 124, 72, 49, 167, 103, 39, 119, 251, + ]), + 63 => BlsFr::from_be_bytes_mod_order(&[ + 61, 171, 84, 188, 155, 239, 104, 141, 217, 32, 134, 226, 83, 180, 57, 214, 81, 186, + 166, 226, 15, 137, 43, 98, 134, 85, 39, 203, 202, 145, 89, 130, + ]), + 64 => BlsFr::from_be_bytes_mod_order(&[ + 75, 60, 231, 83, 17, 33, 143, 154, 233, 5, 248, 78, 170, 91, 43, 56, 24, 68, 139, 191, + 57, 114, 225, 170, 214, 157, 227, 33, 0, 144, 21, 208, + ]), + 65 => BlsFr::from_be_bytes_mod_order(&[ + 6, 219, 251, 66, 185, 121, 136, 77, 226, 128, 211, 22, 112, 18, 63, 116, 76, 36, 179, + 59, 65, 15, 239, 212, 54, 128, 69, 172, 242, 183, 26, 227, + ]), + 66 => BlsFr::from_be_bytes_mod_order(&[ + 6, 141, 107, 70, 8, 170, 232, 16, 198, 240, 57, 234, 25, 115, 166, 62, 184, 210, 222, + 114, 227, 210, 201, 236, 167, 252, 50, 210, 47, 24, 185, 211, + ]), + 67 => BlsFr::from_be_bytes_mod_order(&[ + 76, 92, 37, 69, 137, 169, 42, 54, 8, 74, 87, 211, 177, 217, 100, 39, 138, 204, 126, 79, + 232, 246, 159, 41, 85, 149, 79, 39, 167, 156, 235, 239, + ]), + 68 => BlsFr::from_be_bytes_mod_order(&[ + 108, 186, 197, 225, 112, 9, 132, 235, 195, 45, 161, 91, 75, 185, 104, 63, 170, 186, + 181, 95, 103, 204, 196, 247, 29, 149, 96, 179, 71, 90, 119, 235, + ]), + 69 => BlsFr::from_be_bytes_mod_order(&[ + 70, 3, 196, 3, 187, 250, 154, 23, 115, 138, 92, 98, 120, 234, 171, 28, 55, 236, 48, + 176, 115, 122, 162, 64, 159, 196, 137, 128, 105, 235, 152, 60, + ]), + 70 => BlsFr::from_be_bytes_mod_order(&[ + 104, 148, 231, 226, 43, 44, 29, 92, 112, 167, 18, 166, 52, 90, 230, 177, 146, 169, 200, + 51, 169, 35, 76, 49, 197, 106, 172, 209, 107, 194, 241, 0, + ]), + 71 => BlsFr::from_be_bytes_mod_order(&[ + 91, 226, 203, 188, 68, 5, 58, 208, 138, 250, 77, 30, 171, 199, 243, 210, 49, 238, 167, + 153, 185, 63, 34, 110, 144, 91, 125, 77, 101, 197, 142, 187, + ]), + 72 => BlsFr::from_be_bytes_mod_order(&[ + 88, 229, 95, 40, 123, 69, 58, 152, 8, 98, 74, 140, 42, 53, 61, 82, 141, 160, 247, 231, + 19, 165, 198, 208, 215, 113, 30, 71, 6, 63, 166, 17, + ]), + 73 => BlsFr::from_be_bytes_mod_order(&[ + 54, 110, 191, 175, 163, 173, 56, 28, 14, 226, 88, 201, 184, 253, 252, 205, 184, 104, + 167, 215, 225, 241, 246, 154, 43, 93, 252, 197, 87, 37, 85, 223, + ]), + 74 => BlsFr::from_be_bytes_mod_order(&[ + 69, 118, 106, 183, 40, 150, 140, 100, 47, 144, 217, 124, 207, 85, 4, 221, 193, 5, 24, + 168, 25, 235, 188, 196, 208, 156, 63, 93, 120, 77, 103, 206, + ]), + 75 => BlsFr::from_be_bytes_mod_order(&[ + 57, 103, 143, 101, 81, 47, 30, 228, 4, 219, 48, 36, 244, 29, 63, 86, 126, 246, 109, + 137, 208, 68, 208, 34, 230, 188, 34, 158, 149, 188, 118, 177, + ]), + 76 => BlsFr::from_be_bytes_mod_order(&[ + 70, 58, 237, 29, 47, 31, 149, 94, 48, 120, 190, 91, 247, 191, 196, 111, 192, 235, 140, + 81, 85, 25, 6, 168, 134, 143, 24, 255, 174, 48, 207, 79, + ]), + 77 => BlsFr::from_be_bytes_mod_order(&[ + 33, 102, 143, 1, 106, 128, 99, 192, 213, 139, 119, 80, 163, 188, 47, 225, 207, 130, + 194, 95, 153, 220, 1, 164, 229, 52, 200, 143, 229, 61, 133, 254, + ]), + 78 => BlsFr::from_be_bytes_mod_order(&[ + 57, 208, 9, 148, 168, 165, 4, 106, 27, 199, 73, 54, 62, 152, 167, 104, 227, 77, 234, + 86, 67, 159, 225, 149, 75, 239, 66, 155, 197, 51, 22, 8, + ]), + 79 => BlsFr::from_be_bytes_mod_order(&[ + 77, 127, 93, 205, 120, 236, 233, 169, 51, 152, 77, 227, 44, 11, 72, 250, 194, 187, 169, + 31, 38, 25, 150, 184, 233, 209, 2, 23, 115, 189, 7, 204, + ]), + _ => BlsFr::ZERO, + } +} + +pub fn poseidon_hash_bls(x: BlsFr, y: BlsFr) -> BlsFr { + let mut state = [x, y, BlsFr::ZERO]; + poseidon_permute_comp_bls(&mut state); + state[0] + x +} + +pub fn poseidon_permute_comp_bls(state: &mut [BlsFr; 3]) { + let mut idx = 0; + mix(state); + + // Full rounds + for _ in 0..4 { + round_comp(state, idx, true); + idx += 3; + } + + // Partial rounds + for _ in 0..56 { + round_comp(state, idx, false); + idx += 1; + } + + // Full rounds + for _ in 0..4 { + round_comp(state, idx, true); + idx += 3; + } +} + +#[inline] +fn round_comp(state: &mut [BlsFr; 3], idx: usize, full: bool) { + if full { + state[0] += poseidon_comp_consts(idx); + state[1] += poseidon_comp_consts(idx + 1); + state[2] += poseidon_comp_consts(idx + 2); + // Optimize multiplication + state[0] = state[0] * state[0] * state[0] * state[0] * state[0]; + state[1] = state[1] * state[1] * state[1] * state[1] * state[1]; + state[2] = state[2] * state[2] * state[2] * state[2] * state[2]; + } else { + state[0] += poseidon_comp_consts(idx); + state[2] = state[2] * state[2] * state[2] * state[2] * state[2]; + } + mix(state); +} + +#[inline(always)] +fn mix(state: &mut [BlsFr; 3]) { + state[0] = state[0] + state[1] + state[2]; + state[1] = state[0] + state[1]; + state[2] = state[0] + state[2]; +} + +pub fn poseidon_hash_many_bls(msgs: &[BlsFr]) -> BlsFr { + let mut state = [BlsFr::ZERO, BlsFr::ZERO, BlsFr::ZERO]; + let mut iter = msgs.chunks_exact(2); + + for msg in iter.by_ref() { + state[0] += msg[0]; + state[1] += msg[1]; + poseidon_permute_comp_bls(&mut state); + } + let r = iter.remainder(); + if r.len() == 1 { + state[0] += r[0]; + } + state[r.len()] += BlsFr::ONE; + poseidon_permute_comp_bls(&mut state); + + state[0] +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Deserialize, Serialize)] +pub struct PoseidonBLSMerkleHasher; +impl MerkleHasher for PoseidonBLSMerkleHasher { + type Hash = BlsFr; + + fn hash_node( + children_hashes: Option<(Self::Hash, Self::Hash)>, + column_values: &[BaseField], + ) -> Self::Hash { + let n_column_blocks = column_values.len().div_ceil(ELEMENTS_IN_BLOCK); + let values_len = 2 + n_column_blocks; + let mut values = Vec::with_capacity(values_len); + + if let Some((left, right)) = children_hashes { + values.push(left); + values.push(right); + } + + let padding_length = ELEMENTS_IN_BLOCK * n_column_blocks - column_values.len(); + let padded_values = column_values + .iter() + .copied() + .chain(std::iter::repeat(BaseField::zero()).take(padding_length)); + for chunk in padded_values.array_chunks::() { + let mut word = BlsFr::default(); + for x in chunk { + word = word * BlsFr::from(2u64.pow(31)) + BlsFr::from(x.0); + } + values.push(word); + } + poseidon_hash_many_bls(&values) + } +} + +impl Hash for BlsFr { + fn to_bytes(&self) -> Vec { + self.into_bigint().to_bytes_be() + } + + fn from_bytes(bytes: &[u8]) -> Self { + BlsFr::from_be_bytes_mod_order(bytes) + } +} + +#[derive(Default)] +pub struct PoseidonBLSMerkleChannel; + +impl MerkleChannel for PoseidonBLSMerkleChannel { + type C = PoseidonBLSChannel; + type H = PoseidonBLSMerkleHasher; + + fn mix_root(channel: &mut Self::C, root: ::Hash) { + channel.update_digest(poseidon_hash_bls(channel.digest(), root)); + } +} + +#[cfg(test)] +mod tests { + use num_traits::Zero; + use ark_bls12_381::Fr as BlsFr; + + use crate::core::fields::m31::BaseField; + //use crate::core::vcs::ops::MerkleHasher; + use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleHasher; + use crate::core::vcs::test_utils::prepare_merkle; + use crate::core::vcs::verifier::MerkleVerificationError; + //use crate::m31; + + // TODO: Redo test vectors for bls + /*#[test] + fn test_vector() { + assert_eq!( + PoseidonBLSMerkleHasher::hash_node(None, &[m31!(0), m31!(1)]), + BlsFr::from_str( + "2552053700073128806553921687214114320458351061521275103654266875084493044716" + ) + .unwrap() + ); + + assert_eq!( + Poseidon252MerkleHasher::hash_node( + Some((FieldElement252::from(1u32), FieldElement252::from(2u32))), + &[m31!(3)] + ), + FieldElement252::from_dec_str( + "159358216886023795422515519110998391754567506678525778721401012606792642769" + ) + .unwrap() + ); + }*/ + + #[test] + fn test_merkle_success() { + let (queries, decommitment, values, verifier) = prepare_merkle::(); + verifier.verify(queries, values, decommitment).unwrap(); + } + + #[test] + fn test_merkle_invalid_witness() { + let (queries, mut decommitment, values, verifier) = + prepare_merkle::(); + decommitment.hash_witness[4] = BlsFr::default(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::RootMismatch + ); + } + + #[test] + fn test_merkle_invalid_value() { + let (queries, decommitment, mut values, verifier) = + prepare_merkle::(); + values[3][2] = BaseField::zero(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::RootMismatch + ); + } + + #[test] + fn test_merkle_witness_too_short() { + let (queries, mut decommitment, values, verifier) = + prepare_merkle::(); + decommitment.hash_witness.pop(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::WitnessTooShort + ); + } + + #[test] + fn test_merkle_witness_too_long() { + let (queries, mut decommitment, values, verifier) = + prepare_merkle::(); + decommitment.hash_witness.push(BlsFr::default()); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::WitnessTooLong + ); + } + + #[test] + fn test_merkle_column_values_too_long() { + let (queries, decommitment, mut values, verifier) = + prepare_merkle::(); + values[3].push(BaseField::zero()); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::ColumnValuesTooLong + ); + } + + #[test] + fn test_merkle_column_values_too_short() { + let (queries, decommitment, mut values, verifier) = + prepare_merkle::(); + values[3].pop(); + + assert_eq!( + verifier.verify(queries, values, decommitment).unwrap_err(), + MerkleVerificationError::ColumnValuesTooShort + ); + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/prover.rs b/Stwo_wrapper/crates/prover/src/core/vcs/prover.rs new file mode 100644 index 0000000..c2fd63d --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/vcs/prover.rs @@ -0,0 +1,223 @@ +use std::cmp::Reverse; +use std::collections::BTreeMap; + +use itertools::Itertools; + +use super::ops::{MerkleHasher, MerkleOps}; +use super::utils::{next_decommitment_node, option_flatten_peekable}; +use crate::core::backend::{Col, Column}; +use crate::core::fields::m31::BaseField; +use crate::core::utils::PeekableExt; +use crate::core::ColumnVec; + +pub struct MerkleProver, H: MerkleHasher> { + /// Layers of the Merkle tree. + /// The first layer is the root layer. + /// The last layer is the largest layer. + /// See [MerkleOps::commit_on_layer] for more details. + pub layers: Vec>, +} +/// The MerkleProver struct represents a prover for a Merkle commitment scheme. +/// It is generic over the types `B` and `H`, which represent the Merkle operations and Merkle +/// hasher respectively. +impl, H: MerkleHasher> MerkleProver { + /// Commits to columns. + /// Columns must be of power of 2 sizes. + /// + /// # Arguments + /// + /// * `columns` - A vector of references to columns. + /// + /// # Panics + /// + /// This function will panic if the columns are not sorted in descending order or if the columns + /// vector is empty. + /// + /// # Returns + /// + /// A new instance of `MerkleProver` with the committed layers. + pub fn commit(columns: Vec<&Col>) -> Self { + assert!(!columns.is_empty()); + + let columns = &mut columns + .into_iter() + .sorted_by_key(|c| Reverse(c.len())) + .peekable(); + let mut layers: Vec> = Vec::new(); + + let max_log_size = columns.peek().unwrap().len().ilog2(); + for log_size in (0..=max_log_size).rev() { + // Take columns of the current log_size. + let layer_columns = columns + .peek_take_while(|column| column.len().ilog2() == log_size) + .collect_vec(); + + layers.push(B::commit_on_layer(log_size, layers.last(), &layer_columns)); + } + layers.reverse(); + Self { layers } + } + + /// Decommits to columns on the given queries. + /// Queries are given as indices to the largest column. + /// + /// # Arguments + /// + /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that + /// log_size. + /// * `columns` - A vector of references to columns. + /// + /// # Returns + /// + /// A tuple containing: + /// * A vector of vectors of queried values for each column, in the order of the input columns. + /// * A `MerkleDecommitment` containing the hash and column witnesses. + pub fn decommit( + &self, + queries_per_log_size: BTreeMap>, + columns: Vec<&Col>, + ) -> (ColumnVec>, MerkleDecommitment) { + // Check that queries are sorted and deduped. + // TODO(andrew): Consider using a Queries struct to prevent this. + for queries in queries_per_log_size.values() { + assert!( + queries.windows(2).all(|w| w[0] < w[1]), + "Queries are not sorted." + ); + } + + // Prepare output buffers. + let mut queried_values_by_layer = vec![]; + let mut decommitment = MerkleDecommitment::empty(); + + // Sort columns by layer. + let mut columns_by_layer = columns + .iter() + .sorted_by_key(|c| Reverse(c.len())) + .peekable(); + + let mut last_layer_queries = vec![]; + for layer_log_size in (0..self.layers.len() as u32).rev() { + // Prepare write buffer for queried values to the current layer. + let mut layer_queried_values = vec![]; + + // Prepare write buffer for queries to the current layer. This will propagate to the + // next layer. + let mut layer_total_queries = vec![]; + + // Each layer node is a hash of column values as previous layer hashes. + // Prepare the relevant columns and previous layer hashes to read from. + let layer_columns = columns_by_layer + .peek_take_while(|column| column.len().ilog2() == layer_log_size) + .collect_vec(); + let previous_layer_hashes = self.layers.get(layer_log_size as usize + 1); + + // Queries to this layer come from queried node in the previous layer and queried + // columns in this one. + let mut prev_layer_queries = last_layer_queries.into_iter().peekable(); + let mut layer_column_queries = + option_flatten_peekable(queries_per_log_size.get(&layer_log_size)); + + // Merge previous layer queries and column queries. + while let Some(node_index) = + next_decommitment_node(&mut prev_layer_queries, &mut layer_column_queries) + { + if let Some(previous_layer_hashes) = previous_layer_hashes { + // If the left child was not computed, add it to the witness. + if prev_layer_queries.next_if_eq(&(2 * node_index)).is_none() { + decommitment + .hash_witness + .push(previous_layer_hashes.at(2 * node_index)); + } + + // If the right child was not computed, add it to the witness. + if prev_layer_queries + .next_if_eq(&(2 * node_index + 1)) + .is_none() + { + decommitment + .hash_witness + .push(previous_layer_hashes.at(2 * node_index + 1)); + } + } + + // If the column values were queried, return them. + let node_values = layer_columns.iter().map(|c| c.at(node_index)); + if layer_column_queries.next_if_eq(&node_index).is_some() { + layer_queried_values.push(node_values.collect_vec()); + } else { + // Otherwise, add them to the witness. + decommitment.column_witness.extend(node_values); + } + + layer_total_queries.push(node_index); + } + + queried_values_by_layer.push(layer_queried_values); + + // Propagate queries to the next layer. + last_layer_queries = layer_total_queries; + } + queried_values_by_layer.reverse(); + + // Rearrange returned queried values according to input, and not by layer. + let queried_values = Self::rearrange_queried_values(queried_values_by_layer, columns); + + (queried_values, decommitment) + } + + /// Given queried values by layer, rearranges in the order of input columns. + fn rearrange_queried_values( + queried_values_by_layer: Vec>>, + columns: Vec<&Col>, + ) -> Vec> { + // Turn each column queried values into an iterator. + let mut queried_values_by_layer = queried_values_by_layer + .into_iter() + .map(|layer_results| { + layer_results + .into_iter() + .map(|x| x.into_iter()) + .collect_vec() + }) + .collect_vec(); + + // For each input column, fetch the queried values from the corresponding layer. + let queried_values = columns + .iter() + .map(|column| { + queried_values_by_layer + .get_mut(column.len().ilog2() as usize) + .unwrap() + .iter_mut() + .map(|x| x.next().unwrap()) + .collect_vec() + }) + .collect_vec(); + queried_values + } + + pub fn root(&self) -> H::Hash { + self.layers.first().unwrap().at(0) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd)] +pub struct MerkleDecommitment { + /// Hash values that the verifier needs but cannot deduce from previous computations, in the + /// order they are needed. + pub hash_witness: Vec, + /// Column values that the verifier needs but cannot deduce from previous computations, in the + /// order they are needed. + /// This complements the column values that were queried. These must be supplied directly to + /// the verifier. + pub column_witness: Vec, +} +impl MerkleDecommitment { + fn empty() -> Self { + Self { + hash_witness: Vec::new(), + column_witness: Vec::new(), + } + } +} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/test_utils.rs b/Stwo_wrapper/crates/prover/src/core/vcs/test_utils.rs new file mode 100644 index 0000000..8fc535b --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/vcs/test_utils.rs @@ -0,0 +1,60 @@ +use std::collections::BTreeMap; + +use itertools::Itertools; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; + +use super::ops::{MerkleHasher, MerkleOps}; +use super::prover::MerkleDecommitment; +use super::verifier::MerkleVerifier; +use crate::core::backend::CpuBackend; +use crate::core::fields::m31::BaseField; +use crate::core::vcs::prover::MerkleProver; + +pub type TestData = ( + BTreeMap>, + MerkleDecommitment, + Vec>, + MerkleVerifier, +); + +pub fn prepare_merkle() -> TestData +where + CpuBackend: MerkleOps, +{ + const N_COLS: usize = 10; + const N_QUERIES: usize = 3; + let log_size_range = 3..5; + + let mut rng = SmallRng::seed_from_u64(0); + let log_sizes = (0..N_COLS) + .map(|_| rng.gen_range(log_size_range.clone())) + .collect_vec(); + let cols = log_sizes + .iter() + .map(|&log_size| { + (0..(1 << log_size)) + .map(|_| BaseField::from(rng.gen_range(0..(1 << 30)))) + .collect_vec() + }) + .collect_vec(); + let merkle = MerkleProver::::commit(cols.iter().collect_vec()); + + let mut queries = BTreeMap::>::new(); + for log_size in log_size_range.rev() { + let layer_queries = (0..N_QUERIES) + .map(|_| rng.gen_range(0..(1 << log_size))) + .sorted() + .dedup() + .collect_vec(); + queries.insert(log_size, layer_queries); + } + + let (values, decommitment) = merkle.decommit(queries.clone(), cols.iter().collect_vec()); + + let verifier = MerkleVerifier { + root: merkle.root(), + column_log_sizes: log_sizes, + }; + (queries, decommitment, values, verifier) +} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/utils.rs b/Stwo_wrapper/crates/prover/src/core/vcs/utils.rs new file mode 100644 index 0000000..2f89f40 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/vcs/utils.rs @@ -0,0 +1,20 @@ +use std::iter::Peekable; + +/// Fetches the next node that needs to be decommited in the current Merkle layer. +pub fn next_decommitment_node( + prev_queries: &mut Peekable>, + layer_queries: &mut Peekable>, +) -> Option { + prev_queries + .peek() + .map(|q| *q / 2) + .into_iter() + .chain(layer_queries.peek().into_iter().copied()) + .min() +} + +pub fn option_flatten_peekable<'a, I: IntoIterator>( + a: Option, +) -> Peekable as IntoIterator>::IntoIter>>> { + a.into_iter().flatten().copied().peekable() +} diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/verifier.rs b/Stwo_wrapper/crates/prover/src/core/vcs/verifier.rs new file mode 100644 index 0000000..53346bb --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/core/vcs/verifier.rs @@ -0,0 +1,194 @@ +use std::cmp::Reverse; +use std::collections::BTreeMap; + +use itertools::Itertools; +use thiserror::Error; + +use super::ops::MerkleHasher; +use super::prover::MerkleDecommitment; +use super::utils::{next_decommitment_node, option_flatten_peekable}; +use crate::core::fields::m31::BaseField; +use crate::core::utils::PeekableExt; +use crate::core::ColumnVec; + +// TODO(spapini): This struct is not necessary. Make it a function on decommitment? +pub struct MerkleVerifier { + pub root: H::Hash, + pub column_log_sizes: Vec, +} +impl MerkleVerifier { + pub fn new(root: H::Hash, column_log_sizes: Vec) -> Self { + Self { + root, + column_log_sizes, + } + } + /// Verifies the decommitment of the columns. + /// + /// # Arguments + /// + /// * `queries_per_log_size` - A map from log_size to a vector of queries for columns of that + /// log_size. + /// * `queried_values` - A vector of vectors of queried values. For each column, there is a + /// vector of queried values to that column. + /// * `decommitment` - The decommitment object containing the witness and column values. + /// + /// # Errors + /// + /// Returns an error if any of the following conditions are met: + /// + /// * The witness is too long (not fully consumed). + /// * The witness is too short (missing values). + /// * The column values are too long (not fully consumed). + /// * The column values are too short (missing values). + /// * The computed root does not match the expected root. + /// + /// # Panics + /// + /// This function will panic if the `values` vector is not sorted in descending order based on + /// the `log_size` of the columns. + /// + /// # Returns + /// + /// Returns `Ok(())` if the decommitment is successfully verified. + pub fn verify( + &self, + queries_per_log_size: BTreeMap>, + queried_values: ColumnVec>, + decommitment: MerkleDecommitment, + ) -> Result<(), MerkleVerificationError> { + let max_log_size = self.column_log_sizes.iter().max().copied().unwrap_or(0); + + // Prepare read buffers. + let mut queried_values_by_layer = self + .column_log_sizes + .iter() + .copied() + .zip( + queried_values + .into_iter() + .map(|column_values| column_values.into_iter()), + ) + .sorted_by_key(|(log_size, _)| Reverse(*log_size)) + .peekable(); + let mut hash_witness = decommitment.hash_witness.into_iter(); + let mut column_witness = decommitment.column_witness.into_iter(); + + let mut last_layer_hashes: Option> = None; + for layer_log_size in (0..=max_log_size).rev() { + // Prepare read buffer for queried values to the current layer. + let mut layer_queried_values = queried_values_by_layer + .peek_take_while(|(log_size, _)| *log_size == layer_log_size) + .collect_vec(); + let n_columns_in_layer = layer_queried_values.len(); + + // Prepare write buffer for queries to the current layer. This will propagate to the + // next layer. + let mut layer_total_queries = vec![]; + + // Queries to this layer come from queried node in the previous layer and queried + // columns in this one. + let mut prev_layer_queries = last_layer_hashes + .iter() + .flatten() + .map(|(q, _)| *q) + .collect_vec() + .into_iter() + .peekable(); + let mut prev_layer_hashes = last_layer_hashes.as_ref().map(|x| x.iter().peekable()); + let mut layer_column_queries = + option_flatten_peekable(queries_per_log_size.get(&layer_log_size)); + + // Merge previous layer queries and column queries. + while let Some(node_index) = + next_decommitment_node(&mut prev_layer_queries, &mut layer_column_queries) + { + prev_layer_queries + .peek_take_while(|q| q / 2 == node_index) + .for_each(drop); + + let node_hashes = prev_layer_hashes + .as_mut() + .map(|prev_layer_hashes| { + { + // If the left child was not computed, read it from the witness. + let left_hash = prev_layer_hashes + .next_if(|(index, _)| *index == 2 * node_index) + .map(|(_, hash)| Ok(*hash)) + .unwrap_or_else(|| { + hash_witness + .next() + .ok_or(MerkleVerificationError::WitnessTooShort) + })?; + + // If the right child was not computed, read it to from the witness. + let right_hash = prev_layer_hashes + .next_if(|(index, _)| *index == 2 * node_index + 1) + .map(|(_, hash)| Ok(*hash)) + .unwrap_or_else(|| { + hash_witness + .next() + .ok_or(MerkleVerificationError::WitnessTooShort) + })?; + Ok((left_hash, right_hash)) + } + }) + .transpose()?; + + // If the column values were queried, read them from `queried_value`. + let node_values = if layer_column_queries.next_if_eq(&node_index).is_some() { + layer_queried_values + .iter_mut() + .map(|(_, ref mut column_queries)| { + column_queries + .next() + .ok_or(MerkleVerificationError::ColumnValuesTooShort) + }) + .collect::, _>>()? + } else { + // Otherwise, read them from the witness. + (&mut column_witness).take(n_columns_in_layer).collect_vec() + }; + if node_values.len() != n_columns_in_layer { + return Err(MerkleVerificationError::WitnessTooShort); + } + + layer_total_queries.push((node_index, H::hash_node(node_hashes, &node_values))); + } + + if !layer_queried_values.iter().all(|(_, c)| c.is_empty()) { + return Err(MerkleVerificationError::ColumnValuesTooLong); + } + last_layer_hashes = Some(layer_total_queries); + } + + // Check that all witnesses and values have been consumed. + if !hash_witness.is_empty() { + return Err(MerkleVerificationError::WitnessTooLong); + } + if !column_witness.is_empty() { + return Err(MerkleVerificationError::WitnessTooLong); + } + + let [(_, computed_root)] = last_layer_hashes.unwrap().try_into().unwrap(); + if computed_root != self.root { + return Err(MerkleVerificationError::RootMismatch); + } + + Ok(()) + } +} + +#[derive(Clone, Copy, Debug, Error, PartialEq, Eq)] +pub enum MerkleVerificationError { + #[error("Witness is too short.")] + WitnessTooShort, + #[error("Witness is too long.")] + WitnessTooLong, + #[error("Column values are too long.")] + ColumnValuesTooLong, + #[error("Column values are too short.")] + ColumnValuesTooShort, + #[error("Root mismatch.")] + RootMismatch, +} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/air.rs b/Stwo_wrapper/crates/prover/src/examples/blake/air.rs new file mode 100644 index 0000000..e655ee6 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/blake/air.rs @@ -0,0 +1,483 @@ +use std::simd::u32x16; + +use itertools::{chain, multiunzip, Itertools}; +use num_traits::Zero; +use serde::Serialize; +use tracing::{span, Level}; + +use super::round::{blake_round_info, BlakeRoundComponent, BlakeRoundEval}; +use super::scheduler::{BlakeSchedulerComponent, BlakeSchedulerEval}; +use super::xor_table::{XorTableComponent, XorTableEval}; +use crate::constraint_framework::TraceLocationAllocator; +use crate::core::air::{Component, ComponentProver}; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::BackendForChannel; +use crate::core::channel::{Channel, MerkleChannel}; +use crate::core::fields::qm31::SecureField; +use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; +use crate::core::poly::circle::{CanonicCoset, PolyOps}; +use crate::core::prover::{prove, verify, StarkProof, VerificationError}; +use crate::core::vcs::ops::MerkleHasher; +use crate::examples::blake::round::RoundElements; +use crate::examples::blake::scheduler::{self, blake_scheduler_info, BlakeElements, BlakeInput}; +use crate::examples::blake::{ + round, xor_table, BlakeXorElements, XorAccums, N_ROUNDS, ROUND_LOG_SPLIT, +}; + +#[derive(Serialize)] +pub struct BlakeStatement0 { + log_size: u32, +} +impl BlakeStatement0 { + fn log_sizes(&self) -> TreeVec> { + let mut sizes = vec![]; + sizes.push( + blake_scheduler_info() + .mask_offsets + .as_cols_ref() + .map_cols(|_| self.log_size), + ); + for l in ROUND_LOG_SPLIT { + sizes.push( + blake_round_info() + .mask_offsets + .as_cols_ref() + .map_cols(|_| self.log_size + l), + ); + } + sizes.push(xor_table::trace_sizes::<12, 4>()); + sizes.push(xor_table::trace_sizes::<9, 2>()); + sizes.push(xor_table::trace_sizes::<8, 2>()); + sizes.push(xor_table::trace_sizes::<7, 2>()); + sizes.push(xor_table::trace_sizes::<4, 0>()); + + TreeVec::concat_cols(sizes.into_iter()) + } + fn mix_into(&self, channel: &mut impl Channel) { + // TODO(spapini): Do this better. + channel.mix_nonce(self.log_size as u64); + } +} + +pub struct AllElements { + blake_elements: BlakeElements, + round_elements: RoundElements, + xor_elements: BlakeXorElements, +} +impl AllElements { + pub fn draw(channel: &mut impl Channel) -> Self { + Self { + blake_elements: BlakeElements::draw(channel), + round_elements: RoundElements::draw(channel), + xor_elements: BlakeXorElements::draw(channel), + } + } +} + +pub struct BlakeStatement1 { + scheduler_claimed_sum: SecureField, + round_claimed_sums: Vec, + xor12_claimed_sum: SecureField, + xor9_claimed_sum: SecureField, + xor8_claimed_sum: SecureField, + xor7_claimed_sum: SecureField, + xor4_claimed_sum: SecureField, +} +impl BlakeStatement1 { + fn mix_into(&self, channel: &mut impl Channel) { + channel.mix_felts( + &chain![ + [ + self.scheduler_claimed_sum, + self.xor12_claimed_sum, + self.xor9_claimed_sum, + self.xor8_claimed_sum, + self.xor7_claimed_sum, + self.xor4_claimed_sum + ], + self.round_claimed_sums.clone() + ] + .collect_vec(), + ) + } +} + +pub struct BlakeProof { + stmt0: BlakeStatement0, + stmt1: BlakeStatement1, + stark_proof: StarkProof, +} + +pub struct BlakeComponents { + scheduler_component: BlakeSchedulerComponent, + round_components: Vec, + xor12: XorTableComponent<12, 4>, + xor9: XorTableComponent<9, 2>, + xor8: XorTableComponent<8, 2>, + xor7: XorTableComponent<7, 2>, + xor4: XorTableComponent<4, 0>, +} +impl BlakeComponents { + fn new(stmt0: &BlakeStatement0, all_elements: &AllElements, stmt1: &BlakeStatement1) -> Self { + let tree_span_provider = &mut TraceLocationAllocator::default(); + Self { + scheduler_component: BlakeSchedulerComponent::new( + tree_span_provider, + BlakeSchedulerEval { + log_size: stmt0.log_size, + blake_lookup_elements: all_elements.blake_elements.clone(), + round_lookup_elements: all_elements.round_elements.clone(), + claimed_sum: stmt1.scheduler_claimed_sum, + }, + ), + round_components: ROUND_LOG_SPLIT + .iter() + .zip(stmt1.round_claimed_sums.clone()) + .map(|(l, claimed_sum)| { + BlakeRoundComponent::new( + tree_span_provider, + BlakeRoundEval { + log_size: stmt0.log_size + l, + xor_lookup_elements: all_elements.xor_elements.clone(), + round_lookup_elements: all_elements.round_elements.clone(), + claimed_sum, + }, + ) + }) + .collect(), + xor12: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor12.clone(), + claimed_sum: stmt1.xor12_claimed_sum, + }, + ), + xor9: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor9.clone(), + claimed_sum: stmt1.xor9_claimed_sum, + }, + ), + xor8: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor8.clone(), + claimed_sum: stmt1.xor8_claimed_sum, + }, + ), + xor7: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor7.clone(), + claimed_sum: stmt1.xor7_claimed_sum, + }, + ), + xor4: XorTableComponent::new( + tree_span_provider, + XorTableEval { + lookup_elements: all_elements.xor_elements.xor4.clone(), + claimed_sum: stmt1.xor4_claimed_sum, + }, + ), + } + } + fn components(&self) -> Vec<&dyn Component> { + chain![ + [&self.scheduler_component as &dyn Component], + self.round_components.iter().map(|c| c as &dyn Component), + [ + &self.xor12 as &dyn Component, + &self.xor9 as &dyn Component, + &self.xor8 as &dyn Component, + &self.xor7 as &dyn Component, + &self.xor4 as &dyn Component, + ] + ] + .collect() + } + + fn component_provers(&self) -> Vec<&dyn ComponentProver> { + chain![ + [&self.scheduler_component as &dyn ComponentProver], + self.round_components + .iter() + .map(|c| c as &dyn ComponentProver), + [ + &self.xor12 as &dyn ComponentProver, + &self.xor9 as &dyn ComponentProver, + &self.xor8 as &dyn ComponentProver, + &self.xor7 as &dyn ComponentProver, + &self.xor4 as &dyn ComponentProver, + ] + ] + .collect() + } +} + +#[allow(unused)] +pub fn prove_blake(log_size: u32, config: PcsConfig) -> (BlakeProof) +where + SimdBackend: BackendForChannel, +{ + assert!(log_size >= LOG_N_LANES); + assert_eq!( + ROUND_LOG_SPLIT.map(|x| (1 << x)).into_iter().sum::() as usize, + N_ROUNDS + ); + + // Precompute twiddles. + let span = span!(Level::INFO, "Precompute twiddles").entered(); + const XOR_TABLE_MAX_LOG_SIZE: u32 = 16; + let log_max_rows = + (log_size + *ROUND_LOG_SPLIT.iter().max().unwrap()).max(XOR_TABLE_MAX_LOG_SIZE); + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(log_max_rows + 1 + config.fri_config.log_blowup_factor) + .circle_domain() + .half_coset, + ); + span.exit(); + + // Prepare inputs. + let blake_inputs = (0..(1 << (log_size - LOG_N_LANES))) + .map(|i| { + let v = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j) as u32)); 16]; + let m = [u32x16::from_array(std::array::from_fn(|j| (i + 2 * j + 1) as u32)); 16]; + BlakeInput { v, m } + }) + .collect_vec(); + + // Setup protocol. + let channel = &mut MC::C::default(); + let commitment_scheme = &mut CommitmentSchemeProver::new(config, &twiddles); + + let span = span!(Level::INFO, "Trace").entered(); + + // Scheduler. + let (scheduler_trace, scheduler_lookup_data, round_inputs) = + scheduler::gen_trace(log_size, &blake_inputs); + + // Rounds. + let mut xor_accums = XorAccums::default(); + let mut rest = &round_inputs[..]; + // Split round inputs to components, according to [ROUND_LOG_SPLIT]. + let (round_traces, round_lookup_datas): (Vec<_>, Vec<_>) = + multiunzip(ROUND_LOG_SPLIT.map(|l| { + let (cur_inputs, r) = rest.split_at(1 << (log_size - LOG_N_LANES + l)); + rest = r; + round::generate_trace(log_size + l, cur_inputs, &mut xor_accums) + })); + + // Xor tables. + let (xor_trace12, xor_lookup_data12) = xor_table::generate_trace(xor_accums.xor12); + let (xor_trace9, xor_lookup_data9) = xor_table::generate_trace(xor_accums.xor9); + let (xor_trace8, xor_lookup_data8) = xor_table::generate_trace(xor_accums.xor8); + let (xor_trace7, xor_lookup_data7) = xor_table::generate_trace(xor_accums.xor7); + let (xor_trace4, xor_lookup_data4) = xor_table::generate_trace(xor_accums.xor4); + + // Statement0. + let stmt0 = BlakeStatement0 { log_size }; + stmt0.mix_into(channel); + + // Trace commitment. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals( + chain![ + scheduler_trace, + round_traces.into_iter().flatten(), + xor_trace12, + xor_trace9, + xor_trace8, + xor_trace7, + xor_trace4, + ] + .collect_vec(), + ); + tree_builder.commit(channel); + span.exit(); + + // Draw lookup element. + let all_elements = AllElements::draw(channel); + + // Interaction trace. + let span = span!(Level::INFO, "Interaction").entered(); + let (scheduler_trace, scheduler_claimed_sum) = scheduler::gen_interaction_trace( + log_size, + scheduler_lookup_data, + &all_elements.round_elements, + &all_elements.blake_elements, + ); + + let (round_traces, round_claimed_sums): (Vec<_>, Vec<_>) = multiunzip( + ROUND_LOG_SPLIT + .iter() + .zip(round_lookup_datas) + .map(|(l, lookup_data)| { + round::generate_interaction_trace( + log_size + l, + lookup_data, + &all_elements.xor_elements, + &all_elements.round_elements, + ) + }), + ); + + let (xor_trace12, xor12_claimed_sum) = + xor_table::generate_interaction_trace(xor_lookup_data12, &all_elements.xor_elements.xor12); + let (xor_trace9, xor9_claimed_sum) = + xor_table::generate_interaction_trace(xor_lookup_data9, &all_elements.xor_elements.xor9); + let (xor_trace8, xor8_claimed_sum) = + xor_table::generate_interaction_trace(xor_lookup_data8, &all_elements.xor_elements.xor8); + let (xor_trace7, xor7_claimed_sum) = + xor_table::generate_interaction_trace(xor_lookup_data7, &all_elements.xor_elements.xor7); + let (xor_trace4, xor4_claimed_sum) = + xor_table::generate_interaction_trace(xor_lookup_data4, &all_elements.xor_elements.xor4); + + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals( + chain![ + scheduler_trace, + round_traces.into_iter().flatten(), + xor_trace12, + xor_trace9, + xor_trace8, + xor_trace7, + xor_trace4, + ] + .collect_vec(), + ); + + // Statement1. + let stmt1 = BlakeStatement1 { + scheduler_claimed_sum, + round_claimed_sums, + xor12_claimed_sum, + xor9_claimed_sum, + xor8_claimed_sum, + xor7_claimed_sum, + xor4_claimed_sum, + }; + stmt1.mix_into(channel); + tree_builder.commit(channel); + span.exit(); + + // Constant trace. + let span = span!(Level::INFO, "Constant Trace").entered(); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals( + chain![ + xor_table::generate_constant_trace::<12, 4>(), + xor_table::generate_constant_trace::<9, 2>(), + xor_table::generate_constant_trace::<8, 2>(), + xor_table::generate_constant_trace::<7, 2>(), + xor_table::generate_constant_trace::<4, 0>(), + ] + .collect_vec(), + ); + tree_builder.commit(channel); + span.exit(); + + assert_eq!( + commitment_scheme + .polynomials() + .as_cols_ref() + .map_cols(|c| c.log_size()) + .0, + stmt0.log_sizes().0 + ); + + // Prove constraints. + let components = BlakeComponents::new(&stmt0, &all_elements, &stmt1); + let stark_proof = + prove::(&components.component_provers(), channel, commitment_scheme) + .unwrap(); + + BlakeProof { + stmt0, + stmt1, + stark_proof, + } +} + +#[allow(unused)] +pub fn verify_blake( + BlakeProof { + stmt0, + stmt1, + stark_proof, + }: BlakeProof, + config: PcsConfig, +) -> Result<(), VerificationError> { + let channel = &mut MC::C::default(); + let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); + + let log_sizes = stmt0.log_sizes(); + + // Trace. + stmt0.mix_into(channel); + commitment_scheme.commit(stark_proof.commitments[0], &log_sizes[0], channel); + + // Draw interaction elements. + let all_elements = AllElements::draw(channel); + + // Interaction trace. + stmt1.mix_into(channel); + commitment_scheme.commit(stark_proof.commitments[1], &log_sizes[1], channel); + + // Constant trace. + commitment_scheme.commit(stark_proof.commitments[2], &log_sizes[2], channel); + + let components = BlakeComponents::new(&stmt0, &all_elements, &stmt1); + + // Check that all sums are correct. + let total_sum = stmt1.scheduler_claimed_sum + + stmt1.round_claimed_sums.iter().sum::() + + stmt1.xor12_claimed_sum + + stmt1.xor9_claimed_sum + + stmt1.xor8_claimed_sum + + stmt1.xor7_claimed_sum + + stmt1.xor4_claimed_sum; + + // TODO(spapini): Add inputs to sum, and constraint them. + assert_eq!(total_sum, SecureField::zero()); + + verify( + &components.components(), + channel, + commitment_scheme, + stark_proof, + ) +} + +#[cfg(test)] +mod tests { + use std::env; + + use crate::core::pcs::PcsConfig; + use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; + use crate::examples::blake::air::{prove_blake, verify_blake}; + + // Note: this test is slow. Only run in release. + #[cfg_attr(not(feature = "slow-tests"), ignore)] + #[test_log::test] + fn test_simd_blake_prove() { + // Note: To see time measurement, run test with + // LOG_N_INSTANCES=16 RUST_LOG_SPAN_EVENTS=enter,close RUST_LOG=info RUSTFLAGS=" + // -C target-cpu=native -C target-feature=+avx512f" cargo test --release + // test_simd_blake_prove -- --nocapture --ignored + + // Get from environment variable: + let log_n_instances = env::var("LOG_N_INSTANCES") + .unwrap_or_else(|_| "6".to_string()) + .parse::() + .unwrap(); + let config = PcsConfig::default(); + + // Prove. + let proof = prove_blake::(log_n_instances, config); + + // Verify. + verify_blake::(proof, config).unwrap(); + } +} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/mod.rs b/Stwo_wrapper/crates/prover/src/examples/blake/mod.rs new file mode 100644 index 0000000..6fbe6d8 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/blake/mod.rs @@ -0,0 +1,126 @@ +//! AIR for blake2s and blake3. +//! See + +use std::fmt::Debug; +use std::ops::{Add, AddAssign, Mul, Sub}; +use std::simd::u32x16; + +use xor_table::{XorAccumulator, XorElements}; + +use crate::core::backend::simd::m31::PackedBaseField; +use crate::core::channel::Channel; +use crate::core::fields::m31::BaseField; +use crate::core::fields::FieldExpOps; + +mod air; +mod round; +mod scheduler; +mod xor_table; + +const STATE_SIZE: usize = 16; +const MESSAGE_SIZE: usize = 16; +const N_FELTS_IN_U32: usize = 2; +const N_ROUND_INPUT_FELTS: usize = (STATE_SIZE + STATE_SIZE + MESSAGE_SIZE) * N_FELTS_IN_U32; + +// Parameters for Blake2s. Change these for blake3. +const N_ROUNDS: usize = 10; +/// A splitting N_ROUNDS into several powers of 2. +const ROUND_LOG_SPLIT: [u32; 2] = [3, 1]; + +#[derive(Default)] +struct XorAccums { + xor12: XorAccumulator<12, 4>, + xor9: XorAccumulator<9, 2>, + xor8: XorAccumulator<8, 2>, + xor7: XorAccumulator<7, 2>, + xor4: XorAccumulator<4, 0>, +} +impl XorAccums { + fn add_input(&mut self, w: u32, a: u32x16, b: u32x16) { + match w { + 12 => self.xor12.add_input(a, b), + 9 => self.xor9.add_input(a, b), + 8 => self.xor8.add_input(a, b), + 7 => self.xor7.add_input(a, b), + 4 => self.xor4.add_input(a, b), + _ => panic!("Invalid w"), + } + } +} + +#[derive(Clone)] +pub struct BlakeXorElements { + xor12: XorElements, + xor9: XorElements, + xor8: XorElements, + xor7: XorElements, + xor4: XorElements, +} +impl BlakeXorElements { + fn draw(channel: &mut impl Channel) -> Self { + Self { + xor12: XorElements::draw(channel), + xor9: XorElements::draw(channel), + xor8: XorElements::draw(channel), + xor7: XorElements::draw(channel), + xor4: XorElements::draw(channel), + } + } + fn dummy() -> Self { + Self { + xor12: XorElements::dummy(), + xor9: XorElements::dummy(), + xor8: XorElements::dummy(), + xor7: XorElements::dummy(), + xor4: XorElements::dummy(), + } + } + fn get(&self, w: u32) -> &XorElements { + match w { + 12 => &self.xor12, + 9 => &self.xor9, + 8 => &self.xor8, + 7 => &self.xor7, + 4 => &self.xor4, + _ => panic!("Invalid w"), + } + } +} + +/// Utility for representing a u32 as two field elements, for constraint evaluation. +#[derive(Clone, Copy, Debug)] +struct Fu32 +where + F: FieldExpOps + + Copy + + Debug + + AddAssign + + Add + + Sub + + Mul, +{ + l: F, + h: F, +} +impl Fu32 +where + F: FieldExpOps + + Copy + + Debug + + AddAssign + + Add + + Sub + + Mul, +{ + fn to_felts(self) -> [F; 2] { + [self.l, self.h] + } +} + +/// Utility for splitting a u32 into 2 field elements in trace generation. +fn to_felts(x: &u32x16) -> [PackedBaseField; 2] { + [ + unsafe { PackedBaseField::from_simd_unchecked(x & u32x16::splat(0xffff)) }, + unsafe { PackedBaseField::from_simd_unchecked(x >> 16) }, + ] +} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/round/constraints.rs b/Stwo_wrapper/crates/prover/src/examples/blake/round/constraints.rs new file mode 100644 index 0000000..9440944 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/blake/round/constraints.rs @@ -0,0 +1,164 @@ +use itertools::{chain, Itertools}; +use num_traits::One; + +use super::{BlakeXorElements, RoundElements}; +use crate::constraint_framework::logup::LogupAtRow; +use crate::constraint_framework::EvalAtRow; +use crate::core::fields::m31::BaseField; +use crate::core::lookups::utils::Fraction; +use crate::examples::blake::{Fu32, STATE_SIZE}; + +const INV16: BaseField = BaseField::from_u32_unchecked(1 << 15); +const TWO: BaseField = BaseField::from_u32_unchecked(2); + +pub struct BlakeRoundEval<'a, E: EvalAtRow> { + pub eval: E, + pub xor_lookup_elements: &'a BlakeXorElements, + pub round_lookup_elements: &'a RoundElements, + pub logup: LogupAtRow<2, E>, +} +impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { + pub fn eval(mut self) -> E { + let mut v: [Fu32; STATE_SIZE] = std::array::from_fn(|_| self.next_u32()); + let input_v = v; + let m: [Fu32; STATE_SIZE] = std::array::from_fn(|_| self.next_u32()); + + self.g(v.get_many_mut([0, 4, 8, 12]).unwrap(), m[0], m[1]); + self.g(v.get_many_mut([1, 5, 9, 13]).unwrap(), m[2], m[3]); + self.g(v.get_many_mut([2, 6, 10, 14]).unwrap(), m[4], m[5]); + self.g(v.get_many_mut([3, 7, 11, 15]).unwrap(), m[6], m[7]); + self.g(v.get_many_mut([0, 5, 10, 15]).unwrap(), m[8], m[9]); + self.g(v.get_many_mut([1, 6, 11, 12]).unwrap(), m[10], m[11]); + self.g(v.get_many_mut([2, 7, 8, 13]).unwrap(), m[12], m[13]); + self.g(v.get_many_mut([3, 4, 9, 14]).unwrap(), m[14], m[15]); + + // Yield `Round(input_v, output_v, message)`. + self.logup.push_lookup( + &mut self.eval, + -E::EF::one(), + &chain![ + input_v.iter().copied().flat_map(Fu32::to_felts), + v.iter().copied().flat_map(Fu32::to_felts), + m.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(), + self.round_lookup_elements, + ); + + self.logup.finalize(&mut self.eval); + self.eval + } + fn next_u32(&mut self) -> Fu32 { + let l = self.eval.next_trace_mask(); + let h = self.eval.next_trace_mask(); + Fu32 { l, h } + } + fn g(&mut self, v: [&mut Fu32; 4], m0: Fu32, m1: Fu32) { + let [a, b, c, d] = v; + + *a = self.add3_u32_unchecked(*a, *b, m0); + *d = self.xor_rotr16_u32(*a, *d); + *c = self.add2_u32_unchecked(*c, *d); + *b = self.xor_rotr_u32(*b, *c, 12); + *a = self.add3_u32_unchecked(*a, *b, m1); + *d = self.xor_rotr_u32(*a, *d, 8); + *c = self.add2_u32_unchecked(*c, *d); + *b = self.xor_rotr_u32(*b, *c, 7); + } + + /// Adds two u32s, returning the sum. + /// Assumes a, b are properly range checked. + /// The caller is responsible for checking: + /// res.{l,h} not in [2^16, 2^17) or in [-2^16,0) + fn add2_u32_unchecked(&mut self, a: Fu32, b: Fu32) -> Fu32 { + let sl = self.eval.next_trace_mask(); + let sh = self.eval.next_trace_mask(); + + let carry_l = (a.l + b.l - sl) * E::F::from(INV16); + self.eval.add_constraint(carry_l * carry_l - carry_l); + + let carry_h = (a.h + b.h + carry_l - sh) * E::F::from(INV16); + self.eval.add_constraint(carry_h * carry_h - carry_h); + + Fu32 { l: sl, h: sh } + } + + /// Adds three u32s, returning the sum. + /// Assumes a, b, c are properly range checked. + /// Caller is responsible for checking: + /// res.{l,h} not in [2^16, 3*2^16) or in [-2^17,0) + fn add3_u32_unchecked(&mut self, a: Fu32, b: Fu32, c: Fu32) -> Fu32 { + let sl = self.eval.next_trace_mask(); + let sh = self.eval.next_trace_mask(); + + let carry_l = (a.l + b.l + c.l - sl) * E::F::from(INV16); + self.eval + .add_constraint(carry_l * (carry_l - E::F::one()) * (carry_l - E::F::from(TWO))); + + let carry_h = (a.h + b.h + c.h + carry_l - sh) * E::F::from(INV16); + self.eval + .add_constraint(carry_h * (carry_h - E::F::one()) * (carry_h - E::F::from(TWO))); + + Fu32 { l: sl, h: sh } + } + + /// Splits a felt at r. + /// Caller is responsible for checking that the ranges of h * 2^r and l don't overlap. + fn split_unchecked(&mut self, a: E::F, r: u32) -> (E::F, E::F) { + let h = self.eval.next_trace_mask(); + let l = a - h * E::F::from(BaseField::from_u32_unchecked(1 << r)); + (l, h) + } + + /// Checks that a, b are in range, and computes their xor rotated right by `r` bits. + /// Guarantees that all elements are in range. + fn xor_rotr_u32(&mut self, a: Fu32, b: Fu32, r: u32) -> Fu32 { + let (all, alh) = self.split_unchecked(a.l, r); + let (ahl, ahh) = self.split_unchecked(a.h, r); + let (bll, blh) = self.split_unchecked(b.l, r); + let (bhl, bhh) = self.split_unchecked(b.h, r); + + // These also guarantee that all elements are in range. + let [xorll, xorhl] = self.xor2(r, [all, ahl], [bll, bhl]); + let [xorlh, xorhh] = self.xor2(16 - r, [alh, ahh], [blh, bhh]); + + Fu32 { + l: xorhl * E::F::from(BaseField::from_u32_unchecked(1 << (16 - r))) + xorlh, + h: xorll * E::F::from(BaseField::from_u32_unchecked(1 << (16 - r))) + xorhh, + } + } + + /// Checks that a, b are in range, and computes their xor rotated right by 16 bits. + /// Guarantees that all elements are in range. + fn xor_rotr16_u32(&mut self, a: Fu32, b: Fu32) -> Fu32 { + let (all, alh) = self.split_unchecked(a.l, 8); + let (ahl, ahh) = self.split_unchecked(a.h, 8); + let (bll, blh) = self.split_unchecked(b.l, 8); + let (bhl, bhh) = self.split_unchecked(b.h, 8); + + // These also guarantee that all elements are in range. + let [xorll, xorhl] = self.xor2(8, [all, ahl], [bll, bhl]); + let [xorlh, xorhh] = self.xor2(8, [alh, ahh], [blh, bhh]); + + Fu32 { + l: xorhh * E::F::from(BaseField::from_u32_unchecked(1 << 8)) + xorhl, + h: xorlh * E::F::from(BaseField::from_u32_unchecked(1 << 8)) + xorll, + } + } + + /// Checks that a, b are in [0, 2^w) and computes their xor. + fn xor2(&mut self, w: u32, a: [E::F; 2], b: [E::F; 2]) -> [E::F; 2] { + // TODO: Separate lookups by w. + let c = [self.eval.next_trace_mask(), self.eval.next_trace_mask()]; + let lookup_elements = self.xor_lookup_elements.get(w); + let comb0 = lookup_elements.combine::(&[a[0], b[0], c[0]]); + let comb1 = lookup_elements.combine::(&[a[1], b[1], c[1]]); + let frac = Fraction { + numerator: comb0 + comb1, + denominator: comb0 * comb1, + }; + + self.logup.add_frac(&mut self.eval, frac); + c + } +} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/round/gen.rs b/Stwo_wrapper/crates/prover/src/examples/blake/round/gen.rs new file mode 100644 index 0000000..ba9933b --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/blake/round/gen.rs @@ -0,0 +1,281 @@ +use std::simd::u32x16; +use std::vec; + +use itertools::{chain, Itertools}; +use num_traits::One; +use tracing::{span, Level}; + +use super::{BlakeXorElements, RoundElements}; +use crate::constraint_framework::logup::LogupTraceGenerator; +use crate::core::backend::simd::column::BaseColumn; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{Col, Column}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::ColumnVec; +use crate::examples::blake::round::blake_round_info; +use crate::examples::blake::{to_felts, XorAccums, N_ROUND_INPUT_FELTS, STATE_SIZE}; + +pub struct BlakeRoundLookupData { + /// A vector of (w, [a_col, b_col, c_col]) for each xor lookup. + /// w is the xor width. c_col is the xor col of a_col and b_col. + xor_lookups: Vec<(u32, [BaseColumn; 3])>, + /// A column of round lookup values (v_in, v_out, m). + round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], +} + +pub struct TraceGenerator { + log_size: u32, + trace: Vec, + xor_lookups: Vec<(u32, [BaseColumn; 3])>, + round_lookup: [BaseColumn; N_ROUND_INPUT_FELTS], +} +impl TraceGenerator { + fn new(log_size: u32) -> Self { + assert!(log_size >= LOG_N_LANES); + let trace = (0..blake_round_info().mask_offsets[0].len()) + .map(|_| unsafe { Col::::uninitialized(1 << log_size) }) + .collect_vec(); + Self { + log_size, + trace, + xor_lookups: vec![], + round_lookup: std::array::from_fn(|_| unsafe { + BaseColumn::uninitialized(1 << log_size) + }), + } + } + + fn gen_row(&mut self, vec_row: usize) -> TraceGeneratorRow<'_> { + TraceGeneratorRow { + gen: self, + col_index: 0, + vec_row, + xor_lookups_index: 0, + } + } +} + +/// Trace generator for the constraints defined at [`super::constraints::BlakeRoundEval`] +struct TraceGeneratorRow<'a> { + gen: &'a mut TraceGenerator, + col_index: usize, + vec_row: usize, + xor_lookups_index: usize, +} +impl<'a> TraceGeneratorRow<'a> { + fn append_felt(&mut self, val: u32x16) { + self.gen.trace[self.col_index].data[self.vec_row] = + unsafe { PackedBaseField::from_simd_unchecked(val) }; + self.col_index += 1; + } + + fn append_u32(&mut self, val: u32x16) { + self.append_felt(val & u32x16::splat(0xffff)); + self.append_felt(val >> 16); + } + + fn generate(&mut self, mut v: [u32x16; 16], m: [u32x16; 16]) { + let input_v = v; + v.iter().for_each(|s| { + self.append_u32(*s); + }); + m.iter().for_each(|s| { + self.append_u32(*s); + }); + + self.g(v.get_many_mut([0, 4, 8, 12]).unwrap(), m[0], m[1]); + self.g(v.get_many_mut([1, 5, 9, 13]).unwrap(), m[2], m[3]); + self.g(v.get_many_mut([2, 6, 10, 14]).unwrap(), m[4], m[5]); + self.g(v.get_many_mut([3, 7, 11, 15]).unwrap(), m[6], m[7]); + self.g(v.get_many_mut([0, 5, 10, 15]).unwrap(), m[8], m[9]); + self.g(v.get_many_mut([1, 6, 11, 12]).unwrap(), m[10], m[11]); + self.g(v.get_many_mut([2, 7, 8, 13]).unwrap(), m[12], m[13]); + self.g(v.get_many_mut([3, 4, 9, 14]).unwrap(), m[14], m[15]); + + chain![input_v.iter(), v.iter(), m.iter()] + .flat_map(to_felts) + .enumerate() + .for_each(|(i, felt)| self.gen.round_lookup[i].data[self.vec_row] = felt); + } + + fn g(&mut self, v: [&mut u32x16; 4], m0: u32x16, m1: u32x16) { + let [a, b, c, d] = v; + + *a = self.add3_u32s(*a, *b, m0); + *d = self.xor_rotr16_u32(*a, *d); + *c = self.add2_u32s(*c, *d); + *b = self.xor_rotr_u32(*b, *c, 12); + *a = self.add3_u32s(*a, *b, m1); + *d = self.xor_rotr_u32(*a, *d, 8); + *c = self.add2_u32s(*c, *d); + *b = self.xor_rotr_u32(*b, *c, 7); + } + + /// Adds two u32s, returning the sum. + fn add2_u32s(&mut self, a: u32x16, b: u32x16) -> u32x16 { + let s = a + b; + self.append_u32(s); + s + } + + /// Adds three u32s, returning the sum. + fn add3_u32s(&mut self, a: u32x16, b: u32x16, c: u32x16) -> u32x16 { + let s = a + b + c; + self.append_u32(s); + s + } + + /// Splits a felt at r. + fn split(&mut self, a: u32x16, r: u32) -> (u32x16, u32x16) { + let h = a >> r; + let l = a & u32x16::splat((1 << r) - 1); + self.append_felt(h); + (l, h) + } + + /// Checks that a, b are in range, and computes their xor rotated right by `r` bits. + fn xor_rotr_u32(&mut self, a: u32x16, b: u32x16, r: u32) -> u32x16 { + let c = a ^ b; + let cr = (c >> r) | (c << (32 - r)); + + let (all, alh) = self.split(a & u32x16::splat(0xffff), r); + let (ahl, ahh) = self.split(a >> 16, r); + let (bll, blh) = self.split(b & u32x16::splat(0xffff), r); + let (bhl, bhh) = self.split(b >> 16, r); + + self.xor(r, all, bll); + self.xor(r, ahl, bhl); + self.xor(16 - r, alh, blh); + self.xor(16 - r, ahh, bhh); + + cr + } + + /// Checks that a, b are in range, and computes their xor rotated right by 16 bits. + fn xor_rotr16_u32(&mut self, a: u32x16, b: u32x16) -> u32x16 { + let c = a ^ b; + let cr = (c >> 16) | (c << 16); + + let (all, alh) = self.split(a & u32x16::splat(0xffff), 8); + let (ahl, ahh) = self.split(a >> 16, 8); + let (bll, blh) = self.split(b & u32x16::splat(0xffff), 8); + let (bhl, bhh) = self.split(b >> 16, 8); + + self.xor(8, all, bll); + self.xor(8, ahl, bhl); + self.xor(8, alh, blh); + self.xor(8, ahh, bhh); + + cr + } + + /// Checks that a, b are in [0, 2^w) and computes their xor. + /// a,b,a^b are assumed to fit in a single felt. + fn xor(&mut self, w: u32, a: u32x16, b: u32x16) -> u32x16 { + let c = a ^ b; + self.append_felt(c); + if self.gen.xor_lookups.len() <= self.xor_lookups_index { + self.gen.xor_lookups.push(( + w, + std::array::from_fn(|_| unsafe { + BaseColumn::uninitialized(1 << self.gen.log_size) + }), + )); + } + self.gen.xor_lookups[self.xor_lookups_index].1[0].data[self.vec_row] = + unsafe { PackedBaseField::from_simd_unchecked(a) }; + self.gen.xor_lookups[self.xor_lookups_index].1[1].data[self.vec_row] = + unsafe { PackedBaseField::from_simd_unchecked(b) }; + self.gen.xor_lookups[self.xor_lookups_index].1[2].data[self.vec_row] = + unsafe { PackedBaseField::from_simd_unchecked(c) }; + self.xor_lookups_index += 1; + c + } +} + +#[derive(Copy, Clone, Default)] +pub struct BlakeRoundInput { + pub v: [u32x16; STATE_SIZE], + pub m: [u32x16; STATE_SIZE], +} + +pub fn generate_trace( + log_size: u32, + inputs: &[BlakeRoundInput], + xor_accum: &mut XorAccums, +) -> ( + ColumnVec>, + BlakeRoundLookupData, +) { + let _span = span!(Level::INFO, "Round Generation").entered(); + let mut generator = TraceGenerator::new(log_size); + + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let mut row_gen = generator.gen_row(vec_row); + let BlakeRoundInput { v, m } = inputs.get(vec_row).copied().unwrap_or_default(); + row_gen.generate(v, m); + for (w, [a, b, _c]) in &generator.xor_lookups { + let a = a.data[vec_row].into_simd(); + let b = b.data[vec_row].into_simd(); + xor_accum.add_input(*w, a, b); + } + } + let domain = CanonicCoset::new(log_size).circle_domain(); + ( + generator + .trace + .into_iter() + .map(|eval| CircleEvaluation::::new(domain, eval)) + .collect_vec(), + BlakeRoundLookupData { + xor_lookups: generator.xor_lookups, + round_lookup: generator.round_lookup, + }, + ) +} + +pub fn generate_interaction_trace( + log_size: u32, + lookup_data: BlakeRoundLookupData, + xor_lookup_elements: &BlakeXorElements, + round_lookup_elements: &RoundElements, +) -> ( + ColumnVec>, + SecureField, +) { + let _span = span!(Level::INFO, "Generate round interaction trace").entered(); + let mut logup_gen = LogupTraceGenerator::new(log_size); + + for [(w0, l0), (w1, l1)] in lookup_data.xor_lookups.array_chunks::<2>() { + let mut col_gen = logup_gen.new_col(); + + #[allow(clippy::needless_range_loop)] + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let p0: PackedSecureField = xor_lookup_elements + .get(*w0) + .combine(&l0.each_ref().map(|l| l.data[vec_row])); + let p1: PackedSecureField = xor_lookup_elements + .get(*w1) + .combine(&l1.each_ref().map(|l| l.data[vec_row])); + col_gen.write_frac(vec_row, p0 + p1, p0 * p1); + } + + col_gen.finalize_col(); + } + + let mut col_gen = logup_gen.new_col(); + #[allow(clippy::needless_range_loop)] + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let p = round_lookup_elements + .combine(&lookup_data.round_lookup.each_ref().map(|l| l.data[vec_row])); + col_gen.write_frac(vec_row, -PackedSecureField::one(), p); + } + col_gen.finalize_col(); + + logup_gen.finalize() +} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/round/mod.rs b/Stwo_wrapper/crates/prover/src/examples/blake/round/mod.rs new file mode 100644 index 0000000..cf83113 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/blake/round/mod.rs @@ -0,0 +1,110 @@ +mod constraints; +mod gen; + +pub use gen::{generate_interaction_trace, generate_trace, BlakeRoundInput}; +use num_traits::Zero; + +use super::{BlakeXorElements, N_ROUND_INPUT_FELTS}; +use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; +use crate::core::fields::qm31::SecureField; + +pub type BlakeRoundComponent = FrameworkComponent; + +pub type RoundElements = LookupElements; + +pub struct BlakeRoundEval { + pub log_size: u32, + pub xor_lookup_elements: BlakeXorElements, + pub round_lookup_elements: RoundElements, + pub claimed_sum: SecureField, +} + +impl FrameworkEval for BlakeRoundEval { + fn log_size(&self) -> u32 { + self.log_size + } + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size + 1 + } + fn evaluate(&self, eval: E) -> E { + let blake_eval = constraints::BlakeRoundEval { + eval, + xor_lookup_elements: &self.xor_lookup_elements, + round_lookup_elements: &self.round_lookup_elements, + logup: LogupAtRow::new(1, self.claimed_sum, self.log_size), + }; + blake_eval.eval() + } +} + +pub fn blake_round_info() -> InfoEvaluator { + let component = BlakeRoundEval { + log_size: 1, + xor_lookup_elements: BlakeXorElements::dummy(), + round_lookup_elements: RoundElements::dummy(), + claimed_sum: SecureField::zero(), + }; + component.evaluate(InfoEvaluator::default()) +} + +#[cfg(test)] +mod tests { + use std::simd::Simd; + + use itertools::Itertools; + + use crate::constraint_framework::constant_columns::gen_is_first; + use crate::constraint_framework::FrameworkEval; + use crate::core::poly::circle::CanonicCoset; + use crate::examples::blake::round::r#gen::{ + generate_interaction_trace, generate_trace, BlakeRoundInput, + }; + use crate::examples::blake::round::{BlakeRoundEval, RoundElements}; + use crate::examples::blake::{BlakeXorElements, XorAccums}; + + #[test] + fn test_blake_round() { + use crate::core::pcs::TreeVec; + + const LOG_SIZE: u32 = 10; + + let mut xor_accum = XorAccums::default(); + let (trace, lookup_data) = generate_trace( + LOG_SIZE, + &(0..(1 << LOG_SIZE)) + .map(|_| BlakeRoundInput { + v: std::array::from_fn(|i| Simd::splat(i as u32)), + m: std::array::from_fn(|i| Simd::splat((i + 1) as u32)), + }) + .collect_vec(), + &mut xor_accum, + ); + + let xor_lookup_elements = BlakeXorElements::dummy(); + let round_lookup_elements = RoundElements::dummy(); + let (interaction_trace, claimed_sum) = generate_interaction_trace( + LOG_SIZE, + lookup_data, + &xor_lookup_elements, + &round_lookup_elements, + ); + + let trace = TreeVec::new(vec![trace, interaction_trace, vec![gen_is_first(LOG_SIZE)]]); + let trace_polys = trace.map_cols(|c| c.interpolate()); + + let component = BlakeRoundEval { + log_size: LOG_SIZE, + xor_lookup_elements, + round_lookup_elements, + claimed_sum, + }; + crate::constraint_framework::assert_constraints( + &trace_polys, + CanonicCoset::new(LOG_SIZE), + |eval| { + component.evaluate(eval); + }, + ) + } +} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/constraints.rs b/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/constraints.rs new file mode 100644 index 0000000..63b3cf6 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -0,0 +1,64 @@ +use itertools::{chain, Itertools}; +use num_traits::{One, Zero}; + +use super::BlakeElements; +use crate::constraint_framework::logup::LogupAtRow; +use crate::constraint_framework::EvalAtRow; +use crate::core::vcs::blake2s_ref::SIGMA; +use crate::examples::blake::round::RoundElements; +use crate::examples::blake::{Fu32, N_ROUNDS, STATE_SIZE}; + +pub fn eval_blake_scheduler_constraints( + eval: &mut E, + blake_lookup_elements: &BlakeElements, + round_lookup_elements: &RoundElements, + mut logup: LogupAtRow<2, E>, +) { + let messages: [Fu32; STATE_SIZE] = std::array::from_fn(|_| eval_next_u32(eval)); + let states: [[Fu32; STATE_SIZE]; N_ROUNDS + 1] = + std::array::from_fn(|_| std::array::from_fn(|_| eval_next_u32(eval))); + + // Schedule. + for i in 0..N_ROUNDS { + let input_state = &states[i]; + let output_state = &states[i + 1]; + let round_messages = SIGMA[i].map(|j| messages[j as usize]); + // Use triplet in round lookup. + logup.push_lookup( + eval, + E::EF::one(), + &chain![ + input_state.iter().copied().flat_map(Fu32::to_felts), + output_state.iter().copied().flat_map(Fu32::to_felts), + round_messages.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(), + round_lookup_elements, + ) + } + + let input_state = &states[0]; + let output_state = &states[N_ROUNDS]; + + // TODO(spapini): Support multiplicities. + // TODO(spapini): Change to -1. + logup.push_lookup( + eval, + E::EF::zero(), + &chain![ + input_state.iter().copied().flat_map(Fu32::to_felts), + output_state.iter().copied().flat_map(Fu32::to_felts), + messages.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(), + blake_lookup_elements, + ); + + logup.finalize(eval); +} + +fn eval_next_u32(eval: &mut E) -> Fu32 { + let l = eval.next_trace_mask(); + let h = eval.next_trace_mask(); + Fu32 { l, h } +} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/gen.rs b/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/gen.rs new file mode 100644 index 0000000..0581b2f --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/gen.rs @@ -0,0 +1,171 @@ +use std::simd::u32x16; + +use itertools::{chain, Itertools}; +use num_traits::Zero; +use tracing::{span, Level}; + +use super::{blake_scheduler_info, BlakeElements}; +use crate::constraint_framework::logup::LogupTraceGenerator; +use crate::core::backend::simd::column::BaseColumn; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::{blake2s, SimdBackend}; +use crate::core::backend::Column; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::ColumnVec; +use crate::examples::blake::round::{BlakeRoundInput, RoundElements}; +use crate::examples::blake::{to_felts, N_ROUNDS, N_ROUND_INPUT_FELTS, STATE_SIZE}; + +#[derive(Copy, Clone, Default)] +pub struct BlakeInput { + pub v: [u32x16; STATE_SIZE], + pub m: [u32x16; STATE_SIZE], +} + +pub struct BlakeSchedulerLookupData { + pub round_lookups: [[BaseColumn; N_ROUND_INPUT_FELTS]; N_ROUNDS], + pub blake_lookups: [BaseColumn; N_ROUND_INPUT_FELTS], +} +impl BlakeSchedulerLookupData { + fn new(log_size: u32) -> Self { + Self { + round_lookups: std::array::from_fn(|_| { + std::array::from_fn(|_| unsafe { BaseColumn::uninitialized(1 << log_size) }) + }), + blake_lookups: std::array::from_fn(|_| unsafe { + BaseColumn::uninitialized(1 << log_size) + }), + } + } +} + +pub fn gen_trace( + log_size: u32, + inputs: &[BlakeInput], +) -> ( + ColumnVec>, + BlakeSchedulerLookupData, + Vec, +) { + let _span = span!(Level::INFO, "Scheduler Generation").entered(); + let mut lookup_data = BlakeSchedulerLookupData::new(log_size); + let mut round_inputs = Vec::with_capacity(inputs.len() * N_ROUNDS); + + let mut trace = (0..blake_scheduler_info().mask_offsets[0].len()) + .map(|_| unsafe { BaseColumn::uninitialized(1 << log_size) }) + .collect_vec(); + + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let mut col_index = 0; + + let mut write_u32_array = |x: [u32x16; STATE_SIZE], col_index: &mut usize| { + x.iter().for_each(|x| { + to_felts(x).iter().for_each(|x| { + trace[*col_index].data[vec_row] = *x; + *col_index += 1; + }); + }); + }; + + let BlakeInput { mut v, m } = inputs.get(vec_row).copied().unwrap_or_default(); + let initial_v = v; + write_u32_array(m, &mut col_index); + write_u32_array(v, &mut col_index); + + for r in 0..N_ROUNDS { + let prev_v = v; + blake2s::round(&mut v, m, r); + write_u32_array(v, &mut col_index); + + let round_m = blake2s::SIGMA[r].map(|i| m[i as usize]); + round_inputs.push(BlakeRoundInput { + v: prev_v, + m: round_m, + }); + + chain![ + prev_v.iter().flat_map(to_felts), + v.iter().flat_map(to_felts), + round_m.iter().flat_map(to_felts) + ] + .enumerate() + .for_each(|(i, val)| lookup_data.round_lookups[r][i].data[vec_row] = val); + } + + chain![ + initial_v.iter().flat_map(to_felts), + v.iter().flat_map(to_felts), + m.iter().flat_map(to_felts) + ] + .enumerate() + .for_each(|(i, val)| lookup_data.blake_lookups[i].data[vec_row] = val); + } + + let domain = CanonicCoset::new(log_size).circle_domain(); + let trace = trace + .into_iter() + .map(|eval| CircleEvaluation::::new(domain, eval)) + .collect_vec(); + + (trace, lookup_data, round_inputs) +} +pub fn gen_interaction_trace( + log_size: u32, + lookup_data: BlakeSchedulerLookupData, + round_lookup_elements: &RoundElements, + blake_lookup_elements: &BlakeElements, +) -> ( + ColumnVec>, + SecureField, +) { + let _span = span!(Level::INFO, "Generate scheduler interaction trace").entered(); + + let mut logup_gen = LogupTraceGenerator::new(log_size); + + for [l0, l1] in lookup_data.round_lookups.array_chunks::<2>() { + let mut col_gen = logup_gen.new_col(); + + #[allow(clippy::needless_range_loop)] + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let p0: PackedSecureField = + round_lookup_elements.combine(&l0.each_ref().map(|l| l.data[vec_row])); + let p1: PackedSecureField = + round_lookup_elements.combine(&l1.each_ref().map(|l| l.data[vec_row])); + #[allow(clippy::eq_op)] + col_gen.write_frac(vec_row, p0 + p1, p0 * p1); + } + + col_gen.finalize_col(); + } + + // Last pair. If the number of round is odd (as in blake3), we combine that last round lookup + // with the entire blake lookup. + let mut col_gen = logup_gen.new_col(); + #[allow(clippy::needless_range_loop)] + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let p_blake: PackedSecureField = blake_lookup_elements.combine( + &lookup_data + .blake_lookups + .each_ref() + .map(|l| l.data[vec_row]), + ); + if N_ROUNDS % 2 == 1 { + let p_round: PackedSecureField = round_lookup_elements.combine( + &lookup_data.round_lookups[N_ROUNDS - 1] + .each_ref() + .map(|l| l.data[vec_row]), + ); + // TODO(spapini): Change blake numerator to p_blake - p_round. + col_gen.write_frac(vec_row, p_blake, p_round * p_blake); + } else { + // TODO(spapini): Change numerator to -1. + col_gen.write_frac(vec_row, PackedSecureField::zero(), p_blake); + } + } + col_gen.finalize_col(); + + logup_gen.finalize() +} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/mod.rs b/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/mod.rs new file mode 100644 index 0000000..e8a8c32 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/blake/scheduler/mod.rs @@ -0,0 +1,106 @@ +mod constraints; +mod gen; + +use constraints::eval_blake_scheduler_constraints; +pub use gen::{gen_interaction_trace, gen_trace, BlakeInput}; +use num_traits::Zero; + +use super::round::RoundElements; +use super::N_ROUND_INPUT_FELTS; +use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; +use crate::core::fields::qm31::SecureField; + +pub type BlakeSchedulerComponent = FrameworkComponent; + +pub type BlakeElements = LookupElements; + +pub struct BlakeSchedulerEval { + pub log_size: u32, + pub blake_lookup_elements: BlakeElements, + pub round_lookup_elements: RoundElements, + pub claimed_sum: SecureField, +} +impl FrameworkEval for BlakeSchedulerEval { + fn log_size(&self) -> u32 { + self.log_size + } + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size + 1 + } + fn evaluate(&self, mut eval: E) -> E { + eval_blake_scheduler_constraints( + &mut eval, + &self.blake_lookup_elements, + &self.round_lookup_elements, + LogupAtRow::new(1, self.claimed_sum, self.log_size), + ); + eval + } +} + +pub fn blake_scheduler_info() -> InfoEvaluator { + let component = BlakeSchedulerEval { + log_size: 1, + blake_lookup_elements: BlakeElements::dummy(), + round_lookup_elements: RoundElements::dummy(), + claimed_sum: SecureField::zero(), + }; + component.evaluate(InfoEvaluator::default()) +} + +#[cfg(test)] +mod tests { + use std::simd::Simd; + + use itertools::Itertools; + + use crate::constraint_framework::FrameworkEval; + use crate::core::poly::circle::CanonicCoset; + use crate::examples::blake::round::RoundElements; + use crate::examples::blake::scheduler::r#gen::{gen_interaction_trace, gen_trace, BlakeInput}; + use crate::examples::blake::scheduler::{BlakeElements, BlakeSchedulerEval}; + + #[test] + fn test_blake_scheduler() { + use crate::core::pcs::TreeVec; + + const LOG_SIZE: u32 = 10; + + let (trace, lookup_data, _round_inputs) = gen_trace( + LOG_SIZE, + &(0..(1 << LOG_SIZE)) + .map(|_| BlakeInput { + v: std::array::from_fn(|i| Simd::splat(i as u32)), + m: std::array::from_fn(|i| Simd::splat((i + 1) as u32)), + }) + .collect_vec(), + ); + + let round_lookup_elements = RoundElements::dummy(); + let blake_lookup_elements = BlakeElements::dummy(); + let (interaction_trace, claimed_sum) = gen_interaction_trace( + LOG_SIZE, + lookup_data, + &round_lookup_elements, + &blake_lookup_elements, + ); + + let trace = TreeVec::new(vec![trace, interaction_trace]); + let trace_polys = trace.map_cols(|c| c.interpolate()); + + let component = BlakeSchedulerEval { + log_size: LOG_SIZE, + blake_lookup_elements, + round_lookup_elements, + claimed_sum, + }; + crate::constraint_framework::assert_constraints( + &trace_polys, + CanonicCoset::new(LOG_SIZE), + |eval| { + component.evaluate(eval); + }, + ) + } +} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/constraints.rs b/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/constraints.rs new file mode 100644 index 0000000..00a6583 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/constraints.rs @@ -0,0 +1,52 @@ +use super::{limb_bits, XorElements}; +use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; +use crate::constraint_framework::EvalAtRow; +use crate::core::fields::m31::BaseField; + +/// Constraints for the xor table. +pub struct XorTableEval<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> { + pub eval: E, + pub lookup_elements: &'a XorElements, + pub logup: LogupAtRow<2, E>, +} +impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> + XorTableEval<'a, E, ELEM_BITS, EXPAND_BITS> +{ + pub fn eval(mut self) -> E { + // al, bl are the constant columns for the inputs: All pairs of elements in [0, + // 2^LIMB_BITS). + // cl is the constant column for the xor: al ^ bl. + let [al] = self.eval.next_interaction_mask(2, [0]); + let [bl] = self.eval.next_interaction_mask(2, [0]); + let [cl] = self.eval.next_interaction_mask(2, [0]); + for i in 0..1 << EXPAND_BITS { + for j in 0..1 << EXPAND_BITS { + let multiplicity = self.eval.next_trace_mask(); + + let a = al + + E::F::from(BaseField::from_u32_unchecked( + i << limb_bits::(), + )); + let b = bl + + E::F::from(BaseField::from_u32_unchecked( + j << limb_bits::(), + )); + let c = cl + + E::F::from(BaseField::from_u32_unchecked( + (i ^ j) << limb_bits::(), + )); + + // Add with negative multiplicity. Consumers should lookup with positive + // multiplicity. + self.logup.push_lookup( + &mut self.eval, + (-multiplicity).into(), + &[a, b, c], + self.lookup_elements, + ); + } + } + self.logup.finalize(&mut self.eval); + self.eval + } +} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/gen.rs b/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/gen.rs new file mode 100644 index 0000000..195a6ca --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/gen.rs @@ -0,0 +1,168 @@ +use std::simd::u32x16; + +use itertools::Itertools; +use tracing::{span, Level}; + +use super::{column_bits, limb_bits, XorAccumulator, XorElements}; +use crate::constraint_framework::logup::{LogupTraceGenerator, LookupElements}; +use crate::core::backend::simd::column::BaseColumn; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::ColumnVec; + +pub struct XorTableLookupData { + pub xor_accum: XorAccumulator, +} + +pub fn generate_trace( + xor_accum: XorAccumulator, +) -> ( + ColumnVec>, + XorTableLookupData, +) { + ( + xor_accum + .mults + .iter() + .map(|mult| { + CircleEvaluation::new( + CanonicCoset::new(column_bits::()).circle_domain(), + mult.clone(), + ) + }) + .collect_vec(), + XorTableLookupData { xor_accum }, + ) +} + +/// Generates the interaction trace for the xor table. +/// Returns the interaction trace, the constant trace, and the claimed sum. +#[allow(clippy::type_complexity)] +pub fn generate_interaction_trace( + lookup_data: XorTableLookupData, + lookup_elements: &XorElements, +) -> ( + ColumnVec>, + SecureField, +) { + let limb_bits = limb_bits::(); + let _span = span!(Level::INFO, "Xor interaction trace").entered(); + let offsets_vec = u32x16::from_array(std::array::from_fn(|i| i as u32)); + let mut logup_gen = LogupTraceGenerator::new(column_bits::()); + + // Iterate each pair of columns, to batch their lookup together. + // There are 2^(2*EXPAND_BITS) column, for each combination of ah, bh. + let mut iter = lookup_data + .xor_accum + .mults + .iter() + .enumerate() + .array_chunks::<2>(); + for [(i0, mults0), (i1, mults1)] in &mut iter { + let mut col_gen = logup_gen.new_col(); + + // Extract ah, bh from column index. + let ah0 = i0 as u32 >> EXPAND_BITS; + let bh0 = i0 as u32 & ((1 << EXPAND_BITS) - 1); + let ah1 = i1 as u32 >> EXPAND_BITS; + let bh1 = i1 as u32 & ((1 << EXPAND_BITS) - 1); + + // Each column has 2^(2*LIMB_BITS) rows, packed in N_LANES. + #[allow(clippy::needless_range_loop)] + for vec_row in 0..(1 << (column_bits::() - LOG_N_LANES)) { + // vec_row is LIMB_BITS of al and LIMB_BITS - LOG_N_LANES of bl. + // Extract al, blh from vec_row. + let al = vec_row >> (limb_bits - LOG_N_LANES); + let blh = vec_row & ((1 << (limb_bits - LOG_N_LANES)) - 1); + + // Construct the 3 vectors a, b, c. + let a0 = u32x16::splat((ah0 << limb_bits) | al); + let a1 = u32x16::splat((ah1 << limb_bits) | al); + // bll is just the consecutive numbers 0 .. N_LANES-1. + let b0 = u32x16::splat((bh0 << limb_bits) | (blh << LOG_N_LANES)) | offsets_vec; + let b1 = u32x16::splat((bh1 << limb_bits) | (blh << LOG_N_LANES)) | offsets_vec; + + let c0 = a0 ^ b0; + let c1 = a1 ^ b1; + + let p0: PackedSecureField = lookup_elements + .combine(&[a0, b0, c0].map(|x| unsafe { PackedBaseField::from_simd_unchecked(x) })); + let p1: PackedSecureField = lookup_elements + .combine(&[a1, b1, c1].map(|x| unsafe { PackedBaseField::from_simd_unchecked(x) })); + + let num = p1 * mults0.data[vec_row as usize] + p0 * mults1.data[vec_row as usize]; + let denom = p0 * p1; + col_gen.write_frac(vec_row as usize, -num, denom); + } + col_gen.finalize_col(); + } + + // If there is an odd number of lookup expressions, handle the last one. + if let Some(rem) = iter.into_remainder() { + if let Some((i, mults)) = rem.collect_vec().pop() { + let mut col_gen = logup_gen.new_col(); + let ah = i as u32 >> EXPAND_BITS; + let bh = i as u32 & ((1 << EXPAND_BITS) - 1); + + #[allow(clippy::needless_range_loop)] + for vec_row in 0..(1 << (column_bits::() - LOG_N_LANES)) { + // vec_row is LIMB_BITS of a, and LIMB_BITS - LOG_N_LANES of b. + let al = vec_row >> (limb_bits - LOG_N_LANES); + let a = u32x16::splat((ah << limb_bits) | al); + let bm = vec_row & ((1 << (limb_bits - LOG_N_LANES)) - 1); + let b = u32x16::splat((bh << limb_bits) | (bm << LOG_N_LANES)) | offsets_vec; + + let c = a ^ b; + + let p: PackedSecureField = lookup_elements.combine( + &[a, b, c].map(|x| unsafe { PackedBaseField::from_simd_unchecked(x) }), + ); + + let num = mults.data[vec_row as usize]; + let denom = p; + col_gen.write_frac(vec_row as usize, PackedSecureField::from(-num), denom); + } + col_gen.finalize_col(); + } + } + + let (interaction_trace, claimed_sum) = logup_gen.finalize(); + (interaction_trace, claimed_sum) +} + +/// Generates the constant trace for the xor table. +/// Returns the constant trace, the constant trace, and the claimed sum. +#[allow(clippy::type_complexity)] +pub fn generate_constant_trace( +) -> ColumnVec> { + let limb_bits = limb_bits::(); + let _span = span!(Level::INFO, "Xor constant trace").entered(); + + // Generate the constant columns. In reality, these should be generated before the proof + // even began. + let a_col: BaseColumn = (0..(1 << (column_bits::()))) + .map(|i| BaseField::from_u32_unchecked((i >> limb_bits) as u32)) + .collect(); + let b_col: BaseColumn = (0..(1 << (column_bits::()))) + .map(|i| BaseField::from_u32_unchecked((i & ((1 << limb_bits) - 1)) as u32)) + .collect(); + let c_col: BaseColumn = (0..(1 << (column_bits::()))) + .map(|i| { + BaseField::from_u32_unchecked(((i >> limb_bits) ^ (i & ((1 << limb_bits) - 1))) as u32) + }) + .collect(); + + [a_col, b_col, c_col] + .map(|x| { + CircleEvaluation::new( + CanonicCoset::new(column_bits::()).circle_domain(), + x, + ) + }) + .to_vec() +} diff --git a/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/mod.rs b/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/mod.rs new file mode 100644 index 0000000..877a651 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/blake/xor_table/mod.rs @@ -0,0 +1,158 @@ +#![allow(unused)] +//! Xor table component. +//! Generic on `ELEM_BITS` and `EXPAND_BITS`. +//! The table has all triplets of (a, b, a^b), where a, b are in the range [0,2^ELEM_BITS). +//! a,b are split into high and low parts, of size `EXPAND_BITS` and `ELEM_BITS - EXPAND_BITS` +//! respectively. +//! The component itself will hold 2^(2*EXPAND_BITS) multiplicity columns, each of size +//! 2^(ELEM_BITS - EXPAND_BITS). +//! The constant columns correspond only to the smaller table of the lower `ELEM_BITS - EXPAND_BITS` +//! xors: (a_l, b_l, a_l^b_l). +//! The rest of the lookups are computed based on these constant columns. + +mod constraints; +mod gen; + +use std::simd::u32x16; + +use itertools::Itertools; +use num_traits::Zero; +pub use r#gen::{generate_constant_trace, generate_interaction_trace, generate_trace}; + +use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; +use crate::core::backend::simd::column::BaseColumn; +use crate::core::backend::Column; +use crate::core::fields::qm31::SecureField; +use crate::core::pcs::{TreeSubspan, TreeVec}; + +pub fn trace_sizes() -> TreeVec> { + let component = XorTableEval:: { + lookup_elements: LookupElements::<3>::dummy(), + claimed_sum: SecureField::zero(), + }; + let info = component.evaluate(InfoEvaluator::default()); + info.mask_offsets + .as_cols_ref() + .map_cols(|_| column_bits::()) +} + +const fn limb_bits() -> u32 { + ELEM_BITS - EXPAND_BITS +} +pub const fn column_bits() -> u32 { + 2 * limb_bits::() +} + +/// Accumulator that keeps track of the number of times each input has been used. +pub struct XorAccumulator { + /// 2^(2*EXPAND_BITS) multiplicity columns. Index (al, bl) of column (ah, bh) is the number of + /// times ah||al ^ bh||bl has been used. + pub mults: Vec, +} +impl Default + for XorAccumulator +{ + fn default() -> Self { + Self { + mults: (0..(1 << (2 * EXPAND_BITS))) + .map(|_| BaseColumn::zeros(1 << column_bits::())) + .collect_vec(), + } + } +} +impl XorAccumulator { + pub fn add_input(&mut self, a: u32x16, b: u32x16) { + // Split a and b into high and low parts, according to ELEMENT_BITS and EXPAND_BITS. + // The high part is the index of the multiplicity column. + // The low part is the index of the element in that column. + let al = a & u32x16::splat((1 << limb_bits::()) - 1); + let ah = a >> limb_bits::(); + let bl = b & u32x16::splat((1 << limb_bits::()) - 1); + let bh = b >> limb_bits::(); + let column_idx = (ah << EXPAND_BITS) + bh; + let offset = (al << limb_bits::()) + bl; + + // Since the indices may collide, we cannot use scatter simd operations here. + // Instead, loop over packed values. + for (column_idx, offset) in column_idx.as_array().iter().zip(offset.as_array().iter()) { + self.mults[*column_idx as usize].as_mut_slice()[*offset as usize].0 += 1; + } + } +} + +/// Component that evaluates the xor table. +pub type XorTableComponent = + FrameworkComponent>; + +pub type XorElements = LookupElements<3>; + +/// Evaluates the xor table. +pub struct XorTableEval { + pub lookup_elements: XorElements, + pub claimed_sum: SecureField, +} + +impl FrameworkEval + for XorTableEval +{ + fn log_size(&self) -> u32 { + column_bits::() + } + fn max_constraint_log_degree_bound(&self) -> u32 { + column_bits::() + 1 + } + fn evaluate(&self, mut eval: E) -> E { + let xor_eval = constraints::XorTableEval::<'_, _, ELEM_BITS, EXPAND_BITS> { + eval, + lookup_elements: &self.lookup_elements, + logup: LogupAtRow::new(1, self.claimed_sum, self.log_size()), + }; + xor_eval.eval() + } +} + +#[cfg(test)] +mod tests { + use std::simd::u32x16; + + use crate::constraint_framework::logup::LookupElements; + use crate::constraint_framework::{assert_constraints, FrameworkEval}; + use crate::core::poly::circle::CanonicCoset; + use crate::examples::blake::xor_table::r#gen::{ + generate_constant_trace, generate_interaction_trace, generate_trace, + }; + use crate::examples::blake::xor_table::{column_bits, XorAccumulator, XorTableEval}; + + #[test] + fn test_xor_table() { + use crate::core::pcs::TreeVec; + + const ELEM_BITS: u32 = 9; + const EXPAND_BITS: u32 = 2; + + let mut xor_accum = XorAccumulator::::default(); + xor_accum.add_input(u32x16::splat(1), u32x16::splat(2)); + + let (trace, lookup_data) = generate_trace(xor_accum); + let lookup_elements = crate::examples::blake::xor_table::XorElements::dummy(); + let (interaction_trace, claimed_sum) = + generate_interaction_trace(lookup_data, &lookup_elements); + let constant_trace = generate_constant_trace::(); + + let trace = TreeVec::new(vec![trace, interaction_trace, constant_trace]); + let trace_polys = trace.map_cols(|c| c.interpolate()); + + let component = XorTableEval:: { + lookup_elements, + claimed_sum, + }; + assert_constraints( + &trace_polys, + CanonicCoset::new(column_bits::()), + |eval| { + component.evaluate(eval); + }, + ) + } +} diff --git a/Stwo_wrapper/crates/prover/src/examples/mod.rs b/Stwo_wrapper/crates/prover/src/examples/mod.rs new file mode 100644 index 0000000..330662d --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/mod.rs @@ -0,0 +1,5 @@ +pub mod blake; +pub mod plonk; +pub mod poseidon; +pub mod wide_fibonacci; +pub mod xor; diff --git a/Stwo_wrapper/crates/prover/src/examples/plonk/mod.rs b/Stwo_wrapper/crates/prover/src/examples/plonk/mod.rs new file mode 100644 index 0000000..58248a0 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/plonk/mod.rs @@ -0,0 +1,300 @@ +use itertools::{chain, Itertools}; +use num_traits::One; +use tracing::{span, Level}; + +use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; +use crate::constraint_framework::{ + assert_constraints, EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator, +}; +use crate::core::backend::simd::column::BaseColumn; +use crate::core::backend::simd::m31::LOG_N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Column; +use crate::core::channel::Blake2sChannel; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::pcs::{CommitmentSchemeProver, PcsConfig, TreeSubspan}; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; +use crate::core::poly::BitReversedOrder; +use crate::core::prover::{prove, StarkProof}; +use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; +use crate::core::ColumnVec; + +pub type PlonkComponent = FrameworkComponent; + +#[derive(Clone)] +pub struct PlonkEval { + pub log_n_rows: u32, + pub lookup_elements: LookupElements<2>, + pub claimed_sum: SecureField, + pub base_trace_location: TreeSubspan, + pub interaction_trace_location: TreeSubspan, + pub constants_trace_location: TreeSubspan, +} + +impl FrameworkEval for PlonkEval { + fn log_size(&self) -> u32 { + self.log_n_rows + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_n_rows + 1 + } + + fn evaluate(&self, mut eval: E) -> E { + let mut logup = LogupAtRow::<2, _>::new(1, self.claimed_sum, self.log_n_rows); + + let [a_wire] = eval.next_interaction_mask(2, [0]); + let [b_wire] = eval.next_interaction_mask(2, [0]); + // Note: c_wire could also be implicit: (self.eval.point() - M31_CIRCLE_GEN.into_ef()).x. + // A constant column is easier though. + let [c_wire] = eval.next_interaction_mask(2, [0]); + let [op] = eval.next_interaction_mask(2, [0]); + + let mult = eval.next_trace_mask(); + let a_val = eval.next_trace_mask(); + let b_val = eval.next_trace_mask(); + let c_val = eval.next_trace_mask(); + + eval.add_constraint(c_val - op * (a_val + b_val) + (E::F::one() - op) * a_val * b_val); + + logup.push_lookup( + &mut eval, + E::EF::one(), + &[a_wire, a_val], + &self.lookup_elements, + ); + logup.push_lookup( + &mut eval, + E::EF::one(), + &[b_wire, b_val], + &self.lookup_elements, + ); + logup.push_lookup( + &mut eval, + E::EF::from(-mult), + &[c_wire, c_val], + &self.lookup_elements, + ); + + logup.finalize(&mut eval); + eval + } +} + +#[derive(Clone)] +pub struct PlonkCircuitTrace { + pub mult: BaseColumn, + pub a_wire: BaseColumn, + pub b_wire: BaseColumn, + pub c_wire: BaseColumn, + pub op: BaseColumn, + pub a_val: BaseColumn, + pub b_val: BaseColumn, + pub c_val: BaseColumn, +} +pub fn gen_trace( + log_size: u32, + circuit: &PlonkCircuitTrace, +) -> ColumnVec> { + let _span = span!(Level::INFO, "Generation").entered(); + + let domain = CanonicCoset::new(log_size).circle_domain(); + [ + &circuit.mult, + &circuit.a_val, + &circuit.b_val, + &circuit.c_val, + ] + .into_iter() + .map(|eval| CircleEvaluation::::new(domain, eval.clone())) + .collect_vec() +} + +pub fn gen_interaction_trace( + log_size: u32, + circuit: &PlonkCircuitTrace, + lookup_elements: &LookupElements<2>, +) -> ( + ColumnVec>, + SecureField, +) { + let _span = span!(Level::INFO, "Generate interaction trace").entered(); + let mut logup_gen = LogupTraceGenerator::new(log_size); + + let mut col_gen = logup_gen.new_col(); + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let q0: PackedSecureField = + lookup_elements.combine(&[circuit.a_wire.data[vec_row], circuit.a_val.data[vec_row]]); + let q1: PackedSecureField = + lookup_elements.combine(&[circuit.b_wire.data[vec_row], circuit.b_val.data[vec_row]]); + col_gen.write_frac(vec_row, q0 + q1, q0 * q1); + } + col_gen.finalize_col(); + + let mut col_gen = logup_gen.new_col(); + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + let p = -circuit.mult.data[vec_row]; + let q: PackedSecureField = + lookup_elements.combine(&[circuit.c_wire.data[vec_row], circuit.c_val.data[vec_row]]); + col_gen.write_frac(vec_row, p.into(), q); + } + col_gen.finalize_col(); + + logup_gen.finalize() +} + +#[allow(unused)] +pub fn prove_fibonacci_plonk( + log_n_rows: u32, + config: PcsConfig, +) -> (PlonkComponent, StarkProof) { + assert!(log_n_rows >= LOG_N_LANES); + + // Prepare a fibonacci circuit. + let mut fib_values = vec![BaseField::one(), BaseField::one()]; + for _ in 0..(1 << log_n_rows) { + fib_values.push(fib_values[fib_values.len() - 1] + fib_values[fib_values.len() - 2]); + } + let range = 0..(1 << log_n_rows); + let mut circuit = PlonkCircuitTrace { + mult: range.clone().map(|_| 2.into()).collect(), + a_wire: range.clone().map(|i| i.into()).collect(), + b_wire: range.clone().map(|i| (i + 1).into()).collect(), + c_wire: range.clone().map(|i| (i + 2).into()).collect(), + op: range.clone().map(|_| 1.into()).collect(), + a_val: range.clone().map(|i| fib_values[i]).collect(), + b_val: range.clone().map(|i| fib_values[i + 1]).collect(), + c_val: range.clone().map(|i| fib_values[i + 2]).collect(), + }; + circuit.mult.set((1 << log_n_rows) - 1, 0.into()); + circuit.mult.set((1 << log_n_rows) - 2, 1.into()); + + // Precompute twiddles. + let span = span!(Level::INFO, "Precompute twiddles").entered(); + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(log_n_rows + config.fri_config.log_blowup_factor + 1) + .circle_domain() + .half_coset, + ); + span.exit(); + + // Setup protocol. + let channel = &mut Blake2sChannel::default(); + let commitment_scheme = + &mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); + + // Trace. + let span = span!(Level::INFO, "Trace").entered(); + let trace = gen_trace(log_n_rows, &circuit); + let mut tree_builder = commitment_scheme.tree_builder(); + let base_trace_location = tree_builder.extend_evals(trace); + tree_builder.commit(channel); + span.exit(); + + // Draw lookup element. + let lookup_elements = LookupElements::draw(channel); + + // Interaction trace. + let span = span!(Level::INFO, "Interaction").entered(); + let (trace, claimed_sum) = gen_interaction_trace(log_n_rows, &circuit, &lookup_elements); + let mut tree_builder = commitment_scheme.tree_builder(); + let interaction_trace_location = tree_builder.extend_evals(trace); + tree_builder.commit(channel); + span.exit(); + + // Constant trace. + let span = span!(Level::INFO, "Constant").entered(); + let mut tree_builder = commitment_scheme.tree_builder(); + let constants_trace_location = tree_builder.extend_evals( + chain!([circuit.a_wire, circuit.b_wire, circuit.c_wire, circuit.op] + .into_iter() + .map(|col| { + CircleEvaluation::::new( + CanonicCoset::new(log_n_rows).circle_domain(), + col, + ) + })) + .collect_vec(), + ); + tree_builder.commit(channel); + span.exit(); + + // Prove constraints. + let component = PlonkComponent::new( + &mut TraceLocationAllocator::default(), + PlonkEval { + log_n_rows, + lookup_elements, + claimed_sum, + base_trace_location, + interaction_trace_location, + constants_trace_location, + }, + ); + + // Sanity check. Remove for production. + let trace_polys = commitment_scheme + .trees + .as_ref() + .map(|t| t.polynomials.iter().cloned().collect_vec()); + assert_constraints(&trace_polys, CanonicCoset::new(log_n_rows), |mut eval| { + component.evaluate(eval); + }); + + let proof = prove::(&[&component], channel, commitment_scheme).unwrap(); + + (component, proof) +} + +#[cfg(test)] +mod tests { + use std::env; + + use crate::constraint_framework::logup::LookupElements; + use crate::core::air::Component; + use crate::core::channel::Blake2sChannel; + use crate::core::fri::FriConfig; + use crate::core::pcs::{CommitmentSchemeVerifier, PcsConfig}; + use crate::core::prover::verify; + use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; + use crate::examples::plonk::prove_fibonacci_plonk; + + #[test_log::test] + fn test_simd_plonk_prove() { + // Get from environment variable: + let log_n_instances = env::var("LOG_N_INSTANCES") + .unwrap_or_else(|_| "10".to_string()) + .parse::() + .unwrap(); + let config = PcsConfig { + pow_bits: 10, + fri_config: FriConfig::new(5, 4, 64), + }; + + // Prove. + let (component, proof) = prove_fibonacci_plonk(log_n_instances, config); + + // Verify. + // TODO: Create Air instance independently. + let channel = &mut Blake2sChannel::default(); + let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); + + // Decommit. + // Retrieve the expected column sizes in each commitment interaction, from the AIR. + let sizes = component.trace_log_degree_bounds(); + // Trace columns. + commitment_scheme.commit(proof.commitments[0], &sizes[0], channel); + // Draw lookup element. + let lookup_elements = LookupElements::<2>::draw(channel); + assert_eq!(lookup_elements, component.lookup_elements); + // TODO(spapini): Check claimed sum against first and last instances. + // Interaction columns. + commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); + // Constant columns. + commitment_scheme.commit(proof.commitments[2], &sizes[2], channel); + + verify(&[&component], channel, commitment_scheme, proof).unwrap(); + } +} diff --git a/Stwo_wrapper/crates/prover/src/examples/poseidon/mod.rs b/Stwo_wrapper/crates/prover/src/examples/poseidon/mod.rs new file mode 100644 index 0000000..c94f0ba --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/poseidon/mod.rs @@ -0,0 +1,508 @@ +//! AIR for Poseidon2 hash function from . + +use std::ops::{Add, AddAssign, Mul, Sub}; + +use itertools::Itertools; +use num_traits::One; +use tracing::{span, Level}; + +use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; +use crate::constraint_framework::{ + EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator, +}; +use crate::core::backend::simd::column::BaseColumn; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{Col, Column}; +use crate::core::channel::Blake2sChannel; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldExpOps; +use crate::core::pcs::{CommitmentSchemeProver, PcsConfig}; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; +use crate::core::poly::BitReversedOrder; +use crate::core::prover::{prove, StarkProof}; +use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; +use crate::core::ColumnVec; + +const N_LOG_INSTANCES_PER_ROW: usize = 3; +const N_INSTANCES_PER_ROW: usize = 1 << N_LOG_INSTANCES_PER_ROW; +const N_STATE: usize = 16; +const N_PARTIAL_ROUNDS: usize = 14; +const N_HALF_FULL_ROUNDS: usize = 4; +const FULL_ROUNDS: usize = 2 * N_HALF_FULL_ROUNDS; +const N_COLUMNS_PER_REP: usize = N_STATE * (1 + FULL_ROUNDS) + N_PARTIAL_ROUNDS; +const N_COLUMNS: usize = N_INSTANCES_PER_ROW * N_COLUMNS_PER_REP; +const LOG_EXPAND: u32 = 2; +// TODO(spapini): Pick better constants. +const EXTERNAL_ROUND_CONSTS: [[BaseField; N_STATE]; 2 * N_HALF_FULL_ROUNDS] = + [[BaseField::from_u32_unchecked(1234); N_STATE]; 2 * N_HALF_FULL_ROUNDS]; +const INTERNAL_ROUND_CONSTS: [BaseField; N_PARTIAL_ROUNDS] = + [BaseField::from_u32_unchecked(1234); N_PARTIAL_ROUNDS]; + +pub type PoseidonComponent = FrameworkComponent; + +pub type PoseidonElements = LookupElements<{ N_STATE * 2 }>; + +#[derive(Clone)] +pub struct PoseidonEval { + pub log_n_rows: u32, + pub lookup_elements: PoseidonElements, + pub claimed_sum: SecureField, +} +impl FrameworkEval for PoseidonEval { + fn log_size(&self) -> u32 { + self.log_n_rows + } + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_n_rows + LOG_EXPAND + } + fn evaluate(&self, mut eval: E) -> E { + let logup = LogupAtRow::new(1, self.claimed_sum, self.log_n_rows); + eval_poseidon_constraints(&mut eval, logup, &self.lookup_elements); + eval + } +} + +#[inline(always)] +/// Applies the M4 MDS matrix described in 5.1. +fn apply_m4(x: [F; 4]) -> [F; 4] +where + F: Copy + AddAssign + Add + Sub + Mul, +{ + let t0 = x[0] + x[1]; + let t02 = t0 + t0; + let t1 = x[2] + x[3]; + let t12 = t1 + t1; + let t2 = x[1] + x[1] + t1; + let t3 = x[3] + x[3] + t0; + let t4 = t12 + t12 + t3; + let t5 = t02 + t02 + t2; + let t6 = t3 + t5; + let t7 = t2 + t4; + [t6, t5, t7, t4] +} + +/// Applies the external round matrix. +/// See 5.1 and Appendix B. +fn apply_external_round_matrix(state: &mut [F; 16]) +where + F: Copy + AddAssign + Add + Sub + Mul, +{ + // Applies circ(2M4, M4, M4, M4). + for i in 0..4 { + [ + state[4 * i], + state[4 * i + 1], + state[4 * i + 2], + state[4 * i + 3], + ] = apply_m4([ + state[4 * i], + state[4 * i + 1], + state[4 * i + 2], + state[4 * i + 3], + ]); + } + for j in 0..4 { + let s = state[j] + state[j + 4] + state[j + 8] + state[j + 12]; + for i in 0..4 { + state[4 * i + j] += s; + } + } +} + +// Applies the internal round matrix. +// mu_i = 2^{i+1} + 1. +// See 5.2. +fn apply_internal_round_matrix(state: &mut [F; 16]) +where + F: Copy + AddAssign + Add + Sub + Mul, +{ + // TODO(spapini): Check that these coefficients are good according to section 5.3 of Poseidon2 + // paper. + let sum = state[1..].iter().fold(state[0], |acc, s| acc + *s); + state.iter_mut().enumerate().for_each(|(i, s)| { + // TODO(spapini): Change to rotations. + *s = *s * BaseField::from_u32_unchecked(1 << (i + 1)) + sum; + }); +} + +fn pow5(x: F) -> F { + let x2 = x * x; + let x4 = x2 * x2; + x4 * x +} + +pub fn eval_poseidon_constraints( + eval: &mut E, + mut logup: LogupAtRow<2, E>, + lookup_elements: &PoseidonElements, +) { + for _ in 0..N_INSTANCES_PER_ROW { + let mut state: [_; N_STATE] = std::array::from_fn(|_| eval.next_trace_mask()); + + // Require state lookup. + logup.push_lookup(eval, E::EF::one(), &state, lookup_elements); + + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += EXTERNAL_ROUND_CONSTS[round][i]; + }); + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter_mut().for_each(|s| { + let m = eval.next_trace_mask(); + eval.add_constraint(*s - m); + *s = m; + }); + }); + + // Partial rounds. + (0..N_PARTIAL_ROUNDS).for_each(|round| { + state[0] += INTERNAL_ROUND_CONSTS[round]; + apply_internal_round_matrix(&mut state); + state[0] = pow5(state[0]); + let m = eval.next_trace_mask(); + eval.add_constraint(state[0] - m); + state[0] = m; + }); + + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i]; + }); + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter_mut().for_each(|s| { + let m = eval.next_trace_mask(); + eval.add_constraint(*s - m); + *s = m; + }); + }); + + // Provide state lookup. + logup.push_lookup(eval, -E::EF::one(), &state, lookup_elements); + } + + logup.finalize(eval); +} + +pub struct LookupData { + initial_state: [[BaseColumn; N_STATE]; N_INSTANCES_PER_ROW], + final_state: [[BaseColumn; N_STATE]; N_INSTANCES_PER_ROW], +} +pub fn gen_trace( + log_size: u32, +) -> ( + ColumnVec>, + LookupData, +) { + let _span = span!(Level::INFO, "Generation").entered(); + assert!(log_size >= LOG_N_LANES); + let mut trace = (0..N_COLUMNS) + .map(|_| Col::::zeros(1 << log_size)) + .collect_vec(); + let mut lookup_data = LookupData { + initial_state: std::array::from_fn(|_| { + std::array::from_fn(|_| BaseColumn::zeros(1 << log_size)) + }), + final_state: std::array::from_fn(|_| { + std::array::from_fn(|_| BaseColumn::zeros(1 << log_size)) + }), + }; + + for vec_index in 0..(1 << (log_size - LOG_N_LANES)) { + // Initial state. + let mut col_index = 0; + for rep_i in 0..N_INSTANCES_PER_ROW { + let mut state: [_; N_STATE] = std::array::from_fn(|state_i| { + PackedBaseField::from_array(std::array::from_fn(|i| { + BaseField::from_u32_unchecked((vec_index * 16 + i + state_i + rep_i) as u32) + })) + }); + state.iter().copied().for_each(|s| { + trace[col_index].data[vec_index] = s; + col_index += 1; + }); + lookup_data.initial_state[rep_i] + .iter_mut() + .zip(state) + .for_each(|(res, state_i)| res.data[vec_index] = state_i); + + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += PackedBaseField::broadcast(EXTERNAL_ROUND_CONSTS[round][i]); + }); + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter().copied().for_each(|s| { + trace[col_index].data[vec_index] = s; + col_index += 1; + }); + }); + + // Partial rounds. + (0..N_PARTIAL_ROUNDS).for_each(|round| { + state[0] += PackedBaseField::broadcast(INTERNAL_ROUND_CONSTS[round]); + apply_internal_round_matrix(&mut state); + state[0] = pow5(state[0]); + trace[col_index].data[vec_index] = state[0]; + col_index += 1; + }); + + // 4 full rounds. + (0..N_HALF_FULL_ROUNDS).for_each(|round| { + (0..N_STATE).for_each(|i| { + state[i] += PackedBaseField::broadcast( + EXTERNAL_ROUND_CONSTS[round + N_HALF_FULL_ROUNDS][i], + ); + }); + apply_external_round_matrix(&mut state); + state = std::array::from_fn(|i| pow5(state[i])); + state.iter().copied().for_each(|s| { + trace[col_index].data[vec_index] = s; + col_index += 1; + }); + }); + + lookup_data.final_state[rep_i] + .iter_mut() + .zip(state) + .for_each(|(res, state_i)| res.data[vec_index] = state_i); + } + } + let domain = CanonicCoset::new(log_size).circle_domain(); + let trace = trace + .into_iter() + .map(|eval| CircleEvaluation::::new(domain, eval)) + .collect_vec(); + (trace, lookup_data) +} + +pub fn gen_interaction_trace( + log_size: u32, + lookup_data: LookupData, + lookup_elements: &PoseidonElements, +) -> ( + ColumnVec>, + SecureField, +) { + let _span = span!(Level::INFO, "Generate interaction trace").entered(); + let mut logup_gen = LogupTraceGenerator::new(log_size); + + #[allow(clippy::needless_range_loop)] + for rep_i in 0..N_INSTANCES_PER_ROW { + let mut col_gen = logup_gen.new_col(); + for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { + // Batch the 2 lookups together. + let denom0: PackedSecureField = lookup_elements.combine( + &lookup_data.initial_state[rep_i] + .each_ref() + .map(|s| s.data[vec_row]), + ); + let denom1: PackedSecureField = lookup_elements.combine( + &lookup_data.final_state[rep_i] + .each_ref() + .map(|s| s.data[vec_row]), + ); + // (1 / denom1) - (1 / denom1) = (denom1 - denom0) / (denom0 * denom1). + col_gen.write_frac(vec_row, denom1 - denom0, denom0 * denom1); + } + col_gen.finalize_col(); + } + + logup_gen.finalize() +} + +pub fn prove_poseidon( + log_n_instances: u32, + config: PcsConfig, +) -> (PoseidonComponent, StarkProof) { + assert!(log_n_instances >= N_LOG_INSTANCES_PER_ROW as u32); + let log_n_rows = log_n_instances - N_LOG_INSTANCES_PER_ROW as u32; + + // Precompute twiddles. + let span = span!(Level::INFO, "Precompute twiddles").entered(); + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(log_n_rows + LOG_EXPAND + config.fri_config.log_blowup_factor) + .circle_domain() + .half_coset, + ); + span.exit(); + + // Setup protocol. + let channel = &mut Blake2sChannel::default(); + let commitment_scheme = + &mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); + + // Trace. + let span = span!(Level::INFO, "Trace").entered(); + let (trace, lookup_data) = gen_trace(log_n_rows); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(trace); + tree_builder.commit(channel); + span.exit(); + + // Draw lookup elements. + let lookup_elements = PoseidonElements::draw(channel); + + // Interaction trace. + let span = span!(Level::INFO, "Interaction").entered(); + let (trace, claimed_sum) = gen_interaction_trace(log_n_rows, lookup_data, &lookup_elements); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(trace); + tree_builder.commit(channel); + span.exit(); + + // Prove constraints. + let component = PoseidonComponent::new( + &mut TraceLocationAllocator::default(), + PoseidonEval { + log_n_rows, + lookup_elements, + claimed_sum, + }, + ); + let proof = prove::(&[&component], channel, commitment_scheme).unwrap(); + (component, proof) +} + +#[cfg(test)] +mod tests { + use std::env; + + use itertools::Itertools; + use num_traits::One; + + use crate::constraint_framework::assert_constraints; + use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; + use crate::core::air::Component; + use crate::core::channel::Blake2sChannel; + use crate::core::fields::m31::BaseField; + use crate::core::fri::FriConfig; + use crate::core::pcs::{CommitmentSchemeVerifier, PcsConfig, TreeVec}; + use crate::core::poly::circle::CanonicCoset; + use crate::core::prover::verify; + use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; + use crate::examples::poseidon::{ + apply_internal_round_matrix, apply_m4, eval_poseidon_constraints, gen_interaction_trace, + gen_trace, prove_poseidon, PoseidonElements, + }; + use crate::math::matrix::{RowMajorMatrix, SquareMatrix}; + + #[cfg(all(target_family = "wasm", not(target_os = "wasi")))] + #[wasm_bindgen_test::wasm_bindgen_test] + fn test_poseidon_prove_wasm() { + const LOG_N_INSTANCES: u32 = 10; + let config = PcsConfig { + pow_bits: 10, + fri_config: FriConfig::new(5, 1, 64), + }; + + // Prove. + prove_poseidon(LOG_N_INSTANCES, config); + } + + #[test] + fn test_apply_m4() { + let m4 = RowMajorMatrix::::new( + [5, 7, 1, 3, 4, 6, 1, 1, 1, 3, 5, 7, 1, 1, 4, 6] + .map(BaseField::from_u32_unchecked) + .into_iter() + .collect_vec(), + ); + let state = (0..4) + .map(BaseField::from_u32_unchecked) + .collect_vec() + .try_into() + .unwrap(); + + assert_eq!(apply_m4(state), m4.mul(state)); + } + + #[test] + fn test_apply_internal() { + let mut state: [BaseField; 16] = (0..16) + .map(|i| BaseField::from_u32_unchecked(i * 3 + 187)) + .collect_vec() + .try_into() + .unwrap(); + let mut internal_matrix = [[BaseField::one(); 16]; 16]; + #[allow(clippy::needless_range_loop)] + for i in 0..16 { + internal_matrix[i][i] += BaseField::from_u32_unchecked(1 << (i + 1)); + } + let matrix = RowMajorMatrix::::new(internal_matrix.flatten().to_vec()); + + let expected_state = matrix.mul(state); + apply_internal_round_matrix(&mut state); + + assert_eq!(state, expected_state); + } + + #[test] + fn test_poseidon_constraints() { + const LOG_N_ROWS: u32 = 8; + + // Trace. + let (trace0, interaction_data) = gen_trace(LOG_N_ROWS); + let lookup_elements = LookupElements::dummy(); + let (trace1, claimed_sum) = + gen_interaction_trace(LOG_N_ROWS, interaction_data, &lookup_elements); + + let traces = TreeVec::new(vec![trace0, trace1]); + let trace_polys = + traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); + assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |mut eval| { + eval_poseidon_constraints( + &mut eval, + LogupAtRow::new(1, claimed_sum, LOG_N_ROWS), + &lookup_elements, + ); + }); + } + + #[test_log::test] + fn test_simd_poseidon_prove() { + // Note: To see time measurement, run test with + // RUST_LOG_SPAN_EVENTS=enter,close RUST_LOG=info RUST_BACKTRACE=1 RUSTFLAGS=" + // -C target-cpu=native -C target-feature=+avx512f -C opt-level=3" cargo test + // test_simd_poseidon_prove -- --nocapture + + // Get from environment variable: + let log_n_instances = env::var("LOG_N_INSTANCES") + .unwrap_or_else(|_| "10".to_string()) + .parse::() + .unwrap(); + let config = PcsConfig { + pow_bits: 10, + fri_config: FriConfig::new(5, 1, 64), + }; + + // Prove. + let (component, proof) = prove_poseidon(log_n_instances, config); + + // Verify. + // TODO: Create Air instance independently. + let channel = &mut Blake2sChannel::default(); + let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); + + // Decommit. + // Retrieve the expected column sizes in each commitment interaction, from the AIR. + let sizes = component.trace_log_degree_bounds(); + // Trace columns. + commitment_scheme.commit(proof.commitments[0], &sizes[0], channel); + // Draw lookup element. + let lookup_elements = PoseidonElements::draw(channel); + assert_eq!(lookup_elements, component.lookup_elements); + // TODO(spapini): Check claimed sum against first and last instances. + // Interaction columns. + commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); + + verify(&[&component], channel, commitment_scheme, proof).unwrap(); + } +} diff --git a/Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs b/Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs new file mode 100644 index 0000000..2bd0381 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -0,0 +1,619 @@ +use itertools::Itertools; +use std::fs::File; +use std::io::{Error, Write}; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval}; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::{Col, Column}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::FieldExpOps; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::ColumnVec; +use crate::core::fields::qm31::SecureField; +use crate::core::prover::StarkProof; +//use crate::core::vcs::hash::Hash; +use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleHasher; + +pub type WideFibonacciComponent = FrameworkComponent>; + +pub struct FibInput { + a: PackedBaseField, + b: PackedBaseField, +} + +/// A component that enforces the Fibonacci sequence. +/// Each row contains a seperate Fibonacci sequence of length `N`. +#[derive(Clone)] +pub struct WideFibonacciEval { + pub log_n_rows: u32, +} +impl FrameworkEval for WideFibonacciEval { + fn log_size(&self) -> u32 { + self.log_n_rows + } + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_n_rows + 1 + } + fn evaluate(&self, mut eval: E) -> E { + let mut a = eval.next_trace_mask(); + let mut b = eval.next_trace_mask(); + for _ in 2..N { + let c = eval.next_trace_mask(); + eval.add_constraint(c - (a.square() + b.square())); + a = b; + b = c; + } + eval + } +} + +pub fn generate_trace( + log_size: u32, + inputs: &[FibInput], +) -> ColumnVec> { + assert!(log_size >= LOG_N_LANES); + assert_eq!(inputs.len(), 1 << (log_size - LOG_N_LANES)); + let mut trace = (0..N) + .map(|_| Col::::zeros(1 << log_size)) + .collect_vec(); + for (vec_index, input) in inputs.iter().enumerate() { + let mut a = input.a; + let mut b = input.b; + trace[0].data[vec_index] = a; + trace[1].data[vec_index] = b; + trace.iter_mut().skip(2).for_each(|col| { + (a, b) = (b, a.square() + b.square()); + col.data[vec_index] = b; + }); + } + let domain = CanonicCoset::new(log_size).circle_domain(); + trace + .into_iter() + .map(|eval| CircleEvaluation::::new(domain, eval)) + .collect_vec() +} + +pub fn save_secure_field_element(file: &mut File, q31_element: SecureField) -> Result<(), Error> { + file.write_all(b"[\"")?; + file.write_all(&q31_element.0.0.to_string().into_bytes())?; + file.write_all(b"\",\"")?; + file.write_all(&q31_element.0.1.to_string().into_bytes())?; + file.write_all(b"\",\"")?; + file.write_all(&q31_element.1.0.to_string().into_bytes())?; + file.write_all(b"\",\"")?; + file.write_all(&q31_element.1.1.to_string().into_bytes())?; + file.write_all(b"\"]")?; + Ok(()) +} + +pub fn pretty_save_poseidon_bls_proof(proof: &StarkProof) -> Result<(),Error> { + let mut file = File::create("proof.json")?; + file.write_all(b"{\n")?; + //commitments + file.write_all(b"\t\"commitments\" : \n\t\t[")?; + for i in 0..proof.commitments.len() { + file.write_all(b"\"")?; + file.write_all(&proof.commitments[i].to_string().into_bytes())?; + file.write_all(b"\"")?; + if proof.commitments.len() != i+1 { + file.write_all(b",\n\t\t")?; + } + } + file.write_all(b"],\n\n")?; + // Sampled Values + file.write_all(b"\t\"sampled_values_0\" : \n\t\t[")?; + for i in 0..proof.commitment_scheme_proof.sampled_values.0[0].len() { + save_secure_field_element(&mut file,proof.commitment_scheme_proof.sampled_values.0[0][i][0])?; + if proof.commitment_scheme_proof.sampled_values.0[0].len() != i+1 { + file.write_all(b",\n\t\t")?; + } + } + file.write_all(b"],\n\n")?; + file.write_all(b"\t\"sampled_values_1\" : \n\t\t[")?; + for i in 0..proof.commitment_scheme_proof.sampled_values.0[1].len() { + save_secure_field_element(&mut file,proof.commitment_scheme_proof.sampled_values.0[1][i][0])?; + if proof.commitment_scheme_proof.sampled_values.0[1].len() != i+1 { + file.write_all(b",\n\t\t")?; + } + } + file.write_all(b"],\n\n")?; + + //decommitments + file.write_all(b"\t\"decommitment_0\" : \n\t\t[")?; + for i in 0..proof.commitment_scheme_proof.decommitments.0[0].hash_witness.len() { + file.write_all(b"\"")?; + file.write_all(&proof.commitment_scheme_proof.decommitments.0[0].hash_witness[i].to_string().into_bytes())?; + file.write_all(b"\"")?; + if proof.commitment_scheme_proof.decommitments.0[0].hash_witness.len() != i+1 { + file.write_all(b",\n\t\t")?; + } + } + file.write_all(b"],\n\n")?; + file.write_all(b"\t\"decommitment_1\" : \n\t\t[")?; + for i in 0..proof.commitment_scheme_proof.decommitments.0[1].hash_witness.len() { + file.write_all(b"\"")?; + file.write_all(&proof.commitment_scheme_proof.decommitments.0[1].hash_witness[i].to_string().into_bytes())?; + file.write_all(b"\"")?; + if proof.commitment_scheme_proof.decommitments.0[1].hash_witness.len() != i+1 { + file.write_all(b",\n\t\t")?; + } + } + file.write_all(b"],\n\n")?; + + //Queried_values + file.write_all(b"\t\"queried_values_0\" : \n\t\t[")?; + for i in 0..proof.commitment_scheme_proof.queried_values.0[0].len() { + file.write_all(b"[")?; + for j in 0..6 { + file.write_all(b"\"")?; + file.write_all(&proof.commitment_scheme_proof.queried_values.0[0][i][j].to_string().into_bytes())?; + file.write_all(b"\"")?; + if j != 5 { + file.write_all(b",")?; + } else { file.write_all(b"]")?; } + } + if proof.commitment_scheme_proof.queried_values.0[0].len() != i+1 { + file.write_all(b",\n\t\t")?; + } + } + file.write_all(b"],\n\n")?; + file.write_all(b"\t\"queried_values_1\" : [")?; + for i in 0..proof.commitment_scheme_proof.queried_values.0[1].len() { + file.write_all(b"[")?; + for j in 0..6 { + file.write_all(b"\"")?; + file.write_all(&proof.commitment_scheme_proof.queried_values.0[1][i][j].to_string().into_bytes())?; + file.write_all(b"\"")?; + if j != 5 { + file.write_all(b",")?; + } else { file.write_all(b"]")?; } + } + if proof.commitment_scheme_proof.queried_values.0[1].len() != i+1 { + file.write_all(b",\n\t\t")?; + } + } + file.write_all(b"],\n\n")?; + + //proof of work + file.write_all(b"\t\"proof of work\" : \"")?; + file.write_all(&proof.commitment_scheme_proof.proof_of_work.to_string().into_bytes())?; + file.write_all(b"\",\n\n")?; + + //last FRI layer coeffs + file.write_all(b"\t\"coeffs\" : ")?; + save_secure_field_element(&mut file,proof.commitment_scheme_proof.fri_proof.last_layer_poly.coeffs[0] )?; + file.write_all(b",\n\n")?; + + //intermediate FRI layers + for i in 0..6 { + //commitment + file.write_all(b"\t\"inner_commitment_")?; + file.write_all(&i.to_string().into_bytes())?; + file.write_all(b"\" : \"")?; + file.write_all(&proof.commitment_scheme_proof.fri_proof.inner_layers[i].commitment.0.to_string().into_bytes())?; + file.write_all(b"\",\n\n")?; + + //decommitment + file.write_all(b"\t\"inner_decommitment_")?; + file.write_all(&i.to_string().into_bytes())?; + file.write_all(b"\" : \n\t\t[")?; + for j in 0..proof.commitment_scheme_proof.fri_proof.inner_layers[i].decommitment.hash_witness.len() { + file.write_all(b"\"")?; + file.write_all(&proof.commitment_scheme_proof.fri_proof.inner_layers[i].decommitment.hash_witness[j].to_string().into_bytes())?; + file.write_all(b"\"")?; + if proof.commitment_scheme_proof.fri_proof.inner_layers[i].decommitment.hash_witness.len() != j+1 { + file.write_all(b",\n\t\t")?; + } + } + file.write_all(b"],\n\n")?; + + //evals_subset + file.write_all(b"\t\"inner_evals_subset_")?; + file.write_all(&i.to_string().into_bytes())?; + file.write_all(b"\" : \n\t\t[")?; + for j in 0..proof.commitment_scheme_proof.fri_proof.inner_layers[i].evals_subset.len() { + save_secure_field_element(&mut file,proof.commitment_scheme_proof.fri_proof.inner_layers[i].evals_subset[j])?; + if proof.commitment_scheme_proof.fri_proof.inner_layers[i].evals_subset.len() != j+1 { + file.write_all(b",\n\t\t")?; + } + } + if i != 5 { + file.write_all(b"],\n\n")?; + } else { + file.write_all(b"]\n}")?; + } + } + Ok(()) +} + +pub fn compressed_save_poseidon_bls_proof(proof: &StarkProof) -> Result<(),Error> { + let mut file = File::create("proof.json")?; + file.write_all(b"{")?; + //commitments + file.write_all(b"\"commitments\":[")?; + for i in 0..proof.commitments.len() { + file.write_all(b"\"")?; + file.write_all(&proof.commitments[i].to_string().into_bytes())?; + file.write_all(b"\"")?; + if proof.commitments.len() != i+1 { + file.write_all(b",")?; + } + } + file.write_all(b"],")?; + // Sampled Values + file.write_all(b"\"sampled_values_0\":[")?; + for i in 0..proof.commitment_scheme_proof.sampled_values.0[0].len() { + save_secure_field_element(&mut file,proof.commitment_scheme_proof.sampled_values.0[0][i][0])?; + if proof.commitment_scheme_proof.sampled_values.0[0].len() != i+1 { + file.write_all(b",")?; + } + } + file.write_all(b"],")?; + file.write_all(b"\"sampled_values_1\":[")?; + for i in 0..proof.commitment_scheme_proof.sampled_values.0[1].len() { + save_secure_field_element(&mut file,proof.commitment_scheme_proof.sampled_values.0[1][i][0])?; + if proof.commitment_scheme_proof.sampled_values.0[1].len() != i+1 { + file.write_all(b",")?; + } + } + file.write_all(b"],")?; + + //decommitments + file.write_all(b"\"decommitment_0\":[")?; + for i in 0..proof.commitment_scheme_proof.decommitments.0[0].hash_witness.len() { + file.write_all(b"\"")?; + file.write_all(&proof.commitment_scheme_proof.decommitments.0[0].hash_witness[i].to_string().into_bytes())?; + file.write_all(b"\"")?; + if proof.commitment_scheme_proof.decommitments.0[0].hash_witness.len() != i+1 { + file.write_all(b",")?; + } + } + file.write_all(b"],")?; + file.write_all(b"\"decommitment_1\":[")?; + for i in 0..proof.commitment_scheme_proof.decommitments.0[1].hash_witness.len() { + file.write_all(b"\"")?; + file.write_all(&proof.commitment_scheme_proof.decommitments.0[1].hash_witness[i].to_string().into_bytes())?; + file.write_all(b"\"")?; + if proof.commitment_scheme_proof.decommitments.0[1].hash_witness.len() != i+1 { + file.write_all(b",")?; + } + } + file.write_all(b"],")?; + + //Queried_values + file.write_all(b"\"queried_values_0\":[")?; + for i in 0..proof.commitment_scheme_proof.queried_values.0[0].len() { + file.write_all(b"[")?; + for j in 0..6 { + file.write_all(b"\"")?; + file.write_all(&proof.commitment_scheme_proof.queried_values.0[0][i][j].to_string().into_bytes())?; + file.write_all(b"\"")?; + if j != 5 { + file.write_all(b",")?; + } else { file.write_all(b"]")?; } + } + if proof.commitment_scheme_proof.queried_values.0[0].len() != i+1 { + file.write_all(b",")?; + } + } + file.write_all(b"],")?; + file.write_all(b"\"queried_values_1\":[")?; + for i in 0..proof.commitment_scheme_proof.queried_values.0[1].len() { + file.write_all(b"[")?; + for j in 0..6 { + file.write_all(b"\"")?; + file.write_all(&proof.commitment_scheme_proof.queried_values.0[1][i][j].to_string().into_bytes())?; + file.write_all(b"\"")?; + if j != 5 { + file.write_all(b",")?; + } else { file.write_all(b"]")?; } + } + if proof.commitment_scheme_proof.queried_values.0[1].len() != i+1 { + file.write_all(b",")?; + } + } + file.write_all(b"],")?; + + //proof of work + file.write_all(b"\"proof of work\":\"")?; + file.write_all(&proof.commitment_scheme_proof.proof_of_work.to_string().into_bytes())?; + file.write_all(b"\",")?; + + //last FRI layer coeffs + file.write_all(b"\"coeffs\":")?; + save_secure_field_element(&mut file,proof.commitment_scheme_proof.fri_proof.last_layer_poly.coeffs[0] )?; + file.write_all(b",")?; + + //intermediate FRI layers + for i in 0..6 { + //commitment + file.write_all(b"\"inner_commitment_")?; + file.write_all(&i.to_string().into_bytes())?; + file.write_all(b"\":\"")?; + file.write_all(&proof.commitment_scheme_proof.fri_proof.inner_layers[i].commitment.0.to_string().into_bytes())?; + file.write_all(b"\",")?; + + //decommitment + file.write_all(b"\"inner_decommitment_")?; + file.write_all(&i.to_string().into_bytes())?; + file.write_all(b"\":[")?; + for j in 0..proof.commitment_scheme_proof.fri_proof.inner_layers[i].decommitment.hash_witness.len() { + file.write_all(b"\"")?; + file.write_all(&proof.commitment_scheme_proof.fri_proof.inner_layers[i].decommitment.hash_witness[j].to_string().into_bytes())?; + file.write_all(b"\"")?; + if proof.commitment_scheme_proof.fri_proof.inner_layers[i].decommitment.hash_witness.len() != j+1 { + file.write_all(b",")?; + } + } + file.write_all(b"],")?; + + //evals_subset + file.write_all(b"\"inner_evals_subset_")?; + file.write_all(&i.to_string().into_bytes())?; + file.write_all(b"\":[")?; + for j in 0..proof.commitment_scheme_proof.fri_proof.inner_layers[i].evals_subset.len() { + save_secure_field_element(&mut file,proof.commitment_scheme_proof.fri_proof.inner_layers[i].evals_subset[j])?; + if proof.commitment_scheme_proof.fri_proof.inner_layers[i].evals_subset.len() != j+1 { + file.write_all(b",")?; + } + } + if i != 5 { + file.write_all(b"],")?; + } else { + file.write_all(b"]}")?; + } + } + + + + + Ok(()) +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use num_traits::One; + + use super::{pretty_save_poseidon_bls_proof, WideFibonacciEval}; + use crate::constraint_framework::{ + assert_constraints, AssertEvaluator, FrameworkEval, TraceLocationAllocator, + }; + use crate::core::air::Component; + use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::Column; + use crate::core::channel::Blake2sChannel; + #[cfg(not(target_arch = "wasm32"))] + use crate::core::channel::Poseidon252Channel; + #[cfg(not(target_arch = "wasm32"))] + use crate::core::channel::PoseidonBLSChannel; + use crate::core::fields::m31::BaseField; + use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; + use crate::core::poly::BitReversedOrder; + use crate::core::prover::{prove, verify}; + use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; + #[cfg(not(target_arch = "wasm32"))] + use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel; + #[cfg(not(target_arch = "wasm32"))] + use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleChannel; + use crate::core::ColumnVec; + use crate::examples::wide_fibonacci::{generate_trace, FibInput, WideFibonacciComponent}; + + const FIB_SEQUENCE_LENGTH: usize = 100; + + fn generate_test_trace( + log_n_instances: u32, + ) -> ColumnVec> { + let inputs = (0..(1 << (log_n_instances - LOG_N_LANES))) + .map(|i| FibInput { + a: PackedBaseField::one(), + b: PackedBaseField::from_array(std::array::from_fn(|j| { + BaseField::from_u32_unchecked((i * 16 + j) as u32) + })), + }) + .collect_vec(); + generate_trace::(log_n_instances, &inputs) + } + + fn fibonacci_constraint_evaluator(eval: AssertEvaluator<'_>) { + WideFibonacciEval:: { log_n_rows: N }.evaluate(eval); + } + + #[test] + fn test_wide_fibonacci_constraints() { + const LOG_N_INSTANCES: u32 = 6; + let traces = TreeVec::new(vec![generate_test_trace(LOG_N_INSTANCES)]); + let trace_polys = + traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); + + assert_constraints( + &trace_polys, + CanonicCoset::new(LOG_N_INSTANCES), + fibonacci_constraint_evaluator::, + ); + } + + #[test] + #[should_panic] + fn test_wide_fibonacci_constraints_fails() { + const LOG_N_INSTANCES: u32 = 6; + + let mut trace = generate_test_trace(LOG_N_INSTANCES); + // Modify the trace such that a constraint fail. + trace[17].values.set(2, BaseField::one()); + let traces = TreeVec::new(vec![trace]); + let trace_polys = + traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); + + assert_constraints( + &trace_polys, + CanonicCoset::new(LOG_N_INSTANCES), + fibonacci_constraint_evaluator::, + ); + } + + #[test_log::test] + fn test_wide_fib_prove() { + const LOG_N_INSTANCES: u32 = 6; + let config = PcsConfig::default(); + // Precompute twiddles. + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(LOG_N_INSTANCES + 1 + config.fri_config.log_blowup_factor) + .circle_domain() + .half_coset, + ); + + // Setup protocol. + let prover_channel = &mut Blake2sChannel::default(); + let commitment_scheme = + &mut CommitmentSchemeProver::::new( + config, &twiddles, + ); + + // Trace. + let trace = generate_test_trace(LOG_N_INSTANCES); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(trace); + tree_builder.commit(prover_channel); + + // Prove constraints. + let component = WideFibonacciComponent::new( + &mut TraceLocationAllocator::default(), + WideFibonacciEval:: { + log_n_rows: LOG_N_INSTANCES, + }, + ); + + let proof = prove::( + &[&component], + prover_channel, + commitment_scheme, + ) + .unwrap(); + + // Verify. + let verifier_channel = &mut Blake2sChannel::default(); + let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); + + // Retrieve the expected column sizes in each commitment interaction, from the AIR. + let sizes = component.trace_log_degree_bounds(); + commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); + verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap(); + } + + #[test] + #[cfg(not(target_arch = "wasm32"))] + fn test_wide_fib_prove_with_poseidon() { + const LOG_N_INSTANCES: u32 = 6; + + let config = PcsConfig::default(); + // Precompute twiddles. + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(LOG_N_INSTANCES + 1 + config.fri_config.log_blowup_factor) + .circle_domain() + .half_coset, + ); + + // Setup protocol. + let prover_channel = &mut Poseidon252Channel::default(); + let commitment_scheme = + &mut CommitmentSchemeProver::::new( + config, &twiddles, + ); + + // Trace. + let trace = generate_test_trace(LOG_N_INSTANCES); + //Initialize the parameters of the trace + let mut tree_builder = commitment_scheme.tree_builder(); + // Interpolation of the columns + tree_builder.extend_evals(trace); + // Compute the evaluations of the polynomials, build a Merkle tree of the evaluation and + // update the channel (Fiat-Shamir) with the root of the tree (Transcript <-- root) + tree_builder.commit(prover_channel); + + // Prove constraints. + let component = WideFibonacciComponent::new( + &mut TraceLocationAllocator::default(), + WideFibonacciEval:: { + log_n_rows: LOG_N_INSTANCES, + }, + ); + let proof = prove::( + &[&component], + prover_channel, + commitment_scheme, + ) + .unwrap(); + + // Verify. + let verifier_channel = &mut Poseidon252Channel::default(); + let commitment_scheme = + &mut CommitmentSchemeVerifier::::new(config); + + // Retrieve the expected column sizes in each commitment interaction, from the AIR. + let sizes = component.trace_log_degree_bounds(); + commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); + verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap(); + } + + #[test] + #[cfg(not(target_arch = "wasm32"))] + fn test_wide_fib_prove_with_poseidon_bls() { + + const LOG_N_INSTANCES: u32 = 6; + + let config = PcsConfig::default(); + // Precompute twiddles. + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(LOG_N_INSTANCES + 1 + config.fri_config.log_blowup_factor) + .circle_domain() + .half_coset, + ); + + // Setup protocol. + let prover_channel = &mut PoseidonBLSChannel::default(); + let commitment_scheme = + &mut CommitmentSchemeProver::::new( + config, &twiddles, + ); + + // Trace. + let trace = generate_test_trace(LOG_N_INSTANCES); + //Initialize the parameters of the trace + let mut tree_builder = commitment_scheme.tree_builder(); + // Interpolation of the columns + tree_builder.extend_evals(trace); + // Compute the evaluations of the polynomials, build a Merkle tree of the evaluation and + // update the channel (Fiat-Shamir) with the root of the tree (Transcript <-- root) + tree_builder.commit(prover_channel); + + // Prove constraints. + let component = WideFibonacciComponent::new( + &mut TraceLocationAllocator::default(), + WideFibonacciEval:: { + log_n_rows: LOG_N_INSTANCES, + }, + ); + let proof = prove::( + &[&component], + prover_channel, + commitment_scheme, + ) + .unwrap(); + _ = pretty_save_poseidon_bls_proof(&proof); + + // Verify. + let verifier_channel = &mut PoseidonBLSChannel::default(); + let commitment_scheme = + &mut CommitmentSchemeVerifier::::new(config); + + // Retrieve the expected column sizes in each commitment interaction, from the AIR. + let sizes = component.trace_log_degree_bounds(); + commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); + verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap(); + } +} + diff --git a/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs b/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs new file mode 100644 index 0000000..53ae956 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/accumulation.rs @@ -0,0 +1,186 @@ +use std::iter::zip; +use std::ops::{AddAssign, Mul}; + +use educe::Educe; +use num_traits::One; + +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Backend; +use crate::core::circle::M31_CIRCLE_LOG_ORDER; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::lookups::mle::Mle; +use crate::core::utils::generate_secure_powers; + +pub const MIN_LOG_BLOWUP_FACTOR: u32 = 1; + +/// Max number of variables for multilinear polynomials that get compiled into a univariate +/// IOP for multilinear eval at point. +pub const MAX_MLE_N_VARIABLES: u32 = M31_CIRCLE_LOG_ORDER - MIN_LOG_BLOWUP_FACTOR; + +/// Accumulates [`Mle`]s grouped by their number of variables. +pub struct MleCollection { + mles_by_n_variables: Vec>>>, +} + +impl MleCollection { + /// Appends an [`Mle`] to the collection. + pub fn push(&mut self, mle: impl Into>) { + let mle = mle.into(); + let mles = self.mles_by_n_variables[mle.n_variables()].get_or_insert(Vec::new()); + mles.push(mle); + } +} + +impl MleCollection { + /// Performs a random linear combination of all MLEs, grouped by their number of variables. + /// + /// MLEs are returned in ascending order by number of variables. + pub fn random_linear_combine_by_n_variables( + self, + alpha: SecureField, + ) -> Vec> { + self.mles_by_n_variables + .into_iter() + .flatten() + .map(|mles| mle_random_linear_combination(mles, alpha)) + .collect() + } +} + +/// # Panics +/// +/// Panics if `mles` is empty or all MLEs don't have the same number of variables. +fn mle_random_linear_combination( + mles: Vec>, + alpha: SecureField, +) -> Mle { + assert!(!mles.is_empty()); + let n_variables = mles[0].n_variables(); + assert!(mles.iter().all(|mle| mle.n_variables() == n_variables)); + let alpha_powers = generate_secure_powers(alpha, mles.len()).into_iter().rev(); + let mut mle_and_coeff = zip(mles, alpha_powers); + + // The last value can initialize the accumulator. + let (mle, coeff) = mle_and_coeff.next_back().unwrap(); + assert!(coeff.is_one()); + let mut acc_mle = mle.into_secure_mle(); + + for (mle, coeff) in mle_and_coeff { + match mle { + DynMle::Base(mle) => combine(&mut acc_mle.data, &mle.data, coeff.into()), + DynMle::Secure(mle) => combine(&mut acc_mle.data, &mle.data, coeff.into()), + } + } + + acc_mle +} + +/// Computes all `acc[i] += alpha * v[i]`. +pub fn combine + Copy, F: Copy>( + acc: &mut [EF], + v: &[F], + alpha: EF, +) { + assert_eq!(acc.len(), v.len()); + zip(acc, v).for_each(|(acc, &v)| *acc += alpha * v); +} + +impl Default for MleCollection { + fn default() -> Self { + Self { + mles_by_n_variables: vec![None; MAX_MLE_N_VARIABLES as usize + 1], + } + } +} + +/// Dynamic dispatch for [`Mle`] types. +#[derive(Educe)] +#[educe(Debug, Clone)] +pub enum DynMle { + Base(Mle), + Secure(Mle), +} + +impl DynMle { + fn n_variables(&self) -> usize { + match self { + DynMle::Base(mle) => mle.n_variables(), + DynMle::Secure(mle) => mle.n_variables(), + } + } +} + +impl From> for DynMle { + fn from(mle: Mle) -> Self { + DynMle::Secure(mle) + } +} + +impl From> for DynMle { + fn from(mle: Mle) -> Self { + DynMle::Base(mle) + } +} + +impl DynMle { + fn into_secure_mle(self) -> Mle { + match self { + Self::Base(mle) => Mle::new(mle.into_evals().into_secure_column()), + Self::Secure(mle) => mle, + } + } +} + +#[cfg(test)] +mod tests { + use std::iter::repeat; + + use num_traits::Zero; + + use crate::core::backend::simd::SimdBackend; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::Field; + use crate::core::lookups::mle::{Mle, MleOps}; + use crate::examples::xor::gkr_lookups::accumulation::MleCollection; + + #[test] + fn random_linear_combine_by_n_variables() { + const SMALL_N_VARS: usize = 4; + const LARGE_N_VARS: usize = 6; + let alpha = SecureField::from(10); + let mut mle_collection = MleCollection::::default(); + mle_collection.push(const_mle(SMALL_N_VARS, BaseField::from(1))); + mle_collection.push(const_mle(SMALL_N_VARS, SecureField::from(2))); + mle_collection.push(const_mle(LARGE_N_VARS, BaseField::from(3))); + mle_collection.push(const_mle(LARGE_N_VARS, SecureField::from(4))); + mle_collection.push(const_mle(LARGE_N_VARS, SecureField::from(5))); + let small_eval_point = [SecureField::zero(); SMALL_N_VARS]; + let large_eval_point = [SecureField::zero(); LARGE_N_VARS]; + + let [small_mle, large_mle] = mle_collection + .random_linear_combine_by_n_variables(alpha) + .try_into() + .unwrap(); + + assert_eq!(small_mle.n_variables(), SMALL_N_VARS); + assert_eq!(large_mle.n_variables(), LARGE_N_VARS); + assert_eq!( + small_mle.eval_at_point(&small_eval_point), + SecureField::from(1) * alpha + SecureField::from(2) + ); + assert_eq!( + large_mle.eval_at_point(&large_eval_point), + (SecureField::from(3) * alpha + SecureField::from(4)) * alpha + SecureField::from(5) + ); + } + + fn const_mle(n_variables: usize, v: F) -> Mle + where + B: MleOps, + F: Field, + { + Mle::new(repeat(v).take(1 << n_variables).collect()) + } +} diff --git a/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs new file mode 100644 index 0000000..5a5d605 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -0,0 +1,571 @@ +//! Multilinear extension (MLE) eval at point constraints. +// TODO(andrew): Remove in downstream PR. +#![allow(dead_code)] + +use std::array; + +use itertools::Itertools; +use num_traits::{One, Zero}; + +use crate::constraint_framework::EvalAtRow; +use crate::core::backend::simd::SimdBackend; +use crate::core::circle::{CirclePoint, Coset}; +use crate::core::constraints::{coset_vanishing, point_vanishing}; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SecureColumnByCoords; +use crate::core::fields::{Field, FieldExpOps}; +use crate::core::lookups::utils::eq; +use crate::core::poly::circle::{CanonicCoset, SecureEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; + +/// Evaluates constraints that guarantee an MLE evaluates to a claim at a given point. +/// +/// `mle_coeffs_col_eval` should be the evaluation of the column containing the coefficients of the +/// MLE in the multilinear Lagrange basis. `mle_claim_shift` should equal `claim / 2^N_VARIABLES`. +pub fn eval_mle_eval_constraints( + mle_interaction: usize, + const_interaction: usize, + eval: &mut E, + mle_coeffs_col_eval: E::EF, + mle_eval_point: MleEvalPoint, + mle_claim_shift: SecureField, + carry_quotients_col_eval: E::EF, +) { + let eq_col_eval = eval_eq_constraints( + mle_interaction, + const_interaction, + eval, + mle_eval_point, + carry_quotients_col_eval, + ); + let terms_col_eval = mle_coeffs_col_eval * eq_col_eval; + eval_prefix_sum_constraints(mle_interaction, eval, terms_col_eval, mle_claim_shift) +} + +#[derive(Debug, Clone, Copy)] +pub struct MleEvalPoint { + // Equals `eq({0}^|p|, p)`. + eq_0_p: SecureField, + // Equals `eq({1}^|p|, p)`. + eq_1_p: SecureField, + // Index `i` stores `eq(({1}^|i|, 0), p[0..i+1]) / eq(({0}^|i|, 1), p[0..i+1])`. + eq_carry_quotients: [SecureField; N_VARIABLES], + // Point `p`. + p: [SecureField; N_VARIABLES], +} + +impl MleEvalPoint { + /// Creates new metadata from point `p`. + pub fn new(p: [SecureField; N_VARIABLES]) -> Self { + let zero = SecureField::zero(); + let one = SecureField::one(); + + Self { + eq_0_p: eq(&[zero; N_VARIABLES], &p), + eq_1_p: eq(&[one; N_VARIABLES], &p), + eq_carry_quotients: array::from_fn(|i| { + let mut numer_assignment = vec![one; i + 1]; + numer_assignment[i] = zero; + let mut denom_assignment = vec![zero; i + 1]; + denom_assignment[i] = one; + eq(&numer_assignment, &p[..i + 1]) / eq(&denom_assignment, &p[..i + 1]) + }), + p, + } + } +} + +/// Evaluates EqEvals constraints on a column. +/// +/// Returns the evaluation at offset 0 on the column. +/// +/// Given a column `c(P)` defined on a circle domain `D`, and an MLE eval point `(r0, r1, ...)` +/// evaluates constraints that guarantee: `c(D[b0, b1, ...]) = eq((b0, b1, ...), (r0, r1, ...))`. +/// +/// See (Section 5.1). +fn eval_eq_constraints( + eq_interaction: usize, + const_interaction: usize, + eval: &mut E, + mle_eval_point: MleEvalPoint, + carry_quotients_col_eval: E::EF, +) -> E::EF { + let [curr, next_next] = eval.next_extension_interaction_mask(eq_interaction, [0, 2]); + let [is_first, is_second] = eval.next_interaction_mask(const_interaction, [0, -1]); + + // Check the initial value on half_coset0 and final value on half_coset1. + // Combining these constraints is safe because `is_first` and `is_second` are never + // non-zero at the same time on the trace. + let half_coset0_initial_check = (curr - mle_eval_point.eq_0_p) * is_first; + let half_coset1_final_check = (curr - mle_eval_point.eq_1_p) * is_second; + eval.add_constraint(half_coset0_initial_check + half_coset1_final_check); + + // Check all the steps. + eval.add_constraint(curr - next_next * carry_quotients_col_eval); + + curr +} + +/// Evaluates inclusive prefix sum constraints on a column. +/// +/// Note the column values must be shifted by `cumulative_sum_shift` so the last value equals zero. +/// `cumulative_sum_shift` should equal `cumulative_sum / column_size`. +fn eval_prefix_sum_constraints( + interaction: usize, + eval: &mut E, + row_diff: E::EF, + cumulative_sum_shift: SecureField, +) { + let [curr, prev] = eval.next_extension_interaction_mask(interaction, [0, -1]); + eval.add_constraint(curr - prev - row_diff + cumulative_sum_shift); +} + +/// Returns succinct Eq carry quotients column. +/// +/// Given column `c(P)` defined on a [`CircleDomain`] `D = +-C`, and an MLE eval point +/// `(r0, r1, ...)` let `c(D[b0, b1, ...]) = eq((b0, b1, ...), (r0, r1, ...))`. This function +/// returns column `q(P)` such that all `c(C[i]) = c(C[i + 1]) * q(C[i])` and +/// `c(-C[i]) = c(-C[i + 1]) * q(-C[i])`. +/// +/// [`CircleDomain`]: crate::core::poly::circle::CircleDomain +fn gen_carry_quotient_col( + eval_point: &MleEvalPoint, +) -> SecureEvaluation { + let (half_coset0_carry_quotients, half_coset1_carry_quotients) = + gen_half_coset_carry_quotients(eval_point); + + let log_size = N_VARIABLES as u32; + let size = 1 << log_size; + let half_coset_size = size / 2; + let mut col = SecureColumnByCoords::::zeros(size); + + // TODO(andrew): Optimize. + for i in 0..half_coset_size { + let half_coset0_index = coset_index_to_circle_domain_index(i * 2, log_size); + let half_coset1_index = coset_index_to_circle_domain_index(i * 2 + 1, log_size); + let half_coset0_index_bit_rev = bit_reverse_index(half_coset0_index, log_size); + let half_coset1_index_bit_rev = bit_reverse_index(half_coset1_index, log_size); + + let n_trailing_ones = i.trailing_ones() as usize; + let half_coset0_carry_quotient = half_coset0_carry_quotients[n_trailing_ones]; + let half_coset1_carry_quotient = half_coset1_carry_quotients[n_trailing_ones]; + + col.set(half_coset0_index_bit_rev, half_coset0_carry_quotient); + col.set(half_coset1_index_bit_rev, half_coset1_carry_quotient); + } + + let domain = CanonicCoset::new(log_size).circle_domain(); + SecureEvaluation::new(domain, col) +} + +/// Evaluates the succinct Eq carry quotients column at point `p`. +/// +/// See [`gen_carry_quotient_col`]. +// TODO(andrew): Optimize further. Inline `eval_step_selector` and get runtime down to +// O(N_VARIABLES) vs current O(N_VARIABLES^2). Can also use vanishing evals to compute +// half_coset0_last half_coset1_first. +fn eval_carry_quotient_col( + eval_point: &MleEvalPoint, + p: CirclePoint, +) -> SecureField { + let log_size = N_VARIABLES as u32; + let coset = CanonicCoset::new(log_size).coset(); + + let (half_coset0_carry_quotients, half_coset1_carry_quotients) = + gen_half_coset_carry_quotients(eval_point); + + let mut eval = SecureField::zero(); + + for variable_i in 0..N_VARIABLES.saturating_sub(1) { + let log_step = variable_i as u32 + 2; + let offset = (1 << (log_step - 1)) - 2; + let half_coset0_selector = eval_step_selector_with_offset(coset, offset, log_step, p); + let half_coset1_selector = eval_step_selector_with_offset(coset, offset + 1, log_step, p); + let half_coset0_carry_quotient = half_coset0_carry_quotients[variable_i]; + let half_coset1_carry_quotient = half_coset1_carry_quotients[variable_i]; + eval += half_coset0_selector * half_coset0_carry_quotient; + eval += half_coset1_selector * half_coset1_carry_quotient; + } + + let half_coset0_last = eval_is_first(coset, p + coset.step.double().into_ef()); + let half_coset1_first = eval_is_first(coset, p + coset.step.into_ef()); + eval += *half_coset0_carry_quotients.last().unwrap() * half_coset0_last; + eval += *half_coset1_carry_quotients.last().unwrap() * half_coset1_first; + + eval +} + +/// Evaluates a polynomial that's `1` every `2^log_step` coset points, shifted by an offset, and `0` +/// elsewhere on coset. +fn eval_step_selector_with_offset( + coset: Coset, + offset: usize, + log_step: u32, + p: CirclePoint, +) -> SecureField { + let offset_step = coset.step.mul(offset as u128); + eval_step_selector(coset, log_step, p - offset_step.into_ef()) +} + +/// Evaluates a polynomial that's `1` every `2^log_step` coset points and `0` elsewhere on coset. +fn eval_step_selector(coset: Coset, log_step: u32, p: CirclePoint) -> SecureField { + if log_step == 0 { + return SecureField::one(); + } + + // Rotate the coset to have points on the `x` axis. + let p = p - coset.initial.into_ef(); + let mut vanish_at_log_step = (0..coset.log_size) + .scan(p, |p, _| { + let res = *p; + *p = p.double(); + Some(res.y) + }) + .collect_vec(); + vanish_at_log_step.reverse(); + // We only need the first `log_step` many values. + vanish_at_log_step.truncate(log_step as usize); + let mut vanish_at_log_step_inv = vec![SecureField::zero(); vanish_at_log_step.len()]; + SecureField::batch_inverse(&vanish_at_log_step, &mut vanish_at_log_step_inv); + + let half_coset_selector_dbl = (vanish_at_log_step[0] * vanish_at_log_step_inv[1]).square(); + let vanish_substep_inv_sum = vanish_at_log_step_inv[1..].iter().sum::(); + (half_coset_selector_dbl + vanish_at_log_step[0] * vanish_substep_inv_sum.double()) + / BaseField::from(1 << (log_step + 1)) +} + +fn eval_is_first(coset: Coset, p: CirclePoint) -> SecureField { + coset_vanishing(coset, p) + / (point_vanishing(coset.initial, p) * BaseField::from(1 << coset.log_size)) +} + +/// Output of the form: `(half_coset0_carry_quotients, half_coset1_carry_quotients)`. +fn gen_half_coset_carry_quotients( + eval_point: &MleEvalPoint, +) -> ([SecureField; N_VARIABLES], [SecureField; N_VARIABLES]) { + let last_variable = *eval_point.p.last().unwrap(); + let mut half_coset0_carry_quotients = eval_point.eq_carry_quotients; + *half_coset0_carry_quotients.last_mut().unwrap() *= + eq(&[SecureField::one()], &[last_variable]) / eq(&[SecureField::zero()], &[last_variable]); + let half_coset1_carry_quotients = half_coset0_carry_quotients.map(|v| v.inverse()); + (half_coset0_carry_quotients, half_coset1_carry_quotients) +} + +#[cfg(test)] +mod tests { + use std::array; + use std::iter::{repeat, zip}; + + use itertools::{chain, zip_eq, Itertools}; + use num_traits::One; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::{ + eval_carry_quotient_col, eval_eq_constraints, eval_mle_eval_constraints, + eval_prefix_sum_constraints, gen_carry_quotient_col, MleEvalPoint, + }; + use crate::constraint_framework::constant_columns::{gen_is_first, gen_is_step_with_offset}; + use crate::constraint_framework::{assert_constraints, EvalAtRow}; + use crate::core::backend::simd::column::SecureColumn; + use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; + use crate::core::backend::simd::qm31::PackedSecureField; + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::{Col, Column}; + use crate::core::circle::SECURE_FIELD_CIRCLE_GEN; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::secure_column::SecureColumnByCoords; + use crate::core::lookups::gkr_prover::GkrOps; + use crate::core::lookups::mle::Mle; + use crate::core::pcs::TreeVec; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; + use crate::core::poly::BitReversedOrder; + use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; + use crate::examples::xor::gkr_lookups::mle_eval::eval_step_selector_with_offset; + + #[test] + fn test_mle_eval_constraints_with_log_size_5() { + const N_VARIABLES: usize = 5; + const EVAL_TRACE: usize = 0; + const CARRY_QUOTIENTS_TRACE: usize = 1; + const CONST_TRACE: usize = 2; + let mut rng = SmallRng::seed_from_u64(0); + let log_size = N_VARIABLES as u32; + let size = 1 << log_size; + let mle = Mle::new((0..size).map(|_| rng.gen::()).collect()); + let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let mle_eval_point = MleEvalPoint::new(eval_point); + let base_trace = gen_base_trace(&mle, &eval_point); + let claim = mle.eval_at_point(&eval_point); + let claim_shift = claim / BaseField::from(size); + let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) + .into_coordinate_evals() + .to_vec(); + let constants_trace = gen_constants_trace::(); + let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); + let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); + let trace_domain = CanonicCoset::new(log_size); + + assert_constraints(&trace_polys, trace_domain, |mut eval| { + let [mle_coeff_col_eval] = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); + let [carry_quotients_col_eval] = + eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); + eval_mle_eval_constraints( + EVAL_TRACE, + CONST_TRACE, + &mut eval, + mle_coeff_col_eval, + mle_eval_point, + claim_shift, + carry_quotients_col_eval, + ) + }); + } + + #[test] + #[ignore = "SimdBackend `MIN_FFT_LOG_SIZE` is 5"] + fn eq_constraints_with_4_variables() { + const N_VARIABLES: usize = 4; + const EVAL_TRACE: usize = 0; + const CARRY_QUOTIENTS_TRACE: usize = 1; + const CONST_TRACE: usize = 2; + let mut rng = SmallRng::seed_from_u64(0); + let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); + let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let mle_eval_point = MleEvalPoint::new(eval_point); + let base_trace = gen_base_trace(&mle, &eval_point); + let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) + .into_coordinate_evals() + .to_vec(); + let constants_trace = gen_constants_trace::(); + let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); + let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); + let trace_domain = CanonicCoset::new(eval_point.len() as u32); + + assert_constraints(&trace_polys, trace_domain, |mut eval| { + let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); + let [carry_quotients_col_eval] = + eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); + eval_eq_constraints( + EVAL_TRACE, + CONST_TRACE, + &mut eval, + mle_eval_point, + carry_quotients_col_eval, + ); + }); + } + + #[test] + fn eq_constraints_with_5_variables() { + const N_VARIABLES: usize = 5; + const EVAL_TRACE: usize = 0; + const CARRY_QUOTIENTS_TRACE: usize = 1; + const CONST_TRACE: usize = 2; + let mut rng = SmallRng::seed_from_u64(0); + let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); + let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let mle_eval_point = MleEvalPoint::new(eval_point); + let base_trace = gen_base_trace(&mle, &eval_point); + let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) + .into_coordinate_evals() + .to_vec(); + let constants_trace = gen_constants_trace::(); + let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); + let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); + let trace_domain = CanonicCoset::new(eval_point.len() as u32); + + assert_constraints(&trace_polys, trace_domain, |mut eval| { + let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); + let [carry_quotients_col_eval] = + eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); + eval_eq_constraints( + EVAL_TRACE, + CONST_TRACE, + &mut eval, + mle_eval_point, + carry_quotients_col_eval, + ); + }); + } + + #[test] + fn eq_constraints_with_8_variables() { + const N_VARIABLES: usize = 8; + const EVAL_TRACE: usize = 0; + const CARRY_QUOTIENTS_TRACE: usize = 1; + const CONST_TRACE: usize = 2; + let mut rng = SmallRng::seed_from_u64(0); + let mle = Mle::new(repeat(SecureField::one()).take(1 << N_VARIABLES).collect()); + let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let mle_eval_point = MleEvalPoint::new(eval_point); + let base_trace = gen_base_trace(&mle, &eval_point); + let carry_quotients_col = gen_carry_quotient_col(&mle_eval_point) + .into_coordinate_evals() + .to_vec(); + let constants_trace = gen_constants_trace::(); + let traces = TreeVec::new(vec![base_trace, carry_quotients_col, constants_trace]); + let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); + let trace_domain = CanonicCoset::new(eval_point.len() as u32); + + assert_constraints(&trace_polys, trace_domain, |mut eval| { + let _mle_coeffs_col_eval = eval.next_extension_interaction_mask(EVAL_TRACE, [0]); + let [carry_quotients_col_eval] = + eval.next_extension_interaction_mask(CARRY_QUOTIENTS_TRACE, [0]); + eval_eq_constraints( + EVAL_TRACE, + CONST_TRACE, + &mut eval, + mle_eval_point, + carry_quotients_col_eval, + ); + }); + } + + #[test] + fn inclusive_prefix_sum_constraints_with_log_size_5() { + const LOG_SIZE: u32 = 5; + let mut rng = SmallRng::seed_from_u64(0); + let vals = (0..1 << LOG_SIZE).map(|_| rng.gen()).collect_vec(); + let cumulative_sum = vals.iter().sum::(); + let cumulative_sum_shift = cumulative_sum / BaseField::from(vals.len()); + let trace = TreeVec::new(vec![gen_prefix_sum_trace(vals)]); + let trace_polys = trace.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); + let trace_domain = CanonicCoset::new(LOG_SIZE); + + assert_constraints(&trace_polys, trace_domain, |mut eval| { + let [row_diff] = eval.next_extension_interaction_mask(0, [0]); + eval_prefix_sum_constraints(0, &mut eval, row_diff, cumulative_sum_shift) + }); + } + + #[test] + fn eval_step_selector_with_offset_works() { + const LOG_SIZE: u32 = 5; + const OFFSET: usize = 1; + const LOG_STEP: u32 = 2; + let coset = CanonicCoset::new(LOG_SIZE).coset(); + let col_eval = gen_is_step_with_offset::(LOG_SIZE, LOG_STEP, OFFSET); + let col_poly = col_eval.interpolate(); + let p = SECURE_FIELD_CIRCLE_GEN; + + let eval = eval_step_selector_with_offset(coset, OFFSET, LOG_STEP, p); + + assert_eq!(eval, col_poly.eval_at_point(p)); + } + + #[test] + fn eval_carry_quotient_col_works() { + const N_VARIABLES: usize = 5; + let mut rng = SmallRng::seed_from_u64(0); + let mle_eval_point = MleEvalPoint::::new(array::from_fn(|_| rng.gen())); + let col_eval = gen_carry_quotient_col(&mle_eval_point); + let twiddles = SimdBackend::precompute_twiddles(col_eval.domain.half_coset); + let col_poly = col_eval.interpolate_with_twiddles(&twiddles); + let p = SECURE_FIELD_CIRCLE_GEN; + + let eval = eval_carry_quotient_col(&mle_eval_point, p); + + assert_eq!(eval, col_poly.eval_at_point(p)); + } + + /// Generates a trace. + /// + /// Trace structure: + /// + /// ```text + /// ------------------------------------------------------------------------------------- + /// | MLE coeffs | EqEvals (basis) | MLE terms (prefix sum) | + /// ------------------------------------------------------------------------------------- + /// | c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 | c9 | c9 | c10 | c11 | + /// ------------------------------------------------------------------------------------- + /// ``` + fn gen_base_trace( + mle: &Mle, + eval_point: &[SecureField], + ) -> Vec> { + let mle_coeffs = mle.clone().into_evals(); + let eq_evals = SimdBackend::gen_eq_evals(eval_point, SecureField::one()).into_evals(); + let mle_terms = hadamard_product(&mle_coeffs, &eq_evals); + + let mle_coeff_cols = mle_coeffs.into_secure_column_by_coords().columns; + let eq_evals_cols = eq_evals.into_secure_column_by_coords().columns; + let mle_terms_cols = mle_terms.into_secure_column_by_coords().columns; + + let claim = mle.eval_at_point(eval_point); + let shift = claim / BaseField::from(mle.len()); + let packed_shifts = PackedSecureField::broadcast(shift).into_packed_m31s(); + let mut shifted_mle_terms_cols = mle_terms_cols.clone(); + zip(&mut shifted_mle_terms_cols, packed_shifts) + .for_each(|(col, shift)| col.data.iter_mut().for_each(|v| *v -= shift)); + let shifted_prefix_sum_cols = shifted_mle_terms_cols.map(inclusive_prefix_sum); + + let log_trace_domain_size = mle.n_variables() as u32; + let trace_domain = CanonicCoset::new(log_trace_domain_size).circle_domain(); + + chain![mle_coeff_cols, eq_evals_cols, shifted_prefix_sum_cols] + .map(|c| CircleEvaluation::new(trace_domain, c)) + .collect() + } + + /// Generates a trace. + /// + /// Trace structure: + /// + /// ```text + /// --------------------------------------------------------- + /// | Values | Values prefix sum | + /// --------------------------------------------------------- + /// | c0 | c1 | c2 | c3 | c4 | c5 | c6 | c7 | + /// --------------------------------------------------------- + /// ``` + fn gen_prefix_sum_trace( + values: Vec, + ) -> Vec> { + assert!(values.len().is_power_of_two()); + + let vals_circle_domain_order = coset_order_to_circle_domain_order(&values); + let mut vals_bit_rev_circle_domain_order = vals_circle_domain_order; + bit_reverse(&mut vals_bit_rev_circle_domain_order); + let vals_secure_col: SecureColumnByCoords = + vals_bit_rev_circle_domain_order.into_iter().collect(); + let vals_cols = vals_secure_col.columns; + + let cumulative_sum = values.iter().sum::(); + let cumulative_sum_shift = cumulative_sum / BaseField::from(values.len()); + let packed_cumulative_sum_shift = PackedSecureField::broadcast(cumulative_sum_shift); + let packed_shifts = packed_cumulative_sum_shift.into_packed_m31s(); + let mut shifted_cols = vals_cols.clone(); + zip(&mut shifted_cols, packed_shifts) + .for_each(|(col, packed_shift)| col.data.iter_mut().for_each(|v| *v -= packed_shift)); + let shifted_prefix_sum_cols = shifted_cols.map(inclusive_prefix_sum); + + let log_size = values.len().ilog2(); + let trace_domain = CanonicCoset::new(log_size).circle_domain(); + + chain![vals_cols, shifted_prefix_sum_cols] + .map(|c| CircleEvaluation::new(trace_domain, c)) + .collect() + } + + /// Returns the element-wise product of `a` and `b`. + fn hadamard_product( + a: &Col, + b: &Col, + ) -> Col { + assert_eq!(a.len(), b.len()); + SecureColumn { + data: zip_eq(&a.data, &b.data).map(|(&a, &b)| a * b).collect(), + length: a.len(), + } + } + + fn gen_constants_trace( + ) -> Vec> { + let log_size = N_VARIABLES as u32; + vec![gen_is_first(log_size)] + } +} diff --git a/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/mod.rs b/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/mod.rs new file mode 100644 index 0000000..6ee603e --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/xor/gkr_lookups/mod.rs @@ -0,0 +1,2 @@ +pub mod accumulation; +pub mod mle_eval; diff --git a/Stwo_wrapper/crates/prover/src/examples/xor/mod.rs b/Stwo_wrapper/crates/prover/src/examples/xor/mod.rs new file mode 100644 index 0000000..34e702a --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/examples/xor/mod.rs @@ -0,0 +1 @@ +pub mod gkr_lookups; diff --git a/Stwo_wrapper/crates/prover/src/lib.rs b/Stwo_wrapper/crates/prover/src/lib.rs new file mode 100644 index 0000000..1e9c3be --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/lib.rs @@ -0,0 +1,23 @@ +#![allow(incomplete_features)] +#![feature( + array_chunks, + array_methods, + array_try_from_fn, + assert_matches, + exact_size_is_empty, + generic_const_exprs, + get_many_mut, + int_roundings, + is_sorted, + iter_array_chunks, + new_uninit, + portable_simd, + slice_first_last_chunk, + slice_flatten, + slice_group_by, + stdsimd +)] +pub mod constraint_framework; +pub mod core; +pub mod examples; +pub mod math; diff --git a/Stwo_wrapper/crates/prover/src/math/matrix.rs b/Stwo_wrapper/crates/prover/src/math/matrix.rs new file mode 100644 index 0000000..697f862 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/math/matrix.rs @@ -0,0 +1,67 @@ +use crate::core::fields::m31::BaseField; +use crate::core::fields::ExtensionOf; + +pub trait SquareMatrix, const N: usize> { + fn get_at(&self, i: usize, j: usize) -> F; + fn mul(&self, v: [F; N]) -> [F; N] { + (0..N) + .map(|i| { + (0..N) + .map(|j| self.get_at(i, j) * v[j]) + .fold(F::zero(), |acc, x| acc + x) + }) + .collect::>() + .try_into() + .unwrap() + } +} + +pub struct RowMajorMatrix, const N: usize> { + values: [[F; N]; N], +} + +impl, const N: usize> RowMajorMatrix { + pub fn new(values: Vec) -> Self { + assert_eq!(values.len(), N * N); + Self { + values: values + .chunks(N) + .map(|chunk| chunk.try_into().unwrap()) + .collect::>() + .try_into() + .unwrap(), + } + } +} + +impl, const N: usize> SquareMatrix for RowMajorMatrix { + fn get_at(&self, i: usize, j: usize) -> F { + self.values[i][j] + } +} + +#[cfg(test)] +mod tests { + use crate::core::fields::m31::M31; + use crate::m31; + use crate::math::matrix::{RowMajorMatrix, SquareMatrix}; + + #[test] + fn test_matrix_multiplication() { + let matrix = RowMajorMatrix::::new((0..9).map(|x| m31!(x + 1)).collect::>()); + let vector = (0..3) + .map(|x| m31!(x + 1)) + .collect::>() + .try_into() + .unwrap(); + let expected_result = [ + m31!(14), // 1 * 1 + 2 * 2 + 3 * 3 + m31!(32), // 4 * 1 + 5 * 2 + 6 * 3 + m31!(50), // 7 * 1 + 8 * 2 + 9 * 3 + ]; + + let result = matrix.mul(vector); + + assert_eq!(result, expected_result); + } +} diff --git a/Stwo_wrapper/crates/prover/src/math/mod.rs b/Stwo_wrapper/crates/prover/src/math/mod.rs new file mode 100644 index 0000000..42c0c38 --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/math/mod.rs @@ -0,0 +1,2 @@ +pub mod matrix; +pub mod utils; diff --git a/Stwo_wrapper/crates/prover/src/math/utils.rs b/Stwo_wrapper/crates/prover/src/math/utils.rs new file mode 100644 index 0000000..fd177fa --- /dev/null +++ b/Stwo_wrapper/crates/prover/src/math/utils.rs @@ -0,0 +1,24 @@ +/// Returns s, t, g such that g = gcd(x,y), sx + ty = g. +pub fn egcd(x: isize, y: isize) -> (isize, isize, isize) { + if x == 0 { + return (0, 1, y); + } + let k = y / x; + let (s, t, g) = egcd(y % x, x); + (t - s * k, s, g) +} + +#[cfg(test)] +mod tests { + use crate::math::utils::egcd; + + #[test] + fn test_egcd() { + let pairs = [(17, 5, 1), (6, 4, 2), (7, 7, 7)]; + for (x, y, res) in pairs.into_iter() { + let (a, b, gcd) = egcd(x, y); + assert_eq!(gcd, res); + assert_eq!(a * x + b * y, gcd); + } + } +} diff --git a/Stwo_wrapper/poseidon_benchmark.sh b/Stwo_wrapper/poseidon_benchmark.sh new file mode 100755 index 0000000..c796985 --- /dev/null +++ b/Stwo_wrapper/poseidon_benchmark.sh @@ -0,0 +1,3 @@ +LOG_N_INSTANCES=18 RUST_LOG_SPAN_EVENTS=enter,close RUST_LOG=info \ + RUSTFLAGS="-C target-cpu=native -C opt-level=3" \ + cargo test test_simd_poseidon_prove -- --nocapture diff --git a/Stwo_wrapper/resources/img/logo.png b/Stwo_wrapper/resources/img/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..07d6055eea4b8eb2d100b5cf54aadfd564b29266 GIT binary patch literal 19360 zcmd3Ng;yL;ur;#4;_j}&H3XN%0|a+p+}+(hNO1Sy?(PyG5S-u+!QJKC-+S+m_|94F zIny)KRb98L@9iih1!+_yA|xm%DAcbq5-Lzo&_5u*Pl51|PlI)e56Bmyql~r-6cjS{ ze{X20%xnTkBeaW(v=~&)6v;oxAAkku8wd)jJ|6iM3Fx5{F6ff1P2040SdbGHq zVLPt2p(NR%K{xp4&fdVw{OQ-j5KH_ix`e`(}W+FN^>UUp0 zJ%HE(W#nx=+AY`Mx+c_vDMU9^i9l#;VP7&#oZ0zLSl~5_q8SxDX zS`perRmqS7OZ>g;04`E{yP-AgWa_eD{B&)NM@|&ufu4qDNN9KnTe>uGaE_h`r8rR3 z2z;CQ1H(8K(UUa0nGg$a#~?T~RBOqy5|-B|oWs2?KD3B(RxC3`q-UwJ(D{D0#9}UIIjewnVs*6R8K}hN%@RNOM&JC^8Z$9K)ayr6)N3Il{4=|ZmE12#790M_{E`7`A~7m|(dOLE{8=0VQ)o-zW= zV3!VszPvuglup5bZV4c#mU`MVzcZRDIr%G6g}QoF=TLay&&Wag$qm#P8*{vagSI7# zw5NnLszMqkDVbl+3_cHnFI3u;#_=-7D!DCR6cDMVpmz84O-`((g?k4ulrf<)D9oHBQ=JZDpqMrK zx7aeW+|I{4E>R3b5FTP8kcR36ID-DeOpN9$8juPa>FGFTHlX2L!qlY(M!=b=j-eHv zI0tF39a(&9+y+`q_#BDu;@1(3V=R~GU$?!m{$K$B@9J>u_-;U>vwT6 zyW80rZ2vnGYebff_~kDhvYJ}dK~x6T6PN?C^eh2xSilS%#8WhlqHA4KNFs#yuLh`2 ziMx?~q%Y{VMJL2i-&F|p{zX)FU_^?kBLkzOeQWAMhdS*L5g(`B-rnZ@E*{mr`c7=% z9*vrzT8qn)*WV|bMr~O&)zFmWa0`hchK~EEiRmR4J-QnP_n2{ë)D+-`0fP-Z| zp~gToG}>Xt(j3xa!a#~Yxu|5pXZhku&XOEoQnr?WCs*>ZW~A06;|hm}ID_eru*s7j z@9As`?GK(%_w+TLA-aW4gu1MT+jPe0v3z`h<+!XyF{^*v*%C4o!S}6{5Er6|YoJg8ujwUXUdJbru8G$Z#{fJM4RQoS7&7A4 zeN&8pQ6<6_)at%w_369j6@=j6zwV4+m@OcNM3&Ms+sv2M7XNdWqY(I5k?w!+Bz{Mn zuk$0iN?YpO`MX9e6s{}`nBd73w4RBtn`rkBVWNQ>i9_G zZ*BRJ_uG4s!uNj9ZzCRIE&jw1scYG)&W(XaJyigE{?kTwRHbkfznKG1vvX?FMcRa! z5d*#dX5j>*VB##?)k|N0SZ!xRhhDw4AusqsLI5(?ShzT@PQiTL7a4LkYC5F7ZEw_% zvybX*`M=x7-%214mi34F7?3njj(s- zQUp45OBxZ|?#-nJhbmrS7p0zvGmH>V*BOY3fr=jEy@itBe#?s23nnM;JZmpbC-3{R zrjfs{7~Y9QE*;7SJa&2$47fzQzZ!;3s?amE72`=qOY}TkoQC^&3Ob3+sD%!GOSXyO znD~a*w)y%*avst$rd2}~kS+$lOZss&Nk~549(KaB@ya8L!6GUq)-=@eBJ80OXQlcF zHg4@h{INObVlc8vhTZvl2#ZGk;ocdg2u6*nefOgzK^Aam0D}lK5<_zuPpV7Y5SD9~ zkUxK_S}R8L39UzTdDUizlRJ1cHeFT%xc?MkrsS{o{+%GQ_I7AtSh&~jMKQb6|Lr#x zh6MIlv*hOzP1b5sai&WivL2vfEVGya{7>Xb?09dZi65&eBx@BY7ww zT(hAC2R-`51RcEKiX~3?qiwz!u4>kEu%hX+~=n>Ao^JBYpT~)Nis0wvR4D^s_>ZR zq6zWjhW2r_kbX)Iyq=|&Z0iaf7r^5sUh$uvO}3_nTChC9gb-TswXZfb$aChrB*w&8 zB0>OxvfI(>ZS9f!Q6Tm#@y!6uJ4g++rd!z(?@TUjZHw}fCD5X{^IL6sh;N#Db z=vb$f#WY9}2*MIxi`c&a-|vg?!7lcwKa36?33G`$VZHWEXCsY{79WI8`p64yf{*!q z$qd8>-U11warzSn1^VGl40|zD0O(;BMmk!)bxc=_1?seM8C1*2q8O+Uv^yS<;GVl% z{HQ>w5PA6H;~e|VJzrw!PrO4L^5EoME4xt;39zVPJz!o9Nxj~_mkjgO!lcL6oP_Lx zU9JTHSV-z|<~Ttd;O%0M`&ArvFo4I?$hf+7O&C=Mv%Vzh2 z@qP?@A+D33E^psPUE2+-QQYwh`qeD6C#=w=|M?ML#V7iV(c zrsiW|KSARz9cG9a0V9<>>MWtIIrDIvnJ^o#isO>ce4-+`3+L${+E4gx#he%t0Ep|P z##UjHgvOtGl9LfIil++dXr%`4j0pdP-E<#`YcrX*xevarjEqwuid?ezq1p(ed z$w%(K!0I3JJG5$=V*lNpnBD_#F-lxeTgMi>fij_5_ue~YQ+T@iE#+-l)j5)!PH4)? z$tJJ^dG(Gg$VN_rC7#`{ZPEX-#Su?O()=!}#tK|NeNEJ3YLZgN&>#!M#=;Wg;Nufw z9j&ev9p}xK>&Cj=8>70-{N9dKC&O4fP$6%sB(Sr3g+XeHZjF+WQEw(R%#Y8NYnQu> zLyIY-ESi1RPhkt6#9BCBE}960+#=4t^Ig77I|nuV65+#|5ITxQRypoyLU#9IUzH^f zBOrvBYNQr7ye|YxQojI=|2NOtCub!exz8^Xe{fWnL+e>sM~r_=4D=3f5z4~ReWsxa zjgaXr#Y53TdK1hop8C}DItKqp4A7Wx+ zG8L)`oE@#nhbzH7SlEphDWBZmDePOAIfg+JyW~G#1Ti3h%jGassB zhJ8ho~LxydG-R zrSg}m5Ygn+^;Unk2nHwX2VnHJd_mFN-yd9nS*zJwP*KbFat@hxS0W9q^fvUeA@Hu` z)40kN)My@`e@j8qWVv>EA12z#SW+1Ntd!2Cprh=nC$;I~*P6Pr7kua886f_QhR{!d zhzJQ~`MOMIpCT`D>T&@iooQO zZDgni*p-R{@wT8qdZ#N%It<*+#lFLe0)5ZYfwwyBdg}Okx_D)7@jWqs)$ntV^=+B% zhIqOS$x*KR=u@iPH{EAfPol`62xFR%xGaBKfw_82vS`GP7H<|ho)Wk1@16X8XpW_s zAps)2s8f1gI2$?yni_KSRJjBIG4PsnrGz%k^T7ARvQa_I8RMbvCVR!{v0r%zYB)Fx zgFjGijkQ2!QlpOPq~ceLVpr$UG93*T{l$6RlEYB|ZJH1?IS;Aw5z#4@7(lv3fp4!b z`Rn5t%MmX5SIrc6lWLj%^!9Gj4ehlpU9EgNJ?FE*NbWIrxb*(*zR336G1|!GNw#^(e%t zW2RuAkk(lY2O5o31YoaSCo-I}`W4zaa@1kRHJ~5yT@6!*5+n4UWzLUVM$PkhKRBuIhlpj?;sI-dGrF;HBLvyWH74NJmXRfD9j4yepy zHSgEbq#k@;y{S(>z~BucJt|0@_pp;k;I3i$X!}n6~Kph6BLGtaCO$r;yPl7<{c zdfG!PdUfA>{VMyUBe#TI(6Z+hhVo6|SRbQPzAR*z|L1a^U*AUKO|)1?DMU_jDd4$j z!}Gh^8?fL14%m+cZ@zAQeeBM95GR_{ymr*q;l*9@*hVyzPcK94Skj^JCmF_wXSd&} zSf!p#xzG51V~xqwD}_0k$_Vac!Sl(XQ(^ro15beU1DY0D6EGDT5l({piIDVnr50j5 z4xKgV$s)#Q%aF|vX?E`-;4}2c{UVx7lDbunp4J9^HK{pJ%+ijRd4x#$r45v-Wt0f` zH_Y57bkKG=mt{Y2ptsN}8e7F;gPLs1jETBSCD8ceNI*sgKAEHvhZ0*bO+x``e0jO< zdLMzCHgV<0T#N{OZ(|d}4|BrnOJYqCgpcc_vjjMGLYH>>;h_o;blLZ)I+?_x?VVX@ zA3QC3aI}*_M{7^rKloUHeNE(kVh^+mEZiCbG303^(d5eCE`KYuGrCM%9m_~bN0h`c zS=_Bf>(KQ!cH#H5KN7Z`$BAKeAwLWn=7~UUc5tF-t0Wpb`&$)_XEv;>OKle_XCEz; zwWS^{|Lb3oy|Oo{+UrXtX8*aSHm?NsI(H)ahy5u-}8q9$>>kuQ6ZVX9U=)jp2BVn8*ujtb8$YIOq3u zDg*dDenOq@L=ksHBn#wOHt0gu zqC%Cw8Ja|SrX>2S%778a7YlIHNjgzC4VGcJ;UjC8zjGs3VK1#_n!)lFj;6c7HmS_Y1DpL zfzj>fyUyHY08|`INvGZ>{yq1lW%6`&zFNtDyYW6tqZ4}?X263=@j7)K z+8_bB&JJeK75^jS67JHOM!LQa4ERpK1gyCYySuaCA;fOegpt|%bRyosGHzh%6ERF4*GS#rWzpk(pv+f2jolSwoSy;^c=+>Ag zL=e9>b>^HnCHiTMvj%K)*&J}(b2Etk6@qRbbhi4Xd=&)O)gVy_1H;jJCfZ$-Zm2~s zF5=YsL$$jDCJpEP&Bzt!AhF3vgx#S8%7z2;XJjbuP2Pqh0vV}6aK2gkSjAT)jw47& zqS$wSUX0|s(bp~qd?zkSjlxU4dKuZ;j<^W&?LB;O*aqwI6=u_QK zQ&2~16L&DxKn5uX8e+U&Dz=ol zRDnjWL*zOjRQ)=h^0mBs@NcmQtM`RJjIt z?Uy4`JtGU&zmlM}1c#u<#jnc-8y*$~Wn&WwC`0kFD3wO9*RvHkVd}If``1KmHN8~% zY-Hxr+XIdYI-gA{SSGR@%r`T82t61##MHgJu{tkOSSUV$ZUm_m32)6_KR-vDPO z4^mMi=IKGadBnDUnB>B;Q~5Y$=CvG^STS^aOXO`L?BBChIay{H%~UUt2Y#yweb+&Y z#qYfcXC-=+{-589#-Eq#3bvvml`9)M8sOi?eoUO%GEs-f;S(F*j2I0|1n6qHU1 zK}7rowZoebm2hnJ*05osA4@L@N7F$$n>t8cEm*0n$7Y{LQ^ICeo0MaF3ZU z%5x%W7tBGhx|S{d>%Py@JmYsGxXlpWX!a^Bc4FahjVNt~czkAmU(02=AUdW{8RFduujP!N?Q&N6Ba6FjiG(7fF@2c z4PYX_X(9?mm{|){6!30$ zd?Jwk$Ky!25w9U#oW8fm9UOt46FLbrO4Rj$z{tCQA1K`Ukwi*|0+lV9;SI%sS?>hibWj&H#3P zW#!=I{eb!7`;IvSXM3HS>}Ob%0k?0EsDU<>BN>D9l=)oLPMF9CDO93KG&r%798ZlM zPZnX;BrAu2v8yoP&aMF{L3FsVe~-~ug65t5!guAai)I($-)GfJ^p%WmSFRamhs*Y!^QDB!XFga5d*R!SxUEgQ3J zkq=7X(%riDias2O6!;Kw2ofB^K(cdVpW`;yX5R46dpa%6&s(ODq!|q!UI_678jnb) zZ*J1m$@)4$u@6bwzj!u7ZFYBRYcsmr;5M2e^-&MX{cDlrz>rV)UvN>{Yy0sRM&joO z`9;b^_(gfa=B_w9LO;ImofwcpdtzzI& z!31s}M0c`Sw}7+v74dn>?-YW^J(-wLj8^|xn_KnjMPfmsp_v^jy>@#`7mkq&#*N0= zOFw$ez!btL_~s`^Cq_M;4^!s$K0fQSG1g*gooFn3Gr{QZk9#@Aex8s1-Y|otLsr{C zTr9tsA#610RkAhsRS)~5rPUN(W;W)5*>&2-OBKY_3s%A!7QBw$;1pc(dM^(vdH8ea z_sM1bN)3>+Nw3q*dDI(2+g7$U`*YyZ>8bwEbBiE1I|GD%NMcxEl zJ=)0GGZquBmM;B!GXpz+&p@o$=(?gXP*Ly-s_W$(_lOq#VB*+;Y?T)YJ1--U?LCZb zU}O3;OxdfKj9kx`ob~WMpKI;SUW&jh+Ca;m|NU4aXcTu0Ody z^OGlsN|NVex!pGme0U9}#__ir{{YARYmnDYpoUgBJAHL#O>|UM^Cdb#Cud`{} zKX8%*u047XTWf|8qvI;Ke)O2IGgs$DC3gCMi^AVpDrTTcf7WHpp@9sKQy_s%i``4t z>r9N53^Wcp$xvecT_+n7@2OWRtmY~{o<=fvkWHIFklS`>ebUwd#i1ipl3Z&MPnK{v z%P}c#)2M;o6D7ii*CFf}k*n3)K%NiI`$+i*oUp*1HUZIjRg2ll)S06w()d`o^5b@Z z7k)*h0b6}vm{~E(rq;yZ+mrfoLS;pdzDI@lW%cSRh27p$uglwrhXO1|0tD_vGqMD< zWtkL$ceXozHW#gH=h_TsiA30^%B2a1KgcZwGqbLdm|750nRc{yiFPWh4*dQM7PvW? z{#c72oPIxz2Q9ujjQXP~ycYGc?ozB$BL4LMhrbo@R{eq8GM*kyONHT%I2r&gpHk<1Xl-uS(CJ_ zju#Jy**r*)4HDEKwdrzp2u@_p*PUVMuD^}i*<*tGh@L(zG%QJqRi6z7s+Z*P@AZt< z=#5Q5YXFpqFodf#!L5lmjc1Nzr-!;Ii2QI4m(+TMIUgP+3%`g&s} z`>0BIXnJWiHOytf{}sVXv8I84Ti!~^5~!a*V^AYI!G}gsP&&9vwVTTI38VsFB5GKB z`AoV3C5~mX|6tYYf!n?@PUbz`QrCg2DAuhsjqIwCzFz;$^CanczcUOYf2o(Iw3GT0 z*S{O!PVRew{tny5X>P1O7}6royXNM?-pPw*6&!RQlV;#e78w*4Y8v>JYb!3Dp|4q6 zYa0&s0^~$F_=Pd1uAu=k#chKN3bIp5_8(c<&gY{e&*5zm<4u~huVRQEPQ5UC9*4W@ znwpSk3A6s;HHqbByu4fg6zTbox`-l=6w=94V(=0+^`vo@i3KNgD3qLepuueC;fI#baas}-0%pO zAbKsy7=HL2T}lHvuTe{uX`4`~UU-sMkq_0*?L$)vefG^lQKe7fcy+vi>DHCe0Ue`RJl8!?`=mL`t|J{n^q zrDFdh>av;x`fej*U&i!4dM$TefN24|z}gy(`C-6P_9oXAh3`JwJK*(ssRz(|)%#I+ z%o)qdsVCrnviLoLTl}f3B0b9Azd_{%xq-iLgcB#z_B0=ZKk`Pz+4k(lAIQ;(@eg(Q@>0#3bGhnHkeQ zQmO9y()|YKYMCW8(mfcX-eBM`|HQ_sylU!WD?}HD38|m^$YIj}g}KOX^$R0%34Ss@ z-uA~6xj(5!iF)CNgGfF2_x|N{tw?0lQaLWQvG3l%a9S%A#~-D|#Xwd1K5ZFNa4|G- zhA#6X+uBbD;3{_uJ;LpU{{zA$Lcpca<5tK?U-ecerVp4|MR$koh40S7(eWRS_R#5Zc(81h}F zkNg84U3Xq#=CkQtSq%RnP5u`iU3LOI8uyK{c}W5F3mmL*$VD9G7ahm(eeV4pv27h4 z14P(o1=AGJ6X(s-2HS~l_8P*>WVMaWzgm4m0~ZLxfrSNQz=!^ERHW;uu7kw8yxHmoYx|=<8e+}<)Rz&X;)IBAS<)s$ zL!DcE8k&`LU0uu}2e?{V^n>dH8(7W!3L#6|a`*2UdpF(F+Ju|!gVuX(G_a?wi{UEY zgU27v-K^(jC?MD$s@iVyDwsYr;T zMQq#5p`TC&Gv%z91tVZSerUdXHGC|+9@`}l{S2i}XYS;X>qd%#h`GsD!V}=-=jl6Ip;;ja!~tvbe89O zGpf+`d~SA^X8dz8MSXHnw$*)awj^o(EbTwu2Ik;E;dimVYH)_KQAbqll|njfSNrb< zo+b?fXe66Hi}xh?Cw4`wc1*9a$uX?Oo&gz9i(+i~K=(mQZ_Nw&X1k&eHACl$o)Zyp zATTtv4!i>>2I|mgNf4-Y#QbTU4u+FUr={qZLht%@Lq;1hE)y15)q;6H*Jo+w9c-E9O#R}{Dllee#I887D~ zUEOVnvy<~GZT}QLN8Ul!3Y3xM0_yt4UBN?gq3)z`chvncR~M`GDMWsVRX|Lqps?qH zkk^80O6y#5k|2(g&;4>H=8nQREVe$IOz*c3??Tx=4shpcsg#oMNj&wtD! zc9xOfXL#OIbaZ&!#)v47A^L@(q#$*b^0opwnF#q%a0Y18KbN3C(NR5iG36K9Ge>;x-cBWBNb0_+(gMHpQoUnxhohbA#tPR8D zjdP_~v*oX72rMH8oDB2K#>Y^Ye*ZsKJ?Ho)0U8;3f^noz>-Si7-Lnfwjm?#LFwRP} zT4Riu8Q5i0DT%N>nBNhgr@^EEMj4Ka2RIK1f+q)t?`9%+wlaFJWsh-BBVQXWei2a; zLAas${Tz=T4yMr&dB$`nIuVgxsAlsZk;N?7uB$d$R#4j@n;6Ot;XoOyDF{B*0|is9 z34sI4YtwC=LW4>o77kmPnQN_~NlT%*1t9rwH1naTO{3lScuYwV2@lb`pc0a8k7`E? zwaFhGD$C>b$A%u^==$7mh7zPD3n#3nF*PZiib?3|RoxniF9a`{XKVCI*`j5Hr{{Z; z@IPibzQ0B!z+F3dS=paw(f(hT%0g|SJ^;?0gWqs%UM7^xGO$gSs;DH(9Yc!XD{`ruPV@rL7Q3fftt8RUNQhBY%X*9i{DbPLzC_S(?GhAyHSD9P~SCT;S& zpuN9{_hytw(x>6UEl1%|Qc~(r+2G>`SF3IKJw!WuP%iRQf{aK-TbZIHfpxG_ch8bVl0*>+*)LgGWrx>~su)B55)yeX~x zlqBMX?td2WPWL)P{u^{qxA1n+<8zFJArbQ@^Grh52ybb;;iV~CL{ThU50ap^+OrHm za;=CQfl*l1KKa4#M5N#XDs(+ngEiRGz;?Hog>;;OjbFNtj4cOYf4Nzuh;d1IqL3oI zK1 zr4iWEf`>l2+B~yku*RbG7@eXT)tDF$>xdG(h$}7Fw?NPR)$MpHj{He34f`>2GLw%R zT;AP{g~_(?u)iycVFwAHO~2RHe*G}ApN^jY!~e7^(gStT?V(<>e7x+DPBuyP4fjkP zFC+mGIpmYZVp{4CosdCK!7gp$jZ3(%p@l_;F$U7K^u@_Yw+C5iYEl{p+S*@m@G1Gj zjcFvL^?KjUtxj5|*Hfo5^CHNcAZs*w-QK#~0c&*yP~u+FBh)1OzXsB!2wt`}zI=); zI^vCqwvh?y;{|f_eb(3+mb&cOCAh-=5X=qiLhild2_WwA zKYs9Gppf}GB8jG?2br%BZHbBfn&-{xYRd=B)Ld!4v{4THVvg_YmEc-?BZ}?_k0{u) zObUZU#B!=IMEGVX&?*Ql6{wG#Xo4Jz4S4pdYbc^+^(L2-7!VcesC$+W6>X>gd#tDo z=^P#&Fhh%tJ*prV|C{U2r?r3&bI;c+gsrGg9%8Jcl!q4Iz$2QhhxBsLY}TCt-&YsT zvRi96!l#|uT6>FfPqc|1wWhAfmPVak4^MVa-Y$u<_}=dwJ}6(40$!_gS%cMKcM*y! zo#BCn$dIWnam2#`x_-(Yzg}rR@rSG+RdKo6xXm84YAdbQ(Y0fzI^UjbYAxOWn{ zS^}&|(!RZOt)FoU3Y=!tuZ$6S3(( zYw0IMkaeJw`FmF8iFLtAN#FQ$6al@>IJ%c*a#SP*zRDvkbXRNEP7OZvLNs#k5B@-h zUO(;j)n_~`Oe)YA6(yF5v$MBe^e^H!8oKYx5SD7BHMV~WVFWypidyHkDn&p{99};I z6Kh9sucyPucAq%0ul8Eh)$o(?j6?SnLY6qbC3Pw>>Z^i2hKD(;kYywyLm>eJtCW0Q zFBClrmgqKNxi%Z{zrtil$rVSMR7E)a}IwYQF=pqo55E)Ew-b_6T3zIPvb3mUUN`} zRXefdRw^$7J3EnP+ zMTB|-h!qxvd^{A%akj{kHxEj|Z(5AA*MWpk!s3v_8<4~248C$^vi^;wd+VhIib@>l z(6z9KT>G5av=bq1X^`CU+)7C+y-Z=e_Br{oDy5XdEZGvOi$(8WgY#q6TBwFJB5;Qc zk{A*2-k)XVy6O{12RZ3$?2$zKXG`ku#T&bXv`xIc-*v6k(o~}~G0V=*(*N3>9N(Qv zvcv#@{y%z=q44KxWs^+{q%e#;;jmv3nAob2S?Z0&IBj`pibip8u<6~mIId?bSlj>n zpDhOTr8Xfv7T!Z6W^64UaWk;fRK>K)jT3aWN1cuI_vnwl9>$p z(_2asRgb&z|K=jlKMNA%1)W@xNZu>Dv_o&{)z`-hXv{s|x2~j${L&rsL;|knwJFZ6 zI5B=;P%0$lV}zj8KJ*Qp7qQ7;_Ji7i|5RK(&sp4hLgSW>r~q=a6;wbz0dPcER)EcP zt_5UA0(^qrniO{#mp@5^M6O~k{@8+KG5jt9l)j4H0AFk;Yb=*?PBM1lSF2GG)I>x) zlcRexw(*Ng^Jg^d5UKA8w-OScue+QzJYELEG!DbL2kqIX+8-r`_!IM{7)AVSGlG z&;IRNSenN#a!|DqL~=RyNCiRxwn)y>DA_6f*twd$j&|wzsvn7D%@hQqIklBR<4YpsV z&C!u7KseRJoon{~fJQOXB7TnrwT5dI2IP2BWkkG8H z2=plh#*lX~B9j5to?;BMjWLE2$B%0oM~;l(%Jz0uE+2PtpYB9YaMj+g zPoEiI$VleINW- zJYC-b`tVM<@OdB+_;*KGi9Qm?@AmTT)c;<+JXh7t1H;(-e23{9OMA!Yrc0GmN^2mO z3%(giWs-@AldCMde}apr6nf#KpRu+54!5F_KRR6Q;CGVEX)nql?N&yE-p7sZ87m|A z`@ODLgQ?8!qd$y6U%ak=)qdj-jxbKHOx~$DP=6^bI0jYOB;vgmYfc0k;--(9m+#{q zD*fp9f6aD+@WG2STQrn7@_q_@O4+bWIT=0=^C&bDl^4=K$|h%IYsxM(GX(PH_e%f1 z(0{OT7&&};6P8lm%+niw_FojpJ9Zwpdx#~R#-x2B4G=o)_WZI@b%JN2P{=hVM#-aV z;GUJ!w7dmESrdV`+d+4At6V)%@0kd4)<9Zt`4cDx@}uYagP0qMb0T!w`OJ@n)pWWy z-?r@diBk3todGgwRf&-*ir|$c?j0Yyd-~xrn8BIWaX8ASF!2*wG0M+deLfjMIJ+Ik zL&-`GYd4_C!`!p|^2bz8>L&7?A9p}ct|b?}C?WZYtW06|Js>m*CX%o8d6o1%_*owN zSW@H0KeB*OMxYX#OFHK=JI<)IU~rRs_TVLQcdUQQFS!@hanbw-o+2%k!2{N}uLkZJ zFX_EbcjMM^RAEYSq-z0J-&VufTYp&EU}kqeHGh&m*7vNLRpeU8Fc&DRI7GXIe;y%p zc_ef@c9_08R-4P&!^hPD1mOI)`6u&%_Vg#VtEbG?>4H#ao}bxpdGp$~OcAKuE`EbSWHjH&jeJBbh^407~-wU&s>kmWXGPeI=2fqP!mvh+IgANjv zYvk)oX{R^&zJqvAyo0>p({dYQHGMUVxgU~*n>`^^4TbP8=F z0$u|H`)pbmljKOU5gqE^pUrb5D~`{+UA~t0zA=rY>jbJj8&ov0pEJsolsxub@TY2Hu?IPrWgu)A>`nqK4()Jha~puqSlA?!f(jQY@_!>E$2ZQF&gJ# zwBaZ-u*UtV1HHAQ*C8?T=K&k6teL5Qkl7n4e^yS6#-=V^ZaCUjRXwuO2&0a(F41~Q zOahstZD6Fo69vq?&xPC07rs))lWu#g&I8u2$;;r@_Q3~_Sv0AD5f!N2jSrXQVVL-U z_kjuBz4qv9PFl7kPs4>M4ky4)r#6n0#`>?r=nZ5qw~@8f15X+RrHL>Wy(=u&KoNq` zC>e%pVTI?5U9~+w<^@63<}I&#bs5asZxDT5M@ZBK!wRiQ|6?YhLFLkT4z_`r!1Jhf zBOyoDO83iAe+=we@OpgG=arI3I*%9VDGj(Wu79h1J*|C7tCyopz4Nu=*qm6H7k6XY zU4zcG>vE}Im)g#qoeGMhmb>pD?j0RoJIMvl0o<-*G*luC1SJd+G z(ICyKTj!QYb2nFAxfsIs9huUFrH)#Goijc||QHi36Pv-(W_GCrP{G4WMZAFf+`*N*fQU z?hPxvPhhIQ(LT?to^3jjFH3d>0xoCh)N!d6ddpcA`x<=LYgXn*G~Yh2YwCbm zhVwE>=3ZVnB>Ag<=}vBGvffz9KsktFaV`EU=52XvclKY|NU&8D4q#>?%91@(;9sC& z;0MWWxu~AAf;|Ar!P!Q1X^8|X*3uFQqB7yq*>it#2@TYQMWuD62dzdwh|j)ba*tSaU?Nz2&U8Y?a(nIBWQ zvOg8ca+qa)4zr*RCwyJ^A$oYHgz8If`D3+yCRN}PE}(zt=8nv!*?@MGahR?%+vf90 zEX!m?$^17IYWGz2khB6=ze$5;1E0Na<3Ed{^O4t?Ny%m=m=?HlhQ5p(=a>~*Mvy^y0Thn7N@x;6sj*q;lg85g1X$qDYS9bweTkPg%rsI z+6^kx)ycC*1e)+3%6bid8SYZKCpffoaqg;qC$m5Abts@Ra23&N!OF8)KstfL?)+Hc zy5{oYXlMu2c@zEE*7HU-{hP;u_bE@2VR+M~k0>fd%}Y1WW=*1I-DC)>xl7>D0J(0r z=-Qw0qR4vC8YlT{H*Ju(|(s%6z;CTh*&8Y@AexoZ{Ryl&7lJ!Wyl_kMw%`3O^FohGMn7 z*j4*GiLHZzA90stg$$4pi34mn-vL~)DD!6kmPcheM8Ht3yO}vH=Fvq|&a_8?9o@Sj zuV#taUSu(w@zQS;ONh%21v9%xjFx!#&e_%>kC7rwmN?4}zq0KFAHQ&|XK4-%`eNC= zzm4yQTQ%f-7LiBpU0QBBF4NOx3(WhF`vpIuy~Hqdht5J4vz8b1Im)+Ge});Z3olkx z3ie^Wz15oeqD5cD$!X@ZwO_)~W6}G2s9tZ*doD4nSYdQDYmcv1N`EOU{VVU6GhOtc zCygB%|4&;TpEJ)+pgafr{IOkH!nD-pGAtTnIaBv$VG6C#>#0mp{PC|4qV0s(f2*{c zm{p+@Uq01T8UIc))!Peh)iHK<{gctIUa0LSIP5|h>>&aiX_o(IF1e6t8iIyJOrktD z-vVI36u4r@PIKhkFIg**k50V*`cFI2Nv%@y2~!~1tzce@)~5z8vcE|u05Tc znFR!6l=zL(&=2#ObffLPFj3Kv1)$!lYHb=A55MTEM&=J&Ig0zc1LyPoX6i(>Jk}4F zc2_RqW7sM1w@Gho%#>@Q5E7spma}xGi;e zI;LW~YjK$SO?IXITgK&28Pcs!NRiQ2@L7v@d(Q~;MvML;rR6>u1Pk{$e7t?>FrB*& zEB$7Z6sBgmY+k0WrXT&6bC?M}JIEcPI{wFQ!w_CtpNj?`S`$hmn_U-Kp<^T7KCr@@ zOafFV&*;JZ(PpOpt?7bi5w~cEnEBkQ#Ig?A!oXU4g`od>wPgNrH_PqS43RCXEAD7M zI1`~Z#T`V^O9@#x(+k)YH5z1FR~qf?kozIA@Eo=}z`!4bM&EL(r018yb4 zVKSx*-@I#N1EuAXY?U6Lvwsea&-udh6yp2pa*P8DMwsQl(6wECPAM5r?ltm$wVEAw zxbV=>DE(nCpm2*uLOe-OIHPcADu*CbNQ+b;qb?&~73f!aqT`uC!1bpTQZ=W1ECdMt zoK`Oa_0(wpZqN2=uZPhiP{D04J-fLM@7W=GAOuyZ7AEoa<|xRCQH32c{l5U%2`BcZ z@2M)Bjux=0gLKuVTa#nqq?>KmjLd{Yqgx`kBqppsi_kQU+)hYbC*VeHQR?-*!gWx} z%kXOKlRO{qX?t5;*I8n<9|X}4U@@z-@Qvi0))s&B+U8ro+fu^6i}zLO#DLIN)>cw@ zonKpO5LxW^x)Qg+tRuGKw0>KZx@|BT)vd&}OKH;aH@jIFW*8uPCM1Xo zBCm5QP2vrlo9!3O+ci7RK*IyouUj$C(Wv!`dIiv0|DTrD{btEkn~N6KWuZWMviKF! z7M&Cac1yRGCc1v@Nolig(cQO1X-3k;+Wk%rnsI4KjwH*c=!MOFl zn$~S9=ZglyMWS_Lk&eK83Ywk2&E<3(QIo;R+FxsBy~3kkq&RS}NPO|BXtT7~T~bI@ zWvmDZ6FVC*z#pfmlSzVN*=nEN|Q!$66rp84(jye>Od@G_xG&3 z`ZIeWL}xSlI%&s|b9Xe+Q2y!`tP#rbPmj1djArSx9TROylJGztUC+E`!K;*8qQtEy zpN2O_O~N^QRbPb5lRrnB@YqFS;whq-?A?8|E1x)uNX+`yakTvUft0eAyOtvrb?2PT~Hz zx?GDQ($SZ!l_gsh30hYMGtefOvC^(swpz$pb1q)#$Rfe_RJH4Fj+$tosoiWx_1$T; z-us^9bqfxF#$srOm3g2WAx9Az&c{p|9|Xy9f{Ss zfK5;AJ~}^l@l;n9nCH!jSamGrCD2qqj{U=|&Gl0{HA&m~8<8ZE7NjM5s0~Y~$T&QQ z-+|yeAq57vwIL#M*VO|+WP)$$MhDkoH`Y74VNgWycV<1%#YC|c{1sb`^RX9d+-fOU zL*WZ>j@0~C7(Xt>q+`pZPDr3{CR$dDl){K}2zt8$T?q$|Ze~Cj@an|%lLW0~z*%(o zWo{)c*Mzes%L2CmT3j+da|p76i6i9F%9isNUY?GbOe*VLxwBgo6tqHc&}71Wt42h+ zoG7+J?HEuqc{!KncS~hEikPOfQcU*je)hF?V~J8kUWcviYA&nZ`SrBc!lSBLnx1I{D{v>9`y zE1QqUw7lthqSy**4A_0FzWXCa^golP!i#hI0sgeQV;tI3YvGF&;oXUao@AS_^tEOoq4pMSvvLAz-kQMS=7cRcM zKdCR6a?@q``h6PJc|_^f-PDE~?Oc*e!oKs)H)f^fwI9;4Tub-^Ara{v_#K|R@XCRD zzaKB&>ASLr^OyC-;kma|I%*wPHUn|Cyu$ZqWo0t%o{ZbB5?nG0PUH7jPRi?ed8AOw z_Y2>*!mxV_*LTMY-o|a71rqB9R=*Ma9YP;bxBIX1y)nKk4-SNIW&gpWQ zpG^CU<X zx0ql{lvc5p8E=0aRpaGy2IFNnur~LHefy>-MHTx-u_=ELD?*H~UirEKZlkM5x^bjA zrUKD3Ev<8rGI0w`aSrqHb9cJNj-#$yodo$lLanHj Date: Wed, 25 Sep 2024 09:40:17 +0200 Subject: [PATCH 02/60] Addvances in the verifier script of 09/24/24 --- Stwo_wrapper/crates/prover/Cargo.toml | 1 - .../crates/prover/src/core/prover/mod.rs | 2 +- .../prover/src/examples/wide_fibonacci/mod.rs | 2 + ...{Untitled1.ipynb => verifier_script.ipynb} | 226 +++++++++++++----- 4 files changed, 172 insertions(+), 59 deletions(-) rename Stwo_wrapper/{Untitled1.ipynb => verifier_script.ipynb} (67%) diff --git a/Stwo_wrapper/crates/prover/Cargo.toml b/Stwo_wrapper/crates/prover/Cargo.toml index f316cf7..587e655 100644 --- a/Stwo_wrapper/crates/prover/Cargo.toml +++ b/Stwo_wrapper/crates/prover/Cargo.toml @@ -28,7 +28,6 @@ thiserror.workspace = true tracing.workspace = true rayon = { version = "1.10.0", optional = true } serde = { version = "1.0", features = ["derive"] } -light-poseidon = {path = "../../../../../save/rust/snarks/light-poseidon/light-poseidon"} crypto-bigint = "0.5.5" ark-serialize = "0.4.0" serde_json = "1.0.116" diff --git a/Stwo_wrapper/crates/prover/src/core/prover/mod.rs b/Stwo_wrapper/crates/prover/src/core/prover/mod.rs index 30493fc..ce414aa 100644 --- a/Stwo_wrapper/crates/prover/src/core/prover/mod.rs +++ b/Stwo_wrapper/crates/prover/src/core/prover/mod.rs @@ -118,7 +118,7 @@ pub fn verify( let composition_oods_eval = extract_composition_eval(sampled_oods_values).map_err(|_| { VerificationError::InvalidStructure("Unexpected sampled_values structure".to_string()) })?; - + println!("composition_oods_eval = {:?}",composition_oods_eval); if composition_oods_eval // Compute != components.eval_composition_polynomial_at_point( diff --git a/Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs b/Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs index 2bd0381..50137b4 100644 --- a/Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -400,6 +400,7 @@ mod tests { #[cfg(not(target_arch = "wasm32"))] use crate::core::vcs::poseidon_bls_merkle::PoseidonBLSMerkleChannel; use crate::core::ColumnVec; + use crate::core::fields::qm31::QM31; use crate::examples::wide_fibonacci::{generate_trace, FibInput, WideFibonacciComponent}; const FIB_SEQUENCE_LENGTH: usize = 100; @@ -604,6 +605,7 @@ mod tests { ) .unwrap(); _ = pretty_save_poseidon_bls_proof(&proof); + println!(" 0 1 0 0 = {:?}",QM31::from_u32_unchecked(0,1,0,0)); // Verify. let verifier_channel = &mut PoseidonBLSChannel::default(); diff --git a/Stwo_wrapper/Untitled1.ipynb b/Stwo_wrapper/verifier_script.ipynb similarity index 67% rename from Stwo_wrapper/Untitled1.ipynb rename to Stwo_wrapper/verifier_script.ipynb index e504f91..7d4b991 100644 --- a/Stwo_wrapper/Untitled1.ipynb +++ b/Stwo_wrapper/verifier_script.ipynb @@ -174,7 +174,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 128, "id": "67a09953", "metadata": {}, "outputs": [], @@ -195,7 +195,33 @@ " remainder = cur % shift\n", " cur = quotient\n", " u32s.append(M31(remainder))\n", - " return u32s" + " return u32s\n", + "\n", + "def draw_random_bytes(digest, n_sent):\n", + " shift = 1 << 8\n", + " cur = int(draw_felt252(digest, n_sent))\n", + " byte = []\n", + " for i in range(31):\n", + " quotient = cur // shift\n", + " remainder = cur % shift\n", + " cur = quotient\n", + " byte.append(remainder)\n", + " return byte\n", + "\n", + "def hash_node(has_children, left, right, column_values):\n", + " n_column_blocks = ceil(len(column_values) / 8.0);\n", + " values = []\n", + " if has_children:\n", + " values.append(left)\n", + " values.append(right)\n", + " padding_length = 8 * n_column_blocks - len(column_values)\n", + " padded_values = column_values + [F(0) for i in range(padding_length)]\n", + " for i in range(int(len(padded_values) / 8)):\n", + " word = F(0)\n", + " for j in range(8):\n", + " word = word * F(2**31) + F(padded_values[i*8+j])\n", + " values.append(word)\n", + " return poseidon_hash_many_bls(values)" ] }, { @@ -296,7 +322,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 312, "id": "d9371b02", "metadata": {}, "outputs": [], @@ -326,83 +352,169 @@ "point = oods_point\n", "mask_values = proof[\"sampled_values_0\"] + proof[\"sampled_values_1\"]\n", "\n", - "accumulator = random_coeff\n", + "\n", + "\n", "\n", "# Evaluate random point on the vanishing polynomial of the coset\n", "evaluation = point.x\n", "for i in range(5):\n", " evaluation = CirclePoint.double_x(evaluation)\n", - "evaluation_inverse = evaluation ** (-1)" + "evaluation_inverse = evaluation ** (-1)\n", + "\n", + "# Compute evaluation using the mask values\n", + "accumulator = QM31(0)\n", + "a = QM31([mask_values[0][:2],mask_values[0][2:4]])\n", + "b = QM31([mask_values[1][:2],mask_values[1][2:4]])\n", + "for i in range(len(proof[\"sampled_values_0\"])-2):\n", + " c = QM31([mask_values[2+i][:2],mask_values[2+i][2:4]])\n", + " accumulator = (c - (a**2 + b**2))*evaluation_inverse + accumulator * random_coeff\n", + " a = b\n", + " b = c\n", + "\n", + "expected_value = QM31([proof[\"sampled_values_1\"][0][:2],proof[\"sampled_values_1\"][0][2:4]]) + QM31([proof[\"sampled_values_1\"][1][:2],proof[\"sampled_values_1\"][1][2:4]]) * QM31([[0,1],[0,0]]) + QM31([proof[\"sampled_values_1\"][2][:2],proof[\"sampled_values_1\"][2][2:4]]) * QM31([[0,0],[1,0]]) + QM31([proof[\"sampled_values_1\"][3][:2],proof[\"sampled_values_1\"][3][2:4]]) * QM31([[0,0],[0,1]])\n", + "\n", + "assert(expected_value == accumulator)\n", + "\n", + "\n", + "####FRI part\n", + "# Fiat-Shamir\n", + "shift = F(1 << 31)\n", + "res = [digest]\n", + "for i in range(len(mask_values) / 2):\n", + " cur = F(0)\n", + " for j in range(2):\n", + " for k in range(4):\n", + " cur = cur * shift + F(mask_values[2*i+j][k])\n", + " res.append(cur)\n", + "digest = poseidon_hash_many_bls(res)\n", + "random_coeff = QM31([draw_base_felts(digest,0)[:2],draw_base_felts(digest,0)[2:4]])\n", + "\n", + "# Verify commitment stage\n", + "circle_poly_alpha = QM31([draw_base_felts(digest,1)[:2],draw_base_felts(digest,1)[2:4]])\n", + "folding_alpha = []\n", + "for i in range(6):\n", + " digest = poseidon_hash_bls(digest,proof[\"inner_commitment_\"+str(i)])\n", + " folding_alpha.append(QM31([draw_base_felts(digest,0)[:2],draw_base_felts(digest,0)[2:4]]))\n", + "last_layer_poly = QM31([proof[\"coeffs\"][:2],proof[\"coeffs\"][2:4]])\n", + "\n", + "#Check proof of work\n", + "res = [digest]\n", + "cur = F(0)\n", + "for i in range(4):\n", + " cur = cur * shift + F(proof[\"coeffs\"][i])\n", + "res.append(cur)\n", + "digest = poseidon_hash_many_bls(res)\n", + "digest = poseidon_hash_bls(digest, proof[\"proof of work\"])\n", + "# The first 5 bits following the first 1 must be 0 for proof of work\n", + "for i in range(5):\n", + " assert(bin(digest)[3+i] == '0')\n", + "\n", + "# Compute openings positions\n", + "querries = []\n", + "max_querry = (1<<8)-1\n", + "random_bytes = draw_random_bytes(digest,0)\n", + "# The number of required querries is 3 so we need 3*8 bytes wich is enough with the 31 bytes of random_bytes\n", + "for i in range(3):\n", + " querry_bits = int().from_bytes(random_bytes[i*4:i*4+4],byteorder=\"little\")\n", + " querries.append(querry_bits & max_querry)\n", + "positions = []\n", + "# For each collumn\n", + "for i in range(2):\n", + " # For each querry\n", + " for j in range(3):\n", + " positions.append(querries[j] >> (1+i))\n", + "positions_sav = positions.copy()\n", + "\n", + "\n", + "#Verify decommitments (this part must be simplified repeting duplicated inputs)\n", + "# First tree\n", + "nodes = {}\n", + "for i in range(6):\n", + " leaf = []\n", + " for j in range(100):\n", + " leaf.append(M31(proof[\"queried_values_0\"][j][i]))\n", + " nodes[\"level_0_node_\"+str(positions[i//2] // 2 * 2 + i % 2) ] = hash_node(false,0,0,leaf)\n", + "counter = 0\n", + "for level in range(7):\n", + " for i in range(3):\n", + " if \"level_\"+str(level)+\"_node_\"+str(positions[len(positions)-3] * 2) in nodes:\n", + " lhs = nodes[\"level_\"+str(level)+\"_node_\"+str(positions[len(positions)-3] * 2)]\n", + " else:\n", + " lhs = F(proof[\"decommitment_0\"][counter])\n", + " counter += 1\n", + " if \"level_\"+str(level)+\"_node_\"+str(positions[len(positions) - 3] * 2 + 1) in nodes:\n", + " rhs = nodes[\"level_\"+str(level)+\"_node_\"+str(positions[len(positions) - 3] * 2 + 1)]\n", + " else:\n", + " rhs = F(proof[\"decommitment_0\"][counter])\n", + " counter += 1\n", + " nodes[\"level_\"+str(level+1)+\"_node_\"+str(positions[len(positions)-3])] = hash_node(true,lhs,rhs,[])\n", + " positions.append(positions[len(positions)-3] // 2)\n", + "assert(nodes[\"level_7_node_0\"] == F(proof[\"commitments\"][0]))\n", + "\n", + "#Second tree\n", + "positions = positions_sav\n", + "for i in range(len(positions)):\n", + " positions[i] = positions[i] * 2\n", + "nodes = {}\n", + "for i in range(6):\n", + " leaf = []\n", + " for j in range(4):\n", + " leaf.append(M31(proof[\"queried_values_1\"][j][i]))\n", + " nodes[\"level_0_node_\"+str(positions[i//2] // 2 * 2 + i % 2) ] = hash_node(false,0,0,leaf)\n", + "counter = 0\n", + "for level in range(8):\n", + " for i in range(3):\n", + " if \"level_\"+str(level)+\"_node_\"+str(positions[len(positions)-3] * 2) in nodes:\n", + " lhs = nodes[\"level_\"+str(level)+\"_node_\"+str(positions[len(positions)-3] * 2)]\n", + " else:\n", + " lhs = F(proof[\"decommitment_1\"][counter])\n", + " counter += 1\n", + " if \"level_\"+str(level)+\"_node_\"+str(positions[len(positions) - 3] * 2 + 1) in nodes:\n", + " rhs = nodes[\"level_\"+str(level)+\"_node_\"+str(positions[len(positions) - 3] * 2 + 1)]\n", + " else:\n", + " rhs = F(proof[\"decommitment_1\"][counter])\n", + " counter += 1\n", + " nodes[\"level_\"+str(level+1)+\"_node_\"+str(positions[len(positions)-3])] = hash_node(true,lhs,rhs,[])\n", + " if level == 0:\n", + " nodes[\"level_1_node_\"+str(positions[len(positions)-3] + 1)] = hash_node(true,nodes[\"level_0_node_\"+str((positions[len(positions)-3] + 1)* 2)],nodes[\"level_0_node_\"+str((positions[len(positions)-3] + 1)* 2 + 1)],[])\n", + " positions.append(positions[len(positions)-3] // 2)\n", + "assert(nodes[\"level_8_node_0\"] == F(proof[\"commitments\"][1]))\n", + "\n", + "\n", + "# Check querries answer\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "3adb682e", + "id": "56e4c016", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", - "execution_count": 11, - "id": "59b44767", + "execution_count": null, + "id": "6b89599c", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(1001989877*i + 1100649138)*u + 462165474*i + 673026348" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "evaluation_inverse" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "7790a0f9", - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'PolynomialQuotientRing_field_with_category.element_class' object has no attribute 'double'", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/tmp/ipykernel_10017/940279797.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[0ma\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mQM31\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m \u001b[0mb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mQM31\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2129160320\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1109509513\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m787887008\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mInteger\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1676461964\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0ma\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdouble\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mb\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdouble\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32m/usr/lib/python3/dist-packages/sage/structure/element.pyx\u001b[0m in \u001b[0;36msage.structure.element.Element.__getattr__ (build/cythonized/sage/structure/element.c:4827)\u001b[1;34m()\u001b[0m\n\u001b[0;32m 492\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;34m'LeftZeroSemigroup_with_category.element_class'\u001b[0m \u001b[0mobject\u001b[0m \u001b[0mhas\u001b[0m \u001b[0mno\u001b[0m \u001b[0mattribute\u001b[0m \u001b[1;34m'blah_blah'\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 493\u001b[0m \"\"\"\n\u001b[1;32m--> 494\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgetattr_from_category\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 495\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 496\u001b[0m \u001b[0mcdef\u001b[0m \u001b[0mgetattr_from_category\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m/usr/lib/python3/dist-packages/sage/structure/element.pyx\u001b[0m in \u001b[0;36msage.structure.element.Element.getattr_from_category (build/cythonized/sage/structure/element.c:4939)\u001b[1;34m()\u001b[0m\n\u001b[0;32m 505\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 506\u001b[0m \u001b[0mcls\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mP\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_abstract_element_class\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 507\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mgetattr_from_other_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcls\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 508\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 509\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m__dir__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m/usr/lib/python3/dist-packages/sage/cpython/getattr.pyx\u001b[0m in \u001b[0;36msage.cpython.getattr.getattr_from_other_class (build/cythonized/sage/cpython/getattr.c:2636)\u001b[1;34m()\u001b[0m\n\u001b[0;32m 354\u001b[0m \u001b[0mdummy_error_message\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcls\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtype\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 355\u001b[0m \u001b[0mdummy_error_message\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mname\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 356\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdummy_error_message\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 357\u001b[0m \u001b[0mcdef\u001b[0m \u001b[0mPyObject\u001b[0m\u001b[1;33m*\u001b[0m \u001b[0mattr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minstance_getattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcls\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 358\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mattr\u001b[0m \u001b[1;32mis\u001b[0m \u001b[0mNULL\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mAttributeError\u001b[0m: 'PolynomialQuotientRing_field_with_category.element_class' object has no attribute 'double'" - ] - } - ], - "source": [ - "a = QM31([[1,0],[0,0]])\n", - "b = QM31([[2129160320,1109509513],[787887008,1676461964]])\n", - "a.double() + b.double()\n" - ] + "outputs": [], + "source": [] }, { "cell_type": "code", "execution_count": null, - "id": "5f95a975", + "id": "e486cb3a", "metadata": {}, "outputs": [], - "source": [ - "\n", - "[[\"1\",\"0\",\"0\",\"0\"],\n", - "\t\t[\"2129160320\",\"1109509513\",\"787887008\",\"1676461964\"],\n", - "\t\t[\"262908602\",\"915488457\",\"1893945291\",\"1774327476\"]" - ] + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85dd52b3", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From 5c7ff0fad5e29e6f147a9df532904f378cc86a28 Mon Sep 17 00:00:00 2001 From: thomaslavaur Date: Wed, 25 Sep 2024 14:14:48 +0200 Subject: [PATCH 03/60] Last update --- Stwo_wrapper/crates/prover/proof.json | 12 ++--- Stwo_wrapper/crates/prover/src/core/fri.rs | 1 - .../crates/prover/src/core/pcs/quotients.rs | 6 +-- .../crates/prover/src/core/pcs/verifier.rs | 4 +- .../crates/prover/src/core/prover/mod.rs | 2 +- .../crates/prover/src/core/vcs/verifier.rs | 3 -- .../prover/src/examples/wide_fibonacci/mod.rs | 6 +-- Stwo_wrapper/verifier_script.ipynb | 52 ++++++++++++++++--- 8 files changed, 61 insertions(+), 25 deletions(-) diff --git a/Stwo_wrapper/crates/prover/proof.json b/Stwo_wrapper/crates/prover/proof.json index aecce7d..eef777e 100644 --- a/Stwo_wrapper/crates/prover/proof.json +++ b/Stwo_wrapper/crates/prover/proof.json @@ -258,7 +258,7 @@ "coeffs" : ["329725079","667313404","2083859876","1645693780"], - "inner_commitment_0" : "19048435553851756854966583714228494720706220237110109487981675465058006706934", + "inner_commitment_0" : "45555755014923146766476222823122654194153582923386372098787021809602091298670", "inner_decommitment_0" : ["5462985033728555575703006689913665598917262836577853690370070265073174719979", @@ -280,7 +280,7 @@ ["3310770","2003547458","1663490902","2105455978"], ["824310523","1757518542","231582441","427507918"]], - "inner_commitment_1" : "18169576490546341141767046171814645645044172101340402424909425947818564934087", + "inner_commitment_1" : "36809196688918736785151875655523363274066842107451407514854169291514772437712", "inner_decommitment_1" : ["45836684762941279488847688946861562185184567552524644394596887832754938056979", @@ -299,7 +299,7 @@ ["1550507975","1444313548","117070947","1740854590"], ["342709844","601149328","1436490544","1384381104"]], - "inner_commitment_2" : "29215190960441505077529177935027039424488733477246157993605964254003539615792", + "inner_commitment_2" : "881819876414071785116043072319901261019430342891967010427739931379717752179", "inner_decommitment_2" : ["3578061893060038632121836895066391994380789049478144565795630968916166958170", @@ -315,7 +315,7 @@ ["716686867","138132852","2024080584","392488646"], ["606958913","308986056","258114411","2075401741"]], - "inner_commitment_3" : "33865913108527976063513611264849688092773721821534885464172308921150897546371", + "inner_commitment_3" : "12027868599227153144742193285247060272784688895537159385948117930165143367635", "inner_decommitment_3" : ["46510393127320994984678433868779353751380750916123221099220042551249644407575", @@ -328,7 +328,7 @@ ["2038210709","1600238918","655676259","1542271403"], ["1014579052","1384080403","862591487","1941843578"]], - "inner_commitment_4" : "233063347401903348432987619086774265933828676513621676247992038308030708388", + "inner_commitment_4" : "14412699124489796400221638504502701429059474709940645969751865021865037702257", "inner_decommitment_4" : ["19324767949149751902195880760061491860991545124249052461044586503970003688610"], @@ -338,7 +338,7 @@ ["2101847208","689925082","1602280602","1942656531"], ["1952987285","1995490213","2082219584","1620868519"]], - "inner_commitment_5" : "693572572477915449222328871061568563307509948116711167874020917953971074657", + "inner_commitment_5" : "7898658461322497542615494384418597990320440781916208640366283709415937814000", "inner_decommitment_5" : [], diff --git a/Stwo_wrapper/crates/prover/src/core/fri.rs b/Stwo_wrapper/crates/prover/src/core/fri.rs index db3e37d..0934c77 100644 --- a/Stwo_wrapper/crates/prover/src/core/fri.rs +++ b/Stwo_wrapper/crates/prover/src/core/fri.rs @@ -3,7 +3,6 @@ use std::collections::BTreeMap; use std::fmt::Debug; use std::iter::zip; use std::ops::RangeInclusive; - use itertools::Itertools; use num_traits::Zero; use thiserror::Error; diff --git a/Stwo_wrapper/crates/prover/src/core/pcs/quotients.rs b/Stwo_wrapper/crates/prover/src/core/pcs/quotients.rs index 1a41e83..1034e05 100644 --- a/Stwo_wrapper/crates/prover/src/core/pcs/quotients.rs +++ b/Stwo_wrapper/crates/prover/src/core/pcs/quotients.rs @@ -116,10 +116,10 @@ pub fn fri_answers( multiunzip(tuples); fri_answers_for_log_size( log_size, - &samples, + &samples, // Here it's the points and sampled values (in the proof) random_coeff, &query_domain_per_log_size[&log_size], - &queried_valued_per_column, + &queried_valued_per_column, // Here it's queried values (in the proof) ) }) .collect() @@ -141,7 +141,7 @@ pub fn fri_answers_for_log_size( )); } } - let mut queried_values_per_column = queried_values_per_column + let mut queried_values_per_column = queried_values_per_column .iter() .map(|q| q.iter()) .collect_vec(); diff --git a/Stwo_wrapper/crates/prover/src/core/pcs/verifier.rs b/Stwo_wrapper/crates/prover/src/core/pcs/verifier.rs index 812137f..da58a84 100644 --- a/Stwo_wrapper/crates/prover/src/core/pcs/verifier.rs +++ b/Stwo_wrapper/crates/prover/src/core/pcs/verifier.rs @@ -58,6 +58,7 @@ impl CommitmentSchemeVerifier { proof: CommitmentSchemeProof, channel: &mut MC::C, ) -> Result<(), VerificationError> { + channel.mix_felts(&proof.sampled_values.clone().flatten_cols()); let random_coeff = channel.draw_felt(); @@ -105,9 +106,10 @@ impl CommitmentSchemeVerifier { .0 .into_iter() .collect::>()?; + println!("DONE"); // Answer FRI queries. - let samples = sampled_points + let samples = sampled_points // Sample point is always the same and the value is the one provided in the proof .zip_cols(proof.sampled_values) .map_cols(|(sampled_points, sampled_values)| { zip(sampled_points, sampled_values) diff --git a/Stwo_wrapper/crates/prover/src/core/prover/mod.rs b/Stwo_wrapper/crates/prover/src/core/prover/mod.rs index ce414aa..30493fc 100644 --- a/Stwo_wrapper/crates/prover/src/core/prover/mod.rs +++ b/Stwo_wrapper/crates/prover/src/core/prover/mod.rs @@ -118,7 +118,7 @@ pub fn verify( let composition_oods_eval = extract_composition_eval(sampled_oods_values).map_err(|_| { VerificationError::InvalidStructure("Unexpected sampled_values structure".to_string()) })?; - println!("composition_oods_eval = {:?}",composition_oods_eval); + if composition_oods_eval // Compute != components.eval_composition_polynomial_at_point( diff --git a/Stwo_wrapper/crates/prover/src/core/vcs/verifier.rs b/Stwo_wrapper/crates/prover/src/core/vcs/verifier.rs index 53346bb..08428de 100644 --- a/Stwo_wrapper/crates/prover/src/core/vcs/verifier.rs +++ b/Stwo_wrapper/crates/prover/src/core/vcs/verifier.rs @@ -1,6 +1,5 @@ use std::cmp::Reverse; use std::collections::BTreeMap; - use itertools::Itertools; use thiserror::Error; @@ -152,10 +151,8 @@ impl MerkleVerifier { if node_values.len() != n_columns_in_layer { return Err(MerkleVerificationError::WitnessTooShort); } - layer_total_queries.push((node_index, H::hash_node(node_hashes, &node_values))); } - if !layer_queried_values.iter().all(|(_, c)| c.is_empty()) { return Err(MerkleVerificationError::ColumnValuesTooLong); } diff --git a/Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs b/Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs index 50137b4..1ac82d0 100644 --- a/Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/Stwo_wrapper/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -191,7 +191,7 @@ pub fn pretty_save_poseidon_bls_proof(proof: &StarkProof Date: Wed, 26 Mar 2025 11:46:34 +0100 Subject: [PATCH 04/60] Clean the repo and push the new PoL --- .../anemoi/anemoi_16_to_1_Jubjub.circom | 0 .../anemoi/anemoi_2_to_1_Jubjub.circom | 0 .../anemoi/anemoi_4_to_1_Jubjub.circom | 0 .../anemoi_Jubjub_16_to_1_constants.circom | 0 .../anemoi_Jubjub_2_to_1_constants.circom | 0 .../anemoi_Jubjub_4_to_1_constants.circom | 0 .../anemoi/script_setup_prover.sh | 0 .../poseidon/poseidon_16_to_1_Jubjub.circom | 0 .../poseidon/poseidon_2_to_1_Jubjub.circom | 0 .../poseidon/poseidon_4_to_1_Jubjub.circom | 0 .../poseidon_Jubjub_16_to_1_constants.circom | 0 .../poseidon_Jubjub_2_to_1_constants.circom | 0 .../poseidon_Jubjub_4_to_1_constants.circom | 0 circom_circuits/hash_bn/poseidon2_hash.circom | 19 + circom_circuits/hash_bn/poseidon2_perm.circom | 218 +++ .../hash_bn/poseidon2_sponge.circom | 127 ++ circom_circuits/ledger/merkle.circom | 69 + circom_circuits/ledger/notes.circom | 51 + circom_circuits/misc/comparator.circom | 154 ++ .../proof_of_leadership/PoL.circom | 197 +++ .../proof_of_leadership/generate_inputs.py | 336 ++++ .../Circom/proof_of_equivalence.circom | 103 -- proof_of_leadership/circom/generate_inputs.py | 1342 ---------------- .../circom/leadership_anemoi.circom | 301 ---- .../circom/leadership_anemoi_sha.circom | 317 ---- .../circom/leadership_poseidon.circom | 301 ---- .../circom/leadership_poseidon_sha.circom | 318 ---- .../circom/leadership_sha256.circom | 683 -------- .../Untitled-checkpoint.ipynb | 6 - proof_of_validator/circom/generate_inputs.py | 1383 ----------------- .../circom/validator_Caulk.circom | 209 --- .../circom/validator_anemoi.circom | 226 --- .../circom/validator_poseidon.circom | 227 --- 33 files changed, 1171 insertions(+), 5416 deletions(-) rename circom_circuits/{hash => hash_bls}/anemoi/anemoi_16_to_1_Jubjub.circom (100%) rename circom_circuits/{hash => hash_bls}/anemoi/anemoi_2_to_1_Jubjub.circom (100%) rename circom_circuits/{hash => hash_bls}/anemoi/anemoi_4_to_1_Jubjub.circom (100%) rename circom_circuits/{hash => hash_bls}/anemoi/anemoi_Jubjub_16_to_1_constants.circom (100%) rename circom_circuits/{hash => hash_bls}/anemoi/anemoi_Jubjub_2_to_1_constants.circom (100%) rename circom_circuits/{hash => hash_bls}/anemoi/anemoi_Jubjub_4_to_1_constants.circom (100%) rename circom_circuits/{hash => hash_bls}/anemoi/script_setup_prover.sh (100%) rename circom_circuits/{hash => hash_bls}/poseidon/poseidon_16_to_1_Jubjub.circom (100%) rename circom_circuits/{hash => hash_bls}/poseidon/poseidon_2_to_1_Jubjub.circom (100%) rename circom_circuits/{hash => hash_bls}/poseidon/poseidon_4_to_1_Jubjub.circom (100%) rename circom_circuits/{hash => hash_bls}/poseidon/poseidon_Jubjub_16_to_1_constants.circom (100%) rename circom_circuits/{hash => hash_bls}/poseidon/poseidon_Jubjub_2_to_1_constants.circom (100%) rename circom_circuits/{hash => hash_bls}/poseidon/poseidon_Jubjub_4_to_1_constants.circom (100%) create mode 100644 circom_circuits/hash_bn/poseidon2_hash.circom create mode 100644 circom_circuits/hash_bn/poseidon2_perm.circom create mode 100644 circom_circuits/hash_bn/poseidon2_sponge.circom create mode 100644 circom_circuits/ledger/merkle.circom create mode 100644 circom_circuits/ledger/notes.circom create mode 100644 circom_circuits/misc/comparator.circom create mode 100644 circom_circuits/proof_of_leadership/PoL.circom create mode 100755 circom_circuits/proof_of_leadership/generate_inputs.py delete mode 100644 proof_of_equivalence/Circom/proof_of_equivalence.circom delete mode 100755 proof_of_leadership/circom/generate_inputs.py delete mode 100644 proof_of_leadership/circom/leadership_anemoi.circom delete mode 100644 proof_of_leadership/circom/leadership_anemoi_sha.circom delete mode 100644 proof_of_leadership/circom/leadership_poseidon.circom delete mode 100644 proof_of_leadership/circom/leadership_poseidon_sha.circom delete mode 100644 proof_of_leadership/circom/leadership_sha256.circom delete mode 100644 proof_of_validator/.ipynb_checkpoints/Untitled-checkpoint.ipynb delete mode 100755 proof_of_validator/circom/generate_inputs.py delete mode 100644 proof_of_validator/circom/validator_Caulk.circom delete mode 100644 proof_of_validator/circom/validator_anemoi.circom delete mode 100644 proof_of_validator/circom/validator_poseidon.circom diff --git a/circom_circuits/hash/anemoi/anemoi_16_to_1_Jubjub.circom b/circom_circuits/hash_bls/anemoi/anemoi_16_to_1_Jubjub.circom similarity index 100% rename from circom_circuits/hash/anemoi/anemoi_16_to_1_Jubjub.circom rename to circom_circuits/hash_bls/anemoi/anemoi_16_to_1_Jubjub.circom diff --git a/circom_circuits/hash/anemoi/anemoi_2_to_1_Jubjub.circom b/circom_circuits/hash_bls/anemoi/anemoi_2_to_1_Jubjub.circom similarity index 100% rename from circom_circuits/hash/anemoi/anemoi_2_to_1_Jubjub.circom rename to circom_circuits/hash_bls/anemoi/anemoi_2_to_1_Jubjub.circom diff --git a/circom_circuits/hash/anemoi/anemoi_4_to_1_Jubjub.circom b/circom_circuits/hash_bls/anemoi/anemoi_4_to_1_Jubjub.circom similarity index 100% rename from circom_circuits/hash/anemoi/anemoi_4_to_1_Jubjub.circom rename to circom_circuits/hash_bls/anemoi/anemoi_4_to_1_Jubjub.circom diff --git a/circom_circuits/hash/anemoi/anemoi_Jubjub_16_to_1_constants.circom b/circom_circuits/hash_bls/anemoi/anemoi_Jubjub_16_to_1_constants.circom similarity index 100% rename from circom_circuits/hash/anemoi/anemoi_Jubjub_16_to_1_constants.circom rename to circom_circuits/hash_bls/anemoi/anemoi_Jubjub_16_to_1_constants.circom diff --git a/circom_circuits/hash/anemoi/anemoi_Jubjub_2_to_1_constants.circom b/circom_circuits/hash_bls/anemoi/anemoi_Jubjub_2_to_1_constants.circom similarity index 100% rename from circom_circuits/hash/anemoi/anemoi_Jubjub_2_to_1_constants.circom rename to circom_circuits/hash_bls/anemoi/anemoi_Jubjub_2_to_1_constants.circom diff --git a/circom_circuits/hash/anemoi/anemoi_Jubjub_4_to_1_constants.circom b/circom_circuits/hash_bls/anemoi/anemoi_Jubjub_4_to_1_constants.circom similarity index 100% rename from circom_circuits/hash/anemoi/anemoi_Jubjub_4_to_1_constants.circom rename to circom_circuits/hash_bls/anemoi/anemoi_Jubjub_4_to_1_constants.circom diff --git a/circom_circuits/hash/anemoi/script_setup_prover.sh b/circom_circuits/hash_bls/anemoi/script_setup_prover.sh similarity index 100% rename from circom_circuits/hash/anemoi/script_setup_prover.sh rename to circom_circuits/hash_bls/anemoi/script_setup_prover.sh diff --git a/circom_circuits/hash/poseidon/poseidon_16_to_1_Jubjub.circom b/circom_circuits/hash_bls/poseidon/poseidon_16_to_1_Jubjub.circom similarity index 100% rename from circom_circuits/hash/poseidon/poseidon_16_to_1_Jubjub.circom rename to circom_circuits/hash_bls/poseidon/poseidon_16_to_1_Jubjub.circom diff --git a/circom_circuits/hash/poseidon/poseidon_2_to_1_Jubjub.circom b/circom_circuits/hash_bls/poseidon/poseidon_2_to_1_Jubjub.circom similarity index 100% rename from circom_circuits/hash/poseidon/poseidon_2_to_1_Jubjub.circom rename to circom_circuits/hash_bls/poseidon/poseidon_2_to_1_Jubjub.circom diff --git a/circom_circuits/hash/poseidon/poseidon_4_to_1_Jubjub.circom b/circom_circuits/hash_bls/poseidon/poseidon_4_to_1_Jubjub.circom similarity index 100% rename from circom_circuits/hash/poseidon/poseidon_4_to_1_Jubjub.circom rename to circom_circuits/hash_bls/poseidon/poseidon_4_to_1_Jubjub.circom diff --git a/circom_circuits/hash/poseidon/poseidon_Jubjub_16_to_1_constants.circom b/circom_circuits/hash_bls/poseidon/poseidon_Jubjub_16_to_1_constants.circom similarity index 100% rename from circom_circuits/hash/poseidon/poseidon_Jubjub_16_to_1_constants.circom rename to circom_circuits/hash_bls/poseidon/poseidon_Jubjub_16_to_1_constants.circom diff --git a/circom_circuits/hash/poseidon/poseidon_Jubjub_2_to_1_constants.circom b/circom_circuits/hash_bls/poseidon/poseidon_Jubjub_2_to_1_constants.circom similarity index 100% rename from circom_circuits/hash/poseidon/poseidon_Jubjub_2_to_1_constants.circom rename to circom_circuits/hash_bls/poseidon/poseidon_Jubjub_2_to_1_constants.circom diff --git a/circom_circuits/hash/poseidon/poseidon_Jubjub_4_to_1_constants.circom b/circom_circuits/hash_bls/poseidon/poseidon_Jubjub_4_to_1_constants.circom similarity index 100% rename from circom_circuits/hash/poseidon/poseidon_Jubjub_4_to_1_constants.circom rename to circom_circuits/hash_bls/poseidon/poseidon_Jubjub_4_to_1_constants.circom diff --git a/circom_circuits/hash_bn/poseidon2_hash.circom b/circom_circuits/hash_bn/poseidon2_hash.circom new file mode 100644 index 0000000..b54e0f4 --- /dev/null +++ b/circom_circuits/hash_bn/poseidon2_hash.circom @@ -0,0 +1,19 @@ +// +pragma circom 2.0.0; + +include "poseidon2_sponge.circom"; + +//------------------------------------------------------------------------------ +// Hash `n` field elements into 1, with approximately 254 bits of preimage security (?) +// (assuming bn128 scalar field. We use capacity=2, rate=1, t=3). + +template Poseidon2_hash(n) { + signal input inp[n]; + signal output out; + + component sponge = PoseidonSponge(3,2,n,1); + sponge.inp <== inp; + sponge.out[0] ==> out; +} + +//------------------------------------------------------------------------------ \ No newline at end of file diff --git a/circom_circuits/hash_bn/poseidon2_perm.circom b/circom_circuits/hash_bn/poseidon2_perm.circom new file mode 100644 index 0000000..15256fa --- /dev/null +++ b/circom_circuits/hash_bn/poseidon2_perm.circom @@ -0,0 +1,218 @@ +// +pragma circom 2.0.0; + +// +// The Poseidon2 permutation for bn128 and t=3 +// + +//------------------------------------------------------------------------------ +// The S-box + +template SBox() { + signal input inp; + signal output out; + + signal x2 <== inp*inp; + signal x4 <== x2*x2; + + out <== inp*x4; +} + +//------------------------------------------------------------------------------ +// partial or internal round + +template InternalRound(i) { + signal input inp[3]; + signal output out[3]; + + var round_consts[56] = + [ 0x15ce7e5ae220e8623a40b3a3b22d441eff0c9be1ae1d32f1b777af84eea7e38c + , 0x1bf60ac8bfff0f631983c93e218ca0d4a4059c254b4299b1d9984a07edccfaf0 + , 0x0fab0c9387cb2bec9dc11b2951088b9e1e1d2978542fc131f74a8f8fdac95b40 + , 0x07d085a48750738019784663bccd460656dc62c1b18964a0d27a5bd0c27ee453 + , 0x10d57b1fad99da9d3fe16cf7f5dae05be844f67b2e7db3472a2e96e167578bc4 + , 0x0c36c40f7bd1934b7d5525031467aa39aeaea461996a70eda5a2a704e1733bb0 + , 0x0e4b65a0f3e1f9d3166a2145063c999bd08a4679676d765f4d11f97ed5c080ae + , 0x1ce5561061120d5c7ea09da2528c4c041b9ad0f05d655f38b10d79878b69f29d + , 0x2d323f651c3da8f0e0754391a10fa111b25dfa00471edf5493c44dfc3f28add6 + , 0x05a0741ee5bdc3e099fd6bdad9a0865bc9ceecd13ea4e702e536dd370b8f1953 + , 0x176a2ec4746fc0e0eca9e5e11d6facaee05524a92e5785c8b8161780a4435136 + , 0x0691faf0f42a9ed97629b1ae0dc7f1b019c06dd852cb6efe57f7eeb1aa865aef + , 0x0e46cf138dad09d61b9a7cab95a23b5c8cb276874f3715598bacb55d5ad271de + , 0x0f18c3d95bac1ac424160d240cdffc2c44f7b6315ba65ed3ff2eff5b3e48b4f2 + , 0x2eea6af14b592ec45a4119ac1e6e6f0312ecd090a096e340d472283e543ddff7 + , 0x06b0d7a8f4ce97d049ae994139f5f71dca4899d4f1cd3dd83a32a89a58c0a8e6 + , 0x019df0b9828eed5892dd55c1ad6408196f6293d600ef4491703a1b37e119ba8e + , 0x08ca5e3c93817cdb1c2b2a12d02c779d74c1bb12b6668f3ab3ddd7837f3a4a00 + , 0x28382d747e3fd6cb2e0d8e8edd79c5313eed307a3517c11046245b1476e4f701 + , 0x0ca89aecd5675b77c8271765da98cfcb6875b3053d4742c9ff502861bd16ad28 + , 0x19046bc0b03ca90802ec83f212001e7ffd7f9224cfffae523451deb52eab3787 + , 0x036fd7dfa1c05110b3428e6abcc43e1de9abba915320c4a600f843bfb676ca51 + , 0x08f0a7abcb1a2f6595a9b7380c5028e3999db4fe5cb21892e5bb5cb11a7757ba + , 0x0b614acc1ce3fbe9048f8385e4ee24c3843deea186bacea3c904c9f6340ad8cb + , 0x00b2d98c5d988f9b41f2c98e017fc954a6ae423b2261575941f8eac8835d985c + , 0x1457f18555b7973ba5b311d57ec5d77e936980b97f5973875f1f7cc765a4fc95 + , 0x002b453debc1bee525cb751bc10641a6b86f847d696418cf1144950982591bfa + , 0x0c2af1abcc6ece77218315d2af445ccbfc6647b7af2510682882cc792c6bb8cf + , 0x0e2825d9eb84b59902a1adb49ac0c2c291dee7c45d2e8c30369a4d595039e8ad + , 0x297e2e86a8c672d39f3343b8dfce7a6f20f3571bfd5c8a28e3905aa2dcfeca44 + , 0x00d397281d902e49ec6504ba9186e806db9ad4fc8f86e7277aa7f1467eb6f9de + , 0x2fb7c89c372d7e2050e7377ed471000c73544a2b9fd66557f3577c09cac98b4b + , 0x16125247be4387a8c3e62490167f0cffdba02eda4f018d0b40639a13bb0cfef9 + , 0x2291fd9d442f2d9b97ab22f7d4d52c2a82e41f852cf620b144612650a39e26e8 + , 0x1eec61f16a275ae238540feaeeadfec56d32171b1cc393729d06f37f476fde71 + , 0x259ce871ba5dacbb48d8aed3d8513eef51558dc0b360f28c1a15dbfc5e7f6ca2 + , 0x2d3376a14ddbf95587e2f7567ff04fe13a3c7cb17363c8b9c5dd1d9262a210cb + , 0x13b843d9f65f4cddd7ce10d9cad9b8b99ac5e9a8c4269288173a91c0f3c3b084 + , 0x0b52e9b2f1aa9fd204e4a42c481cc76c704783e34114b8e93e026a50fa9764e8 + , 0x1fd083229276c7f27d3ad941476b394ff37bd44d3a1e9caca1400d9077a2056c + , 0x22743c328a6283f3ba7379af22c684c498568fd7ad9fad5151368c913197cbd9 + , 0x043007aefd9741070d95caaaba0c1b070e4eec8eef8c1e512c8e579c6ed64f76 + , 0x17ab175144f64bc843074f6b3a0c57c5dd2c954af8723c029ee642539496a7b3 + , 0x2befcad3d53fba5eeef8cae9668fed5c1e9e596a46e8458e218f7a665fddf4eb + , 0x15151c4116d97de74bfa6ca3178f73c8fe8fe612c70c6f85a7a1551942cb71cc + , 0x2ac40bf6c3176300a6835d5fc7cc4fd5e5d299fb1baa86487268ec1b9eedfa97 + , 0x0f151de1f01b4e24ffe04279318f0a68efabb485188f191e37e6915ff6059f6e + , 0x2e43dffc34537535182aebac1ad7bf0a5533b88f65f9652f0ad584e2ffc4dd1f + , 0x2ebabc2c37ef53d8b13b24a2a2b729d536735f58956125a3876da0664c2442d7 + , 0x0dc3beceb34e49f5ad7226dd202c5cf879dffcc9a6dd32a300e8f2a4b59edf03 + , 0x2f1ddeccce83adf68779c53b639871a8f81d4d00aefe1e812efce8ec999d457d + , 0x1f63e41280ff5c021715d52b19780298ed8bd3d5eb506316b527e24149d4d4f1 + , 0x1b8c1252a5888f8cb2672effb5df49c633d3fd7183271488a1c40d0f88e7636e + , 0x0f45697130f5498e2940568ef0d5e9e16b1095a6cdbb6411df20a973c605e70b + , 0x0780ccc403cdd68983acbd34cda41cacfb2cf911a93076bc25587b4b0aed4929 + , 0x238d26ca97c691591e929f32199a643550f325f23a85d420080b289d7cecc9d4 + ]; + + component sb = SBox(); + sb.inp <== inp[0] + round_consts[i]; + + out[0] <== 2*sb.out + inp[1] + inp[2]; + out[1] <== sb.out + 2*inp[1] + inp[2]; + out[2] <== sb.out + inp[1] + 3*inp[2]; + +} + +//------------------------------------------------------------------------------ +// external rounds + +template ExternalRound(i) { + signal input inp[3]; + signal output out[3]; + + var round_consts[8][3] = + + [ [ 0x2c4c51fd1bb9567c27e99f5712b49e0574178b41b6f0a476cddc41d242cf2b43 + , 0x1c5f8d18acb9c61ec6fcbfcda5356f1b3fdee7dc22c99a5b73a2750e5b054104 + , 0x2d3c1988b4541e4c045595b8d574e98a7c2820314a82e67a4e380f1c4541ba90 + ] + , [ 0x052547dc9e6d936cab6680372f1734c39f490d0cb970e2077c82f7e4172943d3 + , 0x29d967f4002adcbb5a6037d644d36db91f591b088f69d9b4257694f5f9456bc2 + , 0x0350084b8305b91c426c25aeeecafc83fc5feec44b9636cb3b17d2121ec5b88a + ] + , [ 0x1815d1e52a8196127530cc1e79f07a0ccd815fb5d94d070631f89f6c724d4cbe + , 0x17b5ba882530af5d70466e2b434b0ccb15b7a8c0138d64455281e7724a066272 + , 0x1c859b60226b443767b73cd1b08823620de310bc49ea48662626014cea449aee + ] + , [ 0x1b26e7f0ac7dd8b64c2f7a1904c958bb48d2635478a90d926f5ff2364effab37 + , 0x2da7f36850e6c377bdcdd380efd9e7c419555d3062b0997952dfbe5c54b1a22e + , 0x17803c56450e74bc6c7ff97275390c017f682db11f3f4ca6e1f714efdfb9bd66 + ] + , [ 0x25672a14b5d085e31a30a7e1d5675ebfab034fb04dc2ec5e544887523f98dede + , 0x0cf702434b891e1b2f1d71883506d68cdb1be36fa125674a3019647b3a98accd + , 0x1837e75235ff5d112a5eddf7a4939448748339e7b5f2de683cf0c0ae98bdfbb3 + ] + , [ 0x1cd8a14cff3a61f04197a083c6485581a7d836941f6832704837a24b2d15613a + , 0x266f6d85be0cef2ece525ba6a54b647ff789785069882772e6cac8131eecc1e4 + , 0x0538fde2183c3f5833ecd9e07edf30fe977d28dd6f246d7960889d9928b506b3 + ] + , [ 0x07a0693ff41476abb4664f3442596aa8399fdccf245d65882fce9a37c268aa04 + , 0x11eb49b07d33de2bd60ea68e7f652beda15644ed7855ee5a45763b576d216e8e + , 0x08f8887da6ce51a8c06041f64e22697895f34bacb8c0a39ec12bf597f7c67cfc + ] + , [ 0x2a912ec610191eb7662f86a52cc64c0122bd5ba762e1db8da79b5949fdd38092 + , 0x2031d7fd91b80857aa1fef64e23cfad9a9ba8fe8c8d09de92b1edb592a44c290 + , 0x0f81ebce43c47711751fa64d6c007221016d485641c28c507d04fd3dc7fba1d2 + ] + ]; + + component sb[3]; + for(var j=0; j<3; j++) { + sb[j] = SBox(); + sb[j].inp <== inp[j] + round_consts[i][j]; + } + + out[0] <== 2*sb[0].out + sb[1].out + sb[2].out; + out[1] <== sb[0].out + 2*sb[1].out + sb[2].out; + out[2] <== sb[0].out + sb[1].out + 2*sb[2].out; +} + +//------------------------------------------------------------------------------ +// the initial linear layer + +template LinearLayer() { + signal input inp[3]; + signal output out[3]; + out[0] <== 2*inp[0] + inp[1] + inp[2]; + out[1] <== inp[0] + 2*inp[1] + inp[2]; + out[2] <== inp[0] + inp[1] + 2*inp[2]; +} + +//------------------------------------------------------------------------------ +// the Poseidon2 permutation for t=3 + +template Permutation() { + signal input inp[3]; + signal output out[3]; + + signal aux[65][3]; + + component ll = LinearLayer(); + for(var j=0; j<3; j++) { ll.inp[j] <== inp[j]; } + for(var j=0; j<3; j++) { ll.out[j] ==> aux[0][j]; } + + component ext[8]; + for(var k=0; k<8; k++) { ext[k] = ExternalRound(k); } + + component int[56]; + for(var k=0; k<56; k++) { int[k] = InternalRound(k); } + + // first 4 external rounds + for(var k=0; k<4; k++) { + for(var j=0; j<3; j++) { ext[k].inp[j] <== aux[k ][j]; } + for(var j=0; j<3; j++) { ext[k].out[j] ==> aux[k+1][j]; } + } + + // the 56 internal rounds + for(var k=0; k<56; k++) { + for(var j=0; j<3; j++) { int[k].inp[j] <== aux[k+4][j]; } + for(var j=0; j<3; j++) { int[k].out[j] ==> aux[k+5][j]; } + } + + // last 4 external rounds + for(var k=0; k<4; k++) { + for(var j=0; j<3; j++) { ext[k+4].inp[j] <== aux[k+60][j]; } + for(var j=0; j<3; j++) { ext[k+4].out[j] ==> aux[k+61][j]; } + } + + for(var j=0; j<3; j++) { out[j] <== aux[64][j]; } +} + +//------------------------------------------------------------------------------ +// the "compression function" takes 2 field elements as input and produces +// 1 field element as output. It is a trivial application of the permutation. + +template Compression() { + signal input inp[2]; + signal output out; + + component perm = Permutation(); + perm.inp[0] <== inp[0]; + perm.inp[1] <== inp[1]; + perm.inp[2] <== 0; + + perm.out[0] ==> out; +} + +//------------------------------------------------------------------------------ + diff --git a/circom_circuits/hash_bn/poseidon2_sponge.circom b/circom_circuits/hash_bn/poseidon2_sponge.circom new file mode 100644 index 0000000..ddbb580 --- /dev/null +++ b/circom_circuits/hash_bn/poseidon2_sponge.circom @@ -0,0 +1,127 @@ +// +pragma circom 2.0.0; + +include "poseidon2_perm.circom"; + +//------------------------------------------------------------------------------ + +function min(a,b) { + return (a <= b) ? a : b; +} + +//------------------------------------------------------------------------------ + +// +// Poseidon sponge construction +// +// t = size of state (currently fixed to 3) +// c = capacity (1 or 2) +// r = rate = t - c +// +// everything is measured in number of field elements +// +// we use the padding `10*` from the original Poseidon paper, +// and initial state constant zero. Note that this is different +// from the "SAFE padding" recommended in the Poseidon2 paper +// (which uses `0*` padding and a nontrivial initial state) +// + +template PoseidonSponge(t, capacity, input_len, output_len) { + + var rate = t - capacity; + + assert( t == 3); + + assert( capacity > 0 ); + assert( rate > 0 ); + assert( capacity < t ); + assert( rate < t ); + + signal input inp[ input_len]; + signal output out[output_len]; + + // round up to rate the input + 1 field element ("10*" padding) + var nblocks = ((input_len + 1) + (rate-1)) \ rate; + var nout = (output_len + (rate-1)) \ rate; + var padded_len = nblocks * rate; + + signal padded[padded_len]; + for(var i=0; i state[m+1]; + + } + + var q = min(rate, output_len); + for(var i=0; i out[i]; + } + var out_ptr = rate; + + for(var n=1; n state[nblocks+n ]; + + var q = min(rate, output_len-out_ptr); + for(var i=0; i out[out_ptr+i]; + } + out_ptr += rate; + } + +} + +//------------------------------------------------------------------------------ + +// +// sponge hash with rate=1 +// + +template Poseidon2_sponge_hash_rate_1(n) { + signal input inp[n]; + signal output out; + component sponge = PoseidonSponge(3, 2, n, 1); + sponge.inp <== inp; + sponge.out[0] ==> out; +} + +// +// sponge hash with rate=2 +// + +template Poseidon2_sponge_hash_rate_2(n) { + signal input inp[n]; + signal output out; + component sponge = PoseidonSponge(3, 1, n, 1); + sponge.inp <== inp; + sponge.out[0] ==> out; +} + +//------------------------------------------------------------------------------ diff --git a/circom_circuits/ledger/merkle.circom b/circom_circuits/ledger/merkle.circom new file mode 100644 index 0000000..56348ed --- /dev/null +++ b/circom_circuits/ledger/merkle.circom @@ -0,0 +1,69 @@ +//test +pragma circom 2.1.9; + +include "poseidon2_hash.circom"; +include "comparator.circom"; + +// proof of Merkle membership of depth n +// /!\ To call this function, it's important to check that each selector is a bit before!!! +template proof_of_membership(n) { + signal input nodes[n]; // The path forming the Merkle proof + signal input selector[n]; // it's the leaf's indice in big endian bits + signal input root; + signal input leaf; + + + + component compression_hash[n]; + + compression_hash[0] = Poseidon2_hash(2); + compression_hash[0].inp[0] <== leaf - selector[n-1] * (leaf - nodes[0]); + compression_hash[0].inp[1] <== nodes[0] - selector[n-1] * (nodes[0] - leaf); + + for(var i=1; i 0 ) + assert( rate > 0 ) + assert( capacity < 3 ) + assert( rate < 3 ) + + # round up to rate the input + 1 field element ("10*" padding) + nblocks = ((len(data) + 1) + (rate-1)) // rate; + nout = (output_len + (rate-1)) // rate; + padded_len = nblocks * rate; + + padded = [] + for i in range(len(data)): + padded.append(F(data[i])) + padded.append(F(1)) + for i in range(len(data)+1,padded_len): + padded.append(F(0)) + + civ = F(2**64 + 256*3 + rate) + + state = [F(0),F(0),F(civ)] + sorbed = [F(0) for j in range(rate)] + + for m in range(nblocks): + for i in range(rate): + a = state[i] + b = padded[m*rate+i] + sorbed[i] = a + b + state = Permutation(sorbed[0:rate] + state[rate:3]) + + q = min(rate, output_len) + for i in range(q): + output[i] = state[i] + out_ptr = rate + + for n in range(1,nout): + state[nblocks+n] = Permutation(state[nblocks+n-1]) + q = min(rate, output_len-out_ptr) + for i in range(q): + output[out_ptr+i] = state[nblocks+n][i] + out_ptr += rate + + return output + +R = RealField(500) #Real numbers with precision 500 bits + +if len(sys.argv) != Integer(4): + print("Usage: